Use correct gateway address for secondary interfaces.
[plcapi.git] / PLC / Interfaces.py
1 #
2 # Functions for interacting with the interfaces table in the database
3 #
4 # Mark Huang <mlhuang@cs.princeton.edu>
5 # Copyright (C) 2006 The Trustees of Princeton University
6 #
7 # $Id$
8 # $URL$
9 #
10
11 from types import StringTypes
12 import socket
13 import struct
14
15 from PLC.Faults import *
16 from PLC.Parameter import Parameter
17 from PLC.Filter import Filter
18 from PLC.Debug import profile
19 from PLC.Table import Row, Table
20 from PLC.NetworkTypes import NetworkType, NetworkTypes
21 from PLC.NetworkMethods import NetworkMethod, NetworkMethods
22 import PLC.Nodes
23
24 def valid_ip(ip):
25     try:
26         ip = socket.inet_ntoa(socket.inet_aton(ip))
27         return True
28     except socket.error:
29         return False
30
31 def in_same_network(address1, address2, netmask):
32     """
33     Returns True if two IPv4 addresses are in the same network. Faults
34     if an address is invalid.
35     """
36
37     address1 = struct.unpack('>L', socket.inet_aton(address1))[0]
38     address2 = struct.unpack('>L', socket.inet_aton(address2))[0]
39     netmask = struct.unpack('>L', socket.inet_aton(netmask))[0]
40
41     return (address1 & netmask) == (address2 & netmask)
42
43 class Interface(Row):
44     """
45     Representation of a row in the interfaces table. To use, optionally
46     instantiate with a dict of values. Update as you would a
47     dict. Commit to the database with sync().
48     """
49
50     table_name = 'interfaces'
51     primary_key = 'interface_id'
52     join_tables = ['interface_tag']
53     fields = {
54         'interface_id': Parameter(int, "Node interface identifier"),
55         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
56         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
57         'ip': Parameter(str, "IP address", nullok = True),
58         'mac': Parameter(str, "MAC address", nullok = True),
59         'gateway': Parameter(str, "IP address of primary gateway", nullok = True),
60         'network': Parameter(str, "Subnet address", nullok = True),
61         'broadcast': Parameter(str, "Network broadcast address", nullok = True),
62         'netmask': Parameter(str, "Subnet mask", nullok = True),
63         'dns1': Parameter(str, "IP address of primary DNS server", nullok = True),
64         'dns2': Parameter(str, "IP address of secondary DNS server", nullok = True),
65         'bwlimit': Parameter(int, "Bandwidth limit", min = 0, nullok = True),
66         'hostname': Parameter(str, "(Optional) Hostname", nullok = True),
67         'node_id': Parameter(int, "Node associated with this interface"),
68         'is_primary': Parameter(bool, "Is the primary interface for this node"),
69         'interface_tag_ids' : Parameter([int], "List of interface settings"),
70         }
71
72     view_tags_name = "view_interface_tags"
73     tags = {}
74
75     def validate_method(self, method):
76         network_methods = [row['method'] for row in NetworkMethods(self.api)]
77         if method not in network_methods:
78             raise PLCInvalidArgument, "Invalid addressing method %s"%method
79         return method
80
81     def validate_type(self, type):
82         network_types = [row['type'] for row in NetworkTypes(self.api)]
83         if type not in network_types:
84             raise PLCInvalidArgument, "Invalid address type %s"%type
85         return type
86
87     def validate_ip(self, ip):
88         if ip and not valid_ip(ip):
89             raise PLCInvalidArgument, "Invalid IP address %s"%ip
90         return ip
91
92     def validate_mac(self, mac):
93         if not mac:
94             return mac
95
96         try:
97             bytes = mac.split(":")
98             if len(bytes) < 6:
99                 raise Exception
100             for i, byte in enumerate(bytes):
101                 byte = int(byte, 16)
102                 if byte < 0 or byte > 255:
103                     raise Exception
104                 bytes[i] = "%02x" % byte
105             mac = ":".join(bytes)
106         except:
107             raise PLCInvalidArgument, "Invalid MAC address %s"%mac
108
109         return mac
110
111     validate_gateway = validate_ip
112     validate_network = validate_ip
113     validate_broadcast = validate_ip
114     validate_netmask = validate_ip
115     validate_dns1 = validate_ip
116     validate_dns2 = validate_ip
117
118     def validate_bwlimit(self, bwlimit):
119         if not bwlimit:
120             return bwlimit
121
122         if bwlimit < 500000:
123             raise PLCInvalidArgument, 'Minimum bw is 500 kbs'
124
125         return bwlimit  
126
127     def validate_hostname(self, hostname):
128         # Optional
129         if not hostname:
130             return hostname
131
132         if not PLC.Nodes.valid_hostname(hostname):
133             raise PLCInvalidArgument, "Invalid hostname %s"%hostname
134
135         return hostname
136
137     def validate_node_id(self, node_id):
138         nodes = PLC.Nodes.Nodes(self.api, [node_id])
139         if not nodes:
140             raise PLCInvalidArgument, "No such node %d"%node_id
141
142         return node_id
143
144     def validate_is_primary(self, is_primary):
145         """
146         Set this interface to be the primary one.
147         """
148
149         if is_primary:
150             nodes = PLC.Nodes.Nodes(self.api, [self['node_id']])
151             if not nodes:
152                 raise PLCInvalidArgument, "No such node %d"%node_id
153             node = nodes[0]
154
155             if node['interface_ids']:
156                 conflicts = Interfaces(self.api, node['interface_ids'])
157                 for interface in conflicts:
158                     if ('interface_id' not in self or \
159                         self['interface_id'] != interface['interface_id']) and \
160                        interface['is_primary']:
161                         raise PLCInvalidArgument, "Can only set one primary interface per node"
162
163         return is_primary
164
165     def validate(self):
166         """
167         Flush changes back to the database.
168         """
169
170         # Basic validation
171         Row.validate(self)
172
173         assert 'method' in self
174         method = self['method']
175
176         if method == "proxy" or method == "tap":
177             if 'mac' in self and self['mac']:
178                 raise PLCInvalidArgument, "For %s method, mac should not be specified" % method
179             if 'ip' not in self or not self['ip']:
180                 raise PLCInvalidArgument, "For %s method, ip is required" % method
181             if method == "tap" and ('gateway' not in self or not self['gateway']):
182                 raise PLCInvalidArgument, "For tap method, gateway is required and should be " \
183                       "the IP address of the node that proxies for this address"
184             # Should check that the proxy address is reachable, but
185             # there's no way to tell if the only primary interface is
186             # DHCP!
187
188         elif method == "static":
189             for key in ['gateway', 'dns1']:
190                 if key not in self or not self[key]:
191                     if 'is_primary' in self and self['is_primary'] is True:
192                         raise PLCInvalidArgument, "For static method primary network, %s is required" % key
193                 else:
194                     globals()[key] = self[key]
195             for key in ['ip', 'network', 'broadcast', 'netmask']:
196                 if key not in self or not self[key]:
197                     raise PLCInvalidArgument, "For static method, %s is required" % key
198                 globals()[key] = self[key]
199             if not in_same_network(ip, network, netmask):
200                 raise PLCInvalidArgument, "IP address %s is inconsistent with network %s/%s" % \
201                       (ip, network, netmask)
202             if not in_same_network(broadcast, network, netmask):
203                 raise PLCInvalidArgument, "Broadcast address %s is inconsistent with network %s/%s" % \
204                       (broadcast, network, netmask)
205             if 'gateway' in globals() and not in_same_network(ip, gateway, netmask):
206                 raise PLCInvalidArgument, "Gateway %s is not reachable from %s/%s" % \
207                       (gateway, ip, netmask)
208
209         elif method == "ipmi":
210             if 'ip' not in self or not self['ip']:
211                 raise PLCInvalidArgument, "For ipmi method, ip is required"
212
213 class Interfaces(Table):
214     """
215     Representation of row(s) from the interfaces table in the
216     database.
217     """
218
219     def __init__(self, api, interface_filter = None, columns = None):
220         Table.__init__(self, api, Interface, columns)
221
222         # the view that we're selecting upon: start with view_nodes
223         view = "view_interfaces"
224         # as many left joins as requested tags
225         for tagname in self.tag_columns:
226             view= "%s left join %s using (%s)"%(view,Interface.tagvalue_view_name(tagname),
227                                                 Interface.primary_key)
228             
229         sql = "SELECT %s FROM %s WHERE True" % \
230             (", ".join(self.columns.keys()+self.tag_columns.keys()),view)
231
232         if interface_filter is not None:
233             if isinstance(interface_filter, (list, tuple, set)):
234                 # Separate the list into integers and strings
235                 ints = filter(lambda x: isinstance(x, (int, long)), interface_filter)
236                 strs = filter(lambda x: isinstance(x, StringTypes), interface_filter)
237                 interface_filter = Filter(Interface.fields, {'interface_id': ints, 'ip': strs})
238                 sql += " AND (%s) %s" % interface_filter.sql(api, "OR")
239             elif isinstance(interface_filter, dict):
240                 allowed_fields=dict(Interface.fields.items()+Interface.tags.items())
241                 interface_filter = Filter(allowed_fields, interface_filter)
242                 sql += " AND (%s) %s" % interface_filter.sql(api)
243             elif isinstance(interface_filter, int):
244                 interface_filter = Filter(Interface.fields, {'interface_id': [interface_filter]})
245                 sql += " AND (%s) %s" % interface_filter.sql(api)
246             elif isinstance (interface_filter, StringTypes):
247                 interface_filter = Filter(Interface.fields, {'ip':[interface_filter]})
248                 sql += " AND (%s) %s" % interface_filter.sql(api, "AND")
249             else:
250                 raise PLCInvalidArgument, "Wrong interface filter %r"%interface_filter
251
252         self.selectall(sql)