the model can store keys and email - import keys and email from plcapi
authorThierry Parmentelat <thierry.parmentelat@sophia.inria.fr>
Thu, 2 Feb 2012 23:50:27 +0000 (00:50 +0100)
committerThierry Parmentelat <thierry.parmentelat@sophia.inria.fr>
Thu, 2 Feb 2012 23:50:27 +0000 (00:50 +0100)
sfa/importer/plimporter.py
sfa/storage/model.py

index a902d93..802ab97 100644 (file)
@@ -8,26 +8,10 @@ from sfa.trust.gid import create_uuid
 from sfa.trust.certificate import convert_public_key, Keypair
 
 from sfa.storage.alchemy import dbsession
-from sfa.storage.model import RegRecord, RegAuthority, RegUser, RegSlice, RegNode
+from sfa.storage.model import RegRecord, RegAuthority, RegSlice, RegNode, RegUser, RegKey
 
 from sfa.plc.plshell import PlShell    
 
-def load_keys(filename):
-    keys = {}
-    tmp_dict = {}
-    try:
-        execfile(filename, tmp_dict)
-        if 'keys' in tmp_dict:
-            keys = tmp_dict['keys']
-        return keys
-    except:
-        return keys
-
-def save_keys(filename, keys):
-    f = open(filename, 'w')
-    f.write("keys = %s" % str(keys))
-    f.close()
-
 def _get_site_hrn(interface_hrn, site):
     # Hardcode 'internet2' into the hrn for sites hosting
     # internet2 nodes. This is a special operation for some vini
@@ -57,15 +41,19 @@ class PlImporter:
         shell = PlShell (config)
 
         # create dict of all existing sfa records
-        existing_records = {}
-        existing_hrns = []
+        records_by_hrn_type = {}
+        records_by_type_pointer = {}
         key_ids = []
-        for record in dbsession.query(RegRecord):
-            existing_records[ (record.hrn, record.type,) ] = record
-            existing_hrns.append(record.hrn) 
-            
+        records = dbsession.query(RegRecord)
+        for record in records:
+            records_by_hrn_type[ (record.hrn, record.type,) ] = record
+            if record.pointer != -1:
+                records_by_type_pointer [ (record.type, record.pointer,) ] = record
+
         # Get all plc sites
-        sites = shell.GetSites({'peer_id': None})
+        # retrieve only required stuf
+        sites = shell.GetSites({'peer_id': None, 'enabled' : True},
+                               ['site_id','login_base','node_ids','slice_ids','person_ids',])
         sites_dict = {}
         for site in sites:
             sites_dict[site['login_base']] = site 
@@ -78,21 +66,18 @@ class PlImporter:
             persons_dict[person['person_id']] = person
             key_ids.extend(person['key_ids'])
 
-        # Get all public keys
+        # Get all plc public keys
         keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids})
-        keys_dict = {}
-        for key in keys:
-            keys_dict[key['key_id']] = key['key']
+        keys_by_id = {}
+        for key in keys: keys_by_id[key['key_id']] = key
 
-        # create a dict of person keys keyed on key_id 
-        keys_filename = config.config_path + os.sep + 'person_keys.py' 
-        old_person_keys = load_keys(keys_filename)
-        person_keys = {} 
+        # create a dict person_id -> [ (plc)keys ]
+        keys_by_person_id = {} 
         for person in persons:
             pubkeys = []
             for key_id in person['key_ids']:
-                pubkeys.append(keys_dict[key_id])
-            person_keys[person['person_id']] = pubkeys
+                pubkeys.append(keys_by_id[key_id])
+            keys_by_person_id[person['person_id']] = pubkeys
 
         # Get all plc nodes  
         nodes = shell.GetNodes( {'peer_id': None}, ['node_id', 'hostname', 'site_id'])
@@ -109,13 +94,11 @@ class PlImporter:
         # special case for vini
         if ".vini" in interface_hrn and interface_hrn.endswith('vini'):
             # create a fake internet2 site first
-            i2site = {'name': 'Internet2', 'abbreviated_name': 'I2',
-                        'login_base': 'internet2', 'site_id': -1}
+            i2site = {'name': 'Internet2', 'login_base': 'internet2', 'site_id': -1}
             site_hrn = _get_site_hrn(interface_hrn, i2site)
             # import if hrn is not in list of existing hrns or if the hrn exists
             # but its not a site record
-            if site_hrn not in existing_hrns or \
-               (site_hrn, 'authority') not in existing_records:
+            if (site_hrn, 'authority') not in records_by_hrn_type:
                 urn = hrn_to_urn(site_hrn, 'authority')
                 if not self.auth_hierarchy.auth_exists(urn):
                     self.auth_hierarchy.create_auth(urn)
@@ -136,8 +119,7 @@ class PlImporter:
     
             # import if hrn is not in list of existing hrns or if the hrn exists
             # but its not a site record
-            if site_hrn not in existing_hrns or \
-               (site_hrn, 'authority') not in existing_records:
+            if (site_hrn, 'authority') not in records_by_hrn_type:
                 try:
                     urn = hrn_to_urn(site_hrn, 'authority')
                     if not self.auth_hierarchy.auth_exists(urn):
@@ -168,8 +150,7 @@ class PlImporter:
                 hrn =  hostname_to_hrn(site_auth, site_name, node['hostname'])
                 if len(hrn) > 64:
                     hrn = hrn[:64]
-                if hrn not in existing_hrns or \
-                   (hrn, 'node') not in existing_records:
+                if (hrn, 'node') not in records_by_hrn_type:
                     try:
                         pkey = Keypair(create=True)
                         urn = hrn_to_urn(hrn, 'node')
@@ -193,8 +174,7 @@ class PlImporter:
                     continue 
                 slice = slices_dict[slice_id]
                 hrn = slicename_to_hrn(interface_hrn, slice['name'])
-                if hrn not in existing_hrns or \
-                   (hrn, 'slice') not in existing_records:
+                if (hrn, 'slice') not in records_by_hrn_type:
                     try:
                         pkey = Keypair(create=True)
                         urn = hrn_to_urn(hrn, 'slice')
@@ -214,63 +194,76 @@ class PlImporter:
             # import persons
             for person_id in site['person_ids']:
                 if person_id not in persons_dict:
+                    self.logger.warning ("PlImporter: skipping person %s"%person_id)
                     continue 
                 person = persons_dict[person_id]
                 hrn = email_to_hrn(site_hrn, person['email'])
                 if len(hrn) > 64:
                     hrn = hrn[:64]
     
+                previous_record = records_by_hrn_type.get( (hrn, 'user',) )
+                if not previous_record:
+                    previous_record = records_by_type_pointer.get ( ('user', person_id,) )
                 # if user's primary key has changed then we need to update the 
                 # users gid by forcing an update here
-                old_keys = []
-                new_keys = []
-                if person_id in old_person_keys:
-                    old_keys = old_person_keys[person_id]
-                if person_id in person_keys:
-                    new_keys = person_keys[person_id]
+                plc_keys = []
+                sfa_keys = []
+                if previous_record:
+                    sfa_keys = previous_record.keys
+                if person_id in keys_by_person_id:
+                    plc_keys = keys_by_person_id[person_id]
                 update_record = False
-                for key in new_keys:
-                    if key not in old_keys:
+                def key_in_list (key,sfa_keys):
+                    for reg_key in sfa_keys:
+                        if reg_key.key==key['key']: return True
+                    return False
+                for key in plc_keys:
+                    if not key_in_list (key,sfa_keys):
                         update_record = True 
     
-                if hrn not in existing_hrns or \
-                   (hrn, 'user') not in existing_records or update_record:
+                if not previous_record or update_record:
                     try:
+                        pubkey=None
                         if 'key_ids' in person and person['key_ids']:
-                            key = new_keys[0]
+                            # randomly pick first key in set
+                            pubkey = plc_keys[0]
                             try:
-                                pkey = convert_public_key(key)
+                                pkey = convert_public_key(pubkey['key'])
                             except:
-                                self.logger.warn('unable to convert public key for %s' % hrn)
+                                self.logger.warn('PlImporter: unable to convert public key for %s' % hrn)
                                 pkey = Keypair(create=True)
                         else:
                             # the user has no keys. Creating a random keypair for the user's gid
                             self.logger.warn("PlImporter: person %s does not have a PL public key"%hrn)
-                            pkey = Keypair(create=True) 
+                            pkey = Keypair(create=True)
                         urn = hrn_to_urn(hrn, 'user')
                         person_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
-                        person_record = RegUser ()
-                        person_record.type='user'
-                        person_record.hrn=hrn
-                        person_record.gid=person_gid
-                        person_record.pointer=person['person_id']
-                        person_record.authority=get_authority(hrn)
-                        dbsession.add (person_record)
-                        dbsession.commit()
-                        self.logger.info("PlImporter: imported person: %s" % person_record)
+                        if previous_record: 
+                            previous_record.gid=person_gid
+                            if pubkey: previous_record.keys=[RegKey (pubkey['key'], pubkey['key_id'])]
+                            self.logger.info("PlImporter: updated person: %s" % previous_record)
+                        else:
+                            new_record = RegUser (hrn=hrn, gid=person_gid, 
+                                                  pointer=person['person_id'], 
+                                                  authority=get_authority(hrn),
+                                                  email=person['email'])
+                            if pubkey: new_record.keys=[RegKey (pubkey['key'], pubkey['key_id'])]
+                            dbsession.add (new_record)
+                            dbsession.commit()
+                            self.logger.info("PlImporter: imported person: %s" % new_record)
                     except:
                         self.logger.log_exc("PlImporter: failed to import person.") 
     
         # remove stale records    
         system_records = [interface_hrn, root_auth, interface_hrn + '.slicemanager']
-        for (record_hrn, type) in existing_records.keys():
+        for record in records:
+            record_hrn=record.hrn
             if record_hrn in system_records:
                 continue
-            
-            record = existing_records[(record_hrn, type)]
             if record.peer_authority:
                 continue
-    
+            type=record.type
+            hrn=record.hrn
             # dont delete vini's internet2 placeholdder record
             # normally this would be deleted becuase it does not have a plc record 
             if ".vini" in interface_hrn and interface_hrn.endswith('vini') and \
@@ -279,14 +272,14 @@ class PlImporter:
     
             found = False
             
-            if type == 'authority':    
+            if isinstance (record, RegAuthority):
                 for site in sites:
                     site_hrn = interface_hrn + "." + site['login_base']
                     if site_hrn == record_hrn and site['site_id'] == record.pointer:
                         found = True
                         break
 
-            elif type == 'user':
+            elif isinstance (record, RegUser):
                 login_base = get_leaf(get_authority(record_hrn))
                 username = get_leaf(record_hrn)
                 if login_base in sites_dict:
@@ -300,7 +293,7 @@ class PlImporter:
                             found = True
                             break
         
-            elif type == 'slice':
+            elif isinstance (record, RegSlice):
                 slicename = hrn_to_pl_slicename(record_hrn)
                 for slice in slices:
                     if slicename == slice['name'] and \
@@ -308,7 +301,7 @@ class PlImporter:
                         found = True
                         break    
  
-            elif type == 'node':
+            elif isinstance (record, RegNode):
                 login_base = get_leaf(get_authority(record_hrn))
                 nodename = Xrn.unescape(get_leaf(record_hrn))
                 if login_base in sites_dict:
@@ -325,14 +318,14 @@ class PlImporter:
         
             if not found:
                 try:
-                    record_object = existing_records[(record_hrn, type)]
+                    record_object = records_by_hrn_type[(record_hrn, type)]
                     self.logger.info("PlImporter: deleting record: %s" % record)
                     dbsession.delete(record_object)
                     dbsession.commit()
                 except:
                     self.logger.log_exc("PlImporter: failded to delete record")                    
 
-        # save pub keys
-        self.logger.info('Import: saving current pub keys')
-        save_keys(keys_filename, person_keys)                
+#        # save pub keys
+#        self.logger.info('Import: saving current pub keys')
+#        save_keys(keys_filename, person_keys)                
         
index 810af6b..18029d1 100644 (file)
@@ -165,25 +165,6 @@ class RegRecord (Base,AlchemyObj):
         self.last_updated=now
 
 ##############################
-class RegUser (RegRecord):
-    __tablename__       = 'users'
-    # these objects will have type='user' in the records table
-    __mapper_args__     = { 'polymorphic_identity' : 'user' }
-    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
-    email               = Column ('email', String)
-    
-    # append stuff at the end of the record __repr__
-    def __repr__ (self): 
-        result = RegRecord.__repr__(self).replace("Record","User")
-        result.replace ("]"," email=%s"%self.email)
-        result += "]"
-        return result
-    
-    @validates('email') 
-    def validate_email(self, key, address):
-        assert '@' in address
-        return address
-
 class RegAuthority (RegRecord):
     __tablename__       = 'authorities'
     __mapper_args__     = { 'polymorphic_identity' : 'authority' }
@@ -193,6 +174,7 @@ class RegAuthority (RegRecord):
     def __repr__ (self):
         return RegRecord.__repr__(self).replace("Record","Authority")
 
+##############################
 class RegSlice (RegRecord):
     __tablename__       = 'slices'
     __mapper_args__     = { 'polymorphic_identity' : 'slice' }
@@ -201,6 +183,7 @@ class RegSlice (RegRecord):
     def __repr__ (self):
         return RegRecord.__repr__(self).replace("Record","Slice")
 
+##############################
 class RegNode (RegRecord):
     __tablename__       = 'nodes'
     __mapper_args__     = { 'polymorphic_identity' : 'node' }
@@ -209,6 +192,56 @@ class RegNode (RegRecord):
     def __repr__ (self):
         return RegRecord.__repr__(self).replace("Record","Node")
 
+##############################
+class RegUser (RegRecord):
+    __tablename__       = 'users'
+    # these objects will have type='user' in the records table
+    __mapper_args__     = { 'polymorphic_identity' : 'user' }
+    record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
+    email               = Column ('email', String)
+    keys                = relationship ('RegKey', backref='user')
+    
+    def __init__ (self, **kwds):
+        # handle local settings
+        if 'email' in kwds: self.email=kwds.pop('email')
+        # fill in type if not previously set
+        if 'type' not in kwds: kwds['type']='user'
+        RegRecord.__init__(self, **kwds)
+
+    # append stuff at the end of the record __repr__
+    def __repr__ (self): 
+        result = RegRecord.__repr__(self).replace("Record","User")
+        result.replace ("]"," email=%s"%self.email)
+        result += "]"
+        return result
+    
+    @validates('email') 
+    def validate_email(self, key, address):
+        assert '@' in address
+        return address
+
+####################
+# xxx tocheck : not sure about eager loading of this one
+# meaning, when querying the whole records, we expect there should
+# be a single query to fetch all the keys 
+class RegKey (Base):
+    __tablename__       = 'keys'
+    key_id              = Column (Integer, primary_key=True)
+    record_id             = Column (Integer, ForeignKey ("records.record_id"))
+    key                 = Column (String)
+    pointer             = Column (Integer, default = -1)
+    
+    def __init__ (self, key, pointer=None):
+        self.key=key
+        if pointer: self.pointer=pointer
+
+    def __repr__ (self):
+        result="[key key=%s..."%self.key[8:16]
+        try:    result += " user=%s"%self.user.record_id
+        except: result += " <orphan>"
+        result += "]"
+        return result
+
 ##############################
 # although the db needs of course to be reachable,
 # the schema management functions are here and not in alchemy
@@ -245,6 +278,7 @@ def make_record_dict (record_dict):
     elif type=='node':
         result=RegNode (dict=record_dict)
     else:
+        logger.debug("Untyped RegRecord instance")
         result=RegRecord (dict=record_dict)
     logger.info ("converting dict into Reg* with type=%s"%type)
     logger.info ("returning=%s"%result)