* tried to put some sense in the way things get logged, at least on server-side for now
[sfa.git] / sfa / plc / api.py
1 #
2 # SFA XML-RPC and SOAP interfaces
3 #
4 ### $Id$
5 ### $URL$
6 #
7
8 import sys
9 import os
10 import traceback
11 import string
12 import xmlrpclib
13
14 from sfa.util.sfalogging import sfa_logger
15 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
16 from sfa.trust.auth import Auth
17 from sfa.util.config import *
18 from sfa.util.faults import *
19 from sfa.trust.rights import *
20 from sfa.trust.credential import *
21 from sfa.trust.certificate import *
22 from sfa.util.namespace import *
23 from sfa.util.api import *
24 from sfa.util.nodemanager import NodeManager
25 try:
26     from collections import defaultdict
27 except:
28     class defaultdict(dict):
29         def __init__(self, default_factory=None, *a, **kw):
30             if (default_factory is not None and
31                 not hasattr(default_factory, '__call__')):
32                 raise TypeError('first argument must be callable')
33             dict.__init__(self, *a, **kw)
34             self.default_factory = default_factory
35         def __getitem__(self, key):
36             try:
37                 return dict.__getitem__(self, key)
38             except KeyError:
39                 return self.__missing__(key)
40         def __missing__(self, key):
41             if self.default_factory is None:
42                 raise KeyError(key)
43             self[key] = value = self.default_factory()
44             return value
45         def __reduce__(self):
46             if self.default_factory is None:
47                 args = tuple()
48             else:
49                 args = self.default_factory,
50             return type(self), args, None, None, self.items()
51         def copy(self):
52             return self.__copy__()
53         def __copy__(self):
54             return type(self)(self.default_factory, self)
55         def __deepcopy__(self, memo):
56             import copy
57             return type(self)(self.default_factory,
58                               copy.deepcopy(self.items()))
59         def __repr__(self):
60             return 'defaultdict(%s, %s)' % (self.default_factory,
61                                             dict.__repr__(self))
62 ## end of http://code.activestate.com/recipes/523034/ }}}
63
64 def list_to_dict(recs, key):
65     """
66     convert a list of dictionaries into a dictionary keyed on the 
67     specified dictionary key 
68     """
69     keys = [rec[key] for rec in recs]
70     return dict(zip(keys, recs))
71
72 class SfaAPI(BaseAPI):
73
74     # flat list of method names
75     import sfa.methods
76     methods = sfa.methods.all
77     
78     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", 
79                  methods='sfa.methods', peer_cert = None, interface = None, 
80                 key_file = None, cert_file = None, cache = None):
81         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, \
82                          peer_cert=peer_cert, interface=interface, key_file=key_file, \
83                          cert_file=cert_file, cache=cache)
84  
85         self.encoding = encoding
86         from sfa.util.table import SfaTable
87         self.SfaTable = SfaTable
88         # Better just be documenting the API
89         if config is None:
90             return
91
92         # Load configuration
93         self.config = Config(config)
94         self.auth = Auth(peer_cert)
95         self.interface = interface
96         self.key_file = key_file
97         self.key = Keypair(filename=self.key_file)
98         self.cert_file = cert_file
99         self.cert = Certificate(filename=self.cert_file)
100         self.credential = None
101         # Initialize the PLC shell only if SFA wraps a myPLC
102         rspec_type = self.config.get_aggregate_type()
103         if (rspec_type == 'pl' or rspec_type == 'vini' or rspec_type == 'eucalyptus'):
104             self.plshell = self.getPLCShell()
105             self.plshell_version = "4.3"
106
107         self.hrn = self.config.SFA_INTERFACE_HRN
108         self.time_format = "%Y-%m-%d %H:%M:%S"
109         self.logger=sfa_logger
110
111     def getPLCShell(self):
112         self.plauth = {'Username': self.config.SFA_PLC_USER,
113                        'AuthMethod': 'password',
114                        'AuthString': self.config.SFA_PLC_PASSWORD}
115         try:
116             sys.path.append(os.path.dirname(os.path.realpath("/usr/bin/plcsh")))
117             self.plshell_type = 'direct'
118             import PLC.Shell
119             shell = PLC.Shell.Shell(globals = globals())
120         except:
121             self.plshell_type = 'xmlrpc' 
122             url = self.config.SFA_PLC_URL
123             shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
124         
125         return shell
126
127     def getCredential(self):
128         """
129         Return a valid credential for this interface. 
130         """
131         if self.interface in ['registry']:
132             return self.getCredentialFromLocalRegistry()
133         else:
134             return self.getCredentialFromRegistry()
135
136     def getDelegatedCredential(self, creds):
137         """
138         Attempt to find a credential delegated to us in
139         the specified list of creds.
140         """
141         if creds and not isinstance(creds, list): 
142             creds = [creds]
143         delegated_creds = filter_creds_by_caller(creds,self.hrn)
144         if not delegated_creds:
145             return None
146         return delegated_creds[0]
147  
148     def getCredentialFromRegistry(self):
149         """ 
150         Get our credential from a remote registry 
151         """
152         type = 'authority'
153         path = self.config.SFA_DATA_DIR
154         filename = ".".join([self.interface, self.hrn, type, "cred"])
155         cred_filename = path + os.sep + filename
156         try:
157             credential = Credential(filename = cred_filename)
158             return credential.save_to_string(save_parents=True)
159         except IOError:
160             from sfa.server.registry import Registries
161             registries = Registries(self)
162             registry = registries[self.hrn]
163             cert_string=self.cert.save_to_string(save_parents=True)
164             # get self credential
165             self_cred = registry.get_self_credential(cert_string, type, self.hrn)
166             # get credential
167             cred = registry.get_credential(self_cred, type, self.hrn)
168             
169             # save cred to file
170             Credential(string=cred).save_to_file(cred_filename, save_parents=True)
171             return cred
172
173     def getCredentialFromLocalRegistry(self):
174         """
175         Get our current credential directly from the local registry.
176         """
177
178         hrn = self.hrn
179         auth_hrn = self.auth.get_authority(hrn)
180     
181         # is this a root or sub authority
182         if not auth_hrn or hrn == self.config.SFA_INTERFACE_HRN:
183             auth_hrn = hrn
184         auth_info = self.auth.get_auth_info(auth_hrn)
185         table = self.SfaTable()
186         records = table.findObjects(hrn)
187         if not records:
188             raise RecordNotFound
189         record = records[0]
190         type = record['type']
191         object_gid = record.get_gid_object()
192         new_cred = Credential(subject = object_gid.get_subject())
193         new_cred.set_gid_caller(object_gid)
194         new_cred.set_gid_object(object_gid)
195         new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename())
196         
197         r1 = determine_rights(type, hrn)
198         new_cred.set_privileges(r1)
199
200         auth_kind = "authority,ma,sa"
201
202         new_cred.set_parent(self.auth.hierarchy.get_auth_cred(auth_hrn, kind=auth_kind))
203
204         new_cred.encode()
205         new_cred.sign()
206
207         return new_cred.save_to_string(save_parents=True)
208    
209
210     def loadCredential (self):
211         """
212         Attempt to load credential from file if it exists. If it doesnt get
213         credential from registry.
214         """
215
216         # see if this file exists
217         # XX This is really the aggregate's credential. Using this is easier than getting
218         # the registry's credential from iteslf (ssl errors).   
219         ma_cred_filename = self.config.SFA_DATA_DIR + os.sep + self.interface + self.hrn + ".ma.cred"
220         try:
221             self.credential = Credential(filename = ma_cred_filename)
222         except IOError:
223             self.credential = self.getCredentialFromRegistry()
224
225     ##
226     # Convert SFA fields to PLC fields for use when registering up updating
227     # registry record in the PLC database
228     #
229     # @param type type of record (user, slice, ...)
230     # @param hrn human readable name
231     # @param sfa_fields dictionary of SFA fields
232     # @param pl_fields dictionary of PLC fields (output)
233
234     def sfa_fields_to_pl_fields(self, type, hrn, record):
235
236         def convert_ints(tmpdict, int_fields):
237             for field in int_fields:
238                 if field in tmpdict:
239                     tmpdict[field] = int(tmpdict[field])
240
241         pl_record = {}
242         #for field in record:
243         #    pl_record[field] = record[field]
244  
245         if type == "slice":
246             if not "instantiation" in pl_record:
247                 pl_record["instantiation"] = "plc-instantiated"
248             pl_record["name"] = hrn_to_pl_slicename(hrn)
249             if "url" in record:
250                pl_record["url"] = record["url"]
251             if "description" in record:
252                 pl_record["description"] = record["description"]
253             if "expires" in record:
254                 pl_record["expires"] = int(record["expires"])
255
256         elif type == "node":
257             if not "hostname" in pl_record:
258                 if not "hostname" in record:
259                     raise MissingSfaInfo("hostname")
260                 pl_record["hostname"] = record["hostname"]
261             if not "model" in pl_record:
262                 pl_record["model"] = "geni"
263
264         elif type == "authority":
265             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
266
267             if not "name" in pl_record:
268                 pl_record["name"] = hrn
269
270             if not "abbreviated_name" in pl_record:
271                 pl_record["abbreviated_name"] = hrn
272
273             if not "enabled" in pl_record:
274                 pl_record["enabled"] = True
275
276             if not "is_public" in pl_record:
277                 pl_record["is_public"] = True
278
279         return pl_record
280
281     def fill_record_pl_info(self, records):
282         """
283         Fill in the planetlab specific fields of a SFA record. This
284         involves calling the appropriate PLC method to retrieve the 
285         database record for the object.
286         
287         PLC data is filled into the pl_info field of the record.
288     
289         @param record: record to fill in field (in/out param)     
290         """
291         # get ids by type
292         node_ids, site_ids, slice_ids = [], [], [] 
293         person_ids, key_ids = [], []
294         type_map = {'node': node_ids, 'authority': site_ids,
295                     'slice': slice_ids, 'user': person_ids}
296                   
297         for record in records:
298             for type in type_map:
299                 if type == record['type']:
300                     type_map[type].append(record['pointer'])
301
302         # get pl records
303         nodes, sites, slices, persons, keys = {}, {}, {}, {}, {}
304         if node_ids:
305             node_list = self.plshell.GetNodes(self.plauth, node_ids)
306             nodes = list_to_dict(node_list, 'node_id')
307         if site_ids:
308             site_list = self.plshell.GetSites(self.plauth, site_ids)
309             sites = list_to_dict(site_list, 'site_id')
310         if slice_ids:
311             slice_list = self.plshell.GetSlices(self.plauth, slice_ids)
312             slices = list_to_dict(slice_list, 'slice_id')
313         if person_ids:
314             person_list = self.plshell.GetPersons(self.plauth, person_ids)
315             persons = list_to_dict(person_list, 'person_id')
316             for person in persons:
317                 key_ids.extend(persons[person]['key_ids'])
318
319         pl_records = {'node': nodes, 'authority': sites,
320                       'slice': slices, 'user': persons}
321
322         if key_ids:
323             key_list = self.plshell.GetKeys(self.plauth, key_ids)
324             keys = list_to_dict(key_list, 'key_id')
325
326         # fill record info
327         for record in records:
328             # records with pointer==-1 do not have plc info.
329             # for example, the top level authority records which are
330             # authorities, but not PL "sites"
331             if record['pointer'] == -1:
332                 continue
333            
334             for type in pl_records:
335                 if record['type'] == type:
336                     if record['pointer'] in pl_records[type]:
337                         record.update(pl_records[type][record['pointer']])
338                         break
339             # fill in key info
340             if record['type'] == 'user':
341                 pubkeys = [keys[key_id]['key'] for key_id in record['key_ids'] if key_id in keys] 
342                 record['keys'] = pubkeys
343
344         # fill in record hrns
345         records = self.fill_record_hrns(records)   
346  
347         return records
348
349     def fill_record_hrns(self, records):
350         """
351         convert pl ids to hrns
352         """
353
354         # get ids
355         slice_ids, person_ids, site_ids, node_ids = [], [], [], []
356         for record in records:
357             if 'site_id' in record:
358                 site_ids.append(record['site_id'])
359             if 'site_ids' in records:
360                 site_ids.extend(record['site_ids'])
361             if 'person_ids' in record:
362                 person_ids.extend(record['person_ids'])
363             if 'slice_ids' in record:
364                 slice_ids.extend(record['slice_ids'])
365             if 'node_ids' in record:
366                 node_ids.extend(record['node_ids'])
367
368         # get pl records
369         slices, persons, sites, nodes = {}, {}, {}, {}
370         if site_ids:
371             site_list = self.plshell.GetSites(self.plauth, site_ids, ['site_id', 'login_base'])
372             sites = list_to_dict(site_list, 'site_id')
373         if person_ids:
374             person_list = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'email'])
375             persons = list_to_dict(person_list, 'person_id')
376         if slice_ids:
377             slice_list = self.plshell.GetSlices(self.plauth, slice_ids, ['slice_id', 'name'])
378             slices = list_to_dict(slice_list, 'slice_id')       
379         if node_ids:
380             node_list = self.plshell.GetNodes(self.plauth, node_ids, ['node_id', 'hostname'])
381             nodes = list_to_dict(node_list, 'node_id')
382        
383         # convert ids to hrns
384         for record in records:
385              
386             # get all relevant data
387             type = record['type']
388             pointer = record['pointer']
389             auth_hrn = self.hrn
390             login_base = ''
391             if pointer == -1:
392                 continue
393
394             if 'site_id' in record:
395                 site = sites[record['site_id']]
396                 login_base = site['login_base']
397                 record['site'] = ".".join([auth_hrn, login_base])
398             if 'person_ids' in record:
399                 emails = [persons[person_id]['email'] for person_id in record['person_ids'] \
400                           if person_id in  persons]
401                 usernames = [email.split('@')[0] for email in emails]
402                 person_hrns = [".".join([auth_hrn, login_base, username]) for username in usernames]
403                 record['persons'] = person_hrns 
404             if 'slice_ids' in record:
405                 slicenames = [slices[slice_id]['name'] for slice_id in record['slice_ids'] \
406                               if slice_id in slices]
407                 slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
408                 record['slices'] = slice_hrns
409             if 'node_ids' in record:
410                 hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \
411                              if node_id in nodes]
412                 node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
413                 record['nodes'] = node_hrns
414             if 'site_ids' in record:
415                 login_bases = [sites[site_id]['login_base'] for site_id in record['site_ids'] \
416                                if site_id in sites]
417                 site_hrns = [".".join([auth_hrn, lbase]) for lbase in login_bases]
418                 record['sites'] = site_hrns
419
420         return records   
421
422     def fill_record_sfa_info(self, records):
423
424         def startswith(prefix, values):
425             return [value for value in values if value.startswith(prefix)]
426
427         # get person ids
428         person_ids = []
429         site_ids = []
430         for record in records:
431             person_ids.extend(record.get("person_ids", []))
432             site_ids.extend(record.get("site_ids", [])) 
433             if 'site_id' in record:
434                 site_ids.append(record['site_id']) 
435         
436         # get all pis from the sites we've encountered
437         # and store them in a dictionary keyed on site_id 
438         site_pis = {}
439         if site_ids:
440             pi_filter = {'|roles': ['pi'], '|site_ids': site_ids} 
441             pi_list = self.plshell.GetPersons(self.plauth, pi_filter, ['person_id', 'site_ids'])
442             for pi in pi_list:
443                 # we will need the pi's hrns also
444                 person_ids.append(pi['person_id'])
445                 
446                 # we also need to keep track of the sites these pis
447                 # belong to
448                 for site_id in pi['site_ids']:
449                     if site_id in site_pis:
450                         site_pis[site_id].append(pi)
451                     else:
452                         site_pis[site_id] = [pi]
453                  
454         # get sfa records for all records associated with these records.   
455         # we'll replace pl ids (person_ids) with hrns from the sfa records
456         # we obtain
457         
458         # get the sfa records
459         table = self.SfaTable()
460         person_list, persons = [], {}
461         person_list = table.find({'type': 'user', 'pointer': person_ids})
462         # create a hrns keyed on the sfa record's pointer.
463         # Its possible for  multiple records to have the same pointer so
464         # the dict's value will be a list of hrns.
465         persons = defaultdict(list)
466         for person in person_list:
467             persons[person['pointer']].append(person)
468
469         # get the pl records
470         pl_person_list, pl_persons = [], {}
471         pl_person_list = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'roles'])
472         pl_persons = list_to_dict(pl_person_list, 'person_id')
473
474         # fill sfa info
475         for record in records:
476             # skip records with no pl info (top level authorities)
477             if record['pointer'] == -1:
478                 continue 
479             sfa_info = {}
480             type = record['type']
481             if (type == "slice"):
482                 # all slice users are researchers
483                 record['PI'] = []
484                 record['researcher'] = []
485                 for person_id in record['person_ids']:
486                     hrns = [person['hrn'] for person in persons[person_id]]
487                     record['researcher'].extend(hrns)                
488
489                 # pis at the slice's site
490                 pl_pis = site_pis[record['site_id']]
491                 pi_ids = [pi['person_id'] for pi in pl_pis]
492                 for person_id in pi_ids:
493                     hrns = [person['hrn'] for person in persons[person_id]]
494                     record['PI'].extend(hrns)
495                 record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice')
496                 record['geni_creator'] = record['PI'] 
497                 
498             elif (type == "authority"):
499                 record['PI'] = []
500                 record['operator'] = []
501                 record['owner'] = []
502                 for pointer in record['person_ids']:
503                     if pointer not in persons or pointer not in pl_persons:
504                         # this means there is not sfa or pl record for this user
505                         continue   
506                     hrns = [person['hrn'] for person in persons[pointer]] 
507                     roles = pl_persons[pointer]['roles']   
508                     if 'pi' in roles:
509                         record['PI'].extend(hrns)
510                     if 'tech' in roles:
511                         record['operator'].extend(hrns)
512                     if 'admin' in roles:
513                         record['owner'].extend(hrns)
514                     # xxx TODO: OrganizationName
515             elif (type == "node"):
516                 sfa_info['dns'] = record.get("hostname", "")
517                 # xxx TODO: URI, LatLong, IP, DNS
518     
519             elif (type == "user"):
520                 sfa_info['email'] = record.get("email", "")
521                 sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user')
522                 sfa_info['geni_certificate'] = record['gid'] 
523                 # xxx TODO: PostalAddress, Phone
524             record.update(sfa_info)
525
526     def fill_record_info(self, records):
527         """
528         Given a SFA record, fill in the PLC specific and SFA specific
529         fields in the record. 
530         """
531         if not isinstance(records, list):
532             records = [records]
533
534         self.fill_record_pl_info(records)
535         self.fill_record_sfa_info(records)
536
537     def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc):
538         # get a list of the HRNs tht are members of the old and new records
539         if oldRecord:
540             oldList = oldRecord.get(listName, [])
541         else:
542             oldList = []     
543         newList = record.get(listName, [])
544
545         # if the lists are the same, then we don't have to update anything
546         if (oldList == newList):
547             return
548
549         # build a list of the new person ids, by looking up each person to get
550         # their pointer
551         newIdList = []
552         table = self.SfaTable()
553         records = table.find({'type': 'user', 'hrn': newList})
554         for rec in records:
555             newIdList.append(rec['pointer'])
556
557         # build a list of the old person ids from the person_ids field 
558         if oldRecord:
559             oldIdList = oldRecord.get("person_ids", [])
560             containerId = oldRecord.get_pointer()
561         else:
562             # if oldRecord==None, then we are doing a Register, instead of an
563             # update.
564             oldIdList = []
565             containerId = record.get_pointer()
566
567     # add people who are in the new list, but not the oldList
568         for personId in newIdList:
569             if not (personId in oldIdList):
570                 addFunc(self.plauth, personId, containerId)
571
572         # remove people who are in the old list, but not the new list
573         for personId in oldIdList:
574             if not (personId in newIdList):
575                 delFunc(self.plauth, personId, containerId)
576
577     def update_membership(self, oldRecord, record):
578         if record.type == "slice":
579             self.update_membership_list(oldRecord, record, 'researcher',
580                                         self.plshell.AddPersonToSlice,
581                                         self.plshell.DeletePersonFromSlice)
582         elif record.type == "authority":
583             # xxx TODO
584             pass
585
586
587
588 class ComponentAPI(BaseAPI):
589
590     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods',
591                  peer_cert = None, interface = None, key_file = None, cert_file = None):
592
593         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, peer_cert=peer_cert,
594                          interface=interface, key_file=key_file, cert_file=cert_file)
595         self.encoding = encoding
596
597         # Better just be documenting the API
598         if config is None:
599             return
600
601         self.nodemanager = NodeManager(self.config)
602
603     def sliver_exists(self):
604         sliver_dict = self.nodemanager.GetXIDs()
605         if slicename in sliver_dict.keys():
606             return True
607         else:
608             return False
609
610     def get_registry(self):
611         addr, port = self.config.SFA_REGISTRY_HOST, self.config.SFA_REGISTRY_PORT
612         url = "http://%(addr)s:%(port)s" % locals()
613         server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file)
614         return server
615
616     def get_node_key(self):
617         # this call requires no authentication,
618         # so we can generate a random keypair here
619         subject="component"
620         (kfd, keyfile) = tempfile.mkstemp()
621         (cfd, certfile) = tempfile.mkstemp()
622         key = Keypair(create=True)
623         key.save_to_file(keyfile)
624         cert = Certificate(subject=subject)
625         cert.set_issuer(key=key, subject=subject)
626         cert.set_pubkey(key)
627         cert.sign()
628         cert.save_to_file(certfile)
629         registry = self.get_registry()
630         # the registry will scp the key onto the node
631         registry.get_key()        
632
633     def getCredential(self):
634         """
635         Get our credential from a remote registry
636         """
637         path = self.config.SFA_DATA_DIR
638         config_dir = self.config.config_path
639         cred_filename = path + os.sep + 'node.cred'
640         try:
641             credential = Credential(filename = cred_filename)
642             return credential.save_to_string(save_parents=True)
643         except IOError:
644             node_pkey_file = config_dir + os.sep + "node.key"
645             node_gid_file = config_dir + os.sep + "node.gid"
646             cert_filename = path + os.sep + 'server.cert'
647             if not os.path.exists(node_pkey_file) or \
648                not os.path.exists(node_gid_file):
649                 self.get_node_key()
650
651             # get node's hrn
652             gid = GID(filename=node_gid_file)
653             hrn = gid.get_hrn()
654             # get credential from registry
655             cert_str = Certificate(filename=cert_filename).save_to_string(save_parents=True)
656             registry = self.get_registry()
657             cred = registry.get_self_credential(cert_str, 'node', hrn)
658             Credential(string=cred).save_to_file(credfile, save_parents=True)            
659
660             return cred
661
662     def clean_key_cred(self):
663         """
664         remove the existing keypair and cred  and generate new ones
665         """
666         files = ["server.key", "server.cert", "node.cred"]
667         for f in files:
668             filepath = KEYDIR + os.sep + f
669             if os.path.isfile(filepath):
670                 os.unlink(f)
671
672         # install the new key pair
673         # get_credential will take care of generating the new keypair
674         # and credential
675         self.get_node_key()
676         self.getCredential()
677
678