allow tuples and sets as sequence filters
[plcapi.git] / PLC / NodeNetworks.py
1 #
2 # Functions for interacting with the nodenetworks table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7 # $Id: NodeNetworks.py,v 1.13 2006/11/08 22:59:34 mlhuang Exp $
8 #
9
10 from types import StringTypes
11 import socket
12 import struct
13
14 from PLC.Faults import *
15 from PLC.Parameter import Parameter
16 from PLC.Filter import Filter
17 from PLC.Debug import profile
18 from PLC.Table import Row, Table
19 from PLC.NetworkTypes import NetworkType, NetworkTypes
20 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
21 import PLC.Nodes
22
23 def valid_ip(ip):
24     try:
25         ip = socket.inet_ntoa(socket.inet_aton(ip))
26         return True
27     except socket.error:
28         return False
29
30 def in_same_network(address1, address2, netmask):
31     """
32     Returns True if two IPv4 addresses are in the same network. Faults
33     if an address is invalid.
34     """
35
36     address1 = struct.unpack('>L', socket.inet_aton(address1))[0]
37     address2 = struct.unpack('>L', socket.inet_aton(address2))[0]
38     netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
39
40     return (address1 & netmask) == (address2 & netmask)
41
42 class NodeNetwork(Row):
43     """
44     Representation of a row in the nodenetworks table. To use, optionally
45     instantiate with a dict of values. Update as you would a
46     dict. Commit to the database with sync().
47     """
48
49     table_name = 'nodenetworks'
50     primary_key = 'nodenetwork_id'
51     fields = {
52         'nodenetwork_id': Parameter(int, "Node interface identifier"),
53         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
54         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
55         'ip': Parameter(str, "IP address", nullok = True),
56         'mac': Parameter(str, "MAC address", nullok = True),
57         'gateway': Parameter(str, "IP address of primary gateway", nullok = True),
58         'network': Parameter(str, "Subnet address", nullok = True),
59         'broadcast': Parameter(str, "Network broadcast address", nullok = True),
60         'netmask': Parameter(str, "Subnet mask", nullok = True),
61         'dns1': Parameter(str, "IP address of primary DNS server", nullok = True),
62         'dns2': Parameter(str, "IP address of secondary DNS server", nullok = True),
63         'bwlimit': Parameter(int, "Bandwidth limit", min = 0, nullok = True),
64         'hostname': Parameter(str, "(Optional) Hostname", nullok = True),
65         'node_id': Parameter(int, "Node associated with this interface"),
66         'is_primary': Parameter(bool, "Is the primary interface for this node"),
67         }
68
69     def validate_method(self, method):
70         if method not in NetworkMethods(self.api):
71             raise PLCInvalidArgument, "Invalid addressing method"
72         return method
73
74     def validate_type(self, type):
75         if type not in NetworkTypes(self.api):
76             raise PLCInvalidArgument, "Invalid address type"
77         return type
78
79     def validate_ip(self, ip):
80         if ip and not valid_ip(ip):
81             raise PLCInvalidArgument, "Invalid IP address " + ip
82         return ip
83
84     def validate_mac(self, mac):
85         if not mac:
86             return mac
87
88         try:
89             bytes = mac.split(":")
90             if len(bytes) < 6:
91                 raise Exception
92             for i, byte in enumerate(bytes):
93                 byte = int(byte, 16)
94                 if byte < 0 or byte > 255:
95                     raise Exception
96                 bytes[i] = "%02x" % byte
97             mac = ":".join(bytes)
98         except:
99             raise PLCInvalidArgument, "Invalid MAC address"
100
101         return mac
102
103     validate_gateway = validate_ip
104     validate_network = validate_ip
105     validate_broadcast = validate_ip
106     validate_netmask = validate_ip
107     validate_dns1 = validate_ip
108     validate_dns2 = validate_ip
109
110     def validate_hostname(self, hostname):
111         # Optional
112         if not hostname:
113             return hostname
114
115         if not PLC.Nodes.valid_hostname(hostname):
116             raise PLCInvalidArgument, "Invalid hostname"
117
118         return hostname
119
120     def validate_node_id(self, node_id):
121         nodes = PLC.Nodes.Nodes(self.api, [node_id])
122         if not nodes:
123             raise PLCInvalidArgument, "No such node"
124
125         return node_id
126
127     def validate_is_primary(self, is_primary):
128         """
129         Set this interface to be the primary one.
130         """
131
132         if is_primary:
133             nodes = PLC.Nodes.Nodes(self.api, [self['node_id']]).values()
134             if not nodes:
135                 raise PLCInvalidArgument, "No such node"
136             node = nodes[0]
137
138             if node['nodenetwork_ids']:
139                 conflicts = NodeNetworks(self.api, node['nodenetwork_ids'])
140                 for nodenetwork_id, nodenetwork in conflicts.iteritems():
141                     if ('nodenetwork_id' not in self or self['nodenetwork_id'] != nodenetwork_id) and \
142                        nodenetwork['is_primary']:
143                         raise PLCInvalidArgument, "Can only set one primary interface per node"
144
145         return is_primary
146
147     def validate(self):
148         """
149         Flush changes back to the database.
150         """
151
152         # Basic validation
153         Row.validate(self)
154
155         assert 'method' in self
156         method = self['method']
157
158         if method == "proxy" or method == "tap":
159             if 'mac' in self and self['mac']:
160                 raise PLCInvalidArgument, "For %s method, mac should not be specified" % method
161             if 'ip' not in self or not self['ip']:
162                 raise PLCInvalidArgument, "For %s method, ip is required" % method
163             if method == "tap" and ('gateway' not in self or not self['gateway']):
164                 raise PLCInvalidArgument, "For tap method, gateway is required and should be " \
165                       "the IP address of the node that proxies for this address"
166             # Should check that the proxy address is reachable, but
167             # there's no way to tell if the only primary interface is
168             # DHCP!
169
170         elif method == "static":
171             for key in ['ip', 'gateway', 'network', 'broadcast', 'netmask', 'dns1']:
172                 if key not in self or not self[key]:
173                     raise PLCInvalidArgument, "For static method, %s is required" % key
174                 globals()[key] = self[key]
175             if not in_same_network(ip, network, netmask):
176                 raise PLCInvalidArgument, "IP address %s is inconsistent with network %s/%s" % \
177                       (ip, network, netmask)
178             if not in_same_network(broadcast, network, netmask):
179                 raise PLCInvalidArgument, "Broadcast address %s is inconsistent with network %s/%s" % \
180                       (broadcast, network, netmask)
181             if not in_same_network(ip, gateway, netmask):
182                 raise PLCInvalidArgument, "Gateway %s is not reachable from %s/%s" % \
183                       (gateway, ip, netmask)
184
185         elif method == "ipmi":
186             if 'ip' not in self or not self['ip']:
187                 raise PLCInvalidArgument, "For ipmi method, ip is required"
188
189 class NodeNetworks(Table):
190     """
191     Representation of row(s) from the nodenetworks table in the
192     database.
193     """
194
195     def __init__(self, api, nodenetwork_filter = None):
196         Table.__init__(self, api, NodeNetwork)
197
198         sql = "SELECT %s FROM nodenetworks WHERE True" % \
199               ", ".join(NodeNetwork.fields)
200
201         if nodenetwork_filter is not None:
202             if isinstance(nodenetwork_filter, (list, tuple, set)):
203                 nodenetwork_filter = Filter(NodeNetwork.fields, {'nodenetwork_id': nodenetwork_filter})
204             elif isinstance(nodenetwork_filter, dict):
205                 nodenetwork_filter = Filter(NodeNetwork.fields, nodenetwork_filter)
206             sql += " AND (%s)" % nodenetwork_filter.sql(api)
207
208         self.selectall(sql)