--- /dev/null
+#
+# Functions for interacting with the interfaces table in the database
+#
+# Mark Huang <mlhuang@cs.princeton.edu>
+# Copyright (C) 2006 The Trustees of Princeton University
+#
+
+from types import StringTypes
+import socket
+import struct
+
+from PLC.Faults import *
+from PLC.Parameter import Parameter
+from PLC.Filter import Filter
+from PLC.Debug import profile
+from PLC.Table import Row, Table
+from PLC.NetworkTypes import NetworkType, NetworkTypes
+from PLC.NetworkMethods import NetworkMethod, NetworkMethods
+import PLC.Nodes
+
+# a class for validating IP address strings and applying netmasks
+class SimpleAddress:
+ def __init__(self, addrStr, type=None, parts=None):
+ if (not type):
+ if "." in addrStr:
+ type="ipv4"
+ elif ":" in addrStr:
+ type="ipv6"
+ else:
+ raise ValueError, "Unable to determine type of address: " + str(addrStr)
+
+ self.type = type
+
+ if (type=="ipv4"):
+ # e.g. 1.2.3.4
+ self.delim = "."
+ self.count = 4
+ self.base = 10
+ self.fieldMask = 0xFF
+ elif (type=="ipv6"):
+ # e.g. 2001:db8:85a3:0:0:8a2e:370:7334
+ self.delim = ":"
+ self.count = 8
+ self.base = 16
+ self.fieldMask = 0xFFFF
+ else:
+ raise ValueError, "Unknown type of address: " + str(type)
+
+ if addrStr:
+ parts = addrStr.split(self.delim)
+
+ # deal with ipv6 groups of zeros notation
+ # :: represents a group of 0:0:0... fields
+ if ('' in parts) and (self.type=="ipv6"):
+ expanded_parts = []
+ parts = [elem for i, elem in enumerate(parts) if i == 0 or elem!="" or parts[i-1] != elem]
+ for i,part in enumerate(parts):
+ if part=="":
+ for j in range(0, self.count-len(parts)+1):
+ expanded_parts.append("0")
+ else:
+ expanded_parts.append(part)
+ parts = expanded_parts
+
+ parts = [int(x, self.base) for x in parts]
+ else:
+ if (parts==None):
+ raise ValueError, "Must supply either addrStr or parts to SimpleAddress"
+
+ if len(parts)!=self.count:
+ raise ValueError, "Wrong number of fields in address: " + str(addrStr)
+
+ self.parts = parts
+
+ def as_str(self):
+ if (self.base == 16):
+ textParts = ["%x"%x for x in self.parts]
+ else: # self.base == 10
+ textParts = [str(x) for x in self.parts]
+ return self.delim.join(textParts)
+
+ def compute_network(self, netmask):
+ if (type(netmask)==str) or (type(netmask)==unicode):
+ netmask = SimpleAddress(netmask)
+
+ if (self.type != netmask.type):
+ raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address"
+
+ result = []
+ for i in range(0, self.count):
+ result.append(self.parts[i] & netmask.parts[i])
+
+ return SimpleAddress(addrStr=None, type=self.type, parts = result)
+
+ def compute_broadcast(self, netmask):
+ if (type(netmask)==str) or (type(netmask)==unicode):
+ netmask = SimpleAddress(netmask)
+
+ if (self.type != netmask.type):
+ raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address"
+
+ result = []
+ for i in range(0, self.count):
+ result.append(self.parts[i] | (~netmask.parts[i] & self.fieldMask))
+
+ return SimpleAddress(addrStr=None, type=self.type, parts = result)
+
+class Subnet:
+ def __init__(self, subnetStr):
+ if "/" in subnetStr:
+ (addrStr, netBitCountStr) = subnetStr.split("/")
+ self.addr = SimpleAddress(addrStr)
+ self.netBitCount = int(netBitCountStr)
+ else:
+ self.addr = SimpleAddress(subnetStr)
+ self.netBitCount = None
+
+ def as_str(self):
+ if self.netBitCount is not None:
+ return self.addr.as_str() + "/" + str(self.netBitCount)
+ else:
+ return self.addr.as_str()
+
+class IpAddress(Row):
+ """
+ Representation of a row in the ip_addresses table. To use, optionally
+ instantiate with a dict of values. Update as you would a
+ dict. Commit to the database with sync().
+ """
+
+ table_name = 'ip_addresses'
+ primary_key = 'ip_address_id'
+ join_tables = []
+ fields = {
+ 'ip_address_id': Parameter(int, "IP Address identifier"),
+ 'interface_id': Parameter(int, "Interface associated with this address"),
+ 'type': Parameter(str, "Address type (e.g., 'ipv4')"),
+ 'ip_addr': Parameter(str, "IP Address", nullok = False),
+ 'netmask': Parameter(str, "Subnet mask", nullok = False),
+ 'last_updated': Parameter(int, "Date and time when node entry was created", ro = True),
+ }
+
+ tags = {}
+
+ def validate_ip_addr(self, ip):
+ # SimpleAddress with throw exceptions if the ip
+ SimpleAddress(ip, self["type"])
+ return ip
+
+ validate_netmask = validate_ip_addr
+
+ def validate_interface_id(self, interface_id):
+ interfaces = PLC.Interfaces.Interfaces(self.api, [interface_id])
+ if not interfaces:
+ raise PLCInvalidArgument, "No such interface %d"%interface_id
+
+ return interface_id
+
+ def validate(self):
+ """
+ Flush changes back to the database.
+ """
+
+ # Basic validation
+ Row.validate(self)
+
+ validate_last_updated = Row.validate_timestamp
+
+ def update_timestamp(self, col_name, commit = True):
+ """
+ Update col_name field with current time
+ """
+
+ assert 'ip_address_id' in self
+ assert self.table_name
+
+ self.api.db.do("UPDATE %s SET %s = CURRENT_TIMESTAMP " % (self.table_name, col_name) + \
+ " where ip_address_id = %d" % (self['ip_address_id']) )
+ self.sync(commit)
+
+ def update_last_updated(self, commit = True):
+ self.update_timestamp('last_updated', commit)
+
+ def delete(self,commit=True):
+ Row.delete(self)
+
+ def get_network(self):
+ return SimpleAddress(self["ip_addr"], self["type"]).compute_network(self["netmask"]).as_str()
+
+ def get_broadcast(self):
+ return SimpleAddress(self["ip_addr"], self["type"]).compute_broadcast(self["netmask"]).as_str()
+
+class IpAddresses(Table):
+ """
+ Representation of row(s) from the ip_addresses table in the
+ database.
+ """
+
+ def __init__(self, api, ip_address_filter = None, columns = None):
+ Table.__init__(self, api, IpAddress, columns)
+
+ # the view that we're selecting upon: start with view_ip_addresses
+ view = "view_ip_addresses"
+ # as many left joins as requested tags
+ for tagname in self.tag_columns:
+ view= "%s left join %s using (%s)"%(view,IpAddress.tagvalue_view_name(tagname),
+ IpAddress.primary_key)
+
+ sql = "SELECT %s FROM %s WHERE True" % \
+ (", ".join(self.columns.keys()+self.tag_columns.keys()),view)
+
+ if ip_address_filter is not None:
+ if isinstance(ip_address_filter, (list, tuple, set)):
+ # Separate the list into integers and strings
+ ints = filter(lambda x: isinstance(x, (int, long)), ip_address_filter)
+ strs = filter(lambda x: isinstance(x, StringTypes), ip_address_filter)
+ ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': ints, 'ip_addr': strs})
+ sql += " AND (%s) %s" % ip_address_filter.sql(api, "OR")
+ elif isinstance(ip_address_filter, dict):
+ allowed_fields=dict(IpAddress.fields.items()+IpAddress.tags.items())
+ ip_address_filter = Filter(allowed_fields, ip_address_filter)
+ sql += " AND (%s) %s" % ip_address_filter.sql(api)
+ elif isinstance(ip_address_filter, int):
+ ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': [ip_address_filter]})
+ sql += " AND (%s) %s" % ip_address_filter.sql(api)
+ elif isinstance (ip_address_filter, StringTypes):
+ ip_address_filter = Filter(IpAddresses.fields, {'ip':[ip_address_filter]})
+ sql += " AND (%s) %s" % ip_address_filter.sql(api, "AND")
+ else:
+ raise PLCInvalidArgument, "Wrong ip_address filter %r"%ip_address_filter
+
+ self.selectall(sql)