start to define an abstract interface for the driver
[sfa.git] / sfa / plc / pldriver.py
index af3b213..2b9607e 100644 (file)
@@ -4,9 +4,12 @@ from sfa.util.sfalogging import logger
 from sfa.util.table import SfaTable
 from sfa.util.defaultdict import defaultdict
 
-from sfa.util.xrn import hrn_to_urn
+from sfa.util.xrn import hrn_to_urn, get_leaf
 from sfa.util.plxrn import slicename_to_hrn, hostname_to_hrn, hrn_to_pl_slicename, hrn_to_pl_login_base
 
+# the driver interface, mostly provides default behaviours
+from sfa.managers.driver import Driver
+
 from sfa.plc.plshell import PlShell
 
 def list_to_dict(recs, key):
@@ -17,7 +20,19 @@ def list_to_dict(recs, key):
     keys = [rec[key] for rec in recs]
     return dict(zip(keys, recs))
 
-class PlDriver (PlShell):
+#
+# inheriting Driver is not very helpful in the PL case but
+# makes sense in the general case
+# 
+# PlShell is just an xmlrpc serverproxy where methods
+# can be sent as-is; it takes care of authentication
+# from the global config
+# 
+# so OTOH we inherit PlShell just so one can do driver.GetNodes
+# which would not make much sense in the context of other testbeds
+# so ultimately PlDriver might just as well drop the PlShell inheritance
+# 
+class PlDriver (Driver, PlShell):
 
     def __init__ (self, config):
         PlShell.__init__ (self, config)
@@ -31,62 +46,127 @@ class PlDriver (PlShell):
         assert (rspec_type == 'pl' or rspec_type == 'vini' or \
                     rspec_type == 'eucalyptus' or rspec_type == 'max')
 
+    ########## disabled users 
+    def is_enabled_entity (self, record, aggregates):
+        self.fill_record_info(record, api.aggregates)
+        if record['type'] == 'user':
+            return record['enabled']
+        # only users can be disabled
+        return True
+
+    ########## 
+    def register (self, hrn, sfa_record, pub_key):
+        type = sfa_record['type']
+        pl_record = self.sfa_fields_to_pl_fields(type, hrn, sfa_record)
+
+        if type == 'authority':
+            sites = self.GetSites([pl_record['login_base']])
+            if not sites:
+                pointer = self.AddSite(pl_record)
+            else:
+                pointer = sites[0]['site_id']
+
+        elif type == 'slice':
+            acceptable_fields=['url', 'instantiation', 'name', 'description']
+            for key in pl_record.keys():
+                if key not in acceptable_fields:
+                    pl_record.pop(key)
+            slices = self.GetSlices([pl_record['name']])
+            if not slices:
+                 pointer = self.AddSlice(pl_record)
+            else:
+                 pointer = slices[0]['slice_id']
+
+        elif type == 'user':
+            persons = self.GetPersons([sfa_record['email']])
+            if not persons:
+                pointer = self.AddPerson(dict(sfa_record))
+            else:
+                pointer = persons[0]['person_id']
+    
+            if 'enabled' in sfa_record and sfa_record['enabled']:
+                self.UpdatePerson(pointer, {'enabled': sfa_record['enabled']})
+            # add this person to the site only if she is being added for the first
+            # time by sfa and doesont already exist in plc
+            if not persons or not persons[0]['site_ids']:
+                login_base = get_leaf(sfa_record['authority'])
+                self.AddPersonToSite(pointer, login_base)
+    
+            # What roles should this user have?
+            self.AddRoleToPerson('user', pointer)
+            # Add the user's key
+            if pub_key:
+                self.AddPersonKey(pointer, {'key_type' : 'ssh', 'key' : pub_key})
+
+        elif type == 'node':
+            login_base = hrn_to_pl_login_base(sfa_record['authority'])
+            nodes = api.driver.GetNodes([pl_record['hostname']])
+            if not nodes:
+                pointer = api.driver.AddNode(login_base, pl_record)
+            else:
+                pointer = nodes[0]['node_id']
+    
+        return pointer
+        
+
     ##
     # Convert SFA fields to PLC fields for use when registering up updating
     # registry record in the PLC database
     #
-    # @param type type of record (user, slice, ...)
-    # @param hrn human readable name
-    # @param sfa_fields dictionary of SFA fields
-    # @param pl_fields dictionary of PLC fields (output)
-
-    def sfa_fields_to_pl_fields(self, type, hrn, record):
 
-        def convert_ints(tmpdict, int_fields):
-            for field in int_fields:
-                if field in tmpdict:
-                    tmpdict[field] = int(tmpdict[field])
+    def sfa_fields_to_pl_fields(self, type, hrn, sfa_record):
 
         pl_record = {}
-        #for field in record:
-        #    pl_record[field] = record[field]
  
         if type == "slice":
-            if not "instantiation" in pl_record:
-                pl_record["instantiation"] = "plc-instantiated"
             pl_record["name"] = hrn_to_pl_slicename(hrn)
-           if "url" in record:
-               pl_record["url"] = record["url"]
-           if "description" in record:
-               pl_record["description"] = record["description"]
-           if "expires" in record:
-               pl_record["expires"] = int(record["expires"])
+            if "instantiation" in sfa_record:
+                pl_record['instantiation']=sfa_record['instantiation']
+            else:
+                pl_record["instantiation"] = "plc-instantiated"
+           if "url" in sfa_record:
+               pl_record["url"] = sfa_record["url"]
+           if "description" in sfa_record:
+               pl_record["description"] = sfa_record["description"]
+           if "expires" in sfa_record:
+               pl_record["expires"] = int(sfa_record["expires"])
 
         elif type == "node":
             if not "hostname" in pl_record:
-                if not "hostname" in record:
+                # fetch from sfa_record
+                if "hostname" not in sfa_record:
                     raise MissingSfaInfo("hostname")
-                pl_record["hostname"] = record["hostname"]
-            if not "model" in pl_record:
+                pl_record["hostname"] = sfa_record["hostname"]
+            if "model" in sfa_record: 
+                pl_record["model"] = sfa_record["model"]
+            else:
                 pl_record["model"] = "geni"
 
         elif type == "authority":
             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
-
-            if not "name" in pl_record:
+            if "name" not in sfa_record:
                 pl_record["name"] = hrn
-
-            if not "abbreviated_name" in pl_record:
+            if "abbreviated_name" not in sfa_record:
                 pl_record["abbreviated_name"] = hrn
-
-            if not "enabled" in pl_record:
+            if "enabled" not in sfa_record:
                 pl_record["enabled"] = True
-
-            if not "is_public" in pl_record:
+            if "is_public" not in sfa_record:
                 pl_record["is_public"] = True
 
         return pl_record
 
+    ####################
+    def fill_record_info(self, records, aggregates):
+        """
+        Given a (list of) SFA record, fill in the PLC specific 
+        and SFA specific fields in the record. 
+        """
+        if not isinstance(records, list):
+            records = [records]
+
+        self.fill_record_pl_info(records)
+        self.fill_record_sfa_info(records, aggregates)
+
     def fill_record_pl_info(self, records):
         """
         Fill in the planetlab specific fields of a SFA record. This
@@ -274,7 +354,7 @@ class PlDriver (PlShell):
         person_list, persons = [], {}
         person_list = table.find({'type': 'user', 'pointer': person_ids})
         # create a hrns keyed on the sfa record's pointer.
-        # Its possible for  multiple records to have the same pointer so
+        # Its possible for multiple records to have the same pointer so
         # the dict's value will be a list of hrns.
         persons = defaultdict(list)
         for person in person_list:
@@ -344,16 +424,15 @@ class PlDriver (PlShell):
                 # xxx TODO: PostalAddress, Phone
             record.update(sfa_info)
 
-    def fill_record_info(self, records, aggregates):
-        """
-        Given a SFA record, fill in the PLC specific and SFA specific
-        fields in the record. 
-        """
-        if not isinstance(records, list):
-            records = [records]
-
-        self.fill_record_pl_info(records)
-        self.fill_record_sfa_info(records, aggregates)
+    ####################
+    def update_membership(self, oldRecord, record):
+        if record.type == "slice":
+            self.update_membership_list(oldRecord, record, 'researcher',
+                                        self.AddPersonToSlice,
+                                        self.DeletePersonFromSlice)
+        elif record.type == "authority":
+            # xxx TODO
+            pass
 
     def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc):
         # get a list of the HRNs that are members of the old and new records
@@ -396,12 +475,3 @@ class PlDriver (PlShell):
         for personId in oldIdList:
             if not (personId in newIdList):
                 delFunc(personId, containerId)
-
-    def update_membership(self, oldRecord, record):
-        if record.type == "slice":
-            self.update_membership_list(oldRecord, record, 'researcher',
-                                        self.AddPersonToSlice,
-                                        self.DeletePersonFromSlice)
-        elif record.type == "authority":
-            # xxx TODO
-            pass