StringTypes has gone
[plcapi.git] / PLC / Interfaces.py
1 #
2 # Functions for interacting with the interfaces table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7
8 import socket
9 import struct
10
11 from PLC.Faults import *
12 from PLC.Parameter import Parameter
13 from PLC.Filter import Filter
14 from PLC.Debug import profile
15 from PLC.Table import Row, Table
16 from PLC.NetworkTypes import NetworkType, NetworkTypes
17 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
18 import PLC.Nodes
19
20 def valid_ipv4(ip):
21     try:
22         ip = socket.inet_ntoa(socket.inet_aton(ip))
23         return True
24     except socket.error:
25         return False
26
27 def valid_ipv6(ip):
28     try:
29         ip = socket.inet_ntop(socket.AF_INET6, socket.inet_pton(socket.AF_INET6, ip))
30         return True
31     except socket.error:
32         return False
33
34 def valid_ip(ip):
35     return valid_ipv4(ip) or valid_ipv6(ip)
36
37 def in_same_network_ipv4(address1, address2, netmask):
38     """
39     Returns True if two IPv4 addresses are in the same network. Faults
40     if an address is invalid.
41     """
42     address1 = struct.unpack('>L', socket.inet_aton(address1))[0]
43     address2 = struct.unpack('>L', socket.inet_aton(address2))[0]
44     netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
45
46     return (address1 & netmask) == (address2 & netmask)
47
48 def in_same_network_ipv6(address1, address2, netmask):
49     """
50     Returns True if two IPv6 addresses are in the same network. Faults
51     if an address is invalid.
52     """
53     address1 = struct.unpack('>2Q', socket.inet_pton(socket.AF_INET6, address1))[0]
54     address2 = struct.unpack('>2Q', socket.inet_pton(socket.AF_INET6, address2))[0]
55     netmask = struct.unpack('>2Q', socket.inet_pton(socket.AF_INET6, netmask))[0]
56
57     return (address1 & netmask) == (address2 & netmask)
58
59 def in_same_network(address1, address2, netmask):
60     return in_same_network_ipv4(address1, address2, netmask) or \
61            in_same_network_ipv6(address1, address2, netmask)
62
63 class Interface(Row):
64     """
65     Representation of a row in the interfaces table. To use, optionally
66     instantiate with a dict of values. Update as you would a
67     dict. Commit to the database with sync().
68     """
69
70     table_name = 'interfaces'
71     primary_key = 'interface_id'
72     join_tables = ['interface_tag']
73     fields = {
74         'interface_id': Parameter(int, "Node interface identifier"),
75         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
76         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
77         'ip': Parameter(str, "IP address", nullok = True),
78         'mac': Parameter(str, "MAC address", nullok = True),
79         'gateway': Parameter(str, "IP address of primary gateway", nullok = True),
80         'network': Parameter(str, "Subnet address", nullok = True),
81         'broadcast': Parameter(str, "Network broadcast address", nullok = True),
82         'netmask': Parameter(str, "Subnet mask", nullok = True),
83         'dns1': Parameter(str, "IP address of primary DNS server", nullok = True),
84         'dns2': Parameter(str, "IP address of secondary DNS server", nullok = True),
85         'bwlimit': Parameter(int, "Bandwidth limit", min = 0, nullok = True),
86         'hostname': Parameter(str, "(Optional) Hostname", nullok = True),
87         'node_id': Parameter(int, "Node associated with this interface"),
88         'is_primary': Parameter(bool, "Is the primary interface for this node"),
89         'interface_tag_ids' : Parameter([int], "List of interface settings"),
90         'last_updated': Parameter(int, "Date and time when node entry was created", ro = True),
91         }
92
93     view_tags_name = "view_interface_tags"
94     tags = {}
95
96     def validate_method(self, method):
97         network_methods = [row['method'] for row in NetworkMethods(self.api)]
98         if method not in network_methods:
99             raise PLCInvalidArgument("Invalid addressing method %s"%method)
100         return method
101
102     def validate_type(self, type):
103         network_types = [row['type'] for row in NetworkTypes(self.api)]
104         if type not in network_types:
105             raise PLCInvalidArgument("Invalid address type %s"%type)
106         return type
107
108     def validate_ip(self, ip):
109         if ip and not valid_ip(ip):
110             raise PLCInvalidArgument("Invalid IP address %s"%ip)
111         return ip
112
113     def validate_mac(self, mac):
114         if not mac:
115             return mac
116
117         try:
118             bytes = mac.split(":")
119             if len(bytes) < 6:
120                 raise Exception
121             for i, byte in enumerate(bytes):
122                 byte = int(byte, 16)
123                 if byte < 0 or byte > 255:
124                     raise Exception
125                 bytes[i] = "%02x" % byte
126             mac = ":".join(bytes)
127         except:
128             raise PLCInvalidArgument("Invalid MAC address %s"%mac)
129
130         return mac
131
132     validate_gateway = validate_ip
133     validate_network = validate_ip
134     validate_broadcast = validate_ip
135     validate_netmask = validate_ip
136     validate_dns1 = validate_ip
137     validate_dns2 = validate_ip
138
139     def validate_bwlimit(self, bwlimit):
140         if not bwlimit:
141             return bwlimit
142
143         if bwlimit < 500000:
144             raise PLCInvalidArgument('Minimum bw is 500 kbs')
145
146         return bwlimit
147
148     def validate_hostname(self, hostname):
149         # Optional
150         if not hostname:
151             return hostname
152
153         if not PLC.Nodes.valid_hostname(hostname):
154             raise PLCInvalidArgument("Invalid hostname %s"%hostname)
155
156         return hostname
157
158     def validate_node_id(self, node_id):
159         nodes = PLC.Nodes.Nodes(self.api, [node_id])
160         if not nodes:
161             raise PLCInvalidArgument("No such node %d"%node_id)
162
163         return node_id
164
165     def validate_is_primary(self, is_primary):
166         """
167         Set this interface to be the primary one.
168         """
169
170         if is_primary:
171             nodes = PLC.Nodes.Nodes(self.api, [self['node_id']])
172             if not nodes:
173                 raise PLCInvalidArgument("No such node %d"%node_id)
174             node = nodes[0]
175
176             if node['interface_ids']:
177                 conflicts = Interfaces(self.api, node['interface_ids'])
178                 for interface in conflicts:
179                     if ('interface_id' not in self or \
180                         self['interface_id'] != interface['interface_id']) and \
181                        interface['is_primary']:
182                         raise PLCInvalidArgument("Can only set one primary interface per node")
183
184         return is_primary
185
186     def validate(self):
187         """
188         Flush changes back to the database.
189         """
190
191         # Basic validation
192         Row.validate(self)
193
194         assert 'method' in self
195         method = self['method']
196
197         if method == "proxy" or method == "tap":
198             if 'mac' in self and self['mac']:
199                 raise PLCInvalidArgument("For %s method, mac should not be specified" % method)
200             if 'ip' not in self or not self['ip']:
201                 raise PLCInvalidArgument("For %s method, ip is required" % method)
202             if method == "tap" and ('gateway' not in self or not self['gateway']):
203                 raise PLCInvalidArgument("For tap method, gateway is required and should be " \
204                       "the IP address of the node that proxies for this address")
205             # Should check that the proxy address is reachable, but
206             # there's no way to tell if the only primary interface is
207             # DHCP!
208
209         elif method == "static":
210             if self['type'] == 'ipv4':
211                 for key in ['gateway', 'dns1']:
212                     if key not in self or not self[key]:
213                         if 'is_primary' in self and self['is_primary'] is True:
214                             raise PLCInvalidArgument("For static method primary network, %s is required" % key)
215                     else:
216                         globals()[key] = self[key]
217                 for key in ['ip', 'network', 'broadcast', 'netmask']:
218                     if key not in self or not self[key]:
219                         raise PLCInvalidArgument("For static method, %s is required" % key)
220                     globals()[key] = self[key]
221                 if not in_same_network(ip, network, netmask):
222                     raise PLCInvalidArgument("IP address %s is inconsistent with network %s/%s" % \
223                           (ip, network, netmask))
224                 if not in_same_network(broadcast, network, netmask):
225                     raise PLCInvalidArgument("Broadcast address %s is inconsistent with network %s/%s" % \
226                           (broadcast, network, netmask))
227                 if 'gateway' in globals() and not in_same_network(ip, gateway, netmask):
228                     raise PLCInvalidArgument("Gateway %s is not reachable from %s/%s" % \
229                           (gateway, ip, netmask))
230             elif self['type'] == 'ipv6':
231                 for key in ['ip', 'gateway']:
232                     if key not in self or not self[key]:
233                         raise PLCInvalidArgument("For static ipv6 method, %s is required" % key)
234                     globals()[key] = self[key]
235         elif method == "ipmi":
236             if 'ip' not in self or not self['ip']:
237                 raise PLCInvalidArgument("For ipmi method, ip is required")
238
239     validate_last_updated = Row.validate_timestamp
240
241     def update_timestamp(self, col_name, commit = True):
242         """
243         Update col_name field with current time
244         """
245
246         assert 'interface_id' in self
247         assert self.table_name
248
249         self.api.db.do("UPDATE %s SET %s = CURRENT_TIMESTAMP " % (self.table_name, col_name) + \
250                        " where interface_id = %d" % (self['interface_id']) )
251         self.sync(commit)
252
253     def update_last_updated(self, commit = True):
254         self.update_timestamp('last_updated', commit)
255
256     def delete(self,commit=True):
257         ### need to cleanup ilinks
258         self.api.db.do("DELETE FROM ilink WHERE src_interface_id=%d OR dst_interface_id=%d" % \
259                            (self['interface_id'],self['interface_id']))
260
261         Row.delete(self)
262
263 class Interfaces(Table):
264     """
265     Representation of row(s) from the interfaces table in the
266     database.
267     """
268
269     def __init__(self, api, interface_filter = None, columns = None):
270         Table.__init__(self, api, Interface, columns)
271
272         # the view that we're selecting upon: start with view_nodes
273         view = "view_interfaces"
274         # as many left joins as requested tags
275         for tagname in self.tag_columns:
276             view= "%s left join %s using (%s)"%(view,Interface.tagvalue_view_name(tagname),
277                                                 Interface.primary_key)
278
279         sql = "SELECT %s FROM %s WHERE True" % \
280             (", ".join(list(self.columns.keys())+list(self.tag_columns.keys())),view)
281
282         if interface_filter is not None:
283             if isinstance(interface_filter, (list, tuple, set)):
284                 # Separate the list into integers and strings
285                 ints = [x for x in interface_filter if isinstance(x, int)]
286                 strs = [x for x in interface_filter if isinstance(x, str)]
287                 interface_filter = Filter(Interface.fields, {'interface_id': ints, 'ip': strs})
288                 sql += " AND (%s) %s" % interface_filter.sql(api, "OR")
289             elif isinstance(interface_filter, dict):
290                 allowed_fields=dict(list(Interface.fields.items())+list(Interface.tags.items()))
291                 interface_filter = Filter(allowed_fields, interface_filter)
292                 sql += " AND (%s) %s" % interface_filter.sql(api)
293             elif isinstance(interface_filter, int):
294                 interface_filter = Filter(Interface.fields, {'interface_id': [interface_filter]})
295                 sql += " AND (%s) %s" % interface_filter.sql(api)
296             elif isinstance (interface_filter, str):
297                 interface_filter = Filter(Interface.fields, {'ip':[interface_filter]})
298                 sql += " AND (%s) %s" % interface_filter.sql(api, "AND")
299             else:
300                 raise PLCInvalidArgument("Wrong interface filter %r"%interface_filter)
301
302         self.selectall(sql)