- support new schema
authorMark Huang <mlhuang@cs.princeton.edu>
Mon, 25 Sep 2006 14:55:43 +0000 (14:55 +0000)
committerMark Huang <mlhuang@cs.princeton.edu>
Mon, 25 Sep 2006 14:55:43 +0000 (14:55 +0000)
- use new table views instead of defining join_fields, extra_fields, and
  all_fields
- rename flush() to sync() to be like shelve
- validate_ip: can be blank (e.g., dns2)
- fix validation of ip parameters. Not sure why locals() doesn't work.
- deprecated node_nodenetworks join table

PLC/NodeNetworks.py

index f8ab85c..7f83448 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeNetworks.py,v 1.2 2006/09/15 20:31:59 tmack Exp $
+# $Id: NodeNetworks.py,v 1.3 2006/09/19 19:35:05 mlhuang Exp $
 #
 
 from types import StringTypes
@@ -33,7 +33,7 @@ class NodeNetwork(Row):
     """
     Representation of a row in the nodenetworks table. To use, optionally
     instantiate with a dict of values. Update as you would a
-    dict. Commit to the database with flush().
+    dict. Commit to the database with sync().
     """
 
     fields = {
@@ -51,17 +51,10 @@ class NodeNetwork(Row):
         # XXX Should be an int (bps)
         'bwlimit': Parameter(str, "Bandwidth limit"),
         'hostname': Parameter(str, "(Optional) Hostname"),
-        }
-
-    # These fields are derived from join tables and are not
-    # actually in the nodenetworks table.
-    join_fields = {
         'node_id': Parameter(int, "Node associated with this interface (if any)"),
         'is_primary': Parameter(bool, "Is the primary interface for this node"),
         }
 
-    all_fields = dict(join_fields.items() + fields.items())
-
     methods = ['static', 'dhcp', 'proxy', 'tap', 'ipmi', 'unknown']
 
     types = ['ipv4']
@@ -87,10 +80,11 @@ class NodeNetwork(Row):
        return type
 
     def validate_ip(self, ip):
-        try:
-            ip = socket.inet_ntoa(socket.inet_aton(ip))
-        except socket.error:
-            raise PLCInvalidArgument, "Invalid IP address " + ip
+        if ip:
+            try:
+                ip = socket.inet_ntoa(socket.inet_aton(ip))
+            except socket.error:
+                raise PLCInvalidArgument, "Invalid IP address " + ip
 
         return ip
 
@@ -130,7 +124,7 @@ class NodeNetwork(Row):
         # Validate hostname, and check for conflicts with a node hostname
         return PLC.Nodes.Node.validate_hostname(self, hostname)
 
-    def flush(self, commit = True):
+    def sync(self, commit = True):
         """
         Flush changes back to the database.
         """
@@ -145,11 +139,11 @@ class NodeNetwork(Row):
             raise PLCInvalidArgument, "method and type must both be specified"
 
         if method == "proxy" or method == "tap":
-            if 'mac' in self:
+            if 'mac' in self and self['mac']:
                 raise PLCInvalidArgument, "For %s method, mac should not be specified" % method
-            if 'ip' not in self:
+            if 'ip' not in self or not self['ip']:
                 raise PLCInvalidArgument, "For %s method, ip is required" % method
-            if method == "tap" and 'gateway' not in self:
+            if method == "tap" and ('gateway' not in self or not self['gateway']):
                 raise PLCInvalidArgument, "For tap method, gateway is required and should be " \
                       "the IP address of the node that proxies for this address"
             # Should check that the proxy address is reachable, but
@@ -158,9 +152,9 @@ class NodeNetwork(Row):
 
         elif method == "static":
             for key in ['ip', 'gateway', 'network', 'broadcast', 'netmask', 'dns1']:
-                if key not in self:
+                if key not in self or not self[key]:
                     raise PLCInvalidArgument, "For static method, %s is required" % key
-                locals()[key] = self[key]
+                globals()[key] = self[key]
             if not in_same_network(ip, network, netmask):
                 raise PLCInvalidArgument, "IP address %s is inconsistent with network %s/%s" % \
                       (ip, network, netmask)
@@ -172,7 +166,7 @@ class NodeNetwork(Row):
                       (gateway, ip, netmask)
 
         elif method == "ipmi":
-            if 'ip' not in self:
+            if 'ip' not in self or not self['ip']:
                 raise PLCInvalidArgument, "For ipmi method, ip is required"
 
         # Fetch a new nodenetwork_id if necessary
@@ -186,7 +180,8 @@ class NodeNetwork(Row):
             insert = False
 
         # Filter out fields that cannot be set or updated directly
-        fields = dict(filter(lambda (key, value): key in self.fields,
+        nodenetworks_fields = self.api.db.fields('nodenetworks')
+        fields = dict(filter(lambda (key, value): key in nodenetworks_fields,
                              self.items()))
 
         # Parameterize for safety
@@ -217,10 +212,9 @@ class NodeNetwork(Row):
         assert 'nodenetwork_id' in self
 
         # Delete ourself
-        for table in ['node_nodenetworks', 'nodenetworks']:
-            self.api.db.do("DELETE FROM %s" \
-                           " WHERE nodenetwork_id = %d" % \
-                           (table, self['nodenetwork_id']))
+        self.api.db.do("DELETE FROM nodenetworks" \
+                       " WHERE nodenetwork_id = %d" % \
+                       self['nodenetwork_id'])
         
         if commit:
             self.api.db.commit()
@@ -235,11 +229,7 @@ class NodeNetworks(Table):
         self.api = api
 
         # N.B.: Node IDs returned may be deleted.
-        sql = "SELECT nodenetworks.*" \
-              ", node_nodenetworks.node_id" \
-              ", node_nodenetworks.is_primary" \
-              " FROM nodenetworks" \
-              " LEFT JOIN node_nodenetworks USING (nodenetwork_id)"
+        sql = "SELECT * FROM nodenetworks"
 
         if nodenetwork_id_or_hostname_list:
             # Separate the list into integers and strings
@@ -255,9 +245,6 @@ class NodeNetworks(Table):
             sql += ")"
 
         rows = self.api.db.selectall(sql)
+
         for row in rows:
-            if self.has_key(row['nodenetwork_id']):
-                nodenetwork = self[row['nodenetwork_id']]
-                nodenetwork.update(row)
-            else:
-                self[row['nodenetwork_id']] = NodeNetwork(api, row)
+            self[row['nodenetwork_id']] = NodeNetwork(api, row)