unset None fields, if allowed
[plcapi.git] / PLC / NodeNetworks.py
1 #
2 # Functions for interacting with the nodenetworks table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7 # $Id: NodeNetworks.py,v 1.11 2006/10/25 14:29:13 mlhuang Exp $
8 #
9
10 from types import StringTypes
11 import socket
12 import struct
13
14 from PLC.Faults import *
15 from PLC.Parameter import Parameter
16 from PLC.Debug import profile
17 from PLC.Table import Row, Table
18 from PLC.NetworkTypes import NetworkType, NetworkTypes
19 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
20 import PLC.Nodes
21
22 def valid_ip(ip):
23     try:
24         ip = socket.inet_ntoa(socket.inet_aton(ip))
25         return True
26     except socket.error:
27         return False
28
29 def in_same_network(address1, address2, netmask):
30     """
31     Returns True if two IPv4 addresses are in the same network. Faults
32     if an address is invalid.
33     """
34
35     address1 = struct.unpack('>L', socket.inet_aton(address1))[0]
36     address2 = struct.unpack('>L', socket.inet_aton(address2))[0]
37     netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
38
39     return (address1 & netmask) == (address2 & netmask)
40
41 class NodeNetwork(Row):
42     """
43     Representation of a row in the nodenetworks table. To use, optionally
44     instantiate with a dict of values. Update as you would a
45     dict. Commit to the database with sync().
46     """
47
48     table_name = 'nodenetworks'
49     primary_key = 'nodenetwork_id'
50     fields = {
51         'nodenetwork_id': Parameter(int, "Node interface identifier"),
52         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
53         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
54         'ip': Parameter(str, "IP address", nullok = True),
55         'mac': Parameter(str, "MAC address", nullok = True),
56         'gateway': Parameter(str, "IP address of primary gateway", nullok = True),
57         'network': Parameter(str, "Subnet address", nullok = True),
58         'broadcast': Parameter(str, "Network broadcast address", nullok = True),
59         'netmask': Parameter(str, "Subnet mask", nullok = True),
60         'dns1': Parameter(str, "IP address of primary DNS server", nullok = True),
61         'dns2': Parameter(str, "IP address of secondary DNS server", nullok = True),
62         'bwlimit': Parameter(int, "Bandwidth limit", min = 0, nullok = True),
63         'hostname': Parameter(str, "(Optional) Hostname", nullok = True),
64         'node_id': Parameter(int, "Node associated with this interface"),
65         'is_primary': Parameter(bool, "Is the primary interface for this node"),
66         }
67
68     def validate_method(self, method):
69         if method not in NetworkMethods(self.api):
70             raise PLCInvalidArgument, "Invalid addressing method"
71         return method
72
73     def validate_type(self, type):
74         if type not in NetworkTypes(self.api):
75             raise PLCInvalidArgument, "Invalid address type"
76         return type
77
78     def validate_ip(self, ip):
79         if ip and not valid_ip(ip):
80             raise PLCInvalidArgument, "Invalid IP address " + ip
81         return ip
82
83     def validate_mac(self, mac):
84         if not mac:
85             return mac
86
87         try:
88             bytes = mac.split(":")
89             if len(bytes) < 6:
90                 raise Exception
91             for i, byte in enumerate(bytes):
92                 byte = int(byte, 16)
93                 if byte < 0 or byte > 255:
94                     raise Exception
95                 bytes[i] = "%02x" % byte
96             mac = ":".join(bytes)
97         except:
98             raise PLCInvalidArgument, "Invalid MAC address"
99
100         return mac
101
102     validate_gateway = validate_ip
103     validate_network = validate_ip
104     validate_broadcast = validate_ip
105     validate_netmask = validate_ip
106     validate_dns1 = validate_ip
107     validate_dns2 = validate_ip
108
109     def validate_hostname(self, hostname):
110         # Optional
111         if not hostname:
112             return hostname
113
114         if not PLC.Nodes.valid_hostname(hostname):
115             raise PLCInvalidArgument, "Invalid hostname"
116
117         return hostname
118
119     def validate_node_id(self, node_id):
120         nodes = PLC.Nodes.Nodes(self.api, [node_id])
121         if not nodes:
122             raise PLCInvalidArgument, "No such node"
123
124         return node_id
125
126     def validate_is_primary(self, is_primary):
127         """
128         Set this interface to be the primary one.
129         """
130
131         if is_primary:
132             nodes = PLC.Nodes.Nodes(self.api, [self['node_id']]).values()
133             if not nodes:
134                 raise PLCInvalidArgument, "No such node"
135             node = nodes[0]
136
137             if node['nodenetwork_ids']:
138                 conflicts = NodeNetworks(self.api, node['nodenetwork_ids'])
139                 for nodenetwork_id, nodenetwork in conflicts.iteritems():
140                     if ('nodenetwork_id' not in self or self['nodenetwork_id'] != nodenetwork_id) and \
141                        nodenetwork['is_primary']:
142                         raise PLCInvalidArgument, "Can only set one primary interface per node"
143
144         return is_primary
145
146     def validate(self):
147         """
148         Flush changes back to the database.
149         """
150
151         # Basic validation
152         Row.validate(self)
153
154         assert 'method' in self
155         method = self['method']
156
157         if method == "proxy" or method == "tap":
158             if 'mac' in self and self['mac']:
159                 raise PLCInvalidArgument, "For %s method, mac should not be specified" % method
160             if 'ip' not in self or not self['ip']:
161                 raise PLCInvalidArgument, "For %s method, ip is required" % method
162             if method == "tap" and ('gateway' not in self or not self['gateway']):
163                 raise PLCInvalidArgument, "For tap method, gateway is required and should be " \
164                       "the IP address of the node that proxies for this address"
165             # Should check that the proxy address is reachable, but
166             # there's no way to tell if the only primary interface is
167             # DHCP!
168
169         elif method == "static":
170             for key in ['ip', 'gateway', 'network', 'broadcast', 'netmask', 'dns1']:
171                 if key not in self or not self[key]:
172                     raise PLCInvalidArgument, "For static method, %s is required" % key
173                 globals()[key] = self[key]
174             if not in_same_network(ip, network, netmask):
175                 raise PLCInvalidArgument, "IP address %s is inconsistent with network %s/%s" % \
176                       (ip, network, netmask)
177             if not in_same_network(broadcast, network, netmask):
178                 raise PLCInvalidArgument, "Broadcast address %s is inconsistent with network %s/%s" % \
179                       (broadcast, network, netmask)
180             if not in_same_network(ip, gateway, netmask):
181                 raise PLCInvalidArgument, "Gateway %s is not reachable from %s/%s" % \
182                       (gateway, ip, netmask)
183
184         elif method == "ipmi":
185             if 'ip' not in self or not self['ip']:
186                 raise PLCInvalidArgument, "For ipmi method, ip is required"
187
188 class NodeNetworks(Table):
189     """
190     Representation of row(s) from the nodenetworks table in the
191     database.
192     """
193
194     def __init__(self, api, nodenetwork_id_or_ip_list = None):
195         self.api = api
196
197         sql = "SELECT %s FROM nodenetworks" % \
198               ", ".join(NodeNetwork.fields)
199
200         if nodenetwork_id_or_ip_list:
201             # Separate the list into integers and strings
202             nodenetwork_ids = filter(lambda nodenetwork_id: isinstance(nodenetwork_id, (int, long)),
203                                      nodenetwork_id_or_ip_list)
204             ips = filter(lambda ip: isinstance(ip, StringTypes),
205                                nodenetwork_id_or_ip_list)
206             sql += " WHERE (False"
207             if nodenetwork_ids:
208                 sql += " OR nodenetwork_id IN (%s)" % ", ".join(map(str, nodenetwork_ids))
209             if ips:
210                 sql += " OR ip IN (%s)" % ", ".join(api.db.quote(ips)).lower()
211             sql += ")"
212
213         rows = self.api.db.selectall(sql)
214
215         for row in rows:
216             self[row['nodenetwork_id']] = NodeNetwork(api, row)