- fix hostname checking
[plcapi.git] / PLC / NodeNetworks.py
1 #
2 # Functions for interacting with the nodenetworks table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7 # $Id: NodeNetworks.py,v 1.6 2006/10/10 20:27:13 mlhuang Exp $
8 #
9
10 from types import StringTypes
11 import socket
12 import struct
13
14 from PLC.Faults import *
15 from PLC.Parameter import Parameter
16 from PLC.Debug import profile
17 from PLC.Table import Row, Table
18 from PLC.NetworkTypes import NetworkType, NetworkTypes
19 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
20 import PLC.Nodes
21
22 def valid_ip(ip):
23     try:
24         ip = socket.inet_ntoa(socket.inet_aton(ip))
25         return True
26     except socket.error:
27         return False
28
29 def in_same_network(address1, address2, netmask):
30     """
31     Returns True if two IPv4 addresses are in the same network. Faults
32     if an address is invalid.
33     """
34
35     address1 = struct.unpack('>L', socket.inet_aton(address1))[0]
36     address2 = struct.unpack('>L', socket.inet_aton(address2))[0]
37     netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
38
39     return (address1 & netmask) == (address2 & netmask)
40
41 class NodeNetwork(Row):
42     """
43     Representation of a row in the nodenetworks table. To use, optionally
44     instantiate with a dict of values. Update as you would a
45     dict. Commit to the database with sync().
46     """
47
48     table_name = 'nodenetworks'
49     primary_key = 'nodenetwork_id'
50     fields = {
51         'nodenetwork_id': Parameter(int, "Node interface identifier"),
52         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
53         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
54         'ip': Parameter(str, "IP address"),
55         'mac': Parameter(str, "MAC address"),
56         'gateway': Parameter(str, "IP address of primary gateway"),
57         'network': Parameter(str, "Subnet address"),
58         'broadcast': Parameter(str, "Network broadcast address"),
59         'netmask': Parameter(str, "Subnet mask"),
60         'dns1': Parameter(str, "IP address of primary DNS server"),
61         'dns2': Parameter(str, "IP address of secondary DNS server"),
62         # XXX Should be an int (bps)
63         'bwlimit': Parameter(str, "Bandwidth limit"),
64         'hostname': Parameter(str, "(Optional) Hostname"),
65         'node_id': Parameter(int, "Node associated with this interface (if any)"),
66         'is_primary': Parameter(bool, "Is the primary interface for this node"),
67         }
68
69     bwlimits = ['-1',
70                 '100kbit', '250kbit', '500kbit',
71                 '1mbit', '2mbit', '5mbit',
72                 '10mbit', '20mbit', '50mbit',
73                 '100mbit']
74
75     def __init__(self, api, fields = {}):
76         Row.__init__(self, fields)
77         self.api = api
78
79     def validate_method(self, method):
80         if method not in NetworkMethods(self.api):
81             raise PLCInvalidArgument, "Invalid addressing method"
82         return method
83
84     def validate_type(self, type):
85         if type not in NetworkTypes(self.api):
86             raise PLCInvalidArgument, "Invalid address type"
87         return type
88
89     def validate_ip(self, ip):
90         if ip and not valid_ip(ip):
91             raise PLCInvalidArgument, "Invalid IP address " + ip
92         return ip
93
94     def validate_mac(self, mac):
95         try:
96             bytes = mac.split(":")
97             if len(bytes) < 6:
98                 raise Exception
99             for i, byte in enumerate(bytes):
100                 byte = int(byte, 16)
101                 if byte < 0 or byte > 255:
102                     raise Exception
103                 bytes[i] = "%02x" % byte
104             mac = ":".join(bytes)
105         except:
106             raise PLCInvalidArgument, "Invalid MAC address"
107
108         return mac
109
110     validate_gateway = validate_ip
111     validate_network = validate_ip
112     validate_broadcast = validate_ip
113     validate_netmask = validate_ip
114     validate_dns1 = validate_ip
115     validate_dns2 = validate_ip
116
117     def validate_bwlimit(self, bwlimit):
118         if bwlimit not in self.bwlimits:
119             raise PLCInvalidArgument, "Invalid bandwidth limit"
120         return bwlimit
121
122     def validate_hostname(self, hostname):
123         # Optional
124         if not hostname:
125             return hostname
126
127         if not PLC.Nodes.valid_hostname(hostname):
128             raise PLCInvalidArgument, "Invalid hostname"
129
130         conflicts = NodeNetworks(self.api, [hostname])
131         for nodenetwork_id, nodenetwork in conflicts.iteritems():
132             if 'nodenetwork_id' not in self or self['nodenetwork_id'] != nodenetwork_id:
133                 raise PLCInvalidArgument, "Hostname already in use"
134
135         # Check for conflicts with a node hostname
136         conflicts = PLC.Nodes.Nodes(self.api, [hostname])
137         for node_id in conflicts.iteritems():
138             if 'node_id' not in self or self['node_id'] != node_id:
139                 raise PLCInvalidArgument, "Hostname already in use"
140
141         return hostname
142
143     def validate_node_id(self, node_id):
144         nodes = PLC.Nodes.Nodes(self.api, [node_id])
145         if not nodes:
146             raise PLCInvalidArgument, "No such node"
147
148         return node_id
149
150     def validate_is_primary(self, is_primary):
151         """
152         Set this interface to be the primary one.
153         """
154
155         if is_primary:
156             nodes = PLC.Nodes.Nodes(self.api, [self['node_id']]).values()
157             if not nodes:
158                 raise PLCInvalidArgument, "No such node"
159             node = nodes[0]
160
161             if node['nodenetwork_ids']:
162                 conflicts = NodeNetworks(self.api, node['nodenetwork_ids'])
163                 for nodenetwork_id, nodenetwork in conflicts.iteritems():
164                     if ('nodenetwork_id' not in self or self['nodenetwork_id'] != nodenetwork_id) and \
165                        nodenetwork['is_primary']:
166                         raise PLCInvalidArgument, "Can only set one primary interface per node"
167
168         return is_primary
169
170     def validate(self):
171         """
172         Flush changes back to the database.
173         """
174
175         # Basic validation
176         Row.validate(self)
177
178         assert 'method' in self
179         method = self['method']
180
181         if method == "proxy" or method == "tap":
182             if 'mac' in self and self['mac']:
183                 raise PLCInvalidArgument, "For %s method, mac should not be specified" % method
184             if 'ip' not in self or not self['ip']:
185                 raise PLCInvalidArgument, "For %s method, ip is required" % method
186             if method == "tap" and ('gateway' not in self or not self['gateway']):
187                 raise PLCInvalidArgument, "For tap method, gateway is required and should be " \
188                       "the IP address of the node that proxies for this address"
189             # Should check that the proxy address is reachable, but
190             # there's no way to tell if the only primary interface is
191             # DHCP!
192
193         elif method == "static":
194             for key in ['ip', 'gateway', 'network', 'broadcast', 'netmask', 'dns1']:
195                 if key not in self or not self[key]:
196                     raise PLCInvalidArgument, "For static method, %s is required" % key
197                 globals()[key] = self[key]
198             if not in_same_network(ip, network, netmask):
199                 raise PLCInvalidArgument, "IP address %s is inconsistent with network %s/%s" % \
200                       (ip, network, netmask)
201             if not in_same_network(broadcast, network, netmask):
202                 raise PLCInvalidArgument, "Broadcast address %s is inconsistent with network %s/%s" % \
203                       (broadcast, network, netmask)
204             if not in_same_network(ip, gateway, netmask):
205                 raise PLCInvalidArgument, "Gateway %s is not reachable from %s/%s" % \
206                       (gateway, ip, netmask)
207
208         elif method == "ipmi":
209             if 'ip' not in self or not self['ip']:
210                 raise PLCInvalidArgument, "For ipmi method, ip is required"
211
212     def delete(self, commit = True):
213         """
214         Delete existing nodenetwork.
215         """
216
217         assert 'nodenetwork_id' in self
218
219         # Delete ourself
220         self.api.db.do("DELETE FROM nodenetworks" \
221                        " WHERE nodenetwork_id = %d" % \
222                        self['nodenetwork_id'])
223         
224         if commit:
225             self.api.db.commit()
226
227 class NodeNetworks(Table):
228     """
229     Representation of row(s) from the nodenetworks table in the
230     database.
231     """
232
233     def __init__(self, api, nodenetwork_id_or_hostname_list = None):
234         self.api = api
235
236         sql = "SELECT %s FROM nodenetworks" % \
237               ", ".join(NodeNetwork.fields)
238
239         if nodenetwork_id_or_hostname_list:
240             # Separate the list into integers and strings
241             nodenetwork_ids = filter(lambda nodenetwork_id: isinstance(nodenetwork_id, (int, long)),
242                                      nodenetwork_id_or_hostname_list)
243             hostnames = filter(lambda hostname: isinstance(hostname, StringTypes),
244                                nodenetwork_id_or_hostname_list)
245             sql += " WHERE (False"
246             if nodenetwork_ids:
247                 sql += " OR nodenetwork_id IN (%s)" % ", ".join(map(str, nodenetwork_ids))
248             if hostnames:
249                 sql += " OR hostname IN (%s)" % ", ".join(api.db.quote(hostnames)).lower()
250             sql += ")"
251
252         rows = self.api.db.selectall(sql)
253
254         for row in rows:
255             self[row['nodenetwork_id']] = NodeNetwork(api, row)