# # Functions for interacting with the interfaces table in the database # # Mark Huang # Copyright (C) 2006 The Trustees of Princeton University # from types import StringTypes import socket import struct from PLC.Faults import * from PLC.Parameter import Parameter from PLC.Filter import Filter from PLC.Debug import profile from PLC.Table import Row, Table from PLC.NetworkTypes import NetworkType, NetworkTypes from PLC.NetworkMethods import NetworkMethod, NetworkMethods import PLC.Nodes # a class for validating IP address strings and applying netmasks class SimpleAddress: def __init__(self, addrStr, type=None, parts=None): if (not type): if "." in addrStr: type="ipv4" elif ":" in addrStr: type="ipv6" else: raise ValueError, "Unable to determine type of address: " + str(addrStr) self.type = type if (type=="ipv4"): # e.g. 1.2.3.4 self.delim = "." self.count = 4 self.base = 10 self.fieldMask = 0xFF elif (type=="ipv6"): # e.g. 2001:db8:85a3:0:0:8a2e:370:7334 self.delim = ":" self.count = 8 self.base = 16 self.fieldMask = 0xFFFF else: raise ValueError, "Unknown type of address: " + str(type) if addrStr: parts = addrStr.split(self.delim) # deal with ipv6 groups of zeros notation # :: represents a group of 0:0:0... fields if ('' in parts) and (self.type=="ipv6"): expanded_parts = [] parts = [elem for i, elem in enumerate(parts) if i == 0 or elem!="" or parts[i-1] != elem] for i,part in enumerate(parts): if part=="": for j in range(0, self.count-len(parts)+1): expanded_parts.append("0") else: expanded_parts.append(part) parts = expanded_parts parts = [int(x, self.base) for x in parts] else: if (parts==None): raise ValueError, "Must supply either addrStr or parts to SimpleAddress" if len(parts)!=self.count: raise ValueError, "Wrong number of fields in address: " + str(addrStr) self.parts = parts def as_str(self): if (self.base == 16): textParts = ["%x"%x for x in self.parts] else: # self.base == 10 textParts = [str(x) for x in self.parts] return self.delim.join(textParts) def compute_network(self, netmask): if (type(netmask)==str) or (type(netmask)==unicode): netmask = SimpleAddress(netmask) if (self.type != netmask.type): raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address" result = [] for i in range(0, self.count): result.append(self.parts[i] & netmask.parts[i]) return SimpleAddress(addrStr=None, type=self.type, parts = result) def compute_broadcast(self, netmask): if (type(netmask)==str) or (type(netmask)==unicode): netmask = SimpleAddress(netmask) if (self.type != netmask.type): raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address" result = [] for i in range(0, self.count): result.append(self.parts[i] | (~netmask.parts[i] & self.fieldMask)) return SimpleAddress(addrStr=None, type=self.type, parts = result) class Subnet: def __init__(self, subnetStr): if "/" in subnetStr: (addrStr, netBitCountStr) = subnetStr.split("/") self.addr = SimpleAddress(addrStr) self.netBitCount = int(netBitCountStr) else: self.addr = SimpleAddress(subnetStr) self.netBitCount = None def as_str(self): if self.netBitCount is not None: return self.addr.as_str() + "/" + str(self.netBitCount) else: return self.addr.as_str() class IpAddress(Row): """ Representation of a row in the ip_addresses table. To use, optionally instantiate with a dict of values. Update as you would a dict. Commit to the database with sync(). """ table_name = 'ip_addresses' primary_key = 'ip_address_id' join_tables = [] fields = { 'ip_address_id': Parameter(int, "IP Address identifier"), 'interface_id': Parameter(int, "Interface associated with this address"), 'type': Parameter(str, "Address type (e.g., 'ipv4')"), 'ip_addr': Parameter(str, "IP Address", nullok = False), 'netmask': Parameter(str, "Subnet mask", nullok = False), 'last_updated': Parameter(int, "Date and time when node entry was created", ro = True), } tags = {} def validate_ip_addr(self, ip): # SimpleAddress with throw exceptions if the ip SimpleAddress(ip, self["type"]) return ip validate_netmask = validate_ip_addr def validate_interface_id(self, interface_id): interfaces = PLC.Interfaces.Interfaces(self.api, [interface_id]) if not interfaces: raise PLCInvalidArgument, "No such interface %d"%interface_id return interface_id def validate(self): """ Flush changes back to the database. """ # Basic validation Row.validate(self) validate_last_updated = Row.validate_timestamp def update_timestamp(self, col_name, commit = True): """ Update col_name field with current time """ assert 'ip_address_id' in self assert self.table_name self.api.db.do("UPDATE %s SET %s = CURRENT_TIMESTAMP " % (self.table_name, col_name) + \ " where ip_address_id = %d" % (self['ip_address_id']) ) self.sync(commit) def update_last_updated(self, commit = True): self.update_timestamp('last_updated', commit) def delete(self,commit=True): Row.delete(self) def get_network(self): return SimpleAddress(self["ip_addr"], self["type"]).compute_network(self["netmask"]).as_str() def get_broadcast(self): return SimpleAddress(self["ip_addr"], self["type"]).compute_broadcast(self["netmask"]).as_str() class IpAddresses(Table): """ Representation of row(s) from the ip_addresses table in the database. """ def __init__(self, api, ip_address_filter = None, columns = None): Table.__init__(self, api, IpAddress, columns) # the view that we're selecting upon: start with view_ip_addresses view = "view_ip_addresses" # as many left joins as requested tags for tagname in self.tag_columns: view= "%s left join %s using (%s)"%(view,IpAddress.tagvalue_view_name(tagname), IpAddress.primary_key) sql = "SELECT %s FROM %s WHERE True" % \ (", ".join(self.columns.keys()+self.tag_columns.keys()),view) if ip_address_filter is not None: if isinstance(ip_address_filter, (list, tuple, set)): # Separate the list into integers and strings ints = filter(lambda x: isinstance(x, (int, long)), ip_address_filter) strs = filter(lambda x: isinstance(x, StringTypes), ip_address_filter) ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': ints, 'ip_addr': strs}) sql += " AND (%s) %s" % ip_address_filter.sql(api, "OR") elif isinstance(ip_address_filter, dict): allowed_fields=dict(IpAddress.fields.items()+IpAddress.tags.items()) ip_address_filter = Filter(allowed_fields, ip_address_filter) sql += " AND (%s) %s" % ip_address_filter.sql(api) elif isinstance(ip_address_filter, int): ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': [ip_address_filter]}) sql += " AND (%s) %s" % ip_address_filter.sql(api) elif isinstance (ip_address_filter, StringTypes): ip_address_filter = Filter(IpAddresses.fields, {'ip':[ip_address_filter]}) sql += " AND (%s) %s" % ip_address_filter.sql(api, "AND") else: raise PLCInvalidArgument, "Wrong ip_address filter %r"%ip_address_filter self.selectall(sql)