Merge remote-tracking branch 'origin/geni-v3' into geni-v3
[sfa.git] / sfa / dummy / dummydriver.py
index 1fd0843..b33c85d 100644 (file)
@@ -2,8 +2,8 @@ import time
 import datetime
 #
 from sfa.util.faults import MissingSfaInfo, UnknownSfaType, \
-    RecordNotFound, SfaNotImplemented, SliverDoesNotExist
-
+    RecordNotFound, SfaNotImplemented, SliverDoesNotExist, SearchFailed, \
+    UnsupportedOperation, Forbidden
 from sfa.util.sfalogging import logger
 from sfa.util.defaultdict import defaultdict
 from sfa.util.sfatime import utcparse, datetime_to_string, datetime_to_epoch
@@ -11,12 +11,11 @@ from sfa.util.xrn import Xrn, hrn_to_urn, get_leaf
 from sfa.util.cache import Cache
 
 # one would think the driver should not need to mess with the SFA db, but..
-from sfa.storage.alchemy import dbsession
-from sfa.storage.model import RegRecord
+from sfa.storage.model import RegRecord, SliverAllocation
+from sfa.trust.credential import Credential
 
 # used to be used in get_ticket
 #from sfa.trust.sfaticket import SfaTicket
-
 from sfa.rspecs.version_manager import VersionManager
 from sfa.rspecs.rspec import RSpec
 
@@ -44,14 +43,42 @@ class DummyDriver (Driver):
     # the cache instance is a class member so it survives across incoming requests
     cache = None
 
-    def __init__ (self, config):
-        Driver.__init__ (self, config)
-        self.config = config
+    def __init__ (self, api):
+        Driver.__init__ (self, api)
+        config = api.config
         self.hrn = config.SFA_INTERFACE_HRN
         self.root_auth = config.SFA_REGISTRY_ROOT_AUTH
         self.shell = DummyShell (config)
         self.testbedInfo = self.shell.GetTestbedInfo()
  
+    def check_sliver_credentials(self, creds, urns):
+        # build list of cred object hrns
+        slice_cred_names = []
+        for cred in creds:
+            slice_cred_hrn = Credential(cred=cred).get_gid_object().get_hrn()
+            slice_cred_names.append(DummyXrn(xrn=slice_cred_hrn).dummy_slicename())
+
+        # look up slice name of slivers listed in urns arg
+        slice_ids = []
+        for urn in urns:
+            sliver_id_parts = Xrn(xrn=urn).get_sliver_id_parts()
+            try:
+                slice_ids.append(int(sliver_id_parts[0]))
+            except ValueError:
+                pass
+
+        if not slice_ids:
+             raise Forbidden("sliver urn not provided")
+
+        slices = self.shell.GetSlices({'slice_ids': slice_ids})
+        sliver_names = [slice['slice_name'] for slice in slices]
+
+        # make sure we have a credential for every specified sliver ierd
+        for sliver_name in sliver_names:
+            if sliver_name not in slice_cred_names:
+                msg = "Valid credential not found for target: %s" % sliver_name
+                raise Forbidden(msg)
+
     ########################################
     ########## registry oriented
     ########################################
@@ -308,7 +335,7 @@ class DummyDriver (Driver):
         
         # get the registry records
         user_list, users = [], {}
-        user_list = dbsession.query (RegRecord).filter(RegRecord.pointer.in_(user_ids))
+        user_list = self.api.dbsession().query (RegRecord).filter(RegRecord.pointer.in_(user_ids))
         # create a hrns keyed on the sfa record's pointer.
         # Its possible for multiple records to have the same pointer so
         # the dict's value will be a list of hrns.
@@ -359,7 +386,9 @@ class DummyDriver (Driver):
     def update_relation (self, subject_type, target_type, relation_name, subject_id, target_ids):
         # hard-wire the code for slice/user for now, could be smarter if needed
         if subject_type =='slice' and target_type == 'user' and relation_name == 'researcher':
-            subject=self.shell.GetSlices (subject_id)[0]
+            subject=self.shell.GetSlices ({'slice_id': subject_id})[0]
+            if 'user_ids' not in subject.keys():
+                 subject['user_ids'] = []
             current_target_ids = subject['user_ids']
             add_target_ids = list ( set (target_ids).difference(current_target_ids))
             del_target_ids = list ( set (current_target_ids).difference(target_ids))
@@ -380,186 +409,143 @@ class DummyDriver (Driver):
 
     def testbed_name (self): return "dummy"
 
-    # 'geni_request_rspec_versions' and 'geni_ad_rspec_versions' are mandatory
     def aggregate_version (self):
-        version_manager = VersionManager()
-        ad_rspec_versions = []
-        request_rspec_versions = []
-        for rspec_version in version_manager.versions:
-            if rspec_version.content_type in ['*', 'ad']:
-                ad_rspec_versions.append(rspec_version.to_dict())
-            if rspec_version.content_type in ['*', 'request']:
-                request_rspec_versions.append(rspec_version.to_dict()) 
-        return {
-            'testbed':self.testbed_name(),
-            'geni_request_rspec_versions': request_rspec_versions,
-            'geni_ad_rspec_versions': ad_rspec_versions,
-            }
-
-    def list_slices (self, creds, options):
-    
-        slices = self.shell.GetSlices()
-        slice_hrns = [slicename_to_hrn(self.hrn, slice['slice_name']) for slice in slices]
-        slice_urns = [hrn_to_urn(slice_hrn, 'slice') for slice_hrn in slice_hrns]
-    
-        return slice_urns
-        
-    # first 2 args are None in case of resource discovery
-    def list_resources (self, slice_urn, slice_hrn, creds, options):
-    
-        version_manager = VersionManager()
-        # get the rspec's return format from options
-        rspec_version = version_manager.get_version(options.get('geni_rspec_version'))
-        version_string = "rspec_%s" % (rspec_version)
-    
+        return {}
+
+    def list_resources (self, version=None, options={}):
         aggregate = DummyAggregate(self)
-        rspec =  aggregate.get_rspec(slice_xrn=slice_urn, version=rspec_version, 
-                                     options=options)
-    
+        rspec =  aggregate.list_resources(version=version, options=options)
         return rspec
-    
-    def sliver_status (self, slice_urn, slice_hrn):
-        # find out where this slice is currently running
-        slice_name = hrn_to_dummy_slicename(slice_hrn)
-        
-        slice = self.shell.GetSlices({'slice_name': slice_name})
-        if len(slices) == 0:        
-            raise SliverDoesNotExist("%s (used %s as slicename internally)" % (slice_hrn, slicename))
-        
-        # report about the local nodes only
-        nodes = self.shell.GetNodes({'node_ids':slice['node_ids']})
-
-        if len(nodes) == 0:
-            raise SliverDoesNotExist("You have not allocated any slivers here") 
-
-        # get login info
-        user = {}
-        keys = []
-        if slice['user_ids']:
-            users = self.shell.GetUsers({'user_ids': slice['user_ids']})
-            for user in users:
-                 keys.extend(user['keys'])
-
-            user.update({'urn': slice_urn,
-                         'login': slice['slice_name'],
-                         'protocol': ['ssh'],
-                         'port': ['22'],
-                         'keys': keys})
 
+    def describe(self, urns, version, options={}):
+        aggregate = DummyAggregate(self)
+        return aggregate.describe(urns, version=version, options=options)
     
-        result = {}
-        top_level_status = 'unknown'
-        if nodes:
-            top_level_status = 'ready'
-        result['geni_urn'] = slice_urn
-        result['dummy_login'] = slice['slice_name']
-        result['dummy_expires'] = datetime_to_string(utcparse(slice['expires']))
-        result['geni_expires'] = datetime_to_string(utcparse(slice['expires']))
-        
-        resources = []
-        for node in nodes:
-            res = {}
-            res['dummy_hostname'] = node['hostname']
-            res['geni_expires'] = datetime_to_string(utcparse(slice['expires']))
-            sliver_id = Xrn(slice_urn, type='slice', id=node['node_id'], authority=self.hrn).urn
-            res['geni_urn'] = sliver_id
-            res['geni_status'] = 'ready'
-            res['geni_error'] = ''
-            res['users'] = [users]  
-    
-            resources.append(res)
-            
-        result['geni_status'] = top_level_status
-        result['geni_resources'] = resources
-        return result
-
-    def create_sliver (self, slice_urn, slice_hrn, creds, rspec_string, users, options):
+    def status (self, urns, options={}):
+        aggregate = DummyAggregate(self)
+        desc =  aggregate.describe(urns, version='GENI 3')
+        status = {'geni_urn': desc['geni_urn'],
+                  'geni_slivers': desc['geni_slivers']}
+        return status
 
+        
+    def allocate (self, urn, rspec_string, expiration, options={}):
+        xrn = Xrn(urn)
         aggregate = DummyAggregate(self)
         slices = DummySlices(self)
-        sfa_peer = slices.get_sfa_peer(slice_hrn)
-        slice_record=None    
+        slice_record=None
+        users = options.get('geni_users', [])
         if users:
             slice_record = users[0].get('slice_record', {})
-    
+
         # parse rspec
         rspec = RSpec(rspec_string)
         requested_attributes = rspec.version.get_slice_attributes()
-        
+
         # ensure slice record exists
-        slice = slices.verify_slice(slice_hrn, slice_record, peer, sfa_peer, options=options)
-        # ensure user records exists
-        users = slices.verify_users(slice_hrn, slice, users, peer, sfa_peer, options=options)
-        
+        slice = slices.verify_slice(xrn.hrn, slice_record, expiration=expiration, options=options)
+        # ensure person records exists
+        #persons = slices.verify_persons(xrn.hrn, slice, users, peer, sfa_peer, options=options)
+
         # add/remove slice from nodes
-        requested_slivers = []
-        for node in rspec.version.get_nodes_with_slivers():
-            hostname = None
-            if node.get('component_name'):
-                hostname = node.get('component_name').strip()
-            elif node.get('component_id'):
-                hostname = xrn_to_hostname(node.get('component_id').strip())
-            if hostname:
-                requested_slivers.append(hostname)
-        nodes = slices.verify_slice_nodes(slice, requested_slivers, peer) 
-    
-        return aggregate.get_rspec(slice_xrn=slice_urn, version=rspec.version)
+        request_nodes = rspec.version.get_nodes_with_slivers()
+        nodes = slices.verify_slice_nodes(urn, slice, request_nodes)
 
-    def delete_sliver (self, slice_urn, slice_hrn, creds, options):
-        slicename = hrn_to_dummy_slicename(slice_hrn)
-        slices = self.shell.GetSlices({'slice_name': slicename})
-        if not slices:
-            return True
-        slice = slices[0]
-        
-        try:
-            self.shell.DeleteSliceFromNodes({'slice_id': slice['slice_id'], 'node_ids': slice['node_ids']})
-            return True
-        except:
-            return False
-    
-    def renew_sliver (self, slice_urn, slice_hrn, creds, expiration_time, options):
-        slicename = hrn_to_dummy_slicename(slice_hrn)
-        slices = self.shell.GetSlices({'slice_name': slicename})
-        if not slices:
-            raise RecordNotFound(slice_hrn)
-        slice = slices[0]
+        return aggregate.describe([xrn.get_urn()], version=rspec.version)
+
+    def provision(self, urns, options={}):
+        # update users
+        slices = DummySlices(self)
+        aggregate = DummyAggregate(self)
+        slivers = aggregate.get_slivers(urns)
+        slice = slivers[0]
+        geni_users = options.get('geni_users', [])
+        #users = slices.verify_users(None, slice, geni_users, options=options)
+        # update sliver allocation states and set them to geni_provisioned
+        sliver_ids = [sliver['sliver_id'] for sliver in slivers]
+        dbsession=self.api.dbsession()
+        SliverAllocation.set_allocations(sliver_ids, 'geni_provisioned',dbsession)
+        version_manager = VersionManager()
+        rspec_version = version_manager.get_version(options['geni_rspec_version'])
+        return self.describe(urns, rspec_version, options=options)
+
+    def delete(self, urns, options={}):
+        # collect sliver ids so we can update sliver allocation states after
+        # we remove the slivers.
+        aggregate = DummyAggregate(self)
+        slivers = aggregate.get_slivers(urns)
+        if slivers:
+            slice_id = slivers[0]['slice_id']
+            node_ids = []
+            sliver_ids = []
+            for sliver in slivers:
+                node_ids.append(sliver['node_id'])
+                sliver_ids.append(sliver['sliver_id'])
+
+            # determine if this is a peer slice
+            # xxx I wonder if this would not need to use PlSlices.get_peer instead 
+            # in which case plc.peers could be deprecated as this here
+            # is the only/last call to this last method in plc.peers
+            slice_hrn = DummyXrn(auth=self.hrn, slicename=slivers[0]['slice_name']).get_hrn()
+            try:
+                self.shell.DeleteSliceFromNodes({'slice_id': slice_id, 'node_ids': node_ids})
+                # delete sliver allocation states
+                dbsession=self.api.dbsession()
+                SliverAllocation.delete_allocations(sliver_ids,dbsession)
+            finally:
+                pass
+
+        # prepare return struct
+        geni_slivers = []
+        for sliver in slivers:
+            geni_slivers.append(
+                {'geni_sliver_urn': sliver['sliver_id'],
+                 'geni_allocation_status': 'geni_unallocated',
+                 'geni_expires': datetime_to_string(utcparse(sliver['expires']))})  
+        return geni_slivers
+
+    def renew (self, urns, expiration_time, options={}):
+        aggregate = DummyAggregate(self)
+        slivers = aggregate.get_slivers(urns)
+        if not slivers:
+            raise SearchFailed(urns)
+        slice = slivers[0]
         requested_time = utcparse(expiration_time)
         record = {'expires': int(datetime_to_epoch(requested_time))}
-        try:
-            self.shell.UpdateSlice({'slice_id': slice['slice_id'], 'fields':record})
-            return True
-        except:
-            return False
-
-    # set the 'enabled' tag to True
-    def start_slice (self, slice_urn, slice_hrn, creds):
-        slicename = hrn_to_dummy_slicename(slice_hrn)
-        slices = self.shell.GetSlices({'slice_name': slicename})
-        if not slices:
-            raise RecordNotFound(slice_hrn)
-        slice_id = slices[0]['slice_id']
-        slice_enabled = slices[0]['enabled'] 
-        # just update the slice enabled tag
-        if not slice_enabled:
-            self.shell.UpdateSlice({'slice_id': slice_id, 'fields': {'enabled': True}})
-        return 1
-
-    # set the 'enabled' tag to False
-    def stop_slice (self, slice_urn, slice_hrn, creds):
-        slicename = hrn_to_pl_slicename(slice_hrn)
-        slices = self.shell.GetSlices({'slice_name': slicename})
+        self.shell.UpdateSlice({'slice_id': slice['slice_id'], 'fileds': record})
+        description = self.describe(urns, 'GENI 3', options)
+        return description['geni_slivers']
+
+    def perform_operational_action (self, urns, action, options={}):
+        # Dummy doesn't support operational actions. Lets pretend like it
+        # supports start, but reject everything else.
+        action = action.lower()
+        if action not in ['geni_start']:
+            raise UnsupportedOperation(action)
+
+        # fault if sliver is not full allocated (operational status is geni_pending_allocation)
+        description = self.describe(urns, 'GENI 3', options)
+        for sliver in description['geni_slivers']:
+            if sliver['geni_operational_status'] == 'geni_pending_allocation':
+                raise UnsupportedOperation(action, "Sliver must be fully allocated (operational status is not geni_pending_allocation)")
+        #
+        # Perform Operational Action Here
+        #
+
+        geni_slivers = self.describe(urns, 'GENI 3', options)['geni_slivers']
+        return geni_slivers
+
+    def shutdown (self, xrn, options={}):
+        xrn = DummyXrn(xrn=xrn, type='slice')
+        slicename = xrn.pl_slicename()
+        slices = self.shell.GetSlices({'name': slicename}, ['slice_id'])
         if not slices:
             raise RecordNotFound(slice_hrn)
         slice_id = slices[0]['slice_id']
-        slice_enabled = slices[0]['enabled']
-        # just update the slice enabled tag
-        if slice_enabled:
-            self.shell.UpdateSlice({'slice_id': slice_id, 'fields': {'enabled': False}})
+        slice_tags = self.shell.GetSliceTags({'slice_id': slice_id, 'tagname': 'enabled'})
+        if not slice_tags:
+            self.shell.AddSliceTag(slice_id, 'enabled', '0')
+        elif slice_tags[0]['value'] != "0":
+            tag_id = slice_tags[0]['slice_tag_id']
+            self.shell.UpdateSliceTag(tag_id, '0')
         return 1
-    
-    def reset_slice (self, slice_urn, slice_hrn, creds):
-        raise SfaNotImplemented ("reset_slice not available at this interface")
-    
-    def get_ticket (self, slice_urn, slice_hrn, creds, rspec_string, options):
-        raise SfaNotImplemented,"DummyDriver.get_ticket needs a rewrite"