Trivial change to add support for non-PLC aggregates
[sfa.git] / sfa / plc / api.py
1 #
2 # Geniwrapper XML-RPC and SOAP interfaces
3 #
4 ### $Id$
5 ### $URL$
6 #
7
8 import sys
9 import os
10 import traceback
11 import string
12 import xmlrpclib
13
14 from sfa.trust.auth import Auth
15 from sfa.util.config import *
16 from sfa.util.faults import *
17 from sfa.util.debug import *
18 from sfa.trust.rights import *
19 from sfa.trust.credential import *
20 from sfa.util.misc import *
21 from sfa.util.sfalogging import *
22 from sfa.util.genitable import *
23
24 # See "2.2 Characters" in the XML specification:
25 #
26 # #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD]
27 # avoiding
28 # [#x7F-#x84], [#x86-#x9F], [#xFDD0-#xFDDF]
29
30 invalid_xml_ascii = map(chr, range(0x0, 0x8) + [0xB, 0xC] + range(0xE, 0x1F))
31 xml_escape_table = string.maketrans("".join(invalid_xml_ascii), "?" * len(invalid_xml_ascii))
32
33 def xmlrpclib_escape(s, replace = string.replace):
34     """
35     xmlrpclib does not handle invalid 7-bit control characters. This
36     function augments xmlrpclib.escape, which by default only replaces
37     '&', '<', and '>' with entities.
38     """
39
40     # This is the standard xmlrpclib.escape function
41     s = replace(s, "&", "&amp;")
42     s = replace(s, "<", "&lt;")
43     s = replace(s, ">", "&gt;",)
44
45     # Replace invalid 7-bit control characters with '?'
46     return s.translate(xml_escape_table)
47
48 def xmlrpclib_dump(self, value, write):
49     """
50     xmlrpclib cannot marshal instances of subclasses of built-in
51     types. This function overrides xmlrpclib.Marshaller.__dump so that
52     any value that is an instance of one of its acceptable types is
53     marshalled as that type.
54
55     xmlrpclib also cannot handle invalid 7-bit control characters. See
56     above.
57     """
58
59     # Use our escape function
60     args = [self, value, write]
61     if isinstance(value, (str, unicode)):
62         args.append(xmlrpclib_escape)
63
64     try:
65         # Try for an exact match first
66         f = self.dispatch[type(value)]
67     except KeyError:
68         raise
69         # Try for an isinstance() match
70         for Type, f in self.dispatch.iteritems():
71             if isinstance(value, Type):
72                 f(*args)
73                 return
74         raise TypeError, "cannot marshal %s objects" % type(value)
75     else:
76         f(*args)
77
78 # You can't hide from me!
79 xmlrpclib.Marshaller._Marshaller__dump = xmlrpclib_dump
80
81 # SOAP support is optional
82 try:
83     import SOAPpy
84     from SOAPpy.Parser import parseSOAPRPC
85     from SOAPpy.Types import faultType
86     from SOAPpy.NS import NS
87     from SOAPpy.SOAPBuilder import buildSOAP
88 except ImportError:
89     SOAPpy = None
90
91
92 def import_deep(name):
93     mod = __import__(name)
94     components = name.split('.')
95     for comp in components[1:]:
96         mod = getattr(mod, comp)
97     return mod
98
99 class GeniAPI:
100
101     # flat list of method names
102     import sfa.methods
103     methods = sfa.methods.all
104     
105     def __init__(self, config = "/etc/sfa/sfa_config", encoding = "utf-8", 
106                  peer_cert = None, interface = None, key_file = None, cert_file = None):
107         self.encoding = encoding
108
109         # Better just be documenting the API
110         if config is None:
111             return
112
113         # Load configuration
114         self.config = Config(config)
115         self.auth = Auth(peer_cert)
116         self.interface = interface
117         self.key_file = key_file
118         self.cert_file = cert_file
119         self.credential = None
120         
121         # Initialize the PLC shell only if SFA wraps a myPLC
122         if (self.config.get_aggregate_rspec_type() == 'pl'):
123             self.plshell = self.getPLCShell()
124             self.plshell_version = self.getPLCShellVersion()
125
126         self.hrn = self.config.SFA_INTERFACE_HRN
127         self.time_format = "%Y-%m-%d %H:%M:%S"
128         self.logger=get_sfa_logger()
129
130     def getPLCShell(self):
131         self.plauth = {'Username': self.config.SFA_PLC_USER,
132                        'AuthMethod': 'password',
133                        'AuthString': self.config.SFA_PLC_PASSWORD}
134         try:
135             self.plshell_type = 'direct'
136             import PLC.Shell
137             shell = PLC.Shell.Shell(globals = globals())
138             shell.AuthCheck(self.plauth)
139             return shell
140         except ImportError:
141             self.plshell_type = 'xmlrpc' 
142             # connect via xmlrpc
143             url = self.config.SFA_PLC_URL
144             shell = xmlrpclib.Server(url, verbose = 0, allow_none = True)
145             shell.AuthCheck(self.plauth)
146             return shell
147
148     def getPLCShellVersion(self):
149         # We need to figure out what version of PLCAPI we are talking to.
150         # Some calls we need to make later will be different depending on
151         # the api version. 
152         try:
153             # This is probably a bad way to determine api versions
154             # but its easy and will work for now. Lets try to make 
155             # a call that only exists is PLCAPI.4.3. If it fails, we
156             # can assume the api version is 4.2
157             self.plshell.GetTagTypes(self.plauth)
158             return '4.3'
159         except:
160             return '4.2'
161             
162
163     def getCredential(self):
164         if self.interface in ['registry']:
165             return self.getCredentialFromLocalRegistry()
166         else:
167             return self.getCredentialFromRegistry()
168     
169
170     def getCredentialFromRegistry(self):
171         """ 
172         Get our credential from a remote registry using a geniclient connection
173         """
174         type = 'authority'
175         path = self.config.SFA_BASE_DIR
176         filename = ".".join([self.interface, self.hrn, type, "cred"])
177         cred_filename = path + os.sep + filename
178         try:
179             credential = Credential(filename = cred_filename)
180             return credential
181         except IOError:
182             from sfa.server.registry import Registries
183             registries = Registries(self)
184             registry = registries[self.hrn]
185             self_cred = registry.get_credential(None, type, self.hrn)
186             cred = registry.get_credential(self_cred, type, self.hrn)
187             cred.save_to_file(cred_filename, save_parents=True)
188             return cred
189
190     def getCredentialFromLocalRegistry(self):
191         """
192         Get our current credential directly from the local registry.
193         """
194
195         hrn = self.hrn
196         auth_hrn = self.auth.get_authority(hrn)
197     
198         # is this a root or sub authority
199         if not auth_hrn or hrn == self.config.SFA_INTERFACE_HRN:
200             auth_hrn = hrn
201         auth_info = self.auth.get_auth_info(auth_hrn)
202         table = GeniTable()
203         records = table.findObjects(hrn)
204         if not records:
205             raise RecordNotFound
206         record = records[0]
207         type = record['type']
208         object_gid = record.get_gid_object()
209         new_cred = Credential(subject = object_gid.get_subject())
210         new_cred.set_gid_caller(object_gid)
211         new_cred.set_gid_object(object_gid)
212         new_cred.set_issuer(key=auth_info.get_pkey_object(), subject=auth_hrn)
213         new_cred.set_pubkey(object_gid.get_pubkey())
214         r1 = determine_rights(type, hrn)
215         new_cred.set_privileges(r1)
216
217         auth_kind = "authority,ma,sa"
218
219         new_cred.set_parent(self.auth.hierarchy.get_auth_cred(auth_hrn, kind=auth_kind))
220
221         new_cred.encode()
222         new_cred.sign()
223
224         return new_cred
225    
226
227     def loadCredential (self):
228         """
229         Attempt to load credential from file if it exists. If it doesnt get
230         credential from registry.
231         """
232
233         # see if this file exists
234         # XX This is really the aggregate's credential. Using this is easier than getting
235         # the registry's credential from iteslf (ssl errors).   
236         ma_cred_filename = self.config.SFA_BASE_DIR + os.sep + self.interface + self.hrn + ".ma.cred"
237         try:
238             self.credential = Credential(filename = ma_cred_filename)
239         except IOError:
240             self.credential = self.getCredentialFromRegistry()
241
242     ##
243     # Convert geni fields to PLC fields for use when registering up updating
244     # registry record in the PLC database
245     #
246     # @param type type of record (user, slice, ...)
247     # @param hrn human readable name
248     # @param geni_fields dictionary of geni fields
249     # @param pl_fields dictionary of PLC fields (output)
250
251     def geni_fields_to_pl_fields(self, type, hrn, record):
252
253         def convert_ints(tmpdict, int_fields):
254             for field in int_fields:
255                 if field in tmpdict:
256                     tmpdict[field] = int(tmpdict[field])
257
258         pl_record = {}
259         #for field in record:
260         #    pl_record[field] = record[field]
261  
262         if type == "slice":
263             if not "instantiation" in pl_record:
264                 pl_record["instantiation"] = "plc-instantiated"
265             pl_record["name"] = hrn_to_pl_slicename(hrn)
266             if "url" in record:
267                pl_record["url"] = record["url"]
268             if "description" in record:
269                 pl_record["description"] = record["description"]
270
271         elif type == "node":
272             if not "hostname" in pl_record:
273                 if not "hostname" in record:
274                     raise MissingGeniInfo("hostname")
275                 pl_record["hostname"] = record["hostname"]
276             if not "model" in pl_record:
277                 pl_record["model"] = "geni"
278
279         elif type == "authority":
280             pl_record["login_base"] = hrn_to_pl_login_base(hrn)
281
282             if not "name" in pl_record:
283                 pl_record["name"] = hrn
284
285             if not "abbreviated_name" in pl_record:
286                 pl_record["abbreviated_name"] = hrn
287
288             if not "enabled" in pl_record:
289                 pl_record["enabled"] = True
290
291             if not "is_public" in pl_record:
292                 pl_record["is_public"] = True
293
294         return pl_record
295
296     def fill_record_pl_info(self, record):
297         """
298         Fill in the planetlab specific fields of a Geni record. This
299         involves calling the appropriate PLC method to retrieve the 
300         database record for the object.
301         
302         PLC data is filled into the pl_info field of the record.
303     
304         @param record: record to fill in field (in/out param)     
305         """
306         type = record['type']
307         pointer = record['pointer']
308         auth_hrn = self.hrn
309         login_base = ''
310         # records with pointer==-1 do not have plc info associated with them.
311         # for example, the top level authority records which are
312         # authorities, but not PL "sites"
313         if pointer == -1:
314             record.update({})
315             return
316
317         if (type in ["authority"]):
318             pl_res = self.plshell.GetSites(self.plauth, [pointer])
319         elif (type == "slice"):
320             pl_res = self.plshell.GetSlices(self.plauth, [pointer])
321         elif (type == "user"):
322             pl_res = self.plshell.GetPersons(self.plauth, [pointer])
323         elif (type == "node"):
324             pl_res = self.plshell.GetNodes(self.plauth, [pointer])
325         else:
326             raise UnknownGeniType(type)
327         
328         if not pl_res:
329             raise PlanetLabRecordDoesNotExist(record['hrn'])
330
331         # convert ids to hrns
332         pl_record = pl_res[0]
333         if 'site_id' in pl_record:
334             sites = self.plshell.GetSites(self.plauth, pl_record['site_id'], ['login_base'])
335             site = sites[0]
336             login_base = site['login_base']
337             pl_record['site'] = ".".join([auth_hrn, login_base])
338         if 'person_ids' in pl_record:
339             persons =  self.plshell.GetPersons(self.plauth, pl_record['person_ids'], ['email'])
340             emails = [person['email'] for person in persons]
341             usernames = [email.split('@')[0] for email in emails]
342             person_hrns = [".".join([auth_hrn, login_base, username]) for username in usernames]
343             pl_record['persons'] = person_hrns 
344         if 'slice_ids' in pl_record:
345             slices = self.plshell.GetSlices(self.plauth, pl_record['slice_ids'], ['name'])
346             slicenames = [slice['name'] for slice in slices]
347             slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames]
348             pl_record['slices'] = slice_hrns
349         if 'node_ids' in pl_record:
350             nodes = self.plshell.GetNodes(self.plauth, pl_record['node_ids'], ['hostname'])
351             hostnames = [node['hostname'] for node in nodes]
352             node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames]
353             pl_record['nodes'] = node_hrns
354         if 'site_ids' in pl_record:
355             sites = self.plshell.GetSites(self.plauth, pl_record['site_ids'], ['login_base'])
356             login_bases = [site['login_base'] for site in sites]
357             site_hrns = [".".join([auth_hrn, lbase]) for lbase in login_bases]
358             pl_record['sites'] = site_hrns
359         if 'key_ids' in pl_record:
360             keys = self.plshell.GetKeys(self.plauth, pl_record['key_ids'])
361             pubkeys = []
362             if keys:
363                 pubkeys = [key['key'] for key in keys]
364             pl_record['keys'] = pubkeys     
365
366         record.update(pl_record)
367
368
369
370     def fill_record_geni_info(self, record):
371         geni_info = {}
372         type = record['type']
373         table = GeniTable()
374         if (type == "slice"):
375             person_ids = record.get("person_ids", [])
376             persons = table.find({'type': 'user', 'pointer': person_ids})
377             researchers = [person['hrn'] for person in persons]
378             geni_info['researcher'] = researchers
379
380         elif (type == "authority"):
381             person_ids = record.get("person_ids", [])
382             persons = table.find({'type': 'user', 'pointer': person_ids})
383             persons_dict = {}
384             for person in persons:
385                 persons_dict[person['pointer']] = person 
386             pl_persons = self.plshell.GetPersons(self.plauth, person_ids, ['person_id', 'roles'])
387             pis, techs, admins = [], [], []
388             for person in pl_persons:
389                 pointer = person['person_id']
390                 
391                 if pointer not in persons_dict:
392                     # this means there is not sfa record for this user
393                     continue    
394                 hrn = persons_dict[pointer]['hrn']    
395                 if 'pi' in person['roles']:
396                     pis.append(hrn)
397                 if 'tech' in person['roles']:
398                     techs.append(hrn)
399                 if 'admin' in person['roles']:
400                     admins.append(hrn)
401             
402             geni_info['PI'] = pis
403             geni_info['operator'] = techs
404             geni_info['owner'] = admins
405             # xxx TODO: OrganizationName
406
407         elif (type == "node"):
408             geni_info['dns'] = record.get("hostname", "")
409             # xxx TODO: URI, LatLong, IP, DNS
410     
411         elif (type == "user"):
412             geni_info['email'] = record.get("email", "")
413             # xxx TODO: PostalAddress, Phone
414
415         record.update(geni_info)
416
417     def fill_record_info(self, record):
418         """
419         Given a geni record, fill in the PLC specific and Geni specific
420         fields in the record. 
421         """
422         self.fill_record_pl_info(record)
423         self.fill_record_geni_info(record)
424
425     def update_membership_list(self, oldRecord, record, listName, addFunc, delFunc):
426         # get a list of the HRNs tht are members of the old and new records
427         if oldRecord:
428             oldList = oldRecord.get(listName, [])
429         else:
430             oldList = []     
431         newList = record.get(listName, [])
432
433         # if the lists are the same, then we don't have to update anything
434         if (oldList == newList):
435             return
436
437         # build a list of the new person ids, by looking up each person to get
438         # their pointer
439         newIdList = []
440         table = GeniTable()
441         records = table.find({'type': 'user', 'hrn': newList})
442         for rec in records:
443             newIdList.append(rec['pointer'])
444
445         # build a list of the old person ids from the person_ids field 
446         if oldRecord:
447             oldIdList = oldRecord.get("person_ids", [])
448             containerId = oldRecord.get_pointer()
449         else:
450             # if oldRecord==None, then we are doing a Register, instead of an
451             # update.
452             oldIdList = []
453             containerId = record.get_pointer()
454
455     # add people who are in the new list, but not the oldList
456         for personId in newIdList:
457             if not (personId in oldIdList):
458                 addFunc(self.plauth, personId, containerId)
459
460         # remove people who are in the old list, but not the new list
461         for personId in oldIdList:
462             if not (personId in newIdList):
463                 delFunc(self.plauth, personId, containerId)
464
465     def update_membership(self, oldRecord, record):
466         if record.type == "slice":
467             self.update_membership_list(oldRecord, record, 'researcher',
468                                         self.plshell.AddPersonToSlice,
469                                         self.plshell.DeletePersonFromSlice)
470         elif record.type == "authority":
471             # xxx TODO
472             pass
473
474
475     def callable(self, method):
476         """
477         Return a new instance of the specified method.
478         """
479         # Look up method
480         if method not in self.methods:
481             raise GeniInvalidAPIMethod, method
482         
483         # Get new instance of method
484         try:
485             classname = method.split(".")[-1]
486             module = __import__("sfa.methods." + method, globals(), locals(), [classname])
487             callablemethod = getattr(module, classname)(self)
488             return getattr(module, classname)(self)
489         except ImportError, AttributeError:
490             raise
491             raise GeniInvalidAPIMethod, method
492
493     def call(self, source, method, *args):
494         """
495         Call the named method from the specified source with the
496         specified arguments.
497         """
498         function = self.callable(method)
499         function.source = source
500         return function(*args)
501
502     def handle(self, source, data):
503         """
504         Handle an XML-RPC or SOAP request from the specified source.
505         """
506         # Parse request into method name and arguments
507         try:
508             interface = xmlrpclib
509             (args, method) = xmlrpclib.loads(data)
510             methodresponse = True
511         except Exception, e:
512             if SOAPpy is not None:
513                 interface = SOAPpy
514                 (r, header, body, attrs) = parseSOAPRPC(data, header = 1, body = 1, attrs = 1)
515                 method = r._name
516                 args = r._aslist()
517                 # XXX Support named arguments
518             else:
519                 raise e
520
521         try:
522             result = self.call(source, method, *args)
523         except Exception, fault:
524             traceback.print_exc(file = log)
525             # Handle expected faults
526             if interface == xmlrpclib:
527                 result = fault
528                 methodresponse = None
529             elif interface == SOAPpy:
530                 result = faultParameter(NS.ENV_T + ":Server", "Method Failed", method)
531                 result._setDetail("Fault %d: %s" % (fault.faultCode, fault.faultString))
532             else:
533                 raise
534
535         # Return result
536         if interface == xmlrpclib:
537             if not isinstance(result, GeniFault):
538                 result = (result,)
539
540             data = xmlrpclib.dumps(result, methodresponse = True, encoding = self.encoding, allow_none = 1)
541         elif interface == SOAPpy:
542             data = buildSOAP(kw = {'%sResponse' % method: {'Result': result}}, encoding = self.encoding)
543
544         return data
545