- update comment
[plcapi.git] / PLC / NodeGroups.py
index 0b893b3..340a74f 100644 (file)
@@ -4,7 +4,7 @@
 # Mark Huang <mlhuang@cs.princeton.edu>
 # Copyright (C) 2006 The Trustees of Princeton University
 #
-# $Id: NodeGroups.py,v 1.1 2006/09/06 15:36:07 mlhuang Exp $
+# $Id: NodeGroups.py,v 1.7 2006/09/14 15:45:24 tmack Exp $
 #
 
 from types import StringTypes
@@ -35,19 +35,65 @@ class NodeGroup(Row):
         'node_ids': Parameter([int], "List of nodes in this node group"),
         }
 
+    all_fields = dict(fields.items() + join_fields.items())
+
     def __init__(self, api, fields):
         Row.__init__(self, fields)
         self.api = api
 
     def validate_name(self, name):
-        conflicts = NodeGroups(self.api, [name])
-        for nodegroup_id in conflicts:
+       #remove leading and trailing spaces
+       name = name.strip()
+
+       #make sure name is not blank after we removed the spaces
+        if not len(name) > 0:
+                raise PLCInvalidArgument, "Invalid Node Group Name"
+       
+       #make sure name doenst alredy exist
+       conflicts = NodeGroups(self.api, [name])
+       for nodegroup_id in conflicts:
             if 'nodegroup_id' not in self or self['nodegroup_id'] != nodegroup_id:
-                raise PLCInvalidArgument, "Node group name already in use"
+               raise PLCInvalidArgument, "Node group name already in use"
+
+       return name
+       
+    def validate_description(self, description):
+       #remove trailing and leading spaces
+       description = description.strip()
+       
+       #make sure decription is not blank after we removed the spaces  
+       if not len(description) > 0:
+               raise PLCInvalidArgument, "Invalid Node Group Description"
+
+       return description
 
     def add_node(self, node, commit = True):
         """
-        Add existing node to specified nodegroup.
+        Add node to existing nodegroup.
+        """
+
+        assert 'nodegroup_id' in self
+        assert isinstance(node, Node)
+        assert 'node_id' in node
+
+        node_id = node['node_id']
+        nodegroup_id = self['nodegroup_id']
+        self.api.db.do("INSERT INTO nodegroup_nodes (nodegroup_id, node_id)" \
+                       " VALUES(%(nodegroup_id)d, %(node_id)d)",
+                       locals())
+
+        if commit:
+            self.api.db.commit()
+
+        if 'node_ids' in self and node_id not in self['node_ids']:
+            self['node_ids'].append(node_id)
+
+        if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']:
+            node['nodegroup_ids'].append(nodegroup_id)
+
+    def remove_node(self, node, commit = True):
+        """
+        Remove node from existing nodegroup.
         """
 
         assert 'nodegroup_id' in self
@@ -62,7 +108,13 @@ class NodeGroup(Row):
 
         if commit:
             self.api.db.commit()
-        
+
+        if 'node_ids' in self and node_id not in self['node_ids']:
+            self['node_ids'].append(node_id)
+
+        if 'nodegroup_ids' in node and nodegroup_id not in node['nodegroup_ids']:
+            node['nodegroup_ids'].append(nodegroup_id)
+
     def flush(self, commit = True):
         """
         Flush changes back to the database.
@@ -93,7 +145,7 @@ class NodeGroup(Row):
             sql = "INSERT INTO nodegroups (%s) VALUES (%s)" % \
                   (", ".join(keys), ", ".join(values))
         else:
-            # Update existing row in sites table
+            # Update existing row in nodegroups table
             columns = ["%s = %s" % (key, value) for (key, value) in zip(keys, values)]
             sql = "UPDATE nodegroups SET " + \
                   ", ".join(columns) + \
@@ -103,6 +155,7 @@ class NodeGroup(Row):
 
         if commit:
             self.api.db.commit()
+           
 
     def delete(self, commit = True):
         """
@@ -110,7 +163,7 @@ class NodeGroup(Row):
         """
 
         assert 'nodegroup_id' in self
-
+       assert self is not {}
         # Delete ourself
         tables = ['nodegroup_nodes', 'override_bootscripts',
                   'conf_assoc', 'node_root_access']
@@ -123,8 +176,8 @@ class NodeGroup(Row):
 
         for table in tables:
             self.api.db.do("DELETE FROM %s" \
-                           " WHERE nodegroup_id = %(nodegroup_id)" % \
-                           table, self)
+                           " WHERE nodegroup_id = %d" % \
+                           (table, self['nodegroup_id']), self)
 
         if commit:
             self.api.db.commit()
@@ -136,7 +189,7 @@ class NodeGroups(Table):
     """
 
     def __init__(self, api, nodegroup_id_or_name_list = None):
-        self.api = api
+       self.api = api
 
         # N.B.: Node IDs returned may be deleted.
         sql = "SELECT nodegroups.*, nodegroup_nodes.node_id" \