227ceba2b36be9b200ac13ed1af9add3f7d5e44d
[sfa.git] / sfa / dummy / dummydriver.py
1 import time
2 import datetime
3 #
4 from sfa.util.faults import MissingSfaInfo, UnknownSfaType, \
5     RecordNotFound, SfaNotImplemented, SliverDoesNotExist
6
7 from sfa.util.sfalogging import logger
8 from sfa.util.defaultdict import defaultdict
9 from sfa.util.sfatime import utcparse, datetime_to_string, datetime_to_epoch
10 from sfa.util.xrn import Xrn, hrn_to_urn, get_leaf
11 from sfa.util.cache import Cache
12
13 # one would think the driver should not need to mess with the SFA db, but..
14 from sfa.storage.alchemy import dbsession
15 from sfa.storage.model import RegRecord
16
17 # used to be used in get_ticket
18 #from sfa.trust.sfaticket import SfaTicket
19
20 from sfa.rspecs.version_manager import VersionManager
21 from sfa.rspecs.rspec import RSpec
22
23 # the driver interface, mostly provides default behaviours
24 from sfa.managers.driver import Driver
25
26 from sfa.dummy.dummyshell import DummyShell
27 from sfa.dummy.dummyaggregate import DummyAggregate
28 from sfa.dummy.dummyslices import DummySlices
29 from sfa.dummy.dummyxrn import DummyXrn, slicename_to_hrn, hostname_to_hrn, hrn_to_dummy_slicename, xrn_to_hostname
30
31
32 def list_to_dict(recs, key):
33     """
34     convert a list of dictionaries into a dictionary keyed on the 
35     specified dictionary key 
36     """
37     return dict ( [ (rec[key],rec) for rec in recs ] )
38
39 #
40 # DummyShell is just an xmlrpc serverproxy where methods can be sent as-is; 
41
42 class DummyDriver (Driver):
43
44     # the cache instance is a class member so it survives across incoming requests
45     cache = None
46
47     def __init__ (self, config):
48         Driver.__init__ (self, config)
49         self.config = config
50         self.hrn = config.SFA_INTERFACE_HRN
51         self.root_auth = config.SFA_REGISTRY_ROOT_AUTH
52         self.shell = DummyShell (config)
53         self.testbedInfo = self.shell.GetTestbedInfo()
54  
55     ########################################
56     ########## registry oriented
57     ########################################
58
59     def augment_records_with_testbed_info (self, sfa_records):
60         return self.fill_record_info (sfa_records)
61
62     ########## 
63     def register (self, sfa_record, hrn, pub_key):
64         type = sfa_record['type']
65         dummy_record = self.sfa_fields_to_dummy_fields(type, hrn, sfa_record)
66         
67         if type == 'authority':
68             pointer = -1
69
70         elif type == 'slice':
71             slices = self.shell.GetSlices({'slice_name': dummy_record['slice_name']})
72             if not slices:
73                  pointer = self.shell.AddSlice(dummy_record)
74             else:
75                  pointer = slices[0]['slice_id']
76
77         elif type == 'user':
78             users = self.shell.GetUsers({'email':sfa_record['email']})
79             if not users:
80                 pointer = self.shell.AddUser(dummy_record)
81             else:
82                 pointer = users[0]['user_id']
83     
84             # Add the user's key
85             if pub_key:
86                 self.shell.AddUserKey({'user_id' : pointer, 'key' : pub_key})
87
88         elif type == 'node':
89             nodes = self.shell.GetNodes(dummy_record['hostname'])
90             if not nodes:
91                 pointer = self.shell.AddNode(dummy_record)
92             else:
93                 pointer = users[0]['node_id']
94     
95         return pointer
96         
97     ##########
98     def update (self, old_sfa_record, new_sfa_record, hrn, new_key):
99         pointer = old_sfa_record['pointer']
100         type = old_sfa_record['type']
101         dummy_record=self.sfa_fields_to_dummy_fields(type, hrn, new_sfa_record)
102
103         # new_key implemented for users only
104         if new_key and type not in [ 'user' ]:
105             raise UnknownSfaType(type)
106
107     
108         if type == "slice":
109             self.shell.UpdateSlice({'slice_id': pointer, 'fields': dummy_record})
110     
111         elif type == "user":
112             self.shell.UpdateUser({'user_id': pointer, 'fields': dummy_record})
113
114             if new_key:
115                 self.shell.AddUserKey({'user_id' : pointer, 'key' : new_key})
116
117         elif type == "node":
118             self.shell.UpdateNode({'node_id': pointer, 'fields': dummy_record})
119
120
121         return True
122         
123
124     ##########
125     def remove (self, sfa_record):
126         type=sfa_record['type']
127         pointer=sfa_record['pointer']
128         if type == 'user':
129             self.shell.DeleteUser{'user_id': pointer})
130         elif type == 'slice':
131             self.shell.DeleteSlice('slice_id': pointer)
132         elif type == 'node':
133             self.shell.DeleteNode('node_id': pointer)
134
135         return True
136
137
138
139
140
141     ##
142     # Convert SFA fields to Dummy testbed fields for use when registering or updating
143     # registry record in the dummy testbed
144     #
145
146     def sfa_fields_to_dummy_fields(self, type, hrn, sfa_record):
147
148         dummy_record = {}
149  
150         if type == "slice":
151             dummy_record["slice_name"] = hrn_to_dummy_slicename(hrn)
152         
153         elif type == "node":
154             if "hostname" not in sfa_record:
155                 raise MissingSfaInfo("hostname")
156             dummy_record["hostname"] = sfa_record["hostname"]
157             if "type" in sfa_record:
158                dummy_record["type"] = sfa_record["type"]
159             else:
160                dummy_record["type"] = "dummy_type"
161  
162         elif type == "authority":
163             dummy_record["name"] = hrn
164
165         elif type == "user":
166             dummy_record["user_name"] = sfa_record["email"].split('@')[0]
167             dummy_record["email"] = sfa_record["email"]
168
169         return dummy_record
170
171     ####################
172     def fill_record_info(self, records):
173         """
174         Given a (list of) SFA record, fill in the DUMMY TESTBED specific 
175         and SFA specific fields in the record. 
176         """
177         if not isinstance(records, list):
178             records = [records]
179
180         self.fill_record_dummy_info(records)
181         self.fill_record_hrns(records)
182         self.fill_record_sfa_info(records)
183         return records
184
185     def fill_record_dummy_info(self, records):
186         """
187         Fill in the DUMMY specific fields of a SFA record. This
188         involves calling the appropriate DUMMY method to retrieve the 
189         database record for the object.
190             
191         @param record: record to fill in field (in/out param)     
192         """
193         # get ids by type
194         node_ids, slice_ids, user_ids = [], [], [] 
195         type_map = {'node': node_ids, 'slice': slice_ids, 'user': user_ids}
196                   
197         for record in records:
198             for type in type_map:
199                 if type == record['type']:
200                     type_map[type].append(record['pointer'])
201
202         # get dummy records
203         nodes, slices, users = {}, {}, {}
204         if node_ids:
205             node_list = self.shell.GetNodes({'node_ids':node_ids})
206             nodes = list_to_dict(node_list, 'node_id')
207         if slice_ids:
208             slice_list = self.shell.GetSlices({'slice_ids':slice_ids})
209             slices = list_to_dict(slice_list, 'slice_id')
210         if user_ids:
211             user_list = self.shell.GetUsers({'user_ids': user_ids})
212             users = list_to_dict(user_list, 'user_id')
213
214         dummy_records = {'node': nodes, 'slice': slices, 'user': users}
215
216
217         # fill record info
218         for record in records:
219             # records with pointer==-1 do not have dummy info.
220             if record['pointer'] == -1:
221                 continue
222            
223             for type in dummy_records:
224                 if record['type'] == type:
225                     if record['pointer'] in dummy_records[type]:
226                         record.update(dummy_records[type][record['pointer']])
227                         break
228             # fill in key info
229             if record['type'] == 'user':
230                 record['key_ids'] = []
231                 recors['keys'] = []
232                 for key in dummy_records['user'][record['pointer']]['keys']:
233                      record['key_ids'].append(-1)
234                      recors['keys'].append(key)
235
236         return records
237
238     def fill_record_hrns(self, records):
239         """
240         convert dummy ids to hrns
241         """
242
243         # get ids
244         slice_ids, user_ids, node_ids = [], [], []
245         for record in records:
246             if 'user_ids' in record:
247                 user_ids.extend(record['user_ids'])
248             if 'slice_ids' in record:
249                 slice_ids.extend(record['slice_ids'])
250             if 'node_ids' in record:
251                 node_ids.extend(record['node_ids'])
252
253         # get dummy records
254         slices, users, nodes = {}, {}, {}
255         if user_ids:
256             user_list = self.shell.GetUsers({'user_ids': user_ids})
257             users = list_to_dict(user_list, 'user_id')
258         if slice_ids:
259             slice_list = self.shell.GetSlices({'slice_ids': slice_ids})
260             slices = list_to_dict(slice_list, 'slice_id')       
261         if node_ids:
262             node_list = self.shell.GetNodes({'node_ids': node_ids})
263             nodes = list_to_dict(node_list, 'node_id')
264        
265         # convert ids to hrns
266         for record in records:
267             # get all relevant data
268             type = record['type']
269             pointer = record['pointer']
270             testbed_name = self.testbed_name()
271             auth_hrn = self.hrn
272             if pointer == -1:
273                 continue
274
275             if 'user_ids' in record:
276                 emails = [users[user_id]['email'] for user_id in record['user_ids'] \
277                           if user_id in  users]
278                 usernames = [email.split('@')[0] for email in emails]
279                 user_hrns = [".".join([auth_hrn, testbed_name, username]) for username in usernames]
280                 record['users'] = user_hrns 
281             if 'slice_ids' in record:
282                 slicenames = [slices[slice_id]['slice_name'] for slice_id in record['slice_ids'] \
283                               if slice_id in slices]
284                 slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
285                 record['slices'] = slice_hrns
286             if 'node_ids' in record:
287                 hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \
288                              if node_id in nodes]
289                 node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
290                 record['nodes'] = node_hrns
291
292             
293         return records   
294
295     def fill_record_sfa_info(self, records):
296
297         def startswith(prefix, values):
298             return [value for value in values if value.startswith(prefix)]
299
300         # get user ids
301         user_ids = []
302         for record in records:
303             user_ids.extend(record.get("user_ids", []))
304         
305         # get sfa records for all records associated with these records.   
306         # we'll replace pl ids (person_ids) with hrns from the sfa records
307         # we obtain
308         
309         # get the registry records
310         user_list, users = [], {}
311         user_list = dbsession.query (RegRecord).filter(RegRecord.pointer.in_(user_ids))
312         # create a hrns keyed on the sfa record's pointer.
313         # Its possible for multiple records to have the same pointer so
314         # the dict's value will be a list of hrns.
315         users = defaultdict(list)
316         for user in user_list:
317             users[user.pointer].append(user)
318
319         # get the dummy records
320         dummy_user_list, dummy_users = [], {}
321         dummy_user_list = self.shell.GetUsers({'user_ids': user_ids})
322         dummy_users = list_to_dict(dummy_user_list, 'user_id')
323
324         # fill sfa info
325         for record in records:
326             # skip records with no pl info (top level authorities)
327             #if record['pointer'] == -1:
328             #    continue 
329             sfa_info = {}
330             type = record['type']
331             logger.info("fill_record_sfa_info - incoming record typed %s"%type)
332             if (type == "slice"):
333                 # all slice users are researchers
334                 record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice')
335                 record['PI'] = []
336                 record['researcher'] = []
337                 for user_id in record.get('user_ids', []):
338                     hrns = [user.hrn for user in users[user_id]]
339                     record['researcher'].extend(hrns)                
340
341             elif (type.startswith("authority")):
342                 record['url'] = None
343                 logger.info("fill_record_sfa_info - authority xherex")
344
345             elif (type == "node"):
346                 sfa_info['dns'] = record.get("hostname", "")
347                 # xxx TODO: URI, LatLong, IP, DNS
348     
349             elif (type == "user"):
350                 logger.info('setting user.email')
351                 sfa_info['email'] = record.get("email", "")
352                 sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user')
353                 sfa_info['geni_certificate'] = record['gid'] 
354                 # xxx TODO: PostalAddress, Phone
355             record.update(sfa_info)
356
357
358     ####################
359     def update_relation (self, subject_type, target_type, relation_name, subject_id, target_ids):
360         # hard-wire the code for slice/user for now, could be smarter if needed
361         if subject_type =='slice' and target_type == 'user' and relation_name == 'researcher':
362             subject=self.shell.GetSlices (subject_id)[0]
363             current_target_ids = subject['user_ids']
364             add_target_ids = list ( set (target_ids).difference(current_target_ids))
365             del_target_ids = list ( set (current_target_ids).difference(target_ids))
366             logger.debug ("subject_id = %s (type=%s)"%(subject_id,type(subject_id)))
367             for target_id in add_target_ids:
368                 self.shell.AddUserToSlice ({'user_id': target_id, 'slice_id': subject_id})
369                 logger.debug ("add_target_id = %s (type=%s)"%(target_id,type(target_id)))
370             for target_id in del_target_ids:
371                 logger.debug ("del_target_id = %s (type=%s)"%(target_id,type(target_id)))
372                 self.shell.DeleteUserFromSlice ({'user_id': target_id, 'slice_id': subject_id})
373         else:
374             logger.info('unexpected relation %s to maintain, %s -> %s'%(relation_name,subject_type,target_type))
375
376         
377     ########################################
378     ########## aggregate oriented
379     ########################################
380
381     def testbed_name (self): return "dummy"
382
383     # 'geni_request_rspec_versions' and 'geni_ad_rspec_versions' are mandatory
384     def aggregate_version (self):
385         version_manager = VersionManager()
386         ad_rspec_versions = []
387         request_rspec_versions = []
388         for rspec_version in version_manager.versions:
389             if rspec_version.content_type in ['*', 'ad']:
390                 ad_rspec_versions.append(rspec_version.to_dict())
391             if rspec_version.content_type in ['*', 'request']:
392                 request_rspec_versions.append(rspec_version.to_dict()) 
393         return {
394             'testbed':self.testbed_name(),
395             'geni_request_rspec_versions': request_rspec_versions,
396             'geni_ad_rspec_versions': ad_rspec_versions,
397             }
398
399     def list_slices (self, creds, options):
400     
401         slices = self.shell.GetSlices()
402         slice_hrns = [slicename_to_hrn(self.hrn, slice['slice_name']) for slice in slices]
403         slice_urns = [hrn_to_urn(slice_hrn, 'slice') for slice_hrn in slice_hrns]
404     
405         return slice_urns
406         
407     # first 2 args are None in case of resource discovery
408     def list_resources (self, slice_urn, slice_hrn, creds, options):
409     
410         version_manager = VersionManager()
411         # get the rspec's return format from options
412         rspec_version = version_manager.get_version(options.get('geni_rspec_version'))
413         version_string = "rspec_%s" % (rspec_version)
414     
415         aggregate = DummyAggregate(self)
416         rspec =  aggregate.get_rspec(slice_xrn=slice_urn, version=rspec_version, 
417                                      options=options)
418     
419         return rspec
420     
421     def sliver_status (self, slice_urn, slice_hrn):
422         # find out where this slice is currently running
423         slice_name = hrn_to_dummy_slicename(slice_hrn)
424         
425         slice = self.shell.GetSlices({'slice_name': slice_name})
426         if len(slices) == 0:        
427             raise SliverDoesNotExist("%s (used %s as slicename internally)" % (slice_hrn, slicename))
428         
429         # report about the local nodes only
430         nodes = self.shell.GetNodes({'node_ids':slice['node_ids']})
431
432         if len(nodes) == 0:
433             raise SliverDoesNotExist("You have not allocated any slivers here") 
434
435         # get login info
436         user = {}
437         keys = []
438         if slice['user_ids']:
439             users = self.shell.GetUsers({'user_ids': slice['user_ids']})
440             for user in users:
441                  keys.extend(user['keys'])
442
443             user.update({'urn': slice_urn,
444                          'login': slice['slice_name'],
445                          'protocol': ['ssh'],
446                          'port': ['22'],
447                          'keys': keys})
448
449     
450         result = {}
451         top_level_status = 'unknown'
452         if nodes:
453             top_level_status = 'ready'
454         result['geni_urn'] = slice_urn
455         result['dummy_login'] = slice['slice_name']
456         result['dummy_expires'] = datetime_to_string(utcparse(slice['expires']))
457         result['geni_expires'] = datetime_to_string(utcparse(slice['expires']))
458         
459         resources = []
460         for node in nodes:
461             res = {}
462             res['dummy_hostname'] = node['hostname']
463             res['geni_expires'] = datetime_to_string(utcparse(slice['expires']))
464             sliver_id = Xrn(slice_urn, type='slice', id=node['node_id'], authority=self.hrn).urn
465             res['geni_urn'] = sliver_id
466             res['geni_status'] = 'ready'
467             res['geni_error'] = ''
468             res['users'] = [users]  
469     
470             resources.append(res)
471             
472         result['geni_status'] = top_level_status
473         result['geni_resources'] = resources
474         return result
475
476     def create_sliver (self, slice_urn, slice_hrn, creds, rspec_string, users, options):
477
478         aggregate = DummyAggregate(self)
479         slices = DummySlices(self)
480         sfa_peer = slices.get_sfa_peer(slice_hrn)
481         slice_record=None    
482         if users:
483             slice_record = users[0].get('slice_record', {})
484     
485         # parse rspec
486         rspec = RSpec(rspec_string)
487         requested_attributes = rspec.version.get_slice_attributes()
488         
489         # ensure slice record exists
490         slice = slices.verify_slice(slice_hrn, slice_record, peer, sfa_peer, options=options)
491         # ensure user records exists
492         users = slices.verify_users(slice_hrn, slice, users, peer, sfa_peer, options=options)
493         
494         # add/remove slice from nodes
495         requested_slivers = []
496         for node in rspec.version.get_nodes_with_slivers():
497             hostname = None
498             if node.get('component_name'):
499                 hostname = node.get('component_name').strip()
500             elif node.get('component_id'):
501                 hostname = xrn_to_hostname(node.get('component_id').strip())
502             if hostname:
503                 requested_slivers.append(hostname)
504         nodes = slices.verify_slice_nodes(slice, requested_slivers, peer) 
505     
506         return aggregate.get_rspec(slice_xrn=slice_urn, version=rspec.version)
507
508     def delete_sliver (self, slice_urn, slice_hrn, creds, options):
509         slicename = hrn_to_dummy_slicename(slice_hrn)
510         slices = self.shell.GetSlices({'slice_name': slicename})
511         if not slices:
512             return True
513         slice = slices[0]
514         
515         try:
516             self.shell.DeleteSliceFromNodes({'slice_id': slice['slice_id'], 'node_ids': slice['node_ids'])
517             return True
518         except:
519             return False
520     
521     def renew_sliver (self, slice_urn, slice_hrn, creds, expiration_time, options):
522         slicename = hrn_to_dummy_slicename(slice_hrn)
523         slices = self.shell.GetSlices({'slice_name': slicename})
524         if not slices:
525             raise RecordNotFound(slice_hrn)
526         slice = slices[0]
527         requested_time = utcparse(expiration_time)
528         record = {'expires': int(datetime_to_epoch(requested_time))}
529         try:
530             self.shell.UpdateSlice({'slice_id': slice['slice_id'], 'fields':record})
531             return True
532         except:
533             return False
534
535     # set the 'enabled' tag to True
536     def start_slice (self, slice_urn, slice_hrn, creds):
537         slicename = hrn_to_dummy_slicename(slice_hrn)
538         slices = self.shell.GetSlices({'slice_name': slicename})
539         if not slices:
540             raise RecordNotFound(slice_hrn)
541         slice_id = slices[0]['slice_id']
542         slice_enabled = slices[0]['enabled'] 
543         # just update the slice enabled tag
544         if not slice_enabled:
545             self.shell.UpdateSlice({'slice_id': slice_id, 'fields': {'enabled': True}})
546         return 1
547
548     # set the 'enabled' tag to False
549     def stop_slice (self, slice_urn, slice_hrn, creds):
550         slicename = hrn_to_pl_slicename(slice_hrn)
551         slices = self.shell.GetSlices({'slice_name': slicename})
552         if not slices:
553             raise RecordNotFound(slice_hrn)
554         slice_id = slices[0]['slice_id']
555         slice_enabled = slices[0]['enabled']
556         # just update the slice enabled tag
557         if slice_enabled:
558             self.shell.UpdateSlice({'slice_id': slice_id, 'fields': {'enabled': False}})
559         return 1
560     
561     def reset_slice (self, slice_urn, slice_hrn, creds):
562         raise SfaNotImplemented ("reset_slice not available at this interface")
563     
564     def get_ticket (self, slice_urn, slice_hrn, creds, rspec_string, options):
565         raise SfaNotImplemented,"DummyDriver.get_ticket needs a rewrite"