- use base class __init__() and delete() implementations
[plcapi.git] / PLC / NodeNetworks.py
1 #
2 # Functions for interacting with the nodenetworks table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7 # $Id: NodeNetworks.py,v 1.8 2006/10/16 18:23:53 mlhuang Exp $
8 #
9
10 from types import StringTypes
11 import socket
12 import struct
13
14 from PLC.Faults import *
15 from PLC.Parameter import Parameter
16 from PLC.Debug import profile
17 from PLC.Table import Row, Table
18 from PLC.NetworkTypes import NetworkType, NetworkTypes
19 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
20 import PLC.Nodes
21
22 def valid_ip(ip):
23     try:
24         ip = socket.inet_ntoa(socket.inet_aton(ip))
25         return True
26     except socket.error:
27         return False
28
29 def in_same_network(address1, address2, netmask):
30     """
31     Returns True if two IPv4 addresses are in the same network. Faults
32     if an address is invalid.
33     """
34
35     address1 = struct.unpack('>L', socket.inet_aton(address1))[0]
36     address2 = struct.unpack('>L', socket.inet_aton(address2))[0]
37     netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
38
39     return (address1 & netmask) == (address2 & netmask)
40
41 class NodeNetwork(Row):
42     """
43     Representation of a row in the nodenetworks table. To use, optionally
44     instantiate with a dict of values. Update as you would a
45     dict. Commit to the database with sync().
46     """
47
48     table_name = 'nodenetworks'
49     primary_key = 'nodenetwork_id'
50     fields = {
51         'nodenetwork_id': Parameter(int, "Node interface identifier"),
52         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
53         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
54         'ip': Parameter(str, "IP address"),
55         'mac': Parameter(str, "MAC address"),
56         'gateway': Parameter(str, "IP address of primary gateway"),
57         'network': Parameter(str, "Subnet address"),
58         'broadcast': Parameter(str, "Network broadcast address"),
59         'netmask': Parameter(str, "Subnet mask"),
60         'dns1': Parameter(str, "IP address of primary DNS server"),
61         'dns2': Parameter(str, "IP address of secondary DNS server"),
62         'bwlimit': Parameter(int, "Bandwidth limit", min = 0),
63         'hostname': Parameter(str, "(Optional) Hostname"),
64         'node_id': Parameter(int, "Node associated with this interface (if any)"),
65         'is_primary': Parameter(bool, "Is the primary interface for this node"),
66         }
67
68     def validate_method(self, method):
69         if method not in NetworkMethods(self.api):
70             raise PLCInvalidArgument, "Invalid addressing method"
71         return method
72
73     def validate_type(self, type):
74         if type not in NetworkTypes(self.api):
75             raise PLCInvalidArgument, "Invalid address type"
76         return type
77
78     def validate_ip(self, ip):
79         if ip and not valid_ip(ip):
80             raise PLCInvalidArgument, "Invalid IP address " + ip
81         return ip
82
83     def validate_mac(self, mac):
84         try:
85             bytes = mac.split(":")
86             if len(bytes) < 6:
87                 raise Exception
88             for i, byte in enumerate(bytes):
89                 byte = int(byte, 16)
90                 if byte < 0 or byte > 255:
91                     raise Exception
92                 bytes[i] = "%02x" % byte
93             mac = ":".join(bytes)
94         except:
95             raise PLCInvalidArgument, "Invalid MAC address"
96
97         return mac
98
99     validate_gateway = validate_ip
100     validate_network = validate_ip
101     validate_broadcast = validate_ip
102     validate_netmask = validate_ip
103     validate_dns1 = validate_ip
104     validate_dns2 = validate_ip
105
106     def validate_hostname(self, hostname):
107         # Optional
108         if not hostname:
109             return hostname
110
111         if not PLC.Nodes.valid_hostname(hostname):
112             raise PLCInvalidArgument, "Invalid hostname"
113
114         return hostname
115
116     def validate_node_id(self, node_id):
117         nodes = PLC.Nodes.Nodes(self.api, [node_id])
118         if not nodes:
119             raise PLCInvalidArgument, "No such node"
120
121         return node_id
122
123     def validate_is_primary(self, is_primary):
124         """
125         Set this interface to be the primary one.
126         """
127
128         if is_primary:
129             nodes = PLC.Nodes.Nodes(self.api, [self['node_id']]).values()
130             if not nodes:
131                 raise PLCInvalidArgument, "No such node"
132             node = nodes[0]
133
134             if node['nodenetwork_ids']:
135                 conflicts = NodeNetworks(self.api, node['nodenetwork_ids'])
136                 for nodenetwork_id, nodenetwork in conflicts.iteritems():
137                     if ('nodenetwork_id' not in self or self['nodenetwork_id'] != nodenetwork_id) and \
138                        nodenetwork['is_primary']:
139                         raise PLCInvalidArgument, "Can only set one primary interface per node"
140
141         return is_primary
142
143     def validate(self):
144         """
145         Flush changes back to the database.
146         """
147
148         # Basic validation
149         Row.validate(self)
150
151         assert 'method' in self
152         method = self['method']
153
154         if method == "proxy" or method == "tap":
155             if 'mac' in self and self['mac']:
156                 raise PLCInvalidArgument, "For %s method, mac should not be specified" % method
157             if 'ip' not in self or not self['ip']:
158                 raise PLCInvalidArgument, "For %s method, ip is required" % method
159             if method == "tap" and ('gateway' not in self or not self['gateway']):
160                 raise PLCInvalidArgument, "For tap method, gateway is required and should be " \
161                       "the IP address of the node that proxies for this address"
162             # Should check that the proxy address is reachable, but
163             # there's no way to tell if the only primary interface is
164             # DHCP!
165
166         elif method == "static":
167             for key in ['ip', 'gateway', 'network', 'broadcast', 'netmask', 'dns1']:
168                 if key not in self or not self[key]:
169                     raise PLCInvalidArgument, "For static method, %s is required" % key
170                 globals()[key] = self[key]
171             if not in_same_network(ip, network, netmask):
172                 raise PLCInvalidArgument, "IP address %s is inconsistent with network %s/%s" % \
173                       (ip, network, netmask)
174             if not in_same_network(broadcast, network, netmask):
175                 raise PLCInvalidArgument, "Broadcast address %s is inconsistent with network %s/%s" % \
176                       (broadcast, network, netmask)
177             if not in_same_network(ip, gateway, netmask):
178                 raise PLCInvalidArgument, "Gateway %s is not reachable from %s/%s" % \
179                       (gateway, ip, netmask)
180
181         elif method == "ipmi":
182             if 'ip' not in self or not self['ip']:
183                 raise PLCInvalidArgument, "For ipmi method, ip is required"
184
185 class NodeNetworks(Table):
186     """
187     Representation of row(s) from the nodenetworks table in the
188     database.
189     """
190
191     def __init__(self, api, nodenetwork_id_or_ip_list = None):
192         self.api = api
193
194         sql = "SELECT %s FROM nodenetworks" % \
195               ", ".join(NodeNetwork.fields)
196
197         if nodenetwork_id_or_ip_list:
198             # Separate the list into integers and strings
199             nodenetwork_ids = filter(lambda nodenetwork_id: isinstance(nodenetwork_id, (int, long)),
200                                      nodenetwork_id_or_ip_list)
201             ips = filter(lambda ip: isinstance(ip, StringTypes),
202                                nodenetwork_id_or_ip_list)
203             sql += " WHERE (False"
204             if nodenetwork_ids:
205                 sql += " OR nodenetwork_id IN (%s)" % ", ".join(map(str, nodenetwork_ids))
206             if ips:
207                 sql += " OR ip IN (%s)" % ", ".join(api.db.quote(ips)).lower()
208             sql += ")"
209
210         rows = self.api.db.selectall(sql)
211
212         for row in rows:
213             self[row['nodenetwork_id']] = NodeNetwork(api, row)