No need to import sfa.plc.api just to get a connection to PLCAPI. Just use plcapi...
authorTony Mack <tmack@cs.princeton.edu>
Thu, 13 Aug 2009 04:31:10 +0000 (04:31 +0000)
committerTony Mack <tmack@cs.princeton.edu>
Thu, 13 Aug 2009 04:31:10 +0000 (04:31 +0000)
PLC/SFA.py

index 8363e43..333ee3d 100644 (file)
@@ -1,5 +1,13 @@
-from types import StringTypes
 import traceback
+from types import StringTypes
+from PLC.Sites import Sites
+try:
+    from sfa.plc.sfaImport import sfaImport, cleanup_string
+    from sfa.util.debug import log
+    packages_imported = True
+except:
+    packages_imported = False
+    
 
 def wrap_exception(method):
     def wrap(*args, **kwds):
@@ -7,27 +15,24 @@ def wrap_exception(method):
             return method(*args, **kwds)
         except:
             traceback.print_exc()
-    return wrap 
+    return wrap
+
+def required_packages_imported(method):
+    def wrap(*args, **kwds):
+        if packages_imported:
+            return method(*args, **kwds)
+        else:
+            return
+    return wrap         
 
 class SFA:
     
     @wrap_exception
-    def __init__(self):
-        try:
-            from sfa.plc.sfaImport import sfaImport, cleanup_string
-            from sfa.plc.api import GeniAPI
-            from sfa.util.debug import log
-            packages_imported = True
-        except:
-            packages_imported = False
-            traceback.print_exc()        
+    @required_packages_imported
+    def __init__(self, api):
         
-        self.cleanup_string = cleanup_string
-        self.log = log
+        self.api = api
         self.sfa = sfaImport()
-        geniapi = GeniAPI()
-        self.plcapi = geniapi.plshell
-        self.auth = geniapi.plauth
 
         if self.sfa.level1_auth:
             self.authority = self.sfa.level1_auth
@@ -36,7 +41,7 @@ class SFA:
 
 
     def get_login_base(self, site_id):
-        sites = self.plcapi.GetSites(self.auth, [site_id], ['login_base'])
+        sites = Sites(self.api, [site_id], ['login_base'])
         login_base = sites[0]['login_base']
         return login_base
         
@@ -60,6 +65,7 @@ class SFA:
         return login_bases
 
     @wrap_exception
+    @required_packages_imported
     def update_record(self, object, type, login_bases = None):
         try:
             # determine this objects site and login_base
@@ -70,7 +76,7 @@ class SFA:
                 login_bases = [login_bases]
 
             for login_base in login_bases:
-                login_base = self.cleanup_string(login_base)
+                login_base = cleanup_string(login_base)
                 parent_hrn = self.authority + "." + login_base
                 if type in ['person']:
                     self.sfa.import_person(parent_hrn, object)
@@ -86,11 +92,12 @@ class SFA:
             for key in keys:
                 if object.has_key(key):
                     id = object[key]
-            traceback.print_exc(file = self.log)
-            print >> self.log, "Error importing %s record for %s into geni db: %s" % \
+            traceback.print_exc(file = log)
+            print >> log, "Error importing %s record for %s into geni db: %s" % \
                   (type, id, e.message)
 
     @wrap_exception
+    @required_packages_imported
     def delete_record(self, object, type, login_base = None):
 
         if not login_base:
@@ -99,7 +106,7 @@ class SFA:
             login_bases = [login_base]
 
         for login_base in login_bases:
-            login_base = self.cleanup_string(login_base)
+            login_base = cleanup_string(login_base)
             parent_hrn = self.authority + "." + login_base
             self.sfa.delete_record(parent_hrn, object, type)