52411592baf2a652b72903aaf3c6bc5113c7a843
[sfa.git] / sfa / plc / api-dev.py
1 #
2 # SFA XML-RPC and SOAP interfaces
3 #
4 ### $Id: api.py 17793 2010-04-26 21:40:57Z tmack $
5 ### $URL: https://svn.planet-lab.org/svn/sfa/trunk/sfa/plc/api.py $
6 #
7
8 import sys
9 import os
10 import traceback
11 import string
12 import xmlrpclib
13
14 from sfa.util.sfalogging import sfa_logger
15 from sfa.trust.auth import Auth
16 from sfa.util.config import *
17 from sfa.util.faults import *
18 from sfa.trust.rights import *
19 from sfa.trust.credential import *
20 from sfa.trust.certificate import *
21 from sfa.util.namespace import *
22 from sfa.util.api import *
23 from sfa.util.nodemanager import NodeManager
24 from collections import defaultdict
25
26
27 class RecordInfo():
28     
29     shell = None
30     auth = None
31     
32     # pl records
33     sites = {}
34     slices = {}
35     persons = {}
36     nodes = {}
37     keys = {}
38
39     # sfa records
40     sfa_authorities = {}
41     sfa_slices = {}
42     sfa_users = {}
43     sfa_nodes = {}
44      
45     records = []
46
47     def __init__(self, api, records):
48         self.api = api
49         self.shell = api.plshell
50         self.auth = api.plauth
51         
52         site_ids = []
53         slice_ids = []
54         person_ids = []
55         node_ids = []
56
57         # put records into groups based on types
58         for record in records:
59             pointer = record['pointer']
60             if record['type'] == 'authority':
61                 self.sfa_authorities[pointer] = record
62                 self.records.append(record)
63                 site_ids.append(record['pointer'])
64             elif record['type'] == 'slice':
65                 self.sfa_slices[pointer] = record
66                 self.records.append(record)
67                 slice_ids.append(record['pointer'])
68             elif record['type'] == 'user':
69                 self.sfa_users[pointer] = record
70                 self.records.append(record)
71                 person_ids.append(record['pointer'])
72             elif record['type'] == 'node':
73                 self.sfa_nodes[pointer] = record
74                 self.records.append(record)
75                 node_ids.append(record['pointer']) 
76
77         # get pl info for these records
78         self.update_pl_sites(site_ids)
79         self.update_pl_slices(slice_ids)
80         self.update_pl_persons(person_ids)
81         self.update_pl_nodes(node_ids)
82        
83         site_ids = []
84         slice_ids = []
85         person_ids = []
86         node_ids = []
87         # now get pl records for all ids associated with 
88         # these records
89         for record in records:
90             if 'site_id' in record:
91                 site_ids.append(record['site_id'])
92             if 'site_ids' in records:
93                 site_ids.extend(record['site_ids'])
94             if 'person_ids' in record:
95                 person_ids.extend(record['person_ids'])
96             if 'slice_ids' in record:
97                 slice_ids.extend(record['slice_ids'])
98             if 'node_ids' in record:
99                 node_ids.extend(record['node_ids'])
100
101         # get pl info for these records
102         self.update_pl_sites(site_ids)
103         self.update_pl_slices(slice_ids)
104         self.update_pl_persons(person_ids)
105         self.update_pl_nodes(node_ids)
106
107         # convert pl ids to hrns  
108         self.update_hrns()
109
110         # update sfa info
111         self.update_sfa_info(person_ids)
112
113     def update_pl_sites(self, site_ids):
114         """
115         Update site records with PL info 
116         """
117         if not site_ids:
118             return
119         sites = self.shell.GetSites(self.auth, site_ids)
120         for site in sites:
121             site_id = site['site_id']
122             self.sites[site_id] = site
123             if site_id in self.sfa_authorities:
124                 self.sfa_authorities[site_id].update(site)
125
126     def update_pl_slices(self, slice_ids):
127         """
128         Update slice records with PL info
129         """
130         if not slice_ids:
131             return
132         slices = self.shell.GetSlices(self.auth, slice_ids)
133         for slice in slices:
134             slice_id = slice['slice_id']
135             self.slices[slice_id] = slice
136             if slice_id in self.sfa_slices:
137                 self.sfa_slices[slice_id].update(slice)
138
139     def update_pl_persons(self, person_ids):
140         """
141         Update person records with PL info
142         """
143         key_ids = []
144         if not person_ids:
145             return
146         persons = self.shell.GetPersons(self.auth, person_ids)
147         for person in persons:
148             person_id = person['person_id']
149             self.persons[person_id] = person 
150             key_ids.extend(person['key_ids'])
151             if person_id in self.sfa_users:
152                 self.sfa_users[person_id].update(person)
153         self.update_pl_keys(key_ids)
154
155     def update_pl_keys(self, key_ids):
156         """
157         Update user records with PL public key info
158         """
159         if not key_ids:
160             return
161         keys = self.shell.GetKeys(self.auth, key_ids)
162         for key in keys:
163             person_id = key['person_id']
164             self.keys[key['key_id']] = key
165             if person_id in self.sfa_users:
166                 person = self.sfa_users[person_id]    
167                 if not 'keys' in person:
168                     person['keys'] = [key['key']]
169                 else: 
170                     person['keys'].append(key['key'])
171
172     def update_pl_nodes(self, node_ids):
173         """
174         Update node records with PL info
175         """
176         if not node_ids:
177             return 
178         nodes = self.shell.GetNodes(self.auth, node_ids)
179         for node in nodes:
180             node_id = node['node_id']
181             self.nodes[node['node_id']] = node
182             if node_id in self.sfa_nodes:
183                 self.sfa_nodes[node_id].update(node)
184     
185
186     def update_hrns(self):
187         """
188         Convert pl ids to hrns
189         """
190         for record in self.records:
191             # get all necessary data
192             type = record['type']
193             pointer = record['pointer']
194             auth_hrn = self.api.hrn
195             login_base = ''
196             if pointer == -1:
197                 continue       
198
199             if 'site_id' in record:
200                 site = self.sites[record['site_id']]
201                 login_base = site['login_base']
202                 record['site'] = ".".join([auth_hrn, login_base])
203             if 'person_ids' in record:
204                 emails = [self.persons[person_id]['email'] for person_id in record['person_ids'] \
205                           if person_id in self.persons]
206                 usernames = [email.split('@')[0] for email in emails]
207                 person_hrns = [".".join([auth_hrn, login_base, username]) for username in usernames]
208                 record['persons'] = person_hrns
209             if 'slice_ids' in record:
210                 slicenames = [self.slices[slice_id]['name'] for slice_id in record['slice_ids'] \
211                               if slice_id in self.slices]
212                 slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
213                 record['slices'] = slice_hrns
214             if 'node_ids' in record:
215                 hostnames = [self.nodes[node_id]['hostname'] for node_id in record['node_ids'] \
216                              if node_id in self.nodes]
217                 node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
218                 record['nodes'] = node_hrns
219             if 'site_ids' in record:
220                 login_bases = [self.sites[site_id]['login_base'] for site_id in record['site_ids'] \
221                                if site_id in self.sites]
222                 site_hrns = [".".join([auth_hrn, lbase]) for lbase in login_bases]
223                 record['sites'] = site_hrns 
224
225     def update_sfa_info(self, person_ids):
226         from sfa.util.table import SfaTable
227         table = SfaTable()
228         persons = table.find({'type': 'user', 'pointer': person_ids})
229         # create a hrns keyed on the sfa record's pointer.
230         # Its possible for  multiple records to have the same pointer so 
231         # the dict's value will be a list of hrns.
232         person_dict = defaultdict(list)
233         for person in persons:
234             person_dict[person['pointer']].append(person['hrn'])                       
235
236         def startswith(prefix, values):
237             return [value for value in values if value.startswith(prefix)]
238
239         for record in self.records:
240             authority = record['authority']
241             if record['pointer'] == -1:
242                 continue
243             
244             if record['type'] == 'slice':
245                 # all slice users are researchers
246                 record['PI'] = []
247                 record['researchers'] = []    
248                 for person_id in record['person_ids']:
249                     record['researchers'].extend(person_dict[person_id])
250                 # also add the pis at the slice's site
251                 site = self.sites[record['site_id']]    
252                 for person_id in site['person_ids']:
253                     person = self.persons[person_id]
254                     if 'pi' in person['roles']:
255                         # PLCAPI doesn't support per site roles 
256                         # (a pi has the pi role at every site he belongs to).
257                         # We shouldnt allow this in SFA         
258                         record['PI'].extend(startswith(authority, person_dict[person_id]))    
259
260             elif record['type'] == 'authority':
261                 record['PI'] = []
262                 record['operator'] = []
263                 record['owner'] = []
264                 for person_id in record['person_ids']:
265                     person = self.persons[person_id]
266                     if 'pi' in person['roles']:
267                         # only get PI's at this site
268                         record['PI'].extend(startswith(record['hrn'], person_dict[person_id]))
269                     if 'tech' in person['roles']:
270                         # only get PI's at this site
271                         record['operator'].extend(startswith(record['hrn'], person_dict[person_id]))
272                     if 'admin' in person['roles']:
273                         record['owner'].extend(startswith(record['hrn'], person_dict[person_id]))
274                             
275             elif record['type'] == 'node':
276                 record['dns'] = record['hostname']
277
278             elif record['type'] == 'user':
279                 record['email'] = record['email']                   
280                                    
281                   
282                  
283                 
284                     
285     def get_records(self):
286         return self.records
287  
288
289 class SfaAPI(BaseAPI):
290
291     # flat list of method names
292     import sfa.methods
293     methods = sfa.methods.all
294     
295     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", 
296                  methods='sfa.methods', peer_cert = None, interface = None, 
297                 key_file = None, cert_file = None, cache = None):
298         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, \
299                          peer_cert=peer_cert, interface=interface, key_file=key_file, \
300                          cert_file=cert_file, cache=cache)
301  
302         self.encoding = encoding
303
304         from sfa.util.table import SfaTable
305         self.SfaTable = SfaTable
306         # Better just be documenting the API
307         if config is None:
308             return
309
310         # Load configuration
311         self.config = Config(config)
312         self.auth = Auth(peer_cert)
313         self.interface = interface
314         self.key_file = key_file
315         self.key = Keypair(filename=self.key_file)
316         self.cert_file = cert_file
317         self.cert = Certificate(filename=self.cert_file)
318         self.credential = None
319         # Initialize the PLC shell only if SFA wraps a myPLC
320         rspec_type = self.config.get_aggregate_type()
321         if (rspec_type == 'pl' or rspec_type == 'vini'):
322             self.plshell = self.getPLCShell()
323             self.plshell_version = "4.3"
324
325         self.hrn = self.config.SFA_INTERFACE_HRN
326         self.time_format = "%Y-%m-%d %H:%M:%S"
327         self.logger=sfa_logger()
328
329     def getPLCShell(self):
330         self.plauth = {'Username': self.config.SFA_PLC_USER,
331                        'AuthMethod': 'password',
332                        'AuthString': self.config.SFA_PLC_PASSWORD}
333
334         self.plshell_type = 'xmlrpc' 
335         # connect via xmlrpc
336         url = self.config.SFA_PLC_URL
337         shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
338         return shell
339
340     def getCredential(self):
341         if self.interface in ['registry']:
342             return self.getCredentialFromLocalRegistry()
343         else:
344             return self.getCredentialFromRegistry()
345     
346     def getCredentialFromRegistry(self):
347         """ 
348         Get our credential from a remote registry 
349         """
350         type = 'authority'
351         path = self.config.SFA_DATA_DIR
352         filename = ".".join([self.interface, self.hrn, type, "cred"])
353         cred_filename = path + os.sep + filename
354         try:
355             credential = Credential(filename = cred_filename)
356             return credential.save_to_string(save_parents=True)
357         except IOError:
358             from sfa.server.registry import Registries
359             registries = Registries(self)
360             registry = registries[self.hrn]
361             cert_string=self.cert.save_to_string(save_parents=True)
362             # get self credential
363             self_cred = registry.GetSelfCredential(cert_string, self.hrn, type)
364             # get credential
365             cred = registry.GetCredential(self_cred, type, self.hrn)
366             
367             # save cred to file
368             Credential(string=cred).save_to_file(cred_filename, save_parents=True)
369             return cred
370
371     def getCredentialFromLocalRegistry(self):
372         """
373         Get our current credential directly from the local registry.
374         """
375
376         hrn = self.hrn
377         auth_hrn = self.auth.get_authority(hrn)
378     
379         # is this a root or sub authority
380         if not auth_hrn or hrn == self.config.SFA_INTERFACE_HRN:
381             auth_hrn = hrn
382         auth_info = self.auth.get_auth_info(auth_hrn)
383         table = self.SfaTable()
384         records = table.findObjects(hrn)
385         if not records:
386             raise RecordNotFound
387         record = records[0]
388         type = record['type']
389         object_gid = record.get_gid_object()
390         new_cred = Credential(subject = object_gid.get_subject())
391         new_cred.set_gid_caller(object_gid)
392         new_cred.set_gid_object(object_gid)
393         new_cred.set_issuer(key=auth_info.get_pkey_object(), subject=auth_hrn)
394         new_cred.set_pubkey(object_gid.get_pubkey())
395         r1 = determine_rights(type, hrn)
396         new_cred.set_privileges(r1)
397
398         auth_kind = "authority,ma,sa"
399
400         new_cred.set_parent(self.auth.hierarchy.get_auth_cred(auth_hrn, kind=auth_kind))
401
402         new_cred.encode()
403         new_cred.sign()
404
405         return new_cred.save_to_string(save_parents=True)
406    
407
408     def loadCredential (self):
409         """
410         Attempt to load credential from file if it exists. If it doesnt get
411         credential from registry.
412         """
413
414         # see if this file exists
415         # XX This is really the aggregate's credential. Using this is easier than getting
416         # the registry's credential from iteslf (ssl errors).   
417         ma_cred_filename = self.config.SFA_DATA_DIR + os.sep + self.interface + self.hrn + ".ma.cred"
418         try:
419             self.credential = Credential(filename = ma_cred_filename)
420         except IOError:
421             self.credential = self.getCredentialFromRegistry()
422
423     ##
424     # Convert SFA fields to PLC fields for use when registering up updating
425     # registry record in the PLC database
426     #
427     # @param type type of record (user, slice, ...)
428     # @param hrn human readable name
429     # @param sfa_fields dictionary of SFA fields
430     # @param pl_fields dictionary of PLC fields (output)
431
432     def sfa_fields_to_pl_fields(self, type, hrn, record):
433
434         def convert_ints(tmpdict, int_fields):
435             for field in int_fields:
436                 if field in tmpdict:
437                     tmpdict[field] = int(tmpdict[field])
438
439         pl_record = {}
440         #for field in record:
441         #    pl_record[field] = record[field]
442  
443         if type == "slice":
444             if not "instantiation" in pl_record:
445                 pl_record["instantiation"] = "plc-instantiated"
446             pl_record["name"] = hrn_to_pl_slicename(hrn)
447             if "url" in record:
448                pl_record["url"] = record["url"]
449             if "description" in record:
450                 pl_record["description"] = record["description"]
451             if "expires" in record:
452                 pl_record["expires"] = int(record["expires"])
453
454         elif type == "node":
455             if not "hostname" in pl_record:
456                 if not "hostname" in record:
457                     raise MissingSfaInfo("hostname")
458                 pl_record["hostname"] = record["hostname"]
459             if not "model" in pl_record:
460                 pl_record["model"] = "geni"
461
462         elif type == "authority":
463             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
464
465             if not "name" in pl_record:
466                 pl_record["name"] = hrn
467
468             if not "abbreviated_name" in pl_record:
469                 pl_record["abbreviated_name"] = hrn
470
471             if not "enabled" in pl_record:
472                 pl_record["enabled"] = True
473
474             if not "is_public" in pl_record:
475                 pl_record["is_public"] = True
476
477         return pl_record
478
479
480     def fill_record_info(self, records):
481         """
482         Given a SFA record, fill in the PLC specific and SFA specific
483         fields in the record. 
484         """
485         if not isinstance(records, list):
486             records = [records]
487
488         record_info = RecordInfo(self, records)
489         return record_info.get_records()
490
491     def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc):
492         # get a list of the HRNs tht are members of the old and new records
493         if oldRecord:
494             oldList = oldRecord.get(listName, [])
495         else:
496             oldList = []     
497         newList = record.get(listName, [])
498
499         # if the lists are the same, then we don't have to update anything
500         if (oldList == newList):
501             return
502
503         # build a list of the new person ids, by looking up each person to get
504         # their pointer
505         newIdList = []
506         table = self.SfaTable()
507         records = table.find({'type': 'user', 'hrn': newList})
508         for rec in records:
509             newIdList.append(rec['pointer'])
510
511         # build a list of the old person ids from the person_ids field 
512         if oldRecord:
513             oldIdList = oldRecord.get("person_ids", [])
514             containerId = oldRecord.get_pointer()
515         else:
516             # if oldRecord==None, then we are doing a Register, instead of an
517             # update.
518             oldIdList = []
519             containerId = record.get_pointer()
520
521     # add people who are in the new list, but not the oldList
522         for personId in newIdList:
523             if not (personId in oldIdList):
524                 addFunc(self.plauth, personId, containerId)
525
526         # remove people who are in the old list, but not the new list
527         for personId in oldIdList:
528             if not (personId in newIdList):
529                 delFunc(self.plauth, personId, containerId)
530
531     def update_membership(self, oldRecord, record):
532         if record.type == "slice":
533             self.update_membership_list(oldRecord, record, 'researcher',
534                                         self.plshell.AddPersonToSlice,
535                                         self.plshell.DeletePersonFromSlice)
536         elif record.type == "authority":
537             # xxx TODO
538             pass
539
540
541
542 class ComponentAPI(BaseAPI):
543
544     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods',
545                  peer_cert = None, interface = None, key_file = None, cert_file = None):
546
547         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, peer_cert=peer_cert,
548                          interface=interface, key_file=key_file, cert_file=cert_file)
549         self.encoding = encoding
550
551         # Better just be documenting the API
552         if config is None:
553             return
554
555         self.nodemanager = NodeManager(self.config)
556
557     def sliver_exists(self):
558         sliver_dict = self.nodemanager.GetXIDs()
559         if slicename in sliver_dict.keys():
560             return True
561         else:
562             return False