d892e25e2d2dba983a61a51f75e04a6127dcbb46
[sfa.git] / sfa / plc / pldriver.py
1 import datetime
2 #
3 from sfa.util.faults import MissingSfaInfo, UnknownSfaType
4 from sfa.util.sfalogging import logger
5 from sfa.util.defaultdict import defaultdict
6 from sfa.util.xrn import hrn_to_urn, get_leaf
7 from sfa.util.plxrn import slicename_to_hrn, hostname_to_hrn, hrn_to_pl_slicename, hrn_to_pl_login_base
8
9 # one would think the driver should not need to mess with the SFA db, but..
10 from sfa.storage.table import SfaTable
11
12 from sfa.rspecs.version_manager import VersionManager
13 from sfa.rspecs.rspec import RSpec
14
15 # the driver interface, mostly provides default behaviours
16 from sfa.managers.driver import Driver
17
18 from sfa.plc.plshell import PlShell
19
20 import sfa.plc.peers as peers
21 from sfa.plc.plaggregate import PlAggregate
22 from sfa.plc.plslices import PlSlices
23
24 def list_to_dict(recs, key):
25     """
26     convert a list of dictionaries into a dictionary keyed on the 
27     specified dictionary key 
28     """
29     return dict ( [ (rec[key],rec) for rec in recs ] )
30
31 #
32 # PlShell is just an xmlrpc serverproxy where methods
33 # can be sent as-is; it takes care of authentication
34 # from the global config
35
36 # so we inherit PlShell just so one can do driver.GetNodes
37 # which would not make much sense in the context of other testbeds
38 # so ultimately PlDriver should drop the PlShell inheritance
39 # and would have a driver.shell reference to a PlShell instead
40
41 class PlDriver (Driver, PlShell):
42
43     def __init__ (self, config):
44         PlShell.__init__ (self, config)
45         Driver.__init__ (self, config)
46  
47     ########################################
48     ########## registry oriented
49     ########################################
50
51     ########## disabled users 
52     def is_enabled (self, record):
53         # the incoming record was augmented already, so 'enabled' should be set
54         if record['type'] == 'user':
55             return record['enabled']
56         # only users can be disabled
57         return True
58
59     def augment_records_with_testbed_info (self, sfa_records):
60         return self.fill_record_info (sfa_records)
61
62     ########## 
63     def register (self, sfa_record, hrn, pub_key):
64         type = sfa_record['type']
65         pl_record = self.sfa_fields_to_pl_fields(type, hrn, sfa_record)
66
67         if type == 'authority':
68             sites = self.GetSites([pl_record['login_base']])
69             if not sites:
70                 pointer = self.AddSite(pl_record)
71             else:
72                 pointer = sites[0]['site_id']
73
74         elif type == 'slice':
75             acceptable_fields=['url', 'instantiation', 'name', 'description']
76             for key in pl_record.keys():
77                 if key not in acceptable_fields:
78                     pl_record.pop(key)
79             slices = self.GetSlices([pl_record['name']])
80             if not slices:
81                  pointer = self.AddSlice(pl_record)
82             else:
83                  pointer = slices[0]['slice_id']
84
85         elif type == 'user':
86             persons = self.GetPersons([sfa_record['email']])
87             if not persons:
88                 pointer = self.AddPerson(dict(sfa_record))
89             else:
90                 pointer = persons[0]['person_id']
91     
92             if 'enabled' in sfa_record and sfa_record['enabled']:
93                 self.UpdatePerson(pointer, {'enabled': sfa_record['enabled']})
94             # add this person to the site only if she is being added for the first
95             # time by sfa and doesont already exist in plc
96             if not persons or not persons[0]['site_ids']:
97                 login_base = get_leaf(sfa_record['authority'])
98                 self.AddPersonToSite(pointer, login_base)
99     
100             # What roles should this user have?
101             self.AddRoleToPerson('user', pointer)
102             # Add the user's key
103             if pub_key:
104                 self.AddPersonKey(pointer, {'key_type' : 'ssh', 'key' : pub_key})
105
106         elif type == 'node':
107             login_base = hrn_to_pl_login_base(sfa_record['authority'])
108             nodes = api.driver.GetNodes([pl_record['hostname']])
109             if not nodes:
110                 pointer = api.driver.AddNode(login_base, pl_record)
111             else:
112                 pointer = nodes[0]['node_id']
113     
114         return pointer
115         
116     ##########
117     # xxx actually old_sfa_record comes filled with plc stuff as well in the original code
118     def update (self, old_sfa_record, new_sfa_record, hrn, new_key):
119         pointer = old_sfa_record['pointer']
120         type = old_sfa_record['type']
121
122         # new_key implemented for users only
123         if new_key and type not in [ 'user' ]:
124             raise UnknownSfaType(type)
125
126         if (type == "authority"):
127             self.UpdateSite(pointer, new_sfa_record)
128     
129         elif type == "slice":
130             pl_record=self.sfa_fields_to_pl_fields(type, hrn, new_sfa_record)
131             if 'name' in pl_record:
132                 pl_record.pop('name')
133                 self.UpdateSlice(pointer, pl_record)
134     
135         elif type == "user":
136             # SMBAKER: UpdatePerson only allows a limited set of fields to be
137             #    updated. Ideally we should have a more generic way of doing
138             #    this. I copied the field names from UpdatePerson.py...
139             update_fields = {}
140             all_fields = new_sfa_record
141             for key in all_fields.keys():
142                 if key in ['first_name', 'last_name', 'title', 'email',
143                            'password', 'phone', 'url', 'bio', 'accepted_aup',
144                            'enabled']:
145                     update_fields[key] = all_fields[key]
146             self.UpdatePerson(pointer, update_fields)
147     
148             if new_key:
149                 # must check this key against the previous one if it exists
150                 persons = self.GetPersons([pointer], ['key_ids'])
151                 person = persons[0]
152                 keys = person['key_ids']
153                 keys = self.GetKeys(person['key_ids'])
154                 
155                 # Delete all stale keys
156                 key_exists = False
157                 for key in keys:
158                     if new_key != key['key']:
159                         self.DeleteKey(key['key_id'])
160                     else:
161                         key_exists = True
162                 if not key_exists:
163                     self.AddPersonKey(pointer, {'key_type': 'ssh', 'key': new_key})
164     
165         elif type == "node":
166             self.UpdateNode(pointer, new_sfa_record)
167
168         return True
169         
170
171     ##########
172     def remove (self, sfa_record):
173         type=sfa_record['type']
174         pointer=sfa_record['pointer']
175         if type == 'user':
176             persons = self.GetPersons(pointer)
177             # only delete this person if he has site ids. if he doesnt, it probably means
178             # he was just removed from a site, not actually deleted
179             if persons and persons[0]['site_ids']:
180                 self.DeletePerson(pointer)
181         elif type == 'slice':
182             if self.GetSlices(pointer):
183                 self.DeleteSlice(pointer)
184         elif type == 'node':
185             if self.GetNodes(pointer):
186                 self.DeleteNode(pointer)
187         elif type == 'authority':
188             if self.GetSites(pointer):
189                 self.DeleteSite(pointer)
190
191         return True
192
193
194
195
196
197     ##
198     # Convert SFA fields to PLC fields for use when registering up updating
199     # registry record in the PLC database
200     #
201
202     def sfa_fields_to_pl_fields(self, type, hrn, sfa_record):
203
204         pl_record = {}
205  
206         if type == "slice":
207             pl_record["name"] = hrn_to_pl_slicename(hrn)
208             if "instantiation" in sfa_record:
209                 pl_record['instantiation']=sfa_record['instantiation']
210             else:
211                 pl_record["instantiation"] = "plc-instantiated"
212             if "url" in sfa_record:
213                pl_record["url"] = sfa_record["url"]
214             if "description" in sfa_record:
215                 pl_record["description"] = sfa_record["description"]
216             if "expires" in sfa_record:
217                 pl_record["expires"] = int(sfa_record["expires"])
218
219         elif type == "node":
220             if not "hostname" in pl_record:
221                 # fetch from sfa_record
222                 if "hostname" not in sfa_record:
223                     raise MissingSfaInfo("hostname")
224                 pl_record["hostname"] = sfa_record["hostname"]
225             if "model" in sfa_record: 
226                 pl_record["model"] = sfa_record["model"]
227             else:
228                 pl_record["model"] = "geni"
229
230         elif type == "authority":
231             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
232             if "name" not in sfa_record:
233                 pl_record["name"] = hrn
234             if "abbreviated_name" not in sfa_record:
235                 pl_record["abbreviated_name"] = hrn
236             if "enabled" not in sfa_record:
237                 pl_record["enabled"] = True
238             if "is_public" not in sfa_record:
239                 pl_record["is_public"] = True
240
241         return pl_record
242
243     ####################
244     def fill_record_info(self, records):
245         """
246         Given a (list of) SFA record, fill in the PLC specific 
247         and SFA specific fields in the record. 
248         """
249         if not isinstance(records, list):
250             records = [records]
251
252         self.fill_record_pl_info(records)
253         self.fill_record_hrns(records)
254         self.fill_record_sfa_info(records)
255         return records
256
257     def fill_record_pl_info(self, records):
258         """
259         Fill in the planetlab specific fields of a SFA record. This
260         involves calling the appropriate PLC method to retrieve the 
261         database record for the object.
262             
263         @param record: record to fill in field (in/out param)     
264         """
265         # get ids by type
266         node_ids, site_ids, slice_ids = [], [], [] 
267         person_ids, key_ids = [], []
268         type_map = {'node': node_ids, 'authority': site_ids,
269                     'slice': slice_ids, 'user': person_ids}
270                   
271         for record in records:
272             for type in type_map:
273                 if type == record['type']:
274                     type_map[type].append(record['pointer'])
275
276         # get pl records
277         nodes, sites, slices, persons, keys = {}, {}, {}, {}, {}
278         if node_ids:
279             node_list = self.GetNodes(node_ids)
280             nodes = list_to_dict(node_list, 'node_id')
281         if site_ids:
282             site_list = self.GetSites(site_ids)
283             sites = list_to_dict(site_list, 'site_id')
284         if slice_ids:
285             slice_list = self.GetSlices(slice_ids)
286             slices = list_to_dict(slice_list, 'slice_id')
287         if person_ids:
288             person_list = self.GetPersons(person_ids)
289             persons = list_to_dict(person_list, 'person_id')
290             for person in persons:
291                 key_ids.extend(persons[person]['key_ids'])
292
293         pl_records = {'node': nodes, 'authority': sites,
294                       'slice': slices, 'user': persons}
295
296         if key_ids:
297             key_list = self.GetKeys(key_ids)
298             keys = list_to_dict(key_list, 'key_id')
299
300         # fill record info
301         for record in records:
302             # records with pointer==-1 do not have plc info.
303             # for example, the top level authority records which are
304             # authorities, but not PL "sites"
305             if record['pointer'] == -1:
306                 continue
307            
308             for type in pl_records:
309                 if record['type'] == type:
310                     if record['pointer'] in pl_records[type]:
311                         record.update(pl_records[type][record['pointer']])
312                         break
313             # fill in key info
314             if record['type'] == 'user':
315                 if 'key_ids' not in record:
316                     logger.info("user record has no 'key_ids' - need to import from myplc ?")
317                 else:
318                     pubkeys = [keys[key_id]['key'] for key_id in record['key_ids'] if key_id in keys] 
319                     record['keys'] = pubkeys
320
321         return records
322
323     def fill_record_hrns(self, records):
324         """
325         convert pl ids to hrns
326         """
327
328         # get ids
329         slice_ids, person_ids, site_ids, node_ids = [], [], [], []
330         for record in records:
331             if 'site_id' in record:
332                 site_ids.append(record['site_id'])
333             if 'site_ids' in record:
334                 site_ids.extend(record['site_ids'])
335             if 'person_ids' in record:
336                 person_ids.extend(record['person_ids'])
337             if 'slice_ids' in record:
338                 slice_ids.extend(record['slice_ids'])
339             if 'node_ids' in record:
340                 node_ids.extend(record['node_ids'])
341
342         # get pl records
343         slices, persons, sites, nodes = {}, {}, {}, {}
344         if site_ids:
345             site_list = self.GetSites(site_ids, ['site_id', 'login_base'])
346             sites = list_to_dict(site_list, 'site_id')
347         if person_ids:
348             person_list = self.GetPersons(person_ids, ['person_id', 'email'])
349             persons = list_to_dict(person_list, 'person_id')
350         if slice_ids:
351             slice_list = self.GetSlices(slice_ids, ['slice_id', 'name'])
352             slices = list_to_dict(slice_list, 'slice_id')       
353         if node_ids:
354             node_list = self.GetNodes(node_ids, ['node_id', 'hostname'])
355             nodes = list_to_dict(node_list, 'node_id')
356        
357         # convert ids to hrns
358         for record in records:
359             # get all relevant data
360             type = record['type']
361             pointer = record['pointer']
362             auth_hrn = self.hrn
363             login_base = ''
364             if pointer == -1:
365                 continue
366
367             if 'site_id' in record:
368                 site = sites[record['site_id']]
369                 login_base = site['login_base']
370                 record['site'] = ".".join([auth_hrn, login_base])
371             if 'person_ids' in record:
372                 emails = [persons[person_id]['email'] for person_id in record['person_ids'] \
373                           if person_id in  persons]
374                 usernames = [email.split('@')[0] for email in emails]
375                 person_hrns = [".".join([auth_hrn, login_base, username]) for username in usernames]
376                 record['persons'] = person_hrns 
377             if 'slice_ids' in record:
378                 slicenames = [slices[slice_id]['name'] for slice_id in record['slice_ids'] \
379                               if slice_id in slices]
380                 slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
381                 record['slices'] = slice_hrns
382             if 'node_ids' in record:
383                 hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \
384                              if node_id in nodes]
385                 node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
386                 record['nodes'] = node_hrns
387             if 'site_ids' in record:
388                 login_bases = [sites[site_id]['login_base'] for site_id in record['site_ids'] \
389                                if site_id in sites]
390                 site_hrns = [".".join([auth_hrn, lbase]) for lbase in login_bases]
391                 record['sites'] = site_hrns
392             
393         return records   
394
395     # aggregates is basically api.aggregates
396     def fill_record_sfa_info(self, records):
397
398         def startswith(prefix, values):
399             return [value for value in values if value.startswith(prefix)]
400
401         # get person ids
402         person_ids = []
403         site_ids = []
404         for record in records:
405             person_ids.extend(record.get("person_ids", []))
406             site_ids.extend(record.get("site_ids", [])) 
407             if 'site_id' in record:
408                 site_ids.append(record['site_id']) 
409         
410         # get all pis from the sites we've encountered
411         # and store them in a dictionary keyed on site_id 
412         site_pis = {}
413         if site_ids:
414             pi_filter = {'|roles': ['pi'], '|site_ids': site_ids} 
415             pi_list = self.GetPersons(pi_filter, ['person_id', 'site_ids'])
416             for pi in pi_list:
417                 # we will need the pi's hrns also
418                 person_ids.append(pi['person_id'])
419                 
420                 # we also need to keep track of the sites these pis
421                 # belong to
422                 for site_id in pi['site_ids']:
423                     if site_id in site_pis:
424                         site_pis[site_id].append(pi)
425                     else:
426                         site_pis[site_id] = [pi]
427                  
428         # get sfa records for all records associated with these records.   
429         # we'll replace pl ids (person_ids) with hrns from the sfa records
430         # we obtain
431         
432         # get the sfa records
433         table = SfaTable()
434         person_list, persons = [], {}
435         person_list = table.find({'type': 'user', 'pointer': person_ids})
436         # create a hrns keyed on the sfa record's pointer.
437         # Its possible for multiple records to have the same pointer so
438         # the dict's value will be a list of hrns.
439         persons = defaultdict(list)
440         for person in person_list:
441             persons[person['pointer']].append(person)
442
443         # get the pl records
444         pl_person_list, pl_persons = [], {}
445         pl_person_list = self.GetPersons(person_ids, ['person_id', 'roles'])
446         pl_persons = list_to_dict(pl_person_list, 'person_id')
447
448         # fill sfa info
449         for record in records:
450             # skip records with no pl info (top level authorities)
451             #if record['pointer'] == -1:
452             #    continue 
453             sfa_info = {}
454             type = record['type']
455             if (type == "slice"):
456                 # all slice users are researchers
457                 record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice')
458                 record['PI'] = []
459                 record['researcher'] = []
460                 for person_id in record.get('person_ids', []):
461                     hrns = [person['hrn'] for person in persons[person_id]]
462                     record['researcher'].extend(hrns)                
463
464                 # pis at the slice's site
465                 if 'site_id' in record and record['site_id'] in site_pis:
466                     pl_pis = site_pis[record['site_id']]
467                     pi_ids = [pi['person_id'] for pi in pl_pis]
468                     for person_id in pi_ids:
469                         hrns = [person['hrn'] for person in persons[person_id]]
470                         record['PI'].extend(hrns)
471                         record['geni_creator'] = record['PI'] 
472                 
473             elif (type.startswith("authority")):
474                 record['url'] = None
475                 if record['pointer'] != -1:
476                     record['PI'] = []
477                     record['operator'] = []
478                     record['owner'] = []
479                     for pointer in record.get('person_ids', []):
480                         if pointer not in persons or pointer not in pl_persons:
481                             # this means there is not sfa or pl record for this user
482                             continue   
483                         hrns = [person['hrn'] for person in persons[pointer]] 
484                         roles = pl_persons[pointer]['roles']   
485                         if 'pi' in roles:
486                             record['PI'].extend(hrns)
487                         if 'tech' in roles:
488                             record['operator'].extend(hrns)
489                         if 'admin' in roles:
490                             record['owner'].extend(hrns)
491                         # xxx TODO: OrganizationName
492             elif (type == "node"):
493                 sfa_info['dns'] = record.get("hostname", "")
494                 # xxx TODO: URI, LatLong, IP, DNS
495     
496             elif (type == "user"):
497                 sfa_info['email'] = record.get("email", "")
498                 sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user')
499                 sfa_info['geni_certificate'] = record['gid'] 
500                 # xxx TODO: PostalAddress, Phone
501             record.update(sfa_info)
502
503
504     ####################
505     # plcapi works by changes, compute what needs to be added/deleted
506     def update_relation (self, subject_type, target_type, subject_id, target_ids):
507         # hard-wire the code for slice/user for now, could be smarter if needed
508         if subject_type =='slice' and target_type == 'user':
509             subject=self.GetSlices (subject_id)[0]
510             current_target_ids = subject['person_ids']
511             add_target_ids = list ( set (target_ids).difference(current_target_ids))
512             del_target_ids = list ( set (current_target_ids).difference(target_ids))
513             logger.debug ("subject_id = %s (type=%s)"%(subject_id,type(subject_id)))
514             for target_id in add_target_ids:
515                 self.AddPersonToSlice (target_id,subject_id)
516                 logger.debug ("add_target_id = %s (type=%s)"%(target_id,type(target_id)))
517             for target_id in del_target_ids:
518                 logger.debug ("del_target_id = %s (type=%s)"%(target_id,type(target_id)))
519                 self.DeletePersonFromSlice (target_id, subject_id)
520         else:
521             logger.info('unexpected relation to maintain, %s -> %s'%(subject_type,target_type))
522
523         
524     ########################################
525     ########## aggregate oriented
526     ########################################
527
528     def testbed_name (self): return "myplc"
529
530     # 'geni_request_rspec_versions' and 'geni_ad_rspec_versions' are mandatory
531     def aggregate_version (self):
532         version_manager = VersionManager()
533         ad_rspec_versions = []
534         request_rspec_versions = []
535         for rspec_version in version_manager.versions:
536             if rspec_version.content_type in ['*', 'ad']:
537                 ad_rspec_versions.append(rspec_version.to_dict())
538             if rspec_version.content_type in ['*', 'request']:
539                 request_rspec_versions.append(rspec_version.to_dict()) 
540         return {
541             'testbed':self.testbed_name(),
542             'geni_request_rspec_versions': request_rspec_versions,
543             'geni_ad_rspec_versions': ad_rspec_versions,
544             }
545
546     def sliver_status (self, slice_urn, slice_hrn):
547         # find out where this slice is currently running
548         slicename = hrn_to_pl_slicename(slice_hrn)
549         
550         slices = self.GetSlices([slicename], ['slice_id', 'node_ids','person_ids','name','expires'])
551         if len(slices) == 0:        
552             raise Exception("Slice %s not found (used %s as slicename internally)" % (slice_xrn, slicename))
553         slice = slices[0]
554         
555         # report about the local nodes only
556         nodes = self.GetNodes({'node_id':slice['node_ids'],'peer_id':None},
557                               ['node_id', 'hostname', 'site_id', 'boot_state', 'last_contact'])
558         site_ids = [node['site_id'] for node in nodes]
559     
560         result = {}
561         top_level_status = 'unknown'
562         if nodes:
563             top_level_status = 'ready'
564         result['geni_urn'] = slice_urn
565         result['pl_login'] = slice['name']
566         result['pl_expires'] = datetime.datetime.fromtimestamp(slice['expires']).ctime()
567         
568         resources = []
569         for node in nodes:
570             res = {}
571             res['pl_hostname'] = node['hostname']
572             res['pl_boot_state'] = node['boot_state']
573             res['pl_last_contact'] = node['last_contact']
574             if node['last_contact'] is not None:
575                 res['pl_last_contact'] = datetime.datetime.fromtimestamp(node['last_contact']).ctime()
576             sliver_id = urn_to_sliver_id(slice_urn, slice['slice_id'], node['node_id']) 
577             res['geni_urn'] = sliver_id
578             if node['boot_state'] == 'boot':
579                 res['geni_status'] = 'ready'
580             else:
581                 res['geni_status'] = 'failed'
582                 top_level_status = 'failed' 
583                 
584             res['geni_error'] = ''
585     
586             resources.append(res)
587             
588         result['geni_status'] = top_level_status
589         result['geni_resources'] = resources
590         return result
591
592     def create_sliver (self, slice_urn, slice_hrn, creds, rspec_string, users, options):
593
594         aggregate = PlAggregate(self)
595         slices = PlSlices(self)
596         peer = slices.get_peer(slice_hrn)
597         sfa_peer = slices.get_sfa_peer(slice_hrn)
598         slice_record=None    
599         if users:
600             slice_record = users[0].get('slice_record', {})
601     
602         # parse rspec
603         rspec = RSpec(rspec_string)
604         requested_attributes = rspec.version.get_slice_attributes()
605         
606         # ensure site record exists
607         site = slices.verify_site(slice_hrn, slice_record, peer, sfa_peer, options=options)
608         # ensure slice record exists
609         slice = slices.verify_slice(slice_hrn, slice_record, peer, sfa_peer, options=options)
610         # ensure person records exists
611         persons = slices.verify_persons(slice_hrn, slice, users, peer, sfa_peer, options=options)
612         # ensure slice attributes exists
613         slices.verify_slice_attributes(slice, requested_attributes, options=options)
614         
615         # add/remove slice from nodes
616         requested_slivers = [node.get('component_name') for node in rspec.version.get_nodes_with_slivers()]
617         nodes = slices.verify_slice_nodes(slice, requested_slivers, peer) 
618    
619         # add/remove links links 
620         slices.verify_slice_links(slice, rspec.version.get_link_requests(), nodes)
621     
622         # handle MyPLC peer association.
623         # only used by plc and ple.
624         slices.handle_peer(site, slice, persons, peer)
625         
626         return aggregate.get_rspec(slice_xrn=slice_urn, version=rspec.version)
627
628     def renew_sliver (self, slice_urn, slice_hrn, creds, expiration_time, options):
629         slicename = hrn_to_pl_slicename(slice_hrn)
630         slices = self.driver.GetSlices({'name': slicename}, ['slice_id'])
631         if not slices:
632             raise RecordNotFound(slice_hrn)
633         slice = slices[0]
634         requested_time = utcparse(expiration_time)
635         record = {'expires': int(time.mktime(requested_time.timetuple()))}
636         try:
637             self.driver.UpdateSlice(slice['slice_id'], record)
638             return True
639         except:
640             return False
641