split api and driver
[sfa.git] / sfa / plc / plcdriver.py
similarity index 82%
rename from sfa/plc/plcsfaapi.py
rename to sfa/plc/plcdriver.py
index 842df31..e920ce6 100644 (file)
@@ -1,4 +1,3 @@
-import xmlrpclib
 #
 from sfa.util.faults import MissingSfaInfo
 from sfa.util.sfalogging import logger
@@ -8,7 +7,7 @@ from sfa.util.defaultdict import defaultdict
 from sfa.util.xrn import hrn_to_urn
 from sfa.util.plxrn import slicename_to_hrn, hostname_to_hrn, hrn_to_pl_slicename, hrn_to_pl_login_base
 
-from sfa.server.sfaapi import SfaApi
+from sfa.plc.plcshell import PlcShell
 
 def list_to_dict(recs, key):
     """
@@ -18,45 +17,17 @@ def list_to_dict(recs, key):
     keys = [rec[key] for rec in recs]
     return dict(zip(keys, recs))
 
-class PlcSfaApi(SfaApi):
-
-    def __init__ (self, encoding="utf-8", methods='sfa.methods', 
-                  config = "/etc/sfa/sfa_config.py", 
-                  peer_cert = None, interface = None, 
-                  key_file = None, cert_file = None, cache = None):
-        SfaApi.__init__(self, encoding=encoding, methods=methods, 
-                        config=config, 
-                        peer_cert=peer_cert, interface=interface, 
-                        key_file=key_file, 
-                        cert_file=cert_file, cache=cache)
+class PlcDriver (PlcShell):
+
+    def __init__ (self, config):
+        PlcShell.__init__ (self, config)
  
+        self.hrn = config.SFA_INTERFACE_HRN
         self.SfaTable = SfaTable
         # Initialize the PLC shell only if SFA wraps a myPLC
-        rspec_type = self.config.get_aggregate_type()
-        if (rspec_type == 'pl' or rspec_type == 'vini' or \
-            rspec_type == 'eucalyptus' or rspec_type == 'max'):
-            self.plshell = self.getPLCShell()
-            self.plshell_version = "4.3"
-
-    def getPLCShell(self):
-        self.plauth = {'Username': self.config.SFA_PLC_USER,
-                       'AuthMethod': 'password',
-                       'AuthString': self.config.SFA_PLC_PASSWORD}
-
-        # The native shell (PLC.Shell.Shell) is more efficient than xmlrpc,
-        # but it leaves idle db connections open. use xmlrpc until we can figure
-        # out why PLC.Shell.Shell doesn't close db connection properly     
-        #try:
-        #    sys.path.append(os.path.dirname(os.path.realpath("/usr/bin/plcsh")))
-        #    self.plshell_type = 'direct'
-        #    import PLC.Shell
-        #    shell = PLC.Shell.Shell(globals = globals())
-        #except:
-        
-        self.plshell_type = 'xmlrpc' 
-        url = self.config.SFA_PLC_URL
-        shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
-        return shell
+        rspec_type = config.get_aggregate_type()
+        assert (rspec_type == 'pl' or rspec_type == 'vini' or \
+                    rspec_type == 'eucalyptus' or rspec_type == 'max')
 
     ##
     # Convert SFA fields to PLC fields for use when registering up updating
@@ -138,16 +109,16 @@ class PlcSfaApi(SfaApi):
         # get pl records
         nodes, sites, slices, persons, keys = {}, {}, {}, {}, {}
         if node_ids:
-            node_list = self.plshell.GetNodes(self.plauth, node_ids)
+            node_list = self.GetNodes(node_ids)
             nodes = list_to_dict(node_list, 'node_id')
         if site_ids:
-            site_list = self.plshell.GetSites(self.plauth, site_ids)
+            site_list = self.GetSites(site_ids)
             sites = list_to_dict(site_list, 'site_id')
         if slice_ids:
-            slice_list = self.plshell.GetSlices(self.plauth, slice_ids)
+            slice_list = self.GetSlices(slice_ids)
             slices = list_to_dict(slice_list, 'slice_id')
         if person_ids:
-            person_list = self.plshell.GetPersons(self.plauth, person_ids)
+            person_list = self.GetPersons(person_ids)
             persons = list_to_dict(person_list, 'person_id')
             for person in persons:
                 key_ids.extend(persons[person]['key_ids'])
@@ -156,7 +127,7 @@ class PlcSfaApi(SfaApi):
                       'slice': slices, 'user': persons}
 
         if key_ids:
-            key_list = self.plshell.GetKeys(self.plauth, key_ids)
+            key_list = self.GetKeys(key_ids)
             keys = list_to_dict(key_list, 'key_id')
 
         # fill record info
@@ -207,16 +178,16 @@ class PlcSfaApi(SfaApi):
         # get pl records
         slices, persons, sites, nodes = {}, {}, {}, {}
         if site_ids:
-            site_list = self.plshell.GetSites(self.plauth, site_ids, ['site_id', 'login_base'])
+            site_list = self.GetSites(site_ids, ['site_id', 'login_base'])
             sites = list_to_dict(site_list, 'site_id')
         if person_ids:
-            person_list = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'email'])
+            person_list = self.GetPersons(person_ids, ['person_id', 'email'])
             persons = list_to_dict(person_list, 'person_id')
         if slice_ids:
-            slice_list = self.plshell.GetSlices(self.plauth, slice_ids, ['slice_id', 'name'])
+            slice_list = self.GetSlices(slice_ids, ['slice_id', 'name'])
             slices = list_to_dict(slice_list, 'slice_id')       
         if node_ids:
-            node_list = self.plshell.GetNodes(self.plauth, node_ids, ['node_id', 'hostname'])
+            node_list = self.GetNodes(node_ids, ['node_id', 'hostname'])
             nodes = list_to_dict(node_list, 'node_id')
        
         # convert ids to hrns
@@ -257,7 +228,8 @@ class PlcSfaApi(SfaApi):
             
         return records   
 
-    def fill_record_sfa_info(self, records):
+    # aggregates is basically api.aggregates
+    def fill_record_sfa_info(self, records, aggregates):
 
         def startswith(prefix, values):
             return [value for value in values if value.startswith(prefix)]
@@ -276,7 +248,7 @@ class PlcSfaApi(SfaApi):
         site_pis = {}
         if site_ids:
             pi_filter = {'|roles': ['pi'], '|site_ids': site_ids} 
-            pi_list = self.plshell.GetPersons(self.plauth, pi_filter, ['person_id', 'site_ids'])
+            pi_list = self.GetPersons(pi_filter, ['person_id', 'site_ids'])
             for pi in pi_list:
                 # we will need the pi's hrns also
                 person_ids.append(pi['person_id'])
@@ -306,7 +278,7 @@ class PlcSfaApi(SfaApi):
 
         # get the pl records
         pl_person_list, pl_persons = [], {}
-        pl_person_list = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'roles'])
+        pl_person_list = self.GetPersons(person_ids, ['person_id', 'roles'])
         pl_persons = list_to_dict(pl_person_list, 'person_id')
 
         # fill sfa info
@@ -336,9 +308,9 @@ class PlcSfaApi(SfaApi):
                 
             elif (type.startswith("authority")):
                 record['url'] = None
-                if record['hrn'] in self.aggregates:
+                if record['hrn'] in aggregates:
                     
-                    record['url'] = self.aggregates[record['hrn']].get_url()
+                    record['url'] = aggregates[record['hrn']].get_url()
 
                 if record['pointer'] != -1:
                     record['PI'] = []
@@ -368,7 +340,7 @@ class PlcSfaApi(SfaApi):
                 # xxx TODO: PostalAddress, Phone
             record.update(sfa_info)
 
-    def fill_record_info(self, records):
+    def fill_record_info(self, records, aggregates):
         """
         Given a SFA record, fill in the PLC specific and SFA specific
         fields in the record. 
@@ -377,7 +349,7 @@ class PlcSfaApi(SfaApi):
             records = [records]
 
         self.fill_record_pl_info(records)
-        self.fill_record_sfa_info(records)
+        self.fill_record_sfa_info(records, aggregates)
 
     def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc):
         # get a list of the HRNs that are members of the old and new records
@@ -412,18 +384,18 @@ class PlcSfaApi(SfaApi):
     # add people who are in the new list, but not the oldList
         for personId in newIdList:
             if not (personId in oldIdList):
-                addFunc(self.plauth, personId, containerId)
+                addFunc(personId, containerId)
 
         # remove people who are in the old list, but not the new list
         for personId in oldIdList:
             if not (personId in newIdList):
-                delFunc(self.plauth, personId, containerId)
+                delFunc(personId, containerId)
 
     def update_membership(self, oldRecord, record):
         if record.type == "slice":
             self.update_membership_list(oldRecord, record, 'researcher',
-                                        self.plshell.AddPersonToSlice,
-                                        self.plshell.DeletePersonFromSlice)
+                                        self.AddPersonToSlice,
+                                        self.DeletePersonFromSlice)
         elif record.type == "authority":
             # xxx TODO
             pass