- fix Auth so that it parses AuthMethod and doles out the actual
[plcapi.git] / PLC / NodeNetworks.py
index 954be05..8733f8d 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeNetworks.py,v 1.6 2006/10/10 20:27:13 mlhuang Exp $
+# $Id: NodeNetworks.py,v 1.15 2006/11/09 19:43:55 mlhuang Exp $
 #
 
 from types import StringTypes
 #
 
 from types import StringTypes
@@ -13,6 +13,7 @@ import struct
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.NetworkTypes import NetworkType, NetworkTypes
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.NetworkTypes import NetworkType, NetworkTypes
@@ -51,47 +52,41 @@ class NodeNetwork(Row):
         'nodenetwork_id': Parameter(int, "Node interface identifier"),
         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
         'nodenetwork_id': Parameter(int, "Node interface identifier"),
         'method': Parameter(str, "Addressing method (e.g., 'static' or 'dhcp')"),
         'type': Parameter(str, "Address type (e.g., 'ipv4')"),
-        'ip': Parameter(str, "IP address"),
-        'mac': Parameter(str, "MAC address"),
-        'gateway': Parameter(str, "IP address of primary gateway"),
-        'network': Parameter(str, "Subnet address"),
-        'broadcast': Parameter(str, "Network broadcast address"),
-        'netmask': Parameter(str, "Subnet mask"),
-        'dns1': Parameter(str, "IP address of primary DNS server"),
-        'dns2': Parameter(str, "IP address of secondary DNS server"),
-        # XXX Should be an int (bps)
-        'bwlimit': Parameter(str, "Bandwidth limit"),
-        'hostname': Parameter(str, "(Optional) Hostname"),
-        'node_id': Parameter(int, "Node associated with this interface (if any)"),
+        'ip': Parameter(str, "IP address", nullok = True),
+        'mac': Parameter(str, "MAC address", nullok = True),
+        'gateway': Parameter(str, "IP address of primary gateway", nullok = True),
+        'network': Parameter(str, "Subnet address", nullok = True),
+        'broadcast': Parameter(str, "Network broadcast address", nullok = True),
+        'netmask': Parameter(str, "Subnet mask", nullok = True),
+        'dns1': Parameter(str, "IP address of primary DNS server", nullok = True),
+        'dns2': Parameter(str, "IP address of secondary DNS server", nullok = True),
+        'bwlimit': Parameter(int, "Bandwidth limit", min = 0, nullok = True),
+        'hostname': Parameter(str, "(Optional) Hostname", nullok = True),
+        'node_id': Parameter(int, "Node associated with this interface"),
         'is_primary': Parameter(bool, "Is the primary interface for this node"),
         }
 
         'is_primary': Parameter(bool, "Is the primary interface for this node"),
         }
 
-    bwlimits = ['-1',
-                '100kbit', '250kbit', '500kbit',
-                '1mbit', '2mbit', '5mbit',
-                '10mbit', '20mbit', '50mbit',
-                '100mbit']
-
-    def __init__(self, api, fields = {}):
-        Row.__init__(self, fields)
-        self.api = api
-
     def validate_method(self, method):
     def validate_method(self, method):
-        if method not in NetworkMethods(self.api):
-            raise PLCInvalidArgument, "Invalid addressing method"
+        network_methods = [row['method'] for row in NetworkMethods(self.api)]
+        if method not in network_methods:
+            raise PLCInvalidArgument, "Invalid addressing method %s"%method
        return method
 
     def validate_type(self, type):
        return method
 
     def validate_type(self, type):
-        if type not in NetworkTypes(self.api):
-            raise PLCInvalidArgument, "Invalid address type"
+        network_types = [row['type'] for row in NetworkTypes(self.api)]
+        if type not in network_types:
+            raise PLCInvalidArgument, "Invalid address type %s"%type
        return type
 
     def validate_ip(self, ip):
         if ip and not valid_ip(ip):
        return type
 
     def validate_ip(self, ip):
         if ip and not valid_ip(ip):
-            raise PLCInvalidArgument, "Invalid IP address " + ip
+            raise PLCInvalidArgument, "Invalid IP address %s"%ip
         return ip
 
     def validate_mac(self, mac):
         return ip
 
     def validate_mac(self, mac):
+        if not mac:
+            return mac
+
         try:
             bytes = mac.split(":")
             if len(bytes) < 6:
         try:
             bytes = mac.split(":")
             if len(bytes) < 6:
@@ -103,7 +98,7 @@ class NodeNetwork(Row):
                 bytes[i] = "%02x" % byte
             mac = ":".join(bytes)
         except:
                 bytes[i] = "%02x" % byte
             mac = ":".join(bytes)
         except:
-            raise PLCInvalidArgument, "Invalid MAC address"
+            raise PLCInvalidArgument, "Invalid MAC address %s"%mac
 
         return mac
 
 
         return mac
 
@@ -114,36 +109,20 @@ class NodeNetwork(Row):
     validate_dns1 = validate_ip
     validate_dns2 = validate_ip
 
     validate_dns1 = validate_ip
     validate_dns2 = validate_ip
 
-    def validate_bwlimit(self, bwlimit):
-        if bwlimit not in self.bwlimits:
-            raise PLCInvalidArgument, "Invalid bandwidth limit"
-       return bwlimit
-
     def validate_hostname(self, hostname):
         # Optional
         if not hostname:
             return hostname
 
         if not PLC.Nodes.valid_hostname(hostname):
     def validate_hostname(self, hostname):
         # Optional
         if not hostname:
             return hostname
 
         if not PLC.Nodes.valid_hostname(hostname):
-            raise PLCInvalidArgument, "Invalid hostname"
-
-        conflicts = NodeNetworks(self.api, [hostname])
-        for nodenetwork_id, nodenetwork in conflicts.iteritems():
-            if 'nodenetwork_id' not in self or self['nodenetwork_id'] != nodenetwork_id:
-                raise PLCInvalidArgument, "Hostname already in use"
-
-        # Check for conflicts with a node hostname
-        conflicts = PLC.Nodes.Nodes(self.api, [hostname])
-        for node_id in conflicts.iteritems():
-            if 'node_id' not in self or self['node_id'] != node_id:
-                raise PLCInvalidArgument, "Hostname already in use"
+            raise PLCInvalidArgument, "Invalid hostname %s"%hostname
 
         return hostname
 
     def validate_node_id(self, node_id):
         nodes = PLC.Nodes.Nodes(self.api, [node_id])
         if not nodes:
 
         return hostname
 
     def validate_node_id(self, node_id):
         nodes = PLC.Nodes.Nodes(self.api, [node_id])
         if not nodes:
-            raise PLCInvalidArgument, "No such node"
+            raise PLCInvalidArgument, "No such node %d"%node_id
 
         return node_id
 
 
         return node_id
 
@@ -153,15 +132,16 @@ class NodeNetwork(Row):
         """
 
         if is_primary:
         """
 
         if is_primary:
-            nodes = PLC.Nodes.Nodes(self.api, [self['node_id']]).values()
+            nodes = PLC.Nodes.Nodes(self.api, [self['node_id']])
             if not nodes:
             if not nodes:
-                raise PLCInvalidArgument, "No such node"
+                raise PLCInvalidArgument, "No such node %d"%node_id
             node = nodes[0]
 
             if node['nodenetwork_ids']:
                 conflicts = NodeNetworks(self.api, node['nodenetwork_ids'])
             node = nodes[0]
 
             if node['nodenetwork_ids']:
                 conflicts = NodeNetworks(self.api, node['nodenetwork_ids'])
-                for nodenetwork_id, nodenetwork in conflicts.iteritems():
-                    if ('nodenetwork_id' not in self or self['nodenetwork_id'] != nodenetwork_id) and \
+                for nodenetwork in conflicts:
+                    if ('nodenetwork_id' not in self or \
+                        self['nodenetwork_id'] != nodenetwork['nodenetwork_id']) and \
                        nodenetwork['is_primary']:
                         raise PLCInvalidArgument, "Can only set one primary interface per node"
 
                        nodenetwork['is_primary']:
                         raise PLCInvalidArgument, "Can only set one primary interface per node"
 
@@ -209,47 +189,23 @@ class NodeNetwork(Row):
             if 'ip' not in self or not self['ip']:
                 raise PLCInvalidArgument, "For ipmi method, ip is required"
 
             if 'ip' not in self or not self['ip']:
                 raise PLCInvalidArgument, "For ipmi method, ip is required"
 
-    def delete(self, commit = True):
-        """
-        Delete existing nodenetwork.
-        """
-
-        assert 'nodenetwork_id' in self
-
-        # Delete ourself
-        self.api.db.do("DELETE FROM nodenetworks" \
-                       " WHERE nodenetwork_id = %d" % \
-                       self['nodenetwork_id'])
-        
-        if commit:
-            self.api.db.commit()
-
 class NodeNetworks(Table):
     """
     Representation of row(s) from the nodenetworks table in the
     database.
     """
 
 class NodeNetworks(Table):
     """
     Representation of row(s) from the nodenetworks table in the
     database.
     """
 
-    def __init__(self, api, nodenetwork_id_or_hostname_list = None):
-        self.api = api
-
-        sql = "SELECT %s FROM nodenetworks" % \
-              ", ".join(NodeNetwork.fields)
-
-        if nodenetwork_id_or_hostname_list:
-            # Separate the list into integers and strings
-            nodenetwork_ids = filter(lambda nodenetwork_id: isinstance(nodenetwork_id, (int, long)),
-                                     nodenetwork_id_or_hostname_list)
-            hostnames = filter(lambda hostname: isinstance(hostname, StringTypes),
-                               nodenetwork_id_or_hostname_list)
-            sql += " WHERE (False"
-            if nodenetwork_ids:
-                sql += " OR nodenetwork_id IN (%s)" % ", ".join(map(str, nodenetwork_ids))
-            if hostnames:
-                sql += " OR hostname IN (%s)" % ", ".join(api.db.quote(hostnames)).lower()
-            sql += ")"
-
-        rows = self.api.db.selectall(sql)
-
-        for row in rows:
-            self[row['nodenetwork_id']] = NodeNetwork(api, row)
+    def __init__(self, api, nodenetwork_filter = None, columns = None):
+        Table.__init__(self, api, NodeNetwork, columns)
+
+        sql = "SELECT %s FROM nodenetworks WHERE True" % \
+              ", ".join(self.columns)
+
+        if nodenetwork_filter is not None:
+            if isinstance(nodenetwork_filter, (list, tuple, set)):
+                nodenetwork_filter = Filter(NodeNetwork.fields, {'nodenetwork_id': nodenetwork_filter})
+            elif isinstance(nodenetwork_filter, dict):
+                nodenetwork_filter = Filter(NodeNetwork.fields, nodenetwork_filter)
+            sql += " AND (%s)" % nodenetwork_filter.sql(api)
+
+        self.selectall(sql)