return/raise SfaAPIException with stacktrace if we fail to connect to PLCAPI
[sfa.git] / sfa / plc / api.py
1 #
2 # SFA XML-RPC and SOAP interfaces
3 #
4 ### $Id$
5 ### $URL$
6 #
7
8 import sys
9 import os
10 import traceback
11 import string
12 import xmlrpclib
13 from sfa.trust.auth import Auth
14 from sfa.util.config import *
15 from sfa.util.faults import *
16 from sfa.util.debug import *
17 from sfa.trust.rights import *
18 from sfa.trust.credential import *
19 from sfa.trust.certificate import *
20 from sfa.util.namespace import *
21 from sfa.util.api import *
22 from sfa.util.nodemanager import NodeManager
23 from sfa.util.sfalogging import *
24
25 def list_to_dict(recs, key):
26     keys = [rec[key] for rec in recs]
27     return dict(zip(keys, recs))
28
29 def handle_exception(method):
30     def wrapper(*args, **kwargs):
31         try: return method(*args, **kwargs)
32         except: raise SfaAPIError(traceback.format_exc())
33     return wrapper
34
35 class SfaAPI(BaseAPI):
36
37     # flat list of method names
38     import sfa.methods
39     methods = sfa.methods.all
40     
41     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods', \
42                  peer_cert = None, interface = None, key_file = None, cert_file = None):
43         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, \
44                          peer_cert=peer_cert, interface=interface, key_file=key_file, \
45                          cert_file=cert_file)
46  
47         self.encoding = encoding
48
49         from sfa.util.table import SfaTable
50         self.SfaTable = SfaTable
51         # Better just be documenting the API
52         if config is None:
53             return
54
55         # Load configuration
56         self.config = Config(config)
57         self.auth = Auth(peer_cert)
58         self.interface = interface
59         self.key_file = key_file
60         self.key = Keypair(filename=self.key_file)
61         self.cert_file = cert_file
62         self.cert = Certificate(filename=self.cert_file)
63         self.credential = None
64         # Initialize the PLC shell only if SFA wraps a myPLC
65         rspec_type = self.config.get_aggregate_type()
66         if (rspec_type == 'pl' or rspec_type == 'vini'):
67             self.plshell = self.getPLCShell()
68             self.plshell_version = self.getPLCShellVersion()
69
70         self.hrn = self.config.SFA_INTERFACE_HRN
71         self.time_format = "%Y-%m-%d %H:%M:%S"
72         self.logger=get_sfa_logger()
73
74     @handle_exception 
75     def getPLCShell(self):
76         self.plauth = {'Username': self.config.SFA_PLC_USER,
77                        'AuthMethod': 'password',
78                        'AuthString': self.config.SFA_PLC_PASSWORD}
79         try:
80             sys.path.append(os.path.dirname(os.path.realpath("/usr/bin/plcsh")))
81             self.plshell_type = 'direct'
82             import PLC.Shell
83             shell = PLC.Shell.Shell(globals = globals())
84             shell.AuthCheck(self.plauth)
85             return shell
86         except ImportError:
87             self.plshell_type = 'xmlrpc' 
88             # connect via xmlrpc
89             url = self.config.SFA_PLC_URL
90             shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
91             shell.AuthCheck(self.plauth)
92             return shell
93     
94     @handle_exception 
95     def getPLCShellVersion(self):
96         # We need to figure out what version of PLCAPI we are talking to.
97         # Some calls we need to make later will be different depending on
98         # the api version. 
99         try:
100             # This is probably a bad way to determine api versions
101             # but its easy and will work for now. Lets try to make 
102             # a call that only exists is PLCAPI.4.3. If it fails, we
103             # can assume the api version is 4.2
104             self.plshell.GetTagTypes(self.plauth)
105             return '4.3'
106         except:
107             return '4.2'
108             
109
110     def getCredential(self):
111         if self.interface in ['registry']:
112             return self.getCredentialFromLocalRegistry()
113         else:
114             return self.getCredentialFromRegistry()
115     
116     def getCredentialFromRegistry(self):
117         """ 
118         Get our credential from a remote registry 
119         """
120         type = 'authority'
121         path = self.config.SFA_DATA_DIR
122         filename = ".".join([self.interface, self.hrn, type, "cred"])
123         cred_filename = path + os.sep + filename
124         try:
125             credential = Credential(filename = cred_filename)
126             return credential.save_to_string(save_parents=True)
127         except IOError:
128             from sfa.server.registry import Registries
129             registries = Registries(self)
130             registry = registries[self.hrn]
131             cert_string=self.cert.save_to_string(save_parents=True)
132             # get self credential
133             self_cred = registry.get_self_credential(cert_string, type, self.hrn)
134             # get credential
135             cred = registry.get_credential(self_cred, type, self.hrn)
136             
137             # save cred to file
138             Credential(string=cred).save_to_file(cred_filename, save_parents=True)
139             return cred
140
141     def getCredentialFromLocalRegistry(self):
142         """
143         Get our current credential directly from the local registry.
144         """
145
146         hrn = self.hrn
147         auth_hrn = self.auth.get_authority(hrn)
148     
149         # is this a root or sub authority
150         if not auth_hrn or hrn == self.config.SFA_INTERFACE_HRN:
151             auth_hrn = hrn
152         auth_info = self.auth.get_auth_info(auth_hrn)
153         table = self.SfaTable()
154         records = table.findObjects(hrn)
155         if not records:
156             raise RecordNotFound
157         record = records[0]
158         type = record['type']
159         object_gid = record.get_gid_object()
160         new_cred = Credential(subject = object_gid.get_subject())
161         new_cred.set_gid_caller(object_gid)
162         new_cred.set_gid_object(object_gid)
163         new_cred.set_issuer(key=auth_info.get_pkey_object(), subject=auth_hrn)
164         new_cred.set_pubkey(object_gid.get_pubkey())
165         r1 = determine_rights(type, hrn)
166         new_cred.set_privileges(r1)
167
168         auth_kind = "authority,ma,sa"
169
170         new_cred.set_parent(self.auth.hierarchy.get_auth_cred(auth_hrn, kind=auth_kind))
171
172         new_cred.encode()
173         new_cred.sign()
174
175         return new_cred.save_to_string(save_parents=True)
176    
177
178     def loadCredential (self):
179         """
180         Attempt to load credential from file if it exists. If it doesnt get
181         credential from registry.
182         """
183
184         # see if this file exists
185         # XX This is really the aggregate's credential. Using this is easier than getting
186         # the registry's credential from iteslf (ssl errors).   
187         ma_cred_filename = self.config.SFA_DATA_DIR + os.sep + self.interface + self.hrn + ".ma.cred"
188         try:
189             self.credential = Credential(filename = ma_cred_filename)
190         except IOError:
191             self.credential = self.getCredentialFromRegistry()
192
193     ##
194     # Convert SFA fields to PLC fields for use when registering up updating
195     # registry record in the PLC database
196     #
197     # @param type type of record (user, slice, ...)
198     # @param hrn human readable name
199     # @param sfa_fields dictionary of SFA fields
200     # @param pl_fields dictionary of PLC fields (output)
201
202     def sfa_fields_to_pl_fields(self, type, hrn, record):
203
204         def convert_ints(tmpdict, int_fields):
205             for field in int_fields:
206                 if field in tmpdict:
207                     tmpdict[field] = int(tmpdict[field])
208
209         pl_record = {}
210         #for field in record:
211         #    pl_record[field] = record[field]
212  
213         if type == "slice":
214             if not "instantiation" in pl_record:
215                 pl_record["instantiation"] = "plc-instantiated"
216             pl_record["name"] = hrn_to_pl_slicename(hrn)
217             if "url" in record:
218                pl_record["url"] = record["url"]
219             if "description" in record:
220                 pl_record["description"] = record["description"]
221             if "expires" in record:
222                 pl_record["expires"] = int(record["expires"])
223
224         elif type == "node":
225             if not "hostname" in pl_record:
226                 if not "hostname" in record:
227                     raise MissingSfaInfo("hostname")
228                 pl_record["hostname"] = record["hostname"]
229             if not "model" in pl_record:
230                 pl_record["model"] = "geni"
231
232         elif type == "authority":
233             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
234
235             if not "name" in pl_record:
236                 pl_record["name"] = hrn
237
238             if not "abbreviated_name" in pl_record:
239                 pl_record["abbreviated_name"] = hrn
240
241             if not "enabled" in pl_record:
242                 pl_record["enabled"] = True
243
244             if not "is_public" in pl_record:
245                 pl_record["is_public"] = True
246
247         return pl_record
248
249     def fill_record_pl_info(self, records):
250         """
251         Fill in the planetlab specific fields of a SFA record. This
252         involves calling the appropriate PLC method to retrieve the 
253         database record for the object.
254         
255         PLC data is filled into the pl_info field of the record.
256     
257         @param record: record to fill in field (in/out param)     
258         """
259         # get ids by type
260         node_ids, site_ids, slice_ids = [], [], [] 
261         person_ids, key_ids = [], []
262         type_map = {'node': node_ids, 'authority': site_ids,
263                     'slice': slice_ids, 'user': person_ids}
264                   
265         for record in records:
266             for type in type_map:
267                 if type == record['type']:
268                     type_map[type].append(record['pointer'])
269
270         # get pl records
271         nodes, sites, slices, persons, keys = {}, {}, {}, {}, {}
272         if node_ids:
273             node_list = self.plshell.GetNodes(self.plauth, node_ids)
274             nodes = list_to_dict(node_list, 'node_id')
275         if site_ids:
276             site_list = self.plshell.GetSites(self.plauth, site_ids)
277             sites = list_to_dict(site_list, 'site_id')
278         if slice_ids:
279             slice_list = self.plshell.GetSlices(self.plauth, slice_ids)
280             slices = list_to_dict(slice_list, 'slice_id')
281         if person_ids:
282             person_list = self.plshell.GetPersons(self.plauth, person_ids)
283             persons = list_to_dict(person_list, 'person_id')
284             for person in persons:
285                 key_ids.extend(persons[person]['key_ids'])
286
287         pl_records = {'node': nodes, 'authority': sites,
288                       'slice': slices, 'user': persons}
289
290         if key_ids:
291             key_list = self.plshell.GetKeys(self.plauth, key_ids)
292             keys = list_to_dict(key_list, 'key_id')
293
294         # fill record info
295         for record in records:
296             # records with pointer==-1 do not have plc info associated with them.
297             # for example, the top level authority records which are
298             # authorities, but not PL "sites"
299             if record['pointer'] == -1:
300                 continue
301            
302             for type in pl_records:
303                 if record['type'] == type:
304                     if record['pointer'] in pl_records[type]:
305                         record.update(pl_records[type][record['pointer']])
306                         break
307             # fill in key info
308             if record['type'] == 'user':
309                 pubkeys = [keys[key_id]['key'] for key_id in record['key_ids'] if key_id in keys] 
310                 record['keys'] = pubkeys
311
312         # fill in record hrns
313         records = self.fill_record_hrns(records)   
314  
315         return records
316
317     def fill_record_hrns(self, records):
318         """
319         convert pl ids to hrns
320         """
321
322         # get ids
323         slice_ids, person_ids, site_ids, node_ids = [], [], [], []
324         for record in records:
325             if 'site_id' in record:
326                 site_ids.append(record['site_id'])
327             if 'site_ids' in records:
328                 site_ids.extend(record['site_ids'])
329             if 'person_ids' in record:
330                 person_ids.extend(record['person_ids'])
331             if 'slice_ids' in record:
332                 slice_ids.extend(record['slice_ids'])
333             if 'node_ids' in record:
334                 node_ids.extend(record['node_ids'])
335
336         # get pl records
337         slices, persons, sites, nodes = {}, {}, {}, {}
338         if site_ids:
339             site_list = self.plshell.GetSites(self.plauth, site_ids, ['site_id', 'login_base'])
340             sites = list_to_dict(site_list, 'site_id')
341         if person_ids:
342             person_list = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'email'])
343             persons = list_to_dict(person_list, 'person_id')
344         if slice_ids:
345             slice_list = self.plshell.GetSlices(self.plauth, slice_ids, ['slice_id', 'name'])
346             slices = list_to_dict(slice_list, 'slice_id')       
347         if node_ids:
348             node_list = self.plshell.GetNodes(self.plauth, node_ids, ['node_id', 'hostname'])
349             nodes = list_to_dict(node_list, 'node_id')
350        
351         # convert ids to hrns
352         for record in records:
353              
354             # get all relevant data
355             type = record['type']
356             pointer = record['pointer']
357             auth_hrn = self.hrn
358             login_base = ''
359             if pointer == -1:
360                 continue
361
362             if 'site_id' in record:
363                 site = sites[record['site_id']]
364                 login_base = site['login_base']
365                 record['site'] = ".".join([auth_hrn, login_base])
366             if 'person_ids' in record:
367                 emails = [persons[person_id]['email'] for person_id in record['person_ids'] \
368                           if person_id in  persons]
369                 usernames = [email.split('@')[0] for email in emails]
370                 person_hrns = [".".join([auth_hrn, login_base, username]) for username in usernames]
371                 record['persons'] = person_hrns 
372             if 'slice_ids' in record:
373                 slicenames = [slices[slice_id]['name'] for slice_id in record['slice_ids'] \
374                               if slice_id in slices]
375                 slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
376                 record['slices'] = slice_hrns
377             if 'node_ids' in record:
378                 hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \
379                              if node_id in nodes]
380                 node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
381                 record['nodes'] = node_hrns
382             if 'site_ids' in record:
383                 login_bases = [sites[site_id]['login_base'] for site_id in record['site_ids'] \
384                                if site_id in sites]
385                 site_hrns = [".".join([auth_hrn, lbase]) for lbase in login_bases]
386                 record['sites'] = site_hrns
387
388         return records   
389
390     def fill_record_sfa_info(self, records):
391         # get person ids
392         has_authority = False
393         has_slice = False
394         person_ids = []
395         for record in records:
396             if record['type'] == 'authority':
397                 has_authority = True
398             person_ids.extend(record.get("person_ids", []))
399         
400         # get sfa info
401         table = self.SfaTable()
402         person_list, persons = [], {}
403         pl_person_list, pl_persons = [], {}
404         person_list = table.find({'type': 'user', 'pointer': person_ids})
405         persons = list_to_dict(person_list, 'pointer')
406         if has_authority:
407             pl_person_list = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'roles'])
408             pl_persons = list_to_dict(pl_person_list, 'person_id')
409             
410
411         # fill sfa info
412         for record in records:
413             # skip records with no pl info (top level authorities)
414             if record['pointer'] == -1:
415                 continue 
416             sfa_info = {}
417             type = record['type']
418             if (type == "slice"):
419                 researchers = [persons[person_id]['hrn'] for person_id in record['person_ids'] \
420                                if person_id in persons] 
421                 sfa_info['researcher'] = researchers
422          
423             elif (type == "authority"):
424                 pis, techs, admins = [], [], []
425                 for pointer in record['person_ids']:
426                     if pointer not in persons or pointer not in pl_persons:
427                         # this means there is not sfa or pl record for this user
428                         continue   
429                     hrn = persons[pointer]['hrn'] 
430                     roles = pl_persons[pointer]['roles']   
431                     if 'pi' in roles:
432                         pis.append(hrn)
433                     if 'tech' in roles:
434                         techs.append(hrn)
435                     if 'admin' in roles:
436                         admins.append(hrn)
437             
438                     sfa_info['PI'] = pis
439                     sfa_info['operator'] = techs
440                     sfa_info['owner'] = admins
441                     # xxx TODO: OrganizationName
442             elif (type == "node"):
443                 sfa_info['dns'] = record.get("hostname", "")
444                 # xxx TODO: URI, LatLong, IP, DNS
445     
446             elif (type == "user"):
447                 sfa_info['email'] = record.get("email", "")
448                 # xxx TODO: PostalAddress, Phone
449             record.update(sfa_info)
450
451     def fill_record_info(self, records):
452         """
453         Given a SFA record, fill in the PLC specific and SFA specific
454         fields in the record. 
455         """
456         if not isinstance(records, list):
457             records = [records]
458
459         self.fill_record_pl_info(records)
460         self.fill_record_sfa_info(records)
461
462     def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc):
463         # get a list of the HRNs tht are members of the old and new records
464         if oldRecord:
465             oldList = oldRecord.get(listName, [])
466         else:
467             oldList = []     
468         newList = record.get(listName, [])
469
470         # if the lists are the same, then we don't have to update anything
471         if (oldList == newList):
472             return
473
474         # build a list of the new person ids, by looking up each person to get
475         # their pointer
476         newIdList = []
477         table = self.SfaTable()
478         records = table.find({'type': 'user', 'hrn': newList})
479         for rec in records:
480             newIdList.append(rec['pointer'])
481
482         # build a list of the old person ids from the person_ids field 
483         if oldRecord:
484             oldIdList = oldRecord.get("person_ids", [])
485             containerId = oldRecord.get_pointer()
486         else:
487             # if oldRecord==None, then we are doing a Register, instead of an
488             # update.
489             oldIdList = []
490             containerId = record.get_pointer()
491
492     # add people who are in the new list, but not the oldList
493         for personId in newIdList:
494             if not (personId in oldIdList):
495                 addFunc(self.plauth, personId, containerId)
496
497         # remove people who are in the old list, but not the new list
498         for personId in oldIdList:
499             if not (personId in newIdList):
500                 delFunc(self.plauth, personId, containerId)
501
502     def update_membership(self, oldRecord, record):
503         if record.type == "slice":
504             self.update_membership_list(oldRecord, record, 'researcher',
505                                         self.plshell.AddPersonToSlice,
506                                         self.plshell.DeletePersonFromSlice)
507         elif record.type == "authority":
508             # xxx TODO
509             pass
510
511
512
513 class ComponentAPI(BaseAPI):
514
515     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods',
516                  peer_cert = None, interface = None, key_file = None, cert_file = None):
517
518         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, peer_cert=peer_cert,
519                          interface=interface, key_file=key_file, cert_file=cert_file)
520         self.encoding = encoding
521
522         # Better just be documenting the API
523         if config is None:
524             return
525
526         self.nodemanager = NodeManager(self.config)
527
528     def sliver_exists(self):
529         sliver_dict = self.nodemanager.GetXIDs()
530         if slicename in sliver_dict.keys():
531             return True
532         else:
533             return False