support validating ipv6 addresses
[plcapi.git] / PLC / Interfaces.py
1 #
2 # Functions for interacting with the interfaces table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7
8 from types import StringTypes
9 import socket
10 import struct
11
12 from PLC.Faults import *
13 from PLC.Parameter import Parameter
14 from PLC.Filter import Filter
15 from PLC.Debug import profile
16 from PLC.Table import Row, Table
17 from PLC.NetworkTypes import NetworkType, NetworkTypes
18 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
19 import PLC.Nodes
20
21 def valid_ipv4(ip):
22     try:
23         ip = socket.inet_ntoa(socket.inet_aton(ip))
24         return True
25     except socket.error:
26         return False
27
28 def valid_ipv6(ip):
29     try:
30         ip = socket.inet_ntop(socket.AF_INET6, socket.inet_pton(socket.AF_INET6, ip)
31         return True
32     except socket.error:
33         return False   
34
35 def valid_ip(ip):
36     return valid_ipv4(ip) or valid_ipv6(ip)
37
38 def in_same_network_ipv4(address1, address2, netmask):
39     """
40     Returns True if two IPv4 addresses are in the same network. Faults
41     if an address is invalid.
42     """
43     address1 = struct.unpack('>L', socket.inet_aton(address1))[0]
44     address2 = struct.unpack('>L', socket.inet_aton(address2))[0]
45     netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
46
47     return (address1 & netmask) == (address2 & netmask)
48
49 def in_same_network_ipv6(address1, address2, netmask):
50     """
51     Returns True if two IPv6 addresses are in the same network. Faults
52     if an address is invalid.
53     """
54     address1 = struct.unpack('>2Q', socket.inet_pton(socket.AF_INET6, address1))[0]
55     address2 = struct.unpack('>2Q', socket.inet_pton(socket.AF_INET6, address2))[0]
56     netmask = struct.unpack('>2Q', socket.inet_pton(socket.AF_INET6, netmask))[0]
57
58     return (address1 & netmask) == (address2 & netmask)
59
60 def in_same_network(address1, address2, netmask):
61     return in_same_network_ipv4(address1, address2, netmask) or \
62            in_same_network_ipv6(address1, address2, netmask) 
63
64 class Interface(Row):
65     """
66     Representation of a row in the interfaces table. To use, optionally
67     instantiate with a dict of values. Update as you would a
68     dict. Commit to the database with sync().
69     """
70
71     table_name = 'interfaces'
72     primary_key = 'interface_id'
73     join_tables = ['interface_tag']
74     fields = {
75         'interface_id': Parameter(int, "Node interface identifier"),
76         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
77         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
78         'ip': Parameter(str, "IP address", nullok = True),
79         'mac': Parameter(str, "MAC address", nullok = True),
80         'gateway': Parameter(str, "IP address of primary gateway", nullok = True),
81         'network': Parameter(str, "Subnet address", nullok = True),
82         'broadcast': Parameter(str, "Network broadcast address", nullok = True),
83         'netmask': Parameter(str, "Subnet mask", nullok = True),
84         'dns1': Parameter(str, "IP address of primary DNS server", nullok = True),
85         'dns2': Parameter(str, "IP address of secondary DNS server", nullok = True),
86         'bwlimit': Parameter(int, "Bandwidth limit", min = 0, nullok = True),
87         'hostname': Parameter(str, "(Optional) Hostname", nullok = True),
88         'node_id': Parameter(int, "Node associated with this interface"),
89         'is_primary': Parameter(bool, "Is the primary interface for this node"),
90         'interface_tag_ids' : Parameter([int], "List of interface settings"),
91         'last_updated': Parameter(int, "Date and time when node entry was created", ro = True),
92         }
93
94     view_tags_name = "view_interface_tags"
95     tags = {}
96
97     def validate_method(self, method):
98         network_methods = [row['method'] for row in NetworkMethods(self.api)]
99         if method not in network_methods:
100             raise PLCInvalidArgument, "Invalid addressing method %s"%method
101         return method
102
103     def validate_type(self, type):
104         network_types = [row['type'] for row in NetworkTypes(self.api)]
105         if type not in network_types:
106             raise PLCInvalidArgument, "Invalid address type %s"%type
107         return type
108
109     def validate_ip(self, ip):
110         if ip and not valid_ip(ip):
111             raise PLCInvalidArgument, "Invalid IP address %s"%ip
112         return ip
113
114     def validate_mac(self, mac):
115         if not mac:
116             return mac
117
118         try:
119             bytes = mac.split(":")
120             if len(bytes) < 6:
121                 raise Exception
122             for i, byte in enumerate(bytes):
123                 byte = int(byte, 16)
124                 if byte < 0 or byte > 255:
125                     raise Exception
126                 bytes[i] = "%02x" % byte
127             mac = ":".join(bytes)
128         except:
129             raise PLCInvalidArgument, "Invalid MAC address %s"%mac
130
131         return mac
132
133     validate_gateway = validate_ip
134     validate_network = validate_ip
135     validate_broadcast = validate_ip
136     validate_netmask = validate_ip
137     validate_dns1 = validate_ip
138     validate_dns2 = validate_ip
139
140     def validate_bwlimit(self, bwlimit):
141         if not bwlimit:
142             return bwlimit
143
144         if bwlimit < 500000:
145             raise PLCInvalidArgument, 'Minimum bw is 500 kbs'
146
147         return bwlimit
148
149     def validate_hostname(self, hostname):
150         # Optional
151         if not hostname:
152             return hostname
153
154         if not PLC.Nodes.valid_hostname(hostname):
155             raise PLCInvalidArgument, "Invalid hostname %s"%hostname
156
157         return hostname
158
159     def validate_node_id(self, node_id):
160         nodes = PLC.Nodes.Nodes(self.api, [node_id])
161         if not nodes:
162             raise PLCInvalidArgument, "No such node %d"%node_id
163
164         return node_id
165
166     def validate_is_primary(self, is_primary):
167         """
168         Set this interface to be the primary one.
169         """
170
171         if is_primary:
172             nodes = PLC.Nodes.Nodes(self.api, [self['node_id']])
173             if not nodes:
174                 raise PLCInvalidArgument, "No such node %d"%node_id
175             node = nodes[0]
176
177             if node['interface_ids']:
178                 conflicts = Interfaces(self.api, node['interface_ids'])
179                 for interface in conflicts:
180                     if ('interface_id' not in self or \
181                         self['interface_id'] != interface['interface_id']) and \
182                        interface['is_primary']:
183                         raise PLCInvalidArgument, "Can only set one primary interface per node"
184
185         return is_primary
186
187     def validate(self):
188         """
189         Flush changes back to the database.
190         """
191
192         # Basic validation
193         Row.validate(self)
194
195         assert 'method' in self
196         method = self['method']
197
198         if method == "proxy" or method == "tap":
199             if 'mac' in self and self['mac']:
200                 raise PLCInvalidArgument, "For %s method, mac should not be specified" % method
201             if 'ip' not in self or not self['ip']:
202                 raise PLCInvalidArgument, "For %s method, ip is required" % method
203             if method == "tap" and ('gateway' not in self or not self['gateway']):
204                 raise PLCInvalidArgument, "For tap method, gateway is required and should be " \
205                       "the IP address of the node that proxies for this address"
206             # Should check that the proxy address is reachable, but
207             # there's no way to tell if the only primary interface is
208             # DHCP!
209
210         elif method == "static":
211             for key in ['gateway', 'dns1']:
212                 if key not in self or not self[key]:
213                     if 'is_primary' in self and self['is_primary'] is True:
214                         raise PLCInvalidArgument, "For static method primary network, %s is required" % key
215                 else:
216                     globals()[key] = self[key]
217             for key in ['ip', 'network', 'broadcast', 'netmask']:
218                 if key not in self or not self[key]:
219                     raise PLCInvalidArgument, "For static method, %s is required" % key
220                 globals()[key] = self[key]
221             if not in_same_network(ip, network, netmask):
222                 raise PLCInvalidArgument, "IP address %s is inconsistent with network %s/%s" % \
223                       (ip, network, netmask)
224             if not in_same_network(broadcast, network, netmask):
225                 raise PLCInvalidArgument, "Broadcast address %s is inconsistent with network %s/%s" % \
226                       (broadcast, network, netmask)
227             if 'gateway' in globals() and not in_same_network(ip, gateway, netmask):
228                 raise PLCInvalidArgument, "Gateway %s is not reachable from %s/%s" % \
229                       (gateway, ip, netmask)
230
231         elif method == "ipmi":
232             if 'ip' not in self or not self['ip']:
233                 raise PLCInvalidArgument, "For ipmi method, ip is required"
234
235     validate_last_updated = Row.validate_timestamp
236
237     def update_timestamp(self, col_name, commit = True):
238         """
239         Update col_name field with current time
240         """
241
242         assert 'interface_id' in self
243         assert self.table_name
244
245         self.api.db.do("UPDATE %s SET %s = CURRENT_TIMESTAMP " % (self.table_name, col_name) + \
246                        " where interface_id = %d" % (self['interface_id']) )
247         self.sync(commit)
248
249     def update_last_updated(self, commit = True):
250         self.update_timestamp('last_updated', commit)
251
252     def delete(self,commit=True):
253         ### need to cleanup ilinks
254         self.api.db.do("DELETE FROM ilink WHERE src_interface_id=%d OR dst_interface_id=%d" % \
255                            (self['interface_id'],self['interface_id']))
256         
257         Row.delete(self)
258
259 class Interfaces(Table):
260     """
261     Representation of row(s) from the interfaces table in the
262     database.
263     """
264
265     def __init__(self, api, interface_filter = None, columns = None):
266         Table.__init__(self, api, Interface, columns)
267
268         # the view that we're selecting upon: start with view_nodes
269         view = "view_interfaces"
270         # as many left joins as requested tags
271         for tagname in self.tag_columns:
272             view= "%s left join %s using (%s)"%(view,Interface.tagvalue_view_name(tagname),
273                                                 Interface.primary_key)
274
275         sql = "SELECT %s FROM %s WHERE True" % \
276             (", ".join(self.columns.keys()+self.tag_columns.keys()),view)
277
278         if interface_filter is not None:
279             if isinstance(interface_filter, (list, tuple, set)):
280                 # Separate the list into integers and strings
281                 ints = filter(lambda x: isinstance(x, (int, long)), interface_filter)
282                 strs = filter(lambda x: isinstance(x, StringTypes), interface_filter)
283                 interface_filter = Filter(Interface.fields, {'interface_id': ints, 'ip': strs})
284                 sql += " AND (%s) %s" % interface_filter.sql(api, "OR")
285             elif isinstance(interface_filter, dict):
286                 allowed_fields=dict(Interface.fields.items()+Interface.tags.items())
287                 interface_filter = Filter(allowed_fields, interface_filter)
288                 sql += " AND (%s) %s" % interface_filter.sql(api)
289             elif isinstance(interface_filter, int):
290                 interface_filter = Filter(Interface.fields, {'interface_id': [interface_filter]})
291                 sql += " AND (%s) %s" % interface_filter.sql(api)
292             elif isinstance (interface_filter, StringTypes):
293                 interface_filter = Filter(Interface.fields, {'ip':[interface_filter]})
294                 sql += " AND (%s) %s" % interface_filter.sql(api, "AND")
295             else:
296                 raise PLCInvalidArgument, "Wrong interface filter %r"%interface_filter
297
298         self.selectall(sql)