No need to import sfa.plc.api just to get a connection to PLCAPI. Just use plcapi...
[plcapi.git] / PLC / SFA.py
index bd06772..333ee3d 100644 (file)
@@ -1,32 +1,50 @@
-from types import StringTypes
 import traceback
 from types import StringTypes
-import traceback
-
-class SFA:
+from PLC.Sites import Sites
+try:
+    from sfa.plc.sfaImport import sfaImport, cleanup_string
+    from sfa.util.debug import log
+    packages_imported = True
+except:
+    packages_imported = False
     
-    def __init__(self):
+
+def wrap_exception(method):
+    def wrap(*args, **kwds):
         try:
-            from sfa.plc.sfaImport import sfaImport
-            from sfa.plc.api import GeniAPI
-            from sfa.util.debug import log 
-            self.log = log
-            self.sfa = sfaImport()
-            geniapi = GeniAPI()
-            self.plcapi = geniapi.plshell
-            self.auth = geniapi.plauth
+            return method(*args, **kwds)
         except:
-            traceback.print_exc(file = self.log)
+            traceback.print_exc()
+    return wrap
 
-        if self.gimport.level1_auth:
-            self.authority = self.gimport.level1_auth
+def required_packages_imported(method):
+    def wrap(*args, **kwds):
+        if packages_imported:
+            return method(*args, **kwds)
         else:
-            self.authority = self.gimport.root_auth
+            return
+    return wrap         
 
+class SFA:
+    
+    @wrap_exception
+    @required_packages_imported
+    def __init__(self, api):
+        
+        self.api = api
+        self.sfa = sfaImport()
 
-    def get_login_base(site_id):
-        sites = self.plcapi.GetSites(self.auth, [site_id], ['login_base'])
-        login_base = sites
+        if self.sfa.level1_auth:
+            self.authority = self.sfa.level1_auth
+        else:
+            self.authority = self.sfa.root_auth
+
+
+    def get_login_base(self, site_id):
+        sites = Sites(self.api, [site_id], ['login_base'])
+        login_base = sites[0]['login_base']
+        return login_base
+        
 
     def get_login_bases(self, object):
         login_bases = []
@@ -46,6 +64,8 @@ class SFA:
 
         return login_bases
 
+    @wrap_exception
+    @required_packages_imported
     def update_record(self, object, type, login_bases = None):
         try:
             # determine this objects site and login_base
@@ -56,7 +76,7 @@ class SFA:
                 login_bases = [login_bases]
 
             for login_base in login_bases:
-                login_base = self.sfa.cleanup_string(login_base)
+                login_base = cleanup_string(login_base)
                 parent_hrn = self.authority + "." + login_base
                 if type in ['person']:
                     self.sfa.import_person(parent_hrn, object)
@@ -72,25 +92,27 @@ class SFA:
             for key in keys:
                 if object.has_key(key):
                     id = object[key]
-            traceback.print_exc(file = self.log)
-            print >> self.log, "Error importing %s record for %s into geni db: %s" % \
+            traceback.print_exc(file = log)
+            print >> log, "Error importing %s record for %s into geni db: %s" % \
                   (type, id, e.message)
 
+    @wrap_exception
+    @required_packages_imported
     def delete_record(self, object, type, login_base = None):
-        if not login_bases:
-            login_bases = get_login_bases(object)
+
+        if not login_base:
+            login_bases = self.get_login_bases(object)
+        else:
+            login_bases = [login_base]
 
         for login_base in login_bases:
-            login_base = self.sfa.cleanup_string(login_base)
+            login_base = cleanup_string(login_base)
             parent_hrn = self.authority + "." + login_base
             self.sfa.delete_record(parent_hrn, object, type)
 
     def update_site(self, site, login_base = None):
         self.update_record(site, 'site', login_base)
 
-    def update_site(self, site, login_base = None):
-        self.update_record(site, 'site', login_base)
-
     def update_node(self, node, login_base = None):
         self.update_record(node, 'node', login_base)