Fixes for Eucalyptus aggregate manager
[sfa.git] / sfa / plc / api.py
index b4d295b..c0c0d13 100644 (file)
@@ -21,6 +21,47 @@ from sfa.util.namespace import *
 from sfa.util.api import *
 from sfa.util.nodemanager import NodeManager
 from sfa.util.sfalogging import *
+try:
+    from collections import defaultdict
+except:
+    class defaultdict(dict):
+        def __init__(self, default_factory=None, *a, **kw):
+            if (default_factory is not None and
+                not hasattr(default_factory, '__call__')):
+                raise TypeError('first argument must be callable')
+            dict.__init__(self, *a, **kw)
+            self.default_factory = default_factory
+        def __getitem__(self, key):
+            try:
+                return dict.__getitem__(self, key)
+            except KeyError:
+                return self.__missing__(key)
+        def __missing__(self, key):
+            if self.default_factory is None:
+                raise KeyError(key)
+            self[key] = value = self.default_factory()
+            return value
+        def __reduce__(self):
+            if self.default_factory is None:
+                args = tuple()
+            else:
+                args = self.default_factory,
+            return type(self), args, None, None, self.items()
+        def copy(self):
+            return self.__copy__()
+        def __copy__(self):
+            return type(self)(self.default_factory, self)
+        def __deepcopy__(self, memo):
+            import copy
+            return type(self)(self.default_factory,
+                              copy.deepcopy(self.items()))
+        def __repr__(self):
+            return 'defaultdict(%s, %s)' % (self.default_factory,
+                                            dict.__repr__(self))
+
+
+## end of http://code.activestate.com/recipes/523034/ }}}
+
 
 def list_to_dict(recs, key):
     """
@@ -36,11 +77,12 @@ class SfaAPI(BaseAPI):
     import sfa.methods
     methods = sfa.methods.all
     
-    def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods', \
-                 peer_cert = None, interface = None, key_file = None, cert_file = None):
+    def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", 
+                 methods='sfa.methods', peer_cert = None, interface = None, 
+                key_file = None, cert_file = None, cache = None):
         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, \
                          peer_cert=peer_cert, interface=interface, key_file=key_file, \
-                         cert_file=cert_file)
+                         cert_file=cert_file, cache=cache)
  
         self.encoding = encoding
 
@@ -61,7 +103,7 @@ class SfaAPI(BaseAPI):
         self.credential = None
         # Initialize the PLC shell only if SFA wraps a myPLC
         rspec_type = self.config.get_aggregate_type()
-        if (rspec_type == 'pl' or rspec_type == 'vini'):
+        if (rspec_type == 'pl' or rspec_type == 'vini' or rspec_type == 'eucalyptus'):
             self.plshell = self.getPLCShell()
             self.plshell_version = "4.3"
 
@@ -74,6 +116,7 @@ class SfaAPI(BaseAPI):
                        'AuthMethod': 'password',
                        'AuthString': self.config.SFA_PLC_PASSWORD}
 
+
         self.plshell_type = 'xmlrpc' 
         # connect via xmlrpc
         url = self.config.SFA_PLC_URL
@@ -133,8 +176,8 @@ class SfaAPI(BaseAPI):
         new_cred = Credential(subject = object_gid.get_subject())
         new_cred.set_gid_caller(object_gid)
         new_cred.set_gid_object(object_gid)
-        new_cred.set_issuer(key=auth_info.get_pkey_object(), subject=auth_hrn)
-        new_cred.set_pubkey(object_gid.get_pubkey())
+        new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename())
+        
         r1 = determine_rights(type, hrn)
         new_cred.set_privileges(r1)
 
@@ -361,6 +404,10 @@ class SfaAPI(BaseAPI):
         return records   
 
     def fill_record_sfa_info(self, records):
+
+        def startswith(prefix, values):
+            return [value for value in values if value.startswith(prefix)]
+
         # get person ids
         person_ids = []
         site_ids = []
@@ -396,7 +443,12 @@ class SfaAPI(BaseAPI):
         table = self.SfaTable()
         person_list, persons = [], {}
         person_list = table.find({'type': 'user', 'pointer': person_ids})
-        persons = list_to_dict(person_list, 'pointer')
+        # create a hrns keyed on the sfa record's pointer.
+        # Its possible for  multiple records to have the same pointer so
+        # the dict's value will be a list of hrns.
+        persons = defaultdict(list)
+        for person in person_list:
+            persons[person['pointer']].append(person)
 
         # get the pl records
         pl_person_list, pl_persons = [], {}
@@ -411,32 +463,36 @@ class SfaAPI(BaseAPI):
             sfa_info = {}
             type = record['type']
             if (type == "slice"):
-                # slice users
-                researchers = [persons[person_id]['hrn'] for person_id in record['person_ids'] \
-                               if person_id in persons] 
-                sfa_info['researcher'] = researchers
+                # all slice users are researchers
+                record['PI'] = []
+                record['researcher'] = []
+                for person_id in record['person_ids']:
+                    hrns = [person['hrn'] for person in persons[person_id]]
+                    record['researcher'].extend(hrns)                
+
                 # pis at the slice's site
                 pl_pis = site_pis[record['site_id']]
-                pi_ids = [pi['person_id'] for pi in pl_pis] 
-                sfa_info['PI'] = [persons[person_id]['hrn'] for person_id in pi_ids]
+                pi_ids = [pi['person_id'] for pi in pl_pis]
+                for person_id in pi_ids:
+                    hrns = [person['hrn'] for person in persons[person_id]]
+                    record['PI'].extend(hrns)
                 
             elif (type == "authority"):
-                pis, techs, admins = [], [], []
+                record['PI'] = []
+                record['operator'] = []
+                record['owner'] = []
                 for pointer in record['person_ids']:
                     if pointer not in persons or pointer not in pl_persons:
                         # this means there is not sfa or pl record for this user
                         continue   
-                    hrn = persons[pointer]['hrn'
+                    hrns = [person['hrn'] for person in persons[pointer]
                     roles = pl_persons[pointer]['roles']   
                     if 'pi' in roles:
-                        pis.append(hrn)
+                        record['PI'].extend(hrns)
                     if 'tech' in roles:
-                        techs.append(hrn)
+                        record['operator'].extend(hrns)
                     if 'admin' in roles:
-                        admins.append(hrn)
-                    sfa_info['PI'] = pis
-                    sfa_info['operator'] = techs
-                    sfa_info['owner'] = admins
+                        record['owner'].extend(hrns)
                     # xxx TODO: OrganizationName
             elif (type == "node"):
                 sfa_info['dns'] = record.get("hostname", "")