bring over newinterface branch from Verivue
[plcapi.git] / PLC / IpAddresses.py
diff --git a/PLC/IpAddresses.py b/PLC/IpAddresses.py
new file mode 100644 (file)
index 0000000..6be9b7d
--- /dev/null
@@ -0,0 +1,232 @@
+#
+# Functions for interacting with the interfaces table in the database
+#
+# Mark Huang <mlhuang@cs.princeton.edu>
+# Copyright (C) 2006 The Trustees of Princeton University
+#
+
+from types import StringTypes
+import socket
+import struct
+
+from PLC.Faults import *
+from PLC.Parameter import Parameter
+from PLC.Filter import Filter
+from PLC.Debug import profile
+from PLC.Table import Row, Table
+from PLC.NetworkTypes import NetworkType, NetworkTypes
+from PLC.NetworkMethods import NetworkMethod, NetworkMethods
+import PLC.Nodes
+
+# a class for validating IP address strings and applying netmasks
+class SimpleAddress:
+    def __init__(self, addrStr, type=None, parts=None):
+        if (not type):
+            if "." in addrStr:
+                type="ipv4"
+            elif ":" in addrStr:
+                type="ipv6"
+            else:
+                raise ValueError, "Unable to determine type of address: " + str(addrStr)
+
+        self.type = type
+
+        if (type=="ipv4"):
+            # e.g. 1.2.3.4
+            self.delim = "."
+            self.count = 4
+            self.base = 10
+            self.fieldMask = 0xFF
+        elif (type=="ipv6"):
+            # e.g. 2001:db8:85a3:0:0:8a2e:370:7334
+            self.delim = ":"
+            self.count = 8
+            self.base = 16
+            self.fieldMask = 0xFFFF
+        else:
+            raise ValueError, "Unknown type of address: " + str(type)
+
+        if addrStr:
+            parts = addrStr.split(self.delim)
+
+            # deal with ipv6 groups of zeros notation
+            #  :: represents a group of 0:0:0... fields
+            if ('' in parts) and (self.type=="ipv6"):
+               expanded_parts = []
+               parts = [elem for i, elem in enumerate(parts) if i == 0 or elem!="" or parts[i-1] != elem]
+               for i,part in enumerate(parts):
+                  if part=="":
+                     for j in range(0, self.count-len(parts)+1):
+                         expanded_parts.append("0")
+                  else:
+                     expanded_parts.append(part)
+               parts = expanded_parts
+
+            parts = [int(x, self.base) for x in parts]
+        else:
+            if (parts==None):
+                raise ValueError, "Must supply either addrStr or parts to SimpleAddress"
+
+        if len(parts)!=self.count:
+            raise ValueError, "Wrong number of fields in address: " + str(addrStr)
+
+        self.parts = parts
+
+    def as_str(self):
+        if (self.base == 16):
+            textParts = ["%x"%x for x in self.parts]
+        else:   # self.base == 10
+            textParts = [str(x) for x in self.parts]
+        return self.delim.join(textParts)
+
+    def compute_network(self, netmask):
+        if (type(netmask)==str) or (type(netmask)==unicode):
+            netmask = SimpleAddress(netmask)
+
+        if (self.type != netmask.type):
+            raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address"
+
+        result = []
+        for i in range(0, self.count):
+            result.append(self.parts[i] & netmask.parts[i])
+
+        return SimpleAddress(addrStr=None, type=self.type, parts = result)
+
+    def compute_broadcast(self, netmask):
+        if (type(netmask)==str) or (type(netmask)==unicode):
+            netmask = SimpleAddress(netmask)
+
+        if (self.type != netmask.type):
+            raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address"
+
+        result = []
+        for i in range(0, self.count):
+            result.append(self.parts[i] | (~netmask.parts[i] & self.fieldMask))
+
+        return SimpleAddress(addrStr=None, type=self.type, parts = result)
+
+class Subnet:
+    def __init__(self, subnetStr):
+        if "/" in subnetStr:
+            (addrStr, netBitCountStr) = subnetStr.split("/")
+            self.addr = SimpleAddress(addrStr)
+            self.netBitCount = int(netBitCountStr)
+        else:
+            self.addr = SimpleAddress(subnetStr)
+            self.netBitCount = None
+
+    def as_str(self):
+        if self.netBitCount is not None:
+            return self.addr.as_str() + "/" + str(self.netBitCount)
+        else:
+            return self.addr.as_str()
+
+class IpAddress(Row):
+    """
+    Representation of a row in the ip_addresses table. To use, optionally
+    instantiate with a dict of values. Update as you would a
+    dict. Commit to the database with sync().
+    """
+
+    table_name = 'ip_addresses'
+    primary_key = 'ip_address_id'
+    join_tables = []
+    fields = {
+        'ip_address_id': Parameter(int, "IP Address identifier"),
+        'interface_id': Parameter(int, "Interface associated with this address"),
+        'type': Parameter(str, "Address type (e.g., 'ipv4')"),
+        'ip_addr': Parameter(str, "IP Address", nullok = False),
+        'netmask': Parameter(str, "Subnet mask", nullok = False),
+        'last_updated': Parameter(int, "Date and time when node entry was created", ro = True),
+        }
+
+    tags = {}
+
+    def validate_ip_addr(self, ip):
+        # SimpleAddress with throw exceptions if the ip
+        SimpleAddress(ip, self["type"])
+        return ip
+
+    validate_netmask = validate_ip_addr
+
+    def validate_interface_id(self, interface_id):
+        interfaces = PLC.Interfaces.Interfaces(self.api, [interface_id])
+        if not interfaces:
+            raise PLCInvalidArgument, "No such interface %d"%interface_id
+
+        return interface_id
+
+    def validate(self):
+        """
+        Flush changes back to the database.
+        """
+
+        # Basic validation
+        Row.validate(self)
+
+    validate_last_updated = Row.validate_timestamp
+
+    def update_timestamp(self, col_name, commit = True):
+        """
+        Update col_name field with current time
+        """
+
+        assert 'ip_address_id' in self
+        assert self.table_name
+
+        self.api.db.do("UPDATE %s SET %s = CURRENT_TIMESTAMP " % (self.table_name, col_name) + \
+                       " where ip_address_id = %d" % (self['ip_address_id']) )
+        self.sync(commit)
+
+    def update_last_updated(self, commit = True):
+        self.update_timestamp('last_updated', commit)
+
+    def delete(self,commit=True):
+        Row.delete(self)
+
+    def get_network(self):
+        return SimpleAddress(self["ip_addr"], self["type"]).compute_network(self["netmask"]).as_str()
+
+    def get_broadcast(self):
+        return SimpleAddress(self["ip_addr"], self["type"]).compute_broadcast(self["netmask"]).as_str()
+
+class IpAddresses(Table):
+    """
+    Representation of row(s) from the ip_addresses table in the
+    database.
+    """
+
+    def __init__(self, api, ip_address_filter = None, columns = None):
+        Table.__init__(self, api, IpAddress, columns)
+
+        # the view that we're selecting upon: start with view_ip_addresses
+        view = "view_ip_addresses"
+        # as many left joins as requested tags
+        for tagname in self.tag_columns:
+            view= "%s left join %s using (%s)"%(view,IpAddress.tagvalue_view_name(tagname),
+                                                IpAddress.primary_key)
+
+        sql = "SELECT %s FROM %s WHERE True" % \
+            (", ".join(self.columns.keys()+self.tag_columns.keys()),view)
+
+        if ip_address_filter is not None:
+            if isinstance(ip_address_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), ip_address_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), ip_address_filter)
+                ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': ints, 'ip_addr': strs})
+                sql += " AND (%s) %s" % ip_address_filter.sql(api, "OR")
+            elif isinstance(ip_address_filter, dict):
+                allowed_fields=dict(IpAddress.fields.items()+IpAddress.tags.items())
+                ip_address_filter = Filter(allowed_fields, ip_address_filter)
+                sql += " AND (%s) %s" % ip_address_filter.sql(api)
+            elif isinstance(ip_address_filter, int):
+                ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': [ip_address_filter]})
+                sql += " AND (%s) %s" % ip_address_filter.sql(api)
+            elif isinstance (ip_address_filter, StringTypes):
+                ip_address_filter = Filter(IpAddresses.fields, {'ip':[ip_address_filter]})
+                sql += " AND (%s) %s" % ip_address_filter.sql(api, "AND")
+            else:
+                raise PLCInvalidArgument, "Wrong ip_address filter %r"%ip_address_filter
+
+        self.selectall(sql)