Merge branch 'geni-v3' into dbsession
[sfa.git] / sfa / planetlab / pldriver.py
index 6fb8d60..0978a57 100644 (file)
@@ -10,7 +10,6 @@ from sfa.util.xrn import Xrn, hrn_to_urn, get_leaf
 from sfa.util.cache import Cache
 
 # one would think the driver should not need to mess with the SFA db, but..
-from sfa.storage.alchemy import dbsession
 from sfa.storage.model import RegRecord, SliverAllocation
 from sfa.trust.credential import Credential
 
@@ -45,8 +44,9 @@ class PlDriver (Driver):
     # the cache instance is a class member so it survives across incoming requests
     cache = None
 
-    def __init__ (self, config):
-        Driver.__init__ (self, config)
+    def __init__ (self, api):
+        Driver.__init__ (self, api)
+        config=api.config
         self.shell = PlShell (config)
         self.cache=None
         if config.SFA_AGGREGATE_CACHING:
@@ -115,6 +115,7 @@ class PlDriver (Driver):
                 if 'max_slices' not in pl_record:
                     pl_record['max_slices']=2
                 pointer = self.shell.AddSite(pl_record)
+                self.shell.SetSiteHrn(int(pointer), hrn)
             else:
                 pointer = sites[0]['site_id']
 
@@ -126,6 +127,7 @@ class PlDriver (Driver):
             slices = self.shell.GetSlices([pl_record['name']])
             if not slices:
                  pointer = self.shell.AddSlice(pl_record)
+                 self.shell.SetSliceHrn(int(pointer), hrn)
             else:
                  pointer = slices[0]['slice_id']
 
@@ -138,6 +140,7 @@ class PlDriver (Driver):
                 can_add = ['first_name', 'last_name', 'title','email', 'password', 'phone', 'url', 'bio']
                 add_person_dict=dict ( [ (k,sfa_record[k]) for k in sfa_record if k in can_add ] )
                 pointer = self.shell.AddPerson(add_person_dict)
+                self.shell.SetPersonHrn(int(pointer), hrn)
             else:
                 pointer = persons[0]['person_id']
     
@@ -168,6 +171,7 @@ class PlDriver (Driver):
             nodes = self.shell.GetNodes([pl_record['hostname']])
             if not nodes:
                 pointer = self.shell.AddNode(login_base, pl_record)
+                self.shell.SetNodeHrn(int(pointer), hrn)
             else:
                 pointer = nodes[0]['node_id']
     
@@ -186,12 +190,14 @@ class PlDriver (Driver):
 
         if (type == "authority"):
             self.shell.UpdateSite(pointer, new_sfa_record)
+            self.shell.SetSiteHrn(pointer, hrn)
     
         elif type == "slice":
             pl_record=self.sfa_fields_to_pl_fields(type, hrn, new_sfa_record)
             if 'name' in pl_record:
                 pl_record.pop('name')
                 self.shell.UpdateSlice(pointer, pl_record)
+                self.shell.SetSliceHrn(pointer, hrn)
     
         elif type == "user":
             # SMBAKER: UpdatePerson only allows a limited set of fields to be
@@ -209,6 +215,7 @@ class PlDriver (Driver):
             if 'email' in update_fields and not update_fields['email']:
                 del update_fields['email']
             self.shell.UpdatePerson(pointer, update_fields)
+            self.shell.SetPersonHrn(pointer, hrn)
     
             if new_key:
                 # must check this key against the previous one if it exists
@@ -501,7 +508,7 @@ class PlDriver (Driver):
         
         # get the registry records
         person_list, persons = [], {}
-        person_list = dbsession.query (RegRecord).filter(RegRecord.pointer.in_(person_ids))
+        person_list = self.api.dbsession().query (RegRecord).filter(RegRecord.pointer.in_(person_ids))
         # create a hrns keyed on the sfa record's pointer.
         # Its possible for multiple records to have the same pointer so
         # the dict's value will be a list of hrns.
@@ -680,7 +687,8 @@ class PlDriver (Driver):
         slices.handle_peer(None, None, persons, peer)
         # update sliver allocation states and set them to geni_provisioned
         sliver_ids = [sliver['sliver_id'] for sliver in slivers]
-        SliverAllocation.set_allocations(sliver_ids, 'geni_provisioned')
+        dbsession=self.api.dbsession()
+        SliverAllocation.set_allocations(sliver_ids, 'geni_provisioned',dbsession)
         version_manager = VersionManager()
         rspec_version = version_manager.get_version(options['geni_rspec_version']) 
         return self.describe(urns, rspec_version, options=options)
@@ -719,7 +727,8 @@ class PlDriver (Driver):
                     self.shell.DeleteLeases(leases_ids)
      
                 # delete sliver allocation states
-                SliverAllocation.delete_allocations(sliver_ids)
+                dbsession=self.api.dbsession()
+                SliverAllocation.delete_allocations(sliver_ids,dbsession)
             finally:
                 if peer:
                     self.shell.BindObjectToPeer('slice', slice_id, peer, slice['peer_slice_id'])