Removed parsing.py and parse_filter function in GetNodes and GetSites.
[sfa.git] / sfa / senslab / slabdriver.py
index dcd3d9b..d2c98d9 100644 (file)
@@ -13,9 +13,7 @@ from sfa.storage.record import Record
 from sfa.storage.alchemy import dbsession
 from sfa.storage.model import RegRecord
 
-
-from sfa.trust.certificate import *
-from sfa.trust.credential import *
+from sfa.trust.credential import Credential
 from sfa.trust.gid import GID
 
 from sfa.managers.driver import Driver
@@ -29,23 +27,15 @@ from sfa.util.plxrn import slicename_to_hrn, hostname_to_hrn, hrn_to_pl_slicenam
 # is taken care of 
 # SlabDriver should be really only about talking to the senslab testbed
 
-## thierry : please avoid wildcard imports :)
+
 from sfa.senslab.OARrestapi import  OARrestapi
 from sfa.senslab.LDAPapi import LDAPapi
 
-from sfa.senslab.parsing import parse_filter
 from sfa.senslab.slabpostgres import SlabDB, slab_dbsession,SliceSenslab
 from sfa.senslab.slabaggregate import SlabAggregate
 from sfa.senslab.slabslices import SlabSlices
 
-def list_to_dict(recs, key):
-    """
-    convert a list of dictionaries into a dictionary keyed on the 
-    specified dictionary key 
-    """
 
-    keys = [rec[key] for rec in recs]
-    return dict(zip(keys, recs))
 
 # thierry : note
 # this inheritance scheme is so that the driver object can receive
@@ -60,9 +50,6 @@ class SlabDriver(Driver):
     
         self.root_auth = config.SFA_REGISTRY_ROOT_AUTH
 
-        
-       print >>sys.stderr, "\r\n_____________ SFA SENSLAB DRIVER \r\n" 
-
         self.oar = OARrestapi()
        self.ldap = LDAPapi()
         self.time_format = "%Y-%m-%d %H:%M:%S"
@@ -71,39 +58,48 @@ class SlabDriver(Driver):
         
     
     def sliver_status(self,slice_urn,slice_hrn):
-        # receive a status request for slice named urn/hrn urn:publicid:IDN+senslab+nturro_slice hrn senslab.nturro_slice
-        # shall return a structure as described in
-        # http://groups.geni.net/geni/wiki/GAPI_AM_API_V2#SliverStatus
-        # NT : not sure if we should implement this or not, but used by sface.
+        """Receive a status request for slice named urn/hrn 
+        urn:publicid:IDN+senslab+nturro_slice hrn senslab.nturro_slice
+        shall return a structure as described in
+        http://groups.geni.net/geni/wiki/GAPI_AM_API_V2#SliverStatus
+        NT : not sure if we should implement this or not, but used by sface.
         
-
+        """
+        
+        #First get the slice with the slice hrn
         sl = self.GetSlices(slice_filter= slice_hrn, filter_type = 'slice_hrn')
         if len(sl) is 0:
             raise SliverDoesNotExist("%s  slice_hrn" % (slice_hrn))
-
-        print >>sys.stderr, "\r\n \r\n_____________ Sliver status urn %s hrn %s sl %s \r\n " %(slice_urn,slice_hrn,sl)
+        
+        nodes_in_slice = sl['node_ids']
+        if len(nodes_in_slice) is 0:
+            raise SliverDoesNotExist("No slivers allocated ") 
+        
+        logger.debug("Slabdriver - sliver_status Sliver status urn %s hrn %s sl\
+                             %s \r\n " %(slice_urn,slice_hrn,sl) )
+                             
         if sl['oar_job_id'] is not -1:
-    
-            # report about the local nodes only
-            nodes_all = self.GetNodes({'hostname':sl['node_ids']},
+            #A job is running on Senslab for this slice
+            # report about the local nodes that are in the slice only
+            
+            nodes_all = self.GetNodes({'hostname':nodes_in_slice},
                             ['node_id', 'hostname','site','boot_state'])
             nodeall_byhostname = dict([(n['hostname'], n) for n in nodes_all])
-            nodes = sl['node_ids']
-            if len(nodes) is 0:
-                raise SliverDoesNotExist("No slivers allocated ") 
-                    
+            
 
             result = {}
             top_level_status = 'unknown'
             if nodes:
                 top_level_status = 'ready'
             result['geni_urn'] = slice_urn
-            result['pl_login'] = sl['job_user']
-            #result['slab_login'] = sl['job_user']
+            result['pl_login'] = sl['job_user'] #For compatibility
+
             
             timestamp = float(sl['startTime']) + float(sl['walltime']) 
-            result['pl_expires'] = strftime(self.time_format, gmtime(float(timestamp)))
-            #result['slab_expires'] = strftime(self.time_format, gmtime(float(timestamp)))
+            result['pl_expires'] = strftime(self.time_format, \
+                                                    gmtime(float(timestamp)))
+            #result['slab_expires'] = strftime(self.time_format,\
+                                                     #gmtime(float(timestamp)))
             
             resources = []
             for node in nodes:
@@ -113,8 +109,10 @@ class SlabDriver(Driver):
                 
                 res['pl_hostname'] = nodeall_byhostname[node]['hostname']
                 res['pl_boot_state'] = nodeall_byhostname[node]['boot_state']
-                res['pl_last_contact'] = strftime(self.time_format, gmtime(float(timestamp)))
-                sliver_id = urn_to_sliver_id(slice_urn, sl['record_id_slice'],nodeall_byhostname[node]['node_id'] ) 
+                res['pl_last_contact'] = strftime(self.time_format, \
+                                                    gmtime(float(timestamp)))
+                sliver_id = urn_to_sliver_id(slice_urn, sl['record_id_slice'], \
+                                            nodeall_byhostname[node]['node_id']) 
                 res['geni_urn'] = sliver_id 
                 if nodeall_byhostname[node]['boot_state'] == 'Alive':
 
@@ -383,11 +381,11 @@ class SlabDriver(Driver):
 
         return True
             
-    def GetPeers (self,auth = None, peer_filter=None, return_fields=None):
+    def GetPeers (self,auth = None, peer_filter=None, return_fields_list=None):
 
         existing_records = {}
         existing_hrns_by_types= {}
-        print >>sys.stderr, "\r\n \r\n SLABDRIVER GetPeers auth = %s, peer_filter %s, return_field %s " %(auth , peer_filter, return_fields)
+        print >>sys.stderr, "\r\n \r\n SLABDRIVER GetPeers auth = %s, peer_filter %s, return_field %s " %(auth , peer_filter, return_fields_list)
         all_records = dbsession.query(RegRecord).filter(RegRecord.type.like('%authority%')).all()
         for record in all_records:
             existing_records[(record.hrn,record.type)] = record
@@ -418,33 +416,36 @@ class SlabDriver(Driver):
                 pass
                 
         return_records = records_list
-        if not peer_filter and not return_fields:
+        if not peer_filter and not return_fields_list:
             return records_list
-        #return_records = parse_filter(records_list,peer_filter, 'peers', return_fields) 
+
        
         print >>sys.stderr, "\r\n \r\n SLABDRIVER GetPeers   return_records %s " %(return_records)
         return return_records
         
      
-            
-    def GetPersons(self, person_filter=None, return_fields=None):
-        
-        #if isinstance(person_filter,list):
-            #for f in person_filter:
-                #person = self.ldap.ldapSearch(f)
-        #if isinstance(person_filter,dict):    
-        person_list = self.ldap.ldapFindHrn({'authority': self.root_auth })
-        
-        #check = False
-        #if person_filter and isinstance(person_filter, dict):
-            #for k in  person_filter.keys():
-                #if k in person_list[0].keys():
-                    #check = True
+    #TODO  : Handling OR request in make_ldap_filters_from_records instead of the for loop 
+    #over the records' list
+    def GetPersons(self, person_filter=None, return_fields_list=None):
+        """
+        person_filter should be a list of dictionnaries when not set to None.
+        Returns a list of users found.
+       
+        """
+        print>>sys.stderr, "\r\n \r\n \t\t\t GetPersons person_filter %s" %(person_filter)
+        person_list = []
+        if person_filter and isinstance(person_filter,list):
+        #If we are looking for a list of users (list of dict records)
+        #Usually the list contains only one user record
+            for f in person_filter:
+                person = self.ldap.LdapFindUser(f)
+                person_list.append(person)
+          
+        else:
+              person_list  = self.ldap.LdapFindUser()  
                     
-        return_person_list = parse_filter(person_list,person_filter ,'persons', return_fields)
-        if return_person_list:
-            print>>sys.stderr, " \r\n GetPersons person_filter %s return_fields %s  " %(person_filter,return_fields)
-            return return_person_list
+        return person_list
 
     def GetTimezone(self):
         server_timestamp,server_tz = self.oar.parser.SendRequest("GET_timezone")
@@ -462,7 +463,7 @@ class SlabDriver(Driver):
         print>>sys.stderr, "\r\n \r\n  jobid  DeleteJobs %s "  %(answer)
         
                 
-    def GetJobs(self,job_id= None, resources=True,return_fields=None, username = None):
+    def GetJobs(self,job_id= None, resources=True,return_fields_list=None, username = None):
         #job_resources=['reserved_resources', 'assigned_resources','job_id', 'job_uri', 'assigned_nodes',\
         #'api_timestamp']
         #assigned_res = ['resource_id', 'resource_uri']
@@ -510,37 +511,76 @@ class SlabDriver(Driver):
             
     def GetReservedNodes(self):
         # this function returns a list of all the nodes already involved in an oar job
-
+       #jobs=self.oar.parser.SendRequest("GET_reserved_nodes") 
        jobs=self.oar.parser.SendRequest("GET_jobs_details") 
        nodes=[]
        for j in jobs :
           nodes=j['assigned_network_address']+nodes
        return nodes
      
-    def GetNodes(self,node_filter= None, return_fields=None):
-        node_dict =self.oar.parser.SendRequest("GET_resources_full")
-
+    def GetNodes(self,node_filter_dict = None, return_fields_list = None):
+        """
+        node_filter_dict : dictionnary of lists
+        
+        """
+        node_dict_by_id = self.oar.parser.SendRequest("GET_resources_full")
+        node_dict_list = node_dict_by_id.values()
+        
+        #No  filtering needed return the list directly
+        if not (node_filter_dict or return_fields_list):
+            return node_dict_list
+        
         return_node_list = []
-        if not (node_filter or return_fields):
-                return_node_list = node_dict.values()
-                return return_node_list
-    
-        return_node_list= parse_filter(node_dict.values(),node_filter ,'node', return_fields)
+        if node_filter_dict:
+            for filter_key in node_filter_dict:
+                try:
+                    #Filter the node_dict_list by each value contained in the 
+                    #list node_filter_dict[filter_key]
+                    for value in node_filter_dict[filter_key]:
+                        for node in node_dict_list:
+                            if node[filter_key] == value:
+                                if return_fields_list :
+                                   tmp = {}
+                                   for k in return_fields_list:
+                                        tmp[k] = node[k]     
+                                   return_node_list.append(tmp)
+                                else:
+                                   return_node_list.append(node)
+                except KeyError:
+                    logger.log_exc("GetNodes KeyError")
+                    return
+
+
         return return_node_list
     
   
-    def GetSites(self, site_filter = None, return_fields=None):
-        site_dict =self.oar.parser.SendRequest("GET_sites")
+    def GetSites(self, site_filter_name = None, return_fields_list = None):
+        site_dict = self.oar.parser.SendRequest("GET_sites")
+        #site_dict : dict where the key is the sit ename
         return_site_list = []
-        if not ( site_filter or return_fields):
+        if not ( site_filter_name or return_fields_list):
                 return_site_list = site_dict.values()
                 return return_site_list
-    
-        return_site_list = parse_filter(site_dict.values(), site_filter,'site', return_fields)
+        
+        if site_filter_name in site_dict:
+            if return_fields_list:
+                for field in return_fields_list:
+                    tmp = {}
+                    Create 
+                    try:
+                        tmp[field] = site_dict[site_filter_name][field]
+                    except KeyError:
+                        logger.error("GetSites KeyError %s "%(field))
+                        return None
+                return_site_list.append(tmp)
+            else:
+                return_site_list.append( site_dict[site_filter_name])
+            
+
         return return_site_list
         
 
-    def GetSlices(self,slice_filter = None, filter_type = None, return_fields=None):
+    def GetSlices(self,slice_filter = None, filter_type = None, return_fields_list=None):
         return_slice_list = []
         slicerec  = {}
         rec = {}
@@ -583,8 +623,8 @@ class SlabDriver(Driver):
 
         print >>sys.stderr, " \r\n \r\n \tSLABDRIVER.PY  GetSlices  slices %s slice_filter %s " %(return_slice_list,slice_filter)
         
-        #if return_fields:
-            #return_slice_list  = parse_filter(sliceslist, slice_filter,'slice', return_fields)
+        #if return_fields_list:
+            #return_slice_list  = parse_filter(sliceslist, slice_filter,'slice', return_fields_list)
         
         
                     
@@ -905,7 +945,8 @@ class SlabDriver(Driver):
                     'person_ids':[rec['record_id_user']]})
                     #retourne une liste 100512
                     
-                    user_slab = self.GetPersons({'hrn':recuser.hrn})
+                    #GetPersons takes [] as filters 
+                    user_slab = self.GetPersons([{'hrn':recuser.hrn}])
                     
 
                     rec.update({'type':'slice','hrn':rec['slice_hrn']})