refactored
authorTony Mack <tmack@cs.princeton.edu>
Mon, 14 Sep 2009 23:11:16 +0000 (23:11 +0000)
committerTony Mack <tmack@cs.princeton.edu>
Mon, 14 Sep 2009 23:11:16 +0000 (23:11 +0000)
sfa/plc/sfa-import-plc.py
sfa/plc/sfaImport.py

index 5da81bb..9b69749 100755 (executable)
@@ -54,7 +54,8 @@ def main():
     AuthHierarchy = sfaImporter.AuthHierarchy
     TrustedRoots = sfaImporter.TrustedRoots
     table = GeniTable()
-    table.create()
+    if not table.exists():
+        table.create()
 
     if not level1_auth or level1_auth in ['']:
         level1_auth = None
@@ -73,14 +74,80 @@ def main():
     authority = AuthHierarchy.get_auth_info(import_auth)
     TrustedRoots.add_gid(authority.get_gid_object())
 
-    sites = shell.GetSites(plc_auth, {'peer_id': None})
-    # create a fake internet2 site first
-    i2site = {'name': 'Internet2', 'abbreviated_name': 'I2',
+    if ".vini" in import_auth and import_auth.endswith('vini'):
+        # create a fake internet2 site first
+        i2site = {'name': 'Internet2', 'abbreviated_name': 'I2',
                     'login_base': 'internet2', 'site_id': -1}
-    sfaImporter.import_site(import_auth, i2site)
+        sfaImporter.import_site(import_auth, i2site)
+   
+    # create dict of all existing sfa records
+    existing_records = {}
+    existing_hrns = []
+    results = table.find()
+    for result in results:
+        existing_records[(result['hrn'], result['type'])] = result
+        existing_hrns.append(result['hrn']) 
             
+    # Get all plc sites
+    sites = shell.GetSites(plc_auth)
+    
+    # Get all plc users
+    persons = shell.GetPersons(plc_auth, {}, ['person_id', 'email', 'key_ids'])
+    persons_dict = {}
+    for person in persons:
+        persons_dict[person['person_id']] = person
+
+    # Get all plc nodes  
+    nodes = shell.GetNodes(plc_auth, {}, ['node_id', 'hostname'])
+    nodes_dict = {}
+    for node in nodes:
+        nodes_dict[node['node_id']] = node
+
+    # Get all plc slices
+    slices = shell.GetSlices(plc_auth, {}, ['slice_id', 'name'])
+    slices_dict = {}
+    for slice in slices:
+        slices_dict[slice['slice_id']] = slice
+
+    # start importing 
     for site in sites:
-        sfaImporter.import_site(import_auth, site)
+        site_hrn = import_auth + "." + site['login_base']
+        # import if hrn is not in list of existing hrns or if the hrn exists
+        # but its not a site record
+        if site_hrn not in existing_hrns or \
+           (site_hrn, 'authority') not in existing_records:   
+            sfaImporter.import_site(import_auth, site)
+
+        
+        # import node records
+        for node_id in site['node_ids']:
+            if node_id not in nodes_dict:
+                continue 
+            node = nodes_dict[node_id]
+            hrn =  hostname_to_hrn(import_auth, site['login_base'], node['hostname'])
+            if hrn not in existing_hrns or \
+               (hrn, 'node') not in existing_records:
+                sfaImporter.import_node(site_hrn, node)
+
+        # import slices
+        for slice_id in site['slice_ids']:
+            if slice_id not in slices_dict:
+                continue 
+            slice = slices_dict[slice_id]
+            hrn = slicename_to_hrn(import_auth, slice['name'])
+            if hrn not in existing_hrns or \
+               (hrn, 'slice') not in existing_records:
+                sfaImporter.import_slice(site_hrn, slice)      
 
+        # import persons
+        for person_id in site['person_ids']:
+            if person_id not in persons_dict:
+                continue 
+            person = persons_dict[person_id]
+            hrn = email_to_hrn(site_hrn, person['email'])
+            if hrn not in existing_hrns or \
+               (hrn, 'user') not in existing_records:
+                sfaImporter.import_person(site_hrn, person)
+        
 if __name__ == "__main__":
     main()
index d6691e7..da071e2 100644 (file)
@@ -48,19 +48,6 @@ def cleanup_string(str):
     str = str.replace('"', "_")
     return str
 
-def person_to_hrn(parent_hrn, person):
-    # the old way - Lastname_Firstname
-    #personname = person['last_name'] + "_" + person['first_name']
-
-    # the new way - use email address up to the "@"
-    personname = person['email'].split("@")[0]
-
-    personname = cleanup_string(personname)
-
-    hrn = parent_hrn + "." + personname
-    return hrn
-
-
 class sfaImport:
 
     def __init__(self):
@@ -114,43 +101,41 @@ class sfaImport:
 
     def import_person(self, parent_hrn, person):
         AuthHierarchy = self.AuthHierarchy
-        hrn = person_to_hrn(parent_hrn, person)
+        hrn = email_to_hrn(parent_hrn, person['email'])
 
         # ASN.1 will have problems with hrn's longer than 64 characters
         if len(hrn) > 64:
             hrn = hrn[:64]
 
         trace("Import: importing person " + hrn)
-
-        table = GeniTable()
-
         key_ids = []
         if 'key_ids' in person and person['key_ids']:
             key_ids = person["key_ids"]
-
             # get the user's private key from the SSH keys they have uploaded
             # to planetlab
             keys = self.shell.GetKeys(self.plc_auth, key_ids)
             key = keys[0]['key']
             pkey = convert_public_key(key)
+            if not pkey:
+                pkey = Keypair(create=True)
         else:
             # the user has no keys
             trace("   person " + hrn + " does not have a PL public key")
-
             # if a key is unavailable, then we still need to put something in the
             # user's GID. So make one up.
             pkey = Keypair(create=True)
 
         # create the gid
+        print "*", hrn
         person_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
-        person_record = table.find({'type': 'user', 'hrn': hrn})
-        if not person_record:
-            trace("  inserting user record for " + hrn)
-            person_record = GeniRecord(hrn=hrn, gid=person_gid, type="user", pointer=person['person_id'])
+        table = GeniTable()
+        person_record = GeniRecord(hrn=hrn, gid=person_gid, type="user", pointer=person['person_id'])
+        try:
             table.insert(person_record)
-        else:
-            trace("  updating user record for " + hrn)
-            person_record = GeniRecord(hrn=hrn, gid=person_gid, type="user", pointer=person['person_id'])
+        except:
+            trace("Import: %s exists, updating " % hrn)
+            existing_record = table.find(person_record)
+            person_record['record_id'] = existing_record['record_id']
             table.update(person_record)
 
     def import_slice(self, parent_hrn, slice):
@@ -165,43 +150,45 @@ class sfaImport:
         hrn = parent_hrn + "." + slicename
         trace("Import: importing slice " + hrn)
 
+        pkey = Keypair(create=True)
+        slice_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
+        slice_record = GeniRecord(hrn=hrn, gid=slice_gid, type="slice", pointer=slice['slice_id'])
         table = GeniTable()
-
-        slice_record = table.find({'type': 'sslice', 'hrn': hrn})
-        if not slice_record:
-            pkey = Keypair(create=True)
-            slice_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
-            slice_record = GeniRecord(hrn=hrn, gid=slice_gid, type="slice", pointer=slice['slice_id'])
-            trace("  inserting slice record for " + hrn)
+        try:
             table.insert(slice_record)
+        except:
+            trace("Import: %s exists, updating " % hrn)
+            existing_record = table.find(slice_record)
+            slice_record['record_id'] = existing_record['record_id']
+            table.update(slice_record)
 
     def import_node(self, parent_hrn, node):
         AuthHierarchy = self.AuthHierarchy
         nodename = node['hostname'].split(".")[0]
         nodename = cleanup_string(nodename)
-
+        
         if not nodename:
             error("Import_node: failed to parse node name " + node['hostname'])
             return
 
         hrn = parent_hrn + "." + nodename
-
+        trace("Import: importing node " + hrn)
         # ASN.1 will have problems with hrn's longer than 64 characters
         if len(hrn) > 64:
             hrn = hrn[:64]
 
-        trace("Import: importing node " + hrn)
-
         table = GeniTable()
-
         node_record = table.find({'type': 'node', 'hrn': hrn})
-        if not node_record:
-            pkey = Keypair(create=True)
-            node_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
-            node_record = GeniRecord(hrn=hrn, gid=node_gid, type="node", pointer=node['node_id'])
-            trace("  inserting node record for " + hrn)
+        pkey = Keypair(create=True)
+        node_gid = AuthHierarchy.create_gid(hrn, create_uuid(), pkey)
+        node_record = GeniRecord(hrn=hrn, gid=node_gid, type="node", pointer=node['node_id'])
+        try:
             table.insert(node_record)
-
+        except:
+            trace("Import: %s exists, updating " % hrn)
+            existing_record = table.find(node_record)
+            node_record['record_id'] = existing_record['record_id']
+            table.update(node_record)
 
     
     def import_site(self, parent_hrn, site):
@@ -233,37 +220,39 @@ class sfaImport:
         auth_info = AuthHierarchy.get_auth_info(hrn)
 
         table = GeniTable()
-
-        auth_record = table.find({'type': 'authority', 'hrn': 'hrn'})
-        if not auth_record:
-            auth_record = GeniRecord(hrn=hrn, gid=auth_info.get_gid_object(), type="authority", pointer=site['site_id'])
-            trace("  inserting authority record for " + hrn)
+        auth_record = GeniRecord(hrn=hrn, gid=auth_info.get_gid_object(), type="authority", pointer=site['site_id'])
+        try:
             table.insert(auth_record)
-
-        if 'person_ids' in site:
-            for person_id in site['person_ids']:
-                persons = shell.GetPersons(plc_auth, [person_id])
-                if persons:
-                    try:
-                        self.import_person(hrn, persons[0])
-                    except Exception, e:
-                        trace("Failed to import: %s (%s)" % (persons[0], e))
-        if 'slice_ids' in site:
-            for slice_id in site['slice_ids']:
-                slices = shell.GetSlices(plc_auth, [slice_id])
-                if slices:
-                    try:
-                        self.import_slice(hrn, slices[0])
-                    except Exception, e:
-                        trace("Failed to import: %s (%s)" % (slices[0], e))
-        if 'node_ids' in site:
-            for node_id in site['node_ids']:
-                nodes = shell.GetNodes(plc_auth, [node_id])
-                if nodes:
-                    try:
-                        self.import_node(hrn, nodes[0])
-                    except Exception, e:
-                        trace("Failed to import: %s (%s)" % (nodes[0], e))     
+        except:
+            trace("Import: %s exists, updating " % hrn)
+            existing_record = table.find(auth_record)
+            auth_record['record_id'] = existing_record['record_id']
+            table.update(auth_record)
+
+        #if 'person_ids' in site:
+        #    for person_id in site['person_ids']:
+        #        persons = shell.GetPersons(plc_auth, [person_id])
+        #        if persons:
+        #            try:
+        #                self.import_person(hrn, persons[0])
+        #            except Exception, e:
+        #                trace("Failed to import: %s (%s)" % (persons[0], e))
+        #if 'slice_ids' in site:
+        #    for slice_id in site['slice_ids']:
+        #        slices = shell.GetSlices(plc_auth, [slice_id])
+        #        if slices:
+        #            try:
+        #                self.import_slice(hrn, slices[0])
+        #            except Exception, e:
+        #                trace("Failed to import: %s (%s)" % (slices[0], e))
+        #if 'node_ids' in site:
+        #    for node_id in site['node_ids']:
+        #        nodes = shell.GetNodes(plc_auth, [node_id])
+        #        if nodes:
+        #            try:
+        #                self.import_node(hrn, nodes[0])
+        #            except Exception, e:
+        #                trace("Failed to import: %s (%s)" % (nodes[0], e))     
 
     def delete_record(self, parent_hrn, object, type):
         # get the hrn