Merge branch 'master' of ssh://git.onelab.eu/git/sfa
[sfa.git] / sfa / importer / plimporter.py
index 153104b..f6ecc45 100644 (file)
@@ -20,7 +20,6 @@ import os
 
 from sfa.util.config import Config
 from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn
-from sfa.util.plxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_pl_slicename
 
 from sfa.trust.gid import create_uuid    
 from sfa.trust.certificate import convert_public_key, Keypair
@@ -28,7 +27,8 @@ from sfa.trust.certificate import convert_public_key, Keypair
 from sfa.storage.alchemy import dbsession
 from sfa.storage.model import RegRecord, RegAuthority, RegSlice, RegNode, RegUser, RegKey
 
-from sfa.plc.plshell import PlShell    
+from sfa.planetlab.plshell import PlShell    
+from sfa.planetlab.plxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_pl_slicename
 
 def _get_site_hrn(interface_hrn, site):
     # Hardcode 'internet2' into the hrn for sites hosting
@@ -147,7 +147,7 @@ class PlImporter:
 #        sites_by_login_base = dict ( [ ( site['login_base'], site ) for site in sites ] )
         # Get all plc users
         persons = shell.GetPersons({'peer_id': None, 'enabled': True}, 
-                                   ['person_id', 'email', 'key_ids', 'site_ids'])
+                                   ['person_id', 'email', 'key_ids', 'site_ids', 'role_ids'])
         # create a hash of persons by person_id
         persons_by_id = dict ( [ ( person['person_id'], person) for person in persons ] )
         # Get all plc public keys
@@ -155,7 +155,8 @@ class PlImporter:
         key_ids = []
         for person in persons:
             key_ids.extend(person['key_ids'])
-        keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids} )
+        keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids,
+                               'key_type': 'ssh'} )
         # create a hash of keys by key_id
         keys_by_id = dict ( [ ( key['key_id'], key ) for key in keys ] ) 
         # create a dict person_id -> [ (plc)keys ]
@@ -163,7 +164,9 @@ class PlImporter:
         for person in persons:
             pubkeys = []
             for key_id in person['key_ids']:
-                pubkeys.append(keys_by_id[key_id])
+                key = keys_by_id[key_id]
+                if key['key_type'] == 'ssh': 
+                    pubkeys.append(key)
             keys_by_person_id[person['person_id']] = pubkeys
         # Get all plc nodes  
         nodes = shell.GetNodes( {'peer_id': None}, ['node_id', 'hostname', 'site_id'])
@@ -177,6 +180,11 @@ class PlImporter:
         # isolate special vini case in separate method
         self.create_special_vini_record (interface_hrn)
 
+        def check_hrn (record, hrn):
+            if record.hrn != hrn:
+                record.hrn=hrn
+                dbsession.commit()
+
         # start importing 
         for site in sites:
             site_hrn = _get_site_hrn(interface_hrn, site)
@@ -203,6 +211,8 @@ class PlImporter:
                     self.logger.log_exc("PlImporter: failed to import site. Skipping child records") 
                     continue 
             else:
+                # we might have renamed it - since we first use pointer to locate it
+                check_hrn(site_record, site_hrn)
                 # xxx update the record ...
                 pass
             site_record.stale=False
@@ -216,18 +226,18 @@ class PlImporter:
                     continue 
                 site_auth = get_authority(site_hrn)
                 site_name = site['login_base']
-                hrn =  hostname_to_hrn(site_auth, site_name, node['hostname'])
+                node_hrn =  hostname_to_hrn(site_auth, site_name, node['hostname'])
                 # xxx this sounds suspicious
-                if len(hrn) > 64: hrn = hrn[:64]
-                node_record = self.locate ( 'node', hrn , node['node_id'] )
+                if len(node_hrn) > 64: node_hrn = node_hrn[:64]
+                node_record = self.locate ( 'node', node_hrn , node['node_id'] )
                 if not node_record:
                     try:
                         pkey = Keypair(create=True)
-                        urn = hrn_to_urn(hrn, 'node')
+                        urn = hrn_to_urn(node_hrn, 'node')
                         node_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
-                        node_record = RegNode (hrn=hrn, gid=node_gid, 
+                        node_record = RegNode (hrn=node_hrn, gid=node_gid, 
                                                pointer =node['node_id'],
-                                               authority=get_authority(hrn))
+                                               authority=get_authority(node_hrn))
                         node_record.just_created()
                         dbsession.add(node_record)
                         dbsession.commit()
@@ -237,9 +247,10 @@ class PlImporter:
                         self.logger.log_exc("PlImporter: failed to import node") 
                 else:
                     # xxx update the record ...
-                    pass
+                    check_hrn(node_record, node_hrn)
                 node_record.stale=False
 
+            site_pis=[]
             # import persons
             for person_id in site['person_ids']:
                 try:
@@ -294,6 +305,7 @@ class PlImporter:
                         # update the record ?
                         # if user's primary key has changed then we need to update the 
                         # users gid by forcing an update here
+                        check_hrn (user_record, person_hrn)
                         sfa_keys = user_record.reg_keys
                         def key_in_list (key,sfa_keys):
                             for reg_key in sfa_keys:
@@ -315,9 +327,17 @@ class PlImporter:
                     user_record.email = person['email']
                     dbsession.commit()
                     user_record.stale=False
+                    # accumulate PIs - PLCAPI has a limitation that when someone has PI role
+                    # this is valid for all sites she is in..
+                    # PI is coded with role_id==20
+                    if 20 in person['role_ids']:
+                        site_pis.append (user_record)
                 except:
                     self.logger.log_exc("PlImporter: failed to import person %d %s"%(person['person_id'],person['email']))
     
+            # maintain the list of PIs for a given site
+            site_record.reg_pis = site_pis
+
             # import slices
             for slice_id in site['slice_ids']:
                 try:
@@ -343,6 +363,7 @@ class PlImporter:
                         self.logger.log_exc("PlImporter: failed to import slice")
                 else:
                     # xxx update the record ...
+                    check_hrn (slice_record, slice_hrn)
                     self.logger.warning ("Slice update not yet implemented")
                     pass
                 # record current users affiliated with the slice