create registry objects with kdws-style again, like RegAuthority(hrn=hrn)
authorThierry Parmentelat <thierry.parmentelat@sophia.inria.fr>
Mon, 6 Feb 2012 10:48:31 +0000 (11:48 +0100)
committerThierry Parmentelat <thierry.parmentelat@sophia.inria.fr>
Mon, 6 Feb 2012 10:48:31 +0000 (11:48 +0100)
record_options->add_options

sfa/importer/openstackimporter.py
sfa/importer/plimporter.py
sfa/importer/sfa-import.py
sfa/importer/sfaimporter.py
sfa/storage/model.py

index 27d0344..844980d 100644 (file)
@@ -34,8 +34,8 @@ class OpenstackImporter:
         self.auth_hierarchy = auth_hierarchy
         self.logger=logger
 
-    def record_options (self, parser):
-        self.logger.debug ("PlImporter no options yet")
+    def add_options (self, parser):
+        self.logger.debug ("OpenstackImporter: no options yet")
         pass
 
     def run (self, options):
index 9a6a8f5..82fa5f1 100644 (file)
@@ -29,47 +29,69 @@ class PlImporter:
         self.auth_hierarchy = auth_hierarchy
         self.logger=logger
 
-    def record_options (self, parser):
-        self.logger.debug ("PlImporter no options yet")
+    def add_options (self, parser):
+        # we don't have any options for now
         pass
 
+    # this makes the run method a bit abtruse - out of the way
+    def create_special_vini_record (self, interface_hrn):
+        # special case for vini
+        if ".vini" in interface_hrn and interface_hrn.endswith('vini'):
+            # create a fake internet2 site first
+            i2site = {'name': 'Internet2', 'login_base': 'internet2', 'site_id': -1}
+            site_hrn = _get_site_hrn(interface_hrn, i2site)
+            # import if hrn is not in list of existing hrns or if the hrn exists
+            # but its not a site record
+            if ( 'authority', site_hrn, ) not in self.records_by_type_hrn:
+                urn = hrn_to_urn(site_hrn, 'authority')
+                if not self.auth_hierarchy.auth_exists(urn):
+                    self.auth_hierarchy.create_auth(urn)
+                auth_info = self.auth_hierarchy.get_auth_info(urn)
+                auth_record = RegAuthority(hrn=site_hrn, gid=auth_info.get_gid_object(),
+                                           pointer=site['site_id'],
+                                           authority=get_authority(site_hrn))
+                auth_record.just_created()
+                dbsession.add(auth_record)
+                dbsession.commit()
+                self.logger.info("PlImporter: Imported authority (vini site) %s"%auth_record)
+
     def run (self, options):
-        # we don't have any options for now
         config = Config ()
         interface_hrn = config.SFA_INTERFACE_HRN
         root_auth = config.SFA_REGISTRY_ROOT_AUTH
         shell = PlShell (config)
 
-        # create dict of all existing sfa records
-        records_by_hrn_type = {}
-        records_by_type_pointer = {}
-        key_ids = []
+        ######## retrieve all existing SFA objects
         records = dbsession.query(RegRecord)
-        for record in records:
-            records_by_hrn_type[ (record.hrn, record.type,) ] = record
-            if record.pointer != -1:
-                records_by_type_pointer [ (record.type, record.pointer,) ] = record
+        # create indexes / hashes by (type,hrn) 
+        self.records_by_type_hrn = dict ( [ ( (record.type, record.hrn) , record ) for record in records ] )
+        # and by (type,pointer)
+        self.records_by_type_pointer = \
+            dict ( [ ( (record.type, record.pointer) , record ) for record in records if record.pointer != -1 ] )
 
+        ######## retrieve PLC records
         # Get all plc sites
         # retrieve only required stuf
         sites = shell.GetSites({'peer_id': None, 'enabled' : True},
                                ['site_id','login_base','node_ids','slice_ids','person_ids',])
-        sites_dict = {}
-        for site in sites:
-            sites_dict[site['login_base']] = site 
+        # create a hash of sites by login_base
+        sites_by_login_base = dict ( [ ( site['login_base'], site ) for site in sites ] )
     
         # Get all plc users
         persons = shell.GetPersons({'peer_id': None, 'enabled': True}, 
                                    ['person_id', 'email', 'key_ids', 'site_ids'])
-        persons_dict = {}
+        # create a hash of persons by person_id
+        persons_by_id = dict ( [ ( person['person_id'], person) for person in persons ] )
+        
+        # Get all plc public keys
+        # accumulate key ids for keys retrieval
+        key_ids = []
         for person in persons:
-            persons_dict[person['person_id']] = person
             key_ids.extend(person['key_ids'])
+        keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids} )
 
-        # Get all plc public keys
-        keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids})
-        keys_by_id = {}
-        for key in keys: keys_by_id[key['key_id']] = key
+        # create a hash of keys by key_id
+        keys_by_id = dict ( [ ( key['key_id'], key ) for key in keys ] ) 
 
         # create a dict person_id -> [ (plc)keys ]
         keys_by_person_id = {} 
@@ -81,37 +103,16 @@ class PlImporter:
 
         # Get all plc nodes  
         nodes = shell.GetNodes( {'peer_id': None}, ['node_id', 'hostname', 'site_id'])
-        nodes_dict = {}
-        for node in nodes:
-            nodes_dict[node['node_id']] = node
+        # create hash by node_id
+        nodes_by_id = dict ( [ ( node['node_id'], node, ) for node in nodes ] )
 
         # Get all plc slices
         slices = shell.GetSlices( {'peer_id': None}, ['slice_id', 'name'])
-        slices_dict = {}
-        for slice in slices:
-            slices_dict[slice['slice_id']] = slice
+        # create hash by slice_id
+        slices_by_id = dict ( [ (slice['slice_id'], slice ) for slice in slices ] )
 
-        # special case for vini
-        if ".vini" in interface_hrn and interface_hrn.endswith('vini'):
-            # create a fake internet2 site first
-            i2site = {'name': 'Internet2', 'login_base': 'internet2', 'site_id': -1}
-            site_hrn = _get_site_hrn(interface_hrn, i2site)
-            # import if hrn is not in list of existing hrns or if the hrn exists
-            # but its not a site record
-            if (site_hrn, 'authority') not in records_by_hrn_type:
-                urn = hrn_to_urn(site_hrn, 'authority')
-                if not self.auth_hierarchy.auth_exists(urn):
-                    self.auth_hierarchy.create_auth(urn)
-                auth_info = self.auth_hierarchy.get_auth_info(urn)
-                auth_record = RegAuthority()
-                auth_record.type='authority'
-                auth_record.hrn=site_hrn
-                auth_record.gid=auth_info.get_gid_object()
-                auth_record.pointer=site['site_id']
-                auth_record.authority=get_authority(site_hrn)
-                dbsession.add(auth_record)
-                dbsession.commit()
-                self.logger.info("PlImporter: Imported authority (vini site) %s"%auth_record)
+        # isolate special vini case in separate method
+        self.create_special_vini_record (interface_hrn)
 
         # start importing 
         for site in sites:
@@ -119,18 +120,16 @@ class PlImporter:
     
             # import if hrn is not in list of existing hrns or if the hrn exists
             # but its not a site record
-            if (site_hrn, 'authority') not in records_by_hrn_type:
+            if ( 'authority', site_hrn, ) not in self.records_by_type_hrn:
                 try:
                     urn = hrn_to_urn(site_hrn, 'authority')
                     if not self.auth_hierarchy.auth_exists(urn):
                         self.auth_hierarchy.create_auth(urn)
                     auth_info = self.auth_hierarchy.get_auth_info(urn)
-                    auth_record = RegAuthority()
-                    auth_record.type='authority'
-                    auth_record.hrn=site_hrn
-                    auth_record.gid=auth_info.get_gid_object()
-                    auth_record.pointer=site['site_id']
-                    auth_record.authority=get_authority(site_hrn)
+                    auth_record = RegAuthority(hrn=site_hrn, gid=auth_info.get_gid_object(),
+                                               pointer=site['site_id'],
+                                               authority=get_authority(site_hrn))
+                    auth_record.just_created()
                     dbsession.add(auth_record)
                     dbsession.commit()
                     self.logger.info("PlImporter: imported authority (site) : %s" % auth_record)  
@@ -142,25 +141,23 @@ class PlImporter:
              
             # import node records
             for node_id in site['node_ids']:
-                if node_id not in nodes_dict:
+                if node_id not in nodes_by_id:
                     continue 
-                node = nodes_dict[node_id]
+                node = nodes_by_id[node_id]
                 site_auth = get_authority(site_hrn)
                 site_name = get_leaf(site_hrn)
                 hrn =  hostname_to_hrn(site_auth, site_name, node['hostname'])
                 if len(hrn) > 64:
                     hrn = hrn[:64]
-                if (hrn, 'node') not in records_by_hrn_type:
+                if ( 'node', hrn, ) not in self.records_by_type_hrn:
                     try:
                         pkey = Keypair(create=True)
                         urn = hrn_to_urn(hrn, 'node')
                         node_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
-                        node_record = RegNode ()
-                        node_record.type='node'
-                        node_record.hrn=hrn
-                        node_record.gid=node_gid
-                        node_record.pointer =node['node_id']
-                        node_record.authority=get_authority(hrn)
+                        node_record = RegNode (hrn=hrn, gid=node_gid, 
+                                               pointer =node['node_id'],
+                                               authority=get_authority(hrn))
+                        node_record.just_created()
                         dbsession.add(node_record)
                         dbsession.commit()
                         self.logger.info("PlImporter: imported node: %s" % node_record)  
@@ -170,21 +167,19 @@ class PlImporter:
 
             # import slices
             for slice_id in site['slice_ids']:
-                if slice_id not in slices_dict:
+                if slice_id not in slices_by_id:
                     continue 
-                slice = slices_dict[slice_id]
+                slice = slices_by_id[slice_id]
                 hrn = slicename_to_hrn(interface_hrn, slice['name'])
-                if (hrn, 'slice') not in records_by_hrn_type:
+                if ( 'slice', hrn, ) not in self.records_by_type_hrn:
                     try:
                         pkey = Keypair(create=True)
                         urn = hrn_to_urn(hrn, 'slice')
                         slice_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
-                        slice_record = RegSlice ()
-                        slice_record.type='slice'
-                        slice_record.hrn=hrn
-                        slice_record.gid=slice_gid
-                        slice_record.pointer=slice['slice_id']
-                        slice_record.authority=get_authority(hrn)
+                        slice_record = RegSlice (hrn=hrn, gid=slice_gid, 
+                                                 pointer=slice['slice_id'],
+                                                 authority=get_authority(hrn))
+                        slice_record.just_created()
                         dbsession.add(slice_record)
                         dbsession.commit()
                         self.logger.info("PlImporter: imported slice: %s" % slice_record)  
@@ -193,17 +188,17 @@ class PlImporter:
 
             # import persons
             for person_id in site['person_ids']:
-                if person_id not in persons_dict:
+                if person_id not in persons_by_id:
                     self.logger.warning ("PlImporter: skipping person %s"%person_id)
                     continue 
-                person = persons_dict[person_id]
+                person = persons_by_id[person_id]
                 hrn = email_to_hrn(site_hrn, person['email'])
                 if len(hrn) > 64:
                     hrn = hrn[:64]
     
-                previous_record = records_by_hrn_type.get( (hrn, 'user',) )
+                previous_record = self.records_by_type_hrn.get( ( 'user', hrn, ) )
                 if not previous_record:
-                    previous_record = records_by_type_pointer.get ( ('user', person_id,) )
+                    previous_record = self.records_by_type_pointer.get ( ('user', person_id,) )
                 # if user's primary key has changed then we need to update the 
                 # users gid by forcing an update here
                 plc_keys = []
@@ -240,14 +235,17 @@ class PlImporter:
                         person_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
                         if previous_record: 
                             previous_record.gid=person_gid
-                            if pubkey: previous_record.reg_keys=[RegKey (pubkey['key'], pubkey['key_id'])]
+                            if pubkey: previous_record.reg_keys=[ RegKey (pubkey['key'], pubkey['key_id'])]
                             self.logger.info("PlImporter: updated person: %s" % previous_record)
                         else:
                             new_record = RegUser (hrn=hrn, gid=person_gid, 
                                                   pointer=person['person_id'], 
                                                   authority=get_authority(hrn),
                                                   email=person['email'])
-                            if pubkey: new_record.reg_keys=[RegKey (pubkey['key'], pubkey['key_id'])]
+                            if pubkey: 
+                                new_record.reg_keys=[RegKey (pubkey['key'], pubkey['key_id'])]
+                            else:
+                                logger.warning("No key found for user %s"%new_record)
                             dbsession.add (new_record)
                             dbsession.commit()
                             self.logger.info("PlImporter: imported person: %s" % new_record)
@@ -282,8 +280,8 @@ class PlImporter:
             elif isinstance (record, RegUser):
                 login_base = get_leaf(get_authority(record_hrn))
                 username = get_leaf(record_hrn)
-                if login_base in sites_dict:
-                    site = sites_dict[login_base]
+                if login_base in sites_by_login_base:
+                    site = sites_by_login_base[login_base]
                     for person in persons:
                         tmp_username = person['email'].split("@")[0]
                         alt_username = person['email'].split("@")[0].replace(".", "_").replace("+", "_")
@@ -304,8 +302,8 @@ class PlImporter:
             elif isinstance (record, RegNode):
                 login_base = get_leaf(get_authority(record_hrn))
                 nodename = Xrn.unescape(get_leaf(record_hrn))
-                if login_base in sites_dict:
-                    site = sites_dict[login_base]
+                if login_base in sites_by_login_base:
+                    site = sites_by_login_base[login_base]
                     for node in nodes:
                         tmp_nodename = node['hostname']
                         if tmp_nodename == nodename and \
@@ -318,7 +316,7 @@ class PlImporter:
         
             if not found:
                 try:
-                    record_object = records_by_hrn_type[(record_hrn, type)]
+                    record_object = self.records_by_type_hrn[ ( type, record_hrn, ) ]
                     self.logger.info("PlImporter: deleting record: %s" % record)
                     dbsession.delete(record_object)
                     dbsession.commit()
index aeddaeb..2788282 100755 (executable)
@@ -36,9 +36,9 @@ def main ():
         testbed_importer = importer_class (auth_hierarchy, logger)
 
     parser = OptionParser ()
-    sfa_importer.record_options (parser)
+    sfa_importer.add_options (parser)
     if testbed_importer:
-        testbed_importer.record_options (parser)
+        testbed_importer.add_options (parser)
 
     (options, args) = parser.parse_args ()
     # no args supported ?
index a337655..bc1ea0a 100644 (file)
@@ -48,8 +48,8 @@ class SfaImporter:
        self.root_auth = self.config.SFA_REGISTRY_ROOT_AUTH
 
     # record options into an OptionParser
-    def record_options (self, parser):
-       self.logger.info ("SfaImporter.record_options : to do")
+    def add_options (self, parser):
+       # no generic option
        pass
 
     def run (self, options):
@@ -94,11 +94,8 @@ class SfaImporter:
         self.auth_hierarchy.create_top_level_auth(hrn)    
         # create the db record if it doesnt already exist    
         auth_info = self.auth_hierarchy.get_auth_info(hrn)
-        auth_record = RegAuthority()
-        auth_record.type='authority'
-        auth_record.hrn=hrn
-        auth_record.gid=auth_info.get_gid_object()
-        auth_record.authority=get_authority(hrn)
+        auth_record = RegAuthority(hrn=hrn, gid=auth_info.get_gid_object(),
+                                   authority=get_authority(hrn))
         auth_record.just_created()
         dbsession.add (auth_record)
         dbsession.commit()
@@ -115,11 +112,8 @@ class SfaImporter:
             self.auth_hierarchy.create_auth(urn)
 
         auth_info = self.auth_hierarchy.get_auth_info(hrn)
-        user_record = RegUser()
-        user_record.type='user'
-        user_record.hrn=hrn
-        user_record.gid=auth_info.get_gid_object()
-        user_record.authority=get_authority(hrn)
+        user_record = RegUser(hrn=hrn, gid=auth_info.get_gid_object(),
+                              authority=get_authority(hrn))
         user_record.just_created()
         dbsession.add (user_record)
         dbsession.commit()
@@ -139,11 +133,8 @@ class SfaImporter:
             gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey)
             # xxx this should probably use a RegAuthority, or a to-be-defined RegPeer object
             # but for now we have to preserve the authority+<> stuff
-            interface_record = RegAuthority()
-            interface_record.type=type
-            interface_record.hrn=hrn
-            interface_record.gid= gid
-            interface_record.authority=get_authority(hrn)
+            interface_record = RegAuthority(type=type, hrn=hrn, gid=gid,
+                                            authority=get_authority(hrn))
             interface_record.just_created()
             dbsession.add (interface_record)
             dbsession.commit()
index 30670f4..6da032b 100644 (file)
@@ -111,6 +111,8 @@ class RegRecord (Base,AlchemyObj):
     record_id           = Column (Integer, primary_key=True)
     # this is the discriminator that tells which class to use
     classtype           = Column (String)
+    # in a first version type was the discriminator
+    # but that could not accomodate for 'authority+sa' and the like
     type                = Column (String)
     hrn                 = Column (String)
     gid                 = Column (String)
@@ -136,12 +138,12 @@ class RegRecord (Base,AlchemyObj):
         if dict:                                self.load_from_dict (dict)
 
     def __repr__(self):
-        result="[Record id=%s, type=%s, hrn=%s, authority=%s, pointer=%s" % \
+        result="<Record id=%s, type=%s, hrn=%s, authority=%s, pointer=%s" % \
                 (self.record_id, self.type, self.hrn, self.authority, self.pointer)
         # skip the uniform '--- BEGIN CERTIFICATE --' stuff
         if self.gid: result+=" gid=%s..."%self.gid[28:36]
         else: result+=" nogid"
-        result += "]"
+        result += ">"
         return result
 
     @validates ('gid')
@@ -165,34 +167,51 @@ class RegRecord (Base,AlchemyObj):
         self.last_updated=now
 
 ##############################
+# all subclasses define a convenience constructor with a default value for type, 
+# and when applicable a way to define local fields in a kwd=value argument
+####################
 class RegAuthority (RegRecord):
     __tablename__       = 'authorities'
     __mapper_args__     = { 'polymorphic_identity' : 'authority' }
     record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
     
+    def __init__ (self, **kwds):
+        # fill in type if not previously set
+        if 'type' not in kwds: kwds['type']='authority'
+        # base class constructor
+        RegRecord.__init__(self, **kwds)
+
     # no proper data yet, just hack the typename
     def __repr__ (self):
         return RegRecord.__repr__(self).replace("Record","Authority")
 
-##############################
+####################
 class RegSlice (RegRecord):
     __tablename__       = 'slices'
     __mapper_args__     = { 'polymorphic_identity' : 'slice' }
     record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
     
+    def __init__ (self, **kwds):
+        if 'type' not in kwds: kwds['type']='slice'
+        RegRecord.__init__(self, **kwds)
+
     def __repr__ (self):
         return RegRecord.__repr__(self).replace("Record","Slice")
 
-##############################
+####################
 class RegNode (RegRecord):
     __tablename__       = 'nodes'
     __mapper_args__     = { 'polymorphic_identity' : 'node' }
     record_id           = Column (Integer, ForeignKey ("records.record_id"), primary_key=True)
     
+    def __init__ (self, **kwds):
+        if 'type' not in kwds: kwds['type']='node'
+        RegRecord.__init__(self, **kwds)
+
     def __repr__ (self):
         return RegRecord.__repr__(self).replace("Record","Node")
 
-##############################
+####################
 class RegUser (RegRecord):
     __tablename__       = 'users'
     # these objects will have type='user' in the records table
@@ -203,25 +222,26 @@ class RegUser (RegRecord):
     # a 'keys' tag, and assigning a list of strings in a reference column like this crashes
     reg_keys            = relationship ('RegKey', backref='reg_user')
     
+    # so we can use RegUser (email=.., hrn=..) and the like
     def __init__ (self, **kwds):
         # handle local settings
         if 'email' in kwds: self.email=kwds.pop('email')
-        # fill in type if not previously set
         if 'type' not in kwds: kwds['type']='user'
         RegRecord.__init__(self, **kwds)
 
     # append stuff at the end of the record __repr__
     def __repr__ (self): 
         result = RegRecord.__repr__(self).replace("Record","User")
-        result.replace ("]"," email=%s"%self.email)
-        result += "]"
+        result.replace (">"," email=%s"%self.email)
+        result += ">"
         return result
-    
+
     @validates('email') 
     def validate_email(self, key, address):
         assert '@' in address
         return address
 
+    # xxx this might be temporary
     def normalize_xml (self):
         if hasattr(self,'keys'): self.reg_keys = [ RegKey (key) for key in self.keys ]
 
@@ -229,6 +249,7 @@ class RegUser (RegRecord):
 # xxx tocheck : not sure about eager loading of this one
 # meaning, when querying the whole records, we expect there should
 # be a single query to fetch all the keys 
+# or, is it enough that we issue a single query to retrieve all the keys 
 class RegKey (Base):
     __tablename__       = 'keys'
     key_id              = Column (Integer, primary_key=True)
@@ -241,14 +262,14 @@ class RegKey (Base):
         if pointer: self.pointer=pointer
 
     def __repr__ (self):
-        result="[key key=%s..."%self.key[8:16]
-        try:    result += " user=%s"%self.user.record_id
-        except: result += " <orphan>"
-        result += "]"
+        result="<key id=%s key=%s..."%(self.key_id,self.key[8:16],)
+        try:    result += " user=%s"%self.reg_user.record_id
+        except: result += " no-user"
+        result += ">"
         return result
 
 ##############################
-# although the db needs of course to be reachable,
+# although the db needs of course to be reachable for the following functions
 # the schema management functions are here and not in alchemy
 # because the actual details of the classes need to be known
 # migrations: this code has no notion of the previous versions