- query network_{methods,types} tables when validating method and type
[plcapi.git] / PLC / NodeNetworks.py
1 #
2 # Functions for interacting with the nodenetworks table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7 # $Id: NodeNetworks.py,v 1.5 2006/10/03 19:25:37 mlhuang Exp $
8 #
9
10 from types import StringTypes
11 import socket
12 import struct
13
14 from PLC.Faults import *
15 from PLC.Parameter import Parameter
16 from PLC.Debug import profile
17 from PLC.Table import Row, Table
18 from PLC.NetworkTypes import NetworkType, NetworkTypes
19 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
20 import PLC.Nodes
21
22 def in_same_network(address1, address2, netmask):
23     """
24     Returns True if two IPv4 addresses are in the same network. Faults
25     if an address is invalid.
26     """
27
28     address1 = struct.unpack('>L', socket.inet_aton(address1))[0]
29     address2 = struct.unpack('>L', socket.inet_aton(address2))[0]
30     netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
31
32     return (address1 & netmask) == (address2 & netmask)
33
34 class NodeNetwork(Row):
35     """
36     Representation of a row in the nodenetworks table. To use, optionally
37     instantiate with a dict of values. Update as you would a
38     dict. Commit to the database with sync().
39     """
40
41     table_name = 'nodenetworks'
42     primary_key = 'nodenetwork_id'
43     fields = {
44         'nodenetwork_id': Parameter(int, "Node interface identifier"),
45         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
46         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
47         'ip': Parameter(str, "IP address"),
48         'mac': Parameter(str, "MAC address"),
49         'gateway': Parameter(str, "IP address of primary gateway"),
50         'network': Parameter(str, "Subnet address"),
51         'broadcast': Parameter(str, "Network broadcast address"),
52         'netmask': Parameter(str, "Subnet mask"),
53         'dns1': Parameter(str, "IP address of primary DNS server"),
54         'dns2': Parameter(str, "IP address of secondary DNS server"),
55         # XXX Should be an int (bps)
56         'bwlimit': Parameter(str, "Bandwidth limit"),
57         'hostname': Parameter(str, "(Optional) Hostname"),
58         'node_id': Parameter(int, "Node associated with this interface (if any)"),
59         'is_primary': Parameter(bool, "Is the primary interface for this node"),
60         }
61
62     bwlimits = ['-1',
63                 '100kbit', '250kbit', '500kbit',
64                 '1mbit', '2mbit', '5mbit',
65                 '10mbit', '20mbit', '50mbit',
66                 '100mbit']
67
68     def __init__(self, api, fields = {}):
69         Row.__init__(self, fields)
70         self.api = api
71
72     def validate_method(self, method):
73         if method not in NetworkMethods(self.api):
74             raise PLCInvalidArgument, "Invalid addressing method"
75         return method
76
77     def validate_type(self, type):
78         if type not in NetworkTypes(self.api):
79             raise PLCInvalidArgument, "Invalid address type"
80         return type
81
82     def validate_ip(self, ip):
83         if ip:
84             try:
85                 ip = socket.inet_ntoa(socket.inet_aton(ip))
86             except socket.error:
87                 raise PLCInvalidArgument, "Invalid IP address " + ip
88
89         return ip
90
91     def validate_mac(self, mac):
92         try:
93             bytes = mac.split(":")
94             if len(bytes) < 6:
95                 raise Exception
96             for i, byte in enumerate(bytes):
97                 byte = int(byte, 16)
98                 if byte < 0 or byte > 255:
99                     raise Exception
100                 bytes[i] = "%02x" % byte
101             mac = ":".join(bytes)
102         except:
103             raise PLCInvalidArgument, "Invalid MAC address"
104
105         return mac
106
107     validate_gateway = validate_ip
108     validate_network = validate_ip
109     validate_broadcast = validate_ip
110     validate_netmask = validate_ip
111     validate_dns1 = validate_ip
112     validate_dns2 = validate_ip
113
114     def validate_bwlimit(self, bwlimit):
115         if bwlimit not in self.bwlimits:
116             raise PLCInvalidArgument, "Invalid bandwidth limit"
117         return bwlimit
118
119     def validate_hostname(self, hostname):
120         # Optional
121         if not hostname:
122             return hostname
123
124         # Validate hostname, and check for conflicts with a node hostname
125         return PLC.Nodes.Node.validate_hostname(self, hostname)
126
127     def validate_node_id(self, node_id):
128         nodes = PLC.Nodes.Nodes(self.api, [node_id])
129         if not nodes:
130             raise PLCInvalidArgument, "No such node"
131
132         return node_id
133
134     def validate_is_primary(self, is_primary):
135         """
136         Set this interface to be the primary one.
137         """
138
139         if is_primary:
140             nodes = Nodes(self.api, [self['node_id']])
141             if not nodes:
142                 raise PLCInvalidArgument, "No such node"
143             node = nodes[0]
144
145             if node['nodenetwork_ids']:
146                 conflicts = NodeNetworks(self.api, node['nodenetwork_ids'])
147                 for nodenetwork_id, nodenetwork in conflicts.iteritems():
148                     if ('nodenetwork_id' not in self or self['nodenetwork_id'] != nodenetwork_id) and \
149                        nodenetwork['is_primary']:
150                         raise PLCInvalidArgument, "Can only set one primary interface per node"
151
152         return is_primary
153
154     def validate(self):
155         """
156         Flush changes back to the database.
157         """
158
159         # Basic validation
160         Row.validate(self)
161
162         assert 'method' in self
163         method = self['method']
164
165         if method == "proxy" or method == "tap":
166             if 'mac' in self and self['mac']:
167                 raise PLCInvalidArgument, "For %s method, mac should not be specified" % method
168             if 'ip' not in self or not self['ip']:
169                 raise PLCInvalidArgument, "For %s method, ip is required" % method
170             if method == "tap" and ('gateway' not in self or not self['gateway']):
171                 raise PLCInvalidArgument, "For tap method, gateway is required and should be " \
172                       "the IP address of the node that proxies for this address"
173             # Should check that the proxy address is reachable, but
174             # there's no way to tell if the only primary interface is
175             # DHCP!
176
177         elif method == "static":
178             for key in ['ip', 'gateway', 'network', 'broadcast', 'netmask', 'dns1']:
179                 if key not in self or not self[key]:
180                     raise PLCInvalidArgument, "For static method, %s is required" % key
181                 globals()[key] = self[key]
182             if not in_same_network(ip, network, netmask):
183                 raise PLCInvalidArgument, "IP address %s is inconsistent with network %s/%s" % \
184                       (ip, network, netmask)
185             if not in_same_network(broadcast, network, netmask):
186                 raise PLCInvalidArgument, "Broadcast address %s is inconsistent with network %s/%s" % \
187                       (broadcast, network, netmask)
188             if not in_same_network(ip, gateway, netmask):
189                 raise PLCInvalidArgument, "Gateway %s is not reachable from %s/%s" % \
190                       (gateway, ip, netmask)
191
192         elif method == "ipmi":
193             if 'ip' not in self or not self['ip']:
194                 raise PLCInvalidArgument, "For ipmi method, ip is required"
195
196     def delete(self, commit = True):
197         """
198         Delete existing nodenetwork.
199         """
200
201         assert 'nodenetwork_id' in self
202
203         # Delete ourself
204         self.api.db.do("DELETE FROM nodenetworks" \
205                        " WHERE nodenetwork_id = %d" % \
206                        self['nodenetwork_id'])
207         
208         if commit:
209             self.api.db.commit()
210
211 class NodeNetworks(Table):
212     """
213     Representation of row(s) from the nodenetworks table in the
214     database.
215     """
216
217     def __init__(self, api, nodenetwork_id_or_hostname_list = None):
218         self.api = api
219
220         sql = "SELECT %s FROM nodenetworks" % \
221               ", ".join(NodeNetwork.fields)
222
223         if nodenetwork_id_or_hostname_list:
224             # Separate the list into integers and strings
225             nodenetwork_ids = filter(lambda nodenetwork_id: isinstance(nodenetwork_id, (int, long)),
226                                      nodenetwork_id_or_hostname_list)
227             hostnames = filter(lambda hostname: isinstance(hostname, StringTypes),
228                                nodenetwork_id_or_hostname_list)
229             sql += " WHERE (False"
230             if nodenetwork_ids:
231                 sql += " OR nodenetwork_id IN (%s)" % ", ".join(map(str, nodenetwork_ids))
232             if hostnames:
233                 sql += " OR hostname IN (%s)" % ", ".join(api.db.quote(hostnames)).lower()
234             sql += ")"
235
236         rows = self.api.db.selectall(sql)
237
238         for row in rows:
239             self[row['nodenetwork_id']] = NodeNetwork(api, row)