X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=PLC%2FIpAddresses.py;fp=PLC%2FIpAddresses.py;h=6be9b7d41a0a0b77c99cabf9ae3bbc2ebf6cd4e5;hb=f1b4415be5ba4be9941d25b531c46696b377fa3a;hp=0000000000000000000000000000000000000000;hpb=cb77137d884be296fe1e0b15eabe69d18ba443ae;p=plcapi.git diff --git a/PLC/IpAddresses.py b/PLC/IpAddresses.py new file mode 100644 index 0000000..6be9b7d --- /dev/null +++ b/PLC/IpAddresses.py @@ -0,0 +1,232 @@ +# +# Functions for interacting with the interfaces table in the database +# +# Mark Huang +# Copyright (C) 2006 The Trustees of Princeton University +# + +from types import StringTypes +import socket +import struct + +from PLC.Faults import * +from PLC.Parameter import Parameter +from PLC.Filter import Filter +from PLC.Debug import profile +from PLC.Table import Row, Table +from PLC.NetworkTypes import NetworkType, NetworkTypes +from PLC.NetworkMethods import NetworkMethod, NetworkMethods +import PLC.Nodes + +# a class for validating IP address strings and applying netmasks +class SimpleAddress: + def __init__(self, addrStr, type=None, parts=None): + if (not type): + if "." in addrStr: + type="ipv4" + elif ":" in addrStr: + type="ipv6" + else: + raise ValueError, "Unable to determine type of address: " + str(addrStr) + + self.type = type + + if (type=="ipv4"): + # e.g. 1.2.3.4 + self.delim = "." + self.count = 4 + self.base = 10 + self.fieldMask = 0xFF + elif (type=="ipv6"): + # e.g. 2001:db8:85a3:0:0:8a2e:370:7334 + self.delim = ":" + self.count = 8 + self.base = 16 + self.fieldMask = 0xFFFF + else: + raise ValueError, "Unknown type of address: " + str(type) + + if addrStr: + parts = addrStr.split(self.delim) + + # deal with ipv6 groups of zeros notation + # :: represents a group of 0:0:0... fields + if ('' in parts) and (self.type=="ipv6"): + expanded_parts = [] + parts = [elem for i, elem in enumerate(parts) if i == 0 or elem!="" or parts[i-1] != elem] + for i,part in enumerate(parts): + if part=="": + for j in range(0, self.count-len(parts)+1): + expanded_parts.append("0") + else: + expanded_parts.append(part) + parts = expanded_parts + + parts = [int(x, self.base) for x in parts] + else: + if (parts==None): + raise ValueError, "Must supply either addrStr or parts to SimpleAddress" + + if len(parts)!=self.count: + raise ValueError, "Wrong number of fields in address: " + str(addrStr) + + self.parts = parts + + def as_str(self): + if (self.base == 16): + textParts = ["%x"%x for x in self.parts] + else: # self.base == 10 + textParts = [str(x) for x in self.parts] + return self.delim.join(textParts) + + def compute_network(self, netmask): + if (type(netmask)==str) or (type(netmask)==unicode): + netmask = SimpleAddress(netmask) + + if (self.type != netmask.type): + raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address" + + result = [] + for i in range(0, self.count): + result.append(self.parts[i] & netmask.parts[i]) + + return SimpleAddress(addrStr=None, type=self.type, parts = result) + + def compute_broadcast(self, netmask): + if (type(netmask)==str) or (type(netmask)==unicode): + netmask = SimpleAddress(netmask) + + if (self.type != netmask.type): + raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address" + + result = [] + for i in range(0, self.count): + result.append(self.parts[i] | (~netmask.parts[i] & self.fieldMask)) + + return SimpleAddress(addrStr=None, type=self.type, parts = result) + +class Subnet: + def __init__(self, subnetStr): + if "/" in subnetStr: + (addrStr, netBitCountStr) = subnetStr.split("/") + self.addr = SimpleAddress(addrStr) + self.netBitCount = int(netBitCountStr) + else: + self.addr = SimpleAddress(subnetStr) + self.netBitCount = None + + def as_str(self): + if self.netBitCount is not None: + return self.addr.as_str() + "/" + str(self.netBitCount) + else: + return self.addr.as_str() + +class IpAddress(Row): + """ + Representation of a row in the ip_addresses table. To use, optionally + instantiate with a dict of values. Update as you would a + dict. Commit to the database with sync(). + """ + + table_name = 'ip_addresses' + primary_key = 'ip_address_id' + join_tables = [] + fields = { + 'ip_address_id': Parameter(int, "IP Address identifier"), + 'interface_id': Parameter(int, "Interface associated with this address"), + 'type': Parameter(str, "Address type (e.g., 'ipv4')"), + 'ip_addr': Parameter(str, "IP Address", nullok = False), + 'netmask': Parameter(str, "Subnet mask", nullok = False), + 'last_updated': Parameter(int, "Date and time when node entry was created", ro = True), + } + + tags = {} + + def validate_ip_addr(self, ip): + # SimpleAddress with throw exceptions if the ip + SimpleAddress(ip, self["type"]) + return ip + + validate_netmask = validate_ip_addr + + def validate_interface_id(self, interface_id): + interfaces = PLC.Interfaces.Interfaces(self.api, [interface_id]) + if not interfaces: + raise PLCInvalidArgument, "No such interface %d"%interface_id + + return interface_id + + def validate(self): + """ + Flush changes back to the database. + """ + + # Basic validation + Row.validate(self) + + validate_last_updated = Row.validate_timestamp + + def update_timestamp(self, col_name, commit = True): + """ + Update col_name field with current time + """ + + assert 'ip_address_id' in self + assert self.table_name + + self.api.db.do("UPDATE %s SET %s = CURRENT_TIMESTAMP " % (self.table_name, col_name) + \ + " where ip_address_id = %d" % (self['ip_address_id']) ) + self.sync(commit) + + def update_last_updated(self, commit = True): + self.update_timestamp('last_updated', commit) + + def delete(self,commit=True): + Row.delete(self) + + def get_network(self): + return SimpleAddress(self["ip_addr"], self["type"]).compute_network(self["netmask"]).as_str() + + def get_broadcast(self): + return SimpleAddress(self["ip_addr"], self["type"]).compute_broadcast(self["netmask"]).as_str() + +class IpAddresses(Table): + """ + Representation of row(s) from the ip_addresses table in the + database. + """ + + def __init__(self, api, ip_address_filter = None, columns = None): + Table.__init__(self, api, IpAddress, columns) + + # the view that we're selecting upon: start with view_ip_addresses + view = "view_ip_addresses" + # as many left joins as requested tags + for tagname in self.tag_columns: + view= "%s left join %s using (%s)"%(view,IpAddress.tagvalue_view_name(tagname), + IpAddress.primary_key) + + sql = "SELECT %s FROM %s WHERE True" % \ + (", ".join(self.columns.keys()+self.tag_columns.keys()),view) + + if ip_address_filter is not None: + if isinstance(ip_address_filter, (list, tuple, set)): + # Separate the list into integers and strings + ints = filter(lambda x: isinstance(x, (int, long)), ip_address_filter) + strs = filter(lambda x: isinstance(x, StringTypes), ip_address_filter) + ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': ints, 'ip_addr': strs}) + sql += " AND (%s) %s" % ip_address_filter.sql(api, "OR") + elif isinstance(ip_address_filter, dict): + allowed_fields=dict(IpAddress.fields.items()+IpAddress.tags.items()) + ip_address_filter = Filter(allowed_fields, ip_address_filter) + sql += " AND (%s) %s" % ip_address_filter.sql(api) + elif isinstance(ip_address_filter, int): + ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': [ip_address_filter]}) + sql += " AND (%s) %s" % ip_address_filter.sql(api) + elif isinstance (ip_address_filter, StringTypes): + ip_address_filter = Filter(IpAddresses.fields, {'ip':[ip_address_filter]}) + sql += " AND (%s) %s" % ip_address_filter.sql(api, "AND") + else: + raise PLCInvalidArgument, "Wrong ip_address filter %r"%ip_address_filter + + self.selectall(sql)