remove unnecessary imports
[plcapi.git] / PLC / SFA.py
index bd06772..4fdaa1c 100644 (file)
@@ -1,32 +1,65 @@
-from types import StringTypes
 import traceback
 from types import StringTypes
-import traceback
-
-class SFA:
+from PLC.Sites import Sites
+try:
+    from sfa.util.geniclient import *
+    from sfa.util.config import *
+    from sfa.trust.credential import *         
+    from sfa.plc.sfaImport import cleanup_string
+    from sfa.util.record import *
+    from sfa.trust.hierarchy import *
+    from sfa.util.misc import *
+    packages_imported = True
+except:
+    packages_imported = False
     
-    def __init__(self):
+def wrap_exception(method):
+    def wrap(*args, **kwds):
         try:
-            from sfa.plc.sfaImport import sfaImport
-            from sfa.plc.api import GeniAPI
-            from sfa.util.debug import log 
-            self.log = log
-            self.sfa = sfaImport()
-            geniapi = GeniAPI()
-            self.plcapi = geniapi.plshell
-            self.auth = geniapi.plauth
+            return method(*args, **kwds)
         except:
-            traceback.print_exc(file = self.log)
+            traceback.print_exc()
+    return wrap
 
-        if self.gimport.level1_auth:
-            self.authority = self.gimport.level1_auth
+def required_packages_imported(method):
+    def wrap(*args, **kwds):
+        if packages_imported:
+            return method(*args, **kwds)
         else:
-            self.authority = self.gimport.root_auth
+            return
+    return wrap         
 
-
-    def get_login_base(site_id):
-        sites = self.plcapi.GetSites(self.auth, [site_id], ['login_base'])
-        login_base = sites
+class SFA:
+    
+    @wrap_exception
+    @required_packages_imported
+    def __init__(self, api):
+        
+        self.api = api
+        
+        # Get the path to the sfa server key/cert files from 
+        # the sfa hierarchy object
+        sfa_hierarchy = Hierarchy()
+        sfa_key_path = sfa_hierarchy.basedir
+        key_file = os.path.join(sfa_key_path, "server.key")
+        cert_file = os.path.join(sfa_key_path, "server.cert")
+    
+        # get a connection to our local sfa registry
+        # and a valid credential
+        config = Config()
+        self.authority = config.SFA_INTERFACE_HRN
+        url = 'http://%s:%s/' %(config.SFA_REGISTRY_HOST, config.SFA_REGISTRY_PORT) 
+        self.registry = GeniCleint(url, key_file, cert_file)
+        #self.sfa_api = GeniAPI(key_file = key_file, cert_file = cert_file)
+        #self.credential = self.sfa_api.getCredential()
+        cred_file = '/etc/sfa/slicemgr.plc.authority.cred'
+        self.credential = Credential(filename = cred_file)   
+
+    def get_login_base(self, site_id):
+        sites = Sites(self.api, [site_id], ['login_base'])
+        login_base = sites[0]['login_base']
+        return login_base
+        
 
     def get_login_bases(self, object):
         login_bases = []
@@ -38,7 +71,7 @@ class SFA:
         elif object.has_key('site_ids') and object['site_ids']:
             site_ids.extend(object['site_ids'])
         else:
-            raise Exception
+            return login_bases
 
         # get the login bases
         for site_id in site_ids:
@@ -46,70 +79,105 @@ class SFA:
 
         return login_bases
 
+    def get_object_hrn(self, type, object, authority, login_base):
+        parent_hrn = authority + "." + login_base 
+        if type in ['person', 'user']:
+            name_parts = object['email'].split("@")
+            hrn = parent_hrn + "." + name_parts[:1][0]
+        
+        elif type in ['slice']:
+            name_parts = object['name'].split("_")
+            hrn = parent_hrn + "." + name_parts[-1:][0]
+        
+        elif type in ['node']:
+            hrn = hostname_to_hrn(self.authority, login_base, object['hostname'])
+        
+        elif type in ['site', 'authority']:
+            hrn = parent_hrn
+        
+        else:
+            raise Exception, "Invalid record type %(type)s" % locals()
+
+        return hrn
+
+    def sfa_record_exists(self, hrn, type):
+        """
+        check if the record (hrn and type) already exist in our sfa db
+        """
+        exists = False
+        # list is quicker than resolve
+        parent_hrn = get_authority(hrn)
+        if not parent_hrn: parent_hrn = hrn
+        #records = self.registry.list(self.credential, parent_hrn)
+        records = self.registry.resolve(self.credential, hrn)
+        for record in records: 
+            if record['type'] == type and record['hrn'] == hrn:
+                exists = True
+        return exists 
+
+    @wrap_exception
+    @required_packages_imported
     def update_record(self, object, type, login_bases = None):
-        try:
-            # determine this objects site and login_base
-            if not login_bases:
-                login_bases = self.get_login_bases(object)
-
-            if isinstance(login_bases, StringTypes):
-                login_bases = [login_bases]
-
-            for login_base in login_bases:
-                login_base = self.sfa.cleanup_string(login_base)
-                parent_hrn = self.authority + "." + login_base
-                if type in ['person']:
-                    self.sfa.import_person(parent_hrn, object)
-                elif type in ['slice']:
-                    self.sfa.import_slice(parent_hrn, object)
-                elif type in ['node']:
-                    self.sfa.import_node(parent_hrn, object)
-                elif type in ['site']:
-                    self.sfa.import_site(self.authority, object)
-        except Exception, e:
-            id = None
-            keys = ['name', 'hostname', 'email', 'login_base']
-            for key in keys:
-                if object.has_key(key):
-                    id = object[key]
-            traceback.print_exc(file = self.log)
-            print >> self.log, "Error importing %s record for %s into geni db: %s" % \
-                  (type, id, e.message)
-
-    def delete_record(self, object, type, login_base = None):
+        # determine this objects site and login_base
         if not login_bases:
-            login_bases = get_login_bases(object)
+            login_bases = self.get_login_bases(object)
+
+        if isinstance(login_bases, StringTypes):
+            login_bases = [login_bases]
 
         for login_base in login_bases:
-            login_base = self.sfa.cleanup_string(login_base)
+            login_base = cleanup_string(login_base)
             parent_hrn = self.authority + "." + login_base
-            self.sfa.delete_record(parent_hrn, object, type)
-
-    def update_site(self, site, login_base = None):
-        self.update_record(site, 'site', login_base)
-
-    def update_site(self, site, login_base = None):
-        self.update_record(site, 'site', login_base)
-
-    def update_node(self, node, login_base = None):
-        self.update_record(node, 'node', login_base)
-
-    def update_slice(self, slice, login_base = None):
-        self.update_record(slice, 'slice', login_base)
+                        
+            if type in ['person']:
+                type = 'user'
+            elif type in ['site']:
+                type = 'authority'
+        
+            # set the object hrn, tpye and create the sfa record 
+            # object 
+            object['hrn'] = self.get_object_hrn(type, object, self.authority, login_base)   
+            object['type'] = type
+            if type in ['user']:
+                record = UserRecord(dict=object)
+
+            elif type in ['slice']:
+                record = SliceRecord(dict=object)
+
+            elif type in ['node']:
+                record = NodeRecord(dict=object)
+    
+            elif type in ['authority']:
+                record = AuthorityRecord(dict=object)
 
-    def update_person(self, person, login_base = None):
-        self.update_record(person, 'person', login_base)
+            else:
+                raise Exception, "Invalid record type %(type)s" % locals()
 
-    def delete_site(self, site, login_base = None):
-        site_name = site['login_base']
-        hrn = parent_hrn + site_name
-        self.delete_record(site, 'site', login_base)
+            # add the record to sfa
+            if not self.sfa_record_exists(object['hrn'], type):
+                self.registry.register(self.credential, record)
+            else:
+                self.registry.update(self.credential, record)
 
-    def delete_node(self, node, login_base = None):
-        self.delete_record(node, 'node', login_base)
+    @wrap_exception
+    @required_packages_imported
+    def delete_record(self, object, type, login_base = None):
+        if type in ['person']:
+            type = 'user'
+        elif type in ['site']:
+            type = 'authority'
+        
+        if type not in ['user', 'slice', 'node', 'authority']:
+            raise Exception, "Invalid type %(type)s" % locals()    
+     
+        if not login_base:
+            login_bases = self.get_login_bases(object)
+        else:
+            login_bases = [login_base]
 
-    def delete_slice(self, slice, login_base = None):
-        self.delete_record(slice, 'slice', login_base)
+        for login_base in login_bases:
+            login_base = cleanup_string(login_base)
+            hrn = self.get_object_hrn(type, object, self.authority, login_base)
+            if self.sfa_record_exists(hrn, type):
+                self.registry.remove(self.credential, type, hrn) 
 
-    def delete_person(self, person, login_base = None):
-        self.delete_record(person, 'person', login_base)