sfadump more usable
[sfa.git] / sfa / plc / api.py
1 #
2 # SFA XML-RPC and SOAP interfaces
3 #
4
5 import sys
6 import os
7 import traceback
8 import string
9 import xmlrpclib
10
11 from sfa.util.faults import *
12 from sfa.util.api import *
13 from sfa.util.config import *
14 from sfa.util.sfalogging import sfa_logger
15 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
16 from sfa.trust.auth import Auth
17 from sfa.trust.rights import Right, Rights
18 from sfa.trust.credential import Credential,Keypair
19 from sfa.trust.certificate import Certificate
20 from sfa.util.namespace import get_authority, hrn_to_pl_slicename, hrn_to_pl_slicename, hrn_to_urn, slicename_to_hrn, hostname_to_hrn
21 from sfa.util.nodemanager import NodeManager
22 try:
23     from collections import defaultdict
24 except:
25     class defaultdict(dict):
26         def __init__(self, default_factory=None, *a, **kw):
27             if (default_factory is not None and
28                 not hasattr(default_factory, '__call__')):
29                 raise TypeError('first argument must be callable')
30             dict.__init__(self, *a, **kw)
31             self.default_factory = default_factory
32         def __getitem__(self, key):
33             try:
34                 return dict.__getitem__(self, key)
35             except KeyError:
36                 return self.__missing__(key)
37         def __missing__(self, key):
38             if self.default_factory is None:
39                 raise KeyError(key)
40             self[key] = value = self.default_factory()
41             return value
42         def __reduce__(self):
43             if self.default_factory is None:
44                 args = tuple()
45             else:
46                 args = self.default_factory,
47             return type(self), args, None, None, self.items()
48         def copy(self):
49             return self.__copy__()
50         def __copy__(self):
51             return type(self)(self.default_factory, self)
52         def __deepcopy__(self, memo):
53             import copy
54             return type(self)(self.default_factory,
55                               copy.deepcopy(self.items()))
56         def __repr__(self):
57             return 'defaultdict(%s, %s)' % (self.default_factory,
58                                             dict.__repr__(self))
59 ## end of http://code.activestate.com/recipes/523034/ }}}
60
61 def list_to_dict(recs, key):
62     """
63     convert a list of dictionaries into a dictionary keyed on the 
64     specified dictionary key 
65     """
66     keys = [rec[key] for rec in recs]
67     return dict(zip(keys, recs))
68
69 class SfaAPI(BaseAPI):
70
71     # flat list of method names
72     import sfa.methods
73     methods = sfa.methods.all
74     
75     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", 
76                  methods='sfa.methods', peer_cert = None, interface = None, 
77                 key_file = None, cert_file = None, cache = None):
78         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, \
79                          peer_cert=peer_cert, interface=interface, key_file=key_file, \
80                          cert_file=cert_file, cache=cache)
81  
82         self.encoding = encoding
83         from sfa.util.table import SfaTable
84         self.SfaTable = SfaTable
85         # Better just be documenting the API
86         if config is None:
87             return
88
89         # Load configuration
90         self.config = Config(config)
91         self.auth = Auth(peer_cert)
92         self.interface = interface
93         self.key_file = key_file
94         self.key = Keypair(filename=self.key_file)
95         self.cert_file = cert_file
96         self.cert = Certificate(filename=self.cert_file)
97         self.credential = None
98         # Initialize the PLC shell only if SFA wraps a myPLC
99         rspec_type = self.config.get_aggregate_type()
100         if (rspec_type == 'pl' or rspec_type == 'vini' or rspec_type == 'eucalyptus'):
101             self.plshell = self.getPLCShell()
102             self.plshell_version = "4.3"
103
104         self.hrn = self.config.SFA_INTERFACE_HRN
105         self.time_format = "%Y-%m-%d %H:%M:%S"
106         self.logger=sfa_logger()
107
108     def getPLCShell(self):
109         self.plauth = {'Username': self.config.SFA_PLC_USER,
110                        'AuthMethod': 'password',
111                        'AuthString': self.config.SFA_PLC_PASSWORD}
112         try:
113             sys.path.append(os.path.dirname(os.path.realpath("/usr/bin/plcsh")))
114             self.plshell_type = 'direct'
115             import PLC.Shell
116             shell = PLC.Shell.Shell(globals = globals())
117         except:
118             self.plshell_type = 'xmlrpc' 
119             url = self.config.SFA_PLC_URL
120             shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
121         
122         return shell
123
124     def getCredential(self):
125         """
126         Return a valid credential for this interface. 
127         """
128         if self.interface in ['registry']:
129             return self.getCredentialFromLocalRegistry()
130         else:
131             return self.getCredentialFromRegistry()
132
133     def getDelegatedCredential(self, creds):
134         """
135         Attempt to find a credential delegated to us in
136         the specified list of creds.
137         """
138         if creds and not isinstance(creds, list): 
139             creds = [creds]
140         delegated_creds = filter_creds_by_caller(creds,self.hrn)
141         if not delegated_creds:
142             return None
143         return delegated_creds[0]
144  
145     def getCredentialFromRegistry(self):
146         """ 
147         Get our credential from a remote registry 
148         """
149         type = 'authority'
150         path = self.config.SFA_DATA_DIR
151         filename = ".".join([self.interface, self.hrn, type, "cred"])
152         cred_filename = path + os.sep + filename
153         try:
154             credential = Credential(filename = cred_filename)
155             return credential.save_to_string(save_parents=True)
156         except IOError:
157             from sfa.server.registry import Registries
158             registries = Registries(self)
159             registry = registries[self.hrn]
160             cert_string=self.cert.save_to_string(save_parents=True)
161             # get self credential
162             self_cred = registry.GetSelfCredential(cert_string, self.hrn, type)
163             # get credential
164             cred = registry.GetCredential(self_cred, self.hrn, type)
165             
166             # save cred to file
167             Credential(string=cred).save_to_file(cred_filename, save_parents=True)
168             return cred
169
170     def getCredentialFromLocalRegistry(self):
171         """
172         Get our current credential directly from the local registry.
173         """
174
175         hrn = self.hrn
176         auth_hrn = self.auth.get_authority(hrn)
177     
178         # is this a root or sub authority
179         if not auth_hrn or hrn == self.config.SFA_INTERFACE_HRN:
180             auth_hrn = hrn
181         auth_info = self.auth.get_auth_info(auth_hrn)
182         table = self.SfaTable()
183         records = table.findObjects(hrn)
184         if not records:
185             raise RecordNotFound
186         record = records[0]
187         type = record['type']
188         object_gid = record.get_gid_object()
189         new_cred = Credential(subject = object_gid.get_subject())
190         new_cred.set_gid_caller(object_gid)
191         new_cred.set_gid_object(object_gid)
192         new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename())
193         
194         r1 = determine_rights(type, hrn)
195         new_cred.set_privileges(r1)
196
197         auth_kind = "authority,ma,sa"
198
199         new_cred.set_parent(self.auth.hierarchy.get_auth_cred(auth_hrn, kind=auth_kind))
200
201         new_cred.encode()
202         new_cred.sign()
203
204         return new_cred.save_to_string(save_parents=True)
205    
206
207     def loadCredential (self):
208         """
209         Attempt to load credential from file if it exists. If it doesnt get
210         credential from registry.
211         """
212
213         # see if this file exists
214         # XX This is really the aggregate's credential. Using this is easier than getting
215         # the registry's credential from iteslf (ssl errors).   
216         ma_cred_filename = self.config.SFA_DATA_DIR + os.sep + self.interface + self.hrn + ".ma.cred"
217         try:
218             self.credential = Credential(filename = ma_cred_filename)
219         except IOError:
220             self.credential = self.getCredentialFromRegistry()
221
222     ##
223     # Convert SFA fields to PLC fields for use when registering up updating
224     # registry record in the PLC database
225     #
226     # @param type type of record (user, slice, ...)
227     # @param hrn human readable name
228     # @param sfa_fields dictionary of SFA fields
229     # @param pl_fields dictionary of PLC fields (output)
230
231     def sfa_fields_to_pl_fields(self, type, hrn, record):
232
233         def convert_ints(tmpdict, int_fields):
234             for field in int_fields:
235                 if field in tmpdict:
236                     tmpdict[field] = int(tmpdict[field])
237
238         pl_record = {}
239         #for field in record:
240         #    pl_record[field] = record[field]
241  
242         if type == "slice":
243             if not "instantiation" in pl_record:
244                 pl_record["instantiation"] = "plc-instantiated"
245             pl_record["name"] = hrn_to_pl_slicename(hrn)
246             if "url" in record:
247                pl_record["url"] = record["url"]
248             if "description" in record:
249                 pl_record["description"] = record["description"]
250             if "expires" in record:
251                 pl_record["expires"] = int(record["expires"])
252
253         elif type == "node":
254             if not "hostname" in pl_record:
255                 if not "hostname" in record:
256                     raise MissingSfaInfo("hostname")
257                 pl_record["hostname"] = record["hostname"]
258             if not "model" in pl_record:
259                 pl_record["model"] = "geni"
260
261         elif type == "authority":
262             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
263
264             if not "name" in pl_record:
265                 pl_record["name"] = hrn
266
267             if not "abbreviated_name" in pl_record:
268                 pl_record["abbreviated_name"] = hrn
269
270             if not "enabled" in pl_record:
271                 pl_record["enabled"] = True
272
273             if not "is_public" in pl_record:
274                 pl_record["is_public"] = True
275
276         return pl_record
277
278     def fill_record_pl_info(self, records):
279         """
280         Fill in the planetlab specific fields of a SFA record. This
281         involves calling the appropriate PLC method to retrieve the 
282         database record for the object.
283         
284         PLC data is filled into the pl_info field of the record.
285     
286         @param record: record to fill in field (in/out param)     
287         """
288         # get ids by type
289         node_ids, site_ids, slice_ids = [], [], [] 
290         person_ids, key_ids = [], []
291         type_map = {'node': node_ids, 'authority': site_ids,
292                     'slice': slice_ids, 'user': person_ids}
293                   
294         for record in records:
295             for type in type_map:
296                 if type == record['type']:
297                     type_map[type].append(record['pointer'])
298
299         # get pl records
300         nodes, sites, slices, persons, keys = {}, {}, {}, {}, {}
301         if node_ids:
302             node_list = self.plshell.GetNodes(self.plauth, node_ids)
303             nodes = list_to_dict(node_list, 'node_id')
304         if site_ids:
305             site_list = self.plshell.GetSites(self.plauth, site_ids)
306             sites = list_to_dict(site_list, 'site_id')
307         if slice_ids:
308             slice_list = self.plshell.GetSlices(self.plauth, slice_ids)
309             slices = list_to_dict(slice_list, 'slice_id')
310         if person_ids:
311             person_list = self.plshell.GetPersons(self.plauth, person_ids)
312             persons = list_to_dict(person_list, 'person_id')
313             for person in persons:
314                 key_ids.extend(persons[person]['key_ids'])
315
316         pl_records = {'node': nodes, 'authority': sites,
317                       'slice': slices, 'user': persons}
318
319         if key_ids:
320             key_list = self.plshell.GetKeys(self.plauth, key_ids)
321             keys = list_to_dict(key_list, 'key_id')
322
323         # fill record info
324         for record in records:
325             # records with pointer==-1 do not have plc info.
326             # for example, the top level authority records which are
327             # authorities, but not PL "sites"
328             if record['pointer'] == -1:
329                 continue
330            
331             for type in pl_records:
332                 if record['type'] == type:
333                     if record['pointer'] in pl_records[type]:
334                         record.update(pl_records[type][record['pointer']])
335                         break
336             # fill in key info
337             if record['type'] == 'user':
338                 pubkeys = [keys[key_id]['key'] for key_id in record['key_ids'] if key_id in keys] 
339                 record['keys'] = pubkeys
340
341         # fill in record hrns
342         records = self.fill_record_hrns(records)   
343  
344         return records
345
346     def fill_record_hrns(self, records):
347         """
348         convert pl ids to hrns
349         """
350
351         # get ids
352         slice_ids, person_ids, site_ids, node_ids = [], [], [], []
353         for record in records:
354             if 'site_id' in record:
355                 site_ids.append(record['site_id'])
356             if 'site_ids' in records:
357                 site_ids.extend(record['site_ids'])
358             if 'person_ids' in record:
359                 person_ids.extend(record['person_ids'])
360             if 'slice_ids' in record:
361                 slice_ids.extend(record['slice_ids'])
362             if 'node_ids' in record:
363                 node_ids.extend(record['node_ids'])
364
365         # get pl records
366         slices, persons, sites, nodes = {}, {}, {}, {}
367         if site_ids:
368             site_list = self.plshell.GetSites(self.plauth, site_ids, ['site_id', 'login_base'])
369             sites = list_to_dict(site_list, 'site_id')
370         if person_ids:
371             person_list = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'email'])
372             persons = list_to_dict(person_list, 'person_id')
373         if slice_ids:
374             slice_list = self.plshell.GetSlices(self.plauth, slice_ids, ['slice_id', 'name'])
375             slices = list_to_dict(slice_list, 'slice_id')       
376         if node_ids:
377             node_list = self.plshell.GetNodes(self.plauth, node_ids, ['node_id', 'hostname'])
378             nodes = list_to_dict(node_list, 'node_id')
379        
380         # convert ids to hrns
381         for record in records:
382              
383             # get all relevant data
384             type = record['type']
385             pointer = record['pointer']
386             auth_hrn = self.hrn
387             login_base = ''
388             if pointer == -1:
389                 continue
390
391             if 'site_id' in record:
392                 site = sites[record['site_id']]
393                 login_base = site['login_base']
394                 record['site'] = ".".join([auth_hrn, login_base])
395             if 'person_ids' in record:
396                 emails = [persons[person_id]['email'] for person_id in record['person_ids'] \
397                           if person_id in  persons]
398                 usernames = [email.split('@')[0] for email in emails]
399                 person_hrns = [".".join([auth_hrn, login_base, username]) for username in usernames]
400                 record['persons'] = person_hrns 
401             if 'slice_ids' in record:
402                 slicenames = [slices[slice_id]['name'] for slice_id in record['slice_ids'] \
403                               if slice_id in slices]
404                 slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
405                 record['slices'] = slice_hrns
406             if 'node_ids' in record:
407                 hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \
408                              if node_id in nodes]
409                 node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
410                 record['nodes'] = node_hrns
411             if 'site_ids' in record:
412                 login_bases = [sites[site_id]['login_base'] for site_id in record['site_ids'] \
413                                if site_id in sites]
414                 site_hrns = [".".join([auth_hrn, lbase]) for lbase in login_bases]
415                 record['sites'] = site_hrns
416
417         return records   
418
419     def fill_record_sfa_info(self, records):
420
421         def startswith(prefix, values):
422             return [value for value in values if value.startswith(prefix)]
423
424         # get person ids
425         person_ids = []
426         site_ids = []
427         for record in records:
428             person_ids.extend(record.get("person_ids", []))
429             site_ids.extend(record.get("site_ids", [])) 
430             if 'site_id' in record:
431                 site_ids.append(record['site_id']) 
432         
433         # get all pis from the sites we've encountered
434         # and store them in a dictionary keyed on site_id 
435         site_pis = {}
436         if site_ids:
437             pi_filter = {'|roles': ['pi'], '|site_ids': site_ids} 
438             pi_list = self.plshell.GetPersons(self.plauth, pi_filter, ['person_id', 'site_ids'])
439             for pi in pi_list:
440                 # we will need the pi's hrns also
441                 person_ids.append(pi['person_id'])
442                 
443                 # we also need to keep track of the sites these pis
444                 # belong to
445                 for site_id in pi['site_ids']:
446                     if site_id in site_pis:
447                         site_pis[site_id].append(pi)
448                     else:
449                         site_pis[site_id] = [pi]
450                  
451         # get sfa records for all records associated with these records.   
452         # we'll replace pl ids (person_ids) with hrns from the sfa records
453         # we obtain
454         
455         # get the sfa records
456         table = self.SfaTable()
457         person_list, persons = [], {}
458         person_list = table.find({'type': 'user', 'pointer': person_ids})
459         # create a hrns keyed on the sfa record's pointer.
460         # Its possible for  multiple records to have the same pointer so
461         # the dict's value will be a list of hrns.
462         persons = defaultdict(list)
463         for person in person_list:
464             persons[person['pointer']].append(person)
465
466         # get the pl records
467         pl_person_list, pl_persons = [], {}
468         pl_person_list = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'roles'])
469         pl_persons = list_to_dict(pl_person_list, 'person_id')
470
471         # fill sfa info
472         for record in records:
473             # skip records with no pl info (top level authorities)
474             if record['pointer'] == -1:
475                 continue 
476             sfa_info = {}
477             type = record['type']
478             if (type == "slice"):
479                 # all slice users are researchers
480                 record['PI'] = []
481                 record['researcher'] = []
482                 for person_id in record['person_ids']:
483                     hrns = [person['hrn'] for person in persons[person_id]]
484                     record['researcher'].extend(hrns)                
485
486                 # pis at the slice's site
487                 pl_pis = site_pis[record['site_id']]
488                 pi_ids = [pi['person_id'] for pi in pl_pis]
489                 for person_id in pi_ids:
490                     hrns = [person['hrn'] for person in persons[person_id]]
491                     record['PI'].extend(hrns)
492                 record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice')
493                 record['geni_creator'] = record['PI'] 
494                 
495             elif (type == "authority"):
496                 record['PI'] = []
497                 record['operator'] = []
498                 record['owner'] = []
499                 for pointer in record['person_ids']:
500                     if pointer not in persons or pointer not in pl_persons:
501                         # this means there is not sfa or pl record for this user
502                         continue   
503                     hrns = [person['hrn'] for person in persons[pointer]] 
504                     roles = pl_persons[pointer]['roles']   
505                     if 'pi' in roles:
506                         record['PI'].extend(hrns)
507                     if 'tech' in roles:
508                         record['operator'].extend(hrns)
509                     if 'admin' in roles:
510                         record['owner'].extend(hrns)
511                     # xxx TODO: OrganizationName
512             elif (type == "node"):
513                 sfa_info['dns'] = record.get("hostname", "")
514                 # xxx TODO: URI, LatLong, IP, DNS
515     
516             elif (type == "user"):
517                 sfa_info['email'] = record.get("email", "")
518                 sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user')
519                 sfa_info['geni_certificate'] = record['gid'] 
520                 # xxx TODO: PostalAddress, Phone
521             record.update(sfa_info)
522
523     def fill_record_info(self, records):
524         """
525         Given a SFA record, fill in the PLC specific and SFA specific
526         fields in the record. 
527         """
528         if not isinstance(records, list):
529             records = [records]
530
531         self.fill_record_pl_info(records)
532         self.fill_record_sfa_info(records)
533
534     def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc):
535         # get a list of the HRNs tht are members of the old and new records
536         if oldRecord:
537             oldList = oldRecord.get(listName, [])
538         else:
539             oldList = []     
540         newList = record.get(listName, [])
541
542         # if the lists are the same, then we don't have to update anything
543         if (oldList == newList):
544             return
545
546         # build a list of the new person ids, by looking up each person to get
547         # their pointer
548         newIdList = []
549         table = self.SfaTable()
550         records = table.find({'type': 'user', 'hrn': newList})
551         for rec in records:
552             newIdList.append(rec['pointer'])
553
554         # build a list of the old person ids from the person_ids field 
555         if oldRecord:
556             oldIdList = oldRecord.get("person_ids", [])
557             containerId = oldRecord.get_pointer()
558         else:
559             # if oldRecord==None, then we are doing a Register, instead of an
560             # update.
561             oldIdList = []
562             containerId = record.get_pointer()
563
564     # add people who are in the new list, but not the oldList
565         for personId in newIdList:
566             if not (personId in oldIdList):
567                 addFunc(self.plauth, personId, containerId)
568
569         # remove people who are in the old list, but not the new list
570         for personId in oldIdList:
571             if not (personId in newIdList):
572                 delFunc(self.plauth, personId, containerId)
573
574     def update_membership(self, oldRecord, record):
575         if record.type == "slice":
576             self.update_membership_list(oldRecord, record, 'researcher',
577                                         self.plshell.AddPersonToSlice,
578                                         self.plshell.DeletePersonFromSlice)
579         elif record.type == "authority":
580             # xxx TODO
581             pass
582
583
584
585 class ComponentAPI(BaseAPI):
586
587     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods',
588                  peer_cert = None, interface = None, key_file = None, cert_file = None):
589
590         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, peer_cert=peer_cert,
591                          interface=interface, key_file=key_file, cert_file=cert_file)
592         self.encoding = encoding
593
594         # Better just be documenting the API
595         if config is None:
596             return
597
598         self.nodemanager = NodeManager(self.config)
599
600     def sliver_exists(self):
601         sliver_dict = self.nodemanager.GetXIDs()
602         if slicename in sliver_dict.keys():
603             return True
604         else:
605             return False
606
607     def get_registry(self):
608         addr, port = self.config.SFA_REGISTRY_HOST, self.config.SFA_REGISTRY_PORT
609         url = "http://%(addr)s:%(port)s" % locals()
610         server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file)
611         return server
612
613     def get_node_key(self):
614         # this call requires no authentication,
615         # so we can generate a random keypair here
616         subject="component"
617         (kfd, keyfile) = tempfile.mkstemp()
618         (cfd, certfile) = tempfile.mkstemp()
619         key = Keypair(create=True)
620         key.save_to_file(keyfile)
621         cert = Certificate(subject=subject)
622         cert.set_issuer(key=key, subject=subject)
623         cert.set_pubkey(key)
624         cert.sign()
625         cert.save_to_file(certfile)
626         registry = self.get_registry()
627         # the registry will scp the key onto the node
628         registry.get_key()        
629
630     def getCredential(self):
631         """
632         Get our credential from a remote registry
633         """
634         path = self.config.SFA_DATA_DIR
635         config_dir = self.config.config_path
636         cred_filename = path + os.sep + 'node.cred'
637         try:
638             credential = Credential(filename = cred_filename)
639             return credential.save_to_string(save_parents=True)
640         except IOError:
641             node_pkey_file = config_dir + os.sep + "node.key"
642             node_gid_file = config_dir + os.sep + "node.gid"
643             cert_filename = path + os.sep + 'server.cert'
644             if not os.path.exists(node_pkey_file) or \
645                not os.path.exists(node_gid_file):
646                 self.get_node_key()
647
648             # get node's hrn
649             gid = GID(filename=node_gid_file)
650             hrn = gid.get_hrn()
651             # get credential from registry
652             cert_str = Certificate(filename=cert_filename).save_to_string(save_parents=True)
653             registry = self.get_registry()
654             cred = registry.GetSelfCredential(cert_str, hrn, 'node')
655             Credential(string=cred).save_to_file(credfile, save_parents=True)            
656
657             return cred
658
659     def clean_key_cred(self):
660         """
661         remove the existing keypair and cred  and generate new ones
662         """
663         files = ["server.key", "server.cert", "node.cred"]
664         for f in files:
665             filepath = KEYDIR + os.sep + f
666             if os.path.isfile(filepath):
667                 os.unlink(f)
668
669         # install the new key pair
670         # GetCredential will take care of generating the new keypair
671         # and credential
672         self.get_node_key()
673         self.getCredential()
674
675