renamed sfa/plc into sfa/planetlab
[sfa.git] / sfa / importer / plimporter.py
index f8cb4a7..d205a87 100644 (file)
@@ -28,7 +28,7 @@ from sfa.trust.certificate import convert_public_key, Keypair
 from sfa.storage.alchemy import dbsession
 from sfa.storage.model import RegRecord, RegAuthority, RegSlice, RegNode, RegUser, RegKey
 
-from sfa.plc.plshell import PlShell    
+from sfa.planetlab.plshell import PlShell    
 
 def _get_site_hrn(interface_hrn, site):
     # Hardcode 'internet2' into the hrn for sites hosting
@@ -147,7 +147,7 @@ class PlImporter:
 #        sites_by_login_base = dict ( [ ( site['login_base'], site ) for site in sites ] )
         # Get all plc users
         persons = shell.GetPersons({'peer_id': None, 'enabled': True}, 
-                                   ['person_id', 'email', 'key_ids', 'site_ids'])
+                                   ['person_id', 'email', 'key_ids', 'site_ids', 'role_ids'])
         # create a hash of persons by person_id
         persons_by_id = dict ( [ ( person['person_id'], person) for person in persons ] )
         # Get all plc public keys
@@ -155,7 +155,8 @@ class PlImporter:
         key_ids = []
         for person in persons:
             key_ids.extend(person['key_ids'])
-        keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids} )
+        keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids,
+                               'key_type': 'ssh'} )
         # create a hash of keys by key_id
         keys_by_id = dict ( [ ( key['key_id'], key ) for key in keys ] ) 
         # create a dict person_id -> [ (plc)keys ]
@@ -163,7 +164,9 @@ class PlImporter:
         for person in persons:
             pubkeys = []
             for key_id in person['key_ids']:
-                pubkeys.append(keys_by_id[key_id])
+                key = keys_by_id[key_id]
+                if key['key_type'] == 'ssh': 
+                    pubkeys.append(key)
             keys_by_person_id[person['person_id']] = pubkeys
         # Get all plc nodes  
         nodes = shell.GetNodes( {'peer_id': None}, ['node_id', 'hostname', 'site_id'])
@@ -240,6 +243,7 @@ class PlImporter:
                     pass
                 node_record.stale=False
 
+            site_pis=[]
             # import persons
             for person_id in site['person_ids']:
                 try:
@@ -251,7 +255,7 @@ class PlImporter:
                 if len(person_hrn) > 64: person_hrn = person_hrn[:64]
                 person_urn = hrn_to_urn(person_hrn, 'user')
 
-                user_person = self.locate ( 'user', person_hrn, person['person_id'])
+                user_record = self.locate ( 'user', person_hrn, person['person_id'])
 
                 # return a tuple pubkey (a plc key object) and pkey (a Keypair object)
                 def init_person_key (person, plc_keys):
@@ -273,26 +277,28 @@ class PlImporter:
                 # new person
                 try:
                     plc_keys = keys_by_person_id.get(person['person_id'],[])
-                    if not user_person:
+                    if not user_record:
                         (pubkey,pkey) = init_person_key (person, plc_keys )
                         person_gid = self.auth_hierarchy.create_gid(person_urn, create_uuid(), pkey)
-                        user_person = RegUser (hrn=person_hrn, gid=person_gid, 
+                        person_gid.set_email(person['email'])
+                        user_record = RegUser (hrn=person_hrn, gid=person_gid, 
                                                  pointer=person['person_id'], 
                                                  authority=get_authority(person_hrn),
                                                  email=person['email'])
                         if pubkey: 
-                            user_person.reg_keys=[RegKey (pubkey['key'], pubkey['key_id'])]
+                            user_record.reg_keys=[RegKey (pubkey['key'], pubkey['key_id'])]
                         else:
-                            self.logger.warning("No key found for user %s"%user_person)
-                        dbsession.add (user_person)
+                            self.logger.warning("No key found for user %s"%user_record)
+                        user_record.just_created()
+                        dbsession.add (user_record)
                         dbsession.commit()
-                        self.logger.info("PlImporter: imported person: %s" % user_person)
-                        self.remember_record ( user_person )
+                        self.logger.info("PlImporter: imported person: %s" % user_record)
+                        self.remember_record ( user_record )
                     else:
                         # update the record ?
                         # if user's primary key has changed then we need to update the 
                         # users gid by forcing an update here
-                        sfa_keys = user_person.reg_keys
+                        sfa_keys = user_record.reg_keys
                         def key_in_list (key,sfa_keys):
                             for reg_key in sfa_keys:
                                 if reg_key.key==key['key']: return True
@@ -306,16 +312,24 @@ class PlImporter:
                             (pubkey,pkey) = init_person_key (person, plc_keys)
                             person_gid = self.auth_hierarchy.create_gid(person_urn, create_uuid(), pkey)
                             if not pubkey:
-                                user_person.reg_keys=[]
+                                user_record.reg_keys=[]
                             else:
-                                user_person.reg_keys=[ RegKey (pubkey['key'], pubkey['key_id'])]
-                            self.logger.info("PlImporter: updated person: %s" % user_person)
-                    user_person.email = person['email']
+                                user_record.reg_keys=[ RegKey (pubkey['key'], pubkey['key_id'])]
+                            self.logger.info("PlImporter: updated person: %s" % user_record)
+                    user_record.email = person['email']
                     dbsession.commit()
-                    user_person.stale=False
+                    user_record.stale=False
+                    # accumulate PIs - PLCAPI has a limitation that when someone has PI role
+                    # this is valid for all sites she is in..
+                    # PI is coded with role_id==20
+                    if 20 in person['role_ids']:
+                        site_pis.append (user_record)
                 except:
                     self.logger.log_exc("PlImporter: failed to import person %d %s"%(person['person_id'],person['email']))
     
+            # maintain the list of PIs for a given site
+            site_record.reg_pis = site_pis
+
             # import slices
             for slice_id in site['slice_ids']:
                 try: