Fix dummy driver
[sfa.git] / sfa / dummy / dummydriver.py
1 import time
2 import datetime
3 #
4 from sfa.util.faults import MissingSfaInfo, UnknownSfaType, \
5     RecordNotFound, SfaNotImplemented, SliverDoesNotExist
6
7 from sfa.util.sfalogging import logger
8 from sfa.util.defaultdict import defaultdict
9 from sfa.util.sfatime import utcparse, datetime_to_string, datetime_to_epoch
10 from sfa.util.xrn import Xrn, hrn_to_urn, get_leaf
11 from sfa.util.cache import Cache
12
13 # one would think the driver should not need to mess with the SFA db, but..
14 from sfa.storage.alchemy import dbsession
15 from sfa.storage.model import RegRecord
16
17 # used to be used in get_ticket
18 #from sfa.trust.sfaticket import SfaTicket
19
20 from sfa.rspecs.version_manager import VersionManager
21 from sfa.rspecs.rspec import RSpec
22
23 # the driver interface, mostly provides default behaviours
24 from sfa.managers.driver import Driver
25
26 from sfa.dummy.dummyshell import DummyShell
27 from sfa.dummy.dummyaggregate import DummyAggregate
28 from sfa.dummy.dummyslices import DummySlices
29 from sfa.dummy.dummyxrn import DummyXrn, slicename_to_hrn, hostname_to_hrn, hrn_to_dummy_slicename, xrn_to_hostname
30
31
32 def list_to_dict(recs, key):
33     """
34     convert a list of dictionaries into a dictionary keyed on the 
35     specified dictionary key 
36     """
37     return dict ( [ (rec[key],rec) for rec in recs ] )
38
39 #
40 # DummyShell is just an xmlrpc serverproxy where methods can be sent as-is; 
41
42 class DummyDriver (Driver):
43
44     # the cache instance is a class member so it survives across incoming requests
45     cache = None
46
47     def __init__ (self, config):
48         Driver.__init__ (self, config)
49         self.config = config
50         self.hrn = config.SFA_INTERFACE_HRN
51         self.root_auth = config.SFA_REGISTRY_ROOT_AUTH
52         self.shell = DummyShell (config)
53         self.testbedInfo = self.shell.GetTestbedInfo()
54  
55     ########################################
56     ########## registry oriented
57     ########################################
58
59     def augment_records_with_testbed_info (self, sfa_records):
60         return self.fill_record_info (sfa_records)
61
62     ########## 
63     def register (self, sfa_record, hrn, pub_key):
64         type = sfa_record['type']
65         dummy_record = self.sfa_fields_to_dummy_fields(type, hrn, sfa_record)
66         
67         if type == 'authority':
68             pointer = -1
69
70         elif type == 'slice':
71             slices = self.shell.GetSlices({'slice_name': dummy_record['slice_name']})
72             if not slices:
73                  pointer = self.shell.AddSlice(dummy_record)
74             else:
75                  pointer = slices[0]['slice_id']
76
77         elif type == 'user':
78             users = self.shell.GetUsers({'email':sfa_record['email']})
79             if not users:
80                 pointer = self.shell.AddUser(dummy_record)
81             else:
82                 pointer = users[0]['user_id']
83     
84             # Add the user's key
85             if pub_key:
86                 self.shell.AddUserKey({'user_id' : pointer, 'key' : pub_key})
87
88         elif type == 'node':
89             nodes = self.shell.GetNodes(dummy_record['hostname'])
90             if not nodes:
91                 pointer = self.shell.AddNode(dummy_record)
92             else:
93                 pointer = users[0]['node_id']
94     
95         return pointer
96         
97     ##########
98     def update (self, old_sfa_record, new_sfa_record, hrn, new_key):
99         pointer = old_sfa_record['pointer']
100         type = old_sfa_record['type']
101         dummy_record=self.sfa_fields_to_dummy_fields(type, hrn, new_sfa_record)
102
103         # new_key implemented for users only
104         if new_key and type not in [ 'user' ]:
105             raise UnknownSfaType(type)
106
107     
108         if type == "slice":
109             self.shell.UpdateSlice({'slice_id': pointer, 'fields': dummy_record})
110     
111         elif type == "user":
112             self.shell.UpdateUser({'user_id': pointer, 'fields': dummy_record})
113
114             if new_key:
115                 self.shell.AddUserKey({'user_id' : pointer, 'key' : new_key})
116
117         elif type == "node":
118             self.shell.UpdateNode({'node_id': pointer, 'fields': dummy_record})
119
120
121         return True
122         
123
124     ##########
125     def remove (self, sfa_record):
126         type=sfa_record['type']
127         pointer=sfa_record['pointer']
128         if type == 'user':
129             self.shell.DeleteUser({'user_id': pointer})
130         elif type == 'slice':
131             self.shell.DeleteSlice({'slice_id': pointer})
132         elif type == 'node':
133             self.shell.DeleteNode({'node_id': pointer})
134
135         return True
136
137
138
139
140
141     ##
142     # Convert SFA fields to Dummy testbed fields for use when registering or updating
143     # registry record in the dummy testbed
144     #
145
146     def sfa_fields_to_dummy_fields(self, type, hrn, sfa_record):
147
148         dummy_record = {}
149  
150         if type == "slice":
151             dummy_record["slice_name"] = hrn_to_dummy_slicename(hrn)
152         
153         elif type == "node":
154             if "hostname" not in sfa_record:
155                 raise MissingSfaInfo("hostname")
156             dummy_record["hostname"] = sfa_record["hostname"]
157             if "type" in sfa_record:
158                dummy_record["type"] = sfa_record["type"]
159             else:
160                dummy_record["type"] = "dummy_type"
161  
162         elif type == "authority":
163             dummy_record["name"] = hrn
164
165         elif type == "user":
166             dummy_record["user_name"] = sfa_record["email"].split('@')[0]
167             dummy_record["email"] = sfa_record["email"]
168
169         return dummy_record
170
171     ####################
172     def fill_record_info(self, records):
173         """
174         Given a (list of) SFA record, fill in the DUMMY TESTBED specific 
175         and SFA specific fields in the record. 
176         """
177         if not isinstance(records, list):
178             records = [records]
179
180         self.fill_record_dummy_info(records)
181         self.fill_record_hrns(records)
182         self.fill_record_sfa_info(records)
183         return records
184
185     def fill_record_dummy_info(self, records):
186         """
187         Fill in the DUMMY specific fields of a SFA record. This
188         involves calling the appropriate DUMMY method to retrieve the 
189         database record for the object.
190             
191         @param record: record to fill in field (in/out param)     
192         """
193         # get ids by type
194         node_ids, slice_ids, user_ids = [], [], [] 
195         type_map = {'node': node_ids, 'slice': slice_ids, 'user': user_ids}
196                   
197         for record in records:
198             for type in type_map:
199                 if type == record['type']:
200                     type_map[type].append(record['pointer'])
201
202         # get dummy records
203         nodes, slices, users = {}, {}, {}
204         if node_ids:
205             node_list = self.shell.GetNodes({'node_ids':node_ids})
206             nodes = list_to_dict(node_list, 'node_id')
207         if slice_ids:
208             slice_list = self.shell.GetSlices({'slice_ids':slice_ids})
209             slices = list_to_dict(slice_list, 'slice_id')
210         if user_ids:
211             user_list = self.shell.GetUsers({'user_ids': user_ids})
212             users = list_to_dict(user_list, 'user_id')
213
214         dummy_records = {'node': nodes, 'slice': slices, 'user': users}
215
216
217         # fill record info
218         for record in records:
219             # records with pointer==-1 do not have dummy info.
220             if record['pointer'] == -1:
221                 continue
222            
223             for type in dummy_records:
224                 if record['type'] == type:
225                     if record['pointer'] in dummy_records[type]:
226                         record.update(dummy_records[type][record['pointer']])
227                         break
228             # fill in key info
229             if record['type'] == 'user':
230                 record['key_ids'] = []
231                 recors['keys'] = []
232                 for key in dummy_records['user'][record['pointer']]['keys']:
233                      record['key_ids'].append(-1)
234                      recors['keys'].append(key)
235
236         return records
237
238     def fill_record_hrns(self, records):
239         """
240         convert dummy ids to hrns
241         """
242
243         # get ids
244         slice_ids, user_ids, node_ids = [], [], []
245         for record in records:
246             if 'user_ids' in record:
247                 user_ids.extend(record['user_ids'])
248             if 'slice_ids' in record:
249                 slice_ids.extend(record['slice_ids'])
250             if 'node_ids' in record:
251                 node_ids.extend(record['node_ids'])
252
253         # get dummy records
254         slices, users, nodes = {}, {}, {}
255         if user_ids:
256             user_list = self.shell.GetUsers({'user_ids': user_ids})
257             users = list_to_dict(user_list, 'user_id')
258         if slice_ids:
259             slice_list = self.shell.GetSlices({'slice_ids': slice_ids})
260             slices = list_to_dict(slice_list, 'slice_id')       
261         if node_ids:
262             node_list = self.shell.GetNodes({'node_ids': node_ids})
263             nodes = list_to_dict(node_list, 'node_id')
264        
265         # convert ids to hrns
266         for record in records:
267             # get all relevant data
268             type = record['type']
269             pointer = record['pointer']
270             testbed_name = self.testbed_name()
271             auth_hrn = self.hrn
272             if pointer == -1:
273                 continue
274
275             if 'user_ids' in record:
276                 emails = [users[user_id]['email'] for user_id in record['user_ids'] \
277                           if user_id in  users]
278                 usernames = [email.split('@')[0] for email in emails]
279                 user_hrns = [".".join([auth_hrn, testbed_name, username]) for username in usernames]
280                 record['users'] = user_hrns 
281             if 'slice_ids' in record:
282                 slicenames = [slices[slice_id]['slice_name'] for slice_id in record['slice_ids'] \
283                               if slice_id in slices]
284                 slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
285                 record['slices'] = slice_hrns
286             if 'node_ids' in record:
287                 hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \
288                              if node_id in nodes]
289                 node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
290                 record['nodes'] = node_hrns
291
292             
293         return records   
294
295     def fill_record_sfa_info(self, records):
296
297         def startswith(prefix, values):
298             return [value for value in values if value.startswith(prefix)]
299
300         # get user ids
301         user_ids = []
302         for record in records:
303             user_ids.extend(record.get("user_ids", []))
304         
305         # get sfa records for all records associated with these records.   
306         # we'll replace pl ids (person_ids) with hrns from the sfa records
307         # we obtain
308         
309         # get the registry records
310         user_list, users = [], {}
311         user_list = dbsession.query (RegRecord).filter(RegRecord.pointer.in_(user_ids))
312         # create a hrns keyed on the sfa record's pointer.
313         # Its possible for multiple records to have the same pointer so
314         # the dict's value will be a list of hrns.
315         users = defaultdict(list)
316         for user in user_list:
317             users[user.pointer].append(user)
318
319         # get the dummy records
320         dummy_user_list, dummy_users = [], {}
321         dummy_user_list = self.shell.GetUsers({'user_ids': user_ids})
322         dummy_users = list_to_dict(dummy_user_list, 'user_id')
323
324         # fill sfa info
325         for record in records:
326             # skip records with no pl info (top level authorities)
327             #if record['pointer'] == -1:
328             #    continue 
329             sfa_info = {}
330             type = record['type']
331             logger.info("fill_record_sfa_info - incoming record typed %s"%type)
332             if (type == "slice"):
333                 # all slice users are researchers
334                 record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice')
335                 record['PI'] = []
336                 record['researcher'] = []
337                 for user_id in record.get('user_ids', []):
338                     hrns = [user.hrn for user in users[user_id]]
339                     record['researcher'].extend(hrns)                
340
341             elif (type.startswith("authority")):
342                 record['url'] = None
343                 logger.info("fill_record_sfa_info - authority xherex")
344
345             elif (type == "node"):
346                 sfa_info['dns'] = record.get("hostname", "")
347                 # xxx TODO: URI, LatLong, IP, DNS
348     
349             elif (type == "user"):
350                 logger.info('setting user.email')
351                 sfa_info['email'] = record.get("email", "")
352                 sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user')
353                 sfa_info['geni_certificate'] = record['gid'] 
354                 # xxx TODO: PostalAddress, Phone
355             record.update(sfa_info)
356
357
358     ####################
359     def update_relation (self, subject_type, target_type, relation_name, subject_id, target_ids):
360         # hard-wire the code for slice/user for now, could be smarter if needed
361         if subject_type =='slice' and target_type == 'user' and relation_name == 'researcher':
362             subject=self.shell.GetSlices ({'slice_id': subject_id})[0]
363             if 'user_ids' not in subject.keys():
364                  subject['user_ids'] = []
365             current_target_ids = subject['user_ids']
366             add_target_ids = list ( set (target_ids).difference(current_target_ids))
367             del_target_ids = list ( set (current_target_ids).difference(target_ids))
368             logger.debug ("subject_id = %s (type=%s)"%(subject_id,type(subject_id)))
369             for target_id in add_target_ids:
370                 self.shell.AddUserToSlice ({'user_id': target_id, 'slice_id': subject_id})
371                 logger.debug ("add_target_id = %s (type=%s)"%(target_id,type(target_id)))
372             for target_id in del_target_ids:
373                 logger.debug ("del_target_id = %s (type=%s)"%(target_id,type(target_id)))
374                 self.shell.DeleteUserFromSlice ({'user_id': target_id, 'slice_id': subject_id})
375         else:
376             logger.info('unexpected relation %s to maintain, %s -> %s'%(relation_name,subject_type,target_type))
377
378         
379     ########################################
380     ########## aggregate oriented
381     ########################################
382
383     def testbed_name (self): return "dummy"
384
385     # 'geni_request_rspec_versions' and 'geni_ad_rspec_versions' are mandatory
386     def aggregate_version (self):
387         version_manager = VersionManager()
388         ad_rspec_versions = []
389         request_rspec_versions = []
390         for rspec_version in version_manager.versions:
391             if rspec_version.content_type in ['*', 'ad']:
392                 ad_rspec_versions.append(rspec_version.to_dict())
393             if rspec_version.content_type in ['*', 'request']:
394                 request_rspec_versions.append(rspec_version.to_dict()) 
395         return {
396             'testbed':self.testbed_name(),
397             'geni_request_rspec_versions': request_rspec_versions,
398             'geni_ad_rspec_versions': ad_rspec_versions,
399             }
400
401     def list_slices (self, creds, options):
402     
403         slices = self.shell.GetSlices()
404         slice_hrns = [slicename_to_hrn(self.hrn, slice['slice_name']) for slice in slices]
405         slice_urns = [hrn_to_urn(slice_hrn, 'slice') for slice_hrn in slice_hrns]
406     
407         return slice_urns
408         
409     # first 2 args are None in case of resource discovery
410     def list_resources (self, slice_urn, slice_hrn, creds, options):
411     
412         version_manager = VersionManager()
413         # get the rspec's return format from options
414         rspec_version = version_manager.get_version(options.get('geni_rspec_version'))
415         version_string = "rspec_%s" % (rspec_version)
416     
417         aggregate = DummyAggregate(self)
418         rspec =  aggregate.get_rspec(slice_xrn=slice_urn, version=rspec_version, 
419                                      options=options)
420     
421         return rspec
422     
423     def sliver_status (self, slice_urn, slice_hrn):
424         # find out where this slice is currently running
425         slice_name = hrn_to_dummy_slicename(slice_hrn)
426         
427         slices = self.shell.GetSlices({'slice_name': slice_name})
428         if len(slices) == 0:        
429             raise SliverDoesNotExist("%s (used %s as slicename internally)" % (slice_hrn, slicename))
430         slice = slices[0]
431         
432         # report about the local nodes only
433         nodes = self.shell.GetNodes({'node_ids':slice['node_ids']})
434
435         if len(nodes) == 0:
436             raise SliverDoesNotExist("You have not allocated any slivers here") 
437
438         # get login info
439         user = {}
440         keys = []
441         if slice['user_ids']:
442             users = self.shell.GetUsers({'user_ids': slice['user_ids']})
443             for user in users:
444                  keys.extend(user['keys'])
445
446             user.update({'urn': slice_urn,
447                          'login': slice['slice_name'],
448                          'protocol': ['ssh'],
449                          'port': ['22'],
450                          'keys': keys})
451
452     
453         result = {}
454         top_level_status = 'unknown'
455         if nodes:
456             top_level_status = 'ready'
457         result['geni_urn'] = slice_urn
458         result['dummy_login'] = slice['slice_name']
459         result['dummy_expires'] = datetime_to_string(utcparse(slice['expires']))
460         result['geni_expires'] = datetime_to_string(utcparse(slice['expires']))
461         
462         resources = []
463         for node in nodes:
464             res = {}
465             res['dummy_hostname'] = node['hostname']
466             res['geni_expires'] = datetime_to_string(utcparse(slice['expires']))
467             sliver_id = Xrn(slice_urn, type='slice', id=node['node_id']).urn
468             res['geni_urn'] = sliver_id
469             res['geni_status'] = 'ready'
470             res['geni_error'] = ''
471             res['users'] = [users]  
472     
473             resources.append(res)
474             
475         result['geni_status'] = top_level_status
476         result['geni_resources'] = resources
477         return result
478
479     def create_sliver (self, slice_urn, slice_hrn, creds, rspec_string, users, options):
480
481         aggregate = DummyAggregate(self)
482         slices = DummySlices(self)
483         sfa_peer = slices.get_sfa_peer(slice_hrn)
484         slice_record=None    
485         if users:
486             slice_record = users[0].get('slice_record', {})
487     
488         # parse rspec
489         rspec = RSpec(rspec_string)
490         requested_attributes = rspec.version.get_slice_attributes()
491         
492         # ensure slice record exists
493         slice = slices.verify_slice(slice_hrn, slice_record, sfa_peer, options=options)
494         # ensure user records exists
495         #users = slices.verify_users(slice_hrn, slice, users, sfa_peer, options=options)
496         
497         # add/remove slice from nodes
498         requested_slivers = []
499         for node in rspec.version.get_nodes_with_slivers():
500             hostname = None
501             if node.get('component_name'):
502                 hostname = node.get('component_name').strip()
503             elif node.get('component_id'):
504                 hostname = xrn_to_hostname(node.get('component_id').strip())
505             if hostname:
506                 requested_slivers.append(hostname)
507         requested_slivers_ids = []
508         for hostname in requested_slivers:
509             node_id = self.shell.GetNodes({'hostname': hostname})[0]['node_id']
510             requested_slivers_ids.append(node_id) 
511         nodes = slices.verify_slice_nodes(slice, requested_slivers_ids) 
512     
513         return aggregate.get_rspec(slice_xrn=slice_urn, version=rspec.version)
514
515     def delete_sliver (self, slice_urn, slice_hrn, creds, options):
516         slicename = hrn_to_dummy_slicename(slice_hrn)
517         slices = self.shell.GetSlices({'slice_name': slicename})
518         if not slices:
519             return True
520         slice = slices[0]
521         
522         try:
523             self.shell.DeleteSliceFromNodes({'slice_id': slice['slice_id'], 'node_ids': slice['node_ids']})
524             return True
525         except:
526             return False
527     
528     def renew_sliver (self, slice_urn, slice_hrn, creds, expiration_time, options):
529         slicename = hrn_to_dummy_slicename(slice_hrn)
530         slices = self.shell.GetSlices({'slice_name': slicename})
531         if not slices:
532             raise RecordNotFound(slice_hrn)
533         slice = slices[0]
534         requested_time = utcparse(expiration_time)
535         record = {'expires': int(datetime_to_epoch(requested_time))}
536         try:
537             self.shell.UpdateSlice({'slice_id': slice['slice_id'], 'fields':record})
538             return True
539         except:
540             return False
541
542     # set the 'enabled' tag to True
543     def start_slice (self, slice_urn, slice_hrn, creds):
544         slicename = hrn_to_dummy_slicename(slice_hrn)
545         slices = self.shell.GetSlices({'slice_name': slicename})
546         if not slices:
547             raise RecordNotFound(slice_hrn)
548         slice_id = slices[0]['slice_id']
549         slice_enabled = slices[0]['enabled'] 
550         # just update the slice enabled tag
551         if not slice_enabled:
552             self.shell.UpdateSlice({'slice_id': slice_id, 'fields': {'enabled': True}})
553         return 1
554
555     # set the 'enabled' tag to False
556     def stop_slice (self, slice_urn, slice_hrn, creds):
557         slicename = hrn_to_pl_slicename(slice_hrn)
558         slices = self.shell.GetSlices({'slice_name': slicename})
559         if not slices:
560             raise RecordNotFound(slice_hrn)
561         slice_id = slices[0]['slice_id']
562         slice_enabled = slices[0]['enabled']
563         # just update the slice enabled tag
564         if slice_enabled:
565             self.shell.UpdateSlice({'slice_id': slice_id, 'fields': {'enabled': False}})
566         return 1
567     
568     def reset_slice (self, slice_urn, slice_hrn, creds):
569         raise SfaNotImplemented ("reset_slice not available at this interface")
570     
571     def get_ticket (self, slice_urn, slice_hrn, creds, rspec_string, options):
572         raise SfaNotImplemented,"DummyDriver.get_ticket needs a rewrite"