move package imports out of __init__()
authorTony Mack <tmack@cs.princeton.edu>
Tue, 4 Aug 2009 01:14:46 +0000 (01:14 +0000)
committerTony Mack <tmack@cs.princeton.edu>
Tue, 4 Aug 2009 01:14:46 +0000 (01:14 +0000)
PLC/Methods/AddNode.py
PLC/Methods/BootNotifyOwners.py
PLC/Methods/DeleteNode.py
PLC/Persons.py
PLC/SFA.py
PLC/__init__.py

index ef76fe2..cf9bdba 100644 (file)
@@ -12,6 +12,7 @@ from PLC.TagTypes import TagTypes
 from PLC.NodeTags import NodeTags
 from PLC.Methods.AddNodeTag import AddNodeTag
 from PLC.Methods.UpdateNodeTag import UpdateNodeTag
+from PLC.SFA import SFA
 
 can_update = ['hostname', 'node_type', 'boot_state', 'model', 'version']
 
@@ -45,7 +46,7 @@ class AddNode(Method):
         [native,tags,rejected]=Row.split_fields(node_fields,[Node.fields,Node.tags])
 
         # type checking
-        native = Row.check_fields (native, self.accepted_fields)
+        native = Row.check_fields(native, self.accepted_fields)
         if rejected:
             raise PLCInvalidArgument, "Cannot add Node with column(s) %r"%rejected
 
@@ -82,8 +83,12 @@ class AddNode(Method):
             else:
                 UpdateNodeTag(self.api).__call__(auth,node_tags[0]['node_tag_id'],value)
 
-       self.event_objects = {'Site': [site['site_id']],
+        self.event_objects = {'Site': [site['site_id']],
                             'Node': [node['node_id']]} 
-       self.message = "Node %s created" % node['node_id']
-
+        self.message = "Node %s created" % node['node_id']
+        
+        # sync with geni db
+        sfa = SFA()
+        sfa.update_record(node, 'node', site['login_base']) 
+        
         return node['node_id']
index 704c57b..3564814 100644 (file)
@@ -19,7 +19,7 @@ class BootNotifyOwners(Method):
     roles = ['node']
 
     accepts = [
-        Mixed(BootAuth(), SessionAuth()),
+        Auth(),
         Message.fields['message_id'],
         Parameter(int, "Notify PIs"),
         Parameter(int, "Notify technical contacts"),
index 5919f03..c97f20a 100644 (file)
@@ -5,6 +5,7 @@ from PLC.Method import Method
 from PLC.Parameter import Parameter, Mixed
 from PLC.Auth import Auth
 from PLC.Nodes import Node, Nodes
+from PLC.SFA import SFA
 
 class DeleteNode(Method):
     """
@@ -49,9 +50,12 @@ class DeleteNode(Method):
         site_id=node['site_id']
         node.delete()
 
-       # Logging variables
+        # Logging variables
         # it's not much use to attach to the node as it's going to vanish...
-       self.event_objects = {'Node': [node_id], 'Site': [site_id] }
-       self.message = "Node %d deleted" % node['node_id']
+        self.event_objects = {'Node': [node_id], 'Site': [site_id] }
+        self.message = "Node %d deleted" % node['node_id']
+
+        sfa = SFA()
+        sfa.delete_record(node, 'node')       
 
         return 1
index c2fd613..6feb270 100644 (file)
@@ -380,70 +380,15 @@ class Persons(Table):
     def __init__(self, api, person_filter = None, columns = None):
         Table.__init__(self, api, Person, columns)
 
-       foreign_fields = {'role_ids': ('role_id', 'person_role'),
-                         'roles': ('name', 'roles'),
-                          'site_ids': ('site_id', 'person_site'),
-                          'key_ids': ('key_id', 'person_key'),
-                          'slice_ids': ('slice_id', 'slice_person')
-                          }
-       foreign_keys = {}
-       db_fields = filter(lambda field: field not in foreign_fields.keys(), Person.fields.keys())
-       all_fields = db_fields + [value[0] for value in foreign_fields.values()]
-       fields = []
-       _select = "SELECT "
-       _from = " FROM persons "
-       _join = " LEFT JOIN peer_person USING (person_id) "  
-       _where = " WHERE deleted IS False "
-
-       if not columns:
-           # include all columns       
-           fields = all_fields
-           tables = [value[1] for value in foreign_fields.values()]
-           tables.sort()
-           for key in foreign_fields.keys():
-               foreign_keys[foreign_fields[key][0]] = key  
-           for table in tables:
-               if table in ['roles']:
-                   _join += " LEFT JOIN roles USING(role_id) "
-               else:   
-                   _join += " LEFT JOIN %s USING (person_id) " % (table)
-       else: 
-           tables = set()
-           columns = filter(lambda column: column in db_fields+foreign_fields.keys(), columns)
-           columns.sort()
-           for column in columns: 
-               if column in foreign_fields.keys():
-                   (field, table) = foreign_fields[column]
-                   foreign_keys[field] = column
-                   fields += [field]
-                   tables.add(table)
-                   if column in ['roles']:
-                       _join += " LEFT JOIN roles USING(role_id) "
-                   else:
-                       _join += " LEFT JOIN %s USING (person_id)" % \
-                               (foreign_fields[column][1])
-               
-               else:
-                   fields += [column]  
-       
-       # postgres will return timestamps as datetime objects. 
-       # XMLPRC cannot marshal datetime so convert to int
-       timestamps = ['date_created', 'last_updated', 'verification_expires']
-       for field in fields:
-           if field in timestamps:
-               fields[fields.index(field)] = \
-                "CAST(date_part('epoch', %s) AS bigint) AS %s" % (field, field)
-
-       _select += ", ".join(fields)
-       sql = _select + _from + _join + _where
-
-       # deal with filter                      
+       sql = "SELECT %s FROM view_persons WHERE deleted IS False" % \
+              ", ".join(self.columns)
+
         if person_filter is not None:
             if isinstance(person_filter, (list, tuple, set)):
                 # Separate the list into integers and strings
                 ints = filter(lambda x: isinstance(x, (int, long)), person_filter)
                 strs = filter(lambda x: isinstance(x, StringTypes), person_filter)
-                person_filter = Filter(Person.fields, {'person_id': ints, 'email': strs})
+                node_filter = Filter(Person.fields, {'person_id': ints, 'email': strs})
                 sql += " AND (%s) %s" % person_filter.sql(api, "OR")
             elif isinstance(person_filter, dict):
                 person_filter = Filter(Person.fields, person_filter)
@@ -457,29 +402,4 @@ class Persons(Table):
             else:
                 raise PLCInvalidArgument, "Wrong person filter %r"%person_filter
 
-       # aggregate data
-       all_persons = {}
-       for row in self.api.db.selectall(sql):
-           person_id = row['person_id']
-
-           if all_persons.has_key(person_id):
-               for (key, key_list) in foreign_keys.items():
-                   data = row.pop(key)
-                   row[key_list] = [data]
-                   if data and data not in all_persons[person_id][key_list]:
-                       all_persons[person_id][key_list].append(data)
-            else:
-               for key in foreign_keys.keys():
-                    value = row.pop(key)
-                   if value:   
-                       row[foreign_keys[key]] = [value]
-                   else:
-                       row[foreign_keys[key]] = []
-               if row: 
-                   all_persons[person_id] = row
-               
-       # populate self
-       for row in all_persons.values():
-           obj = self.classobj(self.api, row)
-            self.append(obj)
-
+        self.selectall(sql)
index bd06772..30ef40b 100644 (file)
@@ -1,32 +1,43 @@
 from types import StringTypes
 import traceback
-from types import StringTypes
-import traceback
+try:
+    from sfa.plc.sfaImport import sfaImport, cleanup_string
+    from sfa.plc.api import GeniAPI
+    from sfa.util.debug import log
+    packages_imported = True  
+except:
+    packages_imported = False
+    traceback.print_exc()
+
+def wrap_exception(method):
+    def wrap(*args, **kwds):
+        try:
+            return method(*args, **kwds)
+        except:
+            traceback.print_exc()
+    return wrap 
 
 class SFA:
     
+    @wrap_exception
     def __init__(self):
-        try:
-            from sfa.plc.sfaImport import sfaImport
-            from sfa.plc.api import GeniAPI
-            from sfa.util.debug import log 
-            self.log = log
-            self.sfa = sfaImport()
-            geniapi = GeniAPI()
-            self.plcapi = geniapi.plshell
-            self.auth = geniapi.plauth
-        except:
-            traceback.print_exc(file = self.log)
-
-        if self.gimport.level1_auth:
-            self.authority = self.gimport.level1_auth
+        self.log = log
+        self.sfa = sfaImport()
+        geniapi = GeniAPI()
+        self.plcapi = geniapi.plshell
+        self.auth = geniapi.plauth
+
+        if self.sfa.level1_auth:
+            self.authority = self.sfa.level1_auth
         else:
-            self.authority = self.gimport.root_auth
+            self.authority = self.sfa.root_auth
 
 
-    def get_login_base(site_id):
+    def get_login_base(self, site_id):
         sites = self.plcapi.GetSites(self.auth, [site_id], ['login_base'])
-        login_base = sites
+        login_base = sites[0]['login_base']
+        return login_base
+        
 
     def get_login_bases(self, object):
         login_bases = []
@@ -46,6 +57,7 @@ class SFA:
 
         return login_bases
 
+    @wrap_exception
     def update_record(self, object, type, login_bases = None):
         try:
             # determine this objects site and login_base
@@ -56,7 +68,7 @@ class SFA:
                 login_bases = [login_bases]
 
             for login_base in login_bases:
-                login_base = self.sfa.cleanup_string(login_base)
+                login_base = cleanup_string(login_base)
                 parent_hrn = self.authority + "." + login_base
                 if type in ['person']:
                     self.sfa.import_person(parent_hrn, object)
@@ -76,21 +88,22 @@ class SFA:
             print >> self.log, "Error importing %s record for %s into geni db: %s" % \
                   (type, id, e.message)
 
+    @wrap_exception
     def delete_record(self, object, type, login_base = None):
-        if not login_bases:
-            login_bases = get_login_bases(object)
+
+        if not login_base:
+            login_bases = self.get_login_bases(object)
+        else:
+            login_bases = [login_base]
 
         for login_base in login_bases:
-            login_base = self.sfa.cleanup_string(login_base)
+            login_base = cleanup_string(login_base)
             parent_hrn = self.authority + "." + login_base
             self.sfa.delete_record(parent_hrn, object, type)
 
     def update_site(self, site, login_base = None):
         self.update_record(site, 'site', login_base)
 
-    def update_site(self, site, login_base = None):
-        self.update_record(site, 'site', login_base)
-
     def update_node(self, node, login_base = None):
         self.update_record(node, 'node', login_base)
 
index 6096fb2..73385d7 100644 (file)
@@ -40,6 +40,7 @@ PyCurl
 Roles
 sendmail
 Sessions
+SFA
 Shell
 Sites
 SliceInstantiations