bring over newinterface branch from Verivue
[plcapi.git] / PLC / IpAddresses.py
1 #
2 # Functions for interacting with the interfaces table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7
8 from types import StringTypes
9 import socket
10 import struct
11
12 from PLC.Faults import *
13 from PLC.Parameter import Parameter
14 from PLC.Filter import Filter
15 from PLC.Debug import profile
16 from PLC.Table import Row, Table
17 from PLC.NetworkTypes import NetworkType, NetworkTypes
18 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
19 import PLC.Nodes
20
21 # a class for validating IP address strings and applying netmasks
22 class SimpleAddress:
23     def __init__(self, addrStr, type=None, parts=None):
24         if (not type):
25             if "." in addrStr:
26                 type="ipv4"
27             elif ":" in addrStr:
28                 type="ipv6"
29             else:
30                 raise ValueError, "Unable to determine type of address: " + str(addrStr)
31
32         self.type = type
33
34         if (type=="ipv4"):
35             # e.g. 1.2.3.4
36             self.delim = "."
37             self.count = 4
38             self.base = 10
39             self.fieldMask = 0xFF
40         elif (type=="ipv6"):
41             # e.g. 2001:db8:85a3:0:0:8a2e:370:7334
42             self.delim = ":"
43             self.count = 8
44             self.base = 16
45             self.fieldMask = 0xFFFF
46         else:
47             raise ValueError, "Unknown type of address: " + str(type)
48
49         if addrStr:
50             parts = addrStr.split(self.delim)
51
52             # deal with ipv6 groups of zeros notation
53             #  :: represents a group of 0:0:0... fields
54             if ('' in parts) and (self.type=="ipv6"):
55                expanded_parts = []
56                parts = [elem for i, elem in enumerate(parts) if i == 0 or elem!="" or parts[i-1] != elem]
57                for i,part in enumerate(parts):
58                   if part=="":
59                      for j in range(0, self.count-len(parts)+1):
60                          expanded_parts.append("0")
61                   else:
62                      expanded_parts.append(part)
63                parts = expanded_parts
64
65             parts = [int(x, self.base) for x in parts]
66         else:
67             if (parts==None):
68                 raise ValueError, "Must supply either addrStr or parts to SimpleAddress"
69
70         if len(parts)!=self.count:
71             raise ValueError, "Wrong number of fields in address: " + str(addrStr)
72
73         self.parts = parts
74
75     def as_str(self):
76         if (self.base == 16):
77             textParts = ["%x"%x for x in self.parts]
78         else:   # self.base == 10
79             textParts = [str(x) for x in self.parts]
80         return self.delim.join(textParts)
81
82     def compute_network(self, netmask):
83         if (type(netmask)==str) or (type(netmask)==unicode):
84             netmask = SimpleAddress(netmask)
85
86         if (self.type != netmask.type):
87             raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address"
88
89         result = []
90         for i in range(0, self.count):
91             result.append(self.parts[i] & netmask.parts[i])
92
93         return SimpleAddress(addrStr=None, type=self.type, parts = result)
94
95     def compute_broadcast(self, netmask):
96         if (type(netmask)==str) or (type(netmask)==unicode):
97             netmask = SimpleAddress(netmask)
98
99         if (self.type != netmask.type):
100             raise ValueError, "Cannot apply " + netmask.type + " netmask to " + self.type + " ip address"
101
102         result = []
103         for i in range(0, self.count):
104             result.append(self.parts[i] | (~netmask.parts[i] & self.fieldMask))
105
106         return SimpleAddress(addrStr=None, type=self.type, parts = result)
107
108 class Subnet:
109     def __init__(self, subnetStr):
110         if "/" in subnetStr:
111             (addrStr, netBitCountStr) = subnetStr.split("/")
112             self.addr = SimpleAddress(addrStr)
113             self.netBitCount = int(netBitCountStr)
114         else:
115             self.addr = SimpleAddress(subnetStr)
116             self.netBitCount = None
117
118     def as_str(self):
119         if self.netBitCount is not None:
120             return self.addr.as_str() + "/" + str(self.netBitCount)
121         else:
122             return self.addr.as_str()
123
124 class IpAddress(Row):
125     """
126     Representation of a row in the ip_addresses table. To use, optionally
127     instantiate with a dict of values. Update as you would a
128     dict. Commit to the database with sync().
129     """
130
131     table_name = 'ip_addresses'
132     primary_key = 'ip_address_id'
133     join_tables = []
134     fields = {
135         'ip_address_id': Parameter(int, "IP Address identifier"),
136         'interface_id': Parameter(int, "Interface associated with this address"),
137         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
138         'ip_addr': Parameter(str, "IP Address", nullok = False),
139         'netmask': Parameter(str, "Subnet mask", nullok = False),
140         'last_updated': Parameter(int, "Date and time when node entry was created", ro = True),
141         }
142
143     tags = {}
144
145     def validate_ip_addr(self, ip):
146         # SimpleAddress with throw exceptions if the ip
147         SimpleAddress(ip, self["type"])
148         return ip
149
150     validate_netmask = validate_ip_addr
151
152     def validate_interface_id(self, interface_id):
153         interfaces = PLC.Interfaces.Interfaces(self.api, [interface_id])
154         if not interfaces:
155             raise PLCInvalidArgument, "No such interface %d"%interface_id
156
157         return interface_id
158
159     def validate(self):
160         """
161         Flush changes back to the database.
162         """
163
164         # Basic validation
165         Row.validate(self)
166
167     validate_last_updated = Row.validate_timestamp
168
169     def update_timestamp(self, col_name, commit = True):
170         """
171         Update col_name field with current time
172         """
173
174         assert 'ip_address_id' in self
175         assert self.table_name
176
177         self.api.db.do("UPDATE %s SET %s = CURRENT_TIMESTAMP " % (self.table_name, col_name) + \
178                        " where ip_address_id = %d" % (self['ip_address_id']) )
179         self.sync(commit)
180
181     def update_last_updated(self, commit = True):
182         self.update_timestamp('last_updated', commit)
183
184     def delete(self,commit=True):
185         Row.delete(self)
186
187     def get_network(self):
188         return SimpleAddress(self["ip_addr"], self["type"]).compute_network(self["netmask"]).as_str()
189
190     def get_broadcast(self):
191         return SimpleAddress(self["ip_addr"], self["type"]).compute_broadcast(self["netmask"]).as_str()
192
193 class IpAddresses(Table):
194     """
195     Representation of row(s) from the ip_addresses table in the
196     database.
197     """
198
199     def __init__(self, api, ip_address_filter = None, columns = None):
200         Table.__init__(self, api, IpAddress, columns)
201
202         # the view that we're selecting upon: start with view_ip_addresses
203         view = "view_ip_addresses"
204         # as many left joins as requested tags
205         for tagname in self.tag_columns:
206             view= "%s left join %s using (%s)"%(view,IpAddress.tagvalue_view_name(tagname),
207                                                 IpAddress.primary_key)
208
209         sql = "SELECT %s FROM %s WHERE True" % \
210             (", ".join(self.columns.keys()+self.tag_columns.keys()),view)
211
212         if ip_address_filter is not None:
213             if isinstance(ip_address_filter, (list, tuple, set)):
214                 # Separate the list into integers and strings
215                 ints = filter(lambda x: isinstance(x, (int, long)), ip_address_filter)
216                 strs = filter(lambda x: isinstance(x, StringTypes), ip_address_filter)
217                 ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': ints, 'ip_addr': strs})
218                 sql += " AND (%s) %s" % ip_address_filter.sql(api, "OR")
219             elif isinstance(ip_address_filter, dict):
220                 allowed_fields=dict(IpAddress.fields.items()+IpAddress.tags.items())
221                 ip_address_filter = Filter(allowed_fields, ip_address_filter)
222                 sql += " AND (%s) %s" % ip_address_filter.sql(api)
223             elif isinstance(ip_address_filter, int):
224                 ip_address_filter = Filter(IpAddress.fields, {'ip_address_id': [ip_address_filter]})
225                 sql += " AND (%s) %s" % ip_address_filter.sql(api)
226             elif isinstance (ip_address_filter, StringTypes):
227                 ip_address_filter = Filter(IpAddresses.fields, {'ip':[ip_address_filter]})
228                 sql += " AND (%s) %s" % ip_address_filter.sql(api, "AND")
229             else:
230                 raise PLCInvalidArgument, "Wrong ip_address filter %r"%ip_address_filter
231
232         self.selectall(sql)