- use PLC_WWW_SSL_PORT
[plcapi.git] / PLC / NodeGroups.py
index 571e6fa..90bd5f9 100644 (file)
@@ -4,13 +4,14 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeGroups.py,v 1.9 2006/09/20 14:41:59 mlhuang Exp $
+# $Id: NodeGroups.py,v 1.18 2006/11/09 03:07:42 mlhuang Exp $
 #
 
 from types import StringTypes
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
 #
 
 from types import StringTypes
 
 from PLC.Faults import *
 from PLC.Parameter import Parameter
+from PLC.Filter import Filter
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Nodes import Node, Nodes
 from PLC.Debug import profile
 from PLC.Table import Row, Table
 from PLC.Nodes import Node, Nodes
@@ -22,23 +23,20 @@ class NodeGroup(Row):
     dict. Commit to the database with sync().
     """
 
     dict. Commit to the database with sync().
     """
 
+    table_name = 'nodegroups'
+    primary_key = 'nodegroup_id'
+    join_tables = ['nodegroup_node', 'conf_file_nodegroup']
     fields = {
         'nodegroup_id': Parameter(int, "Node group identifier"),
         'name': Parameter(str, "Node group name", max = 50),
     fields = {
         'nodegroup_id': Parameter(int, "Node group identifier"),
         'name': Parameter(str, "Node group name", max = 50),
-        'description': Parameter(str, "Node group description", max = 200),
+        'description': Parameter(str, "Node group description", max = 200, nullok = True),
         'node_ids': Parameter([int], "List of nodes in this node group"),
         'node_ids': Parameter([int], "List of nodes in this node group"),
+        'conf_file_ids': Parameter([int], "List of configuration files specific to this node group"),
         }
 
         }
 
-    def __init__(self, api, fields):
-        Row.__init__(self, fields)
-        self.api = api
-
     def validate_name(self, name):
     def validate_name(self, name):
-       # Remove leading and trailing spaces
-       name = name.strip()
-
-       # Make sure name is not blank after we removed the spaces
-        if not len(name) > 0:
+       # Make sure name is not blank
+        if not len(name):
                 raise PLCInvalidArgument, "Invalid node group name"
        
        # Make sure node group does not alredy exist
                 raise PLCInvalidArgument, "Invalid node group name"
        
        # Make sure node group does not alredy exist
@@ -60,17 +58,18 @@ class NodeGroup(Row):
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
-        self.api.db.do("INSERT INTO nodegroup_node (nodegroup_id, node_id)" \
-                       " VALUES(%(nodegroup_id)d, %(node_id)d)",
-                       locals())
 
 
-        if commit:
-            self.api.db.commit()
+        if node_id not in self['node_ids']:
+            assert nodegroup_id not in node['nodegroup_ids']
 
 
-        if 'node_ids' in self and node_id not in self['node_ids']:
-            self['node_ids'].append(node_id)
+            self.api.db.do("INSERT INTO nodegroup_node (nodegroup_id, node_id)" \
+                           " VALUES(%(nodegroup_id)d, %(node_id)d)",
+                           locals())
 
 
-        if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']:
+            if commit:
+                self.api.db.commit()
+
+            self['node_ids'].append(node_id)
             node['nodegroup_ids'].append(nodegroup_id)
 
     def remove_node(self, node, commit = True):
             node['nodegroup_ids'].append(nodegroup_id)
 
     def remove_node(self, node, commit = True):
@@ -84,78 +83,20 @@ class NodeGroup(Row):
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
 
         node_id = node['node_id']
         nodegroup_id = self['nodegroup_id']
-        self.api.db.do("DELETE FROM nodegroup_node" \
-                       " WHERE nodegroup_id = %(nodegroup_id)d" \
-                       " AND node_id = %(node_id)d",
-                       locals())
 
 
-        if commit:
-            self.api.db.commit()
+        if node_id in self['node_ids']:
+            assert nodegroup_id in node['nodegroup_ids']
 
 
-        if 'node_ids' in self and node_id in self['node_ids']:
-            self['node_ids'].remove(node_id)
+            self.api.db.do("DELETE FROM nodegroup_node" \
+                           " WHERE nodegroup_id = %(nodegroup_id)d" \
+                           " AND node_id = %(node_id)d",
+                           locals())
 
 
-        if 'nodegroup_ids' in node and nodegroup_id in node['nodegroup_ids']:
-            node['nodegroup_ids'].remove(nodegroup_id)
+            if commit:
+                self.api.db.commit()
 
 
-    def sync(self, commit = True):
-        """
-        Flush changes back to the database.
-        """
-
-        self.validate()
-
-        # Fetch a new nodegroup_id if necessary
-        if 'nodegroup_id' not in self:
-            rows = self.api.db.selectall("SELECT NEXTVAL('nodegroups_nodegroup_id_seq') AS nodegroup_id")
-            if not rows:
-                raise PLCDBError, "Unable to fetch new nodegroup_id"
-            self['nodegroup_id'] = rows[0]['nodegroup_id']
-            insert = True
-        else:
-            insert = False
-
-        # Filter out fields that cannot be set or updated directly
-        nodegroups_fields = self.api.db.fields('nodegroups')
-        fields = dict(filter(lambda (key, value): key in nodegroups_fields,
-                             self.items()))
-
-        # Parameterize for safety
-        keys = fields.keys()
-        values = [self.api.db.param(key, value) for (key, value) in fields.items()]
-
-        if insert:
-            # Insert new row in nodegroups table
-            sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
-                  (", ".join(keys), ", ".join(values))
-        else:
-            # Update existing row in nodegroups table
-            columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
-            sql = "UPDATE nodegroups SET " + \
-                  ", ".join(columns) + \
-                  " WHERE nodegroup_id = %(nodegroup_id)d"
-
-        self.api.db.do(sql, fields)
-
-        if commit:
-            self.api.db.commit()
-           
-
-    def delete(self, commit = True):
-        """
-        Delete existing nodegroup from the database.
-        """
-
-        assert 'nodegroup_id' in self
-
-        # Clean up miscellaneous join tables
-        for table in ['nodegroup_node', 'nodegroups']:
-            self.api.db.do("DELETE FROM %s" \
-                           " WHERE nodegroup_id = %d" % \
-                           (table, self['nodegroup_id']), self)
-
-        if commit:
-            self.api.db.commit()
+            self['node_ids'].remove(node_id)
+            node['nodegroup_ids'].remove(nodegroup_id)
 
 class NodeGroups(Table):
     """
 
 class NodeGroups(Table):
     """
@@ -163,30 +104,21 @@ class NodeGroups(Table):
     database.
     """
 
     database.
     """
 
-    def __init__(self, api, nodegroup_id_or_name_list = None):
-       self.api = api
-
-        sql = "SELECT * FROM view_nodegroups"
-
-        if nodegroup_id_or_name_list:
-            # Separate the list into integers and strings
-            nodegroup_ids = filter(lambda nodegroup_id: isinstance(nodegroup_id, (int, long)),
-                                   nodegroup_id_or_name_list)
-            names = filter(lambda name: isinstance(name, StringTypes),
-                           nodegroup_id_or_name_list)
-            sql += " WHERE (False"
-            if nodegroup_ids:
-                sql += " OR nodegroup_id IN (%s)" % ", ".join(map(str, nodegroup_ids))
-            if names:
-                sql += " OR name IN (%s)" % ", ".join(api.db.quote(names))
-            sql += ")"
-
-        rows = self.api.db.selectall(sql)
-
-        for row in rows:
-            self[row['nodegroup_id']] = nodegroup = NodeGroup(api, row)
-            for aggregate in ['node_ids']:
-                if not nodegroup.has_key(aggregate) or nodegroup[aggregate] is None:
-                    nodegroup[aggregate] = []
-                else:
-                    nodegroup[aggregate] = map(int, nodegroup[aggregate].split(','))
+    def __init__(self, api, nodegroup_filter = None, columns = None):
+        Table.__init__(self, api, NodeGroup, columns)
+
+        sql = "SELECT %s FROM view_nodegroups WHERE True" % \
+              ", ".join(self.columns)
+
+        if nodegroup_filter is not None:
+            if isinstance(nodegroup_filter, (list, tuple, set)):
+                # Separate the list into integers and strings
+                ints = filter(lambda x: isinstance(x, (int, long)), nodegroup_filter)
+                strs = filter(lambda x: isinstance(x, StringTypes), nodegroup_filter)
+                nodegroup_filter = Filter(NodeGroup.fields, {'nodegroup_id': ints, 'name': strs})
+                sql += " AND (%s)" % nodegroup_filter.sql(api, "OR")
+            elif isinstance(nodegroup_filter, dict):
+                nodegroup_filter = Filter(NodeGroup.fields, nodegroup_filter)
+                sql += " AND (%s)" % nodegroup_filter.sql(api, "AND")
+
+        self.selectall(sql)