start to define an abstract interface for the driver
authorThierry Parmentelat <thierry.parmentelat@sophia.inria.fr>
Fri, 25 Nov 2011 17:26:07 +0000 (18:26 +0100)
committerThierry Parmentelat <thierry.parmentelat@sophia.inria.fr>
Fri, 25 Nov 2011 17:26:07 +0000 (18:26 +0100)
sfa/managers/driver.py [new file with mode: 0644]
sfa/managers/registry_manager.py
sfa/plc/pldriver.py

diff --git a/sfa/managers/driver.py b/sfa/managers/driver.py
new file mode 100644 (file)
index 0000000..0709f1d
--- /dev/null
@@ -0,0 +1,21 @@
+# 
+# an attempt to document what a driver class should provide, 
+# and implement reasonable defaults
+#
+
+class Driver:
+    
+    def __init__ (self): pass
+
+    # redefine this if you want to check again records 
+    # when running GetCredential
+    # This is to reflect the 'enabled' user field in planetlab testbeds
+    # expected retcod boolean
+    def is_enabled_entity (self, record, aggregates) : return True
+    
+    # incoming record, as provided by the client to the Register API call
+    # expected retcod 'pointer'
+    # 'pointer' is typically an int db id, that makes sense in the testbed environment
+    # -1 if this feature is not relevant 
+    # here type will be 'authority'
+    def register (self, hrn, sfa_record, pub_key) : return -1
index 7b9ebc5..af0ccbb 100644 (file)
@@ -10,7 +10,7 @@ from sfa.util.faults import RecordNotFound, AccountNotEnabled, PermissionError,
 from sfa.util.prefixTree import prefixTree
 from sfa.util.record import SfaRecord
 from sfa.util.table import SfaTable
-from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn, urn_to_hrn
+from sfa.util.xrn import Xrn, get_authority, hrn_to_urn, urn_to_hrn
 from sfa.util.plxrn import hrn_to_pl_login_base
 from sfa.util.version import version_core
 
@@ -54,9 +54,7 @@ class RegistryManager:
     
         # verify_cancreate_credential requires that the member lists
         # (researchers, pis, etc) be filled in
-        api.driver.fill_record_info(record, api.aggregates)
-        if record['type']=='user':
-           if not record['enabled']:
+        if not api.driver.is_enabled_entity (record, api.aggregates):
               raise AccountNotEnabled(": PlanetLab account %s is not enabled. Please contact your site PI" %(record['email']))
     
         # get the callers gid
@@ -95,13 +93,14 @@ class RegistryManager:
     
     def Resolve(self, api, xrns, type=None, full=True):
     
-        # load all known registry names into a prefix tree and attempt to find
-        # the longest matching prefix
         if not isinstance(xrns, types.ListType):
+            xrns = [xrns]
+            # try to infer type if not set and we get a single input
             if not type:
                 type = Xrn(xrns).get_type()
-            xrns = [xrns]
         hrns = [urn_to_hrn(xrn)[0] for xrn in xrns] 
+        # load all known registry names into a prefix tree and attempt to find
+        # the longest matching prefix
         # create a dict where key is a registry hrn and its value is a
         # hrns at that registry (determined by the known prefix tree).  
         xrn_dict = {}
@@ -133,22 +132,22 @@ class RegistryManager:
                 records.extend([SfaRecord(dict=record).as_dict() for record in peer_records])
     
         # try resolving the remaining unfound records at the local registry
-        remaining_hrns = set(hrns).difference([record['hrn'] for record in records])
-        # convert set to list
-        remaining_hrns = [hrn for hrn in remaining_hrns] 
+        local_hrns = list ( set(hrns).difference([record['hrn'] for record in records]) )
+        # 
         table = SfaTable()
-        local_records = table.findObjects({'hrn': remaining_hrns})
+        local_records = table.findObjects({'hrn': local_hrns})
+        # xxx driver todo
         if full:
             api.driver.fill_record_info(local_records, api.aggregates)
         
         # convert local record objects to dicts
         records.extend([dict(record) for record in local_records])
-        if not records:
-            raise RecordNotFound(str(hrns))
-    
         if type:
             records = filter(lambda rec: rec['type'] in [type], records)
     
+        if not records:
+            raise RecordNotFound(str(hrns))
+    
         return records
     
     def List(self, api, xrn, origin_hrn=None):
@@ -214,8 +213,6 @@ class RegistryManager:
            
         record = SfaRecord(dict = record)
         record['authority'] = get_authority(record['hrn'])
-        type = record['type']
-        hrn = record['hrn']
         auth_info = api.auth.get_auth_info(record['authority'])
         pub_key = None
         # make sure record has a gid
@@ -243,61 +240,17 @@ class RegistryManager:
             # get the GID from the newly created authority
             gid = auth_info.get_gid_object()
             record.set_gid(gid.save_to_string(save_parents=True))
-            pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record)
-            sites = api.driver.GetSites([pl_record['login_base']])
-            if not sites:
-                pointer = api.driver.AddSite(pl_record)
-            else:
-                pointer = sites[0]['site_id']
-    
-            record.set_pointer(pointer)
-            record['pointer'] = pointer
+            pointer = api.driver.register (hrn, record, pub_key)
     
         elif (type == "slice"):
-            acceptable_fields=['url', 'instantiation', 'name', 'description']
-            pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record)
-            for key in pl_record.keys():
-                if key not in acceptable_fields:
-                    pl_record.pop(key)
-            slices = api.driver.GetSlices([pl_record['name']])
-            if not slices:
-                 pointer = api.driver.AddSlice(pl_record)
-            else:
-                 pointer = slices[0]['slice_id']
-            record.set_pointer(pointer)
-            record['pointer'] = pointer
+            pointer = api.driver.register (hrn, record, pub_key)
     
         elif  (type == "user"):
-            persons = api.driver.GetPersons([record['email']])
-            if not persons:
-                pointer = api.driver.AddPerson(dict(record))
-            else:
-                pointer = persons[0]['person_id']
-    
-            if 'enabled' in record and record['enabled']:
-                api.driver.UpdatePerson(pointer, {'enabled': record['enabled']})
-            # add this persons to the site only if he is being added for the first
-            # time by sfa and doesont already exist in plc
-            if not persons or not persons[0]['site_ids']:
-                login_base = get_leaf(record['authority'])
-                api.driver.AddPersonToSite(pointer, login_base)
-    
-            # What roles should this user have?
-            api.driver.AddRoleToPerson('user', pointer)
-            # Add the user's key
-            if pub_key:
-                api.driver.AddPersonKey(pointer, {'key_type' : 'ssh', 'key' : pub_key})
+            pointer = api.driver.register (hrn, record, pub_key)
     
         elif (type == "node"):
-            pl_record = api.driver.sfa_fields_to_pl_fields(type, hrn, record)
-            login_base = hrn_to_pl_login_base(record['authority'])
-            nodes = api.driver.GetNodes([pl_record['hostname']])
-            if not nodes:
-                pointer = api.driver.AddNode(login_base, pl_record)
-            else:
-                pointer = nodes[0]['node_id']
-    
-        record['pointer'] = pointer
+            pointer = api.driver.register (hrn, record, pub_key)
+
         record.set_pointer(pointer)
         record_id = table.insert(record)
         record['record_id'] = record_id
index af3b213..2b9607e 100644 (file)
@@ -4,9 +4,12 @@ from sfa.util.sfalogging import logger
 from sfa.util.table import SfaTable
 from sfa.util.defaultdict import defaultdict
 
-from sfa.util.xrn import hrn_to_urn
+from sfa.util.xrn import hrn_to_urn, get_leaf
 from sfa.util.plxrn import slicename_to_hrn, hostname_to_hrn, hrn_to_pl_slicename, hrn_to_pl_login_base
 
+# the driver interface, mostly provides default behaviours
+from sfa.managers.driver import Driver
+
 from sfa.plc.plshell import PlShell
 
 def list_to_dict(recs, key):
@@ -17,7 +20,19 @@ def list_to_dict(recs, key):
     keys = [rec[key] for rec in recs]
     return dict(zip(keys, recs))
 
-class PlDriver (PlShell):
+#
+# inheriting Driver is not very helpful in the PL case but
+# makes sense in the general case
+# 
+# PlShell is just an xmlrpc serverproxy where methods
+# can be sent as-is; it takes care of authentication
+# from the global config
+# 
+# so OTOH we inherit PlShell just so one can do driver.GetNodes
+# which would not make much sense in the context of other testbeds
+# so ultimately PlDriver might just as well drop the PlShell inheritance
+# 
+class PlDriver (Driver, PlShell):
 
     def __init__ (self, config):
         PlShell.__init__ (self, config)
@@ -31,62 +46,127 @@ class PlDriver (PlShell):
         assert (rspec_type == 'pl' or rspec_type == 'vini' or \
                     rspec_type == 'eucalyptus' or rspec_type == 'max')
 
+    ########## disabled users 
+    def is_enabled_entity (self, record, aggregates):
+        self.fill_record_info(record, api.aggregates)
+        if record['type'] == 'user':
+            return record['enabled']
+        # only users can be disabled
+        return True
+
+    ########## 
+    def register (self, hrn, sfa_record, pub_key):
+        type = sfa_record['type']
+        pl_record = self.sfa_fields_to_pl_fields(type, hrn, sfa_record)
+
+        if type == 'authority':
+            sites = self.GetSites([pl_record['login_base']])
+            if not sites:
+                pointer = self.AddSite(pl_record)
+            else:
+                pointer = sites[0]['site_id']
+
+        elif type == 'slice':
+            acceptable_fields=['url', 'instantiation', 'name', 'description']
+            for key in pl_record.keys():
+                if key not in acceptable_fields:
+                    pl_record.pop(key)
+            slices = self.GetSlices([pl_record['name']])
+            if not slices:
+                 pointer = self.AddSlice(pl_record)
+            else:
+                 pointer = slices[0]['slice_id']
+
+        elif type == 'user':
+            persons = self.GetPersons([sfa_record['email']])
+            if not persons:
+                pointer = self.AddPerson(dict(sfa_record))
+            else:
+                pointer = persons[0]['person_id']
+    
+            if 'enabled' in sfa_record and sfa_record['enabled']:
+                self.UpdatePerson(pointer, {'enabled': sfa_record['enabled']})
+            # add this person to the site only if she is being added for the first
+            # time by sfa and doesont already exist in plc
+            if not persons or not persons[0]['site_ids']:
+                login_base = get_leaf(sfa_record['authority'])
+                self.AddPersonToSite(pointer, login_base)
+    
+            # What roles should this user have?
+            self.AddRoleToPerson('user', pointer)
+            # Add the user's key
+            if pub_key:
+                self.AddPersonKey(pointer, {'key_type' : 'ssh', 'key' : pub_key})
+
+        elif type == 'node':
+            login_base = hrn_to_pl_login_base(sfa_record['authority'])
+            nodes = api.driver.GetNodes([pl_record['hostname']])
+            if not nodes:
+                pointer = api.driver.AddNode(login_base, pl_record)
+            else:
+                pointer = nodes[0]['node_id']
+    
+        return pointer
+        
+
     ##
     # Convert SFA fields to PLC fields for use when registering up updating
     # registry record in the PLC database
     #
-    # @param type type of record (user, slice, ...)
-    # @param hrn human readable name
-    # @param sfa_fields dictionary of SFA fields
-    # @param pl_fields dictionary of PLC fields (output)
-
-    def sfa_fields_to_pl_fields(self, type, hrn, record):
 
-        def convert_ints(tmpdict, int_fields):
-            for field in int_fields:
-                if field in tmpdict:
-                    tmpdict[field] = int(tmpdict[field])
+    def sfa_fields_to_pl_fields(self, type, hrn, sfa_record):
 
         pl_record = {}
-        #for field in record:
-        #    pl_record[field] = record[field]
  
         if type == "slice":
-            if not "instantiation" in pl_record:
-                pl_record["instantiation"] = "plc-instantiated"
             pl_record["name"] = hrn_to_pl_slicename(hrn)
-           if "url" in record:
-               pl_record["url"] = record["url"]
-           if "description" in record:
-               pl_record["description"] = record["description"]
-           if "expires" in record:
-               pl_record["expires"] = int(record["expires"])
+            if "instantiation" in sfa_record:
+                pl_record['instantiation']=sfa_record['instantiation']
+            else:
+                pl_record["instantiation"] = "plc-instantiated"
+           if "url" in sfa_record:
+               pl_record["url"] = sfa_record["url"]
+           if "description" in sfa_record:
+               pl_record["description"] = sfa_record["description"]
+           if "expires" in sfa_record:
+               pl_record["expires"] = int(sfa_record["expires"])
 
         elif type == "node":
             if not "hostname" in pl_record:
-                if not "hostname" in record:
+                # fetch from sfa_record
+                if "hostname" not in sfa_record:
                     raise MissingSfaInfo("hostname")
-                pl_record["hostname"] = record["hostname"]
-            if not "model" in pl_record:
+                pl_record["hostname"] = sfa_record["hostname"]
+            if "model" in sfa_record: 
+                pl_record["model"] = sfa_record["model"]
+            else:
                 pl_record["model"] = "geni"
 
         elif type == "authority":
             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
-
-            if not "name" in pl_record:
+            if "name" not in sfa_record:
                 pl_record["name"] = hrn
-
-            if not "abbreviated_name" in pl_record:
+            if "abbreviated_name" not in sfa_record:
                 pl_record["abbreviated_name"] = hrn
-
-            if not "enabled" in pl_record:
+            if "enabled" not in sfa_record:
                 pl_record["enabled"] = True
-
-            if not "is_public" in pl_record:
+            if "is_public" not in sfa_record:
                 pl_record["is_public"] = True
 
         return pl_record
 
+    ####################
+    def fill_record_info(self, records, aggregates):
+        """
+        Given a (list of) SFA record, fill in the PLC specific 
+        and SFA specific fields in the record. 
+        """
+        if not isinstance(records, list):
+            records = [records]
+
+        self.fill_record_pl_info(records)
+        self.fill_record_sfa_info(records, aggregates)
+
     def fill_record_pl_info(self, records):
         """
         Fill in the planetlab specific fields of a SFA record. This
@@ -274,7 +354,7 @@ class PlDriver (PlShell):
         person_list, persons = [], {}
         person_list = table.find({'type': 'user', 'pointer': person_ids})
         # create a hrns keyed on the sfa record's pointer.
-        # Its possible for  multiple records to have the same pointer so
+        # Its possible for multiple records to have the same pointer so
         # the dict's value will be a list of hrns.
         persons = defaultdict(list)
         for person in person_list:
@@ -344,16 +424,15 @@ class PlDriver (PlShell):
                 # xxx TODO: PostalAddress, Phone
             record.update(sfa_info)
 
-    def fill_record_info(self, records, aggregates):
-        """
-        Given a SFA record, fill in the PLC specific and SFA specific
-        fields in the record. 
-        """
-        if not isinstance(records, list):
-            records = [records]
-
-        self.fill_record_pl_info(records)
-        self.fill_record_sfa_info(records, aggregates)
+    ####################
+    def update_membership(self, oldRecord, record):
+        if record.type == "slice":
+            self.update_membership_list(oldRecord, record, 'researcher',
+                                        self.AddPersonToSlice,
+                                        self.DeletePersonFromSlice)
+        elif record.type == "authority":
+            # xxx TODO
+            pass
 
     def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc):
         # get a list of the HRNs that are members of the old and new records
@@ -396,12 +475,3 @@ class PlDriver (PlShell):
         for personId in oldIdList:
             if not (personId in newIdList):
                 delFunc(personId, containerId)
-
-    def update_membership(self, oldRecord, record):
-        if record.type == "slice":
-            self.update_membership_list(oldRecord, record, 'researcher',
-                                        self.AddPersonToSlice,
-                                        self.DeletePersonFromSlice)
-        elif record.type == "authority":
-            # xxx TODO
-            pass