group db-related stuff in sfa/storage
[sfa.git] / sfa / plc / pldriver.py
1 #
2 from sfa.util.faults import MissingSfaInfo, UnknownSfaType
3 from sfa.util.sfalogging import logger
4 from sfa.util.defaultdict import defaultdict
5 from sfa.util.xrn import hrn_to_urn, get_leaf
6 from sfa.util.plxrn import slicename_to_hrn, hostname_to_hrn, hrn_to_pl_slicename, hrn_to_pl_login_base
7
8 # one would think the driver should not need to mess with the SFA db, but..
9 from sfa.storage.table import SfaTable
10
11 # the driver interface, mostly provides default behaviours
12 from sfa.managers.driver import Driver
13
14 from sfa.plc.plshell import PlShell
15
16 def list_to_dict(recs, key):
17     """
18     convert a list of dictionaries into a dictionary keyed on the 
19     specified dictionary key 
20     """
21     return dict ( [ (rec[key],rec) for rec in recs ] )
22
23 #
24 # PlShell is just an xmlrpc serverproxy where methods
25 # can be sent as-is; it takes care of authentication
26 # from the global config
27
28 # so we inherit PlShell just so one can do driver.GetNodes
29 # which would not make much sense in the context of other testbeds
30 # so ultimately PlDriver should drop the PlShell inheritance
31 # and would have a driver.shell reference to a PlShell instead
32
33 class PlDriver (Driver, PlShell):
34
35     def __init__ (self, config):
36         PlShell.__init__ (self, config)
37  
38         self.hrn = config.SFA_INTERFACE_HRN
39
40     ########## disabled users 
41     def is_enabled (self, record):
42         # the incoming record was augmented already, so 'enabled' should be set
43         if record['type'] == 'user':
44             return record['enabled']
45         # only users can be disabled
46         return True
47
48     def augment_records_with_testbed_info (self, sfa_records):
49         return self.fill_record_info (sfa_records)
50
51     ########## 
52     def register (self, sfa_record, hrn, pub_key):
53         type = sfa_record['type']
54         pl_record = self.sfa_fields_to_pl_fields(type, hrn, sfa_record)
55
56         if type == 'authority':
57             sites = self.GetSites([pl_record['login_base']])
58             if not sites:
59                 pointer = self.AddSite(pl_record)
60             else:
61                 pointer = sites[0]['site_id']
62
63         elif type == 'slice':
64             acceptable_fields=['url', 'instantiation', 'name', 'description']
65             for key in pl_record.keys():
66                 if key not in acceptable_fields:
67                     pl_record.pop(key)
68             slices = self.GetSlices([pl_record['name']])
69             if not slices:
70                  pointer = self.AddSlice(pl_record)
71             else:
72                  pointer = slices[0]['slice_id']
73
74         elif type == 'user':
75             persons = self.GetPersons([sfa_record['email']])
76             if not persons:
77                 pointer = self.AddPerson(dict(sfa_record))
78             else:
79                 pointer = persons[0]['person_id']
80     
81             if 'enabled' in sfa_record and sfa_record['enabled']:
82                 self.UpdatePerson(pointer, {'enabled': sfa_record['enabled']})
83             # add this person to the site only if she is being added for the first
84             # time by sfa and doesont already exist in plc
85             if not persons or not persons[0]['site_ids']:
86                 login_base = get_leaf(sfa_record['authority'])
87                 self.AddPersonToSite(pointer, login_base)
88     
89             # What roles should this user have?
90             self.AddRoleToPerson('user', pointer)
91             # Add the user's key
92             if pub_key:
93                 self.AddPersonKey(pointer, {'key_type' : 'ssh', 'key' : pub_key})
94
95         elif type == 'node':
96             login_base = hrn_to_pl_login_base(sfa_record['authority'])
97             nodes = api.driver.GetNodes([pl_record['hostname']])
98             if not nodes:
99                 pointer = api.driver.AddNode(login_base, pl_record)
100             else:
101                 pointer = nodes[0]['node_id']
102     
103         return pointer
104         
105     ##########
106     # xxx actually old_sfa_record comes filled with plc stuff as well in the original code
107     def update (self, old_sfa_record, new_sfa_record, hrn, new_key):
108         pointer = old_sfa_record['pointer']
109         type = old_sfa_record['type']
110
111         # new_key implemented for users only
112         if new_key and type not in [ 'user' ]:
113             raise UnknownSfaType(type)
114
115         if (type == "authority"):
116             self.UpdateSite(pointer, new_sfa_record)
117     
118         elif type == "slice":
119             pl_record=self.sfa_fields_to_pl_fields(type, hrn, new_sfa_record)
120             if 'name' in pl_record:
121                 pl_record.pop('name')
122                 self.UpdateSlice(pointer, pl_record)
123     
124         elif type == "user":
125             # SMBAKER: UpdatePerson only allows a limited set of fields to be
126             #    updated. Ideally we should have a more generic way of doing
127             #    this. I copied the field names from UpdatePerson.py...
128             update_fields = {}
129             all_fields = new_sfa_record
130             for key in all_fields.keys():
131                 if key in ['first_name', 'last_name', 'title', 'email',
132                            'password', 'phone', 'url', 'bio', 'accepted_aup',
133                            'enabled']:
134                     update_fields[key] = all_fields[key]
135             self.UpdatePerson(pointer, update_fields)
136     
137             if new_key:
138                 # must check this key against the previous one if it exists
139                 persons = self.GetPersons([pointer], ['key_ids'])
140                 person = persons[0]
141                 keys = person['key_ids']
142                 keys = self.GetKeys(person['key_ids'])
143                 
144                 # Delete all stale keys
145                 key_exists = False
146                 for key in keys:
147                     if new_key != key['key']:
148                         self.DeleteKey(key['key_id'])
149                     else:
150                         key_exists = True
151                 if not key_exists:
152                     self.AddPersonKey(pointer, {'key_type': 'ssh', 'key': new_key})
153     
154         elif type == "node":
155             self.UpdateNode(pointer, new_sfa_record)
156
157         return True
158         
159
160     ##########
161     def remove (self, sfa_record):
162         type=sfa_record['type']
163         pointer=sfa_record['pointer']
164         if type == 'user':
165             persons = self.GetPersons(pointer)
166             # only delete this person if he has site ids. if he doesnt, it probably means
167             # he was just removed from a site, not actually deleted
168             if persons and persons[0]['site_ids']:
169                 self.DeletePerson(pointer)
170         elif type == 'slice':
171             if self.GetSlices(pointer):
172                 self.DeleteSlice(pointer)
173         elif type == 'node':
174             if self.GetNodes(pointer):
175                 self.DeleteNode(pointer)
176         elif type == 'authority':
177             if self.GetSites(pointer):
178                 self.DeleteSite(pointer)
179
180         return True
181
182
183
184
185
186     ##
187     # Convert SFA fields to PLC fields for use when registering up updating
188     # registry record in the PLC database
189     #
190
191     def sfa_fields_to_pl_fields(self, type, hrn, sfa_record):
192
193         pl_record = {}
194  
195         if type == "slice":
196             pl_record["name"] = hrn_to_pl_slicename(hrn)
197             if "instantiation" in sfa_record:
198                 pl_record['instantiation']=sfa_record['instantiation']
199             else:
200                 pl_record["instantiation"] = "plc-instantiated"
201             if "url" in sfa_record:
202                pl_record["url"] = sfa_record["url"]
203             if "description" in sfa_record:
204                 pl_record["description"] = sfa_record["description"]
205             if "expires" in sfa_record:
206                 pl_record["expires"] = int(sfa_record["expires"])
207
208         elif type == "node":
209             if not "hostname" in pl_record:
210                 # fetch from sfa_record
211                 if "hostname" not in sfa_record:
212                     raise MissingSfaInfo("hostname")
213                 pl_record["hostname"] = sfa_record["hostname"]
214             if "model" in sfa_record: 
215                 pl_record["model"] = sfa_record["model"]
216             else:
217                 pl_record["model"] = "geni"
218
219         elif type == "authority":
220             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
221             if "name" not in sfa_record:
222                 pl_record["name"] = hrn
223             if "abbreviated_name" not in sfa_record:
224                 pl_record["abbreviated_name"] = hrn
225             if "enabled" not in sfa_record:
226                 pl_record["enabled"] = True
227             if "is_public" not in sfa_record:
228                 pl_record["is_public"] = True
229
230         return pl_record
231
232     ####################
233     def fill_record_info(self, records):
234         """
235         Given a (list of) SFA record, fill in the PLC specific 
236         and SFA specific fields in the record. 
237         """
238         if not isinstance(records, list):
239             records = [records]
240
241         self.fill_record_pl_info(records)
242         self.fill_record_hrns(records)
243         self.fill_record_sfa_info(records)
244         return records
245
246     def fill_record_pl_info(self, records):
247         """
248         Fill in the planetlab specific fields of a SFA record. This
249         involves calling the appropriate PLC method to retrieve the 
250         database record for the object.
251             
252         @param record: record to fill in field (in/out param)     
253         """
254         # get ids by type
255         node_ids, site_ids, slice_ids = [], [], [] 
256         person_ids, key_ids = [], []
257         type_map = {'node': node_ids, 'authority': site_ids,
258                     'slice': slice_ids, 'user': person_ids}
259                   
260         for record in records:
261             for type in type_map:
262                 if type == record['type']:
263                     type_map[type].append(record['pointer'])
264
265         # get pl records
266         nodes, sites, slices, persons, keys = {}, {}, {}, {}, {}
267         if node_ids:
268             node_list = self.GetNodes(node_ids)
269             nodes = list_to_dict(node_list, 'node_id')
270         if site_ids:
271             site_list = self.GetSites(site_ids)
272             sites = list_to_dict(site_list, 'site_id')
273         if slice_ids:
274             slice_list = self.GetSlices(slice_ids)
275             slices = list_to_dict(slice_list, 'slice_id')
276         if person_ids:
277             person_list = self.GetPersons(person_ids)
278             persons = list_to_dict(person_list, 'person_id')
279             for person in persons:
280                 key_ids.extend(persons[person]['key_ids'])
281
282         pl_records = {'node': nodes, 'authority': sites,
283                       'slice': slices, 'user': persons}
284
285         if key_ids:
286             key_list = self.GetKeys(key_ids)
287             keys = list_to_dict(key_list, 'key_id')
288
289         # fill record info
290         for record in records:
291             # records with pointer==-1 do not have plc info.
292             # for example, the top level authority records which are
293             # authorities, but not PL "sites"
294             if record['pointer'] == -1:
295                 continue
296            
297             for type in pl_records:
298                 if record['type'] == type:
299                     if record['pointer'] in pl_records[type]:
300                         record.update(pl_records[type][record['pointer']])
301                         break
302             # fill in key info
303             if record['type'] == 'user':
304                 if 'key_ids' not in record:
305                     logger.info("user record has no 'key_ids' - need to import from myplc ?")
306                 else:
307                     pubkeys = [keys[key_id]['key'] for key_id in record['key_ids'] if key_id in keys] 
308                     record['keys'] = pubkeys
309
310         return records
311
312     def fill_record_hrns(self, records):
313         """
314         convert pl ids to hrns
315         """
316
317         # get ids
318         slice_ids, person_ids, site_ids, node_ids = [], [], [], []
319         for record in records:
320             if 'site_id' in record:
321                 site_ids.append(record['site_id'])
322             if 'site_ids' in record:
323                 site_ids.extend(record['site_ids'])
324             if 'person_ids' in record:
325                 person_ids.extend(record['person_ids'])
326             if 'slice_ids' in record:
327                 slice_ids.extend(record['slice_ids'])
328             if 'node_ids' in record:
329                 node_ids.extend(record['node_ids'])
330
331         # get pl records
332         slices, persons, sites, nodes = {}, {}, {}, {}
333         if site_ids:
334             site_list = self.GetSites(site_ids, ['site_id', 'login_base'])
335             sites = list_to_dict(site_list, 'site_id')
336         if person_ids:
337             person_list = self.GetPersons(person_ids, ['person_id', 'email'])
338             persons = list_to_dict(person_list, 'person_id')
339         if slice_ids:
340             slice_list = self.GetSlices(slice_ids, ['slice_id', 'name'])
341             slices = list_to_dict(slice_list, 'slice_id')       
342         if node_ids:
343             node_list = self.GetNodes(node_ids, ['node_id', 'hostname'])
344             nodes = list_to_dict(node_list, 'node_id')
345        
346         # convert ids to hrns
347         for record in records:
348             # get all relevant data
349             type = record['type']
350             pointer = record['pointer']
351             auth_hrn = self.hrn
352             login_base = ''
353             if pointer == -1:
354                 continue
355
356             if 'site_id' in record:
357                 site = sites[record['site_id']]
358                 login_base = site['login_base']
359                 record['site'] = ".".join([auth_hrn, login_base])
360             if 'person_ids' in record:
361                 emails = [persons[person_id]['email'] for person_id in record['person_ids'] \
362                           if person_id in  persons]
363                 usernames = [email.split('@')[0] for email in emails]
364                 person_hrns = [".".join([auth_hrn, login_base, username]) for username in usernames]
365                 record['persons'] = person_hrns 
366             if 'slice_ids' in record:
367                 slicenames = [slices[slice_id]['name'] for slice_id in record['slice_ids'] \
368                               if slice_id in slices]
369                 slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
370                 record['slices'] = slice_hrns
371             if 'node_ids' in record:
372                 hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \
373                              if node_id in nodes]
374                 node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
375                 record['nodes'] = node_hrns
376             if 'site_ids' in record:
377                 login_bases = [sites[site_id]['login_base'] for site_id in record['site_ids'] \
378                                if site_id in sites]
379                 site_hrns = [".".join([auth_hrn, lbase]) for lbase in login_bases]
380                 record['sites'] = site_hrns
381             
382         return records   
383
384     # aggregates is basically api.aggregates
385     def fill_record_sfa_info(self, records):
386
387         def startswith(prefix, values):
388             return [value for value in values if value.startswith(prefix)]
389
390         # get person ids
391         person_ids = []
392         site_ids = []
393         for record in records:
394             person_ids.extend(record.get("person_ids", []))
395             site_ids.extend(record.get("site_ids", [])) 
396             if 'site_id' in record:
397                 site_ids.append(record['site_id']) 
398         
399         # get all pis from the sites we've encountered
400         # and store them in a dictionary keyed on site_id 
401         site_pis = {}
402         if site_ids:
403             pi_filter = {'|roles': ['pi'], '|site_ids': site_ids} 
404             pi_list = self.GetPersons(pi_filter, ['person_id', 'site_ids'])
405             for pi in pi_list:
406                 # we will need the pi's hrns also
407                 person_ids.append(pi['person_id'])
408                 
409                 # we also need to keep track of the sites these pis
410                 # belong to
411                 for site_id in pi['site_ids']:
412                     if site_id in site_pis:
413                         site_pis[site_id].append(pi)
414                     else:
415                         site_pis[site_id] = [pi]
416                  
417         # get sfa records for all records associated with these records.   
418         # we'll replace pl ids (person_ids) with hrns from the sfa records
419         # we obtain
420         
421         # get the sfa records
422         table = SfaTable()
423         person_list, persons = [], {}
424         person_list = table.find({'type': 'user', 'pointer': person_ids})
425         # create a hrns keyed on the sfa record's pointer.
426         # Its possible for multiple records to have the same pointer so
427         # the dict's value will be a list of hrns.
428         persons = defaultdict(list)
429         for person in person_list:
430             persons[person['pointer']].append(person)
431
432         # get the pl records
433         pl_person_list, pl_persons = [], {}
434         pl_person_list = self.GetPersons(person_ids, ['person_id', 'roles'])
435         pl_persons = list_to_dict(pl_person_list, 'person_id')
436
437         # fill sfa info
438         for record in records:
439             # skip records with no pl info (top level authorities)
440             #if record['pointer'] == -1:
441             #    continue 
442             sfa_info = {}
443             type = record['type']
444             if (type == "slice"):
445                 # all slice users are researchers
446                 record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice')
447                 record['PI'] = []
448                 record['researcher'] = []
449                 for person_id in record.get('person_ids', []):
450                     hrns = [person['hrn'] for person in persons[person_id]]
451                     record['researcher'].extend(hrns)                
452
453                 # pis at the slice's site
454                 if 'site_id' in record and record['site_id'] in site_pis:
455                     pl_pis = site_pis[record['site_id']]
456                     pi_ids = [pi['person_id'] for pi in pl_pis]
457                     for person_id in pi_ids:
458                         hrns = [person['hrn'] for person in persons[person_id]]
459                         record['PI'].extend(hrns)
460                         record['geni_creator'] = record['PI'] 
461                 
462             elif (type.startswith("authority")):
463                 record['url'] = None
464                 if record['pointer'] != -1:
465                     record['PI'] = []
466                     record['operator'] = []
467                     record['owner'] = []
468                     for pointer in record.get('person_ids', []):
469                         if pointer not in persons or pointer not in pl_persons:
470                             # this means there is not sfa or pl record for this user
471                             continue   
472                         hrns = [person['hrn'] for person in persons[pointer]] 
473                         roles = pl_persons[pointer]['roles']   
474                         if 'pi' in roles:
475                             record['PI'].extend(hrns)
476                         if 'tech' in roles:
477                             record['operator'].extend(hrns)
478                         if 'admin' in roles:
479                             record['owner'].extend(hrns)
480                         # xxx TODO: OrganizationName
481             elif (type == "node"):
482                 sfa_info['dns'] = record.get("hostname", "")
483                 # xxx TODO: URI, LatLong, IP, DNS
484     
485             elif (type == "user"):
486                 sfa_info['email'] = record.get("email", "")
487                 sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user')
488                 sfa_info['geni_certificate'] = record['gid'] 
489                 # xxx TODO: PostalAddress, Phone
490             record.update(sfa_info)
491
492
493     ####################
494     # plcapi works by changes, compute what needs to be added/deleted
495     def update_relation (self, subject_type, target_type, subject_id, target_ids):
496         # hard-wire the code for slice/user for now, could be smarter if needed
497         if subject_type =='slice' and target_type == 'user':
498             subject=self.GetSlices (subject_id)[0]
499             current_target_ids = subject['person_ids']
500             add_target_ids = list ( set (target_ids).difference(current_target_ids))
501             del_target_ids = list ( set (current_target_ids).difference(target_ids))
502             logger.info ("subject_id = %s (type=%s)"%(subject_id,type(subject_id)))
503             for target_id in add_target_ids:
504                 self.AddPersonToSlice (target_id,subject_id)
505                 logger.info ("add_target_id = %s (type=%s)"%(target_id,type(target_id)))
506             for target_id in del_target_ids:
507                 logger.info ("del_target_id = %s (type=%s)"%(target_id,type(target_id)))
508                 self.DeletePersonFromSlice (target_id, subject_id)
509         else:
510             logger.info('unexpected relation to maintain, %s -> %s'%(subject_type,target_type))
511
512