From: Thierry Parmentelat Date: Wed, 11 Jan 2017 13:30:28 +0000 (+0100) Subject: autopep8 X-Git-Tag: sfa-3.1-21~5 X-Git-Url: http://git.onelab.eu/?p=sfa.git;a=commitdiff_plain;h=04a3f20dc71bf8b3f96b1e3172623aa346a638a7 autopep8 --- diff --git a/clientbin/getNodes.py b/clientbin/getNodes.py index ec2e4a2d..89689a5b 100644 --- a/clientbin/getNodes.py +++ b/clientbin/getNodes.py @@ -7,6 +7,7 @@ from pprint import pprint from sfa.util.py23 import StringType + def create_parser(): command = sys.argv[0] argv = sys.argv[1:] @@ -14,27 +15,27 @@ def create_parser(): description = """getNodes will open a rspec file and print all key/values, or filter results based on a given key or set of keys.""" parser = OptionParser(usage=usage, description=description) parser.add_option("-i", "--infile", dest="infile", default=None, - help = "input rspec file") + help="input rspec file") parser.add_option("-t", "--tag", dest="tag", default=None, - help = "filter rspec for this tag") + help="filter rspec for this tag") parser.add_option("-a", "--attribute", dest="attribute", default=None, - help = "comma separated list of attributes to display") + help="comma separated list of attributes to display") parser.add_option("-r", "--recursive", dest="print_children", default=False, action="store_true", - help = "print the tag's child nodes") + help="print the tag's child nodes") - return parser + return parser def print_dict(rdict, options, counter=1): print_children = options.print_children attributes = [] - if options.attribute: - attributes = options.attribute.split(',') + if options.attribute: + attributes = options.attribute.split(',') lists = [] tab = " " - + if not isinstance(rdict, dict): - raise "%s not a dict" % rdict + raise "%s not a dict" % rdict for (key, value) in rdict.iteritems(): if isinstance(value, StringType): if (attributes and key in attributes) or not attributes: @@ -44,29 +45,30 @@ def print_dict(rdict, options, counter=1): if isinstance(listitem, dict): lists.append((key, listitem)) elif isinstance(value, dict): - lists.append((key, value)) - - if counter == 1 or print_children: + lists.append((key, value)) + + if counter == 1 or print_children: for (key, listitem) in lists: if isinstance(listitem, dict): print tab * (counter - 1) + key - print_dict(listitem, options, counter+1) + print_dict(listitem, options, counter + 1) elif not attributes or (attributes and 'children' in attributes): keys = set([key for (key, listitem) in lists]) - if keys: print tab * (counter) + "(children: %s)" % (",".join(keys)) - + if keys: + print tab * (counter) + "(children: %s)" % (",".join(keys)) + # # this code probably is obsolete # RSpec is not imported, it does not have a toDict() method anyway # plus, getNodes.py is not exposed in packaging -# +# def main(): - parser = create_parser(); + parser = create_parser() (options, args) = parser.parse_args() if not options.infile: print "RSpec file not specified" - return + return rspec = RSpec() try: @@ -79,14 +81,15 @@ def main(): rspec_dicts = rspec.getDictsByTagName(tag_name) rspec_dict = {tag_name: rspec_dicts} else: - rspec_dict = rspec.toDict() - + rspec_dict = rspec.toDict() + print_dict(rspec_dict, options) return if __name__ == '__main__': - try: main() + try: + main() except Exception as e: raise print e diff --git a/clientbin/getRecord.py b/clientbin/getRecord.py index c88f193e..28608f3f 100755 --- a/clientbin/getRecord.py +++ b/clientbin/getRecord.py @@ -14,37 +14,39 @@ import os from optparse import OptionParser from pprint import pprint from xml.parsers.expat import ExpatError -from sfa.util.xml import XML +from sfa.util.xml import XML + def create_parser(): command = sys.argv[0] argv = sys.argv[1:] usage = "%(command)s [options]" % locals() description = """getRecord will parse a supplied (via stdin) record and print all values or key/values, and filter results based on a given key or set of keys.""" - parser = OptionParser(usage=usage,description=description) + parser = OptionParser(usage=usage, description=description) parser.add_option("-d", "--debug", dest="DEBUG", action="store_true", - default=False, help = "record file path") + default=False, help="record file path") parser.add_option("-k", "--key", dest="withkey", action="store_true", - default=False, help = "print SSH keys and certificates") + default=False, help="print SSH keys and certificates") parser.add_option("-p", "--plinfo", dest="plinfo", action="store_true", - default=False, help = "print PlanetLab specific internal fields") - - return parser + default=False, help="print PlanetLab specific internal fields") + + return parser def printRec(record_dict, filters, options): line = "" if len(filters): for filter in filters: - if options.DEBUG: print "Filtering on %s" %filter - line += "%s: %s\n" % (filter, - printVal(record_dict.get(filter, None))) + if options.DEBUG: + print "Filtering on %s" % filter + line += "%s: %s\n" % (filter, + printVal(record_dict.get(filter, None))) print line else: # print the wole thing for (key, value) in record_dict.iteritems(): if (not options.withkey and key in ('gid', 'keys')) or\ - (not options.plinfo and key == 'pl_info'): + (not options.plinfo and key == 'pl_info'): continue line += "%s: %s\n" % (key, printVal(value)) print line @@ -62,22 +64,23 @@ def printVal(value): def main(): - parser = create_parser(); + parser = create_parser() (options, args) = parser.parse_args() stdin = sys.stdin.read() - + record = XML(stdin) record_dict = record.todict() - - if options.DEBUG: + + if options.DEBUG: pprint(record.toxml()) print "#####################################################" printRec(record_dict, args, options) if __name__ == '__main__': - try: main() + try: + main() except ExpatError as e: print "RecordError. Is your record valid XML?" print e diff --git a/clientbin/setRecord.py b/clientbin/setRecord.py index aafd95c2..886b8014 100755 --- a/clientbin/setRecord.py +++ b/clientbin/setRecord.py @@ -16,16 +16,17 @@ from optparse import OptionParser from pprint import pprint from sfa.util.xml import XML + def create_parser(): command = sys.argv[0] argv = sys.argv[1:] usage = "%(command)s [options]" % locals() description = """setRecord will edit a record (from stdin), modify its contents, then print the new record to stdout""" - parser = OptionParser(usage=usage,description=description) + parser = OptionParser(usage=usage, description=description) parser.add_option("-d", "--debug", dest="DEBUG", action="store_true", - default=False, help = "print debug info") - - return parser + default=False, help="print debug info") + + return parser def editDict(args, recordDict, options): @@ -37,19 +38,19 @@ def editDict(args, recordDict, options): if vect.count("+="): # append value modDict({vect.split("+=")[0]: returnVal(vect.split("+=")[1])}, - recordDict, options) - + recordDict, options) + elif vect.count("="): # reassign value replaceDict({vect.split("=")[0]: returnVal("=".join(vect.split("=")[1:]))}, - recordDict, options) + recordDict, options) else: if vect in recordDict: del recordDict[vect] else: - raise TypeError("Argument error: Records are updated with \n" \ - "key=val1,val2,valN or\n" \ - "key+=val1,val2,valN \n%s Unknown key/val" % vect) + raise TypeError("Argument error: Records are updated with \n" + "key=val1,val2,valN or\n" + "key+=val1,val2,valN \n%s Unknown key/val" % vect) def replaceDict(newval, recordDict, options): @@ -60,6 +61,7 @@ def replaceDict(newval, recordDict, options): for (key, val) in newval.iteritems(): recordDict[key] = val + def modDict(newval, recordDict, options): """ Checks type of existing field, addends new field @@ -86,8 +88,9 @@ def returnVal(arg): else: return arg + def main(): - parser = create_parser(); + parser = create_parser() (options, args) = parser.parse_args() record = XML(sys.stdin.read()) @@ -96,12 +99,13 @@ def main(): editDict(args, record_dict, options) if options.DEBUG: print "New Record:\n%s" % record_dict - + record.parse_dict(record_dict) s = record.toxml() sys.stdout.write(s) if __name__ == '__main__': - try: main() + try: + main() except Exception as e: print e diff --git a/clientbin/sfadump.py b/clientbin/sfadump.py index 617635b9..4a9316d9 100755 --- a/clientbin/sfadump.py +++ b/clientbin/sfadump.py @@ -2,12 +2,14 @@ from __future__ import with_statement import sys -import os, os.path +import os +import os.path import tempfile from argparse import ArgumentParser from sfa.util.sfalogging import logger -from sfa.util.faults import CredentialNotVerifiable, CertMissingParent #, ChildRightsNotSubsetOfParent +# , ChildRightsNotSubsetOfParent +from sfa.util.faults import CredentialNotVerifiable, CertMissingParent from sfa.trust.certificate import Certificate from sfa.trust.credential import Credential @@ -15,26 +17,34 @@ from sfa.trust.gid import GID from sfa.storage.record import Record + def determine_sfa_filekind(fn): - if fn.endswith('.gid'): return 'gid' - elif fn.endswith('.cert'): return 'certificate' - elif fn.endswith('cred'): return 'credential' + if fn.endswith('.gid'): + return 'gid' + elif fn.endswith('.cert'): + return 'certificate' + elif fn.endswith('cred'): + return 'credential' try: - cred=Credential(filename=fn) + cred = Credential(filename=fn) return 'credential' - except: pass + except: + pass - try: - gid=GID(filename=fn) - if gid.uuid: return 'gid' - except: pass + try: + gid = GID(filename=fn) + if gid.uuid: + return 'gid' + except: + pass try: - cert = Certificate(filename = fn) + cert = Certificate(filename=fn) return 'certificate' - except: pass + except: + pass # to be completed # if "gidCaller" in dict: @@ -45,97 +55,105 @@ def determine_sfa_filekind(fn): return "unknown" + def save_gid(gid): - hrn = gid.get_hrn() - lastpart = hrn.split(".")[-1] - filename = lastpart + ".gid" + hrn = gid.get_hrn() + lastpart = hrn.split(".")[-1] + filename = lastpart + ".gid" - if os.path.exists(filename): - print filename, ": already exists... skipping" - return + if os.path.exists(filename): + print filename, ": already exists... skipping" + return - print filename, ": extracting gid of", hrn + print filename, ": extracting gid of", hrn + + gid.save_to_file(filename, save_parents=True) - gid.save_to_file(filename, save_parents = True) def extract_gids(cred, extract_parents): - gidCaller = cred.get_gid_caller() - if gidCaller: - save_gid(gidCaller) + gidCaller = cred.get_gid_caller() + if gidCaller: + save_gid(gidCaller) - gidObject = cred.get_gid_object() - if gidObject and ((gidCaller == None) or (gidCaller.get_hrn() != gidObject.get_hrn())): - save_gid(gidObject) + gidObject = cred.get_gid_object() + if gidObject and ((gidCaller == None) or (gidCaller.get_hrn() != gidObject.get_hrn())): + save_gid(gidObject) - # no such method Credential.get_parent + # no such method Credential.get_parent # if extract_parents: # parent = cred.get_parent() # if parent: # extract_gids(parent, extract_parents) -def verify_input_object (obj, kind, options): + +def verify_input_object(obj, kind, options): if options.trusted_roots: print "CHECKING...", - message= "against [" + (" + ".join(options.trusted_roots)) + "]" + message = "against [" + (" + ".join(options.trusted_roots)) + "]" try: - if kind=='credential': - print "verify",message, + if kind == 'credential': + print "verify", message, obj.verify(options.trusted_roots) - elif kind in ['certificate','gid']: - print "verify_chain",message, + elif kind in ('certificate', 'gid'): + print "verify_chain", message, obj.verify_chain(options.trusted_roots) print "--> OK" except Exception as inst: - print "--> KO",type(inst).__name__ + print "--> KO", type(inst).__name__ -def handle_input (filename, options): + +def handle_input(filename, options): kind = determine_sfa_filekind(filename) # dump methods current do 'print' so let's go this road for now - if kind=="certificate": - cert=Certificate (filename=filename) - print '--------------------',filename,'IS A',kind + if kind == "certificate": + cert = Certificate(filename=filename) + print '--------------------', filename, 'IS A', kind cert.dump(show_extensions=options.show_extensions) - verify_input_object (cert, kind, options) - elif kind=="credential": - cred = Credential(filename = filename) - print '--------------------',filename,'IS A',kind - cred.dump(dump_parents = options.dump_parents, show_xml=options.show_xml) + verify_input_object(cert, kind, options) + elif kind == "credential": + cred = Credential(filename=filename) + print '--------------------', filename, 'IS A', kind + cred.dump(dump_parents=options.dump_parents, show_xml=options.show_xml) if options.extract_gids: - print '--------------------',filename,'embedded GIDs' - extract_gids(cred, extract_parents = options.dump_parents) - verify_input_object (cred, kind, options) - elif kind=="gid": - gid = GID(filename = filename) - print '--------------------',filename,'IS A',kind - gid.dump(dump_parents = options.dump_parents) - verify_input_object (gid, kind, options) + print '--------------------', filename, 'embedded GIDs' + extract_gids(cred, extract_parents=options.dump_parents) + verify_input_object(cred, kind, options) + elif kind == "gid": + gid = GID(filename=filename) + print '--------------------', filename, 'IS A', kind + gid.dump(dump_parents=options.dump_parents) + verify_input_object(gid, kind, options) else: - print "%s: unknown filekind '%s'"% (filename,kind) + print "%s: unknown filekind '%s'" % (filename, kind) + def main(): usage = """%(prog)s file1 [ .. filen] display info on input files""" parser = ArgumentParser(usage=usage) - parser.add_argument("-g", "--extract-gids", action="store_true", dest="extract_gids", + parser.add_argument("-g", "--extract-gids", action="store_true", dest="extract_gids", default=False, help="Extract GIDs from credentials") - parser.add_argument("-p", "--dump-parents", action="store_true", dest="dump_parents", + parser.add_argument("-p", "--dump-parents", action="store_true", dest="dump_parents", default=False, help="Show parents") - parser.add_argument("-e", "--extensions", action="store_true", - dest="show_extensions", default="False", help="Show certificate extensions") - parser.add_argument("-v", "--verbose", action='count', + parser.add_argument("-e", "--extensions", action="store_true", + dest="show_extensions", default="False", + help="Show certificate extensions") + parser.add_argument("-v", "--verbose", action='count', dest='verbose', default=0, help="More and more verbose") - parser.add_argument("-x", "--xml", action='store_true', + parser.add_argument("-x", "--xml", action='store_true', dest='show_xml', default=False, help="dumps xml tree (cred. only)") parser.add_argument("-c", "--check", action='append', dest='trusted_roots', - help="cumulative list of trusted GIDs - when provided, the input is verify'ed against these") - parser.add_argument("filenames",metavar='F',nargs='+',help="filenames to dump") + help="cumulative list of trusted GIDs - " + "when provided, the input is verify'ed against these") + parser.add_argument("filenames", metavar='F', nargs='+', + help="filenames to dump") options = parser.parse_args() logger.setLevelFromOptVerbose(options.verbose) - for filename in options.filenames: - handle_input(filename,options) + for filename in options.filenames: + handle_input(filename, options) -if __name__=="__main__": - main() +if __name__ == "__main__": + main() diff --git a/clientbin/sfiAddAttribute.py b/clientbin/sfiAddAttribute.py index 6fa51b40..89f4a7c5 100755 --- a/clientbin/sfiAddAttribute.py +++ b/clientbin/sfiAddAttribute.py @@ -28,7 +28,6 @@ if command.opts.infile: nodes = f.read().split() f.close() - for name in attrs: print >> sys.stderr, name, attrs[name] for value in attrs[name]: @@ -36,12 +35,14 @@ if command.opts.infile: try: rspec.version.add_default_sliver_attribute(name, value) except: - logger.log_exc("sfiAddAttribute FAILED on all nodes: %s=%s" % (name, value)) + logger.log_exc( + "sfiAddAttribute FAILED on all nodes: %s=%s" % (name, value)) else: for node in nodes: try: rspec.version.add_sliver_attribute(node, name, value) except: - logger.log_exc ("sfiAddAttribute FAILED on node %s: %s=%s" % (node, name, value)) + logger.log_exc( + "sfiAddAttribute FAILED on node %s: %s=%s" % (node, name, value)) print rspec.toxml() diff --git a/clientbin/sfiAddLinks.py b/clientbin/sfiAddLinks.py index b16cae6a..f3792b6c 100755 --- a/clientbin/sfiAddLinks.py +++ b/clientbin/sfiAddLinks.py @@ -19,7 +19,7 @@ if not command.opts.linkfile: print "Missing link list -- exiting" command.parser.print_help() sys.exit(1) - + if command.opts.infile: infile = open(command.opts.infile) else: @@ -36,7 +36,8 @@ version_manager = VersionManager() try: type = ad_rspec.version.type version_num = ad_rspec.version.version - request_version = version_manager._get_version(type, version_num, 'request') + request_version = version_manager._get_version( + type, version_num, 'request') request_rspec = RSpec(version=request_version) request_rspec.version.merge(ad_rspec) request_rspec.version.add_link_requests(link_tuples) diff --git a/clientbin/sfiAddSliver.py b/clientbin/sfiAddSliver.py index 6e6042ce..93bb5005 100755 --- a/clientbin/sfiAddSliver.py +++ b/clientbin/sfiAddSliver.py @@ -19,13 +19,13 @@ if not command.opts.nodefile: print "Missing node list -- exiting" command.parser.print_help() sys.exit(1) - + if command.opts.infile: infile = open(command.opts.infile) else: infile = sys.stdin if command.opts.outfile: - outfile = open(command.opts.outfile,"w") + outfile = open(command.opts.outfile, "w") else: outfile = sys.stdout ad_rspec = RSpec(infile) @@ -34,7 +34,8 @@ version_manager = VersionManager() try: type = ad_rspec.version.type version_num = ad_rspec.version.version - request_version = version_manager._get_version(type, version_num, 'request') + request_version = version_manager._get_version( + type, version_num, 'request') request_rspec = RSpec(version=request_version) request_rspec.version.merge(ad_rspec) request_rspec.version.add_slivers(nodes) diff --git a/clientbin/sfiDeleteAttribute.py b/clientbin/sfiDeleteAttribute.py index 7e6a5aeb..d0ae0b4c 100755 --- a/clientbin/sfiDeleteAttribute.py +++ b/clientbin/sfiDeleteAttribute.py @@ -28,7 +28,6 @@ if command.opts.infile: nodes = f.read().split() f.close() - for name in attrs: print >> sys.stderr, name, attrs[name] for value in attrs[name]: @@ -36,12 +35,15 @@ if command.opts.infile: try: rspec.version.remove_default_sliver_attribute(name, value) except: - logger.log_exc("sfiDeleteAttribute FAILED on all nodes: %s=%s" % (name, value)) + logger.log_exc( + "sfiDeleteAttribute FAILED on all nodes: %s=%s" % (name, value)) else: for node in nodes: try: - rspec.version.remove_sliver_attribute(node, name, value) + rspec.version.remove_sliver_attribute( + node, name, value) except: - logger.log_exc("sfiDeleteAttribute FAILED on node %s: %s=%s" % (node, name, value)) + logger.log_exc( + "sfiDeleteAttribute FAILED on node %s: %s=%s" % (node, name, value)) print rspec.toxml() diff --git a/clientbin/sfiDeleteSliver.py b/clientbin/sfiDeleteSliver.py index 3dc50e65..8669ad69 100755 --- a/clientbin/sfiDeleteSliver.py +++ b/clientbin/sfiDeleteSliver.py @@ -21,14 +21,10 @@ if command.opts.infile: f = open(command.opts.nodefile, "r") nodes = f.read().split() f.close() - + try: slivers = [{'hostname': node} for node in nodes] rspec.version.remove_slivers(slivers) print rspec.toxml() except: logger.log_exc("sfiDeleteSliver FAILED with nodes %s" % nodes) - - - - diff --git a/clientbin/sfiListLinks.py b/clientbin/sfiListLinks.py index a4720ca1..67e603d1 100755 --- a/clientbin/sfiListLinks.py +++ b/clientbin/sfiListLinks.py @@ -3,11 +3,11 @@ import sys from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec -from sfa.util.xrn import Xrn +from sfa.util.xrn import Xrn command = Commands(usage="%prog [options]", - description="List all links in the RSpec. " + - "Use this to display the list of available links. " ) + description="List all links in the RSpec. " + + "Use this to display the list of available links. ") command.prep() if command.opts.infile: @@ -15,12 +15,8 @@ if command.opts.infile: links = rspec.version.get_links() if command.opts.outfile: sys.stdout = open(command.opts.outfile, 'w') - + for link in links: ifname1 = Xrn(link['interface1']['component_id']).get_leaf() ifname2 = Xrn(link['interface2']['component_id']).get_leaf() print "%s %s" % (ifname1, ifname2) - - - - diff --git a/clientbin/sfiListNodes.py b/clientbin/sfiListNodes.py index 17cb3414..13377c9a 100755 --- a/clientbin/sfiListNodes.py +++ b/clientbin/sfiListNodes.py @@ -6,11 +6,11 @@ from sfa.client.sfi_commands import Commands from sfa.rspecs.rspec import RSpec -from sfa.planetlab.plxrn import xrn_to_hostname +from sfa.planetlab.plxrn import xrn_to_hostname command = Commands(usage="%prog [options]", - description="List all nodes in the RSpec. " + - "Use this to display the list of nodes on which it is " + + description="List all nodes in the RSpec. " + + "Use this to display the list of nodes on which it is " + "possible to create a slice.") command.prep() @@ -19,14 +19,10 @@ if command.opts.infile: nodes = rspec.version.get_nodes() if command.opts.outfile: sys.stdout = open(command.opts.outfile, 'w') - + for node in nodes: hostname = None if node.get('component_id'): hostname = xrn_to_hostname(node['component_id']) if hostname: - print hostname - - - - + print hostname diff --git a/clientbin/sfiListSlivers.py b/clientbin/sfiListSlivers.py index c9611d0c..9fdd1c0c 100755 --- a/clientbin/sfiListSlivers.py +++ b/clientbin/sfiListSlivers.py @@ -8,8 +8,8 @@ from sfa.rspecs.rspec import RSpec from sfa.planetlab.plxrn import xrn_to_hostname command = Commands(usage="%prog [options]", - description="List all slivers in the RSpec. " + - "Use this to display the list of nodes belonging to " + + description="List all slivers in the RSpec. " + + "Use this to display the list of nodes belonging to " + "the slice.") command.add_show_attributes_option() command.prep() @@ -17,13 +17,13 @@ command.prep() if command.opts.infile: rspec = RSpec(command.opts.infile) nodes = rspec.version.get_nodes_with_slivers() - + if command.opts.showatt: defaults = rspec.version.get_default_sliver_attributes() if defaults: print "ALL NODES" for (name, value) in defaults: - print " %s: %s" % (name, value) + print " %s: %s" % (name, value) for node in nodes: hostname = None @@ -35,5 +35,3 @@ if command.opts.infile: atts = rspec.version.get_sliver_attributes(hostname) for (name, value) in atts: print " %s: %s" % (name, value) - - diff --git a/config/gen-sfa-cm-config.py b/config/gen-sfa-cm-config.py index abbf3d3d..62f72ded 100755 --- a/config/gen-sfa-cm-config.py +++ b/config/gen-sfa-cm-config.py @@ -10,8 +10,8 @@ sfa_config = SfaConfig() plc_config = PlcConfig() default_host = socket.gethostbyname(socket.gethostname()) all_vars = ['SFA_CONFIG_DIR', 'SFA_DATA_DIR', 'SFA_INTERFACE_HRN', - 'SFA_CM_SLICE_PREFIX', 'SFA_REGISTRY_HOST', 'SFA_REGISTRY_PORT', - 'SFA_AGGREGATE_HOST', 'SFA_AGGREGATE_PORT', + 'SFA_CM_SLICE_PREFIX', 'SFA_REGISTRY_HOST', 'SFA_REGISTRY_PORT', + 'SFA_AGGREGATE_HOST', 'SFA_AGGREGATE_PORT', 'SFA_SM_HOST', 'SFA_SM_PORT', 'SFA_CM_ENABLED', 'SFA_CM_HOST', 'SFA_CM_PORT', 'SFA_CM_TYPE', 'SFA_CM_SLICE_PREFIX', 'SFA_API_LOGLEVEL'] @@ -23,19 +23,18 @@ defaults = { 'SFA_CM_SLICE_PREFIX': plc_config.PLC_SLICE_PREFIX, 'SFA_CM_TYPE': 'pl', 'SFA_API_LOGLEVEL': '0' - } +} host_defaults = { 'SFA_REGISTRY_HOST': default_host, 'SFA_AGGREGATE_HOST': default_host, - 'SFA_SM_HOST': default_host, - } - + 'SFA_SM_HOST': default_host, +} + const_dict = {} for key in all_vars: value = "" - - + if key in defaults: value = defaults[key] elif hasattr(sfa_config, key): @@ -43,16 +42,14 @@ for key in all_vars: # sfa_config may specify localhost instead of a resolvalbe host or ip # if so replace this with the host's address if key in host_defaults and value in ['localhost', '127.0.0.1']: - value = host_defaults[key] + value = host_defaults[key] const_dict[key] = value filename = sfa_config.config_path + os.sep + 'sfa_component_config' conffile = open(filename, 'w') -format='%s="%s"\n' +format = '%s="%s"\n' for var in all_vars: conffile.write(format % (var, const_dict[var])) -conffile.close() - - +conffile.close() diff --git a/config/sfa-config b/config/sfa-config index e623ed84..93ebc051 100755 --- a/config/sfa-config +++ b/config/sfa-config @@ -48,11 +48,12 @@ Usage: %s [OPTION]... [FILES] sys.exit(1) -def deprecated (message): - print "%s: deprecated usage"%sys.argv[0] +def deprecated(message): + print "%s: deprecated usage" % sys.argv[0] print message sys.exit(1) + def main(): config = Config() fileobjs = [] @@ -85,7 +86,7 @@ def main(): for (opt, optval) in opts: if opt == "--shell" or \ - opt == "--bash": + opt == "--bash": output = config.output_shell elif opt == "--python": output = config.output_python @@ -108,17 +109,17 @@ def main(): elif opt == "--value": variable['value'] = optval elif opt == "--group": -# group['id'] = optval + # group['id'] = optval deprecated("option --group deprecated -- use .lst files instead") elif opt == "--package": -# package['name'] = optval + # package['name'] = optval deprecated("option --package deprecated -- use .lst files instead") elif opt == "--type": package['type'] = optval elif opt == '-s' or opt == "--save": if not optval: usage() - print 'parsed save option',optval + print 'parsed save option', optval save = optval elif opt == '-h' or opt == "--help": usage() @@ -152,11 +153,11 @@ def main(): # --save if save: # create directory if needed - # so that plc.d/{api,postgres} can create configs/site.xml - dirname = os.path.dirname (save) - if (not os.path.exists (dirname)): - os.makedirs(dirname,0755) - if (not os.path.exists (dirname)): + # so that plc.d/{api,postgres} can create configs/site.xml + dirname = os.path.dirname(save) + if (not os.path.exists(dirname)): + os.makedirs(dirname, 0755) + if (not os.path.exists(dirname)): print "Cannot create dir %s - exiting" % dirname sys.exit(1) config.save(save) diff --git a/config/sfa-config-tty b/config/sfa-config-tty index 4de43ee2..5377d46a 100755 --- a/config/sfa-config-tty +++ b/config/sfa-config-tty @@ -13,67 +13,71 @@ from optparse import OptionParser from sfa.util.version import version_tag from sfa.util.config import Config + def validator(validated_variables): pass # maint_user = validated_variables["PLC_API_MAINTENANCE_USER"] # root_user = validated_variables["PLC_ROOT_USER"] # if maint_user == root_user: -# errStr="PLC_API_MAINTENANCE_USER=%s cannot be the same as PLC_ROOT_USER=%s"%(maint_user,root_user) +# errStr="PLC_API_MAINTENANCE_USER=%s cannot be the same as PLC_ROOT_USER=%s"%(maint_user, root_user) # raise plc_config.ConfigurationException(errStr) usual_variables = [ "SFA_GENERIC_FLAVOUR", "SFA_INTERFACE_HRN", "SFA_REGISTRY_ROOT_AUTH", - "SFA_REGISTRY_HOST", + "SFA_REGISTRY_HOST", "SFA_AGGREGATE_HOST", "SFA_SM_HOST", "SFA_DB_HOST", - ] - -flavour_xml_section_hash = { \ - 'pl':'sfa_plc', - 'openstack':'sfa_nova', - 'fd':'sfa_federica', - 'nitos':'sfa_nitos', - 'dummy':'sfa_dummy', - } -configuration={ \ - 'name':'sfa', - 'service':"sfa", - 'usual_variables':usual_variables, - 'config_dir':"/etc/sfa", - 'validate_variables':{}, - 'validator':validator, - } +] + +flavour_xml_section_hash = { + 'pl': 'sfa_plc', + 'openstack': 'sfa_nova', + 'fd': 'sfa_federica', + 'nitos': 'sfa_nitos', + 'dummy': 'sfa_dummy', +} +configuration = { + 'name': 'sfa', + 'service': "sfa", + 'usual_variables': usual_variables, + 'config_dir': "/etc/sfa", + 'validate_variables': {}, + 'validator': validator, +} # GLOBAL VARIABLES # -g_configuration=None -usual_variables=None -config_dir=None -service=None +g_configuration = None +usual_variables = None +config_dir = None +service = None + def noop_validator(validated_variables): pass # historically we could also configure the devel pkg.... -def init_configuration (): + + +def init_configuration(): global g_configuration global usual_variables, config_dir, service - usual_variables=g_configuration["usual_variables"] - config_dir=g_configuration["config_dir"] - service=g_configuration["service"] + usual_variables = g_configuration["usual_variables"] + config_dir = g_configuration["config_dir"] + service = g_configuration["service"] global def_default_config, def_site_config, def_consolidated_config - def_default_config= "%s/default_config.xml" % config_dir + def_default_config = "%s/default_config.xml" % config_dir def_site_config = "%s/configs/site_config" % config_dir def_consolidated_config = "%s/%s_config" % (config_dir, service) global mainloop_usage - mainloop_usage= """Available commands: + mainloop_usage = """Available commands: Uppercase versions give variables comments, when available u/U\t\t\tEdit usual variables w\t\t\tWrite @@ -92,17 +96,18 @@ def init_configuration (): Typical usage involves: u, [l,] w, r, q """ % globals() -def usage (): - command_usage="%prog [options] [default-xml [site-xml [consolidated-xml]]]" - init_configuration () - command_usage +=""" + +def usage(): + command_usage = "%prog [options] [default-xml [site-xml [consolidated-xml]]]" + init_configuration() + command_usage += """ \t default-xml defaults to %s \t site-xml defaults to %s -\t consolidated-xml defaults to %s""" % (def_default_config,def_site_config, def_consolidated_config) +\t consolidated-xml defaults to %s""" % (def_default_config, def_site_config, def_consolidated_config) return command_usage #################### -variable_usage= """Edit Commands : +variable_usage = """Edit Commands : #\tShow variable comments .\tStops prompting, return to mainloop /\tCleans any site-defined value, reverts to default @@ -112,149 +117,171 @@ variable_usage= """Edit Commands : """ #################### -def get_value (config, category_id, variable_id): - value = config.get (category_id, variable_id) + + +def get_value(config, category_id, variable_id): + value = config.get(category_id, variable_id) return value -def get_type (config, category_id, variable_id): - value = config.get (category_id, variable_id) - #return variable['type'] + +def get_type(config, category_id, variable_id): + value = config.get(category_id, variable_id) + # return variable['type'] return str -def get_current_value (cread, cwrite, category_id, variable_id): + +def get_current_value(cread, cwrite, category_id, variable_id): # the value stored in cwrite, if present, is the one we want try: - result=get_value (cwrite,category_id,variable_id) + result = get_value(cwrite, category_id, variable_id) except: - result=get_value (cread,category_id,variable_id) + result = get_value(cread, category_id, variable_id) return result # refrain from using plc_config's _sanitize -def get_varname (config, category_id, variable_id): - varname = category_id +"_"+ variable_id + + +def get_varname(config, category_id, variable_id): + varname = category_id + "_" + variable_id config.locate_varname(varname) return varname # could not avoid using _sanitize here.. -def get_name_comments (config, cid, vid): + + +def get_name_comments(config, cid, vid): try: - (category, variable) = config.get (cid, vid) - (id, name, value, comments) = config._sanitize_variable (cid,variable) - return (name,comments) + (category, variable) = config.get(cid, vid) + (id, name, value, comments) = config._sanitize_variable(cid, variable) + return (name, comments) except: - return (None,[]) + return (None, []) + -def print_name_comments (config, cid, vid): - (name,comments)=get_name_comments(config,cid,vid) +def print_name_comments(config, cid, vid): + name, comments = get_name_comments(config, cid, vid) if name: print "### %s" % name if comments: for line in comments: print "# %s" % line else: - print "!!! No comment associated to %s_%s" % (cid,vid) + print "!!! No comment associated to %s_%s" % (cid, vid) #################### -def list_categories (config): - result=[] + + +def list_categories(config): + result = [] for section in config.sections(): result += [section] return result -def print_categories (config): + +def print_categories(config): print "Known categories" for cid in list_categories(config): print "%s" % (cid.upper()) #################### -def list_category (config, cid): - result=[] + + +def list_category(config, cid): + result = [] for section in config.sections(): if section == cid.lower(): - for (name,value) in config.items(section): - result += ["%s_%s" %(cid,name)] + for (name, value) in config.items(section): + result += ["%s_%s" % (cid, name)] return result -def print_category (config, cid, show_comments=True): - cid=cid.lower() - CID=cid.upper() - vids=list_category(config,cid) + +def print_category(config, cid, show_comments=True): + cid = cid.lower() + CID = cid.upper() + vids = list_category(config, cid) if (len(vids) == 0): - print "%s : no such category"%CID + print "%s : no such category" % CID else: - print "Category %s contains" %(CID) + print "Category %s contains" % (CID) for vid in vids: print vid.upper() #################### -def consolidate (default_config, site_config, consolidated_config): + + +def consolidate(default_config, site_config, consolidated_config): global service try: conso = Config(default_config) - conso.load (site_config) - conso.save (consolidated_config) + conso.load(site_config) + conso.save(consolidated_config) except Exception, inst: print "Could not consolidate, %s" % (str(inst)) return - print ("Merged\n\t%s\nand\t%s\ninto\t%s"%(default_config,site_config, - consolidated_config)) + print("Merged\n\t%s\nand\t%s\ninto\t%s" % (default_config, site_config, + consolidated_config)) -def reload_service (): + +def reload_service(): global service os.system("set -x ; service %s reload" % service) #################### -def restart_service (): + + +def restart_service(): global service - print ("==================== Stopping %s" % service) + print("==================== Stopping %s" % service) os.system("service %s stop" % service) - print ("==================== Starting %s" % service) + print("==================== Starting %s" % service) os.system("service %s start" % service) #################### -def prompt_variable (cdef, cread, cwrite, category, variable, - show_comments, support_next=False): +def prompt_variable(cdef, cread, cwrite, category, variable, + show_comments, support_next=False): + category_id = category variable_id = variable while True: - default_value = get_value(cdef,category_id,variable_id) - variable_type = get_type(cdef,category_id,variable_id) - current_value = get_current_value(cread,cwrite,category_id, variable_id) - varname = get_varname (cread,category_id, variable_id) - - if show_comments : - print_name_comments (cdef, category_id, variable_id) - prompt = "== %s : [%s] " % (varname,current_value) + default_value = get_value(cdef, category_id, variable_id) + variable_type = get_type(cdef, category_id, variable_id) + current_value = get_current_value( + cread, cwrite, category_id, variable_id) + varname = get_varname(cread, category_id, variable_id) + + if show_comments: + print_name_comments(cdef, category_id, variable_id) + prompt = "== %s : [%s] " % (varname, current_value) try: answer = raw_input(prompt).strip() - except EOFError : - raise Exception ('BailOut') + except EOFError: + raise Exception('BailOut') except KeyboardInterrupt: print "\n" - raise Exception ('BailOut') + raise Exception('BailOut') # no change if (answer == "") or (answer == current_value): return None elif (answer == "."): - raise Exception ('BailOut') + raise Exception('BailOut') elif (answer == "#"): - print_name_comments(cread,category_id,variable_id) + print_name_comments(cread, category_id, variable_id) elif (answer == "?"): print variable_usage.strip() elif (answer == "="): - print ("%s defaults to %s" %(varname,default_value)) + print("%s defaults to %s" % (varname, default_value)) # revert to default : remove from cwrite (i.e. site-config) elif (answer == "/"): - cwrite.delete(category_id,variable_id) - print ("%s reverted to %s" %(varname,default_value)) + cwrite.delete(category_id, variable_id) + print("%s reverted to %s" % (varname, default_value)) return elif (answer == ">"): if support_next: - raise Exception ('NextCategory') + raise Exception('NextCategory') else: print "No support for next category" else: @@ -264,83 +291,97 @@ def prompt_variable (cdef, cread, cwrite, category, variable, else: print "Not a valid value" -def prompt_variables_all (cdef, cread, cwrite, show_comments): + +def prompt_variables_all(cdef, cread, cwrite, show_comments): try: for (category_id, (category, variables)) in cread.variables().iteritems(): - print ("========== Category = %s" % category_id.upper()) + print("========== Category = %s" % category_id.upper()) for variable in variables.values(): try: - newvar = prompt_variable (cdef, cread, cwrite, category, variable, - show_comments, True) + newvar = prompt_variable(cdef, cread, cwrite, category, variable, + show_comments, True) except Exception, inst: - if (str(inst) == 'NextCategory'): break - else: raise + if (str(inst) == 'NextCategory'): + break + else: + raise except Exception, inst: - if (str(inst) == 'BailOut'): return - else: raise + if (str(inst) == 'BailOut'): + return + else: + raise -def prompt_variables_category (cdef, cread, cwrite, cid, show_comments): - cid=cid.lower() - CID=cid.upper() + +def prompt_variables_category(cdef, cread, cwrite, cid, show_comments): + cid = cid.lower() + CID = cid.upper() try: - print ("========== Category = %s" % CID) - for vid in list_category(cdef,cid): - (category,variable) = cdef.locate_varname(vid.upper()) - newvar = prompt_variable (cdef, cread, cwrite, category, variable, - show_comments, False) + print("========== Category = %s" % CID) + for vid in list_category(cdef, cid): + (category, variable) = cdef.locate_varname(vid.upper()) + newvar = prompt_variable(cdef, cread, cwrite, category, variable, + show_comments, False) except Exception, inst: - if (str(inst) == 'BailOut'): return - else: raise + if (str(inst) == 'BailOut'): + return + else: + raise #################### -def show_variable (cdef, cread, cwrite, - category, variable,show_value,show_comments): + + +def show_variable(cdef, cread, cwrite, + category, variable, show_value, show_comments): assert category.has_key('id') assert variable.has_key('id') - category_id = category ['id'] + category_id = category['id'] variable_id = variable['id'] - default_value = get_value(cdef,category_id,variable_id) - current_value = get_current_value(cread,cwrite,category_id,variable_id) - varname = get_varname (cread,category_id, variable_id) - if show_comments : - print_name_comments (cdef, category_id, variable_id) + default_value = get_value(cdef, category_id, variable_id) + current_value = get_current_value(cread, cwrite, category_id, variable_id) + varname = get_varname(cread, category_id, variable_id) + if show_comments: + print_name_comments(cdef, category_id, variable_id) if show_value: - print "%s = %s" % (varname,current_value) + print "%s = %s" % (varname, current_value) else: print "%s" % (varname) -def show_variables_all (cdef, cread, cwrite, show_value, show_comments): + +def show_variables_all(cdef, cread, cwrite, show_value, show_comments): for (category_id, (category, variables)) in cread.variables().iteritems(): - print ("========== Category = %s" % category_id.upper()) + print("========== Category = %s" % category_id.upper()) for variable in variables.values(): - show_variable (cdef, cread, cwrite, - category, variable,show_value,show_comments) - -def show_variables_category (cdef, cread, cwrite, cid, show_value,show_comments): - cid=cid.lower() - CID=cid.upper() - print ("========== Category = %s" % CID) - for vid in list_category(cdef,cid): - (category,variable) = cdef.locate_varname(vid.upper()) - show_variable (cdef, cread, cwrite, category, variable, - show_value,show_comments) + show_variable(cdef, cread, cwrite, + category, variable, show_value, show_comments) + + +def show_variables_category(cdef, cread, cwrite, cid, show_value, show_comments): + cid = cid.lower() + CID = cid.upper() + print("========== Category = %s" % CID) + for vid in list_category(cdef, cid): + (category, variable) = cdef.locate_varname(vid.upper()) + show_variable(cdef, cread, cwrite, category, variable, + show_value, show_comments) #################### -re_mainloop_0arg="^(?P[uUwrRqlLsSeEcvVhH\?])[ \t]*$" -re_mainloop_1arg="^(?P[sSeEvV])[ \t]+(?P\w+)$" -matcher_mainloop_0arg=re.compile(re_mainloop_0arg) -matcher_mainloop_1arg=re.compile(re_mainloop_1arg) +re_mainloop_0arg = "^(?P[uUwrRqlLsSeEcvVhH\?])[ \t]*$" +re_mainloop_1arg = "^(?P[sSeEvV])[ \t]+(?P\w+)$" +matcher_mainloop_0arg = re.compile(re_mainloop_0arg) +matcher_mainloop_1arg = re.compile(re_mainloop_1arg) -def mainloop (cdef, cread, cwrite, default_config, site_config, consolidated_config): + +def mainloop(cdef, cread, cwrite, default_config, site_config, consolidated_config): global service while True: try: - answer = raw_input("Enter command (u for usual changes, w to save, ? for help) ").strip() + answer = raw_input( + "Enter command (u for usual changes, w to save, ? for help) ").strip() except EOFError: - answer ="" + answer = "" except KeyboardInterrupt: print "\nBye" sys.exit() @@ -349,36 +390,36 @@ def mainloop (cdef, cread, cwrite, default_config, site_config, consolidated_con print mainloop_usage continue groups_parse = matcher_mainloop_0arg.match(answer) - command=None + command = None if (groups_parse): command = groups_parse.group('command') - arg=None + arg = None else: groups_parse = matcher_mainloop_1arg.match(answer) if (groups_parse): command = groups_parse.group('command') - arg=groups_parse.group('arg') + arg = groups_parse.group('arg') if not command: - print ("Unknown command >%s< -- use h for help" % answer) + print("Unknown command >%s< -- use h for help" % answer) continue - show_comments=command.isupper() + show_comments = command.isupper() - mode='ALL' + mode = 'ALL' if arg: - mode=None - arg=arg.lower() - variables=list_category (cdef,arg) + mode = None + arg = arg.lower() + variables = list_category(cdef, arg) if len(variables): # category_id as the category name # variables as the list of variable names - mode='CATEGORY' - category_id=arg - arg=arg.upper() - (category,variable)=cdef.locate_varname(arg) + mode = 'CATEGORY' + category_id = arg + arg = arg.upper() + (category, variable) = cdef.locate_varname(arg) if variable: # category/variable as output by locate_varname - mode='VARIABLE' + mode = 'VARIABLE' if not mode: print "%s: no such category or variable" % arg continue @@ -389,38 +430,42 @@ def mainloop (cdef, cread, cwrite, default_config, site_config, consolidated_con elif command == "w": try: # Confirm that various constraints are met before saving file. - validate_variables = g_configuration.get('validate_variables',{}) - validated_variables = cwrite.verify(cdef, cread, validate_variables) - validator = g_configuration.get('validator',noop_validator) + validate_variables = g_configuration.get( + 'validate_variables', {}) + validated_variables = cwrite.verify( + cdef, cread, validate_variables) + validator = g_configuration.get('validator', noop_validator) validator(validated_variables) cwrite.save(site_config) except: print "Save failed due to a configuration exception:" print traceback.print_exc() - print ("Could not save -- fix write access on %s" % site_config) + print("Could not save -- fix write access on %s" % site_config) break - print ("Wrote %s" % site_config) + print("Wrote %s" % site_config) consolidate(default_config, site_config, consolidated_config) - print ("You might want to type 'r' (restart %s), 'R' (reload %s) or 'q' (quit)" % \ - (service,service)) + print("You might want to type 'r' (restart %s), 'R' (reload %s) or 'q' (quit)" % + (service, service)) elif command in "uU": global usual_variables global flavour_xml_section_hash try: for varname in usual_variables: - (category,variable) = cdef.locate_varname(varname) + (category, variable) = cdef.locate_varname(varname) if not (category is None and variable is None): - prompt_variable(cdef, cread, cwrite, category, variable, False) - - # set the driver variable according to the already set flavour + prompt_variable(cdef, cread, cwrite, + category, variable, False) + + # set the driver variable according to the already set flavour generic_flavour = cwrite.items('sfa')[0][1] for section in cdef.sections(): - if generic_flavour in flavour_xml_section_hash and flavour_xml_section_hash[generic_flavour] == section: - for item in cdef.items(section): - category = section - variable = item[0] - prompt_variable(cdef, cread, cwrite, category, variable, False) - break + if generic_flavour in flavour_xml_section_hash and flavour_xml_section_hash[generic_flavour] == section: + for item in cdef.items(section): + category = section + variable = item[0] + prompt_variable(cdef, cread, cwrite, + category, variable, False) + break except Exception, inst: if (str(inst) != 'BailOut'): @@ -433,71 +478,80 @@ def mainloop (cdef, cread, cwrite, default_config, site_config, consolidated_con print_categories(cread) elif command in "eE": if mode == 'ALL': - prompt_variables_all(cdef, cread, cwrite,show_comments) + prompt_variables_all(cdef, cread, cwrite, show_comments) elif mode == 'CATEGORY': - prompt_variables_category(cdef,cread,cwrite,category_id,show_comments) + prompt_variables_category( + cdef, cread, cwrite, category_id, show_comments) elif mode == 'VARIABLE': try: - prompt_variable (cdef,cread,cwrite,category,variable, - show_comments,False) + prompt_variable(cdef, cread, cwrite, category, variable, + show_comments, False) except Exception, inst: if str(inst) != 'BailOut': raise elif command in "vVsSlL": - show_value=(command in "sSlL") - (c1,c2,c3) = (cdef, cread, cwrite) + show_value = (command in "sSlL") + (c1, c2, c3) = (cdef, cread, cwrite) if command in "lL": - (c1,c2,c3) = (cwrite,cwrite,cwrite) + (c1, c2, c3) = (cwrite, cwrite, cwrite) if mode == 'ALL': - show_variables_all(c1,c2,c3,show_value,show_comments) + show_variables_all(c1, c2, c3, show_value, show_comments) elif mode == 'CATEGORY': - show_variables_category(c1,c2,c3,category_id,show_value,show_comments) + show_variables_category( + c1, c2, c3, category_id, show_value, show_comments) elif mode == 'VARIABLE': - show_variable (c1,c2,c3,category,variable,show_value,show_comments) + show_variable(c1, c2, c3, category, variable, + show_value, show_comments) else: - print ("Unknown command >%s< -- use h for help" % answer) + print("Unknown command >%s< -- use h for help" % answer) #################### # creates directory for file if not yet existing -def check_dir (config_file): - dirname = os.path.dirname (config_file) - if (not os.path.exists (dirname)): +def check_dir(config_file): + dirname = os.path.dirname(config_file) + if (not os.path.exists(dirname)): try: - os.makedirs(dirname,0755) + os.makedirs(dirname, 0755) except OSError, e: - print "Cannot create dir %s due to %s - exiting" % (dirname,e) + print "Cannot create dir %s due to %s - exiting" % (dirname, e) sys.exit(1) - if (not os.path.exists (dirname)): + if (not os.path.exists(dirname)): print "Cannot create dir %s - exiting" % dirname sys.exit(1) else: print "Created directory %s" % dirname #################### + + def optParserSetup(configuration): - parser = OptionParser(usage=usage(), version="%prog " + version_tag ) + parser = OptionParser(usage=usage(), version="%prog " + version_tag) parser.set_defaults(config_dir=configuration['config_dir'], service=configuration['service'], usual_variables=configuration['usual_variables']) - parser.add_option("","--configdir",dest="config_dir",help="specify configuration directory") - parser.add_option("","--service",dest="service",help="specify /etc/init.d style service name") - parser.add_option("","--usual_variable",dest="usual_variables",action="append", help="add a usual variable") + parser.add_option("", "--configdir", dest="config_dir", + help="specify configuration directory") + parser.add_option("", "--service", dest="service", + help="specify /etc/init.d style service name") + parser.add_option("", "--usual_variable", dest="usual_variables", + action="append", help="add a usual variable") return parser -def main(command,argv,configuration): + +def main(command, argv, configuration): global g_configuration - g_configuration=configuration + g_configuration = configuration parser = optParserSetup(configuration) - (config,args) = parser.parse_args() - if len(args)>3: + (config, args) = parser.parse_args() + if len(args) > 3: parser.error("too many arguments") - configuration['service']=config.service - configuration['usual_variables']=config.usual_variables - configuration['config_dir']=config.config_dir + configuration['service'] = config.service + configuration['usual_variables'] = config.usual_variables + configuration['config_dir'] = config.config_dir # add in new usual_variables defined on the command line for usual_variable in config.usual_variables: if usual_variable not in configuration['usual_variables']: @@ -506,16 +560,17 @@ def main(command,argv,configuration): # intialize configuration init_configuration() - (default_config,site_config,consolidated_config) = (def_default_config, def_site_config, def_consolidated_config) + default_config, site_config, consolidated_config = \ + def_default_config, def_site_config, def_consolidated_config if len(args) >= 1: - default_config=args[0] + default_config = args[0] if len(args) >= 2: - site_config=args[1] + site_config = args[1] if len(args) == 3: - consolidated_config=args[2] + consolidated_config = args[2] - for c in (default_config,site_config,consolidated_config): - check_dir (c) + for c in (default_config, site_config, consolidated_config): + check_dir(c) try: # the default settings only - read only @@ -525,12 +580,13 @@ def main(command,argv,configuration): cread = Config(default_config) except: print traceback.print_exc() - print ("default config files %s not found, is myplc installed ?" % default_config) + print("default config files %s not found, is myplc installed ?" % + default_config) return 1 # local settings only, will be modified & saved config_filename = "%s/sfa_config" % config.config_dir - cwrite=Config(config_filename) + cwrite = Config(config_filename) try: cread.load(site_config) cwrite.load(default_config) @@ -538,10 +594,11 @@ def main(command,argv,configuration): except: cwrite = Config() - mainloop (cdef, cread, cwrite, default_config, site_config, consolidated_config) - return 0 + mainloop(cdef, cread, cwrite, default_config, + site_config, consolidated_config) + return 0 if __name__ == '__main__': - command=sys.argv[0] + command = sys.argv[0] argv = sys.argv[1:] - main(command,argv,configuration) + main(command, argv, configuration) diff --git a/flashpolicy/sfa_flashpolicy.py b/flashpolicy/sfa_flashpolicy.py index f4d3f165..6d266c27 100644 --- a/flashpolicy/sfa_flashpolicy.py +++ b/flashpolicy/sfa_flashpolicy.py @@ -19,20 +19,26 @@ import contextlib VERSION = 0.1 + def daemon(): """Daemonize the current process.""" - if os.fork() != 0: os._exit(0) + if os.fork() != 0: + os._exit(0) os.setsid() - if os.fork() != 0: os._exit(0) + if os.fork() != 0: + os._exit(0) os.umask(0) devnull = os.open(os.devnull, os.O_RDWR) os.dup2(devnull, 0) - # xxx fixme - this is just to make sure that nothing gets stupidly lost - should use devnull + # xxx fixme - this is just to make sure that nothing gets stupidly lost - + # should use devnull crashlog = os.open('/var/log/sfa_flashpolicy.log', os.O_RDWR | os.O_APPEND | os.O_CREAT, 0644) os.dup2(crashlog, 1) os.dup2(crashlog, 2) + class policy_server(object): + def __init__(self, port, path): self.port = port self.path = path @@ -49,6 +55,7 @@ class policy_server(object): self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.sock.bind(('', port)) self.sock.listen(5) + def read_policy(self, path): with open(path, 'rb') as f: policy = f.read(10001) @@ -59,42 +66,46 @@ class policy_server(object): raise exceptions.RuntimeError('Not a valid policy file', path) return policy + def run(self): try: while True: thread.start_new_thread(self.handle, self.sock.accept()) except socket.error as e: - self.log('Error accepting connection: %s' % (e[1],)) + self.log('Error accepting connection: %s' % e[1]) + def handle(self, conn, addr): - addrstr = '%s:%s' % (addr[0],addr[1]) + addrstr = '%s:%s' % (addr[0], addr[1]) try: - self.log('Connection from %s' % (addrstr,)) + self.log('Connection from %s' % addrstr) with contextlib.closing(conn): # It's possible that we won't get the entire request in # a single recv, but very unlikely. request = conn.recv(1024).strip() - #if request != '\0': + # if request != '\0': # self.log('Unrecognized request from %s: %s' % (addrstr, request)) # return - self.log('Valid request received from %s' % (addrstr,)) + self.log('Valid request received from %s' % addrstr) conn.sendall(self.policy) - self.log('Sent policy file to %s' % (addrstr,)) + self.log('Sent policy file to %s' % addrstr) except socket.error as e: self.log('Error handling connection from %s: %s' % (addrstr, e[1])) except Exception as e: self.log('Error handling connection from %s: %s' % (addrstr, e[1])) + def log(self, str): print >>sys.stderr, str + def main(): - parser = optparse.OptionParser(usage = '%prog [--port=PORT] --file=FILE', + parser = optparse.OptionParser(usage='%prog [--port=PORT] --file=FILE', version='%prog ' + str(VERSION)) parser.add_option('-p', '--port', dest='port', type=int, default=843, help='listen on port PORT', metavar='PORT') parser.add_option('-f', '--file', dest='path', help='server policy file FILE', metavar='FILE') parser.add_option("-d", "--daemon", dest="daemon", action="store_true", - help="Run as daemon.", default=False) + help="Run as daemon.", default=False) opts, args = parser.parse_args() if args: parser.error('No arguments are needed. See help.') diff --git a/keyconvert/keyconvert.py b/keyconvert/keyconvert.py index 5a239d59..143a8c86 100755 --- a/keyconvert/keyconvert.py +++ b/keyconvert/keyconvert.py @@ -13,7 +13,7 @@ class RSA_pub_fix(RSA.RSA_pub): return self.save_pub_key_bio(bio) def rsa_new_pub_key(couple): - (e,n)=couple + (e, n)=couple rsa = m2.rsa_new() m2.rsa_set_e(rsa, e) m2.rsa_set_n(rsa, n) diff --git a/setup.py b/setup.py index 45f58f8e..ab111bf8 100755 --- a/setup.py +++ b/setup.py @@ -103,25 +103,26 @@ for d in processor_subdirs: if sys.argv[1] in ['uninstall', 'remove', 'delete', 'clean']: python_path = sys.path - site_packages_path = [ os.path.join(p,'sfa') for p in python_path if p.endswith('site-packages')] - site_packages_path += [ os.path.join(p,'sfatables') for p in python_path if p.endswith('site-packages')] + site_packages_path = [ os.path.join(p, 'sfa') for p in python_path if p.endswith('site-packages')] + site_packages_path += [ os.path.join(p, 'sfatables') for p in python_path if p.endswith('site-packages')] remove_dirs = ['/etc/sfa/', '/etc/sfatables'] + site_packages_path remove_bins = [ '/usr/bin/' + os.path.basename(bin) for bin in scripts ] remove_files = remove_bins + [ "/etc/init.d/%s"%x for x in initscripts ] # remove files - def feedback (file, msg): print ("removing", file, "...",msg) + def feedback (file, msg): + print ("removing", file, "...", msg) for filepath in remove_files: try: os.remove(filepath) - feedback(filepath,"success") + feedback(filepath, "success") except: - feedback(filepath,"failed") + feedback(filepath, "failed") # remove directories for directory in remove_dirs: try: shutil.rmtree(directory) - feedback (directory,"success") + feedback (directory, "success") except: feedback (directory, "failed") else: @@ -142,7 +143,7 @@ else: packages = packages, data_files = data_files, version = version_tag, - keywords = ['federation','testbeds','SFA','SfaWrap'], + keywords = ['federation', 'testbeds', 'SFA', 'SfaWrap'], url = "http://svn.planet-lab.org/wiki/SFATutorial", author = "Thierry Parmentelat, Tony Mack, Scott Baker", author_email = "thierry.parmentelat@inria.fr, tmack@princeton.cs.edu, smbaker@gmail.com", diff --git a/sfa/client/candidates.py b/sfa/client/candidates.py index 830e24d8..c36e1991 100644 --- a/sfa/client/candidates.py +++ b/sfa/client/candidates.py @@ -1,52 +1,64 @@ from __future__ import print_function -### utility to match command-line args to names +# utility to match command-line args to names + + class Candidates: - def __init__ (self, names): - self.names=names + + def __init__(self, names): + self.names = names # is an input string acceptable for one of the known names? + @staticmethod - def fits (input, name): - return name.find(input)==0 + def fits(input, name): + return name.find(input) == 0 # returns one of the names if the input name has a unique match # or None otherwise - def only_match (self, input): - if input in self.names: return input - matches=[ name for name in self.names if Candidates.fits(input,name) ] - if len(matches)==1: return matches[0] - else: return None - -#################### minimal test -candidates_specs=[ -('create delete reset resources slices start status stop version create_gid', - [ ('ver','version'), - ('r',None), - ('re',None), - ('res',None), - ('rese','reset'), - ('reset','reset'), - ('reso','resources'), - ('sli','slices'), - ('st',None), - ('sta',None), - ('stop','stop'), - ('a',None), - ('cre',None), - ('create','create'), - ('create_','create_gid'), - ('create_g','create_gid'), - ('create_gi','create_gid'), - ('create_gid','create_gid'), -]) + + def only_match(self, input): + if input in self.names: + return input + matches = [name for name in self.names if Candidates.fits(input, name)] + if len(matches) == 1: + return matches[0] + else: + return None + +# minimal test +candidates_specs = [ + ('create delete reset resources slices start status stop version create_gid', + [('ver', 'version'), + ('r', None), + ('re', None), + ('res', None), + ('rese', 'reset'), + ('reset', 'reset'), + ('reso', 'resources'), + ('sli', 'slices'), + ('st', None), + ('sta', None), + ('stop', 'stop'), + ('a', None), + ('cre', None), + ('create', 'create'), + ('create_', 'create_gid'), + ('create_g', 'create_gid'), + ('create_gi', 'create_gid'), + ('create_gid', 'create_gid'), + ]) ] -def test_candidates (): + +def test_candidates(): for (names, tuples) in candidates_specs: - names=names.split() - for (input,expected) in tuples: - got=Candidates(names).only_match(input) - if got==expected: print('.', end=' ') - else: print('X FAIL','names[',names,'] input',input,'expected',expected,'got',got) + names = names.split() + for (input, expected) in tuples: + got = Candidates(names).only_match(input) + if got == expected: + print('.', end=' ') + else: + print('X FAIL', 'names[', names, '] input', + input, 'expected', expected, 'got', got) if __name__ == '__main__': test_candidates() diff --git a/sfa/client/client_helper.py b/sfa/client/client_helper.py index 6e917060..fd15c0fe 100644 --- a/sfa/client/client_helper.py +++ b/sfa/client/client_helper.py @@ -14,21 +14,24 @@ # the fact e.g. that PlanetLab insists on getting a first_name and last_name is not # exactly consistent with the GENI spec. of CreateSliver # + + def pg_users_arg(records): - users = [] + users = [] for record in records: - if record['type'] != 'user': + if record['type'] != 'user': continue user = {'urn': record['reg-urn'], 'keys': record['reg-keys'], 'email': record['email']} users.append(user) - return users + return users + -def sfa_users_arg (records, slice_record): +def sfa_users_arg(records, slice_record): users = [] for record in records: - if record['type'] != 'user': + if record['type'] != 'user': continue user = {'urn': record['reg-urn'], 'keys': record['reg-keys'], @@ -36,22 +39,25 @@ def sfa_users_arg (records, slice_record): } # fill as much stuff as possible from planetlab or similar # note that reg-email is not yet available - pl_fields = ['email', 'person_id', 'first_name', 'last_name', 'key_ids'] - nitos_fields = [ 'email', 'user_id' ] - extra_fields = list ( set(pl_fields).union(set(nitos_fields))) + pl_fields = ['email', 'person_id', + 'first_name', 'last_name', 'key_ids'] + nitos_fields = ['email', 'user_id'] + extra_fields = list(set(pl_fields).union(set(nitos_fields))) # try to fill all these in for field in extra_fields: - if field in record: user[field]=record[field] + if field in record: + user[field] = record[field] users.append(user) return users + def sfa_to_pg_users_arg(users): new_users = [] fields = ['urn', 'keys'] for user in users: - new_user = dict([item for item in user.items() \ - if item[0] in fields]) + new_user = dict([item for item in user.items() + if item[0] in fields]) new_users.append(new_user) - return new_users + return new_users diff --git a/sfa/client/common.py b/sfa/client/common.py index 16a06940..f27098db 100644 --- a/sfa/client/common.py +++ b/sfa/client/common.py @@ -2,105 +2,128 @@ from __future__ import print_function + def optparse_listvalue_callback(option, opt, value, parser): - former=getattr(parser.values,option.dest) - if not former: former=[] + former = getattr(parser.values, option.dest) + if not former: + former = [] # support for using e.g. sfi update -t slice -x the.slice.hrn -r none # instead of -r '' which is painful and does not pass well through ssh - if value.lower()=='none': - newvalue=former + if value.lower() == 'none': + newvalue = former else: - newvalue=former+value.split(',') + newvalue = former + value.split(',') setattr(parser.values, option.dest, newvalue) -def optparse_dictvalue_callback (option, option_string, value, parser): + +def optparse_dictvalue_callback(option, option_string, value, parser): try: - (k,v)=value.split('=',1) - d=getattr(parser.values, option.dest) - d[k]=v + (k, v) = value.split('=', 1) + d = getattr(parser.values, option.dest) + d[k] = v except: parser.print_help() sys.exit(1) -# a code fragment that could be helpful for argparse which unfortunately is +# a code fragment that could be helpful for argparse which unfortunately is # available with 2.7 only, so this feels like too strong a requirement for the client side -#class ExtraArgAction (argparse.Action): +# class ExtraArgAction (argparse.Action): # def __call__ (self, parser, namespace, values, option_string=None): # would need a try/except of course -# (k,v)=values.split('=') -# d=getattr(namespace,self.dest) -# d[k]=v +# (k, v) = values.split('=') +# d = getattr(namespace, self.dest) +# d[k] = v ##### -#parser.add_argument ("-X","--extra",dest='extras', default={}, action=ExtraArgAction, -# help="set extra flags, testbed dependent, e.g. --extra enabled=true") - +# parser.add_argument ("-X", "--extra", dest='extras', default={}, action=ExtraArgAction, +# help="set extra flags, testbed dependent, e.g. --extra enabled=true") + ############################## # these are not needed from the outside -def terminal_render_plural (how_many, name,names=None): - if not names: names="%ss"%name - if how_many<=0: return "No %s"%name - elif how_many==1: return "1 %s"%name - else: return "%d %s"%(how_many,names) -def terminal_render_default (record,options): + +def terminal_render_plural(how_many, name, names=None): + if not names: + names = "%ss" % name + if how_many <= 0: + return "No %s" % name + elif how_many == 1: + return "1 %s" % name + else: + return "%d %s" % (how_many, names) + + +def terminal_render_default(record, options): print("%s (%s)" % (record['hrn'], record['type'])) -def terminal_render_user (record, options): - print("%s (User)"%record['hrn'], end=' ') + + +def terminal_render_user(record, options): + print("%s (User)" % record['hrn'], end=' ') if options.verbose and record.get('email', None): print("email='{}'".format(record['email']), end=' ') if record.get('reg-pi-authorities', None): - print(" [PI at %s]"%(" and ".join(record['reg-pi-authorities'])), end=' ') + print(" [PI at %s]" % + (" and ".join(record['reg-pi-authorities'])), end=' ') if record.get('reg-slices', None): - print(" [IN slices %s]"%(" and ".join(record['reg-slices'])), end=' ') - user_keys=record.get('reg-keys',[]) + print(" [IN slices %s]" % + (" and ".join(record['reg-slices'])), end=' ') + user_keys = record.get('reg-keys', []) if not options.verbose: - print(" [has %s]"%(terminal_render_plural(len(user_keys),"key"))) + print(" [has %s]" % (terminal_render_plural(len(user_keys), "key"))) else: print("") - for key in user_keys: print(8*' ',key.strip("\n")) - -def terminal_render_slice (record, options): - print("%s (Slice)"%record['hrn'], end=' ') + for key in user_keys: + print(8 * ' ', key.strip("\n")) + + +def terminal_render_slice(record, options): + print("%s (Slice)" % record['hrn'], end=' ') if record.get('reg-researchers', None): - print(" [USERS %s]"%(" and ".join(record['reg-researchers'])), end=' ') + print(" [USERS %s]" % + (" and ".join(record['reg-researchers'])), end=' ') # print record.keys() print("") -def terminal_render_authority (record, options): - print("%s (Authority)"%record['hrn'], end=' ') + + +def terminal_render_authority(record, options): + print("%s (Authority)" % record['hrn'], end=' ') if options.verbose and record.get('name'): print("name='{}'".format(record['name'])) if record.get('reg-pis', None): - print(" [PIS %s]"%(" and ".join(record['reg-pis'])), end=' ') + print(" [PIS %s]" % (" and ".join(record['reg-pis'])), end=' ') print("") -def terminal_render_node (record, options): - print("%s (Node)"%record['hrn']) -### used in sfi list -def terminal_render (records, options): +def terminal_render_node(record, options): + print("%s (Node)" % record['hrn']) + + +# used in sfi list +def terminal_render(records, options): # sort records by type grouped_by_type = {} for record in records: type = record['type'] if type not in grouped_by_type: - grouped_by_type[type]=[] + grouped_by_type[type] = [] grouped_by_type[type].append(record) group_types = grouped_by_type.keys() group_types.sort() for type in group_types: group = grouped_by_type[type] # print 20 * '-', type - try: renderer = eval('terminal_render_' + type) - except: renderer = terminal_render_default + try: + renderer = eval('terminal_render_' + type) + except: + renderer = terminal_render_default for record in group: renderer(record, options) #################### + + def filter_records(type, records): filtered_records = [] for record in records: if (record['type'] == type) or (type == "all"): filtered_records.append(record) return filtered_records - - diff --git a/sfa/client/manifolduploader.py b/sfa/client/manifolduploader.py index 3b9de167..e377c8d7 100755 --- a/sfa/client/manifolduploader.py +++ b/sfa/client/manifolduploader.py @@ -16,7 +16,7 @@ # v1, that offers an AddCredential API call, towards a new API v2 that # manages credentials with the same set of Get/Update calls as other # objects -# +# # mostly this is intended to be used through 'sfi myslice' # so the defaults below are of no real importance @@ -27,64 +27,75 @@ DEFAULT_PLATFORM = 'ple' # starting with 2.7.9 we need to turn off server verification import ssl -try: turn_off_server_verify = { 'context' : ssl._create_unverified_context() } -except: turn_off_server_verify = {} +try: + turn_off_server_verify = {'context': ssl._create_unverified_context()} +except: + turn_off_server_verify = {} import getpass from sfa.util.py23 import xmlrpc_client + class ManifoldUploader: """A utility class for uploading delegated credentials to a manifold/MySlice infrastructure""" - # platform is a name internal to the manifold deployment, + # platform is a name internal to the manifold deployment, # that maps to a testbed, like e.g. 'ple' - def __init__ (self, logger, url=None, platform=None, username=None, password=None, ): - self._url=url - self._platform=platform - self._username=username - self._password=password - self.logger=logger - self._proxy=None - - def username (self): + def __init__(self, logger, url=None, platform=None, username=None, password=None, ): + self._url = url + self._platform = platform + self._username = username + self._password = password + self.logger = logger + self._proxy = None + + def username(self): if not self._username: - self._username=raw_input("Enter your manifold username: ") + self._username = raw_input("Enter your manifold username: ") return self._username - def password (self): + def password(self): if not self._password: - username=self.username() - self._password=getpass.getpass("Enter password for manifold user %s: "%username) + username = self.username() + self._password = getpass.getpass( + "Enter password for manifold user %s: " % username) return self._password - def platform (self): + def platform(self): if not self._platform: - self._platform=raw_input("Enter your manifold platform [%s]: "%DEFAULT_PLATFORM) - if self._platform.strip()=="": self._platform = DEFAULT_PLATFORM + self._platform = raw_input( + "Enter your manifold platform [%s]: " % DEFAULT_PLATFORM) + if self._platform.strip() == "": + self._platform = DEFAULT_PLATFORM return self._platform - def url (self): + def url(self): if not self._url: - self._url=raw_input("Enter the URL for your manifold API [%s]: "%DEFAULT_URL) - if self._url.strip()=="": self._url = DEFAULT_URL + self._url = raw_input( + "Enter the URL for your manifold API [%s]: " % DEFAULT_URL) + if self._url.strip() == "": + self._url = DEFAULT_URL return self._url def prompt_all(self): - self.username(); self.password(); self.platform(); self.url() + self.username() + self.password() + self.platform() + self.url() # looks like the current implementation of manifold server # won't be happy with several calls issued in the same session # so we do not cache this one - def proxy (self): -# if not self._proxy: -# url=self.url() -# self.logger.info("Connecting manifold url %s"%url) -# self._proxy = xmlrpc_client.ServerProxy(url, allow_none = True) -# return self._proxy - url=self.url() - self.logger.debug("Connecting manifold url %s"%url) - proxy = xmlrpc_client.ServerProxy(url, allow_none = True, + def proxy(self): + # if not self._proxy: + # url=self.url() + # self.logger.info("Connecting manifold url %s"%url) + # self._proxy = xmlrpc_client.ServerProxy(url, allow_none = True) + # return self._proxy + url = self.url() + self.logger.debug("Connecting manifold url %s" % url) + proxy = xmlrpc_client.ServerProxy(url, allow_none=True, **turn_off_server_verify) return proxy @@ -92,19 +103,21 @@ class ManifoldUploader: # does the job for one credential # expects the credential (string) and an optional message (e.g. hrn) for reporting # return True upon success and False otherwise - def upload (self, delegated_credential, message=None): - platform=self.platform() - username=self.username() - password=self.password() - auth = {'AuthMethod': 'password', 'Username': username, 'AuthString': password} - if not message: message="" + def upload(self, delegated_credential, message=None): + platform = self.platform() + username = self.username() + password = self.password() + auth = {'AuthMethod': 'password', + 'Username': username, 'AuthString': password} + if not message: + message = "" try: - manifold=self.proxy() + manifold = self.proxy() # the code for a V2 interface - query = { 'action': 'update', + query = {'action': 'update', 'object': 'local:account', - 'filters': [ ['platform', '=', platform] ] , + 'filters': [['platform', '=', platform]], 'params': {'credential': delegated_credential, }, } annotation = {'authentication': auth, } @@ -112,61 +125,71 @@ class ManifoldUploader: # but fill in error code and messages instead # however this is only theoretical so let's be on the safe side try: - self.logger.debug("Using new v2 method forward+annotation@%s %s"%(platform,message)) - retcod2=manifold.forward (query, annotation) + self.logger.debug( + "Using new v2 method forward+annotation@%s %s" % (platform, message)) + retcod2 = manifold.forward(query, annotation) except Exception as e: - # xxx we need a constant constant for UNKNOWN, how about using 1 - MANIFOLD_UNKNOWN=1 - retcod2={'code':MANIFOLD_UNKNOWN,'description':"%s"%e} - if retcod2['code']==0: - info="" - if message: info += message+" " + # xxx we need a constant constant for UNKNOWN, how about using + # 1 + MANIFOLD_UNKNOWN = 1 + retcod2 = {'code': MANIFOLD_UNKNOWN, 'description': "%s" % e} + if retcod2['code'] == 0: + info = "" + if message: + info += message + " " info += 'v2 upload OK' self.logger.info(info) return True # everything has failed, let's report - self.logger.error("Could not upload %s"%(message if message else "credential")) - self.logger.info(" V2 Update returned code %s and error >>%s<<"%(retcod2['code'],retcod2['description'])) + self.logger.error("Could not upload %s" % + (message if message else "credential")) + self.logger.info(" V2 Update returned code %s and error >>%s<<" % ( + retcod2['code'], retcod2['description'])) self.logger.debug("****** full retcod2") - for (k,v) in retcod2.items(): self.logger.debug("**** %s: %s"%(k,v)) + for k, v in retcod2.items(): + self.logger.debug("**** %s: %s" % (k, v)) return False except Exception as e: - if message: self.logger.error("Could not upload %s %s"%(message,e)) - else: self.logger.error("Could not upload credential %s"%e) + if message: + self.logger.error("Could not upload %s %s" % (message, e)) + else: + self.logger.error("Could not upload credential %s" % e) if self.logger.debugEnabled(): import traceback traceback.print_exc() return False -### this is mainly for unit testing this class but can come in handy as well -def main (): +# this is mainly for unit testing this class but can come in handy as well + + +def main(): from argparse import ArgumentParser - parser = ArgumentParser (description="manifoldupoader simple tester.") - parser.add_argument ('credential_files',metavar='FILE',type=str,nargs='+', - help="the filenames to upload") - parser.add_argument ('-u','--url',dest='url', action='store',default=None, - help='the URL of the manifold API') - parser.add_argument ('-p','--platform',dest='platform',action='store',default=None, - help='the manifold platform name') - parser.add_argument ('-U','--user',dest='username',action='store',default=None, - help='the manifold username') - parser.add_argument ('-P','--password',dest='password',action='store',default=None, - help='the manifold password') - parser.add_argument ('-v','--verbose',dest='verbose',action='count',default=0, - help='more and more verbose') - args = parser.parse_args () - + parser = ArgumentParser(description="manifoldupoader simple tester.") + parser.add_argument('credential_files', metavar='FILE', type=str, nargs='+', + help="the filenames to upload") + parser.add_argument('-u', '--url', dest='url', action='store', default=None, + help='the URL of the manifold API') + parser.add_argument('-p', '--platform', dest='platform', action='store', default=None, + help='the manifold platform name') + parser.add_argument('-U', '--user', dest='username', action='store', default=None, + help='the manifold username') + parser.add_argument('-P', '--password', dest='password', action='store', default=None, + help='the manifold password') + parser.add_argument('-v', '--verbose', dest='verbose', action='count', default=0, + help='more and more verbose') + args = parser.parse_args() + from sfa.util.sfalogging import sfi_logger sfi_logger.enable_console() sfi_logger.setLevelFromOptVerbose(args.verbose) - uploader = ManifoldUploader (url=args.url, platform=args.platform, - username=args.username, password=args.password, - logger=sfi_logger) + uploader = ManifoldUploader(url=args.url, platform=args.platform, + username=args.username, password=args.password, + logger=sfi_logger) for filename in args.credential_files: with open(filename) as f: - result=uploader.upload (f.read(),filename) - sfi_logger.info('... result=%s'%result) + result = uploader.upload(f.read(), filename) + sfi_logger.info('... result=%s' % result) if __name__ == '__main__': main() diff --git a/sfa/client/multiclient.py b/sfa/client/multiclient.py index 6bdf2b27..3939994b 100644 --- a/sfa/client/multiclient.py +++ b/sfa/client/multiclient.py @@ -6,6 +6,7 @@ import time from Queue import Queue from sfa.util.sfalogging import logger + def ThreadedMethod(callable, results, errors): """ A function decorator that returns a running thread. The thread @@ -13,20 +14,20 @@ def ThreadedMethod(callable, results, errors): results queue """ def wrapper(args, kwds): - class ThreadInstance(threading.Thread): + class ThreadInstance(threading.Thread): + def run(self): try: results.put(callable(*args, **kwds)) except Exception as e: logger.log_exc('MultiClient: Error in thread: ') errors.put(traceback.format_exc()) - + thread = ThreadInstance() thread.start() return thread return wrapper - class MultiClient: """ @@ -39,7 +40,7 @@ class MultiClient: self.errors = Queue() self.threads = [] - def run (self, method, *args, **kwds): + def run(self, method, *args, **kwds): """ Execute a callable in a separate thread. """ @@ -68,11 +69,11 @@ class MultiClient: results = [] if not lenient: errors = self.get_errors() - if errors: + if errors: raise Exception(errors[0]) while not self.results.empty(): - results.append(self.results.get()) + results.append(self.results.get()) return results def get_errors(self): @@ -90,24 +91,25 @@ class MultiClient: Get the value that should be returuned to the client. If there are errors then the first error is returned. If there are no errors, then the first result is returned """ - - + + if __name__ == '__main__': def f(name, n, sleep=1): nums = [] - for i in range(n, n+5): + for i in range(n, n + 5): print("%s: %s" % (name, i)) nums.append(i) time.sleep(sleep) return nums + def e(name, n, sleep=1): nums = [] - for i in range(n, n+3) + ['n', 'b']: + for i in range(n, n + 3) + ['n', 'b']: print("%s: 1 + %s:" % (name, i)) nums.append(i + 1) time.sleep(sleep) - return nums + return nums threads = MultiClient() threads.run(f, "Thread1", 10, 2) @@ -116,7 +118,6 @@ if __name__ == '__main__': #results = threads.get_results() #errors = threads.get_errors() - #print "Results:", results - #print "Errors:", errors + # print "Results:", results + # print "Errors:", errors results_xlenient = threads.get_results(lenient=False) - diff --git a/sfa/client/return_value.py b/sfa/client/return_value.py index c2c4f476..b69cd1c4 100644 --- a/sfa/client/return_value.py +++ b/sfa/client/return_value.py @@ -1,22 +1,21 @@ class ReturnValue(dict): - @staticmethod def get_code(return_value): - return ReturnValue.get_key_value('code', return_value) + return ReturnValue.get_key_value('code', return_value) @staticmethod def get_value(return_value): - return ReturnValue.get_key_value('value', return_value) + return ReturnValue.get_key_value('value', return_value) @staticmethod def get_output(return_value): - return ReturnValue.get_key_value('output', return_value) + return ReturnValue.get_key_value('output', return_value) @staticmethod def get_key_value(key, return_value): if isinstance(return_value, dict) and key in return_value: return return_value.get(key) else: - return return_value + return return_value diff --git a/sfa/client/sfaadmin.py b/sfa/client/sfaadmin.py index 49757b32..261218e1 100755 --- a/sfa/client/sfaadmin.py +++ b/sfa/client/sfaadmin.py @@ -9,7 +9,7 @@ from optparse import OptionParser from sfa.generic import Generic from sfa.util.xrn import Xrn -from sfa.storage.record import Record +from sfa.storage.record import Record from sfa.trust.hierarchy import Hierarchy from sfa.trust.gid import GID @@ -22,9 +22,10 @@ from sfa.client.sfi import save_records_to_file pprinter = PrettyPrinter(indent=4) try: - help_basedir=Hierarchy().basedir + help_basedir = Hierarchy().basedir except: - help_basedir='*unable to locate Hierarchy().basedir' + help_basedir = '*unable to locate Hierarchy().basedir' + def add_options(*args, **kwargs): def _decorator(func): @@ -32,7 +33,9 @@ def add_options(*args, **kwargs): return func return _decorator + class Commands(object): + def _get_commands(self): command_names = [] for attrib in dir(self): @@ -42,49 +45,51 @@ class Commands(object): class RegistryCommands(Commands): + def __init__(self, *args, **kwds): - self.api= Generic.the_flavour().make_api(interface='registry') - + self.api = Generic.the_flavour().make_api(interface='registry') + def version(self): - """Display the Registry version""" + """Display the Registry version""" version = self.api.manager.GetVersion(self.api, {}) pprinter.pprint(version) - @add_options('-x', '--xrn', dest='xrn', metavar='', help='authority to list (hrn/urn - mandatory)') - @add_options('-t', '--type', dest='type', metavar='', help='object type', default='all') - @add_options('-r', '--recursive', dest='recursive', metavar='', help='list all child records', + @add_options('-x', '--xrn', dest='xrn', metavar='', help='authority to list (hrn/urn - mandatory)') + @add_options('-t', '--type', dest='type', metavar='', help='object type', default='all') + @add_options('-r', '--recursive', dest='recursive', metavar='', help='list all child records', action='store_true', default=False) @add_options('-v', '--verbose', dest='verbose', action='store_true', default=False) def list(self, xrn, type=None, recursive=False, verbose=False): """List names registered at a given authority - possibly filtered by type""" - xrn = Xrn(xrn, type) + xrn = Xrn(xrn, type) options_dict = {'recursive': recursive} - records = self.api.manager.List(self.api, xrn.get_hrn(), options=options_dict) + records = self.api.manager.List( + self.api, xrn.get_hrn(), options=options_dict) list = filter_records(type, records) # terminal_render expects an options object - class Options: pass - options=Options() - options.verbose=verbose - terminal_render (list, options) + class Options: + pass + options = Options() + options.verbose = verbose + terminal_render(list, options) - @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') - @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) - @add_options('-o', '--outfile', dest='outfile', metavar='', help='save record to file') - @add_options('-f', '--format', dest='format', metavar='', type='choice', - choices=('text', 'xml', 'simple'), help='display record in different formats') + @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') + @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) + @add_options('-o', '--outfile', dest='outfile', metavar='', help='save record to file') + @add_options('-f', '--format', dest='format', metavar='', type='choice', + choices=('text', 'xml', 'simple'), help='display record in different formats') def show(self, xrn, type=None, format=None, outfile=None): """Display details for a registered object""" records = self.api.manager.Resolve(self.api, xrn, type, details=True) for record in records: sfa_record = Record(dict=record) - sfa_record.dump(format) + sfa_record.dump(format) if outfile: - save_records_to_file(outfile, records) - + save_records_to_file(outfile, records) - def _record_dict(self, xrn, type, email, key, - slices, researchers, pis, + def _record_dict(self, xrn, type, email, key, + slices, researchers, pis, url, description, extras): record_dict = {} if xrn: @@ -117,9 +122,8 @@ class RegistryCommands(Commands): record_dict.update(extras) return record_dict - @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn', default=None) - @add_options('-t', '--type', dest='type', metavar='', help='object type (mandatory)',) + @add_options('-t', '--type', dest='type', metavar='', help='object type (mandatory)') @add_options('-a', '--all', dest='all', metavar='', action='store_true', default=False, help='check all users GID') @add_options('-v', '--verbose', dest='verbose', metavar='', action='store_true', default=False, help='verbose mode: display user\'s hrn ') def check_gid(self, xrn=None, type=None, all=None, verbose=None): @@ -148,72 +152,69 @@ class RegistryCommands(Commands): ERROR = [] NOKEY = [] for record in records: - # get the pubkey stored in SFA DB - if record.reg_keys: - db_pubkey_str = record.reg_keys[0].key - try: - db_pubkey_obj = convert_public_key(db_pubkey_str) - except: - ERROR.append(record.hrn) - continue - else: - NOKEY.append(record.hrn) - continue - - # get the pubkey from the gid - gid_str = record.gid - gid_obj = GID(string = gid_str) - gid_pubkey_obj = gid_obj.get_pubkey() - - # Check if gid_pubkey_obj and db_pubkey_obj are the same - check = gid_pubkey_obj.is_same(db_pubkey_obj) - if check : - OK.append(record.hrn) - else: - NOK.append(record.hrn) + # get the pubkey stored in SFA DB + if record.reg_keys: + db_pubkey_str = record.reg_keys[0].key + try: + db_pubkey_obj = convert_public_key(db_pubkey_str) + except: + ERROR.append(record.hrn) + continue + else: + NOKEY.append(record.hrn) + continue + + # get the pubkey from the gid + gid_str = record.gid + gid_obj = GID(string=gid_str) + gid_pubkey_obj = gid_obj.get_pubkey() + + # Check if gid_pubkey_obj and db_pubkey_obj are the same + check = gid_pubkey_obj.is_same(db_pubkey_obj) + if check: + OK.append(record.hrn) + else: + NOK.append(record.hrn) if not verbose: print("Users NOT having a PubKey: %s\n\ Users having a non RSA PubKey: %s\n\ Users having a GID/PubKey correpondence OK: %s\n\ -Users having a GID/PubKey correpondence Not OK: %s\n"%(len(NOKEY), len(ERROR), len(OK), len(NOK))) +Users having a GID/PubKey correpondence Not OK: %s\n" % (len(NOKEY), len(ERROR), len(OK), len(NOK))) else: print("Users NOT having a PubKey: %s and are: \n%s\n\n\ Users having a non RSA PubKey: %s and are: \n%s\n\n\ Users having a GID/PubKey correpondence OK: %s and are: \n%s\n\n\ -Users having a GID/PubKey correpondence NOT OK: %s and are: \n%s\n\n"%(len(NOKEY),NOKEY, len(ERROR), ERROR, len(OK), OK, len(NOK), NOK)) - - +Users having a GID/PubKey correpondence NOT OK: %s and are: \n%s\n\n" % (len(NOKEY), NOKEY, len(ERROR), ERROR, len(OK), OK, len(NOK), NOK)) - @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') - @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) + @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') + @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) @add_options('-e', '--email', dest='email', default="", help="email (mandatory for users)") @add_options('-u', '--url', dest='url', metavar='', default=None, help="URL, useful for slices") - @add_options('-d', '--description', dest='description', metavar='', + @add_options('-d', '--description', dest='description', metavar='', help='Description, useful for slices', default=None) - @add_options('-k', '--key', dest='key', metavar='', help='public key string or file', + @add_options('-k', '--key', dest='key', metavar='', help='public key string or file', default=None) - @add_options('-s', '--slices', dest='slices', metavar='', help='Set/replace slice xrns', + @add_options('-s', '--slices', dest='slices', metavar='', help='Set/replace slice xrns', default='', type="str", action='callback', callback=optparse_listvalue_callback) - @add_options('-r', '--researchers', dest='researchers', metavar='', help='Set/replace slice researchers', + @add_options('-r', '--researchers', dest='researchers', metavar='', help='Set/replace slice researchers', default='', type="str", action='callback', callback=optparse_listvalue_callback) - @add_options('-p', '--pis', dest='pis', metavar='', - help='Set/replace Principal Investigators/Project Managers', + @add_options('-p', '--pis', dest='pis', metavar='', + help='Set/replace Principal Investigators/Project Managers', default='', type="str", action='callback', callback=optparse_listvalue_callback) - @add_options('-X','--extra',dest='extras',default={},type='str',metavar="", - action="callback", callback=optparse_dictvalue_callback, nargs=1, + @add_options('-X', '--extra', dest='extras', default={}, type='str', metavar="", + action="callback", callback=optparse_dictvalue_callback, nargs=1, help="set extra/testbed-dependent flags, e.g. --extra enabled=true") - def register(self, xrn, type=None, email='', key=None, + def register(self, xrn, type=None, email='', key=None, slices='', pis='', researchers='', url=None, description=None, extras={}): """Create a new Registry record""" - record_dict = self._record_dict(xrn=xrn, type=type, email=email, key=key, + record_dict = self._record_dict(xrn=xrn, type=type, email=email, key=key, slices=slices, researchers=researchers, pis=pis, url=url, description=description, extras=extras) - self.api.manager.Register(self.api, record_dict) - + self.api.manager.Register(self.api, record_dict) @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) @@ -229,33 +230,32 @@ Users having a GID/PubKey correpondence NOT OK: %s and are: \n%s\n\n"%(len(NOKEY @add_options('-p', '--pis', dest='pis', metavar='', help='Set/replace Principal Investigators/Project Managers', default='', type="str", action='callback', callback=optparse_listvalue_callback) - @add_options('-X','--extra',dest='extras',default={},type='str',metavar="", - action="callback", callback=optparse_dictvalue_callback, nargs=1, + @add_options('-X', '--extra', dest='extras', default={}, type='str', metavar="", + action="callback", callback=optparse_dictvalue_callback, nargs=1, help="set extra/testbed-dependent flags, e.g. --extra enabled=true") - def update(self, xrn, type=None, email='', key=None, + def update(self, xrn, type=None, email='', key=None, slices='', pis='', researchers='', url=None, description=None, extras={}): - """Update an existing Registry record""" - record_dict = self._record_dict(xrn=xrn, type=type, email=email, key=key, + """Update an existing Registry record""" + record_dict = self._record_dict(xrn=xrn, type=type, email=email, key=key, slices=slices, researchers=researchers, pis=pis, url=url, description=description, extras=extras) self.api.manager.Update(self.api, record_dict) - - @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') - @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) + + @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') + @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) def remove(self, xrn, type=None): """Remove given object from the registry""" xrn = Xrn(xrn, type) - self.api.manager.Remove(self.api, xrn) + self.api.manager.Remove(self.api, xrn) - - @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') - @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) + @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') + @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) def credential(self, xrn, type=None): """Invoke GetCredential""" - cred = self.api.manager.GetCredential(self.api, xrn, type, self.api.hrn) + cred = self.api.manager.GetCredential( + self.api, xrn, type, self.api.hrn) print(cred) - def import_registry(self): """Run the importer""" @@ -266,23 +266,24 @@ Users having a GID/PubKey correpondence NOT OK: %s and are: \n%s\n\n"%(len(NOKEY def sync_db(self): """Initialize or upgrade the db""" from sfa.storage.dbschema import DBSchema - dbschema=DBSchema() + dbschema = DBSchema() dbschema.init_or_upgrade() - + @add_options('-a', '--all', dest='all', metavar='', action='store_true', default=False, help='Remove all registry records and all files in %s area' % help_basedir) @add_options('-c', '--certs', dest='certs', metavar='', action='store_true', default=False, - help='Remove all cached certs/gids found in %s' % help_basedir ) + help='Remove all cached certs/gids found in %s' % help_basedir) @add_options('-0', '--no-reinit', dest='reinit', metavar='', action='store_false', default=True, help='Prevents new DB schema from being installed after cleanup') def nuke(self, all=False, certs=False, reinit=True): """Cleanup local registry DB, plus various additional filesystem cleanups optionally""" from sfa.storage.dbschema import DBSchema from sfa.util.sfalogging import _SfaLogger - logger = _SfaLogger(logfile='/var/log/sfa_import.log', loggername='importlog') + logger = _SfaLogger( + logfile='/var/log/sfa_import.log', loggername='importlog') logger.setLevelFromOptVerbose(self.api.config.SFA_API_LOGLEVEL) logger.info("Purging SFA records from database") - dbschema=DBSchema() + dbschema = DBSchema() dbschema.nuke() # for convenience we re-create the schema here, so there's no need for an explicit @@ -292,31 +293,33 @@ Users having a GID/PubKey correpondence NOT OK: %s and are: \n%s\n\n"%(len(NOKEY logger.info("re-creating empty schema") dbschema.init_or_upgrade() - # remove the server certificate and all gids found in /var/lib/sfa/authorities + # remove the server certificate and all gids found in + # /var/lib/sfa/authorities if certs: logger.info("Purging cached certificates") for (dir, _, files) in os.walk('/var/lib/sfa/authorities'): for file in files: if file.endswith('.gid') or file == 'server.cert': - path=dir+os.sep+file + path = dir + os.sep + file os.unlink(path) # just remove all files that do not match 'server.key' or 'server.cert' if all: logger.info("Purging registry filesystem cache") - preserved_files = [ 'server.key', 'server.cert'] - for (dir,_,files) in os.walk(Hierarchy().basedir): + preserved_files = ['server.key', 'server.cert'] + for dir, _, files in os.walk(Hierarchy().basedir): for file in files: - if file in preserved_files: continue - path=dir+os.sep+file + if file in preserved_files: + continue + path = dir + os.sep + file os.unlink(path) - - + + class CertCommands(Commands): def __init__(self, *args, **kwds): - self.api= Generic.the_flavour().make_api(interface='registry') - + self.api = Generic.the_flavour().make_api(interface='registry') + def import_gid(self, xrn): pass @@ -324,12 +327,13 @@ class CertCommands(Commands): @add_options('-t', '--type', dest='type', metavar='', help='object type', default=None) @add_options('-o', '--outfile', dest='outfile', metavar='', help='output file', default=None) def export(self, xrn, type=None, outfile=None): - """Fetch an object's GID from the Registry""" + """Fetch an object's GID from the Registry""" from sfa.storage.model import RegRecord hrn = Xrn(xrn).get_hrn() - request=self.api.dbsession().query(RegRecord).filter_by(hrn=hrn) - if type: request = request.filter_by(type=type) - record=request.first() + request = self.api.dbsession().query(RegRecord).filter_by(hrn=hrn) + if type: + request = request.filter_by(type=type) + record = request.first() if record: gid = GID(string=record.gid) else: @@ -345,8 +349,8 @@ class CertCommands(Commands): if not outfile: outfile = os.path.abspath('./%s.gid' % gid.get_hrn()) gid.save_to_file(outfile, save_parents=True) - - @add_options('-g', '--gidfile', dest='gid', metavar='', help='path of gid file to display (mandatory)') + + @add_options('-g', '--gidfile', dest='gid', metavar='', help='path of gid file to display (mandatory)') def display(self, gidfile): """Print contents of a GID file""" gid_path = os.path.abspath(gidfile) @@ -355,70 +359,66 @@ class CertCommands(Commands): sys.exit(1) gid = GID(filename=gid_path) gid.dump(dump_parents=True) - + class AggregateCommands(Commands): def __init__(self, *args, **kwds): - self.api= Generic.the_flavour().make_api(interface='aggregate') - + self.api = Generic.the_flavour().make_api(interface='aggregate') + def version(self): """Display the Aggregate version""" version = self.api.manager.GetVersion(self.api, {}) pprinter.pprint(version) - @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') + @add_options('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') def status(self, xrn): """Retrieve the status of the slivers belonging to the named slice (Status)""" urns = [Xrn(xrn, 'slice').get_urn()] status = self.api.manager.Status(self.api, urns, [], {}) pprinter.pprint(status) - - @add_options('-r', '--rspec-version', dest='rspec_version', metavar='', - default='GENI', help='version/format of the resulting rspec response') + + @add_options('-r', '--rspec-version', dest='rspec_version', metavar='', + default='GENI', help='version/format of the resulting rspec response') def resources(self, rspec_version='GENI'): - """Display the available resources at an aggregate""" + """Display the available resources at an aggregate""" options = {'geni_rspec_version': rspec_version} print(options) resources = self.api.manager.ListResources(self.api, [], options) print(resources) - @add_options('-x', '--xrn', dest='xrn', metavar='', help='slice hrn/urn (mandatory)') @add_options('-r', '--rspec', dest='rspec', metavar='', help='rspec file (mandatory)') def allocate(self, xrn, rspec): """Allocate slivers""" xrn = Xrn(xrn, 'slice') - slice_urn=xrn.get_urn() + slice_urn = xrn.get_urn() rspec_string = open(rspec).read() - options={} - manifest = self.api.manager.Allocate(self.api, slice_urn, [], rspec_string, options) + options = {} + manifest = self.api.manager.Allocate( + self.api, slice_urn, [], rspec_string, options) print(manifest) - @add_options('-x', '--xrn', dest='xrn', metavar='', help='slice hrn/urn (mandatory)') def provision(self, xrn): """Provision slivers""" xrn = Xrn(xrn, 'slice') - slice_urn=xrn.get_urn() - options={} - manifest = self.api.manager.provision(self.api, [slice_urn], [], options) + slice_urn = xrn.get_urn() + options = {} + manifest = self.api.manager.provision( + self.api, [slice_urn], [], options) print(manifest) - - @add_options('-x', '--xrn', dest='xrn', metavar='', help='slice hrn/urn (mandatory)') def delete(self, xrn): - """Delete slivers""" + """Delete slivers""" self.api.manager.Delete(self.api, [xrn], [], {}) - - class SliceManagerCommands(AggregateCommands): - + def __init__(self, *args, **kwds): - self.api= Generic.the_flavour().make_api(interface='slicemgr') + self.api = Generic.the_flavour().make_api(interface='slicemgr') class SfaAdmin: @@ -429,55 +429,58 @@ class SfaAdmin: 'slicemgr': SliceManagerCommands} # returns (name,class) or (None,None) - def find_category (self, input): - full_name=Candidates (SfaAdmin.CATEGORIES.keys()).only_match(input) - if not full_name: return (None,None) - return (full_name,SfaAdmin.CATEGORIES[full_name]) + def find_category(self, input): + full_name = Candidates(SfaAdmin.CATEGORIES.keys()).only_match(input) + if not full_name: + return (None, None) + return (full_name, SfaAdmin.CATEGORIES[full_name]) - def summary_usage (self, category=None): + def summary_usage(self, category=None): print("Usage:", self.script_name + " category command []") - if category and category in SfaAdmin.CATEGORIES: - categories=[category] + if category and category in SfaAdmin.CATEGORIES: + categories = [category] else: - categories=SfaAdmin.CATEGORIES + categories = SfaAdmin.CATEGORIES for c in categories: - cls=SfaAdmin.CATEGORIES[c] - print("==================== category=%s"%c) - names=cls.__dict__.keys() + cls = SfaAdmin.CATEGORIES[c] + print("==================== category=%s" % c) + names = cls.__dict__.keys() names.sort() for name in names: - method=cls.__dict__[name] - if name.startswith('_'): continue - margin=15 - format="%%-%ds"%margin - print("%-15s"%name, end=' ') - doc=getattr(method,'__doc__',None) - if not doc: + method = cls.__dict__[name] + if name.startswith('_'): + continue + margin = 15 + format = "%%-%ds" % margin + print("%-15s" % name, end=' ') + doc = getattr(method, '__doc__', None) + if not doc: print("") continue - lines=[line.strip() for line in doc.split("\n")] - line1=lines.pop(0) + lines = [line.strip() for line in doc.split("\n")] + line1 = lines.pop(0) print(line1) - for extra_line in lines: print(margin*" ",extra_line) + for extra_line in lines: + print(margin * " ", extra_line) sys.exit(2) def main(self): argv = copy.deepcopy(sys.argv) self.script_name = argv.pop(0) - # ensure category is specified + # ensure category is specified if len(argv) < 1: self.summary_usage() # ensure category is valid category_input = argv.pop(0) - (category_name, category_class) = self.find_category (category_input) + (category_name, category_class) = self.find_category(category_input) if not category_name or not category_class: self.summary_usage(category_name) usage = "%%prog %s command [options]" % (category_name) parser = OptionParser(usage=usage) - - # ensure command is valid + + # ensure command is valid category_instance = category_class() commands = category_instance._get_commands() if len(argv) < 1: @@ -485,8 +488,8 @@ class SfaAdmin: command_name = '__call__' else: command_input = argv.pop(0) - command_name = Candidates (commands).only_match (command_input) - + command_name = Candidates(commands).only_match(command_input) + if command_name and hasattr(category_instance, command_name): command = getattr(category_instance, command_name) else: @@ -507,17 +510,18 @@ class SfaAdmin: # execute command try: - #print "invoking %s *=%s **=%s"%(command.__name__,cmd_args, cmd_kwds) + # print "invoking %s *=%s **=%s"%(command.__name__, cmd_args, + # cmd_kwds) command(*cmd_args, **cmd_kwds) sys.exit(0) except TypeError: print("Possible wrong number of arguments supplied") #import traceback - #traceback.print_exc() + # traceback.print_exc() print(command.__doc__) parser.print_help() sys.exit(1) - #raise + # raise except Exception: print("Command failed, please check log for more info") raise diff --git a/sfa/client/sfaclientlib.py b/sfa/client/sfaclientlib.py index b7114be4..cac4d1fd 100644 --- a/sfa/client/sfaclientlib.py +++ b/sfa/client/sfaclientlib.py @@ -10,13 +10,14 @@ from __future__ import print_function # certificates and automatically retrieve fresh ones when expired import sys -import os, os.path +import os +import os.path import subprocess from datetime import datetime from sfa.util.xrn import Xrn import sfa.util.sfalogging -# importing sfa.utils.faults does pull a lot of stuff +# importing sfa.utils.faults does pull a lot of stuff # OTOH it's imported from Certificate anyways, so.. from sfa.util.faults import RecordNotFound @@ -26,13 +27,13 @@ from sfa.client.sfaserverproxy import SfaServerProxy from sfa.trust.certificate import Keypair, Certificate from sfa.trust.credential import Credential from sfa.trust.gid import GID -########## +########## # a helper class to implement the bootstrapping of cryptoa. material -# assuming we are starting from scratch on the client side +# assuming we are starting from scratch on the client side # what's needed to complete a full slice creation cycle -# (**) prerequisites: -# (*) a local private key -# (*) the corresp. public key in the registry +# (**) prerequisites: +# (*) a local private key +# (*) the corresp. public key in the registry # (**) step1: a self-signed certificate # default filename is .sscert # (**) step2: a user credential @@ -47,88 +48,91 @@ from sfa.trust.gid import GID # From that point on, the GID is used as the SSL certificate # and the following can be done # -# (**) retrieve a slice (or authority) credential +# (**) retrieve a slice (or authority) credential # obtained at the registry with GetCredential # using the (step2) user-credential as credential # default filename is ..cred -# (**) retrieve a slice (or authority) GID +# (**) retrieve a slice (or authority) GID # obtained at the registry with Resolve # using the (step2) user-credential as credential # default filename is ..cred # -# (**) additionnally, it might make sense to upgrade a GID file +# (**) additionnally, it might make sense to upgrade a GID file # into a pkcs12 certificate usable in a browser # this bundled format allows for embedding the private key -# +# -########## Implementation notes +# Implementation notes # # (*) decorators # -# this implementation is designed as a guideline for +# this implementation is designed as a guideline for # porting to other languages # -# the decision to go for decorators aims at focusing +# the decision to go for decorators aims at focusing # on the core of what needs to be done when everything -# works fine, and to take caching and error management +# works fine, and to take caching and error management # out of the way -# -# for non-pythonic developers, it should be enough to +# +# for non-pythonic developers, it should be enough to # implement the bulk of this code, namely the _produce methods -# and to add caching and error management by whichever means -# is available, including inline +# and to add caching and error management by whichever means +# is available, including inline # # (*) self-signed certificates -# +# # still with other languages in mind, we've tried to keep the # dependencies to the rest of the code as low as possible -# +# # however this still relies on the sfa.trust.certificate module # for the initial generation of a self-signed-certificate that # is associated to the user's ssh-key -# (for user-friendliness, and for smooth operations with planetlab, +# (for user-friendliness, and for smooth operations with planetlab, # the usage model is to reuse an existing keypair) -# +# # there might be a more portable, i.e. less language-dependant way, to # implement this step by exec'ing the openssl command. -# a known successful attempt at this approach that worked +# a known successful attempt at this approach that worked # for Java is documented below # http://nam.ece.upatras.gr/fstoolkit/trac/wiki/JavaSFAClient # # (*) pkcs12 -# +# # the implementation of the pkcs12 wrapping, which is a late addition, # is done through direct calls to openssl # #################### -class SfaClientException(Exception): pass + +class SfaClientException(Exception): + pass + class SfaClientBootstrap: # dir is mandatory but defaults to '.' - def __init__(self, user_hrn, registry_url, dir=None, + def __init__(self, user_hrn, registry_url, dir=None, verbose=False, timeout=None, logger=None): self.hrn = user_hrn self.registry_url = registry_url if dir is None: - dir="." + dir = "." self.dir = dir self.verbose = verbose self.timeout = timeout # default for the logger is to use the global sfa logger - if logger is None: + if logger is None: logger = sfa.util.sfalogging.logger self.logger = logger - ######################################## *_produce methods - ### step1 + # *_produce methods + # step1 # unconditionnally create a self-signed certificate def self_signed_cert_produce(self, output): self.assert_private_key() private_key_filename = self.private_key_filename() keypair = Keypair(filename=private_key_filename) - self_signed = Certificate(subject = self.hrn) + self_signed = Certificate(subject=self.hrn) self_signed.set_pubkey(keypair) self_signed.set_issuer(keypair, self.hrn) self_signed.sign() @@ -137,7 +141,7 @@ class SfaClientBootstrap: .format(self.hrn, output)) return output - ### step2 + # step2 # unconditionnally retrieve my credential (GetSelfCredential) # we always use the self-signed-cert as the SSL cert def my_credential_produce(self, output): @@ -149,21 +153,25 @@ class SfaClientBootstrap: self.private_key_filename(), certificate_filename) try: - credential_string = registry_proxy.GetSelfCredential(certificate_string, self.hrn, "user") + credential_string = registry_proxy.GetSelfCredential( + certificate_string, self.hrn, "user") except: - # some urns hrns may replace non hierarchy delimiters '.' with an '_' instead of escaping the '.' - hrn = Xrn(self.hrn).get_hrn().replace('\.', '_') - credential_string = registry_proxy.GetSelfCredential(certificate_string, hrn, "user") + # some urns hrns may replace non hierarchy delimiters '.' with an + # '_' instead of escaping the '.' + hrn = Xrn(self.hrn).get_hrn().replace('\.', '_') + credential_string = registry_proxy.GetSelfCredential( + certificate_string, hrn, "user") self.plain_write(output, credential_string) - self.logger.debug("SfaClientBootstrap: Wrote result of GetSelfCredential in {}".format(output)) + self.logger.debug( + "SfaClientBootstrap: Wrote result of GetSelfCredential in {}".format(output)) return output - ### step3 - # unconditionnally retrieve my GID - use the general form + # step3 + # unconditionnally retrieve my GID - use the general form def my_gid_produce(self, output): return self.gid_produce(output, self.hrn, "user") - ### retrieve any credential (GetCredential) unconditionnal form + # retrieve any credential (GetCredential) unconditionnal form # we always use the GID as the SSL cert def credential_produce(self, output, hrn, type): self.assert_my_gid() @@ -173,9 +181,11 @@ class SfaClientBootstrap: certificate_filename) self.assert_my_credential() my_credential_string = self.my_credential_string() - credential_string = registry_proxy.GetCredential(my_credential_string, hrn, type) + credential_string = registry_proxy.GetCredential( + my_credential_string, hrn, type) self.plain_write(output, credential_string) - self.logger.debug("SfaClientBootstrap: Wrote result of GetCredential in {}".format(output)) + self.logger.debug( + "SfaClientBootstrap: Wrote result of GetCredential in {}".format(output)) return output def slice_credential_produce(self, output, hrn): @@ -184,16 +194,16 @@ class SfaClientBootstrap: def authority_credential_produce(self, output, hrn): return self.credential_produce(output, hrn, "authority") - ### retrieve any gid(Resolve) - unconditionnal form + # retrieve any gid(Resolve) - unconditionnal form # use my GID when available as the SSL cert, otherwise the self-signed - def gid_produce(self, output, hrn, type ): + def gid_produce(self, output, hrn, type): try: self.assert_my_gid() certificate_filename = self.my_gid_filename() except: self.assert_self_signed_cert() certificate_filename = self.self_signed_cert_filename() - + self.assert_private_key() registry_proxy = SfaServerProxy(self.registry_url, self.private_key_filename(), certificate_filename) @@ -201,24 +211,27 @@ class SfaClientBootstrap: records = registry_proxy.Resolve(hrn, credential_string) records = [record for record in records if record['type'] == type] if not records: - raise RecordNotFound("hrn {} ({}) unknown to registry {}".format(hrn, type, self.registry_url)) + raise RecordNotFound("hrn {} ({}) unknown to registry {}".format( + hrn, type, self.registry_url)) record = records[0] self.plain_write(output, record['gid']) - self.logger.debug("SfaClientBootstrap: Wrote GID for {} ({}) in {}".format(hrn, type, output)) + self.logger.debug( + "SfaClientBootstrap: Wrote GID for {} ({}) in {}".format(hrn, type, output)) return output # http://trac.myslice.info/wiki/MySlice/Developer/SFALogin -### produce a pkcs12 bundled certificate from GID and private key +# produce a pkcs12 bundled certificate from GID and private key # xxx for now we put a hard-wired password that's just, well, 'password' -# when leaving this empty on the mac, result can't seem to be loaded in keychain.. +# when leaving this empty on the mac, result can't seem to be loaded in +# keychain.. def my_pkcs12_produce(self, filename): password = raw_input("Enter password for p12 certificate: ") - openssl_command = ['openssl', 'pkcs12', "-export"] - openssl_command += [ "-password", "pass:{}".format(password) ] - openssl_command += [ "-inkey", self.private_key_filename()] - openssl_command += [ "-in", self.my_gid_filename()] - openssl_command += [ "-out", filename ] + openssl_command = ['openssl', 'pkcs12', "-export"] + openssl_command += ["-password", "pass:{}".format(password)] + openssl_command += ["-inkey", self.private_key_filename()] + openssl_command += ["-in", self.my_gid_filename()] + openssl_command += ["-out", filename] if subprocess.call(openssl_command) == 0: print("Successfully created {}".format(filename)) else: @@ -232,10 +245,9 @@ class SfaClientBootstrap: if cred.get_expiration() < datetime.utcnow(): valid = False return valid - - #################### public interface - + # public interface + # return my_gid, run all missing steps in the bootstrap sequence def bootstrap_my_gid(self): self.self_signed_cert() @@ -268,41 +280,49 @@ class SfaClientBootstrap: os.chmod(private_key_filename, os.stat(user_private_key).st_mode) self.logger.debug("SfaClientBootstrap: Copied private key from {} into {}" .format(user_private_key, private_key_filename)) - - #################### private details + + # private details # stupid stuff def fullpath(self, file): return os.path.join(self.dir, file) # the expected filenames for the various pieces - def private_key_filename(self): + def private_key_filename(self): return self.fullpath("{}.pkey".format(Xrn.unescape(self.hrn))) - def self_signed_cert_filename(self): + + def self_signed_cert_filename(self): return self.fullpath("{}.sscert".format(self.hrn)) + def my_credential_filename(self): return self.credential_filename(self.hrn, "user") # the tests use sfi -u ; meaning that the slice credential filename # needs to keep track of the user too - def credential_filename(self, hrn, type): + + def credential_filename(self, hrn, type): if type in ['user']: basename = "{}.{}.cred".format(hrn, type) else: basename = "{}-{}.{}.cred".format(self.hrn, hrn, type) return self.fullpath(basename) - def slice_credential_filename(self, hrn): + + def slice_credential_filename(self, hrn): return self.credential_filename(hrn, 'slice') - def authority_credential_filename(self, hrn): + + def authority_credential_filename(self, hrn): return self.credential_filename(hrn, 'authority') + def my_gid_filename(self): return self.gid_filename(self.hrn, "user") - def gid_filename(self, hrn, type): + + def gid_filename(self, hrn, type): return self.fullpath("{}.{}.gid".format(hrn, type)) + def my_pkcs12_filename(self): return self.fullpath("{}.p12".format(self.hrn)) # optimizing dependencies -# originally we used classes GID or Credential or Certificate -# like e.g. +# originally we used classes GID or Credential or Certificate +# like e.g. # return Credential(filename=self.my_credential()).save_to_string() # but in order to make it simpler to other implementations/languages.. def plain_read(self, filename): @@ -317,17 +337,19 @@ class SfaClientBootstrap: if not os.path.isfile(filename): raise IOError("Missing {} file {}".format(kind, filename)) return True - + def assert_private_key(self): return self.assert_filename(self.private_key_filename(), "private key") + def assert_self_signed_cert(self): return self.assert_filename(self.self_signed_cert_filename(), "self-signed certificate") + def assert_my_credential(self): return self.assert_filename(self.my_credential_filename(), "user's credential") + def assert_my_gid(self): return self.assert_filename(self.my_gid_filename(), "user's GID") - # decorator to make up the other methods def get_or_produce(filename_method, produce_method, validate_method=None): # default validator returns true @@ -337,18 +359,19 @@ class SfaClientBootstrap: if os.path.isfile(filename): if not validate_method: return filename - elif validate_method(self, filename): + elif validate_method(self, filename): return filename else: # remove invalid file - self.logger.warning("Removing {} - has expired".format(filename)) - os.unlink(filename) + self.logger.warning( + "Removing {} - has expired".format(filename)) + os.unlink(filename) try: produce_method(self, filename, *args, **kw) return filename except IOError: - raise - except : + raise + except: error = sys.exc_info()[:2] message = "Could not produce/retrieve {} ({} -- {})"\ .format(filename, error[0], error[1]) @@ -379,17 +402,18 @@ class SfaClientBootstrap: def authority_credential(self, hrn): pass @get_or_produce(gid_filename, gid_produce) - def gid(self, hrn, type ): pass - + def gid(self, hrn, type): pass # get the credentials as strings, for inserting as API arguments - def my_credential_string(self): + def my_credential_string(self): self.my_credential() return self.plain_read(self.my_credential_filename()) - def slice_credential_string(self, hrn): + + def slice_credential_string(self, hrn): self.slice_credential(hrn) return self.plain_read(self.slice_credential_filename(hrn)) - def authority_credential_string(self, hrn): + + def authority_credential_string(self, hrn): self.authority_credential(hrn) return self.plain_read(self.authority_credential_filename(hrn)) @@ -431,5 +455,6 @@ class SfaClientBootstrap: # to_gid = GID(to_gidfile ) # to_hrn = delegee_gid.get_hrn() # print 'to_hrn', to_hrn - delegated_credential = original_credential.delegate(to_gidfile, self.private_key(), my_gid) + delegated_credential = original_credential.delegate( + to_gidfile, self.private_key(), my_gid) return delegated_credential.save_to_string(save_parents=True) diff --git a/sfa/client/sfascan.py b/sfa/client/sfascan.py index 136835db..ce9717c3 100644 --- a/sfa/client/sfascan.py +++ b/sfa/client/sfascan.py @@ -1,6 +1,7 @@ from __future__ import print_function -import sys, os.path +import sys +import os.path import pickle import time import socket @@ -19,331 +20,366 @@ from sfa.client.sfi import Sfi from sfa.util.sfalogging import logger, DEBUG from sfa.client.sfaserverproxy import SfaServerProxy -def url_hostname_port (url): - if url.find("://")<0: - url="http://"+url - parsed_url=urlparse(url) + +def url_hostname_port(url): + if url.find("://") < 0: + url = "http://" + url + parsed_url = urlparse(url) # 0(scheme) returns protocol - default_port='80' - if parsed_url[0]=='https': default_port='443' + default_port = '80' + if parsed_url[0] == 'https': + default_port = '443' # 1(netloc) returns the hostname+port part - parts=parsed_url[1].split(":") + parts = parsed_url[1].split(":") # just a hostname - if len(parts)==1: - return (url,parts[0],default_port) + if len(parts) == 1: + return (url, parts[0], default_port) else: - return (url,parts[0],parts[1]) + return (url, parts[0], parts[1]) + +# a very simple cache mechanism so that successive runs (see make) +# will go *much* faster +# assuming everything is sequential, as simple as it gets +# { url -> (timestamp, version)} + -### a very simple cache mechanism so that successive runs (see make) -### will go *much* faster -### assuming everything is sequential, as simple as it gets -### { url -> (timestamp,version)} class VersionCache: # default expiration period is 1h - def __init__ (self, filename=None, expires=60*60): + + def __init__(self, filename=None, expires=60 * 60): # default is to store cache in the same dir as argv[0] if filename is None: - filename=os.path.join(os.path.dirname(sys.argv[0]),"sfascan-version-cache.pickle") - self.filename=filename - self.expires=expires - self.url2version={} + filename = os.path.join(os.path.dirname( + sys.argv[0]), "sfascan-version-cache.pickle") + self.filename = filename + self.expires = expires + self.url2version = {} self.load() - def load (self): + def load(self): try: - infile=open(self.filename,'r') - self.url2version=pickle.load(infile) + infile = open(self.filename, 'r') + self.url2version = pickle.load(infile) infile.close() except: logger.debug("Cannot load version cache, restarting from scratch") self.url2version = {} - logger.debug("loaded version cache with %d entries %s"%(len(self.url2version),self.url2version.keys())) + logger.debug("loaded version cache with %d entries %s" % (len(self.url2version), + self.url2version.keys())) - def save (self): + def save(self): try: - outfile=open(self.filename,'w') - pickle.dump(self.url2version,outfile) + outfile = open(self.filename, 'w') + pickle.dump(self.url2version, outfile) outfile.close() except: - logger.log_exc ("Cannot save version cache into %s"%self.filename) - def clean (self): + logger.log_exc("Cannot save version cache into %s" % self.filename) + + def clean(self): try: - retcod=os.unlink(self.filename) - logger.info("Cleaned up version cache %s, retcod=%d"%(self.filename,retcod)) + retcod = os.unlink(self.filename) + logger.info("Cleaned up version cache %s, retcod=%d" % + (self.filename, retcod)) except: - logger.info ("Could not unlink version cache %s"%self.filename) + logger.info("Could not unlink version cache %s" % self.filename) + + def show(self): + entries = len(self.url2version) + print("version cache from file %s has %d entries" % + (self.filename, entries)) + key_values = self.url2version.items() - def show (self): - entries=len(self.url2version) - print("version cache from file %s has %d entries"%(self.filename,entries)) - key_values=self.url2version.items() - def old_first (kv1,kv2): return int(kv1[1][0]-kv2[1][0]) + def old_first(kv1, kv2): return int(kv1[1][0] - kv2[1][0]) key_values.sort(old_first) for key_value in key_values: - (url,tuple) = key_value - (timestamp,version) = tuple - how_old = time.time()-timestamp - if how_old<=self.expires: - print(url,"-- %d seconds ago"%how_old) + (url, tuple) = key_value + (timestamp, version) = tuple + how_old = time.time() - timestamp + if how_old <= self.expires: + print(url, "-- %d seconds ago" % how_old) else: - print("OUTDATED",url,"(%d seconds ago, expires=%d)"%(how_old,self.expires)) - + print("OUTDATED", url, "(%d seconds ago, expires=%d)" % + (how_old, self.expires)) + # turns out we might have trailing slashes or not - def normalize (self, url): + def normalize(self, url): return url.strip("/") - - def set (self,url,version): - url=self.normalize(url) - self.url2version[url]=( time.time(), version) - def get (self,url): - url=self.normalize(url) + + def set(self, url, version): + url = self.normalize(url) + self.url2version[url] = (time.time(), version) + + def get(self, url): + url = self.normalize(url) try: - (timestamp,version)=self.url2version[url] - how_old = time.time()-timestamp - if how_old<=self.expires: return version - else: return None + (timestamp, version) = self.url2version[url] + how_old = time.time() - timestamp + if how_old <= self.expires: + return version + else: + return None except: return None ### # non-existing hostnames happen... # for better perfs we cache the result of gethostbyname too + + class Interface: - def __init__ (self,url,mentioned_in=None,verbose=False): - self._url=url - self.verbose=verbose - cache=VersionCache() - key="interface:%s"%url + def __init__(self, url, mentioned_in=None, verbose=False): + self._url = url + self.verbose = verbose + cache = VersionCache() + key = "interface:%s" % url try: - (self._url,self.hostname,self.port)=url_hostname_port(url) + (self._url, self.hostname, self.port) = url_hostname_port(url) # look for ip in the cache - tuple=cache.get(key) + tuple = cache.get(key) if tuple: (self.hostname, self.ip, self.port) = tuple else: - self.ip=socket.gethostbyname(self.hostname) + self.ip = socket.gethostbyname(self.hostname) except: - msg="can't resolve hostname %s\n\tfound in url %s"%(self.hostname,self._url) + msg = "can't resolve hostname %s\n\tfound in url %s" % ( + self.hostname, self._url) if mentioned_in: - msg += "\n\t(mentioned at %s)"%mentioned_in - logger.warning (msg) - self.hostname="unknown" - self.ip='0.0.0.0' - self.port="???" + msg += "\n\t(mentioned at %s)" % mentioned_in + logger.warning(msg) + self.hostname = "unknown" + self.ip = '0.0.0.0' + self.port = "???" - cache.set(key, (self.hostname, self.ip, self.port,) ) + cache.set(key, (self.hostname, self.ip, self.port,)) cache.save() - self.probed=False + self.probed = False # mark unknown interfaces as probed to avoid unnecessary attempts - if self.hostname=='unknown': + if self.hostname == 'unknown': # don't really try it - self.probed=True - self._version={} - + self.probed = True + self._version = {} def url(self): return self._url # this is used as a key for creating graph nodes and to avoid duplicates - def uid (self): - return "%s:%s"%(self.ip,self.port) + def uid(self): + return "%s:%s" % (self.ip, self.port) # connect to server and trigger GetVersion def get_version(self): - ### if we already know the answer: + # if we already know the answer: if self.probed: return self._version - ### otherwise let's look in the cache file - logger.debug("searching in version cache %s"%self.url()) + # otherwise let's look in the cache file + logger.debug("searching in version cache %s" % self.url()) cached_version = VersionCache().get(self.url()) if cached_version is not None: - logger.info("Retrieved version info from cache %s"%self.url()) + logger.info("Retrieved version info from cache %s" % self.url()) return cached_version - ### otherwise let's do the hard work + # otherwise let's do the hard work # dummy to meet Sfi's expectations for its 'options' field + class DummyOptions: pass - options=DummyOptions() - options.verbose=self.verbose - options.timeout=10 + options = DummyOptions() + options.verbose = self.verbose + options.timeout = 10 try: - client=Sfi(options) + client = Sfi(options) client.read_config() client.bootstrap() key_file = client.private_key cert_file = client.my_gid - logger.debug("using key %s & cert %s"%(key_file,cert_file)) - url=self.url() - logger.info('issuing GetVersion at %s'%url) + logger.debug("using key %s & cert %s" % (key_file, cert_file)) + url = self.url() + logger.info('issuing GetVersion at %s' % url) # setting timeout here seems to get the call to fail - even though the response time is fast #server=SfaServerProxy(url, key_file, cert_file, verbose=self.verbose, timeout=options.timeout) - server=SfaServerProxy(url, key_file, cert_file, verbose=self.verbose) - self._version=ReturnValue.get_value(server.GetVersion()) + server = SfaServerProxy( + url, key_file, cert_file, verbose=self.verbose) + self._version = ReturnValue.get_value(server.GetVersion()) except: logger.log_exc("failed to get version") - self._version={} + self._version = {} # so that next run from this process will find out - self.probed=True + self.probed = True # store in version cache so next processes will remember for an hour - cache=VersionCache() - cache.set(self.url(),self._version) + cache = VersionCache() + cache.set(self.url(), self._version) cache.save() - logger.debug("Saved version for url=%s in version cache"%self.url()) + logger.debug("Saved version for url=%s in version cache" % self.url()) # that's our result return self._version @staticmethod def multi_lines_label(*lines): - result='<
' + \ + result = '<
' + \ '
'.join(lines) + \ '
>' return result # default is for when we can't determine the type of the service - # typically the server is down, or we can't authenticate, or it's too old code - shapes = {"registry": "diamond", "slicemgr":"ellipse", "aggregate":"box", 'default':'plaintext'} - abbrevs = {"registry": "REG", "slicemgr":"SA", "aggregate":"AM", 'default':'[unknown interface]'} + # typically the server is down, or we can't authenticate, or it's too old + # code + shapes = {"registry": "diamond", "slicemgr": "ellipse", + "aggregate": "box", 'default': 'plaintext'} + abbrevs = {"registry": "REG", "slicemgr": "SA", + "aggregate": "AM", 'default': '[unknown interface]'} # return a dictionary that translates into the node's attr - def get_layout (self): - layout={} - ### retrieve cached GetVersion - version=self.get_version() - # set the href; xxx would make sense to try and 'guess' the web URL, not the API's one... - layout['href']=self.url() - ### set html-style label - ### see http://www.graphviz.org/doc/info/shapes.html#html + def get_layout(self): + layout = {} + # retrieve cached GetVersion + version = self.get_version() + # set the href; xxx would make sense to try and 'guess' the web URL, + # not the API's one... + layout['href'] = self.url() + # set html-style label + # see http://www.graphviz.org/doc/info/shapes.html#html # if empty the service is unreachable if not version: - label="offline" + label = "offline" else: - label='' - try: abbrev=Interface.abbrevs[version['interface']] - except: abbrev=Interface.abbrevs['default'] + label = '' + try: + abbrev = Interface.abbrevs[version['interface']] + except: + abbrev = Interface.abbrevs['default'] label += abbrev - if 'hrn' in version: label += " %s"%version['hrn'] - else: label += "[no hrn]" - if 'code_tag' in version: - label += " %s"%version['code_tag'] + if 'hrn' in version: + label += " %s" % version['hrn'] + else: + label += "[no hrn]" + if 'code_tag' in version: + label += " %s" % version['code_tag'] if 'testbed' in version: - label += " (%s)"%version['testbed'] - layout['label']=Interface.multi_lines_label(self.url(),label) - ### set shape - try: shape=Interface.shapes[version['interface']] - except: shape=Interface.shapes['default'] - layout['shape']=shape - ### fill color to outline wrongly configured or unreachable bodies + label += " (%s)" % version['testbed'] + layout['label'] = Interface.multi_lines_label(self.url(), label) + # set shape + try: + shape = Interface.shapes[version['interface']] + except: + shape = Interface.shapes['default'] + layout['shape'] = shape + # fill color to outline wrongly configured or unreachable bodies # as of sfa-2.0 registry doesn't have 'sfa' not 'geni_api', but have peers # slicemgr and aggregate have 'geni_api' and 'sfa' if 'geni_api' not in version and 'peers' not in version: - layout['style']='filled' - layout['fillcolor']='gray' + layout['style'] = 'filled' + layout['fillcolor'] = 'gray' return layout + class Scanner: # provide the entry points (a list of interfaces) - def __init__ (self, left_to_right=False, verbose=False): - self.verbose=verbose - self.left_to_right=left_to_right - - def graph (self,entry_points): - graph=pygraphviz.AGraph(directed=True) - if self.left_to_right: - graph.graph_attr['rankdir']='LR' - self.scan(entry_points,graph) + def __init__(self, left_to_right=False, verbose=False): + self.verbose = verbose + self.left_to_right = left_to_right + + def graph(self, entry_points): + graph = pygraphviz.AGraph(directed=True) + if self.left_to_right: + graph.graph_attr['rankdir'] = 'LR' + self.scan(entry_points, graph) return graph - + # scan from the given interfaces as entry points - def scan(self,interfaces,graph): - if not isinstance(interfaces,list): - interfaces=[interfaces] + def scan(self, interfaces, graph): + if not isinstance(interfaces, list): + interfaces = [interfaces] # remember node to interface mapping - node2interface={} + node2interface = {} # add entry points right away using the interface uid's as a key - to_scan=interfaces - for i in interfaces: + to_scan = interfaces + for i in interfaces: graph.add_node(i.uid()) - node2interface[graph.get_node(i.uid())]=i - scanned=[] + node2interface[graph.get_node(i.uid())] = i + scanned = [] # keep on looping until we reach a fixed point # don't worry about abels and shapes that will get fixed later on while to_scan: for interface in to_scan: # performing xmlrpc call - logger.info("retrieving/fetching version at interface %s"%interface.url()) - version=interface.get_version() + logger.info( + "retrieving/fetching version at interface %s" % interface.url()) + version = interface.get_version() if not version: - logger.info("") - else: - for (k,v) in version.iteritems(): - if not isinstance(v,dict): - logger.debug("\r\t%s:%s"%(k,v)) + logger.info( + "") + else: + for (k, v) in version.iteritems(): + if not isinstance(v, dict): + logger.debug("\r\t%s:%s" % (k, v)) else: logger.debug(k) - for (k1,v1) in v.iteritems(): - logger.debug("\r\t\t%s:%s"%(k1,v1)) + for (k1, v1) in v.iteritems(): + logger.debug("\r\t\t%s:%s" % (k1, v1)) # proceed with neighbours - if 'peers' in version: - for (next_name,next_url) in version['peers'].iteritems(): - next_interface=Interface(next_url,mentioned_in=interface.url()) + if 'peers' in version: + for (next_name, next_url) in version['peers'].iteritems(): + next_interface = Interface( + next_url, mentioned_in=interface.url()) # locate or create node in graph try: # if found, we're good with this one - next_node=graph.get_node(next_interface.uid()) + next_node = graph.get_node(next_interface.uid()) except: # otherwise, let's move on with it graph.add_node(next_interface.uid()) - next_node=graph.get_node(next_interface.uid()) - node2interface[next_node]=next_interface + next_node = graph.get_node(next_interface.uid()) + node2interface[next_node] = next_interface to_scan.append(next_interface) - graph.add_edge(interface.uid(),next_interface.uid()) + graph.add_edge(interface.uid(), next_interface.uid()) scanned.append(interface) to_scan.remove(interface) - # we've scanned the whole graph, let's get the labels and shapes right + # we've scanned the whole graph, let's get the labels and shapes + # right for node in graph.nodes(): - interface=node2interface.get(node,None) + interface = node2interface.get(node, None) if interface: - for (k,v) in interface.get_layout().iteritems(): - node.attr[k]=v + for (k, v) in interface.get_layout().iteritems(): + node.attr[k] = v else: - logger.error("MISSED interface with node %s"%node) - + logger.error("MISSED interface with node %s" % node) + class SfaScan: - default_outfiles=['sfa.png','sfa.svg','sfa.dot'] + default_outfiles = ['sfa.png', 'sfa.svg', 'sfa.dot'] def main(self): - usage="%prog [options] url-entry-point(s)" - parser=OptionParser(usage=usage) + usage = "%prog [options] url-entry-point(s)" + parser = OptionParser(usage=usage) parser.add_option("-d", "--dir", dest="sfi_dir", help="config & working directory - default is " + Sfi.default_sfi_dir(), metavar="PATH", default=Sfi.default_sfi_dir()) - parser.add_option("-o","--output",action='append',dest='outfiles',default=[], - help="output filenames (cumulative) - defaults are %r"%SfaScan.default_outfiles) - parser.add_option("-l","--left-to-right",action="store_true",dest="left_to_right",default=False, + parser.add_option("-o", "--output", action='append', dest='outfiles', default=[], + help="output filenames (cumulative) - defaults are %r" % SfaScan.default_outfiles) + parser.add_option("-l", "--left-to-right", action="store_true", dest="left_to_right", default=False, help="instead of top-to-bottom") parser.add_option("-v", "--verbose", action="count", dest="verbose", default=0, help="verbose - can be repeated for more verbosity") - parser.add_option("-c", "--clean-cache",action='store_true', - dest='clean_cache',default=False, + parser.add_option("-c", "--clean-cache", action='store_true', + dest='clean_cache', default=False, help='clean/trash version cache and exit') - parser.add_option("-s","--show-cache",action='store_true', - dest='show_cache',default=False, + parser.add_option("-s", "--show-cache", action='store_true', + dest='show_cache', default=False, help='show/display version cache') - - (options,args)=parser.parse_args() + + (options, args) = parser.parse_args() logger.enable_console() # apply current verbosity to logger logger.setLevelFromOptVerbose(options.verbose) - # figure if we need to be verbose for these local classes that only have a bool flag - bool_verbose=logger.getBoolVerboseFromOpt(options.verbose) - - if options.show_cache: + # figure if we need to be verbose for these local classes that only + # have a bool flag + bool_verbose = logger.getBoolVerboseFromOpt(options.verbose) + + if options.show_cache: VersionCache().show() sys.exit(0) if options.clean_cache: @@ -352,21 +388,23 @@ class SfaScan: if not args: parser.print_help() sys.exit(1) - + if not options.outfiles: - options.outfiles=SfaScan.default_outfiles - scanner=Scanner(left_to_right=options.left_to_right, verbose=bool_verbose) - entries = [ Interface(entry,mentioned_in="command line") for entry in args ] + options.outfiles = SfaScan.default_outfiles + scanner = Scanner(left_to_right=options.left_to_right, + verbose=bool_verbose) + entries = [Interface(entry, mentioned_in="command line") + for entry in args] try: - g=scanner.graph(entries) + g = scanner.graph(entries) logger.info("creating layout") g.layout(prog='dot') for outfile in options.outfiles: - logger.info("drawing in %s"%outfile) + logger.info("drawing in %s" % outfile) g.draw(outfile) logger.info("done") # test mode when pygraphviz is not available except: - entry=entries[0] - print("GetVersion at %s returned %s"%(entry.url(),entry.get_version())) - + entry = entries[0] + print("GetVersion at %s returned %s" % + (entry.url(), entry.get_version())) diff --git a/sfa/client/sfaserverproxy.py b/sfa/client/sfaserverproxy.py index e281f66d..73aefc8c 100644 --- a/sfa/client/sfaserverproxy.py +++ b/sfa/client/sfaserverproxy.py @@ -2,8 +2,10 @@ # starting with 2.7.9 we need to turn off server verification import ssl -try: turn_off_server_verify = { 'context' : ssl._create_unverified_context() } -except: turn_off_server_verify = {} +try: + turn_off_server_verify = {'context': ssl._create_unverified_context()} +except: + turn_off_server_verify = {} from sfa.util.py23 import xmlrpc_client from sfa.util.py23 import http_client @@ -20,10 +22,13 @@ except: # Used to convert server exception strings back to an exception. # from usenet, Raghuram Devarakonda + class ServerException(Exception): pass + class ExceptionUnmarshaller(xmlrpc_client.Unmarshaller): + def close(self): try: return xmlrpc_client.Unmarshaller.close(self) @@ -37,20 +42,21 @@ class ExceptionUnmarshaller(xmlrpc_client.Unmarshaller): # targetting only python-2.7 we can get rid of some older code + class XMLRPCTransport(xmlrpc_client.Transport): - - def __init__(self, key_file = None, cert_file = None, timeout = None): + + def __init__(self, key_file=None, cert_file=None, timeout=None): xmlrpc_client.Transport.__init__(self) - self.timeout=timeout + self.timeout = timeout self.key_file = key_file self.cert_file = cert_file - + def make_connection(self, host): # create a HTTPS connection object from a host descriptor # host may be a string, or a (host, x509-dict) tuple host, extra_headers, x509 = self.get_host_info(host) - conn = http_client.HTTPSConnection(host, None, key_file = self.key_file, - cert_file = self.cert_file, + conn = http_client.HTTPSConnection(host, None, key_file=self.key_file, + cert_file=self.cert_file, **turn_off_server_verify) # Some logic to deal with timeouts. It appears that some (or all) versions @@ -77,7 +83,9 @@ class XMLRPCTransport(xmlrpc_client.Transport): parser = xmlrpc_client.ExpatParser(unmarshaller) return parser, unmarshaller + class XMLRPCServerProxy(xmlrpc_client.ServerProxy): + def __init__(self, url, transport, allow_none=True, verbose=False): # remember url for GetVersion # xxx not sure this is still needed as SfaServerProxy has this too @@ -90,10 +98,12 @@ class XMLRPCServerProxy(xmlrpc_client.ServerProxy): logger.debug("xml-rpc %s method:%s" % (self.url, attr)) return xmlrpc_client.ServerProxy.__getattr__(self, attr) -########## the object on which we can send methods that get sent over xmlrpc +# the object on which we can send methods that get sent over xmlrpc + + class SfaServerProxy: - def __init__ (self, url, keyfile, certfile, verbose=False, timeout=None): + def __init__(self, url, keyfile, certfile, verbose=False, timeout=None): self.url = url self.keyfile = keyfile self.certfile = certfile @@ -101,9 +111,10 @@ class SfaServerProxy: self.timeout = timeout # an instance of xmlrpc_client.ServerProxy transport = XMLRPCTransport(keyfile, certfile, timeout) - self.serverproxy = XMLRPCServerProxy(url, transport, allow_none=True, verbose=verbose) + self.serverproxy = XMLRPCServerProxy( + url, transport, allow_none=True, verbose=verbose) - # this is python magic to return the code to run when + # this is python magic to return the code to run when # SfaServerProxy receives a method call # so essentially we send the same method with identical arguments # to the server_proxy object @@ -111,4 +122,3 @@ class SfaServerProxy: def func(*args, **kwds): return getattr(self.serverproxy, name)(*args, **kwds) return func - diff --git a/sfa/client/sfi.py b/sfa/client/sfi.py index 95acf181..c32669c9 100644 --- a/sfa/client/sfi.py +++ b/sfa/client/sfi.py @@ -8,7 +8,8 @@ from __future__ import print_function import sys sys.path.append('.') -import os, os.path +import os +import os.path import socket import re import datetime @@ -52,9 +53,11 @@ CM_PORT = 12346 DEFAULT_RSPEC_VERSION = "GENI 3" from sfa.client.common import optparse_listvalue_callback, optparse_dictvalue_callback, \ - terminal_render, filter_records + terminal_render, filter_records # display methods + + def display_rspec(rspec, format='rspec'): if format in ['dns']: tree = etree.parse(StringIO(rspec)) @@ -72,15 +75,18 @@ def display_rspec(rspec, format='rspec'): print(result) return + def display_list(results): for result in results: print(result) + def display_records(recordList, dump=False): ''' Print all fields in the record''' for record in recordList: display_record(record, dump) + def display_record(record, dump=False): if dump: record.dump(sort=True) @@ -109,14 +115,18 @@ def credential_printable(cred): result += "rights={}\n".format(rights) return result + def show_credentials(cred_s): - if not isinstance(cred_s, list): cred_s = [cred_s] + if not isinstance(cred_s, list): + cred_s = [cred_s] for cred in cred_s: print("Using Credential {}".format(credential_printable(cred))) -########## save methods +# save methods + +# raw + -### raw def save_raw_to_file(var, filename, format='text', banner=None): if filename == '-': _save_raw_to_file(var, sys.stdout, format, banner) @@ -125,11 +135,14 @@ def save_raw_to_file(var, filename, format='text', banner=None): _save_raw_to_file(var, fileobj, format, banner) print("(Over)wrote {}".format(filename)) + def _save_raw_to_file(var, f, format, banner): if format == "text": - if banner: f.write(banner+"\n") + if banner: + f.write(banner + "\n") f.write("{}".format(var)) - if banner: f.write('\n'+banner+"\n") + if banner: + f.write('\n' + banner + "\n") elif format == "pickled": f.write(pickle.dumps(var)) elif format == "json": @@ -138,7 +151,9 @@ def _save_raw_to_file(var, f, format, banner): # this should never happen print("unknown output format", format) -### +### + + def save_rspec_to_file(rspec, filename): if not filename.endswith(".rspec"): filename = filename + ".rspec" @@ -146,6 +161,7 @@ def save_rspec_to_file(rspec, filename): f.write("{}".format(rspec)) print("(Over)wrote {}".format(filename)) + def save_record_to_file(filename, record_dict): record = Record(dict=record_dict) xml = record.save_as_xml() @@ -153,6 +169,7 @@ def save_record_to_file(filename, record_dict): f.write(xml) print("(Over)wrote {}".format(filename)) + def save_records_to_file(filename, record_dicts, format="xml"): if format == "xml": for index, record_dict in enumerate(record_dicts): @@ -162,7 +179,8 @@ def save_records_to_file(filename, record_dicts, format="xml"): f.write("\n") for record_dict in record_dicts: record_obj = Record(dict=record_dict) - f.write('\n') + f.write('\n') f.write("\n") print("(Over)wrote {}".format(filename)) @@ -178,11 +196,15 @@ def save_records_to_file(filename, record_dicts, format="xml"): print("unknown output format", format) # minimally check a key argument + + def check_ssh_key(key): good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$' return re.match(good_ssh_key, key, re.IGNORECASE) # load methods + + def normalize_type(type): if type.startswith('au'): return 'authority' @@ -200,6 +222,7 @@ def normalize_type(type): print('unknown type {} - should start with one of au|us|sl|no|ag|al'.format(type)) return None + def load_record_from_opts(options): record_dict = {} if hasattr(options, 'xrn') and options.xrn: @@ -216,7 +239,8 @@ def load_record_from_opts(options): except IOError: pubkey = options.key if not check_ssh_key(pubkey): - raise SfaInvalidArgument(name='key', msg="Could not find file, or wrong key format") + raise SfaInvalidArgument( + name='key', msg="Could not find file, or wrong key format") record_dict['reg-keys'] = [pubkey] if hasattr(options, 'slices') and options.slices: record_dict['slices'] = options.slices @@ -232,46 +256,52 @@ def load_record_from_opts(options): # handle extra settings record_dict.update(options.extras) - + return Record(dict=record_dict) + def load_record_from_file(filename): with codecs.open(filename, encoding="utf-8", mode="r") as f: xml_str = f.read() return Record(xml=xml_str) import uuid + + def unique_call_id(): return uuid.uuid4().urn -########## a simple model for maintaing 3 doc attributes per command (instead of just one) +# a simple model for maintaing 3 doc attributes per command (instead of just one) # essentially for the methods that implement a subcommand like sfi list # we need to keep track of # (*) doc a few lines that tell what it does, still located in __doc__ # (*) args_string a simple one-liner that describes mandatory arguments # (*) example well, one or several releant examples -# +# # since __doc__ only accounts for one, we use this simple mechanism below # however we keep doc in place for easier migration from functools import wraps # we use a list as well as a dict so we can keep track of the order -commands_list=[] -commands_dict={} +commands_list = [] +commands_dict = {} + def declare_command(args_string, example, aliases=None): - def wrap(m): - name=getattr(m, '__name__') - doc=getattr(m, '__doc__', "-- missing doc --") - doc=doc.strip(" \t\n") + def wrap(m): + name = getattr(m, '__name__') + doc = getattr(m, '__doc__', "-- missing doc --") + doc = doc.strip(" \t\n") commands_list.append(name) - # last item is 'canonical' name, so we can know which commands are aliases - command_tuple=(doc, args_string, example, name) - commands_dict[name]=command_tuple + # last item is 'canonical' name, so we can know which commands are + # aliases + command_tuple = (doc, args_string, example, name) + commands_dict[name] = command_tuple if aliases is not None: for alias in aliases: commands_list.append(alias) - commands_dict[alias]=command_tuple + commands_dict[alias] = command_tuple + @wraps(m) def new_method(*args, **kwds): return m(*args, **kwds) return new_method @@ -279,19 +309,22 @@ def declare_command(args_string, example, aliases=None): def remove_none_fields(record): - none_fields=[ k for (k, v) in record.items() if v is None ] - for k in none_fields: del record[k] + none_fields = [k for (k, v) in record.items() if v is None] + for k in none_fields: + del record[k] ########## + class Sfi: - + # dirty hack to make this class usable from the outside - required_options=['verbose', 'debug', 'registry', 'sm', 'auth', 'user', 'user_private_key'] + required_options = ['verbose', 'debug', 'registry', + 'sm', 'auth', 'user', 'user_private_key'] @staticmethod def default_sfi_dir(): - if os.path.isfile("./sfi_config"): + if os.path.isfile("./sfi_config"): return os.getcwd() else: return os.path.expanduser("~/.sfi/") @@ -302,7 +335,8 @@ class Sfi: pass def __init__(self, options=None): - if options is None: options=Sfi.DummyOptions() + if options is None: + options = Sfi.DummyOptions() for opt in Sfi.required_options: if not hasattr(options, opt): setattr(options, opt, None) @@ -313,21 +347,22 @@ class Sfi: self.authority = None self.logger = sfi_logger self.logger.enable_console() - ### various auxiliary material that we keep at hand - self.command=None - # need to call this other than just 'config' as we have a command/method with that name + # various auxiliary material that we keep at hand + self.command = None + # need to call this other than just 'config' as we have a + # command/method with that name self.config_instance = None self.config_file = None self.client_bootstrap = None - ### suitable if no reasonable command has been provided + # suitable if no reasonable command has been provided def print_commands_help(self, options): verbose = getattr(options, 'verbose') format3 = "%10s %-35s %s" format3offset = 47 - line = 80*'-' + line = 80 * '-' if not verbose: - print(format3%("command", "cmd_args", "description")) + print(format3 % ("command", "cmd_args", "description")) print(line) else: print(line) @@ -337,19 +372,19 @@ class Sfi: try: (doc, args_string, example, canonical) = commands_dict[command] except: - print("Cannot find info on command %s - skipped"%command) + print("Cannot find info on command %s - skipped" % command) continue if verbose: print(line) - if command==canonical: + if command == canonical: doc = doc.replace("\n", "\n" + format3offset * ' ') print(format3 % (command, args_string, doc)) if verbose: self.create_parser_command(command).print_help() else: - print(format3 % (command, "<>"%canonical, "")) - - ### now if a known command was found we can be more verbose on that one + print(format3 % (command, "<>" % canonical, "")) + + # now if a known command was found we can be more verbose on that one def print_help(self): print("==================== Generic sfi usage") self.sfi_parser.print_help() @@ -372,47 +407,46 @@ class Sfi: usage="sfi [sfi_options] command [cmd_options] [cmd_args]", description="Commands: {}".format(" ".join(commands_list))) parser.add_option("-r", "--registry", dest="registry", - help="root registry", metavar="URL", default=None) + help="root registry", metavar="URL", default=None) parser.add_option("-s", "--sliceapi", dest="sm", default=None, metavar="URL", - help="slice API - in general a SM URL, but can be used to talk to an aggregate") + help="slice API - in general a SM URL, but can be used to talk to an aggregate") parser.add_option("-R", "--raw", dest="raw", default=None, help="Save raw, unparsed server response to a file") parser.add_option("", "--rawformat", dest="rawformat", type="choice", help="raw file format ([text]|pickled|json)", default="text", - choices=("text","pickled","json")) + choices=("text", "pickled", "json")) parser.add_option("", "--rawbanner", dest="rawbanner", default=None, help="text string to write before and after raw output") parser.add_option("-d", "--dir", dest="sfi_dir", - help="config & working directory - default is %default", - metavar="PATH", default=Sfi.default_sfi_dir()) + help="config & working directory - default is %default", + metavar="PATH", default=Sfi.default_sfi_dir()) parser.add_option("-u", "--user", dest="user", - help="user name", metavar="HRN", default=None) + help="user name", metavar="HRN", default=None) parser.add_option("-a", "--auth", dest="auth", - help="authority name", metavar="HRN", default=None) + help="authority name", metavar="HRN", default=None) parser.add_option("-v", "--verbose", action="count", dest="verbose", default=0, - help="verbose mode - cumulative") + help="verbose mode - cumulative") parser.add_option("-D", "--debug", action="store_true", dest="debug", default=False, help="Debug (xml-rpc) protocol messages") # would it make sense to use ~/.ssh/id_rsa as a default here ? parser.add_option("-k", "--private-key", - action="store", dest="user_private_key", default=None, - help="point to the private key file to use if not yet installed in sfi_dir") + action="store", dest="user_private_key", default=None, + help="point to the private key file to use if not yet installed in sfi_dir") parser.add_option("-t", "--timeout", dest="timeout", default=None, - help="Amout of time to wait before timing out the request") - parser.add_option("-h", "--help", - action="store_true", dest="help", default=False, - help="one page summary on commands & exit") + help="Amout of time to wait before timing out the request") + parser.add_option("-h", "--help", + action="store_true", dest="help", default=False, + help="one page summary on commands & exit") parser.disable_interspersed_args() return parser - def create_parser_command(self, command): if command not in commands_dict: - msg="Invalid command\n" - msg+="Commands: " - msg += ','.join(commands_list) + msg = "Invalid command\n" + msg += "Commands: " + msg += ','.join(commands_list) self.logger.critical(msg) sys.exit(2) @@ -420,57 +454,62 @@ class Sfi: (_, args_string, __, canonical) = commands_dict[command] parser = OptionParser(add_help_option=False, - usage="sfi [sfi_options] {} [cmd_options] {}"\ + usage="sfi [sfi_options] {} [cmd_options] {}" .format(command, args_string)) - parser.add_option("-h","--help",dest='help',action='store_true',default=False, - help="Summary of one command usage") + parser.add_option("-h", "--help", dest='help', action='store_true', default=False, + help="Summary of one command usage") if canonical in ("config"): parser.add_option('-m', '--myslice', dest='myslice', action='store_true', default=False, help='how myslice config variables as well') if canonical in ("version"): - parser.add_option("-l","--local", + parser.add_option("-l", "--local", action="store_true", dest="version_local", default=False, help="display version of the local client") if canonical in ("version", "trusted", "introspect"): - parser.add_option("-R","--registry_interface", + parser.add_option("-R", "--registry_interface", action="store_true", dest="registry_interface", default=False, help="target the registry interface instead of slice interface") if canonical in ("register", "update"): - parser.add_option('-x', '--xrn', dest='xrn', metavar='', help='object hrn/urn (mandatory)') - parser.add_option('-t', '--type', dest='type', metavar='', help='object type (2 first chars is enough)', default=None) - parser.add_option('-e', '--email', dest='email', default="", help="email (mandatory for users)") - parser.add_option('-n', '--name', dest='name', default="", help="name (optional for authorities)") - parser.add_option('-k', '--key', dest='key', metavar='', help='public key string or file', + parser.add_option('-x', '--xrn', dest='xrn', + metavar='', help='object hrn/urn (mandatory)') + parser.add_option('-t', '--type', dest='type', metavar='', + help='object type (2 first chars is enough)', default=None) + parser.add_option('-e', '--email', dest='email', + default="", help="email (mandatory for users)") + parser.add_option('-n', '--name', dest='name', + default="", help="name (optional for authorities)") + parser.add_option('-k', '--key', dest='key', metavar='', help='public key string or file', default=None) parser.add_option('-s', '--slices', dest='slices', metavar='', help='Set/replace slice xrns', default='', type="str", action='callback', callback=optparse_listvalue_callback) - parser.add_option('-r', '--researchers', dest='reg_researchers', metavar='', - help='Set/replace slice researchers - use -r none to reset', default=None, type="str", action='callback', + parser.add_option('-r', '--researchers', dest='reg_researchers', metavar='', + help='Set/replace slice researchers - use -r none to reset', default=None, type="str", action='callback', callback=optparse_listvalue_callback) parser.add_option('-p', '--pis', dest='reg_pis', metavar='', help='Set/replace Principal Investigators/Project Managers', default='', type="str", action='callback', callback=optparse_listvalue_callback) - parser.add_option('-X','--extra',dest='extras',default={},type='str',metavar="", - action="callback", callback=optparse_dictvalue_callback, nargs=1, - help="set extra/testbed-dependent flags, e.g. --extra enabled=true") - - # user specifies remote aggregate/sm/component - if canonical in ("resources", "describe", "allocate", "provision", "delete", "allocate", "provision", - "action", "shutdown", "renew", "status"): - parser.add_option("-d", "--delegate", dest="delegate", default=None, - action="store_true", - help="Include a credential delegated to the user's root"+\ - "authority in set of credentials for this call") + parser.add_option('-X', '--extra', dest='extras', default={}, type='str', metavar="", + action="callback", callback=optparse_dictvalue_callback, nargs=1, + help="set extra/testbed-dependent flags, e.g. --extra enabled=true") + + # user specifies remote aggregate/sm/component + if canonical in ("resources", "describe", "allocate", "provision", "delete", "allocate", "provision", + "action", "shutdown", "renew", "status"): + parser.add_option("-d", "--delegate", dest="delegate", default=None, + action="store_true", + help="Include a credential delegated to the user's root" + + "authority in set of credentials for this call") # show_credential option - if canonical in ("list","resources", "describe", "provision", "allocate", "register","update","remove","delete","status","renew"): - parser.add_option("-C","--credential",dest='show_credential',action='store_true',default=False, + if canonical in ("list", "resources", "describe", "provision", "allocate", "register", + "update", "remove", "delete", "status", "renew"): + parser.add_option("-C", "--credential", dest='show_credential', action='store_true', default=False, help="show credential(s) used in human-readable form") if canonical in ("renew"): - parser.add_option("-l","--as-long-as-possible",dest='alap',action='store_true',default=False, + parser.add_option("-l", "--as-long-as-possible", dest='alap', action='store_true', default=False, help="renew as long as possible") # registy filter option if canonical in ("list", "show", "remove"): @@ -478,9 +517,9 @@ class Sfi: default="all", help="type filter - 2 first chars is enough ([all]|user|slice|authority|node|aggregate)") if canonical in ("show"): - parser.add_option("-k","--key",dest="keys",action="append",default=[], + parser.add_option("-k", "--key", dest="keys", action="append", default=[], help="specify specific keys to be displayed from record") - parser.add_option("-n","--no-details",dest="no_details",action="store_true",default=False, + parser.add_option("-n", "--no-details", dest="no_details", action="store_true", default=False, help="call Resolve without the 'details' option") if canonical in ("resources", "describe"): # rspec version @@ -488,72 +527,74 @@ class Sfi: help="schema type and version of resulting RSpec (default:{})".format(DEFAULT_RSPEC_VERSION)) # disable/enable cached rspecs parser.add_option("-c", "--current", dest="current", default=False, - action="store_true", + action="store_true", help="Request the current rspec bypassing the cache. Cached rspecs are returned by default") # display formats parser.add_option("-f", "--format", dest="format", type="choice", - help="display format ([xml]|dns|ip)", default="xml", - choices=("xml", "dns", "ip")) - #panos: a new option to define the type of information about resources a user is interested in + help="display format ([xml]|dns|ip)", default="xml", + choices=("xml", "dns", "ip")) + # panos: a new option to define the type of information about + # resources a user is interested in parser.add_option("-i", "--info", dest="info", - help="optional component information", default=None) - # a new option to retrieve or not reservation-oriented RSpecs (leases) + help="optional component information", default=None) + # a new option to retrieve or not reservation-oriented RSpecs + # (leases) parser.add_option("-l", "--list_leases", dest="list_leases", type="choice", - help="Retrieve or not reservation-oriented RSpecs ([resources]|leases|all)", - choices=("all", "resources", "leases"), default="resources") - + help="Retrieve or not reservation-oriented RSpecs ([resources]|leases|all)", + choices=("all", "resources", "leases"), default="resources") if canonical in ("resources", "describe", "allocate", "provision", "show", "list", "gid"): - parser.add_option("-o", "--output", dest="file", - help="output XML to file", metavar="FILE", default=None) + parser.add_option("-o", "--output", dest="file", + help="output XML to file", metavar="FILE", default=None) if canonical in ("show", "list"): - parser.add_option("-f", "--format", dest="format", type="choice", - help="display format ([text]|xml)", default="text", - choices=("text", "xml")) + parser.add_option("-f", "--format", dest="format", type="choice", + help="display format ([text]|xml)", default="text", + choices=("text", "xml")) - parser.add_option("-F", "--fileformat", dest="fileformat", type="choice", - help="output file format ([xml]|xmllist|hrnlist)", default="xml", - choices=("xml", "xmllist", "hrnlist")) + parser.add_option("-F", "--fileformat", dest="fileformat", type="choice", + help="output file format ([xml]|xmllist|hrnlist)", default="xml", + choices=("xml", "xmllist", "hrnlist")) if canonical == 'list': - parser.add_option("-r", "--recursive", dest="recursive", action='store_true', - help="list all child records", default=False) - parser.add_option("-v", "--verbose", dest="verbose", action='store_true', - help="gives details, like user keys", default=False) + parser.add_option("-r", "--recursive", dest="recursive", action='store_true', + help="list all child records", default=False) + parser.add_option("-v", "--verbose", dest="verbose", action='store_true', + help="gives details, like user keys", default=False) if canonical in ("delegate"): - parser.add_option("-u", "--user", - action="store_true", dest="delegate_user", default=False, - help="delegate your own credentials; default if no other option is provided") - parser.add_option("-s", "--slice", dest="delegate_slices",action='append',default=[], - metavar="slice_hrn", help="delegate cred. for slice HRN") - parser.add_option("-a", "--auths", dest='delegate_auths',action='append',default=[], - metavar='auth_hrn', help="delegate cred for auth HRN") - # this primarily is a shorthand for -A my_hrn^ - parser.add_option("-p", "--pi", dest='delegate_pi', default=None, action='store_true', - help="delegate your PI credentials, so s.t. like -A your_hrn^") - parser.add_option("-A","--to-authority",dest='delegate_to_authority',action='store_true',default=False, - help="""by default the mandatory argument is expected to be a user, + parser.add_option("-u", "--user", + action="store_true", dest="delegate_user", default=False, + help="delegate your own credentials; default if no other option is provided") + parser.add_option("-s", "--slice", dest="delegate_slices", action='append', default=[], + metavar="slice_hrn", help="delegate cred. for slice HRN") + parser.add_option("-a", "--auths", dest='delegate_auths', action='append', default=[], + metavar='auth_hrn', help="delegate cred for auth HRN") + # this primarily is a shorthand for -A my_hrn^ + parser.add_option("-p", "--pi", dest='delegate_pi', default=None, action='store_true', + help="delegate your PI credentials, so s.t. like -A your_hrn^") + parser.add_option("-A", "--to-authority", dest='delegate_to_authority', action='store_true', default=False, + help="""by default the mandatory argument is expected to be a user, use this if you mean an authority instead""") if canonical in ("myslice"): - parser.add_option("-p","--password",dest='password',action='store',default=None, + parser.add_option("-p", "--password", dest='password', action='store', default=None, help="specify mainfold password on the command line") - parser.add_option("-s", "--slice", dest="delegate_slices",action='append',default=[], - metavar="slice_hrn", help="delegate cred. for slice HRN") - parser.add_option("-a", "--auths", dest='delegate_auths',action='append',default=[], - metavar='auth_hrn', help="delegate PI cred for auth HRN") - parser.add_option('-d', '--delegate', dest='delegate', help="Override 'delegate' from the config file") - parser.add_option('-b', '--backend', dest='backend', help="Override 'backend' from the config file") - + parser.add_option("-s", "--slice", dest="delegate_slices", action='append', default=[], + metavar="slice_hrn", help="delegate cred. for slice HRN") + parser.add_option("-a", "--auths", dest='delegate_auths', action='append', default=[], + metavar='auth_hrn', help="delegate PI cred for auth HRN") + parser.add_option('-d', '--delegate', dest='delegate', + help="Override 'delegate' from the config file") + parser.add_option('-b', '--backend', dest='backend', + help="Override 'backend' from the config file") + return parser - # # Main: parse arguments and dispatch to command # def dispatch(self, command, command_options, command_args): (doc, args_string, example, canonical) = commands_dict[command] - method=getattr(self, canonical, None) + method = getattr(self, canonical, None) if not method: print("sfi: unknown command {}".format(command)) raise SystemExit("Unknown command {}".format(command)) @@ -566,7 +607,7 @@ use this if you mean an authority instead""") def main(self): self.sfi_parser = self.create_parser_global() (options, args) = self.sfi_parser.parse_args() - if options.help: + if options.help: self.print_commands_help(options) sys.exit(1) self.options = options @@ -577,7 +618,7 @@ use this if you mean an authority instead""") self.logger.critical("No command given. Use -h for help.") self.print_commands_help(options) return -1 - + # complete / find unique match with command set command_candidates = Candidates(commands_list) input = args[0] @@ -588,7 +629,8 @@ use this if you mean an authority instead""") # second pass options parsing self.command = command self.command_parser = self.create_parser_command(command) - (command_options, command_args) = self.command_parser.parse_args(args[1:]) + (command_options, command_args) = self.command_parser.parse_args( + args[1:]) if command_options.help: self.print_help() sys.exit(1) @@ -599,8 +641,8 @@ use this if you mean an authority instead""") command_options.type = normalize_type(command_options.type) if not command_options.type: sys.exit(1) - - self.read_config() + + self.read_config() self.bootstrap() self.logger.debug("Command={}".format(self.command)) @@ -612,24 +654,25 @@ use this if you mean an authority instead""") self.logger.log_exc("sfi command {} failed".format(command)) return 1 return retcod - + #################### def read_config(self): config_file = os.path.join(self.options.sfi_dir, "sfi_config") - shell_config_file = os.path.join(self.options.sfi_dir, "sfi_config.sh") + shell_config_file = os.path.join(self.options.sfi_dir, "sfi_config.sh") try: if Config.is_ini(config_file): config = Config(config_file) else: # try upgrading from shell config format - fp, fn = mkstemp(suffix='sfi_config', text=True) + fp, fn = mkstemp(suffix='sfi_config', text=True) config = Config(fn) - # we need to preload the sections we want parsed + # we need to preload the sections we want parsed # from the shell config config.add_section('sfi') - # sface users should be able to use this same file to configure their stuff + # sface users should be able to use this same file to configure + # their stuff config.add_section('sface') - # manifold users should be able to specify the details + # manifold users should be able to specify the details # of their backend server here for 'sfi myslice' config.add_section('myslice') config.load(config_file) @@ -637,57 +680,64 @@ use this if you mean an authority instead""") shutil.move(config_file, shell_config_file) # write new config config.save(config_file) - + except: - self.logger.critical("Failed to read configuration file {}".format(config_file)) - self.logger.info("Make sure to remove the export clauses and to add quotes") + self.logger.critical( + "Failed to read configuration file {}".format(config_file)) + self.logger.info( + "Make sure to remove the export clauses and to add quotes") if self.options.verbose == 0: self.logger.info("Re-run with -v for more details") else: - self.logger.log_exc("Could not read config file {}".format(config_file)) + self.logger.log_exc( + "Could not read config file {}".format(config_file)) sys.exit(1) - + self.config_instance = config errors = 0 # Set SliceMgr URL if (self.options.sm is not None): - self.sm_url = self.options.sm + self.sm_url = self.options.sm elif hasattr(config, "SFI_SM"): - self.sm_url = config.SFI_SM + self.sm_url = config.SFI_SM else: - self.logger.error("You need to set e.g. SFI_SM='http://your.slicemanager.url:12347/' in {}".format(config_file)) - errors += 1 + self.logger.error( + "You need to set e.g. SFI_SM='http://your.slicemanager.url:12347/' in {}".format(config_file)) + errors += 1 # Set Registry URL if (self.options.registry is not None): - self.reg_url = self.options.registry + self.reg_url = self.options.registry elif hasattr(config, "SFI_REGISTRY"): - self.reg_url = config.SFI_REGISTRY + self.reg_url = config.SFI_REGISTRY else: - self.logger.error("You need to set e.g. SFI_REGISTRY='http://your.registry.url:12345/' in {}".format(config_file)) - errors += 1 + self.logger.error( + "You need to set e.g. SFI_REGISTRY='http://your.registry.url:12345/' in {}".format(config_file)) + errors += 1 # Set user HRN if (self.options.user is not None): - self.user = self.options.user + self.user = self.options.user elif hasattr(config, "SFI_USER"): - self.user = config.SFI_USER + self.user = config.SFI_USER else: - self.logger.error("You need to set e.g. SFI_USER='plc.princeton.username' in {}".format(config_file)) - errors += 1 + self.logger.error( + "You need to set e.g. SFI_USER='plc.princeton.username' in {}".format(config_file)) + errors += 1 # Set authority HRN if (self.options.auth is not None): - self.authority = self.options.auth + self.authority = self.options.auth elif hasattr(config, "SFI_AUTH"): - self.authority = config.SFI_AUTH + self.authority = config.SFI_AUTH else: - self.logger.error("You need to set e.g. SFI_AUTH='plc.princeton' in {}".format(config_file)) - errors += 1 + self.logger.error( + "You need to set e.g. SFI_AUTH='plc.princeton' in {}".format(config_file)) + errors += 1 self.config_file = config_file if errors: - sys.exit(1) + sys.exit(1) # # Get various credential and spec files @@ -701,48 +751,51 @@ use this if you mean an authority instead""") # - bootstrap authority credential from user credential # - bootstrap slice credential from user credential # - + # init self-signed cert, user credentials and gid def bootstrap(self): if self.options.verbose: - self.logger.info("Initializing SfaClientBootstrap with {}".format(self.reg_url)) + self.logger.info( + "Initializing SfaClientBootstrap with {}".format(self.reg_url)) client_bootstrap = SfaClientBootstrap(self.user, self.reg_url, self.options.sfi_dir, - logger=self.logger) + logger=self.logger) # if -k is provided, use this to initialize private key if self.options.user_private_key: - client_bootstrap.init_private_key_if_missing(self.options.user_private_key) + client_bootstrap.init_private_key_if_missing( + self.options.user_private_key) else: - # trigger legacy compat code if needed + # trigger legacy compat code if needed # the name has changed from just .pkey to .pkey if not os.path.isfile(client_bootstrap.private_key_filename()): self.logger.info("private key not found, trying legacy name") try: legacy_private_key = os.path.join(self.options.sfi_dir, "{}.pkey" - .format(Xrn.unescape(get_leaf(self.user)))) + .format(Xrn.unescape(get_leaf(self.user)))) self.logger.debug("legacy_private_key={}" .format(legacy_private_key)) - client_bootstrap.init_private_key_if_missing(legacy_private_key) + client_bootstrap.init_private_key_if_missing( + legacy_private_key) self.logger.info("Copied private key from legacy location {}" .format(legacy_private_key)) except: self.logger.log_exc("Can't find private key ") sys.exit(1) - + # make it bootstrap client_bootstrap.bootstrap_my_gid() # extract what's needed self.private_key = client_bootstrap.private_key() self.my_credential_string = client_bootstrap.my_credential_string() self.my_credential = {'geni_type': 'geni_sfa', - 'geni_version': '3', + 'geni_version': '3', 'geni_value': self.my_credential_string} self.my_gid = client_bootstrap.my_gid() self.client_bootstrap = client_bootstrap - def my_authority_credential_string(self): if not self.authority: - self.logger.critical("no authority specified. Use -a or set SF_AUTH") + self.logger.critical( + "no authority specified. Use -a or set SF_AUTH") sys.exit(-1) return self.client_bootstrap.authority_credential_string(self.authority) @@ -755,16 +808,16 @@ use this if you mean an authority instead""") def slice_credential(self, name): return {'geni_type': 'geni_sfa', 'geni_version': '3', - 'geni_value': self.slice_credential_string(name)} + 'geni_value': self.slice_credential_string(name)} # xxx should be supported by sfaclientbootstrap as well def delegate_cred(self, object_cred, hrn, type='authority'): # the gid and hrn of the object we are delegating if isinstance(object_cred, str): - object_cred = Credential(string=object_cred) + object_cred = Credential(string=object_cred) object_gid = object_cred.get_gid_object() object_hrn = object_gid.get_hrn() - + if not object_cred.get_privileges().get_all_delegate(): self.logger.error("Object credential {} does not have delegate bit set" .format(object_hrn)) @@ -772,54 +825,60 @@ use this if you mean an authority instead""") # the delegating user's gid caller_gidfile = self.my_gid() - + # the gid of the user who will be delegated to delegee_gid = self.client_bootstrap.gid(hrn, type) delegee_hrn = delegee_gid.get_hrn() - dcred = object_cred.delegate(delegee_gid, self.private_key, caller_gidfile) + dcred = object_cred.delegate( + delegee_gid, self.private_key, caller_gidfile) return dcred.save_to_string(save_parents=True) - + # # Management of the servers - # + # def registry(self): # cache the result if not hasattr(self, 'registry_proxy'): self.logger.info("Contacting Registry at: {}".format(self.reg_url)) self.registry_proxy \ - = SfaServerProxy(self.reg_url, self.private_key, self.my_gid, - timeout=self.options.timeout, verbose=self.options.debug) + = SfaServerProxy(self.reg_url, self.private_key, self.my_gid, + timeout=self.options.timeout, verbose=self.options.debug) return self.registry_proxy def sliceapi(self): # cache the result if not hasattr(self, 'sliceapi_proxy'): - # if the command exposes the --component option, figure it's hostname and connect at CM_PORT + # if the command exposes the --component option, figure it's + # hostname and connect at CM_PORT if hasattr(self.command_options, 'component') and self.command_options.component: # resolve the hrn at the registry node_hrn = self.command_options.component records = self.registry().Resolve(node_hrn, self.my_credential_string) records = filter_records('node', records) if not records: - self.logger.warning("No such component:{}".format(opts.component)) + self.logger.warning( + "No such component:{}".format(opts.component)) record = records[0] cm_url = "http://{}:{}/".format(record['hostname'], CM_PORT) - self.sliceapi_proxy = SfaServerProxy(cm_url, self.private_key, self.my_gid) + self.sliceapi_proxy = SfaServerProxy( + cm_url, self.private_key, self.my_gid) else: - # otherwise use what was provided as --sliceapi, or SFI_SM in the config + # otherwise use what was provided as --sliceapi, or SFI_SM in + # the config if not self.sm_url.startswith('http://') or self.sm_url.startswith('https://'): self.sm_url = 'http://' + self.sm_url - self.logger.info("Contacting Slice Manager at: {}".format(self.sm_url)) + self.logger.info( + "Contacting Slice Manager at: {}".format(self.sm_url)) self.sliceapi_proxy \ - = SfaServerProxy(self.sm_url, self.private_key, self.my_gid, - timeout=self.options.timeout, verbose=self.options.debug) + = SfaServerProxy(self.sm_url, self.private_key, self.my_gid, + timeout=self.options.timeout, verbose=self.options.debug) return self.sliceapi_proxy def get_cached_server_version(self, server): # check local cache first cache = None - version = None + version = None cache_file = os.path.join(self.options.sfi_dir, 'sfi_cache.dat') cache_key = server.url + "-version" try: @@ -831,31 +890,31 @@ use this if you mean an authority instead""") if cache: version = cache.get(cache_key) - if not version: + if not version: result = server.GetVersion() version = ReturnValue.get_value(result) # cache version for 20 minutes - cache.add(cache_key, version, ttl=60*20) + cache.add(cache_key, version, ttl=60 * 20) self.logger.info("Updating cache file {}".format(cache_file)) cache.save_to_file(cache_file) - return version - - ### resurrect this temporarily so we can support V1 aggregates for a while + return version + + # resurrect this temporarily so we can support V1 aggregates for a while def server_supports_options_arg(self, server): """ Returns true if server support the optional call_id arg, false otherwise. """ server_version = self.get_cached_server_version(server) result = False - # xxx need to rewrite this + # xxx need to rewrite this if int(server_version.get('geni_api')) >= 2: result = True return result def server_supports_call_id_arg(self, server): server_version = self.get_cached_server_version(server) - result = False + result = False if 'sfa' in server_version and 'code_tag' in server_version: code_tag = server_version['code_tag'] code_tag_parts = code_tag.split("-") @@ -864,56 +923,58 @@ use this if you mean an authority instead""") rev = code_tag_parts[1] if int(major) == 1 and minor == 0 and build >= 22: result = True - return result + return result - ### ois = options if supported - # to be used in something like serverproxy.Method(arg1, arg2, *self.ois(api_options)) + # ois = options if supported + # to be used in something like serverproxy.Method(arg1, arg2, + # *self.ois(api_options)) def ois(self, server, option_dict): - if self.server_supports_options_arg(server): + if self.server_supports_options_arg(server): return [option_dict] elif self.server_supports_call_id_arg(server): - return [ unique_call_id() ] - else: + return [unique_call_id()] + else: return [] - ### cis = call_id if supported - like ois + # cis = call_id if supported - like ois def cis(self, server): if self.server_supports_call_id_arg(server): - return [ unique_call_id ] + return [unique_call_id] else: return [] - ######################################## miscell utilities + # miscell utilities def get_rspec_file(self, rspec): - if (os.path.isabs(rspec)): - file = rspec - else: - file = os.path.join(self.options.sfi_dir, rspec) - if (os.path.isfile(file)): - return file - else: - self.logger.critical("No such rspec file {}".format(rspec)) - sys.exit(1) - - def get_record_file(self, record): - if (os.path.isabs(record)): - file = record - else: - file = os.path.join(self.options.sfi_dir, record) - if (os.path.isfile(file)): - return file - else: - self.logger.critical("No such registry record file {}".format(record)) - sys.exit(1) + if (os.path.isabs(rspec)): + file = rspec + else: + file = os.path.join(self.options.sfi_dir, rspec) + if (os.path.isfile(file)): + return file + else: + self.logger.critical("No such rspec file {}".format(rspec)) + sys.exit(1) + def get_record_file(self, record): + if (os.path.isabs(record)): + file = record + else: + file = os.path.join(self.options.sfi_dir, record) + if (os.path.isfile(file)): + return file + else: + self.logger.critical( + "No such registry record file {}".format(record)) + sys.exit(1) # helper function to analyze raw output - # for main : return 0 if everything is fine, something else otherwise (mostly 1 for now) + # for main : return 0 if everything is fine, something else otherwise + # (mostly 1 for now) def success(self, raw): return_value = ReturnValue(raw) output = ReturnValue.get_output(return_value) # means everything is fine - if not output: + if not output: return 0 # something went wrong print('ERROR:', output) @@ -935,23 +996,26 @@ use this if you mean an authority instead""") sys.exit(1) print("# From configuration file {}".format(self.config_file)) - flags = [ ('sfi', [ ('registry', 'reg_url'), - ('auth', 'authority'), - ('user', 'user'), - ('sm', 'sm_url'), - ]), - ] + flags = [('sfi', [('registry', 'reg_url'), + ('auth', 'authority'), + ('user', 'user'), + ('sm', 'sm_url'), + ]), + ] if options.myslice: - flags.append( ('myslice', ['backend', 'delegate', 'platform', 'username'] ) ) + flags.append( + ('myslice', ['backend', 'delegate', 'platform', 'username'])) for (section, tuples) in flags: print("[{}]".format(section)) try: for external_name, internal_name in tuples: - print("{:<20} = {}".format(external_name, getattr(self, internal_name))) + print("{:<20} = {}".format( + external_name, getattr(self, internal_name))) except: for external_name, internal_name in tuples: - varname = "{}_{}".format(section.upper(), external_name.upper()) + varname = "{}_{}".format( + section.upper(), external_name.upper()) value = getattr(self.config_instance, varname) print("{:<20} = {}".format(external_name, value)) # xxx should analyze result @@ -966,7 +1030,7 @@ use this if you mean an authority instead""") if len(args) != 0: self.print_help() sys.exit(1) - + if options.version_local: version = version_core() else: @@ -977,7 +1041,8 @@ use this if you mean an authority instead""") result = server.GetVersion() version = ReturnValue.get_value(result) if self.options.raw: - save_raw_to_file(result, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(result, self.options.raw, + self.options.rawformat, self.options.rawbanner) else: pprinter = PrettyPrinter(indent=4) pprinter.pprint(version) @@ -997,7 +1062,7 @@ use this if you mean an authority instead""") opts = {} if options.recursive: opts['recursive'] = options.recursive - + if options.show_credential: show_credentials(self.my_credential_string) try: @@ -1013,7 +1078,7 @@ use this if you mean an authority instead""") save_records_to_file(options.file, list, options.fileformat) # xxx should analyze result return 0 - + @declare_command("name", "") def show(self, options, args): """ @@ -1028,7 +1093,8 @@ use this if you mean an authority instead""") resolve_options = {} if not options.no_details: resolve_options['details'] = True - record_dicts = self.registry().Resolve(hrn, self.my_credential_string, resolve_options) + record_dicts = self.registry().Resolve( + hrn, self.my_credential_string, resolve_options) record_dicts = filter_records(options.type, record_dicts) if not record_dicts: self.logger.error("No record of type {}".format(options.type)) @@ -1038,20 +1104,26 @@ use this if you mean an authority instead""") def project(record): projected = {} for key in options.keys: - try: projected[key] = record[key] - except: pass + try: + projected[key] = record[key] + except: + pass return projected - record_dicts = [ project(record) for record in record_dicts ] - records = [ Record(dict=record_dict) for record_dict in record_dicts ] + record_dicts = [project(record) for record in record_dicts] + records = [Record(dict=record_dict) for record_dict in record_dicts] for record in records: - if (options.format == "text"): record.dump(sort=True) - else: print(record.save_as_xml()) + if (options.format == "text"): + record.dump(sort=True) + else: + print(record.save_as_xml()) if options.file: - save_records_to_file(options.file, record_dicts, options.fileformat) + save_records_to_file( + options.file, record_dicts, options.fileformat) # xxx should analyze result return 0 - - # this historically was named 'add', it is now 'register' with an alias for legacy + + # this historically was named 'add', it is now 'register' with an alias + # for legacy @declare_command("[xml-filename]", "", ['add']) def register(self, options, args): """ @@ -1071,14 +1143,15 @@ use this if you mean an authority instead""") try: record_filepath = args[0] rec_file = self.get_record_file(record_filepath) - record_dict.update(load_record_from_file(rec_file).record_to_dict()) + record_dict.update(load_record_from_file( + rec_file).record_to_dict()) except: print("Cannot load record file {}".format(record_filepath)) sys.exit(1) if options: record_dict.update(load_record_from_opts(options).record_to_dict()) # we should have a type by now - if 'type' not in record_dict : + if 'type' not in record_dict: self.print_help() sys.exit(1) # this is still planetlab dependent.. as plc will whine without that @@ -1087,13 +1160,13 @@ use this if you mean an authority instead""") if not 'first_name' in record_dict: record_dict['first_name'] = record_dict['hrn'] if 'last_name' not in record_dict: - record_dict['last_name'] = record_dict['hrn'] + record_dict['last_name'] = record_dict['hrn'] register = self.registry().Register(record_dict, auth_cred) # xxx looks like the result here is not ReturnValue-compatible - #return self.success (register) + # return self.success (register) # xxx should analyze result return 0 - + @declare_command("[xml-filename]", "") def update(self, options, args): """ @@ -1109,7 +1182,8 @@ use this if you mean an authority instead""") if len(args) == 1: record_filepath = args[0] rec_file = self.get_record_file(record_filepath) - record_dict.update(load_record_from_file(rec_file).record_to_dict()) + record_dict.update(load_record_from_file( + rec_file).record_to_dict()) if options: record_dict.update(load_record_from_opts(options).record_to_dict()) # at the very least we need 'type' here @@ -1128,26 +1202,27 @@ use this if you mean an authority instead""") try: cred = self.slice_credential_string(record_dict['hrn']) except ServerException as e: - # XXX smbaker -- once we have better error return codes, update this - # to do something better than a string compare - if "Permission error" in e.args[0]: - cred = self.my_authority_credential_string() - else: - raise + # XXX smbaker -- once we have better error return codes, update this + # to do something better than a string compare + if "Permission error" in e.args[0]: + cred = self.my_authority_credential_string() + else: + raise elif record_dict['type'] in ['authority']: cred = self.my_authority_credential_string() elif record_dict['type'] in ['node']: cred = self.my_authority_credential_string() else: - raise Exception("unknown record type {}".format(record_dict['type'])) + raise Exception( + "unknown record type {}".format(record_dict['type'])) if options.show_credential: show_credentials(cred) update = self.registry().Update(record_dict, cred) # xxx looks like the result here is not ReturnValue-compatible - #return self.success(update) + # return self.success(update) # xxx should analyze result return 0 - + @declare_command("hrn", "") def remove(self, options, args): """ @@ -1159,17 +1234,17 @@ use this if you mean an authority instead""") sys.exit(1) hrn = args[0] - type = options.type + type = options.type if type in ['all']: type = '*' if options.show_credential: show_credentials(auth_cred) remove = self.registry().Remove(hrn, auth_cred, type) # xxx looks like the result here is not ReturnValue-compatible - #return self.success (remove) + # return self.success (remove) # xxx should analyze result return 0 - + # ================================================================== # Slice-related commands # ================================================================== @@ -1188,7 +1263,8 @@ use this if you mean an authority instead""") # set creds creds = [self.my_credential] if options.delegate: - creds.append(self.delegate_cred(cred, get_authority(self.authority))) + creds.append(self.delegate_cred( + cred, get_authority(self.authority))) if options.show_credential: show_credentials(creds) @@ -1196,9 +1272,9 @@ use this if you mean an authority instead""") # been a required argument since v1 API api_options = {} # always send call_id to v2 servers - api_options ['call_id'] = unique_call_id() + api_options['call_id'] = unique_call_id() # ask for cached value if available - api_options ['cached'] = True + api_options['cached'] = True if options.info: api_options['info'] = options.info if options.list_leases: @@ -1209,12 +1285,14 @@ use this if you mean an authority instead""") else: api_options['cached'] = True version_manager = VersionManager() - api_options['geni_rspec_version'] = version_manager.get_version(options.rspec_version).to_dict() + api_options['geni_rspec_version'] = version_manager.get_version( + options.rspec_version).to_dict() list_resources = server.ListResources(creds, api_options) value = ReturnValue.get_value(list_resources) if self.options.raw: - save_raw_to_file(list_resources, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(list_resources, self.options.raw, + self.options.rawformat, self.options.rawbanner) if options.file is not None: save_rspec_to_file(value, options.file) if (self.options.raw is None) and (options.file is None): @@ -1235,7 +1313,8 @@ use this if you mean an authority instead""") # set creds creds = [self.slice_credential(args[0])] if options.delegate: - creds.append(self.delegate_cred(cred, get_authority(self.authority))) + creds.append(self.delegate_cred( + cred, get_authority(self.authority))) if options.show_credential: show_credentials(creds) @@ -1244,7 +1323,7 @@ use this if you mean an authority instead""") 'info': options.info, 'list_leases': options.list_leases, 'geni_rspec_version': {'type': 'geni', 'version': '3'}, - } + } if options.info: api_options['info'] = options.info @@ -1253,15 +1332,18 @@ use this if you mean an authority instead""") server_version = self.get_cached_server_version(server) if 'sfa' in server_version: # just request the version the client wants - api_options['geni_rspec_version'] = version_manager.get_version(options.rspec_version).to_dict() + api_options['geni_rspec_version'] = version_manager.get_version( + options.rspec_version).to_dict() else: - api_options['geni_rspec_version'] = {'type': 'geni', 'version': '3'} + api_options['geni_rspec_version'] = { + 'type': 'geni', 'version': '3'} urn = Xrn(args[0], type='slice').get_urn() - remove_none_fields(api_options) + remove_none_fields(api_options) describe = server.Describe([urn], creds, api_options) value = ReturnValue.get_value(describe) if self.options.raw: - save_raw_to_file(describe, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(describe, self.options.raw, + self.options.rawformat, self.options.rawbanner) if options.file is not None: save_rspec_to_file(value['geni_rspec'], options.file) if (self.options.raw is None) and (options.file is None): @@ -1280,7 +1362,7 @@ use this if you mean an authority instead""") server = self.sliceapi() # slice urn slice_hrn = args[0] - slice_urn = hrn_to_urn(slice_hrn, 'slice') + slice_urn = hrn_to_urn(slice_hrn, 'slice') if len(args) > 1: # we have sliver urns @@ -1292,16 +1374,18 @@ use this if you mean an authority instead""") # creds slice_cred = self.slice_credential(slice_hrn) creds = [slice_cred] - + # options and call_id when supported api_options = {} - api_options ['call_id'] = unique_call_id() + api_options['call_id'] = unique_call_id() if options.show_credential: show_credentials(creds) - delete = server.Delete(sliver_urns, creds, *self.ois(server, api_options ) ) + delete = server.Delete(sliver_urns, creds, * + self.ois(server, api_options)) value = ReturnValue.get_value(delete) if self.options.raw: - save_raw_to_file(delete, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(delete, self.options.raw, + self.options.rawformat, self.options.rawbanner) else: print(value) return self.success(delete) @@ -1330,9 +1414,9 @@ use this if you mean an authority instead""") # delegate our cred to the slice manager # do not delegate cred to slicemgr...not working at the moment pass - #if server_version.get('hrn'): + # if server_version.get('hrn'): # delegated_cred = self.delegate_cred(slice_cred, server_version['hrn']) - #elif server_version.get('urn'): + # elif server_version.get('urn'): # delegated_cred = self.delegate_cred(slice_cred, urn_to_hrn(server_version['urn'])) if options.show_credential: @@ -1340,17 +1424,19 @@ use this if you mean an authority instead""") # rspec api_options = {} - api_options ['call_id'] = unique_call_id() + api_options['call_id'] = unique_call_id() # users sfa_users = [] geni_users = [] - slice_records = self.registry().Resolve(slice_urn, [self.my_credential_string]) + slice_records = self.registry().Resolve( + slice_urn, [self.my_credential_string]) remove_none_fields(slice_records[0]) if slice_records and 'reg-researchers' in slice_records[0] and slice_records[0]['reg-researchers'] != []: slice_record = slice_records[0] user_hrns = slice_record['reg-researchers'] user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns] - user_records = self.registry().Resolve(user_urns, [self.my_credential_string]) + user_records = self.registry().Resolve( + user_urns, [self.my_credential_string]) sfa_users = sfa_users_arg(user_records, slice_record) geni_users = pg_users_arg(user_records) @@ -1359,10 +1445,12 @@ use this if you mean an authority instead""") with open(rspec_file) as rspec: rspec_xml = rspec.read() - allocate = server.Allocate(slice_urn, creds, rspec_xml, api_options) + allocate = server.Allocate( + slice_urn, creds, rspec_xml, api_options) value = ReturnValue.get_value(allocate) if self.options.raw: - save_raw_to_file(allocate, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(allocate, self.options.raw, + self.options.rawformat, self.options.rawbanner) if options.file is not None: save_rspec_to_file(value['geni_rspec'], options.file) if (self.options.raw is None) and (options.file is None): @@ -1396,16 +1484,16 @@ use this if you mean an authority instead""") # delegate our cred to the slice manager # do not delegate cred to slicemgr...not working at the moment pass - #if server_version.get('hrn'): + # if server_version.get('hrn'): # delegated_cred = self.delegate_cred(slice_cred, server_version['hrn']) - #elif server_version.get('urn'): + # elif server_version.get('urn'): # delegated_cred = self.delegate_cred(slice_cred, urn_to_hrn(server_version['urn'])) if options.show_credential: show_credentials(creds) api_options = {} - api_options ['call_id'] = unique_call_id() + api_options['call_id'] = unique_call_id() # set the requtested rspec version version_manager = VersionManager() @@ -1419,19 +1507,22 @@ use this if you mean an authority instead""") # keys: [, ] # }] users = [] - slice_records = self.registry().Resolve(slice_urn, [self.my_credential_string]) + slice_records = self.registry().Resolve( + slice_urn, [self.my_credential_string]) if slice_records and 'reg-researchers' in slice_records[0] and slice_records[0]['reg-researchers'] != []: slice_record = slice_records[0] user_hrns = slice_record['reg-researchers'] user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns] - user_records = self.registry().Resolve(user_urns, [self.my_credential_string]) + user_records = self.registry().Resolve( + user_urns, [self.my_credential_string]) users = pg_users_arg(user_records) - + api_options['geni_users'] = users provision = server.Provision(sliver_urns, creds, api_options) value = ReturnValue.get_value(provision) if self.options.raw: - save_raw_to_file(provision, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(provision, self.options.raw, + self.options.rawformat, self.options.rawbanner) if options.file is not None: save_rspec_to_file(value['geni_rspec'], options.file) if (self.options.raw is None) and (options.file is None): @@ -1450,9 +1541,9 @@ use this if you mean an authority instead""") server = self.sliceapi() # slice urn slice_hrn = args[0] - slice_urn = hrn_to_urn(slice_hrn, 'slice') + slice_urn = hrn_to_urn(slice_hrn, 'slice') - # creds + # creds slice_cred = self.slice_credential(slice_hrn) creds = [slice_cred] @@ -1461,10 +1552,12 @@ use this if you mean an authority instead""") api_options['call_id'] = unique_call_id() if options.show_credential: show_credentials(creds) - status = server.Status([slice_urn], creds, *self.ois(server, api_options)) + status = server.Status([slice_urn], creds, * + self.ois(server, api_options)) value = ReturnValue.get_value(status) if self.options.raw: - save_raw_to_file(status, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(status, self.options.raw, + self.options.rawformat, self.options.rawbanner) else: print(value) return self.success(status) @@ -1494,13 +1587,16 @@ use this if you mean an authority instead""") slice_cred = self.slice_credential(args[0]) creds = [slice_cred] if options.delegate: - delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority)) + delegated_cred = self.delegate_cred( + slice_cred, get_authority(self.authority)) creds.append(delegated_cred) - - perform_action = server.PerformOperationalAction(sliver_urns, creds, action , api_options) + + perform_action = server.PerformOperationalAction( + sliver_urns, creds, action, api_options) value = ReturnValue.get_value(perform_action) if self.options.raw: - save_raw_to_file(perform_action, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(perform_action, self.options.raw, + self.options.rawformat, self.options.rawbanner) else: print(value) return self.success(perform_action) @@ -1510,7 +1606,7 @@ use this if you mean an authority instead""") "sfi renew onelab.ple.heartbeat 2015-04-31T14:00:00Z", "sfi renew onelab.ple.heartbeat +5d", "sfi renew onelab.ple.heartbeat +3w", - "sfi renew onelab.ple.heartbeat +2m",])) + "sfi renew onelab.ple.heartbeat +2m", ])) def renew(self, options, args): """ renew slice(Renew) @@ -1542,10 +1638,12 @@ use this if you mean an authority instead""") api_options['geni_extend_alap'] = True if options.show_credential: show_credentials(creds) - renew = server.Renew(sliver_urns, creds, input_time, *self.ois(server, api_options)) + renew = server.Renew(sliver_urns, creds, input_time, + *self.ois(server, api_options)) value = ReturnValue.get_value(renew) if self.options.raw: - save_raw_to_file(renew, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(renew, self.options.raw, + self.options.rawformat, self.options.rawbanner) else: print(value) return self.success(renew) @@ -1562,14 +1660,15 @@ use this if you mean an authority instead""") server = self.sliceapi() # slice urn slice_hrn = args[0] - slice_urn = hrn_to_urn(slice_hrn, 'slice') + slice_urn = hrn_to_urn(slice_hrn, 'slice') # creds slice_cred = self.slice_credential(slice_hrn) creds = [slice_cred] shutdown = server.Shutdown(slice_urn, creds) value = ReturnValue.get_value(shutdown) if self.options.raw: - save_raw_to_file(shutdown, self.options.raw, self.options.rawformat, self.options.rawbanner) + save_raw_to_file(shutdown, self.options.raw, + self.options.rawformat, self.options.rawbanner) else: print(value) return self.success(shutdown) @@ -1584,17 +1683,18 @@ use this if you mean an authority instead""") sys.exit(1) target_hrn = args[0] - my_gid_string = open(self.client_bootstrap.my_gid()).read() + my_gid_string = open(self.client_bootstrap.my_gid()).read() gid = self.registry().CreateGid(self.my_credential_string, target_hrn, my_gid_string) if options.file: filename = options.file else: - filename = os.sep.join([self.options.sfi_dir, '{}.gid'.format(target_hrn)]) + filename = os.sep.join( + [self.options.sfi_dir, '{}.gid'.format(target_hrn)]) self.logger.info("writing {} gid to {}".format(target_hrn, filename)) GID(string=gid).save_to_file(filename) # xxx should analyze result return 0 - + #################### @declare_command("to_hrn", """$ sfi delegate -u -p -s ple.inria.heartbeat -s ple.inria.omftest ple.upmc.slicebrowser @@ -1626,16 +1726,16 @@ use this if you mean an authority instead""") for slice_hrn in options.delegate_slices: message = "{}.slice".format(slice_hrn) original = self.slice_credential_string(slice_hrn) - tuples.append( (message, original,) ) + tuples.append((message, original,)) if options.delegate_pi: my_authority = self.authority message = "{}.pi".format(my_authority) original = self.my_authority_credential_string() - tuples.append( (message, original,) ) + tuples.append((message, original,)) for auth_hrn in options.delegate_auths: message = "{}.auth".format(auth_hrn) original = self.authority_credential_string(auth_hrn) - tuples.append( (message, original, ) ) + tuples.append((message, original, )) # if nothing was specified at all at this point, let's assume -u if not tuples: options.delegate_user = True @@ -1643,7 +1743,7 @@ use this if you mean an authority instead""") if options.delegate_user: message = "{}.user".format(self.user) original = self.my_credential_string - tuples.append( (message, original, ) ) + tuples.append((message, original, )) # default type for beneficial is user unless -A to_type = 'authority' if options.delegate_to_authority else 'user' @@ -1651,14 +1751,15 @@ use this if you mean an authority instead""") # let's now handle all this # it's all in the filenaming scheme for (message, original) in tuples: - delegated_string = self.client_bootstrap.delegate_credential_string(original, to_hrn, to_type) + delegated_string = self.client_bootstrap.delegate_credential_string( + original, to_hrn, to_type) delegated_credential = Credential(string=delegated_string) filename = os.path.join(self.options.sfi_dir, "{}_for_{}.{}.cred".format(message, to_hrn, to_type)) delegated_credential.save_to_file(filename, save_parents=True) self.logger.info("delegated credential for {} to {} and wrote to {}" .format(message, to_hrn, filename)) - + #################### @declare_command("", """$ less +/myslice sfi_config [myslice] @@ -1684,9 +1785,8 @@ $ sfi m -b http://mymanifold.foo.com:7080/ is synonym to sfi myslice as no other command starts with an 'm' and uses a custom backend for this one call """ -) # declare_command + ) # declare_command def myslice(self, options, args): - """ This helper is for refreshing your credentials at myslice; it will * compute all the slices that you currently have credentials on * refresh all your credentials (you as a user and pi, your slices) @@ -1700,14 +1800,14 @@ $ sfi m -b http://mymanifold.foo.com:7080/ self.print_help() sys.exit(1) # enable info by default - self.logger.setLevelFromOptVerbose(self.options.verbose+1) - ### the rough sketch goes like this + self.logger.setLevelFromOptVerbose(self.options.verbose + 1) + # the rough sketch goes like this # (0) produce a p12 file self.client_bootstrap.my_pkcs12() # (a) rain check for sufficient config in sfi_config myslice_dict = {} - myslice_keys = [ 'backend', 'delegate', 'platform', 'username'] + myslice_keys = ['backend', 'delegate', 'platform', 'username'] for key in myslice_keys: value = None # oct 2013 - I'm finding myself juggling with config files @@ -1733,63 +1833,74 @@ $ sfi m -b http://mymanifold.foo.com:7080/ sys.exit(1) my_record = my_records[0] my_auths_all = my_record['reg-pi-authorities'] - self.logger.info("Found {} authorities that we are PI for".format(len(my_auths_all))) + self.logger.info( + "Found {} authorities that we are PI for".format(len(my_auths_all))) self.logger.debug("They are {}".format(my_auths_all)) - + my_auths = my_auths_all if options.delegate_auths: - my_auths = list(set(my_auths_all).intersection(set(options.delegate_auths))) - self.logger.debug("Restricted to user-provided auths {}".format(my_auths)) + my_auths = list(set(my_auths_all).intersection( + set(options.delegate_auths))) + self.logger.debug( + "Restricted to user-provided auths {}".format(my_auths)) # (c) get the set of slices that we are in my_slices_all = my_record['reg-slices'] - self.logger.info("Found {} slices that we are member of".format(len(my_slices_all))) + self.logger.info( + "Found {} slices that we are member of".format(len(my_slices_all))) self.logger.debug("They are: {}".format(my_slices_all)) - + my_slices = my_slices_all # if user provided slices, deal only with these - if they are found if options.delegate_slices: - my_slices = list(set(my_slices_all).intersection(set(options.delegate_slices))) - self.logger.debug("Restricted to user-provided slices: {}".format(my_slices)) + my_slices = list(set(my_slices_all).intersection( + set(options.delegate_slices))) + self.logger.debug( + "Restricted to user-provided slices: {}".format(my_slices)) # (d) make sure we have *valid* credentials for all these hrn_credentials = [] - hrn_credentials.append( (self.user, 'user', self.my_credential_string,) ) + hrn_credentials.append((self.user, 'user', self.my_credential_string,)) for auth_hrn in my_auths: - hrn_credentials.append( (auth_hrn, 'auth', self.authority_credential_string(auth_hrn),) ) + hrn_credentials.append( + (auth_hrn, 'auth', self.authority_credential_string(auth_hrn),)) for slice_hrn in my_slices: try: - hrn_credentials.append( (slice_hrn, 'slice', self.slice_credential_string(slice_hrn),) ) + hrn_credentials.append( + (slice_hrn, 'slice', self.slice_credential_string(slice_hrn),)) except: print("WARNING: could not get slice credential for slice {}" .format(slice_hrn)) # (e) check for the delegated version of these - # xxx todo add an option -a/-A? like for 'sfi delegate' for when we ever + # xxx todo add an option -a/-A? like for 'sfi delegate' for when we ever # switch to myslice using an authority instead of a user delegatee_type = 'user' delegatee_hrn = myslice_dict['delegate'] hrn_delegated_credentials = [] for (hrn, htype, credential) in hrn_credentials: - delegated_credential = self.client_bootstrap.delegate_credential_string(credential, delegatee_hrn, delegatee_type) + delegated_credential = self.client_bootstrap.delegate_credential_string( + credential, delegatee_hrn, delegatee_type) # save these so user can monitor what she's uploaded - filename = os.path.join( self.options.sfi_dir, - "{}.{}_for_{}.{}.cred"\ - .format(hrn, htype, delegatee_hrn, delegatee_type)) + filename = os.path.join(self.options.sfi_dir, + "{}.{}_for_{}.{}.cred" + .format(hrn, htype, delegatee_hrn, delegatee_type)) with open(filename, 'w') as f: f.write(delegated_credential) self.logger.debug("(Over)wrote {}".format(filename)) - hrn_delegated_credentials.append((hrn, htype, delegated_credential, filename, )) + hrn_delegated_credentials.append( + (hrn, htype, delegated_credential, filename, )) # (f) and finally upload them to manifold server # xxx todo add an option so the password can be set on the command line # (but *NOT* in the config file) so other apps can leverage this - self.logger.info("Uploading on backend at {}".format(myslice_dict['backend'])) + self.logger.info("Uploading on backend at {}".format( + myslice_dict['backend'])) uploader = ManifoldUploader(logger=self.logger, - url=myslice_dict['backend'], - platform=myslice_dict['platform'], - username=myslice_dict['username'], - password=options.password) + url=myslice_dict['backend'], + platform=myslice_dict['platform'], + username=myslice_dict['username'], + password=options.password) uploader.prompt_all() (count_all, count_success) = (0, 0) for (hrn, htype, delegated_credential, filename) in hrn_delegated_credentials: @@ -1816,7 +1927,7 @@ $ sfi m -b http://mymanifold.foo.com:7080/ def trusted(self, options, args): """ return the trusted certs at this interface (get_trusted_certs) - """ + """ if options.registry_interface: server = self.registry() else: @@ -1850,7 +1961,7 @@ $ sfi m -b http://mymanifold.foo.com:7080/ # at first sight a list here means it's fine, # and a dict suggests an error (no support for introspection?) if isinstance(results, list): - results = [ name for name in results if 'system.' not in name ] + results = [name for name in results if 'system.' not in name] results.sort() print("== methods supported at {}".format(server.url)) if 'Discover' in results: diff --git a/sfa/client/sfi_commands.py b/sfa/client/sfi_commands.py index 80897cd2..a8724b1c 100755 --- a/sfa/client/sfi_commands.py +++ b/sfa/client/sfi_commands.py @@ -3,7 +3,9 @@ import sys from optparse import OptionParser + class Commands: + def __init__(self, usage, description, epilog=None): self.parser = OptionParser(usage=usage, description=description, epilog=epilog) @@ -17,7 +19,7 @@ class Commands: def add_nodefile_option(self): self.nodefile = True - self.parser.add_option("-n", "", dest="nodefile", + self.parser.add_option("-n", "", dest="nodefile", metavar="FILE", help="read node list from FILE"), @@ -25,11 +27,11 @@ class Commands: self.linkfile = True self.parser.add_option("-l", "", dest="linkfile", metavar="FILE", - help="read link list from FILE") + help="read link list from FILE") def add_show_attributes_option(self): - self.parser.add_option("-s", "--show-attributes", action="store_true", - dest="showatt", default=False, + self.parser.add_option("-s", "--show-attributes", action="store_true", + dest="showatt", default=False, help="show sliver attributes") def add_attribute_options(self): @@ -41,73 +43,73 @@ class Commands: help="Demux HTTP between slices using " + "localhost ports") self.parser.add_option("", "--cpu-pct", action="append", - metavar="", + metavar="", help="Reserved CPU percent (e.g., 25)") self.parser.add_option("", "--cpu-share", action="append", - metavar="", + metavar="", help="Number of CPU shares (e.g., 5)") - self.parser.add_option("", "--delegations", + self.parser.add_option("", "--delegations", metavar="", action="append", help="List of slices with delegation authority") - self.parser.add_option("", "--disk-max", + self.parser.add_option("", "--disk-max", metavar="", action="append", help="Disk quota (1k disk blocks)") - self.parser.add_option("", "--initscript", + self.parser.add_option("", "--initscript", metavar="", action="append", help="Slice initialization script (e.g., stork)") self.parser.add_option("", "--ip-addresses", action="append", - metavar="", + metavar="", help="Add an IP address to a sliver") - self.parser.add_option("", "--net-i2-max-kbyte", + self.parser.add_option("", "--net-i2-max-kbyte", metavar="", action="append", help="Maximum daily network Tx limit " + "to I2 hosts.") - self.parser.add_option("", "--net-i2-max-rate", + self.parser.add_option("", "--net-i2-max-rate", metavar="", action="append", help="Maximum bandwidth over I2 routes") - self.parser.add_option("", "--net-i2-min-rate", + self.parser.add_option("", "--net-i2-min-rate", metavar="", action="append", help="Minimum bandwidth over I2 routes") - self.parser.add_option("", "--net-i2-share", + self.parser.add_option("", "--net-i2-share", metavar="", action="append", help="Number of bandwidth shares over I2 routes") - self.parser.add_option("", "--net-i2-thresh-kbyte", + self.parser.add_option("", "--net-i2-thresh-kbyte", metavar="", action="append", help="Limit sent to I2 hosts before warning, " + "throttling") - self.parser.add_option("", "--net-max-kbyte", + self.parser.add_option("", "--net-max-kbyte", metavar="", action="append", help="Maximum daily network Tx limit " + "to non-I2 hosts.") - self.parser.add_option("", "--net-max-rate", + self.parser.add_option("", "--net-max-rate", metavar="", action="append", help="Maximum bandwidth over non-I2 routes") - self.parser.add_option("", "--net-min-rate", + self.parser.add_option("", "--net-min-rate", metavar="", action="append", help="Minimum bandwidth over non-I2 routes") - self.parser.add_option("", "--net-share", + self.parser.add_option("", "--net-share", metavar="", action="append", help="Number of bandwidth shares over non-I2 " + "routes") - self.parser.add_option("", "--net-thresh-kbyte", + self.parser.add_option("", "--net-thresh-kbyte", metavar="", action="append", help="Limit sent to non-I2 hosts before " + "warning, throttling") - self.parser.add_option("", "--vsys", + self.parser.add_option("", "--vsys", metavar="", action="append", help="Vsys script (e.g., fd_fusemount)") - self.parser.add_option("", "--vsys-vnet", + self.parser.add_option("", "--vsys-vnet", metavar="", action="append", help="Allocate a virtual private network") def get_attribute_dict(self): - attrlist = ['capabilities','codemux','cpu_pct','cpu_share', - 'delegations','disk_max','initscript','ip_addresses', - 'net_i2_max_kbyte','net_i2_max_rate','net_i2_min_rate', - 'net_i2_share','net_i2_thresh_kbyte', - 'net_max_kbyte','net_max_rate','net_min_rate', - 'net_share','net_thresh_kbyte', - 'vsys','vsys_vnet'] + attrlist = ['capabilities', 'codemux', 'cpu_pct', 'cpu_share', + 'delegations', 'disk_max', 'initscript', 'ip_addresses', + 'net_i2_max_kbyte', 'net_i2_max_rate', 'net_i2_min_rate', + 'net_i2_share', 'net_i2_thresh_kbyte', + 'net_max_kbyte', 'net_max_rate', 'net_min_rate', + 'net_share', 'net_thresh_kbyte', + 'vsys', 'vsys_vnet'] attrdict = {} for attr in attrlist: value = getattr(self.opts, attr, None) @@ -118,12 +120,12 @@ class Commands: def prep(self): (self.opts, self.args) = self.parser.parse_args() - #if self.opts.infile: + # if self.opts.infile: # sys.stdin = open(self.opts.infile, "r") #xml = sys.stdin.read() #self.rspec = RSpec(xml) - # - #if self.nodefile: + # + # if self.nodefile: # if self.opts.nodefile: # f = open(self.opts.nodefile, "r") # self.nodes = f.read().split() @@ -131,12 +133,5 @@ class Commands: # else: # self.nodes = self.args # - #if self.opts.outfile: + # if self.opts.outfile: # sys.stdout = open(self.opts.outfile, "w") - - - - - - - diff --git a/sfa/dummy/dummy_testbed_api.py b/sfa/dummy/dummy_testbed_api.py index f553e40e..697cda58 100644 --- a/sfa/dummy/dummy_testbed_api.py +++ b/sfa/dummy/dummy_testbed_api.py @@ -3,73 +3,81 @@ import time dummy_api_addr = ("localhost", 8080) -# Fake Testbed DB +# Fake Testbed DB nodes_list = [] -for i in range(1,11): - node = {'hostname': 'node'+str(i)+'.dummy-testbed.org', 'type': 'dummy-node', 'node_id': i} +for i in range(1, 11): + node = {'hostname': 'node' + + str(i) + '.dummy-testbed.org', 'type': 'dummy-node', 'node_id': i} nodes_list.append(node) slices_list = [] -for i in range(1,3): - slice = {'slice_name': 'slice'+str(i), - 'user_ids': range(i,4,2), - 'slice_id': i, - 'node_ids': range(i,10,2), +for i in range(1, 3): + slice = {'slice_name': 'slice' + str(i), + 'user_ids': range(i, 4, 2), + 'slice_id': i, + 'node_ids': range(i, 10, 2), 'enabled': True, - 'expires': int(time.time())+60*60*24*30} + 'expires': int(time.time()) + 60 * 60 * 24 * 30} slices_list.append(slice) users_list = [] -for i in range(1,5): - user = {'user_name': 'user'+str(i), 'user_id': i, 'email': 'user'+str(i)+'@dummy-testbed.org', 'keys': ['user_ssh_pub_key_'+str(i)]} +for i in range(1, 5): + user = {'user_name': 'user' + str(i), 'user_id': i, 'email': 'user' + str( + i) + '@dummy-testbed.org', 'keys': ['user_ssh_pub_key_' + str(i)]} users_list.append(user) -DB = {'nodes_list': nodes_list,'node_index': 11, 'slices_list': slices_list, 'slice_index': 3, 'users_list': users_list, 'user_index': 5} +DB = {'nodes_list': nodes_list, 'node_index': 11, 'slices_list': slices_list, + 'slice_index': 3, 'users_list': users_list, 'user_index': 5} + +# Filter function gor the GET methods -#Filter function gor the GET methods def FilterList(myfilter, mylist): result = [] result.extend(mylist) for item in mylist: - for key in myfilter.keys(): - if 'ids' in key: - pass - else: - if isinstance(myfilter[key], str) and myfilter[key] != item[key] or isinstance(myfilter[key], list) and item[key] not in myfilter[key]: - result.remove(item) - break + for key in myfilter.keys(): + if 'ids' in key: + pass + else: + if isinstance(myfilter[key], str) and myfilter[key] != item[key] or isinstance(myfilter[key], list) and item[key] not in myfilter[key]: + result.remove(item) + break return result # RPC functions definition -#GET +# GET def GetTestbedInfo(): - return {'name': 'dummy', 'longitude': 123456, 'latitude': 654321, 'domain':'dummy-testbed.org'} + return {'name': 'dummy', 'longitude': 123456, 'latitude': 654321, 'domain': 'dummy-testbed.org'} + def GetNodes(filter=None): - if filter is None: filter={} + if filter is None: + filter = {} global DB result = [] result.extend(DB['nodes_list']) if 'node_ids' in filter: for node in DB['nodes_list']: - if node['node_id'] not in filter['node_ids']: - result.remove(node) + if node['node_id'] not in filter['node_ids']: + result.remove(node) if filter: result = FilterList(filter, result) return result + def GetSlices(filter=None): - if filter is None: filter={} + if filter is None: + filter = {} global DB result = [] result.extend(DB['slices_list']) if 'slice_ids' in filter: for slice in DB['slices_list']: - if slice['slice_id'] not in filter['slice_ids']: - result.remove(slice) + if slice['slice_id'] not in filter['slice_ids']: + result.remove(slice) if filter: result = FilterList(filter, result) @@ -77,47 +85,48 @@ def GetSlices(filter=None): def GetUsers(filter=None): - if filter is None: filter={} + if filter is None: + filter = {} global DB result = [] result.extend(DB['users_list']) if 'user_ids' in filter: for user in DB['users_list']: - if user['user_id'] not in filter['user_ids']: - result.remove(user) + if user['user_id'] not in filter['user_ids']: + result.remove(user) if filter: result = FilterList(filter, result) return result -#def GetKeys(): - +# def GetKeys(): -#add +# add def AddNode(node): global DB if not isinstance(node, dict): return False for key in node.keys(): - if key not in ['hostname', 'type']: - return False + if key not in ['hostname', 'type']: + return False node['node_id'] = DB['node_index'] DB['node_index'] += 1 - DB['nodes_list'].append(node) + DB['nodes_list'].append(node) return node['node_id'] + def AddSlice(slice): global DB if not isinstance(slice, dict): return False for key in slice.keys(): - if key not in ['slice_name', 'user_ids', 'node_ids', 'enabled', 'expires']: - return False + if key not in ['slice_name', 'user_ids', 'node_ids', 'enabled', 'expires']: + return False slice['slice_id'] = DB['slice_index'] - slice['expires'] = int(time.time())+60*60*24*30 + slice['expires'] = int(time.time()) + 60 * 60 * 24 * 30 DB['slice_index'] += 1 DB['slices_list'].append(slice) return slice['slice_id'] @@ -128,8 +137,8 @@ def AddUser(user): if not isinstance(user, dict): return False for key in user.keys(): - if key not in ['user_name', 'email', 'keys']: - return False + if key not in ['user_name', 'email', 'keys']: + return False user['user_id'] = DB['user_index'] DB['user_index'] += 1 DB['users_list'].append(user) @@ -142,46 +151,50 @@ def AddUserKey(param): return False try: for user in DB['users_list']: - if param['user_id'] == user['user_id']: - if 'keys' in user.keys(): - user['keys'].append(param['key']) - else: - user['keys'] = [param['key']] - return True + if param['user_id'] == user['user_id']: + if 'keys' in user.keys(): + user['keys'].append(param['key']) + else: + user['keys'] = [param['key']] + return True return False except: return False + def AddUserToSlice(param): global DB if not isinstance(param, dict): return False try: for slice in DB['slices_list']: - if param['slice_id'] == slice['slice_id']: - if not 'user_ids' in slice: slice['user_ids'] = [] - slice['user_ids'].append(param['user_id']) - return True + if param['slice_id'] == slice['slice_id']: + if not 'user_ids' in slice: + slice['user_ids'] = [] + slice['user_ids'].append(param['user_id']) + return True return False except: return False + def AddSliceToNodes(param): global DB if not isinstance(param, dict): return False try: for slice in DB['slices_list']: - if param['slice_id'] == slice['slice_id']: - if not 'node_ids' in slice: slice['node_ids'] = [] - slice['node_ids'].extend(param['node_ids']) - return True + if param['slice_id'] == slice['slice_id']: + if not 'node_ids' in slice: + slice['node_ids'] = [] + slice['node_ids'].extend(param['node_ids']) + return True return False except: return False -#Delete +# Delete def DeleteNode(param): global DB @@ -189,14 +202,14 @@ def DeleteNode(param): return False try: for node in DB['nodes_list']: - if param['node_id'] == node['node_id']: - DB['nodes_list'].remove(node) - for slice in DB['slices_list']: - if param['node_id'] in slice['node_ids']: - slice['node_ids'].remove(param['node_id']) - return True + if param['node_id'] == node['node_id']: + DB['nodes_list'].remove(node) + for slice in DB['slices_list']: + if param['node_id'] in slice['node_ids']: + slice['node_ids'].remove(param['node_id']) + return True return False - except: + except: return False @@ -206,9 +219,9 @@ def DeleteSlice(param): return False try: for slice in DB['slices_list']: - if param['slice_id'] == slice['slice_id']: - DB['slices_list'].remove(slice) - return True + if param['slice_id'] == slice['slice_id']: + DB['slices_list'].remove(slice) + return True return False except: return False @@ -220,16 +233,16 @@ def DeleteUser(param): return False try: for user in DB['users_list']: - if param['user_id'] == user['user_id']: - DB['users_list'].remove(user) - for slice in DB['slices_list']: - if param['user_id'] in slice['user_ids']: - slice['user_ids'].remove(param['user_id']) - return True + if param['user_id'] == user['user_id']: + DB['users_list'].remove(user) + for slice in DB['slices_list']: + if param['user_id'] in slice['user_ids']: + slice['user_ids'].remove(param['user_id']) + return True return False except: return False - + def DeleteKey(param): global DB @@ -237,26 +250,27 @@ def DeleteKey(param): return False try: for user in DB['users_list']: - if param['key'] in user['keys']: - user['keys'].remove(param['key']) - return True + if param['key'] in user['keys']: + user['keys'].remove(param['key']) + return True return False except: return False + def DeleteUserFromSlice(param): global DB if not isinstance(param, dict): return False try: for slice in DB['slices_list']: - if param['slice_id'] == slice['slice_id'] and param['user_id'] in slice['user_ids']: - slice['user_ids'].remove(param['user_id']) - return True + if param['slice_id'] == slice['slice_id'] and param['user_id'] in slice['user_ids']: + slice['user_ids'].remove(param['user_id']) + return True return False except: return False - + def DeleteSliceFromNodes(param): global DB @@ -264,16 +278,17 @@ def DeleteSliceFromNodes(param): return False try: for slice in DB['slices_list']: - if param['slice_id'] == slice['slice_id']: - for node_id in param['node_ids']: - if node_id in slice['node_ids']: slice['node_ids'].remove(node_id) - return True + if param['slice_id'] == slice['slice_id']: + for node_id in param['node_ids']: + if node_id in slice['node_ids']: + slice['node_ids'].remove(node_id) + return True return False except: return False -#Update +# Update def UpdateNode(param): global DB @@ -281,11 +296,11 @@ def UpdateNode(param): return False try: for node in DB['nodes_list']: - if param['node_id'] == node['node_id']: - for key in param['fields'].keys(): - if key in ['hostname', 'type']: - node[key] = param['fields'][key] - return True + if param['node_id'] == node['node_id']: + for key in param['fields'].keys(): + if key in ['hostname', 'type']: + node[key] = param['fields'][key] + return True return False except: return False @@ -297,11 +312,11 @@ def UpdateSlice(param): return False try: for slice in DB['slices_list']: - if param['slice_id'] == slice['slice_id']: - for key in param['fields'].keys(): - if key in ['slice_name']: - slice[key] = param['fields'][key] - return True + if param['slice_id'] == slice['slice_id']: + for key in param['fields'].keys(): + if key in ['slice_name']: + slice[key] = param['fields'][key] + return True return False except: return False @@ -313,19 +328,17 @@ def UpdateUser(param): return False try: for user in DB['users_list']: - if param['user_id'] == user['user_id']: - for key in param['fields'].keys(): - if key in ['user_name', 'email']: - user[key] = param['fields'][key] - return True + if param['user_id'] == user['user_id']: + for key in param['fields'].keys(): + if key in ['user_name', 'email']: + user[key] = param['fields'][key] + return True return False except: return False - - -# Instantiate the XMLRPC server +# Instantiate the XMLRPC server dummy_api_server = SimpleXMLRPCServer.SimpleXMLRPCServer(dummy_api_addr) # RPC functions registration @@ -355,6 +368,3 @@ dummy_api_server.register_introspection_functions() # Handle requests dummy_api_server.serve_forever() - - - diff --git a/sfa/dummy/dummy_testbed_api_client.py b/sfa/dummy/dummy_testbed_api_client.py index 57a04825..2c3d363f 100644 --- a/sfa/dummy/dummy_testbed_api_client.py +++ b/sfa/dummy/dummy_testbed_api_client.py @@ -8,7 +8,8 @@ dummy_url = "http://localhost:8080" dummy_api = xmlrpc_client.ServerProxy(dummy_url) # Add a user: -my_user_id = dummy_api.AddUser({'email': 'john.doe@test.net', 'user_name': 'john.doe', 'keys': ['copy here your ssh-rsa public key']}) +my_user_id = dummy_api.AddUser({'email': 'john.doe@test.net', 'user_name': 'john.doe', 'keys': [ + 'copy here your ssh-rsa public key']}) # Attach the user with the slice named : slice2 : dummy_api.AddUserToSlice({'slice_id': 2, 'user_id': my_user_id}) diff --git a/sfa/dummy/dummyaggregate.py b/sfa/dummy/dummyaggregate.py index c5b4d10c..dc8384fc 100644 --- a/sfa/dummy/dummyaggregate.py +++ b/sfa/dummy/dummyaggregate.py @@ -22,6 +22,7 @@ from sfa.dummy.dummyxrn import DummyXrn, hostname_to_urn, hrn_to_dummy_slicename from sfa.storage.model import SliverAllocation import time + class DummyAggregate: def __init__(self, driver): @@ -42,24 +43,27 @@ class DummyAggregate: if not slices: return (slice, slivers) slice = slices[0] - - # sort slivers by node id + + # sort slivers by node id slice_nodes = [] if 'node_ids' in slice.keys(): - slice_nodes = self.driver.shell.GetNodes({'node_ids': slice['node_ids']}) + slice_nodes = self.driver.shell.GetNodes( + {'node_ids': slice['node_ids']}) for node in slice_nodes: - slivers[node['node_id']] = node + slivers[node['node_id']] = node return (slice, slivers) def get_nodes(self, options=None): - if options is None: options={} + if options is None: + options = {} filter = {} nodes = self.driver.shell.GetNodes(filter) return nodes def get_slivers(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} slice_names = set() slice_ids = set() node_ids = [] @@ -88,7 +92,8 @@ class DummyAggregate: if not slices: return [] slice = slices[0] - slice['hrn'] = DummyXrn(auth=self.driver.hrn, slicename=slice['slice_name']).hrn + slice['hrn'] = DummyXrn(auth=self.driver.hrn, + slicename=slice['slice_name']).hrn # get sliver users users = [] @@ -111,13 +116,15 @@ class DummyAggregate: users_list.append(user) if node_ids: - node_ids = [node_id for node_id in node_ids if node_id in slice['node_ids']] + node_ids = [ + node_id for node_id in node_ids if node_id in slice['node_ids']] slice['node_ids'] = node_ids nodes_dict = self.get_slice_nodes(slice, options) slivers = [] for node in nodes_dict.values(): node.update(slice) - sliver_hrn = '%s.%s-%s' % (self.driver.hrn, slice['slice_id'], node['node_id']) + sliver_hrn = '%s.%s-%s' % (self.driver.hrn, + slice['slice_id'], node['node_id']) node['sliver_id'] = Xrn(sliver_hrn, type='sliver').urn node['urn'] = node['sliver_id'] node['services_user'] = users @@ -125,20 +132,25 @@ class DummyAggregate: return slivers def node_to_rspec_node(self, node, options=None): - if options is None: options={} + if options is None: + options = {} rspec_node = NodeElement() - site=self.driver.testbedInfo - rspec_node['component_id'] = hostname_to_urn(self.driver.hrn, site['name'], node['hostname']) + site = self.driver.testbedInfo + rspec_node['component_id'] = hostname_to_urn( + self.driver.hrn, site['name'], node['hostname']) rspec_node['component_name'] = node['hostname'] - rspec_node['component_manager_id'] = Xrn(self.driver.hrn, 'authority+cm').get_urn() - rspec_node['authority_id'] = hrn_to_urn(DummyXrn.site_hrn(self.driver.hrn, site['name']), 'authority+sa') - #distinguish between Shared and Reservable nodes + rspec_node['component_manager_id'] = Xrn( + self.driver.hrn, 'authority+cm').get_urn() + rspec_node['authority_id'] = hrn_to_urn(DummyXrn.site_hrn( + self.driver.hrn, site['name']), 'authority+sa') + # distinguish between Shared and Reservable nodes rspec_node['exclusive'] = 'false' rspec_node['hardware_types'] = [HardwareType({'name': 'dummy-pc'}), HardwareType({'name': 'pc'})] if site['longitude'] and site['latitude']: - location = Location({'longitude': site['longitude'], 'latitude': site['latitude'], 'country': 'unknown'}) + location = Location({'longitude': site['longitude'], 'latitude': site[ + 'latitude'], 'country': 'unknown'}) rspec_node['location'] = location return rspec_node @@ -147,27 +159,30 @@ class DummyAggregate: rspec_node['expires'] = datetime_to_string(utcparse(sliver['expires'])) # add sliver info rspec_sliver = Sliver({'sliver_id': sliver['urn'], - 'name': sliver['slice_name'], - 'type': 'dummy-vserver', - 'tags': []}) + 'name': sliver['slice_name'], + 'type': 'dummy-vserver', + 'tags': []}) rspec_node['sliver_id'] = rspec_sliver['sliver_id'] if sliver['urn'] in sliver_allocations: - rspec_node['client_id'] = sliver_allocations[sliver['urn']].client_id + rspec_node['client_id'] = sliver_allocations[ + sliver['urn']].client_id if sliver_allocations[sliver['urn']].component_id: - rspec_node['component_id'] = sliver_allocations[sliver['urn']].component_id + rspec_node['component_id'] = sliver_allocations[ + sliver['urn']].component_id rspec_node['slivers'] = [rspec_sliver] # slivers always provide the ssh service login = Login({'authentication': 'ssh-keys', 'hostname': sliver['hostname'], - 'port':'22', + 'port': '22', 'username': sliver['slice_name'], 'login': sliver['slice_name'] - }) + }) return rspec_node def get_slice_nodes(self, slice, options=None): - if options is None: options={} + if options is None: + options = {} nodes_dict = {} filter = {} if slice and slice.get('node_ids'): @@ -180,15 +195,16 @@ class DummyAggregate: nodes_dict[node['node_id']] = node return nodes_dict - def rspec_node_to_geni_sliver(self, rspec_node, sliver_allocations = None): - if sliver_allocations is None: sliver_allocations={} + def rspec_node_to_geni_sliver(self, rspec_node, sliver_allocations=None): + if sliver_allocations is None: + sliver_allocations = {} if rspec_node['sliver_id'] in sliver_allocations: # set sliver allocation and operational status sliver_allocation = sliver_allocations[rspec_node['sliver_id']] if sliver_allocation: allocation_status = sliver_allocation.allocation_state if allocation_status == 'geni_allocated': - op_status = 'geni_pending_allocation' + op_status = 'geni_pending_allocation' elif allocation_status == 'geni_provisioned': op_status = 'geni_ready' else: @@ -201,22 +217,24 @@ class DummyAggregate: # required fields geni_sliver = {'geni_sliver_urn': rspec_node['sliver_id'], 'geni_expires': rspec_node['expires'], - 'geni_allocation_status' : allocation_status, + 'geni_allocation_status': allocation_status, 'geni_operational_status': op_status, 'geni_error': '', } return geni_sliver - def list_resources(self, version = None, options=None): - if options is None: options={} + def list_resources(self, version=None, options=None): + if options is None: + options = {} version_manager = VersionManager() version = version_manager.get_version(version) - rspec_version = version_manager._get_version(version.type, version.version, 'ad') + rspec_version = version_manager._get_version( + version.type, version.version, 'ad') rspec = RSpec(version=rspec_version, user_options=options) # get nodes - nodes = self.get_nodes(options) + nodes = self.get_nodes(options) nodes_dict = {} for node in nodes: nodes_dict[node['node_id']] = node @@ -231,10 +249,12 @@ class DummyAggregate: return rspec.toxml() def describe(self, urns, version=None, options=None): - if options is None: options={} + if options is None: + options = {} version_manager = VersionManager() version = version_manager.get_version(version) - rspec_version = version_manager._get_version(version.type, version.version, 'manifest') + rspec_version = version_manager._get_version( + version.type, version.version, 'manifest') rspec = RSpec(version=rspec_version, user_options=options) # get slivers @@ -250,11 +270,13 @@ class DummyAggregate: geni_urn = urns[0] sliver_ids = [sliver['sliver_id'] for sliver in slivers] constraint = SliverAllocation.sliver_id.in_(sliver_ids) - sliver_allocations = self.driver.api.dbsession().query(SliverAllocation).filter(constraint) + sliver_allocations = self.driver.api.dbsession().query( + SliverAllocation).filter(constraint) sliver_allocation_dict = {} for sliver_allocation in sliver_allocations: geni_urn = sliver_allocation.slice_urn - sliver_allocation_dict[sliver_allocation.sliver_id] = sliver_allocation + sliver_allocation_dict[ + sliver_allocation.sliver_id] = sliver_allocation # add slivers nodes_dict = {} @@ -262,13 +284,14 @@ class DummyAggregate: nodes_dict[sliver['node_id']] = sliver rspec_nodes = [] for sliver in slivers: - rspec_node = self.sliver_to_rspec_node(sliver, sliver_allocation_dict) + rspec_node = self.sliver_to_rspec_node( + sliver, sliver_allocation_dict) rspec_nodes.append(rspec_node) - geni_sliver = self.rspec_node_to_geni_sliver(rspec_node, sliver_allocation_dict) + geni_sliver = self.rspec_node_to_geni_sliver( + rspec_node, sliver_allocation_dict) geni_slivers.append(geni_sliver) rspec.version.add_nodes(rspec_nodes) return {'geni_urn': geni_urn, 'geni_rspec': rspec.toxml(), 'geni_slivers': geni_slivers} - diff --git a/sfa/dummy/dummydriver.py b/sfa/dummy/dummydriver.py index a69662e1..01bc8d09 100644 --- a/sfa/dummy/dummydriver.py +++ b/sfa/dummy/dummydriver.py @@ -33,30 +33,34 @@ def list_to_dict(recs, key): convert a list of dictionaries into a dictionary keyed on the specified dictionary key """ - return dict ( [ (rec[key],rec) for rec in recs ] ) + return dict([(rec[key], rec) for rec in recs]) # -# DummyShell is just an xmlrpc serverproxy where methods can be sent as-is; -# +# DummyShell is just an xmlrpc serverproxy where methods can be sent as-is; +# + + class DummyDriver (Driver): - # the cache instance is a class member so it survives across incoming requests + # the cache instance is a class member so it survives across incoming + # requests cache = None - def __init__ (self, api): - Driver.__init__ (self, api) + def __init__(self, api): + Driver.__init__(self, api) config = api.config self.hrn = config.SFA_INTERFACE_HRN self.root_auth = config.SFA_REGISTRY_ROOT_AUTH - self.shell = DummyShell (config) + self.shell = DummyShell(config) self.testbedInfo = self.shell.GetTestbedInfo() - + def check_sliver_credentials(self, creds, urns): # build list of cred object hrns slice_cred_names = [] for cred in creds: slice_cred_hrn = Credential(cred=cred).get_gid_object().get_hrn() - slice_cred_names.append(DummyXrn(xrn=slice_cred_hrn).dummy_slicename()) + slice_cred_names.append( + DummyXrn(xrn=slice_cred_hrn).dummy_slicename()) # look up slice name of slivers listed in urns arg slice_ids = [] @@ -68,7 +72,7 @@ class DummyDriver (Driver): pass if not slice_ids: - raise Forbidden("sliver urn not provided") + raise Forbidden("sliver urn not provided") slices = self.shell.GetSlices({'slice_ids': slice_ids}) sliver_names = [slice['slice_name'] for slice in slices] @@ -80,37 +84,38 @@ class DummyDriver (Driver): raise Forbidden(msg) ######################################## - ########## registry oriented + # registry oriented ######################################## - def augment_records_with_testbed_info (self, sfa_records): - return self.fill_record_info (sfa_records) + def augment_records_with_testbed_info(self, sfa_records): + return self.fill_record_info(sfa_records) - ########## - def register (self, sfa_record, hrn, pub_key): + ########## + def register(self, sfa_record, hrn, pub_key): type = sfa_record['type'] dummy_record = self.sfa_fields_to_dummy_fields(type, hrn, sfa_record) - + if type == 'authority': pointer = -1 elif type == 'slice': - slices = self.shell.GetSlices({'slice_name': dummy_record['slice_name']}) + slices = self.shell.GetSlices( + {'slice_name': dummy_record['slice_name']}) if not slices: - pointer = self.shell.AddSlice(dummy_record) + pointer = self.shell.AddSlice(dummy_record) else: - pointer = slices[0]['slice_id'] + pointer = slices[0]['slice_id'] elif type == 'user': - users = self.shell.GetUsers({'email':sfa_record['email']}) + users = self.shell.GetUsers({'email': sfa_record['email']}) if not users: pointer = self.shell.AddUser(dummy_record) else: pointer = users[0]['user_id'] - + # Add the user's key if pub_key: - self.shell.AddUserKey({'user_id' : pointer, 'key' : pub_key}) + self.shell.AddUserKey({'user_id': pointer, 'key': pub_key}) elif type == 'node': nodes = self.shell.GetNodes(dummy_record['hostname']) @@ -118,40 +123,39 @@ class DummyDriver (Driver): pointer = self.shell.AddNode(dummy_record) else: pointer = users[0]['node_id'] - + return pointer - + ########## - def update (self, old_sfa_record, new_sfa_record, hrn, new_key): + def update(self, old_sfa_record, new_sfa_record, hrn, new_key): pointer = old_sfa_record['pointer'] type = old_sfa_record['type'] - dummy_record=self.sfa_fields_to_dummy_fields(type, hrn, new_sfa_record) + dummy_record = self.sfa_fields_to_dummy_fields( + type, hrn, new_sfa_record) # new_key implemented for users only - if new_key and type not in [ 'user' ]: + if new_key and type not in ['user']: raise UnknownSfaType(type) - if type == "slice": - self.shell.UpdateSlice({'slice_id': pointer, 'fields': dummy_record}) - + self.shell.UpdateSlice( + {'slice_id': pointer, 'fields': dummy_record}) + elif type == "user": self.shell.UpdateUser({'user_id': pointer, 'fields': dummy_record}) if new_key: - self.shell.AddUserKey({'user_id' : pointer, 'key' : new_key}) + self.shell.AddUserKey({'user_id': pointer, 'key': new_key}) elif type == "node": self.shell.UpdateNode({'node_id': pointer, 'fields': dummy_record}) - return True - ########## - def remove (self, sfa_record): - type=sfa_record['type'] - pointer=sfa_record['pointer'] + def remove(self, sfa_record): + type = sfa_record['type'] + pointer = sfa_record['pointer'] if type == 'user': self.shell.DeleteUser({'user_id': pointer}) elif type == 'slice': @@ -161,10 +165,6 @@ class DummyDriver (Driver): return True - - - - ## # Convert SFA fields to Dummy testbed fields for use when registering or updating # registry record in the dummy testbed @@ -173,19 +173,19 @@ class DummyDriver (Driver): def sfa_fields_to_dummy_fields(self, type, hrn, sfa_record): dummy_record = {} - + if type == "slice": dummy_record["slice_name"] = hrn_to_dummy_slicename(hrn) - + elif type == "node": if "hostname" not in sfa_record: raise MissingSfaInfo("hostname") dummy_record["hostname"] = sfa_record["hostname"] if "type" in sfa_record: - dummy_record["type"] = sfa_record["type"] + dummy_record["type"] = sfa_record["type"] else: - dummy_record["type"] = "dummy_type" - + dummy_record["type"] = "dummy_type" + elif type == "authority": dummy_record["name"] = hrn @@ -214,13 +214,13 @@ class DummyDriver (Driver): Fill in the DUMMY specific fields of a SFA record. This involves calling the appropriate DUMMY method to retrieve the database record for the object. - + @param record: record to fill in field (in/out param) """ # get ids by type - node_ids, slice_ids, user_ids = [], [], [] + node_ids, slice_ids, user_ids = [], [], [] type_map = {'node': node_ids, 'slice': slice_ids, 'user': user_ids} - + for record in records: for type in type_map: if type == record['type']: @@ -229,10 +229,10 @@ class DummyDriver (Driver): # get dummy records nodes, slices, users = {}, {}, {} if node_ids: - node_list = self.shell.GetNodes({'node_ids':node_ids}) + node_list = self.shell.GetNodes({'node_ids': node_ids}) nodes = list_to_dict(node_list, 'node_id') if slice_ids: - slice_list = self.shell.GetSlices({'slice_ids':slice_ids}) + slice_list = self.shell.GetSlices({'slice_ids': slice_ids}) slices = list_to_dict(slice_list, 'slice_id') if user_ids: user_list = self.shell.GetUsers({'user_ids': user_ids}) @@ -240,13 +240,12 @@ class DummyDriver (Driver): dummy_records = {'node': nodes, 'slice': slices, 'user': users} - # fill record info for record in records: # records with pointer==-1 do not have dummy info. if record['pointer'] == -1: continue - + for type in dummy_records: if record['type'] == type: if record['pointer'] in dummy_records[type]: @@ -257,8 +256,8 @@ class DummyDriver (Driver): record['key_ids'] = [] record['keys'] = [] for key in dummy_records['user'][record['pointer']]['keys']: - record['key_ids'].append(-1) - record['keys'].append(key) + record['key_ids'].append(-1) + record['keys'].append(key) return records @@ -284,11 +283,11 @@ class DummyDriver (Driver): users = list_to_dict(user_list, 'user_id') if slice_ids: slice_list = self.shell.GetSlices({'slice_ids': slice_ids}) - slices = list_to_dict(slice_list, 'slice_id') + slices = list_to_dict(slice_list, 'slice_id') if node_ids: node_list = self.shell.GetNodes({'node_ids': node_ids}) nodes = list_to_dict(node_list, 'node_id') - + # convert ids to hrns for record in records: # get all relevant data @@ -300,24 +299,26 @@ class DummyDriver (Driver): continue if 'user_ids' in record: - emails = [users[user_id]['email'] for user_id in record['user_ids'] \ - if user_id in users] + emails = [users[user_id]['email'] for user_id in record['user_ids'] + if user_id in users] usernames = [email.split('@')[0] for email in emails] - user_hrns = [".".join([auth_hrn, testbed_name, username]) for username in usernames] - record['users'] = user_hrns + user_hrns = [".".join([auth_hrn, testbed_name, username]) + for username in usernames] + record['users'] = user_hrns if 'slice_ids' in record: - slicenames = [slices[slice_id]['slice_name'] for slice_id in record['slice_ids'] \ + slicenames = [slices[slice_id]['slice_name'] for slice_id in record['slice_ids'] if slice_id in slices] - slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames] + slice_hrns = [slicename_to_hrn( + auth_hrn, slicename) for slicename in slicenames] record['slices'] = slice_hrns if 'node_ids' in record: - hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \ + hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] if node_id in nodes] - node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames] + node_hrns = [hostname_to_hrn( + auth_hrn, login_base, hostname) for hostname in hostnames] record['nodes'] = node_hrns - - return records + return records def fill_record_sfa_info(self, records): @@ -328,14 +329,15 @@ class DummyDriver (Driver): user_ids = [] for record in records: user_ids.extend(record.get("user_ids", [])) - - # get sfa records for all records associated with these records. + + # get sfa records for all records associated with these records. # we'll replace pl ids (person_ids) with hrns from the sfa records # we obtain - + # get the registry records user_list, users = [], {} - user_list = self.api.dbsession().query (RegRecord).filter(RegRecord.pointer.in_(user_ids)) + user_list = self.api.dbsession().query( + RegRecord).filter(RegRecord.pointer.in_(user_ids)) # create a hrns keyed on the sfa record's pointer. # Its possible for multiple records to have the same pointer so # the dict's value will be a list of hrns. @@ -351,11 +353,12 @@ class DummyDriver (Driver): # fill sfa info for record in records: # skip records with no pl info (top level authorities) - #if record['pointer'] == -1: - # continue + # if record['pointer'] == -1: + # continue sfa_info = {} type = record['type'] - logger.info("fill_record_sfa_info - incoming record typed %s"%type) + logger.info( + "fill_record_sfa_info - incoming record typed %s" % type) if (type == "slice"): # all slice users are researchers record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice') @@ -363,7 +366,7 @@ class DummyDriver (Driver): record['researcher'] = [] for user_id in record.get('user_ids', []): hrns = [user.hrn for user in users[user_id]] - record['researcher'].extend(hrns) + record['researcher'].extend(hrns) elif (type.startswith("authority")): record['url'] = None @@ -372,72 +375,81 @@ class DummyDriver (Driver): elif (type == "node"): sfa_info['dns'] = record.get("hostname", "") # xxx TODO: URI, LatLong, IP, DNS - + elif (type == "user"): logger.info('setting user.email') sfa_info['email'] = record.get("email", "") sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user') - sfa_info['geni_certificate'] = record['gid'] + sfa_info['geni_certificate'] = record['gid'] # xxx TODO: PostalAddress, Phone record.update(sfa_info) - #################### - def update_relation (self, subject_type, target_type, relation_name, subject_id, target_ids): + def update_relation(self, subject_type, target_type, relation_name, subject_id, target_ids): # hard-wire the code for slice/user for now, could be smarter if needed - if subject_type =='slice' and target_type == 'user' and relation_name == 'researcher': - subject=self.shell.GetSlices ({'slice_id': subject_id})[0] + if subject_type == 'slice' and target_type == 'user' and relation_name == 'researcher': + subject = self.shell.GetSlices({'slice_id': subject_id})[0] if 'user_ids' not in subject.keys(): - subject['user_ids'] = [] + subject['user_ids'] = [] current_target_ids = subject['user_ids'] - add_target_ids = list ( set (target_ids).difference(current_target_ids)) - del_target_ids = list ( set (current_target_ids).difference(target_ids)) - logger.debug ("subject_id = %s (type=%s)"%(subject_id,type(subject_id))) + add_target_ids = list( + set(target_ids).difference(current_target_ids)) + del_target_ids = list( + set(current_target_ids).difference(target_ids)) + logger.debug("subject_id = %s (type=%s)" % + (subject_id, type(subject_id))) for target_id in add_target_ids: - self.shell.AddUserToSlice ({'user_id': target_id, 'slice_id': subject_id}) - logger.debug ("add_target_id = %s (type=%s)"%(target_id,type(target_id))) + self.shell.AddUserToSlice( + {'user_id': target_id, 'slice_id': subject_id}) + logger.debug("add_target_id = %s (type=%s)" % + (target_id, type(target_id))) for target_id in del_target_ids: - logger.debug ("del_target_id = %s (type=%s)"%(target_id,type(target_id))) - self.shell.DeleteUserFromSlice ({'user_id': target_id, 'slice_id': subject_id}) + logger.debug("del_target_id = %s (type=%s)" % + (target_id, type(target_id))) + self.shell.DeleteUserFromSlice( + {'user_id': target_id, 'slice_id': subject_id}) else: - logger.info('unexpected relation %s to maintain, %s -> %s'%(relation_name,subject_type,target_type)) + logger.info('unexpected relation %s to maintain, %s -> %s' % + (relation_name, subject_type, target_type)) - ######################################## - ########## aggregate oriented + # aggregate oriented ######################################## - def testbed_name (self): return "dummy" + def testbed_name(self): return "dummy" - def aggregate_version (self): + def aggregate_version(self): return {} - def list_resources (self, version=None, options=None): - if options is None: options={} + def list_resources(self, version=None, options=None): + if options is None: + options = {} aggregate = DummyAggregate(self) - rspec = aggregate.list_resources(version=version, options=options) + rspec = aggregate.list_resources(version=version, options=options) return rspec def describe(self, urns, version, options=None): - if options is None: options={} + if options is None: + options = {} aggregate = DummyAggregate(self) return aggregate.describe(urns, version=version, options=options) - - def status (self, urns, options=None): - if options is None: options={} + + def status(self, urns, options=None): + if options is None: + options = {} aggregate = DummyAggregate(self) - desc = aggregate.describe(urns, version='GENI 3') + desc = aggregate.describe(urns, version='GENI 3') status = {'geni_urn': desc['geni_urn'], 'geni_slivers': desc['geni_slivers']} return status - - def allocate (self, urn, rspec_string, expiration, options=None): - if options is None: options={} + def allocate(self, urn, rspec_string, expiration, options=None): + if options is None: + options = {} xrn = Xrn(urn) aggregate = DummyAggregate(self) slices = DummySlices(self) - slice_record=None + slice_record = None users = options.get('geni_users', []) if users: slice_record = users[0].get('slice_record', {}) @@ -447,7 +459,8 @@ class DummyDriver (Driver): requested_attributes = rspec.version.get_slice_attributes() # ensure slice record exists - slice = slices.verify_slice(xrn.hrn, slice_record, expiration=expiration, options=options) + slice = slices.verify_slice( + xrn.hrn, slice_record, expiration=expiration, options=options) # ensure person records exists #persons = slices.verify_persons(xrn.hrn, slice, users, peer, sfa_peer, options=options) @@ -458,7 +471,8 @@ class DummyDriver (Driver): return aggregate.describe([xrn.get_urn()], version=rspec.version) def provision(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} # update users slices = DummySlices(self) aggregate = DummyAggregate(self) @@ -468,14 +482,17 @@ class DummyDriver (Driver): #users = slices.verify_users(None, slice, geni_users, options=options) # update sliver allocation states and set them to geni_provisioned sliver_ids = [sliver['sliver_id'] for sliver in slivers] - dbsession=self.api.dbsession() - SliverAllocation.set_allocations(sliver_ids, 'geni_provisioned',dbsession) + dbsession = self.api.dbsession() + SliverAllocation.set_allocations( + sliver_ids, 'geni_provisioned', dbsession) version_manager = VersionManager() - rspec_version = version_manager.get_version(options['geni_rspec_version']) + rspec_version = version_manager.get_version( + options['geni_rspec_version']) return self.describe(urns, rspec_version, options=options) def delete(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} # collect sliver ids so we can update sliver allocation states after # we remove the slivers. aggregate = DummyAggregate(self) @@ -489,15 +506,17 @@ class DummyDriver (Driver): sliver_ids.append(sliver['sliver_id']) # determine if this is a peer slice - # xxx I wonder if this would not need to use PlSlices.get_peer instead + # xxx I wonder if this would not need to use PlSlices.get_peer instead # in which case plc.peers could be deprecated as this here # is the only/last call to this last method in plc.peers - slice_hrn = DummyXrn(auth=self.hrn, slicename=slivers[0]['slice_name']).get_hrn() + slice_hrn = DummyXrn(auth=self.hrn, slicename=slivers[ + 0]['slice_name']).get_hrn() try: - self.shell.DeleteSliceFromNodes({'slice_id': slice_id, 'node_ids': node_ids}) + self.shell.DeleteSliceFromNodes( + {'slice_id': slice_id, 'node_ids': node_ids}) # delete sliver allocation states - dbsession=self.api.dbsession() - SliverAllocation.delete_allocations(sliver_ids,dbsession) + dbsession = self.api.dbsession() + SliverAllocation.delete_allocations(sliver_ids, dbsession) finally: pass @@ -507,11 +526,12 @@ class DummyDriver (Driver): geni_slivers.append( {'geni_sliver_urn': sliver['sliver_id'], 'geni_allocation_status': 'geni_unallocated', - 'geni_expires': datetime_to_string(utcparse(sliver['expires']))}) + 'geni_expires': datetime_to_string(utcparse(sliver['expires']))}) return geni_slivers - def renew (self, urns, expiration_time, options=None): - if options is None: options={} + def renew(self, urns, expiration_time, options=None): + if options is None: + options = {} aggregate = DummyAggregate(self) slivers = aggregate.get_slivers(urns) if not slivers: @@ -519,23 +539,27 @@ class DummyDriver (Driver): slice = slivers[0] requested_time = utcparse(expiration_time) record = {'expires': int(datetime_to_epoch(requested_time))} - self.shell.UpdateSlice({'slice_id': slice['slice_id'], 'fileds': record}) + self.shell.UpdateSlice( + {'slice_id': slice['slice_id'], 'fileds': record}) description = self.describe(urns, 'GENI 3', options) return description['geni_slivers'] - def perform_operational_action (self, urns, action, options=None): - if options is None: options={} + def perform_operational_action(self, urns, action, options=None): + if options is None: + options = {} # Dummy doesn't support operational actions. Lets pretend like it # supports start, but reject everything else. action = action.lower() if action not in ['geni_start']: raise UnsupportedOperation(action) - # fault if sliver is not full allocated (operational status is geni_pending_allocation) + # fault if sliver is not full allocated (operational status is + # geni_pending_allocation) description = self.describe(urns, 'GENI 3', options) for sliver in description['geni_slivers']: if sliver['geni_operational_status'] == 'geni_pending_allocation': - raise UnsupportedOperation(action, "Sliver must be fully allocated (operational status is not geni_pending_allocation)") + raise UnsupportedOperation( + action, "Sliver must be fully allocated (operational status is not geni_pending_allocation)") # # Perform Operational Action Here # @@ -543,15 +567,17 @@ class DummyDriver (Driver): geni_slivers = self.describe(urns, 'GENI 3', options)['geni_slivers'] return geni_slivers - def shutdown (self, xrn, options=None): - if options is None: options={} + def shutdown(self, xrn, options=None): + if options is None: + options = {} xrn = DummyXrn(xrn=xrn, type='slice') slicename = xrn.pl_slicename() slices = self.shell.GetSlices({'name': slicename}, ['slice_id']) if not slices: raise RecordNotFound(slice_hrn) slice_id = slices[0]['slice_id'] - slice_tags = self.shell.GetSliceTags({'slice_id': slice_id, 'tagname': 'enabled'}) + slice_tags = self.shell.GetSliceTags( + {'slice_id': slice_id, 'tagname': 'enabled'}) if not slice_tags: self.shell.AddSliceTag(slice_id, 'enabled', '0') elif slice_tags[0]['value'] != "0": diff --git a/sfa/dummy/dummyshell.py b/sfa/dummy/dummyshell.py index 67007d7e..d647cb5f 100644 --- a/sfa/dummy/dummyshell.py +++ b/sfa/dummy/dummyshell.py @@ -5,30 +5,31 @@ from urlparse import urlparse from sfa.util.sfalogging import logger from sfa.util.py23 import xmlrpc_client + class DummyShell: """ A simple xmlrpc shell to the dummy testbed API instance """ - - direct_calls = ['AddNode', 'AddSlice', 'AddUser', 'AddUserKey', 'AddUserToSlice', 'AddSliceToNodes', + + direct_calls = ['AddNode', 'AddSlice', 'AddUser', 'AddUserKey', 'AddUserToSlice', 'AddSliceToNodes', 'GetTestbedInfo', 'GetNodes', 'GetSlices', 'GetUsers', - 'DeleteNode', 'DeleteSlice', 'DeleteUser', 'DeleteKey', 'DeleteUserFromSlice', + 'DeleteNode', 'DeleteSlice', 'DeleteUser', 'DeleteKey', 'DeleteUserFromSlice', 'DeleteSliceFromNodes', 'UpdateNode', 'UpdateSlice', 'UpdateUser', - ] + ] - - def __init__ ( self, config ) : + def __init__(self, config): url = config.SFA_DUMMY_URL - self.proxy = xmlrpc_client.ServerProxy(url, verbose = False, allow_none = True) + self.proxy = xmlrpc_client.ServerProxy( + url, verbose=False, allow_none=True) def __getattr__(self, name): def func(*args, **kwds): if not name in DummyShell.direct_calls: - raise Exception("Illegal method call %s for DUMMY driver"%(name)) - result=getattr(self.proxy, name)(*args, **kwds) - logger.debug('DummyShell %s returned ... '%(name)) + raise Exception( + "Illegal method call %s for DUMMY driver" % (name)) + result = getattr(self.proxy, name)(*args, **kwds) + logger.debug('DummyShell %s returned ... ' % (name)) return result return func - diff --git a/sfa/dummy/dummyslices.py b/sfa/dummy/dummyslices.py index 394e816f..6c0f8aec 100644 --- a/sfa/dummy/dummyslices.py +++ b/sfa/dummy/dummyslices.py @@ -10,7 +10,8 @@ from sfa.storage.model import SliverAllocation from sfa.dummy.dummyxrn import DummyXrn, hrn_to_dummy_slicename -MAXINT = 2L**31-1 +MAXINT = 2L**31 - 1 + class DummySlices: @@ -19,18 +20,18 @@ class DummySlices: def get_slivers(self, xrn, node=None): hrn, type = urn_to_hrn(xrn) - + slice_name = hrn_to_dummy_slicename(hrn) - + slices = self.driver.shell.GetSlices({'slice_name': slice_name}) slice = slices[0] # Build up list of users and slice attributes user_ids = slice['user_ids'] # Get user information - all_users_list = self.driver.shell.GetUsers({'user_id':user_ids}) + all_users_list = self.driver.shell.GetUsers({'user_id': user_ids}) all_users = {} for user in all_users_list: - all_users[user['user_id']] = user + all_users[user['user_id']] = user # Build up list of keys all_keys = set() @@ -42,8 +43,9 @@ class DummySlices: keys = all_keys # XXX Sanity check; though technically this should be a system invariant # checked with an assertion - if slice['expires'] > MAXINT: slice['expires']= MAXINT - + if slice['expires'] > MAXINT: + slice['expires'] = MAXINT + slivers.append({ 'hrn': hrn, 'name': slice['name'], @@ -53,7 +55,6 @@ class DummySlices: }) return slivers - def verify_slice_nodes(self, slice_urn, slice, rspec_nodes): @@ -67,7 +68,8 @@ class DummySlices: elif component_id: hostname = xrn_to_hostname(component_id) if hostname: - slivers[hostname] = {'client_id': client_id, 'component_id': component_id} + slivers[hostname] = { + 'client_id': client_id, 'component_id': component_id} all_nodes = self.driver.shell.GetNodes() requested_slivers = [] for node in all_nodes: @@ -75,63 +77,72 @@ class DummySlices: requested_slivers.append(node['node_id']) if 'node_ids' not in slice.keys(): - slice['node_ids']=[] + slice['node_ids'] = [] nodes = self.driver.shell.GetNodes({'node_ids': slice['node_ids']}) current_slivers = [node['node_id'] for node in nodes] # remove nodes not in rspec - deleted_nodes = list(set(current_slivers).difference(requested_slivers)) + deleted_nodes = list( + set(current_slivers).difference(requested_slivers)) # add nodes from rspec - added_nodes = list(set(requested_slivers).difference(current_slivers)) + added_nodes = list(set(requested_slivers).difference(current_slivers)) try: - self.driver.shell.AddSliceToNodes({'slice_id': slice['slice_id'], 'node_ids': added_nodes}) - self.driver.shell.DeleteSliceFromNodes({'slice_id': slice['slice_id'], 'node_ids': deleted_nodes}) + self.driver.shell.AddSliceToNodes( + {'slice_id': slice['slice_id'], 'node_ids': added_nodes}) + self.driver.shell.DeleteSliceFromNodes( + {'slice_id': slice['slice_id'], 'node_ids': deleted_nodes}) - except: + except: logger.log_exc('Failed to add/remove slice from nodes') - slices = self.driver.shell.GetSlices({'slice_name': slice['slice_name']}) - resulting_nodes = self.driver.shell.GetNodes({'node_ids': slices[0]['node_ids']}) + slices = self.driver.shell.GetSlices( + {'slice_name': slice['slice_name']}) + resulting_nodes = self.driver.shell.GetNodes( + {'node_ids': slices[0]['node_ids']}) # update sliver allocations for node in resulting_nodes: client_id = slivers[node['hostname']]['client_id'] component_id = slivers[node['hostname']]['component_id'] - sliver_hrn = '%s.%s-%s' % (self.driver.hrn, slice['slice_id'], node['node_id']) + sliver_hrn = '%s.%s-%s' % (self.driver.hrn, + slice['slice_id'], node['node_id']) sliver_id = Xrn(sliver_hrn, type='sliver').urn record = SliverAllocation(sliver_id=sliver_id, client_id=client_id, component_id=component_id, - slice_urn = slice_urn, + slice_urn=slice_urn, allocation_state='geni_allocated') record.sync(self.driver.api.dbsession()) return resulting_nodes - def verify_slice(self, slice_hrn, slice_record, expiration, options=None): - if options is None: options={} + if options is None: + options = {} slicename = hrn_to_dummy_slicename(slice_hrn) parts = slicename.split("_") login_base = parts[0] - slices = self.driver.shell.GetSlices({'slice_name': slicename}) + slices = self.driver.shell.GetSlices({'slice_name': slicename}) if not slices: slice = {'slice_name': slicename} - # add the slice + # add the slice slice['slice_id'] = self.driver.shell.AddSlice(slice) slice['node_ids'] = [] slice['user_ids'] = [] else: slice = slices[0] if slice_record and slice_record.get('expires'): - requested_expires = int(datetime_to_epoch(utcparse(slice_record['expires']))) + requested_expires = int(datetime_to_epoch( + utcparse(slice_record['expires']))) if requested_expires and slice['expires'] != requested_expires: - self.driver.shell.UpdateSlice( {'slice_id': slice['slice_id'], 'fields':{'expires' : expiration}}) - + self.driver.shell.UpdateSlice( + {'slice_id': slice['slice_id'], 'fields': {'expires': expiration}}) + return slice def verify_users(self, slice_hrn, slice_record, users, options=None): - if options is None: options={} + if options is None: + options = {} slice_name = hrn_to_dummy_slicename(slice_hrn) users_by_email = {} for user in users: @@ -141,38 +152,41 @@ class DummySlices: user['username'] = username if 'email' in user: - user['email'] = user['email'].lower() + user['email'] = user['email'].lower() users_by_email[user['email']] = user - + # start building a list of existing users existing_users_by_email = {} existing_slice_users_by_email = {} existing_users = self.driver.shell.GetUsers() - existing_slice_users_ids = self.driver.shell.GetSlices({'slice_name': slice_name})[0]['user_ids'] + existing_slice_users_ids = self.driver.shell.GetSlices( + {'slice_name': slice_name})[0]['user_ids'] for user in existing_users: - existing_users_by_email[user['email']] = user - if user['user_id'] in existing_slice_users_ids: + existing_users_by_email[user['email']] = user + if user['user_id'] in existing_slice_users_ids: existing_slice_users_by_email[user['email']] = user - - add_users_by_email = set(users_by_email).difference(existing_slice_user_by_email) - delete_users_by_email = set(existing_slice_user_by_email).difference(users_by_email) + + add_users_by_email = set(users_by_email).difference( + existing_slice_user_by_email) + delete_users_by_email = set( + existing_slice_user_by_email).difference(users_by_email) try: - for user in add_users_by_email: + for user in add_users_by_email: self.driver.shell.AddUser() - except: + except: pass - def verify_keys(self, old_users, new_users, options=None): - if options is None: options={} - # existing keys + if options is None: + options = {} + # existing keys existing_keys = [] for user in old_users: - existing_keys.append(user['keys']) + existing_keys.append(user['keys']) userdict = {} for user in old_users: - userdict[user['email']] = user - + userdict[user['email']] = user + # add new keys requested_keys = [] updated_users = [] @@ -184,17 +198,17 @@ class DummySlices: if key_string not in existing_keys: key = key_string try: - self.driver.shell.AddUserKey({'user_id': user['user_id'], 'key':key}) - + self.driver.shell.AddUserKey( + {'user_id': user['user_id'], 'key': key}) + except: - pass + pass # remove old keys (only if we are not appending) append = options.get('append', True) - if append == False: + if append == False: removed_keys = set(existing_keys).difference(requested_keys) for key in removed_keys: - try: - self.driver.shell.DeleteKey({'key': key}) - except: - pass - + try: + self.driver.shell.DeleteKey({'key': key}) + except: + pass diff --git a/sfa/dummy/dummyxrn.py b/sfa/dummy/dummyxrn.py index 05c36816..b934da21 100644 --- a/sfa/dummy/dummyxrn.py +++ b/sfa/dummy/dummyxrn.py @@ -3,58 +3,74 @@ import re from sfa.util.xrn import Xrn # temporary helper functions to use this module instead of namespace -def hostname_to_hrn (auth, testbed_name, hostname): - return DummyXrn(auth=auth+'.'+testbed_name,hostname=hostname).get_hrn() + + +def hostname_to_hrn(auth, testbed_name, hostname): + return DummyXrn(auth=auth + '.' + testbed_name, hostname=hostname).get_hrn() + + def hostname_to_urn(auth, testbed_name, hostname): - return DummyXrn(auth=auth+'.'+testbed_name,hostname=hostname).get_urn() -def slicename_to_hrn (auth_hrn, slicename): - return DummyXrn(auth=auth_hrn,slicename=slicename).get_hrn() -def email_to_hrn (auth_hrn, email): + return DummyXrn(auth=auth + '.' + testbed_name, hostname=hostname).get_urn() + + +def slicename_to_hrn(auth_hrn, slicename): + return DummyXrn(auth=auth_hrn, slicename=slicename).get_hrn() + + +def email_to_hrn(auth_hrn, email): return DummyXrn(auth=auth_hrn, email=email).get_hrn() -def hrn_to_dummy_slicename (hrn): - return DummyXrn(xrn=hrn,type='slice').dummy_slicename() -def hrn_to_dummy_authname (hrn): - return DummyXrn(xrn=hrn,type='any').dummy_authname() + + +def hrn_to_dummy_slicename(hrn): + return DummyXrn(xrn=hrn, type='slice').dummy_slicename() + + +def hrn_to_dummy_authname(hrn): + return DummyXrn(xrn=hrn, type='any').dummy_authname() + + def xrn_to_hostname(hrn): return Xrn.unescape(PlXrn(xrn=hrn, type='node').get_leaf()) + class DummyXrn (Xrn): - @staticmethod - def site_hrn (auth, testbed_name): - return '.'.join([auth,testbed_name]) + @staticmethod + def site_hrn(auth, testbed_name): + return '.'.join([auth, testbed_name]) - def __init__ (self, auth=None, hostname=None, slicename=None, email=None, interface=None, **kwargs): - #def hostname_to_hrn(auth_hrn, login_base, hostname): + def __init__(self, auth=None, hostname=None, slicename=None, email=None, interface=None, **kwargs): + # def hostname_to_hrn(auth_hrn, login_base, hostname): if hostname is not None: - self.type='node' + self.type = 'node' # keep only the first part of the DNS name #self.hrn='.'.join( [auth,hostname.split(".")[0] ] ) # escape the '.' in the hostname - self.hrn='.'.join( [auth,Xrn.escape(hostname)] ) + self.hrn = '.'.join([auth, Xrn.escape(hostname)]) self.hrn_to_urn() - #def slicename_to_hrn(auth_hrn, slicename): + # def slicename_to_hrn(auth_hrn, slicename): elif slicename is not None: - self.type='slice' + self.type = 'slice' # split at the first _ - parts = slicename.split("_",1) - self.hrn = ".".join([auth] + parts ) + parts = slicename.split("_", 1) + self.hrn = ".".join([auth] + parts) self.hrn_to_urn() - #def email_to_hrn(auth_hrn, email): + # def email_to_hrn(auth_hrn, email): elif email is not None: - self.type='person' + self.type = 'person' # keep only the part before '@' and replace special chars into _ - self.hrn='.'.join([auth,email.split('@')[0].replace(".", "_").replace("+", "_")]) + self.hrn = '.'.join( + [auth, email.split('@')[0].replace(".", "_").replace("+", "_")]) self.hrn_to_urn() elif interface is not None: self.type = 'interface' self.hrn = auth + '.' + interface self.hrn_to_urn() else: - Xrn.__init__ (self,**kwargs) + Xrn.__init__(self, **kwargs) - #def hrn_to_pl_slicename(hrn): - def dummy_slicename (self): + # def hrn_to_pl_slicename(hrn): + def dummy_slicename(self): self._normalize() leaf = self.leaf sliver_id_parts = leaf.split(':') @@ -62,8 +78,8 @@ class DummyXrn (Xrn): name = re.sub('[^a-zA-Z0-9_]', '', name) return name - #def hrn_to_pl_authname(hrn): - def dummy_authname (self): + # def hrn_to_pl_authname(hrn): + def dummy_authname(self): self._normalize() return self.authority[-1] @@ -71,18 +87,18 @@ class DummyXrn (Xrn): self._normalize() return self.leaf - def dummy_login_base (self): + def dummy_login_base(self): self._normalize() if self.type and self.type.startswith('authority'): - base = self.leaf + base = self.leaf else: base = self.authority[-1] - + # Fix up names of GENI Federates base = base.lower() base = re.sub('\\\[^a-zA-Z0-9]', '', base) if len(base) > 20: - base = base[len(base)-20:] - + base = base[len(base) - 20:] + return base diff --git a/sfa/examples/miniclient.py b/sfa/examples/miniclient.py index 535b4465..2366c782 100755 --- a/sfa/examples/miniclient.py +++ b/sfa/examples/miniclient.py @@ -8,95 +8,105 @@ from __future__ import print_function # init logging on console import logging console = logging.StreamHandler() -logger=logging.getLogger('') +logger = logging.getLogger('') logger.addHandler(console) logger.setLevel(logging.DEBUG) import uuid + + def unique_call_id(): return uuid.uuid4().urn # use sys.argv to point to a completely fresh directory import sys -args=sys.argv[1:] -if len(args)!=1: - print("Usage: %s directory"%sys.argv[0]) +args = sys.argv[1:] +if len(args) != 1: + print("Usage: %s directory" % sys.argv[0]) sys.exit(1) -dir=args[0] -logger.debug('sfaclientsample: Using directory %s'%dir) +dir = args[0] +logger.debug('sfaclientsample: Using directory %s' % dir) ### # this uses a test sfa deployment at openlab -registry_url="http://sfa1.pl.sophia.inria.fr:12345/" -aggregate_url="http://sfa1.pl.sophia.inria.fr:12347/" +registry_url = "http://sfa1.pl.sophia.inria.fr:12345/" +aggregate_url = "http://sfa1.pl.sophia.inria.fr:12347/" # this is where the private key sits - would be ~/.ssh/id_rsa in most cases # but in this context, create this local file # the tests key pair can be found in # http://git.onelab.eu/?p=tests.git;a=blob;f=system/config_default.py # search for public_key / private_key -private_key="miniclient-private-key" +private_key = "miniclient-private-key" # user hrn -user_hrn="pla.inri.fake-pi1" +user_hrn = "pla.inri.fake-pi1" -slice_hrn="pla.inri.slpl1" +slice_hrn = "pla.inri.slpl1" # hrn_to_urn(slice_hrn,'slice') -slice_urn='urn:publicid:IDN+pla:inri+slice+slpl1' +slice_urn = 'urn:publicid:IDN+pla:inri+slice+slpl1' from sfa.client.sfaclientlib import SfaClientBootstrap -bootstrap = SfaClientBootstrap (user_hrn, registry_url, dir=dir, logger=logger) +bootstrap = SfaClientBootstrap(user_hrn, registry_url, dir=dir, logger=logger) # install the private key in the client directory from 'private_key' bootstrap.init_private_key_if_missing(private_key) + def truncate(content, length=20, suffix='...'): - if isinstance (content, (int) ): return content - if isinstance (content, list): return truncate ( "%s"%content, length, suffix) + if isinstance(content, (int)): + return content + if isinstance(content, list): + return truncate("%s" % content, length, suffix) if len(content) <= length: return content else: - return content[:length+1]+ ' '+suffix + return content[:length + 1] + ' ' + suffix -### issue a GetVersion call -### this assumes we've already somehow initialized the certificate -def get_version (url): +# issue a GetVersion call +# this assumes we've already somehow initialized the certificate +def get_version(url): # make sure we have a self-signed cert bootstrap.self_signed_cert() server_proxy = bootstrap.server_proxy_simple(url) server_version = server_proxy.GetVersion() - print("miniclient: GetVersion at %s returned:"%(url)) - for (k,v) in server_version.iteritems(): print("miniclient: \tversion[%s]=%s"%(k,truncate(v))) + print("miniclient: GetVersion at %s returned:" % (url)) + for (k, v) in server_version.iteritems(): + print("miniclient: \tversion[%s]=%s" % (k, truncate(v))) # version_dict = {'type': 'SFA', 'version': '1', } -version_dict = {'type':'ProtoGENI', 'version':'2'} +version_dict = {'type': 'ProtoGENI', 'version': '2'} # ditto with list resources -def list_resources (): +def list_resources(): bootstrap.bootstrap_my_gid() credential = bootstrap.my_credential_string() - credentials = [ credential ] + credentials = [credential] options = {} - options [ 'geni_rspec_version' ] = version_dict - options [ 'call_id' ] = unique_call_id() - list_resources = bootstrap.server_proxy (aggregate_url).ListResources(credentials,options) - print("miniclient: ListResources at %s returned : %s"%(aggregate_url,truncate(list_resources))) + options['geni_rspec_version'] = version_dict + options['call_id'] = unique_call_id() + list_resources = bootstrap.server_proxy( + aggregate_url).ListResources(credentials, options) + print("miniclient: ListResources at %s returned : %s" % + (aggregate_url, truncate(list_resources))) -def list_slice_resources (): + +def list_slice_resources(): bootstrap.bootstrap_my_gid() - credential = bootstrap.slice_credential_string (slice_hrn) - credentials = [ credential ] - options = { } - options [ 'geni_rspec_version' ] = version_dict - options [ 'geni_slice_urn' ] = slice_urn - options [ 'call_id' ] = unique_call_id() - list_resources = bootstrap.server_proxy (aggregate_url).ListResources(credentials,options) - print("miniclient: ListResources at %s for slice %s returned : %s"%(aggregate_url,slice_urn,truncate(list_resources))) - - - -def main (): + credential = bootstrap.slice_credential_string(slice_hrn) + credentials = [credential] + options = {} + options['geni_rspec_version'] = version_dict + options['geni_slice_urn'] = slice_urn + options['call_id'] = unique_call_id() + list_resources = bootstrap.server_proxy( + aggregate_url).ListResources(credentials, options) + print("miniclient: ListResources at %s for slice %s returned : %s" % + (aggregate_url, slice_urn, truncate(list_resources))) + + +def main(): get_version(registry_url) get_version(aggregate_url) list_resources() diff --git a/sfa/federica/fddriver.py b/sfa/federica/fddriver.py index cec702d6..0307506d 100644 --- a/sfa/federica/fddriver.py +++ b/sfa/federica/fddriver.py @@ -8,110 +8,119 @@ from sfa.federica.fdshell import FdShell # hardwired for now # this could/should be obtained by issuing getRSpecVersion -federica_version_string="RSpecV2" - -#### avail. methods on the federica side as of 2012/02/13 -# listAvailableResources(String credentials, String rspecVersion) -# listSliceResources(String credentials, String rspecVersion, String sliceUrn) -# createSlice(String credentials, String sliceUrn, String rspecVersion, String rspecString) -# deleteSlice(String credentials, String sliceUrn) -# listSlices() +federica_version_string = "RSpecV2" + +# avail. methods on the federica side as of 2012/02/13 +# listAvailableResources(String credentials, String rspecVersion) +# listSliceResources(String credentials, String rspecVersion, String sliceUrn) +# createSlice(String credentials, String sliceUrn, String rspecVersion, String rspecString) +# deleteSlice(String credentials, String sliceUrn) +# listSlices() # getRSpecVersion() -##### all return +# all return # Result: {'code': 0, 'value': RSpec} if success # {'code': code_id, 'output': Error message} if error + class FdDriver (PlDriver): - def __init__ (self,api): - PlDriver.__init__ (self, api) + def __init__(self, api): + PlDriver.__init__(self, api) config = api.config - self.shell=FdShell(config) + self.shell = FdShell(config) # the agreement with the federica driver is for them to expose results in a way # compliant with the avpi v2 return code, i.e. a dict with 'code' 'value' 'output' # essentially, either 'code'==0, then 'value' is set to the actual result - # otherwise, 'code' is set to an error code and 'output' holds an error message - def response (self, from_xmlrpc): - if isinstance (from_xmlrpc, dict) and 'code' in from_xmlrpc: - if from_xmlrpc['code']==0: + # otherwise, 'code' is set to an error code and 'output' holds an error + # message + def response(self, from_xmlrpc): + if isinstance(from_xmlrpc, dict) and 'code' in from_xmlrpc: + if from_xmlrpc['code'] == 0: return from_xmlrpc['value'] else: - raise SfaFault(from_xmlrpc['code'],from_xmlrpc['output']) + raise SfaFault(from_xmlrpc['code'], from_xmlrpc['output']) else: logger.warning("unexpected result from federica xmlrpc api") return from_xmlrpc - def aggregate_version (self): - result={} - federica_version_string_api = self.response(self.shell.getRSpecVersion()) - result ['federica_version_string_api']=federica_version_string_api + def aggregate_version(self): + result = {} + federica_version_string_api = self.response( + self.shell.getRSpecVersion()) + result['federica_version_string_api'] = federica_version_string_api if federica_version_string_api != federica_version_string: - result['WARNING']="hard-wired rspec version %d differs from what the API currently exposes"%\ - federica_version_string + result['WARNING'] = "hard-wired rspec version %d differs from what the API currently exposes" %\ + federica_version_string return result - def testbed_name (self): + def testbed_name(self): return "federica" - def list_slices (self, creds, options): + def list_slices(self, creds, options): # the issue is that federica returns the list of slice's urn in a string format - # this is why this dirty hack is needed until federica fixes it. + # this is why this dirty hack is needed until federica fixes it. slices_str = self.shell.listSlices()['value'][1:-1] slices_list = slices_str.split(", ") return slices_list - def sliver_status (self, slice_urn, slice_hrn): - return "fddriver.sliver_status: undefined/todo for slice %s"%slice_hrn + def sliver_status(self, slice_urn, slice_hrn): + return "fddriver.sliver_status: undefined/todo for slice %s" % slice_hrn - def list_resources (self, slice_urn, slice_hrn, creds, options): + def list_resources(self, slice_urn, slice_hrn, creds, options): # right now rspec_version is ignored on the federica side # we normally derive it from options # look in cache if client has requested so - cached_requested = options.get('cached', True) + cached_requested = options.get('cached', True) # global advertisement if not slice_hrn: - # self.cache is initialized unless the global config has it turned off + # self.cache is initialized unless the global config has it turned + # off if cached_requested and self.cache: # using federica_version_string as the key into the cache rspec = self.cache.get(federica_version_string) if rspec: - logger.debug("FdDriver.ListResources: returning cached advertisement") + logger.debug( + "FdDriver.ListResources: returning cached advertisement") return self.response(rspec) # otherwise, need to get it # java code expects creds as a String # rspec = self.shell.listAvailableResources (creds, federica_version_string) - rspec = self.shell.listAvailableResources ("", federica_version_string) + rspec = self.shell.listAvailableResources( + "", federica_version_string) # rspec = self.shell.listAvailableResources (federica_version_string) # cache it for future use if self.cache: - logger.debug("FdDriver.ListResources: stores advertisement in cache") + logger.debug( + "FdDriver.ListResources: stores advertisement in cache") self.cache.add(federica_version_string, rspec) return self.response(rspec) # about a given slice : don't cache else: # java code expects creds as a String -# return self.response(self.shell.listSliceResources(creds, federica_version_string, slice_urn)) + # return self.response(self.shell.listSliceResources(creds, + # federica_version_string, slice_urn)) return self.response(self.shell.listSliceResources("", federica_version_string, slice_urn)) - def create_sliver (self, slice_urn, slice_hrn, creds, rspec_string, users, options): + def create_sliver(self, slice_urn, slice_hrn, creds, rspec_string, users, options): # right now version_string is ignored on the federica side # we normally derive it from options # java code expects creds as a String -# return self.response(self.shell.createSlice(creds, slice_urn, federica_version_string, rspec_string)) + # return self.response(self.shell.createSlice(creds, slice_urn, + # federica_version_string, rspec_string)) return self.response(self.shell.createSlice("", slice_urn, federica_version_string, rspec_string)) - def delete_sliver (self, slice_urn, slice_hrn, creds, options): + def delete_sliver(self, slice_urn, slice_hrn, creds, options): # right now version_string is ignored on the federica side # we normally derive it from options # xxx not sure if that's currentl supported at all # java code expects creds as a String -# return self.response(self.shell.deleteSlice(creds, slice_urn)) + # return self.response(self.shell.deleteSlice(creds, slice_urn)) return self.response(self.shell.deleteSlice("", slice_urn)) # for the the following methods we use what is provided by the default driver class - #def renew_sliver (self, slice_urn, slice_hrn, creds, expiration_time, options): - #def start_slice (self, slice_urn, slice_xrn, creds): - #def stop_slice (self, slice_urn, slice_xrn, creds): - #def reset_slice (self, slice_urn, slice_xrn, creds): - #def get_ticket (self, slice_urn, slice_xrn, creds, rspec, options): + # def renew_sliver (self, slice_urn, slice_hrn, creds, expiration_time, options): + # def start_slice (self, slice_urn, slice_xrn, creds): + # def stop_slice (self, slice_urn, slice_xrn, creds): + # def reset_slice (self, slice_urn, slice_xrn, creds): + # def get_ticket (self, slice_urn, slice_xrn, creds, rspec, options): diff --git a/sfa/federica/fdshell.py b/sfa/federica/fdshell.py index 1e7349c5..179c8516 100644 --- a/sfa/federica/fdshell.py +++ b/sfa/federica/fdshell.py @@ -1,27 +1,28 @@ from sfa.util.sfalogging import logger from sfa.util.py23 import xmlrpc_client + class FdShell: """ A simple xmlrpc shell to a federica API server This class can receive the XMLRPC calls to the federica testbed For safety this is limited to a set of hard-coded calls """ - - direct_calls = [ 'listAvailableResources', - 'listSliceResources', - 'createSlice', - 'deleteSlice', - 'getRSpecVersion', - 'listSlices', + + direct_calls = ['listAvailableResources', + 'listSliceResources', + 'createSlice', + 'deleteSlice', + 'getRSpecVersion', + 'listSlices', ] - def __init__ ( self, config ) : - url=config.SFA_FEDERICA_URL + def __init__(self, config): + url = config.SFA_FEDERICA_URL # xxx not sure if java xmlrpc has support for None # self.proxy = xmlrpc_client.ServerProxy(url, verbose = False, allow_none = True) # xxx turn on verbosity - self.proxy = xmlrpc_client.ServerProxy(url, verbose = True) + self.proxy = xmlrpc_client.ServerProxy(url, verbose=True) # xxx get credentials from the config ? # right now basic auth data goes into the URL @@ -29,12 +30,13 @@ class FdShell: def __getattr__(self, name): def func(*args, **kwds): if name not in FdShell.direct_calls: - raise Exception("Illegal method call %s for FEDERICA driver"%(name)) - logger.info("Issuing %s args=%s kwds=%s to federica"%\ - (name,args,kwds)) + raise Exception( + "Illegal method call %s for FEDERICA driver" % (name)) + logger.info("Issuing %s args=%s kwds=%s to federica" % + (name, args, kwds)) # result=getattr(self.proxy, "AggregateManager.%s"%name)(credential, *args, **kwds) - result=getattr(self.proxy, "AggregateManager.%s"%name)(*args, **kwds) - logger.debug('FdShell %s (%s) returned ... '%(name,name)) + result = getattr(self.proxy, "AggregateManager.%s" % + name)(*args, **kwds) + logger.debug('FdShell %s (%s) returned ... ' % (name, name)) return result return func - diff --git a/sfa/generic/__init__.py b/sfa/generic/__init__.py index ece7e2b8..a7a0741a 100644 --- a/sfa/generic/__init__.py +++ b/sfa/generic/__init__.py @@ -3,113 +3,122 @@ from sfa.util.config import Config from sfa.managers.managerwrapper import ManagerWrapper -# a bundle is the combination of +# a bundle is the combination of # (*) an api that reacts on the incoming requests to trigger the API methods -# (*) a manager that implements the function of the service, +# (*) a manager that implements the function of the service, # either aggregate, registry, or slicemgr # (*) a driver that controls the underlying testbed -# -# -# The Generic class is a utility that uses the configuration to figure out -# which combination of these pieces need to be put together +# +# +# The Generic class is a utility that uses the configuration to figure out +# which combination of these pieces need to be put together # from config. # this extra indirection is needed to adapt to the current naming scheme -# where we have 'pl' and 'plc' and components and the like, that does not +# where we have 'pl' and 'plc' and components and the like, that does not # yet follow a sensible scheme # needs refinements to cache more efficiently, esp. wrt the config + class Generic: - def __init__ (self, flavour, config): - self.flavour=flavour - self.config=config + def __init__(self, flavour, config): + self.flavour = flavour + self.config = config # proof of concept # example flavour='pl' -> sfa.generic.pl.pl() @staticmethod - def the_flavour (flavour=None, config=None): - if config is None: config=Config() - if flavour is None: flavour=config.SFA_GENERIC_FLAVOUR + def the_flavour(flavour=None, config=None): + if config is None: + config = Config() + if flavour is None: + flavour = config.SFA_GENERIC_FLAVOUR flavour = flavour.lower() #mixed = flavour.capitalize() - module_path="sfa.generic.%s"%flavour - classname="%s"%flavour - logger.debug("Generic.the_flavour with flavour=%s"%flavour) + module_path = "sfa.generic.%s" % flavour + classname = "%s" % flavour + logger.debug("Generic.the_flavour with flavour=%s" % flavour) try: - module = __import__ (module_path, globals(), locals(), [classname]) - return getattr(module, classname)(flavour,config) + module = __import__(module_path, globals(), locals(), [classname]) + return getattr(module, classname)(flavour, config) except: - logger.log_exc("Cannot locate generic instance with flavour=%s"%flavour) + logger.log_exc( + "Cannot locate generic instance with flavour=%s" % flavour) # provide default for importer_class - def importer_class (self): + def importer_class(self): return None # in the simplest case these can be redefined to the class/module objects to be used # see pl.py for an example # some descendant of SfaApi - def api_class (self) : pass + def api_class(self): pass # the python classes to use to build up the context - def registry_class (self) : pass - def slicemgr_class (self) : pass - def aggregate_class (self) : pass - def component_class (self) : pass + def registry_class(self): pass + + def slicemgr_class(self): pass + + def aggregate_class(self): pass + + def component_class(self): pass # build an API object - # insert a manager instance - def make_api (self, *args, **kwargs): + # insert a manager instance + def make_api(self, *args, **kwargs): # interface is a required arg if not 'interface' in kwargs: logger.critical("Generic.make_api: no interface found") api = self.api_class()(*args, **kwargs) - # xxx can probably drop support for managers implemented as modules + # xxx can probably drop support for managers implemented as modules # which makes it a bit awkward manager_class_or_module = self.make_manager(api.interface) - driver = self.make_driver (api) - ### arrange stuff together + driver = self.make_driver(api) + # arrange stuff together # add a manager wrapper - manager_wrap = ManagerWrapper(manager_class_or_module,api.interface,api.config) - api.manager=manager_wrap + manager_wrap = ManagerWrapper( + manager_class_or_module, api.interface, api.config) + api.manager = manager_wrap # add it in api as well; driver.api is set too as part of make_driver - api.driver=driver + api.driver = driver return api - def make_manager (self, interface): + def make_manager(self, interface): """ interface expected in ['registry', 'aggregate', 'slicemgr', 'component'] flavour is e.g. 'pl' or 'max' or whatever """ flavour = self.flavour - message="Generic.make_manager for interface=%s and flavour=%s"%(interface,flavour) - - classname = "%s_manager_class"%interface + message = "Generic.make_manager for interface=%s and flavour=%s" % ( + interface, flavour) + + classname = "%s_manager_class" % interface try: - module_or_class = getattr(self,classname)() - logger.debug("%s : %s"%(message,module_or_class)) - # this gets passed to ManagerWrapper that will call the class constructor + module_or_class = getattr(self, classname)() + logger.debug("%s : %s" % (message, module_or_class)) + # this gets passed to ManagerWrapper that will call the class constructor # if it's a class, or use the module as is if it's a module # so bottom line is, don't try the constructor here return module_or_class except: logger.log_exc_critical(message) - + # need interface to select the right driver - def make_driver (self, api): - config=api.config - interface=api.interface + def make_driver(self, api): + config = api.config + interface = api.interface flavour = self.flavour - message="Generic.make_driver for flavour=%s and interface=%s"%(flavour,interface) - + message = "Generic.make_driver for flavour=%s and interface=%s" % ( + flavour, interface) + if interface == "component": classname = "component_driver_class" else: classname = "driver_class" try: - class_obj = getattr(self,classname)() - logger.debug("%s : %s"%(message,class_obj)) + class_obj = getattr(self, classname)() + logger.debug("%s : %s" % (message, class_obj)) return class_obj(api) except: logger.log_exc_critical(message) - diff --git a/sfa/generic/dummy.py b/sfa/generic/dummy.py index ee8aff33..1d5a3f28 100644 --- a/sfa/generic/dummy.py +++ b/sfa/generic/dummy.py @@ -1,30 +1,32 @@ from sfa.generic import Generic + class dummy (Generic): - + # the importer class - def importer_class (self): + def importer_class(self): import sfa.importer.dummyimporter return sfa.importer.dummyimporter.DummyImporter - + # use the standard api class - def api_class (self): + def api_class(self): import sfa.server.sfaapi return sfa.server.sfaapi.SfaApi # the manager classes for the server-side services - def registry_manager_class (self) : + def registry_manager_class(self): import sfa.managers.registry_manager return sfa.managers.registry_manager.RegistryManager - def slicemgr_manager_class (self) : + + def slicemgr_manager_class(self): import sfa.managers.slice_manager return sfa.managers.slice_manager.SliceManager - def aggregate_manager_class (self) : + + def aggregate_manager_class(self): import sfa.managers.aggregate_manager return sfa.managers.aggregate_manager.AggregateManager # driver class for server-side services, talk to the whole testbed - def driver_class (self): + def driver_class(self): import sfa.dummy.dummydriver return sfa.dummy.dummydriver.DummyDriver - diff --git a/sfa/generic/fd.py b/sfa/generic/fd.py index f18b7359..60f172bd 100644 --- a/sfa/generic/fd.py +++ b/sfa/generic/fd.py @@ -1,13 +1,14 @@ -# +# from sfa.generic.pl import pl import sfa.federica.fddriver -# the federica flavour behaves like pl, except for +# the federica flavour behaves like pl, except for # the driver + class fd (pl): - def driver_class (self) : + def driver_class(self): import sfa.managers.v2_to_v3_adapter return sfa.managers.v2_to_v3_adapter.V2ToV3Adapter diff --git a/sfa/generic/iotlab.py b/sfa/generic/iotlab.py index 22f93102..818e01ef 100644 --- a/sfa/generic/iotlab.py +++ b/sfa/generic/iotlab.py @@ -3,37 +3,37 @@ from sfa.generic import Generic import sfa.server.sfaapi - class iotlab (Generic): # use the standard api class - def api_class (self): + def api_class(self): return sfa.server.sfaapi.SfaApi # the importer class - def importer_class (self): + def importer_class(self): import sfa.importer.iotlabimporter return sfa.importer.iotlabimporter.IotLabImporter # the manager classes for the server-side services - def registry_manager_class (self) : + def registry_manager_class(self): import sfa.managers.registry_manager return sfa.managers.registry_manager.RegistryManager - def slicemgr_manager_class (self) : + def slicemgr_manager_class(self): import sfa.managers.slice_manager return sfa.managers.slice_manager.SliceManager - def aggregate_manager_class (self) : + def aggregate_manager_class(self): import sfa.managers.aggregate_manager return sfa.managers.aggregate_manager.AggregateManager - def driver_class (self): + def driver_class(self): import sfa.iotlab.iotlabdriver return sfa.iotlab.iotlabdriver.IotLabDriver - def component_manager_class (self): + def component_manager_class(self): return None # driver_class - def component_driver_class (self): - return None \ No newline at end of file + + def component_driver_class(self): + return None diff --git a/sfa/generic/max.py b/sfa/generic/max.py index d54afc50..8b572268 100644 --- a/sfa/generic/max.py +++ b/sfa/generic/max.py @@ -1,20 +1,20 @@ # an example of how to plugin the max aggregate manager with the flavour model # might need to be tested -# +# from sfa.generic.pl import pl + class max (pl): -# the max flavour behaves like pl, except for -# the aggregate - def aggregate_manager_class (self) : + # the max flavour behaves like pl, except for + # the aggregate + def aggregate_manager_class(self): import sfa.managers.aggregate_manager_max return sfa.managers.aggregate_manager_max.AggregateManagerMax # I believe the component stuff is not implemented - def component_manager_class (self): - return None - def component_driver_class (self): + def component_manager_class(self): return None - + def component_driver_class(self): + return None diff --git a/sfa/generic/nitos.py b/sfa/generic/nitos.py index 5bad942f..3d8f394f 100644 --- a/sfa/generic/nitos.py +++ b/sfa/generic/nitos.py @@ -1,37 +1,41 @@ from sfa.generic import Generic + class nitos (Generic): # the importer class - def importer_class (self): + def importer_class(self): import sfa.importer.nitosimporter return sfa.importer.nitosimporter.NitosImporter # use the standard api class - def api_class (self): + def api_class(self): import sfa.server.sfaapi return sfa.server.sfaapi.SfaApi # the manager classes for the server-side services - def registry_manager_class (self) : + def registry_manager_class(self): import sfa.managers.registry_manager return sfa.managers.registry_manager.RegistryManager - def slicemgr_manager_class (self) : + + def slicemgr_manager_class(self): import sfa.managers.slice_manager return sfa.managers.slice_manager.SliceManager - def aggregate_manager_class (self) : + + def aggregate_manager_class(self): import sfa.managers.aggregate_manager return sfa.managers.aggregate_manager.AggregateManager # driver class for server-side services, talk to the whole testbed - def driver_class (self): + def driver_class(self): import sfa.managers.v2_to_v3_adapter return sfa.managers.v2_to_v3_adapter.V2ToV3Adapter # for the component mode, to be run on board planetlab nodes # manager class - def component_manager_class (self): + def component_manager_class(self): return None # driver_class - def component_driver_class (self): + + def component_driver_class(self): return None diff --git a/sfa/generic/openstack.py b/sfa/generic/openstack.py index bac57dc7..bd88e90e 100644 --- a/sfa/generic/openstack.py +++ b/sfa/generic/openstack.py @@ -9,22 +9,21 @@ import sfa.managers.slice_manager # use pl as a model so we only redefine what's different from sfa.generic.pl import pl + class openstack (pl): - + # the importer class - def importer_class (self): + def importer_class(self): import sfa.importer.openstackimporter return sfa.importer.openstackimporter.OpenstackImporter - + # the manager classes for the server-side services - def registry_manager_class (self) : + def registry_manager_class(self): return sfa.managers.registry_manager_openstack.RegistryManager - def aggregate_manager_class (self) : + + def aggregate_manager_class(self): return sfa.managers.aggregate_manager.AggregateManager # driver class for server-side services, talk to the whole testbed - def driver_class (self): + def driver_class(self): return sfa.openstack.nova_driver.NovaDriver - - - diff --git a/sfa/generic/pl.py b/sfa/generic/pl.py index 457a09b4..cbeabd11 100644 --- a/sfa/generic/pl.py +++ b/sfa/generic/pl.py @@ -1,40 +1,43 @@ from sfa.generic import Generic + class pl (Generic): - + # the importer class - def importer_class (self): + def importer_class(self): import sfa.importer.plimporter return sfa.importer.plimporter.PlImporter - + # use the standard api class - def api_class (self): + def api_class(self): import sfa.server.sfaapi return sfa.server.sfaapi.SfaApi # the manager classes for the server-side services - def registry_manager_class (self) : + def registry_manager_class(self): import sfa.managers.registry_manager return sfa.managers.registry_manager.RegistryManager - def slicemgr_manager_class (self) : + + def slicemgr_manager_class(self): import sfa.managers.slice_manager return sfa.managers.slice_manager.SliceManager - def aggregate_manager_class (self) : + + def aggregate_manager_class(self): import sfa.managers.aggregate_manager return sfa.managers.aggregate_manager.AggregateManager # driver class for server-side services, talk to the whole testbed - def driver_class (self): + def driver_class(self): import sfa.planetlab.pldriver return sfa.planetlab.pldriver.PlDriver # for the component mode, to be run on board planetlab nodes # manager class - def component_manager_class (self): + def component_manager_class(self): import sfa.managers return sfa.managers.component_manager_pl # driver_class - def component_driver_class (self): + + def component_driver_class(self): import sfa.planetlab.plcomponentdriver return sfa.planetlab.plcomponentdriver.PlComponentDriver - diff --git a/sfa/generic/teagle.py b/sfa/generic/teagle.py index 094e8880..a4ec66ad 100644 --- a/sfa/generic/teagle.py +++ b/sfa/generic/teagle.py @@ -1,32 +1,34 @@ from sfa.generic import Generic + class teagle (Generic): - + # the importer class - def importer_class (self): + def importer_class(self): import sfa.importer.dummyimporter return sfa.importer.dummyimporter.DummyImporter - + # use the standard api class - def api_class (self): + def api_class(self): import sfa.server.sfaapi return sfa.server.sfaapi.SfaApi # the manager classes for the server-side services - def registry_manager_class (self) : + def registry_manager_class(self): import sfa.managers.registry_manager return sfa.managers.registry_manager.RegistryManager - def slicemgr_manager_class (self) : + + def slicemgr_manager_class(self): import sfa.managers.slice_manager return sfa.managers.slice_manager.SliceManager - def aggregate_manager_class (self) : + + def aggregate_manager_class(self): import sfa.managers.aggregate_manager return sfa.managers.aggregate_manager.AggregateManager # driver class for server-side services, talk to the whole testbed - def driver_class (self): + def driver_class(self): import teaglesfa.driver return teaglesfa.driver.TeagleDriver # import sfa.dummy.dummydriver # return sfa.dummy.dummydriver.DummyDriver - diff --git a/sfa/generic/void.py b/sfa/generic/void.py index 6c496f4f..bbd5deaf 100644 --- a/sfa/generic/void.py +++ b/sfa/generic/void.py @@ -2,36 +2,38 @@ from sfa.generic import Generic + class void (Generic): - + # the importer class # when set to None, the importer only performs the basic stuff - # xxx this convention probably is confusing, since None suggests that + # xxx this convention probably is confusing, since None suggests that # *nothing* should be done.. # xxx need to refactor the importers anyway - def importer_class (self): + def importer_class(self): return None - + # use the standard api class - def api_class (self): + def api_class(self): import sfa.server.sfaapi return sfa.server.sfaapi.SfaApi # the manager classes for the server-side services - def registry_manager_class (self) : + def registry_manager_class(self): import sfa.managers.registry_manager return sfa.managers.registry_manager.RegistryManager - def slicemgr_manager_class (self) : + + def slicemgr_manager_class(self): import sfa.managers.slice_manager return sfa.managers.slice_manager.SliceManager # most likely you'll want to turn OFF the aggregate in sfa-config-tty # SFA_AGGREGATE_ENABLED=false - def aggregate_manager_class (self) : + + def aggregate_manager_class(self): import sfa.managers.aggregate_manager return sfa.managers.aggregate_manager.AggregateManager # driver class for server-side services, talk to the whole testbed - def driver_class (self): + def driver_class(self): import sfa.managers.driver return sfa.managers.driver.Driver - diff --git a/sfa/importer/__init__.py b/sfa/importer/__init__.py index e1c3a289..65e57ec0 100644 --- a/sfa/importer/__init__.py +++ b/sfa/importer/__init__.py @@ -10,7 +10,7 @@ from sfa.util.sfalogging import _SfaLogger from sfa.trust.hierarchy import Hierarchy #from sfa.trust.trustedroots import TrustedRoots from sfa.trust.gid import create_uuid -# using global alchemy.session() here is fine +# using global alchemy.session() here is fine # as importer is on standalone one-shot process from sfa.storage.alchemy import global_dbsession from sfa.storage.model import RegRecord, RegAuthority, RegUser @@ -19,25 +19,26 @@ from sfa.trust.certificate import convert_public_key, Keypair class Importer: - def __init__(self,auth_hierarchy=None,logger=None): + def __init__(self, auth_hierarchy=None, logger=None): self.config = Config() if auth_hierarchy is not None: - self.auth_hierarchy=auth_hierarchy + self.auth_hierarchy = auth_hierarchy else: - self.auth_hierarchy = Hierarchy () + self.auth_hierarchy = Hierarchy() if logger is not None: - self.logger=logger + self.logger = logger else: - self.logger = _SfaLogger(logfile='/var/log/sfa_import.log', loggername='importlog') + self.logger = _SfaLogger( + logfile='/var/log/sfa_import.log', loggername='importlog') self.logger.setLevelFromOptVerbose(self.config.SFA_API_LOGLEVEL) # ugly side effect so that other modules get it right import sfa.util.sfalogging - sfa.util.sfalogging.logger=logger -# self.TrustedRoots = TrustedRoots(self.config.get_trustedroots_dir()) - + sfa.util.sfalogging.logger = logger +# self.TrustedRoots = TrustedRoots(self.config.get_trustedroots_dir()) + # check before creating a RegRecord entry as we run this over and over - def record_exists (self, type, hrn): - return global_dbsession.query(RegRecord).filter_by(hrn=hrn,type=type).count()!=0 + def record_exists(self, type, hrn): + return global_dbsession.query(RegRecord).filter_by(hrn=hrn, type=type).count() != 0 def create_top_level_auth_records(self, hrn): """ @@ -53,15 +54,15 @@ class Importer: # ensure key and cert exists: self.auth_hierarchy.create_top_level_auth(hrn) # create the db record if it doesnt already exist - if not self.record_exists ('authority',hrn): + if not self.record_exists('authority', hrn): auth_info = self.auth_hierarchy.get_auth_info(hrn) auth_record = RegAuthority(hrn=hrn, gid=auth_info.get_gid_object(), authority=get_authority(hrn)) auth_record.just_created() - global_dbsession.add (auth_record) + global_dbsession.add(auth_record) global_dbsession.commit() - self.logger.info("SfaImporter: imported authority (parent) %s " % auth_record) - + self.logger.info( + "SfaImporter: imported authority (parent) %s " % auth_record) def create_sm_client_record(self): """ @@ -73,15 +74,16 @@ class Importer: self.logger.info("SfaImporter: creating Slice Manager user") self.auth_hierarchy.create_auth(urn) - if self.record_exists ('user',hrn): return + if self.record_exists('user', hrn): + return auth_info = self.auth_hierarchy.get_auth_info(hrn) user_record = RegUser(hrn=hrn, gid=auth_info.get_gid_object(), authority=get_authority(hrn)) user_record.just_created() - global_dbsession.add (user_record) + global_dbsession.add(user_record) global_dbsession.commit() - self.logger.info("SfaImporter: importing user (slicemanager) %s " % user_record) - + self.logger.info( + "SfaImporter: importing user (slicemanager) %s " % user_record) def create_interface_records(self): """ @@ -89,44 +91,48 @@ class Importer: """ # just create certs for all sfa interfaces even if they # aren't enabled - auth_info = self.auth_hierarchy.get_auth_info(self.config.SFA_INTERFACE_HRN) + auth_info = self.auth_hierarchy.get_auth_info( + self.config.SFA_INTERFACE_HRN) pkey = auth_info.get_pkey_object() - hrn=self.config.SFA_INTERFACE_HRN - for type in [ 'authority+sa', 'authority+am', 'authority+sm', ]: + hrn = self.config.SFA_INTERFACE_HRN + for type in ['authority+sa', 'authority+am', 'authority+sm', ]: urn = hrn_to_urn(hrn, type) gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey) # for now we have to preserve the authority+<> stuff - if self.record_exists (type,hrn): continue + if self.record_exists(type, hrn): + continue interface_record = RegAuthority(type=type, hrn=hrn, gid=gid, authority=get_authority(hrn)) interface_record.just_created() - global_dbsession.add (interface_record) + global_dbsession.add(interface_record) global_dbsession.commit() - self.logger.info("SfaImporter: imported authority (%s) %s " % (type,interface_record)) - + self.logger.info("SfaImporter: imported authority (%s) %s " % ( + type, interface_record)) + def run(self, options=None): if not self.config.SFA_REGISTRY_ENABLED: - self.logger.critical("Importer: need SFA_REGISTRY_ENABLED to run import") + self.logger.critical( + "Importer: need SFA_REGISTRY_ENABLED to run import") # testbed-neutral : create local certificates and the like - auth_hierarchy = Hierarchy () + auth_hierarchy = Hierarchy() self.create_top_level_auth_records(self.config.SFA_INTERFACE_HRN) self.create_interface_records() - + # testbed-specific testbed_importer = None - generic=Generic.the_flavour() + generic = Generic.the_flavour() importer_class = generic.importer_class() if importer_class: - begin_time=datetime.utcnow() - self.logger.info (30*'=') - self.logger.info ("Starting import on %s, using class %s from flavour %s"%\ - (begin_time,importer_class.__name__,generic.flavour)) - testbed_importer = importer_class (auth_hierarchy, self.logger) + begin_time = datetime.utcnow() + self.logger.info(30 * '=') + self.logger.info("Starting import on %s, using class %s from flavour %s" % + (begin_time, importer_class.__name__, generic.flavour)) + testbed_importer = importer_class(auth_hierarchy, self.logger) if testbed_importer: testbed_importer.add_options(options) - testbed_importer.run (options) - end_time=datetime.utcnow() - duration=end_time-begin_time - self.logger.info("Import took %s"%duration) - self.logger.info (30*'=') + testbed_importer.run(options) + end_time = datetime.utcnow() + duration = end_time - begin_time + self.logger.info("Import took %s" % duration) + self.logger.info(30 * '=') diff --git a/sfa/importer/dummyimporter.py b/sfa/importer/dummyimporter.py index d274b279..4594ad6e 100644 --- a/sfa/importer/dummyimporter.py +++ b/sfa/importer/dummyimporter.py @@ -1,89 +1,93 @@ # # Dummy importer -# +# # requirements -# +# # read the planetlab database and update the local registry database accordingly # so we update the following collections # . authorities (from pl sites) # . node (from pl nodes) # . users+keys (from pl persons and attached keys) # known limitation : *one* of the ssh keys is chosen at random here -# xxx todo/check xxx at the very least, when a key is known to the registry +# xxx todo/check xxx at the very least, when a key is known to the registry # and is still current in plc # then we should definitely make sure to keep that one in sfa... # . slice+researchers (from pl slices and attached users) -# +# import os from sfa.util.config import Config from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn -from sfa.trust.gid import create_uuid +from sfa.trust.gid import create_uuid from sfa.trust.certificate import convert_public_key, Keypair -# using global alchemy.session() here is fine +# using global alchemy.session() here is fine # as importer is on standalone one-shot process from sfa.storage.alchemy import global_dbsession from sfa.storage.model import RegRecord, RegAuthority, RegSlice, RegNode, RegUser, RegKey -from sfa.dummy.dummyshell import DummyShell +from sfa.dummy.dummyshell import DummyShell from sfa.dummy.dummyxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_dummy_slicename + def _get_site_hrn(interface_hrn, site): - hrn = ".".join([interface_hrn, site['name']]) - return hrn + hrn = ".".join([interface_hrn, site['name']]) + return hrn class DummyImporter: - def __init__ (self, auth_hierarchy, logger): + def __init__(self, auth_hierarchy, logger): self.auth_hierarchy = auth_hierarchy - self.logger=logger + self.logger = logger - def add_options (self, parser): + def add_options(self, parser): # we don't have any options for now pass # hrn hash is initialized from current db # remember just-created records as we go # xxx might make sense to add a UNIQUE constraint in the db itself - def remember_record_by_hrn (self, record): + def remember_record_by_hrn(self, record): tuple = (record.type, record.hrn) if tuple in self.records_by_type_hrn: - self.logger.warning ("DummyImporter.remember_record_by_hrn: duplicate (%s,%s)"%tuple) + self.logger.warning( + "DummyImporter.remember_record_by_hrn: duplicate (%s,%s)" % tuple) return - self.records_by_type_hrn [ tuple ] = record + self.records_by_type_hrn[tuple] = record # ditto for pointer hash - def remember_record_by_pointer (self, record): + def remember_record_by_pointer(self, record): if record.pointer == -1: - self.logger.warning ("DummyImporter.remember_record_by_pointer: pointer is void") + self.logger.warning( + "DummyImporter.remember_record_by_pointer: pointer is void") return tuple = (record.type, record.pointer) if tuple in self.records_by_type_pointer: - self.logger.warning ("DummyImporter.remember_record_by_pointer: duplicate (%s,%s)"%tuple) + self.logger.warning( + "DummyImporter.remember_record_by_pointer: duplicate (%s,%s)" % tuple) return - self.records_by_type_pointer [ ( record.type, record.pointer,) ] = record + self.records_by_type_pointer[(record.type, record.pointer,)] = record - def remember_record (self, record): - self.remember_record_by_hrn (record) - self.remember_record_by_pointer (record) + def remember_record(self, record): + self.remember_record_by_hrn(record) + self.remember_record_by_pointer(record) - def locate_by_type_hrn (self, type, hrn): - return self.records_by_type_hrn.get ( (type, hrn), None) + def locate_by_type_hrn(self, type, hrn): + return self.records_by_type_hrn.get((type, hrn), None) - def locate_by_type_pointer (self, type, pointer): - return self.records_by_type_pointer.get ( (type, pointer), None) + def locate_by_type_pointer(self, type, pointer): + return self.records_by_type_pointer.get((type, pointer), None) # a convenience/helper function to see if a record is already known - # a former, broken, attempt (in 2.1-9) had been made + # a former, broken, attempt (in 2.1-9) had been made # to try and use 'pointer' as a first, most significant attempt - # the idea being to preserve stuff as much as possible, and thus + # the idea being to preserve stuff as much as possible, and thus # to avoid creating a new gid in the case of a simple hrn rename # however this of course doesn't work as the gid depends on the hrn... - #def locate (self, type, hrn=None, pointer=-1): + # def locate (self, type, hrn=None, pointer=-1): # if pointer!=-1: # attempt = self.locate_by_type_pointer (type, pointer) # if attempt : return attempt @@ -94,28 +98,32 @@ class DummyImporter: # this makes the run method a bit abtruse - out of the way - def run (self, options): - config = Config () + def run(self, options): + config = Config() interface_hrn = config.SFA_INTERFACE_HRN root_auth = config.SFA_REGISTRY_ROOT_AUTH - shell = DummyShell (config) + shell = DummyShell(config) - ######## retrieve all existing SFA objects + # retrieve all existing SFA objects all_records = global_dbsession.query(RegRecord).all() - # create hash by (type,hrn) - # we essentially use this to know if a given record is already known to SFA + # create hash by (type,hrn) + # we essentially use this to know if a given record is already known to + # SFA self.records_by_type_hrn = \ - dict ( [ ( (record.type, record.hrn) , record ) for record in all_records ] ) - # create hash by (type,pointer) + dict([((record.type, record.hrn), record) + for record in all_records]) + # create hash by (type,pointer) self.records_by_type_pointer = \ - dict ( [ ( (record.type, record.pointer) , record ) for record in all_records - if record.pointer != -1] ) + dict([((record.type, record.pointer), record) for record in all_records + if record.pointer != -1]) - # initialize record.stale to True by default, then mark stale=False on the ones that are in use - for record in all_records: record.stale=True + # initialize record.stale to True by default, then mark stale=False on + # the ones that are in use + for record in all_records: + record.stale = True - ######## retrieve Dummy TB data + # retrieve Dummy TB data # Get all plc sites # retrieve only required stuf sites = [shell.GetTestbedInfo()] @@ -124,33 +132,32 @@ class DummyImporter: # Get all dummy TB users users = shell.GetUsers() # create a hash of users by user_id - users_by_id = dict ( [ ( user['user_id'], user) for user in users ] ) + users_by_id = dict([(user['user_id'], user) for user in users]) # Get all dummy TB public keys keys = [] for user in users: if 'keys' in user: keys.extend(user['keys']) # create a dict user_id -> [ keys ] - keys_by_person_id = {} + keys_by_person_id = {} for user in users: - if 'keys' in user: - keys_by_person_id[user['user_id']] = user['keys'] - # Get all dummy TB nodes + if 'keys' in user: + keys_by_person_id[user['user_id']] = user['keys'] + # Get all dummy TB nodes nodes = shell.GetNodes() # create hash by node_id - nodes_by_id = dict ( [ ( node['node_id'], node, ) for node in nodes ] ) + nodes_by_id = dict([(node['node_id'], node, ) for node in nodes]) # Get all dummy TB slices slices = shell.GetSlices() # create hash by slice_id - slices_by_id = dict ( [ (slice['slice_id'], slice ) for slice in slices ] ) - + slices_by_id = dict([(slice['slice_id'], slice) for slice in slices]) - # start importing + # start importing for site in sites: site_hrn = _get_site_hrn(interface_hrn, site) # import if hrn is not in list of existing hrns or if the hrn exists # but its not a site record - site_record=self.locate_by_type_hrn ('authority', site_hrn) + site_record = self.locate_by_type_hrn('authority', site_hrn) if not site_record: try: urn = hrn_to_urn(site_hrn, 'authority') @@ -158,177 +165,205 @@ class DummyImporter: self.auth_hierarchy.create_auth(urn) auth_info = self.auth_hierarchy.get_auth_info(urn) site_record = RegAuthority(hrn=site_hrn, gid=auth_info.get_gid_object(), - pointer= -1, + pointer=-1, authority=get_authority(site_hrn)) site_record.just_created() global_dbsession.add(site_record) global_dbsession.commit() - self.logger.info("DummyImporter: imported authority (site) : %s" % site_record) - self.remember_record (site_record) + self.logger.info( + "DummyImporter: imported authority (site) : %s" % site_record) + self.remember_record(site_record) except: # if the site import fails then there is no point in trying to import the - # site's child records (node, slices, persons), so skip them. - self.logger.log_exc("DummyImporter: failed to import site. Skipping child records") - continue + # site's child records (node, slices, persons), so skip + # them. + self.logger.log_exc( + "DummyImporter: failed to import site. Skipping child records") + continue else: # xxx update the record ... pass - site_record.stale=False - + site_record.stale = False + # import node records for node in nodes: site_auth = get_authority(site_hrn) site_name = site['name'] - node_hrn = hostname_to_hrn(site_auth, site_name, node['hostname']) + node_hrn = hostname_to_hrn( + site_auth, site_name, node['hostname']) # xxx this sounds suspicious - if len(node_hrn) > 64: node_hrn = node_hrn[:64] - node_record = self.locate_by_type_hrn ( 'node', node_hrn ) + if len(node_hrn) > 64: + node_hrn = node_hrn[:64] + node_record = self.locate_by_type_hrn('node', node_hrn) if not node_record: try: pkey = Keypair(create=True) urn = hrn_to_urn(node_hrn, 'node') - node_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey) - node_record = RegNode (hrn=node_hrn, gid=node_gid, - pointer =node['node_id'], - authority=get_authority(node_hrn)) + node_gid = self.auth_hierarchy.create_gid( + urn, create_uuid(), pkey) + node_record = RegNode(hrn=node_hrn, gid=node_gid, + pointer=node['node_id'], + authority=get_authority(node_hrn)) node_record.just_created() global_dbsession.add(node_record) global_dbsession.commit() - self.logger.info("DummyImporter: imported node: %s" % node_record) - self.remember_record (node_record) + self.logger.info( + "DummyImporter: imported node: %s" % node_record) + self.remember_record(node_record) except: - self.logger.log_exc("DummyImporter: failed to import node") + self.logger.log_exc( + "DummyImporter: failed to import node") else: # xxx update the record ... pass - node_record.stale=False + node_record.stale = False - site_pis=[] + site_pis = [] # import users for user in users: user_hrn = email_to_hrn(site_hrn, user['email']) # xxx suspicious again - if len(user_hrn) > 64: user_hrn = user_hrn[:64] + if len(user_hrn) > 64: + user_hrn = user_hrn[:64] user_urn = hrn_to_urn(user_hrn, 'user') - user_record = self.locate_by_type_hrn ( 'user', user_hrn) + user_record = self.locate_by_type_hrn('user', user_hrn) - # return a tuple pubkey (a dummy TB key object) and pkey (a Keypair object) + # return a tuple pubkey (a dummy TB key object) and pkey (a + # Keypair object) - def init_user_key (user): + def init_user_key(user): pubkey = None pkey = None - if user['keys']: + if user['keys']: # randomly pick first key in set for key in user['keys']: - pubkey = key - try: + pubkey = key + try: pkey = convert_public_key(pubkey) break - except: + except: continue if not pkey: - self.logger.warn('DummyImporter: unable to convert public key for %s' % user_hrn) + self.logger.warn( + 'DummyImporter: unable to convert public key for %s' % user_hrn) pkey = Keypair(create=True) else: - # the user has no keys. Creating a random keypair for the user's gid - self.logger.warn("DummyImporter: user %s does not have a NITOS public key"%user_hrn) + # the user has no keys. Creating a random keypair for + # the user's gid + self.logger.warn( + "DummyImporter: user %s does not have a NITOS public key" % user_hrn) pkey = Keypair(create=True) return (pubkey, pkey) # new user try: if not user_record: - (pubkey,pkey) = init_user_key (user) - user_gid = self.auth_hierarchy.create_gid(user_urn, create_uuid(), pkey) + (pubkey, pkey) = init_user_key(user) + user_gid = self.auth_hierarchy.create_gid( + user_urn, create_uuid(), pkey) user_gid.set_email(user['email']) - user_record = RegUser (hrn=user_hrn, gid=user_gid, - pointer=user['user_id'], - authority=get_authority(user_hrn), - email=user['email']) - if pubkey: - user_record.reg_keys=[RegKey (pubkey)] + user_record = RegUser(hrn=user_hrn, gid=user_gid, + pointer=user['user_id'], + authority=get_authority( + user_hrn), + email=user['email']) + if pubkey: + user_record.reg_keys = [RegKey(pubkey)] else: - self.logger.warning("No key found for user %s"%user_record) + self.logger.warning( + "No key found for user %s" % user_record) user_record.just_created() - global_dbsession.add (user_record) + global_dbsession.add(user_record) global_dbsession.commit() - self.logger.info("DummyImporter: imported person: %s" % user_record) - self.remember_record ( user_record ) + self.logger.info( + "DummyImporter: imported person: %s" % user_record) + self.remember_record(user_record) else: # update the record ? - # if user's primary key has changed then we need to update the + # if user's primary key has changed then we need to update the # users gid by forcing an update here sfa_keys = user_record.reg_keys - def key_in_list (key,sfa_keys): + + def key_in_list(key, sfa_keys): for reg_key in sfa_keys: - if reg_key.key==key: return True + if reg_key.key == key: + return True return False # is there a new key in Dummy TB ? - new_keys=False + new_keys = False for key in user['keys']: - if not key_in_list (key,sfa_keys): + if not key_in_list(key, sfa_keys): new_keys = True if new_keys: - (pubkey,pkey) = init_user_key (user) - user_gid = self.auth_hierarchy.create_gid(user_urn, create_uuid(), pkey) + (pubkey, pkey) = init_user_key(user) + user_gid = self.auth_hierarchy.create_gid( + user_urn, create_uuid(), pkey) if not pubkey: - user_record.reg_keys=[] + user_record.reg_keys = [] else: - user_record.reg_keys=[ RegKey (pubkey)] - self.logger.info("DummyImporter: updated person: %s" % user_record) + user_record.reg_keys = [RegKey(pubkey)] + self.logger.info( + "DummyImporter: updated person: %s" % user_record) user_record.email = user['email'] global_dbsession.commit() - user_record.stale=False + user_record.stale = False except: - self.logger.log_exc("DummyImporter: failed to import user %d %s"%(user['user_id'],user['email'])) - + self.logger.log_exc("DummyImporter: failed to import user %d %s" % ( + user['user_id'], user['email'])) # import slices for slice in slices: slice_hrn = slicename_to_hrn(site_hrn, slice['slice_name']) - slice_record = self.locate_by_type_hrn ('slice', slice_hrn) + slice_record = self.locate_by_type_hrn('slice', slice_hrn) if not slice_record: try: pkey = Keypair(create=True) urn = hrn_to_urn(slice_hrn, 'slice') - slice_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey) - slice_record = RegSlice (hrn=slice_hrn, gid=slice_gid, - pointer=slice['slice_id'], - authority=get_authority(slice_hrn)) + slice_gid = self.auth_hierarchy.create_gid( + urn, create_uuid(), pkey) + slice_record = RegSlice(hrn=slice_hrn, gid=slice_gid, + pointer=slice['slice_id'], + authority=get_authority(slice_hrn)) slice_record.just_created() global_dbsession.add(slice_record) global_dbsession.commit() - self.logger.info("DummyImporter: imported slice: %s" % slice_record) - self.remember_record ( slice_record ) + self.logger.info( + "DummyImporter: imported slice: %s" % slice_record) + self.remember_record(slice_record) except: - self.logger.log_exc("DummyImporter: failed to import slice") + self.logger.log_exc( + "DummyImporter: failed to import slice") else: # xxx update the record ... - self.logger.warning ("Slice update not yet implemented") + self.logger.warning("Slice update not yet implemented") pass # record current users affiliated with the slice slice_record.reg_researchers = \ - [ self.locate_by_type_pointer ('user',user_id) for user_id in slice['user_ids'] ] + [self.locate_by_type_pointer( + 'user', user_id) for user_id in slice['user_ids']] global_dbsession.commit() - slice_record.stale=False + slice_record.stale = False - ### remove stale records + # remove stale records # special records must be preserved - system_hrns = [interface_hrn, root_auth, interface_hrn + '.slicemanager'] - for record in all_records: - if record.hrn in system_hrns: - record.stale=False + system_hrns = [interface_hrn, root_auth, + interface_hrn + '.slicemanager'] + for record in all_records: + if record.hrn in system_hrns: + record.stale = False if record.peer_authority: - record.stale=False + record.stale = False for record in all_records: - try: stale=record.stale - except: - stale=True - self.logger.warning("stale not found with %s"%record) + try: + stale = record.stale + except: + stale = True + self.logger.warning("stale not found with %s" % record) if stale: - self.logger.info("DummyImporter: deleting stale record: %s" % record) + self.logger.info( + "DummyImporter: deleting stale record: %s" % record) global_dbsession.delete(record) global_dbsession.commit() diff --git a/sfa/importer/iotlabimporter.py b/sfa/importer/iotlabimporter.py index 3aa45ec7..a794a40f 100644 --- a/sfa/importer/iotlabimporter.py +++ b/sfa/importer/iotlabimporter.py @@ -6,6 +6,7 @@ from sfa.storage.model import init_tables from sqlalchemy import Table, MetaData from sqlalchemy.exc import NoSuchTableError + class IotLabImporter: """ Creates the iotlab specific lease table to keep track @@ -16,10 +17,10 @@ class IotLabImporter: self.logger = loc_logger self.logger.setLevelDebug() - def add_options (self, parser): + def add_options(self, parser): """ Not used and need by SFA """ pass - + def _exists(self, tablename): """ Checks if the table exists in SFA database. @@ -31,7 +32,6 @@ class IotLabImporter: except NoSuchTableError: return False - def run(self, options): """ Run importer""" diff --git a/sfa/importer/nitosimporter.py b/sfa/importer/nitosimporter.py index 425be778..f77abd42 100644 --- a/sfa/importer/nitosimporter.py +++ b/sfa/importer/nitosimporter.py @@ -4,70 +4,74 @@ import os from sfa.util.config import Config from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn -from sfa.trust.gid import create_uuid +from sfa.trust.gid import create_uuid from sfa.trust.certificate import convert_public_key, Keypair -# using global alchemy.session() here is fine +# using global alchemy.session() here is fine # as importer is on standalone one-shot process from sfa.storage.alchemy import global_dbsession from sfa.storage.model import RegRecord, RegAuthority, RegSlice, RegNode, RegUser, RegKey -from sfa.nitos.nitosshell import NitosShell +from sfa.nitos.nitosshell import NitosShell from sfa.nitos.nitosxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_nitos_slicename, username_to_hrn + def _get_site_hrn(interface_hrn, site): - hrn = ".".join([interface_hrn, site['name']]) + hrn = ".".join([interface_hrn, site['name']]) return hrn class NitosImporter: - def __init__ (self, auth_hierarchy, logger): + def __init__(self, auth_hierarchy, logger): self.auth_hierarchy = auth_hierarchy - self.logger=logger + self.logger = logger - def add_options (self, parser): + def add_options(self, parser): # we don't have any options for now pass # hrn hash is initialized from current db # remember just-created records as we go # xxx might make sense to add a UNIQUE constraint in the db itself - def remember_record_by_hrn (self, record): + def remember_record_by_hrn(self, record): tuple = (record.type, record.hrn) if tuple in self.records_by_type_hrn: - self.logger.warning ("NitosImporter.remember_record_by_hrn: duplicate (%s,%s)"%tuple) + self.logger.warning( + "NitosImporter.remember_record_by_hrn: duplicate (%s,%s)" % tuple) return - self.records_by_type_hrn [ tuple ] = record + self.records_by_type_hrn[tuple] = record # ditto for pointer hash - def remember_record_by_pointer (self, record): + def remember_record_by_pointer(self, record): if record.pointer == -1: - self.logger.warning ("NitosImporter.remember_record_by_pointer: pointer is void") + self.logger.warning( + "NitosImporter.remember_record_by_pointer: pointer is void") return tuple = (record.type, record.pointer) if tuple in self.records_by_type_pointer: - self.logger.warning ("NitosImporter.remember_record_by_pointer: duplicate (%s,%s)"%tuple) + self.logger.warning( + "NitosImporter.remember_record_by_pointer: duplicate (%s,%s)" % tuple) return - self.records_by_type_pointer [ ( record.type, record.pointer,) ] = record + self.records_by_type_pointer[(record.type, record.pointer,)] = record - def remember_record (self, record): - self.remember_record_by_hrn (record) - self.remember_record_by_pointer (record) + def remember_record(self, record): + self.remember_record_by_hrn(record) + self.remember_record_by_pointer(record) - def locate_by_type_hrn (self, type, hrn): - return self.records_by_type_hrn.get ( (type, hrn), None) + def locate_by_type_hrn(self, type, hrn): + return self.records_by_type_hrn.get((type, hrn), None) - def locate_by_type_pointer (self, type, pointer): - return self.records_by_type_pointer.get ( (type, pointer), None) + def locate_by_type_pointer(self, type, pointer): + return self.records_by_type_pointer.get((type, pointer), None) # a convenience/helper function to see if a record is already known - # a former, broken, attempt (in 2.1-9) had been made + # a former, broken, attempt (in 2.1-9) had been made # to try and use 'pointer' as a first, most significant attempt - # the idea being to preserve stuff as much as possible, and thus + # the idea being to preserve stuff as much as possible, and thus # to avoid creating a new gid in the case of a simple hrn rename # however this of course doesn't work as the gid depends on the hrn... - #def locate (self, type, hrn=None, pointer=-1): + # def locate (self, type, hrn=None, pointer=-1): # if pointer!=-1: # attempt = self.locate_by_type_pointer (type, pointer) # if attempt : return attempt @@ -78,28 +82,32 @@ class NitosImporter: # this makes the run method a bit abtruse - out of the way - def run (self, options): - config = Config () + def run(self, options): + config = Config() interface_hrn = config.SFA_INTERFACE_HRN root_auth = config.SFA_REGISTRY_ROOT_AUTH - shell = NitosShell (config) + shell = NitosShell(config) - ######## retrieve all existing SFA objects + # retrieve all existing SFA objects all_records = global_dbsession.query(RegRecord).all() - # create hash by (type,hrn) - # we essentially use this to know if a given record is already known to SFA + # create hash by (type,hrn) + # we essentially use this to know if a given record is already known to + # SFA self.records_by_type_hrn = \ - dict ( [ ( (record.type, record.hrn) , record ) for record in all_records ] ) - # create hash by (type,pointer) + dict([((record.type, record.hrn), record) + for record in all_records]) + # create hash by (type,pointer) self.records_by_type_pointer = \ - dict ( [ ( (record.type, record.pointer) , record ) for record in all_records - if record.pointer != -1] ) + dict([((record.type, record.pointer), record) for record in all_records + if record.pointer != -1]) - # initialize record.stale to True by default, then mark stale=False on the ones that are in use - for record in all_records: record.stale=True + # initialize record.stale to True by default, then mark stale=False on + # the ones that are in use + for record in all_records: + record.stale = True - ######## retrieve NITOS data + # retrieve NITOS data # Get site info # retrieve only required stuf site = shell.getTestbedInfo() @@ -107,9 +115,9 @@ class NitosImporter: # create a hash of sites by login_base # # sites_by_login_base = dict ( [ ( site['login_base'], site ) for site in sites ] ) # Get all NITOS users - users = shell.getUsers() + users = shell.getUsers() # create a hash of users by user_id - users_by_id = dict ( [ ( user['user_id'], user) for user in users ] ) + users_by_id = dict([(user['user_id'], user) for user in users]) # Get all NITOS public keys # accumulate key ids for keys retrieval # key_ids = [] @@ -118,26 +126,26 @@ class NitosImporter: # keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids, # 'key_type': 'ssh'} ) # # create a hash of keys by key_id -# keys_by_id = dict ( [ ( key['key_id'], key ) for key in keys ] ) +# keys_by_id = dict ( [ ( key['key_id'], key ) for key in keys ] ) # create a dict user_id -> [ (nitos)keys ] - keys_by_user_id = dict ( [ ( user['user_id'], user['keys']) for user in users ] ) - # Get all nitos nodes + keys_by_user_id = dict( + [(user['user_id'], user['keys']) for user in users]) + # Get all nitos nodes nodes = shell.getNodes({}, []) # create hash by node_id - nodes_by_id = dict ( [ (node['node_id'], node) for node in nodes ] ) + nodes_by_id = dict([(node['node_id'], node) for node in nodes]) # Get all nitos slices slices = shell.getSlices({}, []) # create hash by slice_id - slices_by_id = dict ( [ (slice['slice_id'], slice) for slice in slices ] ) + slices_by_id = dict([(slice['slice_id'], slice) for slice in slices]) - - # start importing + # start importing for site in sites: - #for i in [0]: + # for i in [0]: site_hrn = _get_site_hrn(interface_hrn, site) # import if hrn is not in list of existing hrns or if the hrn exists # but its not a site record - site_record=self.locate_by_type_hrn ('authority', site_hrn) + site_record = self.locate_by_type_hrn('authority', site_hrn) if not site_record: try: urn = hrn_to_urn(site_hrn, 'authority') @@ -150,181 +158,206 @@ class NitosImporter: site_record.just_created() global_dbsession.add(site_record) global_dbsession.commit() - self.logger.info("NitosImporter: imported authority (site) : %s" % site_record) - self.remember_record (site_record) + self.logger.info( + "NitosImporter: imported authority (site) : %s" % site_record) + self.remember_record(site_record) except: # if the site import fails then there is no point in trying to import the - # site's child records (node, slices, persons), so skip them. - self.logger.log_exc("NitosImporter: failed to import site. Skipping child records") - continue + # site's child records (node, slices, persons), so skip + # them. + self.logger.log_exc( + "NitosImporter: failed to import site. Skipping child records") + continue else: # xxx update the record ... pass - site_record.stale=False - + site_record.stale = False + # import node records for node in nodes: site_auth = get_authority(site_hrn) site_name = site['name'] - node_hrn = hostname_to_hrn(site_auth, site_name, node['hostname']) + node_hrn = hostname_to_hrn( + site_auth, site_name, node['hostname']) # xxx this sounds suspicious - if len(node_hrn) > 64: node_hrn = node_hrn[:64] - node_record = self.locate_by_type_hrn ( 'node', node_hrn ) + if len(node_hrn) > 64: + node_hrn = node_hrn[:64] + node_record = self.locate_by_type_hrn('node', node_hrn) if not node_record: try: pkey = Keypair(create=True) urn = hrn_to_urn(node_hrn, 'node') - node_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey) - node_record = RegNode (hrn=node_hrn, gid=node_gid, - pointer =node['node_id'], - authority=get_authority(node_hrn)) + node_gid = self.auth_hierarchy.create_gid( + urn, create_uuid(), pkey) + node_record = RegNode(hrn=node_hrn, gid=node_gid, + pointer=node['node_id'], + authority=get_authority(node_hrn)) node_record.just_created() global_dbsession.add(node_record) global_dbsession.commit() - self.logger.info("NitosImporter: imported node: %s" % node_record) - self.remember_record (node_record) + self.logger.info( + "NitosImporter: imported node: %s" % node_record) + self.remember_record(node_record) except: - self.logger.log_exc("NitosImporter: failed to import node") + self.logger.log_exc( + "NitosImporter: failed to import node") else: # xxx update the record ... pass - - node_record.stale=False + node_record.stale = False # import users for user in users: - user_hrn = username_to_hrn(interface_hrn, site['name'], user['username']) + user_hrn = username_to_hrn( + interface_hrn, site['name'], user['username']) # xxx suspicious again - if len(user_hrn) > 64: user_hrn = user_hrn[:64] + if len(user_hrn) > 64: + user_hrn = user_hrn[:64] user_urn = hrn_to_urn(user_hrn, 'user') - user_record = self.locate_by_type_hrn ( 'user', user_hrn) + user_record = self.locate_by_type_hrn('user', user_hrn) - # return a tuple pubkey (a nitos key object) and pkey (a Keypair object) - def init_user_key (user): + # return a tuple pubkey (a nitos key object) and pkey (a + # Keypair object) + def init_user_key(user): pubkey = None pkey = None - if user['keys']: + if user['keys']: # randomly pick first key in set for key in user['keys']: - pubkey = key - try: + pubkey = key + try: pkey = convert_public_key(pubkey) break - except: + except: continue if not pkey: - self.logger.warn('NitosImporter: unable to convert public key for %s' % user_hrn) + self.logger.warn( + 'NitosImporter: unable to convert public key for %s' % user_hrn) pkey = Keypair(create=True) else: - # the user has no keys. Creating a random keypair for the user's gid - self.logger.warn("NitosImporter: user %s does not have a NITOS public key"%user_hrn) + # the user has no keys. Creating a random keypair for + # the user's gid + self.logger.warn( + "NitosImporter: user %s does not have a NITOS public key" % user_hrn) pkey = Keypair(create=True) return (pubkey, pkey) # new user try: if not user_record: - (pubkey,pkey) = init_user_key (user) - user_gid = self.auth_hierarchy.create_gid(user_urn, create_uuid(), pkey) + (pubkey, pkey) = init_user_key(user) + user_gid = self.auth_hierarchy.create_gid( + user_urn, create_uuid(), pkey) user_gid.set_email(user['email']) - user_record = RegUser (hrn=user_hrn, gid=user_gid, - pointer=user['user_id'], - authority=get_authority(user_hrn), - email=user['email']) - if pubkey: - user_record.reg_keys=[RegKey (pubkey)] + user_record = RegUser(hrn=user_hrn, gid=user_gid, + pointer=user['user_id'], + authority=get_authority( + user_hrn), + email=user['email']) + if pubkey: + user_record.reg_keys = [RegKey(pubkey)] else: - self.logger.warning("No key found for user %s"%user_record) + self.logger.warning( + "No key found for user %s" % user_record) user_record.just_created() - global_dbsession.add (user_record) + global_dbsession.add(user_record) global_dbsession.commit() - self.logger.info("NitosImporter: imported user: %s" % user_record) - self.remember_record ( user_record ) + self.logger.info( + "NitosImporter: imported user: %s" % user_record) + self.remember_record(user_record) else: # update the record ? - # if user's primary key has changed then we need to update the + # if user's primary key has changed then we need to update the # users gid by forcing an update here sfa_keys = user_record.reg_keys - def sfa_key_in_list (sfa_key,nitos_user_keys): + def sfa_key_in_list(sfa_key, nitos_user_keys): for nitos_key in nitos_user_keys: - if nitos_key==sfa_key: return True + if nitos_key == sfa_key: + return True return False # are all the SFA keys known to nitos ? - new_keys=False + new_keys = False if not sfa_keys and user['keys']: new_keys = True else: for sfa_key in sfa_keys: - if not sfa_key_in_list (sfa_key.key,user['keys']): - new_keys = True + if not sfa_key_in_list(sfa_key.key, user['keys']): + new_keys = True if new_keys: - (pubkey,pkey) = init_user_key (user) - user_gid = self.auth_hierarchy.create_gid(user_urn, create_uuid(), pkey) + (pubkey, pkey) = init_user_key(user) + user_gid = self.auth_hierarchy.create_gid( + user_urn, create_uuid(), pkey) if not pubkey: - user_record.reg_keys=[] + user_record.reg_keys = [] else: - user_record.reg_keys=[ RegKey (pubkey)] + user_record.reg_keys = [RegKey(pubkey)] user_record.gid = user_gid user_record.just_updated() - self.logger.info("NitosImporter: updated user: %s" % user_record) + self.logger.info( + "NitosImporter: updated user: %s" % user_record) user_record.email = user['email'] global_dbsession.commit() - user_record.stale=False + user_record.stale = False except: - self.logger.log_exc("NitosImporter: failed to import user %s %s"%(user['user_id'],user['email'])) - + self.logger.log_exc("NitosImporter: failed to import user %s %s" % ( + user['user_id'], user['email'])) # import slices for slice in slices: - slice_hrn = slicename_to_hrn(interface_hrn, site['name'], slice['slice_name']) - slice_record = self.locate_by_type_hrn ('slice', slice_hrn) + slice_hrn = slicename_to_hrn( + interface_hrn, site['name'], slice['slice_name']) + slice_record = self.locate_by_type_hrn('slice', slice_hrn) if not slice_record: try: pkey = Keypair(create=True) urn = hrn_to_urn(slice_hrn, 'slice') - slice_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey) - slice_record = RegSlice (hrn=slice_hrn, gid=slice_gid, - pointer=slice['slice_id'], - authority=get_authority(slice_hrn)) + slice_gid = self.auth_hierarchy.create_gid( + urn, create_uuid(), pkey) + slice_record = RegSlice(hrn=slice_hrn, gid=slice_gid, + pointer=slice['slice_id'], + authority=get_authority(slice_hrn)) slice_record.just_created() global_dbsession.add(slice_record) global_dbsession.commit() - self.logger.info("NitosImporter: imported slice: %s" % slice_record) - self.remember_record ( slice_record ) + self.logger.info( + "NitosImporter: imported slice: %s" % slice_record) + self.remember_record(slice_record) except: - self.logger.log_exc("NitosImporter: failed to import slice") + self.logger.log_exc( + "NitosImporter: failed to import slice") else: # xxx update the record ... - self.logger.warning ("Slice update not yet implemented") + self.logger.warning("Slice update not yet implemented") pass # record current users affiliated with the slice slice_record.reg_researchers = \ - [ self.locate_by_type_pointer ('user',int(user_id)) for user_id in slice['user_ids'] ] + [self.locate_by_type_pointer('user', int( + user_id)) for user_id in slice['user_ids']] global_dbsession.commit() - slice_record.stale=False + slice_record.stale = False - - ### remove stale records + # remove stale records # special records must be preserved - system_hrns = [interface_hrn, root_auth, interface_hrn + '.slicemanager'] - for record in all_records: - if record.hrn in system_hrns: - record.stale=False + system_hrns = [interface_hrn, root_auth, + interface_hrn + '.slicemanager'] + for record in all_records: + if record.hrn in system_hrns: + record.stale = False if record.peer_authority: - record.stale=False + record.stale = False for record in all_records: - try: stale=record.stale - except: - stale=True - self.logger.warning("stale not found with %s"%record) + try: + stale = record.stale + except: + stale = True + self.logger.warning("stale not found with %s" % record) if stale: - self.logger.info("NitosImporter: deleting stale record: %s" % record) + self.logger.info( + "NitosImporter: deleting stale record: %s" % record) global_dbsession.delete(record) global_dbsession.commit() - - diff --git a/sfa/importer/openstackimporter.py b/sfa/importer/openstackimporter.py index c8233bda..eb531738 100644 --- a/sfa/importer/openstackimporter.py +++ b/sfa/importer/openstackimporter.py @@ -2,14 +2,15 @@ import os from sfa.util.config import Config from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn -from sfa.trust.gid import create_uuid +from sfa.trust.gid import create_uuid from sfa.trust.certificate import convert_public_key, Keypair -# using global alchemy.session() here is fine +# using global alchemy.session() here is fine # as importer is on standalone one-shot process from sfa.storage.alchemy import global_dbsession from sfa.storage.model import RegRecord, RegAuthority, RegUser, RegSlice, RegNode from sfa.openstack.osxrn import OSXrn -from sfa.openstack.shell import Shell +from sfa.openstack.shell import Shell + def load_keys(filename): keys = {} @@ -22,23 +23,25 @@ def load_keys(filename): except: return keys + def save_keys(filename, keys): f = open(filename, 'w') f.write("keys = %s" % str(keys)) f.close() + class OpenstackImporter: - def __init__ (self, auth_hierarchy, logger): + def __init__(self, auth_hierarchy, logger): self.auth_hierarchy = auth_hierarchy - self.logger=logger - self.config = Config () + self.logger = logger + self.config = Config() self.interface_hrn = self.config.SFA_INTERFACE_HRN self.root_auth = self.config.SFA_REGISTRY_ROOT_AUTH - self.shell = Shell (self.config) + self.shell = Shell(self.config) - def add_options (self, parser): - self.logger.debug ("OpenstackImporter: no options yet") + def add_options(self, parser): + self.logger.debug("OpenstackImporter: no options yet") pass def import_users(self, existing_hrns, existing_records): @@ -52,38 +55,44 @@ class OpenstackImporter: auth_hrn = self.config.SFA_INTERFACE_HRN if user.tenantId is not None: tenant = self.shell.auth_manager.tenants.find(id=user.tenantId) - auth_hrn = OSXrn(name=tenant.name, auth=self.config.SFA_INTERFACE_HRN, type='authority').get_hrn() + auth_hrn = OSXrn( + name=tenant.name, auth=self.config.SFA_INTERFACE_HRN, type='authority').get_hrn() hrn = OSXrn(name=user.name, auth=auth_hrn, type='user').get_hrn() users_dict[hrn] = user old_keys = old_user_keys.get(hrn, []) keyname = OSXrn(xrn=hrn, type='user').get_slicename() - keys = [k.public_key for k in self.shell.nova_manager.keypairs.findall(name=keyname)] + keys = [ + k.public_key for k in self.shell.nova_manager.keypairs.findall(name=keyname)] user_keys[hrn] = keys update_record = False if old_keys != keys: update_record = True if hrn not in existing_hrns or \ - (hrn, 'user') not in existing_records or update_record: + (hrn, 'user') not in existing_records or update_record: urn = OSXrn(xrn=hrn, type='user').get_urn() if keys: try: pkey = convert_public_key(keys[0]) except: - self.logger.log_exc('unable to convert public key for %s' % hrn) + self.logger.log_exc( + 'unable to convert public key for %s' % hrn) pkey = Keypair(create=True) else: - self.logger.warn("OpenstackImporter: person %s does not have a PL public key"%hrn) + self.logger.warn( + "OpenstackImporter: person %s does not have a PL public key" % hrn) pkey = Keypair(create=True) - user_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey, email=user.email) - user_record = RegUser () - user_record.type='user' - user_record.hrn=hrn - user_record.gid=user_gid - user_record.authority=get_authority(hrn) + user_gid = self.auth_hierarchy.create_gid( + urn, create_uuid(), pkey, email=user.email) + user_record = RegUser() + user_record.type = 'user' + user_record.hrn = hrn + user_record.gid = user_gid + user_record.authority = get_authority(hrn) global_dbsession.add(user_record) global_dbsession.commit() - self.logger.info("OpenstackImporter: imported person %s" % user_record) + self.logger.info( + "OpenstackImporter: imported person %s" % user_record) return users_dict, user_keys @@ -97,7 +106,8 @@ class OpenstackImporter: for tenant in tenants: hrn = self.config.SFA_INTERFACE_HRN + '.' + tenant.name tenants_dict[hrn] = tenant - authority_hrn = OSXrn(xrn=hrn, type='authority').get_authority_hrn() + authority_hrn = OSXrn( + xrn=hrn, type='authority').get_authority_hrn() if hrn in existing_hrns: continue @@ -110,71 +120,73 @@ class OpenstackImporter: self.auth_hierarchy.create_auth(urn) auth_info = self.auth_hierarchy.get_auth_info(urn) gid = auth_info.get_gid_object() - record.type='authority' - record.hrn=hrn - record.gid=gid - record.authority=get_authority(hrn) + record.type = 'authority' + record.hrn = hrn + record.gid = gid + record.authority = get_authority(hrn) global_dbsession.add(record) global_dbsession.commit() - self.logger.info("OpenstackImporter: imported authority: %s" % record) + self.logger.info( + "OpenstackImporter: imported authority: %s" % record) else: - record = RegSlice () + record = RegSlice() urn = OSXrn(xrn=hrn, type='slice').get_urn() pkey = Keypair(create=True) gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey) - record.type='slice' - record.hrn=hrn - record.gid=gid - record.authority=get_authority(hrn) + record.type = 'slice' + record.hrn = hrn + record.gid = gid + record.authority = get_authority(hrn) global_dbsession.add(record) global_dbsession.commit() - self.logger.info("OpenstackImporter: imported slice: %s" % record) + self.logger.info( + "OpenstackImporter: imported slice: %s" % record) return tenants_dict - def run (self, options): + def run(self, options): # we don't have any options for now - self.logger.info ("OpenstackImporter.run : to do") + self.logger.info("OpenstackImporter.run : to do") # create dict of all existing sfa records existing_records = {} existing_hrns = [] key_ids = [] for record in global_dbsession.query(RegRecord): - existing_records[ (record.hrn, record.type,) ] = record - existing_hrns.append(record.hrn) - + existing_records[(record.hrn, record.type,)] = record + existing_hrns.append(record.hrn) tenants_dict = self.import_tenants(existing_hrns, existing_records) - users_dict, user_keys = self.import_users(existing_hrns, existing_records) - - # remove stale records - system_records = [self.interface_hrn, self.root_auth, self.interface_hrn + '.slicemanager'] + users_dict, user_keys = self.import_users( + existing_hrns, existing_records) + + # remove stale records + system_records = [self.interface_hrn, self.root_auth, + self.interface_hrn + '.slicemanager'] for (record_hrn, type) in existing_records.keys(): if record_hrn in system_records: continue - + record = existing_records[(record_hrn, type)] if record.peer_authority: continue if type == 'user': if record_hrn in users_dict: - continue + continue elif type in['slice', 'authority']: if record_hrn in tenants_dict: continue else: - continue - - record_object = existing_records[ (record_hrn, type) ] + continue + + record_object = existing_records[(record_hrn, type)] self.logger.info("OpenstackImporter: removing %s " % record) global_dbsession.delete(record_object) global_dbsession.commit() - + # save pub keys self.logger.info('OpenstackImporter: saving current pub keys') keys_filename = self.config.config_path + os.sep + 'person_keys.py' - save_keys(keys_filename, user_keys) - + save_keys(keys_filename, user_keys) diff --git a/sfa/importer/plimporter.py b/sfa/importer/plimporter.py index ae0e83f9..3bbd8c30 100644 --- a/sfa/importer/plimporter.py +++ b/sfa/importer/plimporter.py @@ -1,8 +1,8 @@ # # PlanetLab importer -# +# # requirements -# +# # read the planetlab database and update the local registry database accordingly # (in other words, with this testbed, the SFA registry is *not* authoritative) # so we update the following collections @@ -10,33 +10,34 @@ # . node (from pl nodes) # . users+keys (from pl persons and attached keys) # known limitation : *one* of the ssh keys is chosen at random here -# xxx todo/check xxx at the very least, when a key is known to the registry +# xxx todo/check xxx at the very least, when a key is known to the registry # and is still current in plc # then we should definitely make sure to keep that one in sfa... # . slice+researchers (from pl slices and attached users) -# +# import os from sfa.util.config import Config from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn -from sfa.trust.gid import create_uuid +from sfa.trust.gid import create_uuid from sfa.trust.certificate import convert_public_key, Keypair -# using global alchemy.session() here is fine +# using global alchemy.session() here is fine # as importer is on standalone one-shot process from sfa.storage.alchemy import global_dbsession from sfa.storage.model import RegRecord, RegAuthority, RegSlice, RegNode, RegUser, RegKey -from sfa.planetlab.plshell import PlShell +from sfa.planetlab.plshell import PlShell from sfa.planetlab.plxrn import hostname_to_hrn, slicename_to_hrn, email_to_hrn, hrn_to_pl_slicename + def _get_site_hrn(interface_hrn, site): # Hardcode 'internet2' into the hrn for sites hosting # internet2 nodes. This is a special operation for some vini # sites only - hrn = ".".join([interface_hrn, site['login_base']]) + hrn = ".".join([interface_hrn, site['login_base']]) if ".vini" in interface_hrn and interface_hrn.endswith('vini'): if site['login_base'].startswith("i2") or site['login_base'].startswith("nlr"): hrn = ".".join([interface_hrn, "internet2", site['login_base']]) @@ -45,52 +46,55 @@ def _get_site_hrn(interface_hrn, site): class PlImporter: - def __init__ (self, auth_hierarchy, logger): + def __init__(self, auth_hierarchy, logger): self.auth_hierarchy = auth_hierarchy self.logger = logger - def add_options (self, parser): + def add_options(self, parser): # we don't have any options for now pass # hrn hash is initialized from current db # remember just-created records as we go # xxx might make sense to add a UNIQUE constraint in the db itself - def remember_record_by_hrn (self, record): + def remember_record_by_hrn(self, record): tuple = (record.type, record.hrn) if tuple in self.records_by_type_hrn: - self.logger.warning ("PlImporter.remember_record_by_hrn: duplicate {}".format(tuple)) + self.logger.warning( + "PlImporter.remember_record_by_hrn: duplicate {}".format(tuple)) return - self.records_by_type_hrn [ tuple ] = record + self.records_by_type_hrn[tuple] = record # ditto for pointer hash - def remember_record_by_pointer (self, record): + def remember_record_by_pointer(self, record): if record.pointer == -1: - self.logger.warning ("PlImporter.remember_record_by_pointer: pointer is void") + self.logger.warning( + "PlImporter.remember_record_by_pointer: pointer is void") return tuple = (record.type, record.pointer) if tuple in self.records_by_type_pointer: - self.logger.warning ("PlImporter.remember_record_by_pointer: duplicate {}".format(tuple)) + self.logger.warning( + "PlImporter.remember_record_by_pointer: duplicate {}".format(tuple)) return - self.records_by_type_pointer [ ( record.type, record.pointer,) ] = record + self.records_by_type_pointer[(record.type, record.pointer,)] = record - def remember_record (self, record): - self.remember_record_by_hrn (record) - self.remember_record_by_pointer (record) + def remember_record(self, record): + self.remember_record_by_hrn(record) + self.remember_record_by_pointer(record) - def locate_by_type_hrn (self, type, hrn): - return self.records_by_type_hrn.get ( (type, hrn), None) + def locate_by_type_hrn(self, type, hrn): + return self.records_by_type_hrn.get((type, hrn), None) - def locate_by_type_pointer (self, type, pointer): - return self.records_by_type_pointer.get ( (type, pointer), None) + def locate_by_type_pointer(self, type, pointer): + return self.records_by_type_pointer.get((type, pointer), None) # a convenience/helper function to see if a record is already known - # a former, broken, attempt (in 2.1-9) had been made + # a former, broken, attempt (in 2.1-9) had been made # to try and use 'pointer' as a first, most significant attempt - # the idea being to preserve stuff as much as possible, and thus + # the idea being to preserve stuff as much as possible, and thus # to avoid creating a new gid in the case of a simple hrn rename # however this of course doesn't work as the gid depends on the hrn... - #def locate (self, type, hrn=None, pointer=-1): + # def locate (self, type, hrn=None, pointer=-1): # if pointer!=-1: # attempt = self.locate_by_type_pointer (type, pointer) # if attempt : return attempt @@ -100,15 +104,16 @@ class PlImporter: # return None # this makes the run method a bit abtruse - out of the way - def create_special_vini_record (self, interface_hrn): + def create_special_vini_record(self, interface_hrn): # special case for vini if ".vini" in interface_hrn and interface_hrn.endswith('vini'): # create a fake internet2 site first - i2site = {'name': 'Internet2', 'login_base': 'internet2', 'site_id': -1} + i2site = {'name': 'Internet2', + 'login_base': 'internet2', 'site_id': -1} site_hrn = _get_site_hrn(interface_hrn, i2site) # import if hrn is not in list of existing hrns or if the hrn exists # but its not a site record - if ( 'authority', site_hrn, ) not in self.records_by_type_hrn: + if ('authority', site_hrn, ) not in self.records_by_type_hrn: urn = hrn_to_urn(site_hrn, 'authority') if not self.auth_hierarchy.auth_exists(urn): self.auth_hierarchy.create_auth(urn) @@ -119,57 +124,65 @@ class PlImporter: auth_record.just_created() global_dbsession.add(auth_record) global_dbsession.commit() - self.logger.info("PlImporter: Imported authority (vini site) {}".format(auth_record)) - self.remember_record ( site_record ) + self.logger.info( + "PlImporter: Imported authority (vini site) {}".format(auth_record)) + self.remember_record(site_record) - def run (self, options): - config = Config () + def run(self, options): + config = Config() interface_hrn = config.SFA_INTERFACE_HRN root_auth = config.SFA_REGISTRY_ROOT_AUTH - shell = PlShell (config) + shell = PlShell(config) - ######## retrieve all existing SFA objects + # retrieve all existing SFA objects all_records = global_dbsession.query(RegRecord).all() - # create hash by (type,hrn) - # we essentially use this to know if a given record is already known to SFA + # create hash by (type,hrn) + # we essentially use this to know if a given record is already known to + # SFA self.records_by_type_hrn = \ - dict ( [ ( (record.type, record.hrn) , record ) for record in all_records ] ) - # create hash by (type,pointer) + dict([((record.type, record.hrn), record) + for record in all_records]) + # create hash by (type,pointer) self.records_by_type_pointer = \ - dict ( [ ( (record.type, record.pointer) , record ) for record in all_records - if record.pointer != -1] ) + dict([((record.type, record.pointer), record) for record in all_records + if record.pointer != -1]) - # initialize record.stale to True by default, then mark stale=False on the ones that are in use + # initialize record.stale to True by default, then mark stale=False on + # the ones that are in use for record in all_records: record.stale = True - ######## retrieve PLC data + # retrieve PLC data # Get all plc sites # retrieve only required stuf - sites = shell.GetSites({'peer_id': None, 'enabled' : True}, - ['site_id','login_base','node_ids','slice_ids','person_ids', 'name', 'hrn']) + sites = shell.GetSites({'peer_id': None, 'enabled': True}, + ['site_id', 'login_base', 'node_ids', 'slice_ids', 'person_ids', 'name', 'hrn']) # create a hash of sites by login_base # sites_by_login_base = dict ( [ ( site['login_base'], site ) for site in sites ] ) # Get all plc users - persons = shell.GetPersons({'peer_id': None, 'enabled': True}, + persons = shell.GetPersons({'peer_id': None, 'enabled': True}, ['person_id', 'email', 'key_ids', 'site_ids', 'role_ids', 'hrn']) # create a hash of persons by person_id - persons_by_id = dict ( [ ( person['person_id'], person) for person in persons ] ) - # also gather non-enabled user accounts so as to issue relevant warnings - disabled_persons = shell.GetPersons({'peer_id': None, 'enabled': False}, ['person_id']) - disabled_person_ids = [ person['person_id'] for person in disabled_persons ] + persons_by_id = dict([(person['person_id'], person) + for person in persons]) + # also gather non-enabled user accounts so as to issue relevant + # warnings + disabled_persons = shell.GetPersons( + {'peer_id': None, 'enabled': False}, ['person_id']) + disabled_person_ids = [person['person_id'] + for person in disabled_persons] # Get all plc public keys # accumulate key ids for keys retrieval key_ids = [] for person in persons: key_ids.extend(person['key_ids']) - keys = shell.GetKeys( {'peer_id': None, 'key_id': key_ids, - 'key_type': 'ssh'} ) + keys = shell.GetKeys({'peer_id': None, 'key_id': key_ids, + 'key_type': 'ssh'}) # create a hash of keys by key_id - keys_by_id = dict ( [ ( key['key_id'], key ) for key in keys ] ) + keys_by_id = dict([(key['key_id'], key) for key in keys]) # create a dict person_id -> [ (plc)keys ] - keys_by_person_id = {} + keys_by_person_id = {} for person in persons: pubkeys = [] for key_id in person['key_ids']: @@ -179,30 +192,33 @@ class PlImporter: key = keys_by_id[key_id] pubkeys.append(key) except: - self.logger.warning("Could not spot key {} - probably non-ssh".format(key_id)) + self.logger.warning( + "Could not spot key {} - probably non-ssh".format(key_id)) keys_by_person_id[person['person_id']] = pubkeys - # Get all plc nodes - nodes = shell.GetNodes( {'peer_id': None}, ['node_id', 'hostname', 'site_id']) + # Get all plc nodes + nodes = shell.GetNodes({'peer_id': None}, [ + 'node_id', 'hostname', 'site_id']) # create hash by node_id - nodes_by_id = dict ( [ ( node['node_id'], node, ) for node in nodes ] ) + nodes_by_id = dict([(node['node_id'], node, ) for node in nodes]) # Get all plc slices - slices = shell.GetSlices( {'peer_id': None}, ['slice_id', 'name', 'person_ids', 'hrn']) + slices = shell.GetSlices( + {'peer_id': None}, ['slice_id', 'name', 'person_ids', 'hrn']) # create hash by slice_id - slices_by_id = dict ( [ (slice['slice_id'], slice ) for slice in slices ] ) + slices_by_id = dict([(slice['slice_id'], slice) for slice in slices]) # isolate special vini case in separate method - self.create_special_vini_record (interface_hrn) + self.create_special_vini_record(interface_hrn) # Get top authority record - top_auth_record = self.locate_by_type_hrn ('authority', root_auth) + top_auth_record = self.locate_by_type_hrn('authority', root_auth) admins = [] - # start importing + # start importing for site in sites: try: - site_sfa_created = shell.GetSiteSfaCreated(site['site_id']) - except: - site_sfa_created = None + site_sfa_created = shell.GetSiteSfaCreated(site['site_id']) + except: + site_sfa_created = None if site['name'].startswith('sfa:') or site_sfa_created == 'True': continue @@ -210,7 +226,7 @@ class PlImporter: site_hrn = site['hrn'] # import if hrn is not in list of existing hrns or if the hrn exists # but its not a site record - site_record = self.locate_by_type_hrn ('authority', site_hrn) + site_record = self.locate_by_type_hrn('authority', site_hrn) if not site_record: try: urn = hrn_to_urn(site_hrn, 'authority') @@ -219,54 +235,62 @@ class PlImporter: auth_info = self.auth_hierarchy.get_auth_info(urn) site_record = RegAuthority(hrn=site_hrn, gid=auth_info.get_gid_object(), pointer=site['site_id'], - authority=get_authority(site_hrn), + authority=get_authority( + site_hrn), name=site['name']) site_record.just_created() global_dbsession.add(site_record) global_dbsession.commit() - self.logger.info("PlImporter: imported authority (site) : {}".format(site_record)) + self.logger.info( + "PlImporter: imported authority (site) : {}".format(site_record)) self.remember_record(site_record) except: # if the site import fails then there is no point in trying to import the - # site's child records (node, slices, persons), so skip them. - self.logger.log_exc("PlImporter: failed to import site {}. Skipping child records"\ + # site's child records (node, slices, persons), so skip + # them. + self.logger.log_exc("PlImporter: failed to import site {}. Skipping child records" .format(site_hrn)) - continue + continue else: # xxx update the record ... site_record.name = site['name'] pass site_record.stale = False - + # import node records for node_id in site['node_ids']: try: node = nodes_by_id[node_id] except: - self.logger.warning ("PlImporter: cannot find node_id {} - ignored" - .format(node_id)) - continue + self.logger.warning("PlImporter: cannot find node_id {} - ignored" + .format(node_id)) + continue site_auth = get_authority(site_hrn) site_name = site['login_base'] - node_hrn = hostname_to_hrn(site_auth, site_name, node['hostname']) + node_hrn = hostname_to_hrn( + site_auth, site_name, node['hostname']) # xxx this sounds suspicious - if len(node_hrn) > 64: node_hrn = node_hrn[:64] - node_record = self.locate_by_type_hrn ( 'node', node_hrn ) + if len(node_hrn) > 64: + node_hrn = node_hrn[:64] + node_record = self.locate_by_type_hrn('node', node_hrn) if not node_record: try: pkey = Keypair(create=True) urn = hrn_to_urn(node_hrn, 'node') - node_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey) - node_record = RegNode (hrn=node_hrn, gid=node_gid, - pointer =node['node_id'], - authority=get_authority(node_hrn)) + node_gid = self.auth_hierarchy.create_gid( + urn, create_uuid(), pkey) + node_record = RegNode(hrn=node_hrn, gid=node_gid, + pointer=node['node_id'], + authority=get_authority(node_hrn)) node_record.just_created() global_dbsession.add(node_record) global_dbsession.commit() - self.logger.info("PlImporter: imported node: {}".format(node_record)) - self.remember_record (node_record) + self.logger.info( + "PlImporter: imported node: {}".format(node_record)) + self.remember_record(node_record) except: - self.logger.log_exc("PlImporter: failed to import node {}".format(node_hrn)) + self.logger.log_exc( + "PlImporter: failed to import node {}".format(node_hrn)) continue else: # xxx update the record ... @@ -283,27 +307,30 @@ class PlImporter: elif person_id in disabled_person_ids: pass else: - self.logger.warning ("PlImporter: cannot locate person_id {} in site {} - ignored"\ - .format(person_id, site_hrn)) + self.logger.warning("PlImporter: cannot locate person_id {} in site {} - ignored" + .format(person_id, site_hrn)) # make sure to NOT run this if anything is wrong - if not proceed: continue + if not proceed: + continue #person_hrn = email_to_hrn(site_hrn, person['email']) person_hrn = person['hrn'] if person_hrn is None: - self.logger.warn("Person {} has no hrn - skipped".format(person['email'])) + self.logger.warn( + "Person {} has no hrn - skipped".format(person['email'])) continue # xxx suspicious again if len(person_hrn) > 64: person_hrn = person_hrn[:64] person_urn = hrn_to_urn(person_hrn, 'user') - user_record = self.locate_by_type_hrn ( 'user', person_hrn) + user_record = self.locate_by_type_hrn('user', person_hrn) - # return a tuple pubkey (a plc key object) and pkey (a Keypair object) - def init_person_key (person, plc_keys): + # return a tuple pubkey (a plc key object) and pkey (a Keypair + # object) + def init_person_key(person, plc_keys): pubkey = None - if person['key_ids']: + if person['key_ids']: # randomly pick first key in set pubkey = plc_keys[0] try: @@ -313,7 +340,8 @@ class PlImporter: .format(person_hrn)) pkey = Keypair(create=True) else: - # the user has no keys. Creating a random keypair for the user's gid + # the user has no keys. Creating a random keypair for + # the user's gid self.logger.warn("PlImporter: person {} does not have a PL public key" .format(person_hrn)) pkey = Keypair(create=True) @@ -321,24 +349,28 @@ class PlImporter: # new person try: - plc_keys = keys_by_person_id.get(person['person_id'],[]) + plc_keys = keys_by_person_id.get(person['person_id'], []) if not user_record: - (pubkey, pkey) = init_person_key (person, plc_keys ) + (pubkey, pkey) = init_person_key(person, plc_keys) person_gid = self.auth_hierarchy.create_gid(person_urn, create_uuid(), pkey, email=person['email']) - user_record = RegUser (hrn=person_hrn, gid=person_gid, - pointer=person['person_id'], - authority=get_authority(person_hrn), - email=person['email']) - if pubkey: - user_record.reg_keys=[RegKey (pubkey['key'], pubkey['key_id'])] + user_record = RegUser(hrn=person_hrn, gid=person_gid, + pointer=person['person_id'], + authority=get_authority( + person_hrn), + email=person['email']) + if pubkey: + user_record.reg_keys = [ + RegKey(pubkey['key'], pubkey['key_id'])] else: - self.logger.warning("No key found for user {}".format(user_record)) + self.logger.warning( + "No key found for user {}".format(user_record)) user_record.just_created() - global_dbsession.add (user_record) + global_dbsession.add(user_record) global_dbsession.commit() - self.logger.info("PlImporter: imported person: {}".format(user_record)) - self.remember_record ( user_record ) + self.logger.info( + "PlImporter: imported person: {}".format(user_record)) + self.remember_record(user_record) else: # update the record ? # @@ -359,7 +391,8 @@ class PlImporter: # NOTE: with this logic, the first key entered in PLC remains the one # current in SFA until it is removed from PLC sfa_keys = user_record.reg_keys - def sfa_key_in_list (sfa_key,plc_keys): + + def sfa_key_in_list(sfa_key, plc_keys): for plc_key in plc_keys: if plc_key['key'] == sfa_key.key: return True @@ -368,21 +401,24 @@ class PlImporter: new_keys = False if not sfa_keys and plc_keys: new_keys = True - else: + else: for sfa_key in sfa_keys: - if not sfa_key_in_list (sfa_key,plc_keys): - new_keys = True + if not sfa_key_in_list(sfa_key, plc_keys): + new_keys = True if new_keys: - (pubkey,pkey) = init_person_key (person, plc_keys) - person_gid = self.auth_hierarchy.create_gid(person_urn, create_uuid(), pkey) + (pubkey, pkey) = init_person_key(person, plc_keys) + person_gid = self.auth_hierarchy.create_gid( + person_urn, create_uuid(), pkey) person_gid.set_email(person['email']) if not pubkey: user_record.reg_keys = [] else: - user_record.reg_keys = [ RegKey (pubkey['key'], pubkey['key_id'])] + user_record.reg_keys = [ + RegKey(pubkey['key'], pubkey['key_id'])] user_record.gid = person_gid user_record.just_updated() - self.logger.info("PlImporter: updated person: {}".format(user_record)) + self.logger.info( + "PlImporter: updated person: {}".format(user_record)) user_record.email = person['email'] global_dbsession.commit() user_record.stale = False @@ -390,23 +426,25 @@ class PlImporter: # this is valid for all sites she is in.. # PI is coded with role_id == 20 if 20 in person['role_ids']: - site_pis.append (user_record) + site_pis.append(user_record) - # PL Admins need to marked as PI of the top authority record + # PL Admins need to marked as PI of the top authority + # record if 10 in person['role_ids'] and user_record not in top_auth_record.reg_pis: admins.append(user_record) except: self.logger.log_exc("PlImporter: failed to import person {} {}" .format(person['person_id'], person['email'])) - + # maintain the list of PIs for a given site # for the record, Jordan had proposed the following addition as a welcome hotfix to a previous version: - # site_pis = list(set(site_pis)) + # site_pis = list(set(site_pis)) # this was likely due to a bug in the above logic, that had to do with disabled persons # being improperly handled, and where the whole loop on persons # could be performed twice with the same person... - # so hopefully we do not need to eliminate duplicates explicitly here anymore + # so hopefully we do not need to eliminate duplicates explicitly + # here anymore site_record.reg_pis = list(set(site_pis)) global_dbsession.commit() @@ -415,8 +453,8 @@ class PlImporter: try: slice = slices_by_id[slice_id] except: - self.logger.warning ("PlImporter: cannot locate slice_id {} - ignored" - .format(slice_id)) + self.logger.warning("PlImporter: cannot locate slice_id {} - ignored" + .format(slice_id)) continue #slice_hrn = slicename_to_hrn(interface_hrn, slice['name']) slice_hrn = slice['hrn'] @@ -424,20 +462,22 @@ class PlImporter: self.logger.warning("Slice {} has no hrn - skipped" .format(slice['name'])) continue - slice_record = self.locate_by_type_hrn ('slice', slice_hrn) + slice_record = self.locate_by_type_hrn('slice', slice_hrn) if not slice_record: try: pkey = Keypair(create=True) urn = hrn_to_urn(slice_hrn, 'slice') - slice_gid = self.auth_hierarchy.create_gid(urn, create_uuid(), pkey) - slice_record = RegSlice (hrn=slice_hrn, gid=slice_gid, - pointer=slice['slice_id'], - authority=get_authority(slice_hrn)) + slice_gid = self.auth_hierarchy.create_gid( + urn, create_uuid(), pkey) + slice_record = RegSlice(hrn=slice_hrn, gid=slice_gid, + pointer=slice['slice_id'], + authority=get_authority(slice_hrn)) slice_record.just_created() global_dbsession.add(slice_record) global_dbsession.commit() - self.logger.info("PlImporter: imported slice: {}".format(slice_record)) - self.remember_record ( slice_record ) + self.logger.info( + "PlImporter: imported slice: {}".format(slice_record)) + self.remember_record(slice_record) except: self.logger.log_exc("PlImporter: failed to import slice {} ({})" .format(slice_hrn, slice['name'])) @@ -449,9 +489,11 @@ class PlImporter: pass # record current users affiliated with the slice slice_record.reg_researchers = \ - [ self.locate_by_type_pointer ('user', user_id) for user_id in slice['person_ids'] ] + [self.locate_by_type_pointer('user', user_id) for user_id in slice[ + 'person_ids']] # remove any weird value (looks like we can get 'None' here - slice_record.reg_researchers = [ x for x in slice_record.reg_researchers if x ] + slice_record.reg_researchers = [ + x for x in slice_record.reg_researchers if x] global_dbsession.commit() slice_record.stale = False @@ -462,24 +504,27 @@ class PlImporter: self.logger.info('PlImporter: set PL admins {} as PIs of {}' .format(admins, top_auth_record.hrn)) - ### remove stale records + # remove stale records # special records must be preserved - system_hrns = [interface_hrn, root_auth, interface_hrn + '.slicemanager'] - for record in all_records: - if record.hrn in system_hrns: + system_hrns = [interface_hrn, root_auth, + interface_hrn + '.slicemanager'] + for record in all_records: + if record.hrn in system_hrns: record.stale = False if record.peer_authority: record.stale = False if ".vini" in interface_hrn and interface_hrn.endswith('vini') and \ - record.hrn.endswith("internet2"): + record.hrn.endswith("internet2"): record.stale = False for record in all_records: - try: stale = record.stale - except: + try: + stale = record.stale + except: stale = True self.logger.warning("stale not found with {}".format(record)) if stale: - self.logger.info("PlImporter: deleting stale record: {}".format(record)) + self.logger.info( + "PlImporter: deleting stale record: {}".format(record)) global_dbsession.delete(record) global_dbsession.commit() diff --git a/sfa/iotlab/iotlabaggregate.py b/sfa/iotlab/iotlabaggregate.py index a0edf3a7..36b3291a 100644 --- a/sfa/iotlab/iotlabaggregate.py +++ b/sfa/iotlab/iotlabaggregate.py @@ -36,12 +36,12 @@ class IotLABAggregate(object): rspec_lease = Lease() rspec_lease['lease_id'] = lease['id'] iotlab_xrn = Xrn('.'.join([self.driver.root_auth, - Xrn.escape(node)]), + Xrn.escape(node)]), type='node') rspec_lease['component_id'] = iotlab_xrn.urn rspec_lease['start_time'] = str(lease['date']) # duration in minutes - duration = int(lease['duration'])/60 + duration = int(lease['duration']) / 60 rspec_lease['duration'] = duration rspec_lease['slice_id'] = lease['slice_id'] rspec_leases.append(rspec_lease) @@ -54,7 +54,7 @@ class IotLABAggregate(object): rspec_node['archi'] = node['archi'] rspec_node['radio'] = (node['archi'].split(':'))[1] iotlab_xrn = Xrn('.'.join([self.driver.root_auth, - Xrn.escape(node['network_address'])]), + Xrn.escape(node['network_address'])]), type='node') # rspec_node['boot_state'] = 'true' if node['state'] == 'Absent' or \ @@ -183,7 +183,8 @@ class IotLABAggregate(object): else: reserved_nodes[lease_id]['slice_id'] = \ hrn_to_urn(self.driver.root_auth + '.' + - reserved_nodes[lease_id]['owner']+"_slice", + reserved_nodes[lease_id][ + 'owner'] + "_slice", 'slice') leases.append(reserved_nodes[lease_id]) diff --git a/sfa/iotlab/iotlabdriver.py b/sfa/iotlab/iotlabdriver.py index f2c9a26b..fc7c7c27 100644 --- a/sfa/iotlab/iotlabdriver.py +++ b/sfa/iotlab/iotlabdriver.py @@ -143,7 +143,7 @@ class IotLabDriver(Driver): """ leases = rspec.version.get_leases() start_time = min([int(lease['start_time']) - for lease in leases]) + for lease in leases]) # ASAP jobs if start_time == 0: start_time = None @@ -152,13 +152,13 @@ class IotLabDriver(Driver): # schedule jobs else: end_time = max([int(lease['start_time']) + - int(lease['duration'])*60 + int(lease['duration']) * 60 for lease in leases]) from math import floor # minutes - duration = floor((end_time - start_time)/60) + duration = floor((end_time - start_time) / 60) nodes_list = [Xrn.unescape(Xrn(lease['component_id'].strip(), - type='node').get_leaf()) + type='node').get_leaf()) for lease in leases] # uniq hostnames nodes_list = list(set(nodes_list)) diff --git a/sfa/managers/aggregate_manager.py b/sfa/managers/aggregate_manager.py index 8b736444..a8c3af4b 100644 --- a/sfa/managers/aggregate_manager.py +++ b/sfa/managers/aggregate_manager.py @@ -10,8 +10,8 @@ from sfa.server.api_versions import ApiVersions class AggregateManager: - def __init__ (self, config): pass - + def __init__(self, config): pass + # essentially a union of the core version, the generic version (this code) and # whatever the driver needs to expose @@ -27,42 +27,52 @@ class AggregateManager: return { 'geni_request_rspec_versions': request_rspec_versions, 'geni_ad_rspec_versions': ad_rspec_versions, - } + } def get_rspec_version_string(self, rspec_version, options=None): - if options is None: options={} + if options is None: + options = {} version_string = "rspec_%s" % (rspec_version) - #panos adding the info option to the caching key (can be improved) + # panos adding the info option to the caching key (can be improved) if options.get('info'): - version_string = version_string + "_"+options.get('info', 'default') + version_string = version_string + \ + "_" + options.get('info', 'default') # Adding the list_leases option to the caching key if options.get('list_leases'): - version_string = version_string + "_"+options.get('list_leases', 'default') + version_string = version_string + "_" + \ + options.get('list_leases', 'default') # Adding geni_available to caching key if options.get('geni_available'): - version_string = version_string + "_" + str(options.get('geni_available')) + version_string = version_string + "_" + \ + str(options.get('geni_available')) return version_string def GetVersion(self, api, options): - xrn=Xrn(api.hrn, type='authority+am') + xrn = Xrn(api.hrn, type='authority+am') version = version_core() - cred_types = [{'geni_type': 'geni_sfa', 'geni_version': str(i)} for i in range(4)[-2:]] + cred_types = [{'geni_type': 'geni_sfa', + 'geni_version': str(i)} for i in range(4)[-2:]] geni_api_versions = ApiVersions().get_versions() - geni_api_versions['3'] = 'http://%s:%s' % (api.config.sfa_aggregate_host, api.config.sfa_aggregate_port) + geni_api_versions[ + '3'] = 'http://%s:%s' % (api.config.sfa_aggregate_host, api.config.sfa_aggregate_port) version_generic = { 'testbed': api.driver.testbed_name(), - 'interface':'aggregate', + 'interface': 'aggregate', 'sfa': 3, - 'hrn':xrn.get_hrn(), - 'urn':xrn.get_urn(), + 'hrn': xrn.get_hrn(), + 'urn': xrn.get_urn(), 'geni_api': 3, 'geni_api_versions': geni_api_versions, - 'geni_single_allocation': 0, # Accept operations that act on as subset of slivers in a given state. - 'geni_allocate': 'geni_many',# Multiple slivers can exist and be incrementally added, including those which connect or overlap in some way. + # Accept operations that act on as subset of slivers in a given + # state. + 'geni_single_allocation': 0, + # Multiple slivers can exist and be incrementally added, including + # those which connect or overlap in some way. + 'geni_allocate': 'geni_many', 'geni_credential_types': cred_types, 'geni_handles_speaksfor': True, # supports 'speaks for' credentials } @@ -71,14 +81,16 @@ class AggregateManager: testbed_version = api.driver.aggregate_version() version.update(testbed_version) return version - + def ListResources(self, api, creds, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + if Callids().already_handled(call_id): + return "" # get the rspec's return format from options version_manager = VersionManager() - rspec_version = version_manager.get_version(options.get('geni_rspec_version')) + rspec_version = version_manager.get_version( + options.get('geni_rspec_version')) version_string = self.get_rspec_version_string(rspec_version, options) # look in cache first @@ -86,29 +98,32 @@ class AggregateManager: if cached_requested and api.driver.cache: rspec = api.driver.cache.get(version_string) if rspec: - logger.debug("%s.ListResources returning cached advertisement" % (api.driver.__module__)) + logger.debug("%s.ListResources returning cached advertisement" % ( + api.driver.__module__)) return rspec - - rspec = api.driver.list_resources (rspec_version, options) + + rspec = api.driver.list_resources(rspec_version, options) if api.driver.cache: - logger.debug("%s.ListResources stores advertisement in cache" % (api.driver.__module__)) - api.driver.cache.add(version_string, rspec) + logger.debug("%s.ListResources stores advertisement in cache" % ( + api.driver.__module__)) + api.driver.cache.add(version_string, rspec) return rspec - + def Describe(self, api, creds, urns, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + if Callids().already_handled(call_id): + return "" version_manager = VersionManager() - rspec_version = version_manager.get_version(options.get('geni_rspec_version')) + rspec_version = version_manager.get_version( + options.get('geni_rspec_version')) return api.driver.describe(urns, rspec_version, options) - - - def Status (self, api, urns, creds, options): + + def Status(self, api, urns, creds, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return {} - return api.driver.status (urns, options=options) - + if Callids().already_handled(call_id): + return {} + return api.driver.status(urns, options=options) def Allocate(self, api, xrn, creds, rspec_string, expiration, options): """ @@ -116,16 +131,18 @@ class AggregateManager: to a slice with the named URN. """ call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + if Callids().already_handled(call_id): + return "" return api.driver.allocate(xrn, rspec_string, expiration, options) - + def Provision(self, api, xrns, creds, options): """ Create the sliver[s] (slice) at this aggregate. Verify HRN and initialize the slice record in PLC if necessary. """ call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + if Callids().already_handled(call_id): + return "" # make sure geni_rspec_version is specified in options if 'geni_rspec_version' not in options: @@ -133,32 +150,38 @@ class AggregateManager: raise SfaInvalidArgument(msg, 'geni_rspec_version') # make sure we support the requested rspec version version_manager = VersionManager() - rspec_version = version_manager.get_version(options['geni_rspec_version']) + rspec_version = version_manager.get_version( + options['geni_rspec_version']) if not rspec_version: raise InvalidRSpecVersion(options['geni_rspec_version']) - + return api.driver.provision(xrns, options) - + def Delete(self, api, xrns, creds, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return True + if Callids().already_handled(call_id): + return True return api.driver.delete(xrns, options) def Renew(self, api, xrns, creds, expiration_time, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return True + if Callids().already_handled(call_id): + return True return api.driver.renew(xrns, expiration_time, options) def PerformOperationalAction(self, api, xrns, creds, action, options=None): - if options is None: options={} + if options is None: + options = {} call_id = options.get('call_id') - if Callids().already_handled(call_id): return True - return api.driver.perform_operational_action(xrns, action, options) + if Callids().already_handled(call_id): + return True + return api.driver.perform_operational_action(xrns, action, options) def Shutdown(self, api, xrn, creds, options=None): - if options is None: options={} + if options is None: + options = {} call_id = options.get('call_id') - if Callids().already_handled(call_id): return True - return api.driver.shutdown(xrn, options) - + if Callids().already_handled(call_id): + return True + return api.driver.shutdown(xrn, options) diff --git a/sfa/managers/aggregate_manager_max.py b/sfa/managers/aggregate_manager_max.py index 88df708f..67b115d1 100644 --- a/sfa/managers/aggregate_manager_max.py +++ b/sfa/managers/aggregate_manager_max.py @@ -17,13 +17,14 @@ from sfa.managers.aggregate_manager import AggregateManager from sfa.planetlab.plslices import PlSlices + class AggregateManagerMax (AggregateManager): - def __init__ (self, config): + def __init__(self, config): pass RSPEC_TMP_FILE_PREFIX = "/tmp/max_rspec" - + # execute shell command and return both exit code and text output def shell_execute(self, cmd, timeout): pipe = os.popen('{ ' + cmd + '; } 2>&1', 'r') @@ -33,13 +34,14 @@ class AggregateManagerMax (AggregateManager): line = pipe.read() text += line time.sleep(1) - timeout = timeout-1 + timeout = timeout - 1 code = pipe.close() - if code is None: code = 0 - if text[-1:] == '\n': text = text[:-1] + if code is None: + code = 0 + if text[-1:] == '\n': + text = text[:-1] return code, text - - + def call_am_apiclient(self, client_app, params, timeout): """ call AM API client with command like in the following example: @@ -49,56 +51,57 @@ class AggregateManagerMax (AggregateManager): ... params ... """ (client_path, am_url) = Config().get_max_aggrMgr_info() - sys_cmd = "cd " + client_path + "; java -classpath AggregateWS-client-api.jar:lib/* net.geni.aggregate.client.examples." + client_app + " ./repo " + am_url + " " + ' '.join(params) + sys_cmd = "cd " + client_path + "; java -classpath AggregateWS-client-api.jar:lib/* net.geni.aggregate.client.examples." + \ + client_app + " ./repo " + am_url + " " + ' '.join(params) ret = self.shell_execute(sys_cmd, timeout) logger.debug("shell_execute cmd: %s returns %s" % (sys_cmd, ret)) return ret - + # save request RSpec xml content to a tmp file def save_rspec_to_file(self, rspec): path = AggregateManagerMax.RSPEC_TMP_FILE_PREFIX + "_" + \ - time.strftime(SFATIME_FORMAT, time.gmtime(time.time())) +".xml" + time.strftime(SFATIME_FORMAT, time.gmtime(time.time())) + ".xml" file = open(path, "w") file.write(rspec) file.close() return path - + # get stripped down slice id/name plc.maxpl.xislice1 --> maxpl_xislice1 def get_plc_slice_id(self, cred, xrn): (hrn, type) = urn_to_hrn(xrn) slice_id = hrn.find(':') sep = '.' if hrn.find(':') != -1: - sep=':' + sep = ':' elif hrn.find('+') != -1: - sep='+' + sep = '+' else: - sep='.' + sep = '.' slice_id = hrn.split(sep)[-2] + '_' + hrn.split(sep)[-1] return slice_id - - # extract xml + + # extract xml def get_xml_by_tag(self, text, tag): - indx1 = text.find('<'+tag) - indx2 = text.find('/'+tag+'>') + indx1 = text.find('<' + tag) + indx2 = text.find('/' + tag + '>') xml = None - if indx1!=-1 and indx2>indx1: - xml = text[indx1:indx2+len(tag)+2] + if indx1 != -1 and indx2 > indx1: + xml = text[indx1:indx2 + len(tag) + 2] return xml - # formerly in aggregate_manager.py but got unused in there... + # formerly in aggregate_manager.py but got unused in there... def _get_registry_objects(self, slice_xrn, creds, users): """ - + """ hrn, _ = urn_to_hrn(slice_xrn) - + #hrn_auth = get_authority(hrn) - + # Build up objects that an SFA registry would return if SFA # could contact the slice's registry directly reg_objects = None - + if users: # dont allow special characters in the site login base #only_alphanumeric = re.compile('[^a-zA-Z0-9]+') @@ -108,23 +111,26 @@ class AggregateManagerMax (AggregateManager): reg_objects = {} site = {} site['site_id'] = 0 - site['name'] = 'geni.%s' % login_base + site['name'] = 'geni.%s' % login_base site['enabled'] = True site['max_slices'] = 100 - + # Note: # Is it okay if this login base is the same as one already at this myplc site? - # Do we need uniqueness? Should use hrn_auth instead of just the leaf perhaps? + # Do we need uniqueness? Should use hrn_auth instead of just the + # leaf perhaps? site['login_base'] = login_base site['abbreviated_name'] = login_base site['max_slivers'] = 1000 reg_objects['site'] = site - + slice = {} - - # get_expiration always returns a normalized datetime - no need to utcparse + + # get_expiration always returns a normalized datetime - no need to + # utcparse extime = Credential(string=creds[0]).get_expiration() - # If the expiration time is > 60 days from now, set the expiration time to 60 days from now + # If the expiration time is > 60 days from now, set the expiration + # time to 60 days from now if extime > datetime.datetime.utcnow() + datetime.timedelta(days=60): extime = datetime.datetime.utcnow() + datetime.timedelta(days=60) slice['expires'] = int(time.mktime(extime.timetuple())) @@ -134,7 +140,7 @@ class AggregateManagerMax (AggregateManager): slice['description'] = hrn slice['pointer'] = 0 reg_objects['slice_record'] = slice - + reg_objects['users'] = {} for user in users: user['key_ids'] = [] @@ -143,16 +149,16 @@ class AggregateManagerMax (AggregateManager): user['first_name'] = hrn user['last_name'] = hrn reg_objects['users'][user['email']] = user - + return reg_objects - + def prepare_slice(self, api, slice_xrn, creds, users): reg_objects = self._get_registry_objects(slice_xrn, creds, users) (hrn, type) = urn_to_hrn(slice_xrn) slices = PlSlices(self.driver) peer = slices.get_peer(hrn) sfa_peer = slices.get_sfa_peer(hrn) - slice_record=None + slice_record = None if users: slice_record = users[0].get('slice_record', {}) registry = api.registries[api.hrn] @@ -163,7 +169,7 @@ class AggregateManagerMax (AggregateManager): slice = slices.verify_slice(hrn, slice_record, peer, sfa_peer) # ensure person records exists persons = slices.verify_persons(hrn, slice, users, peer, sfa_peer) - + def parse_resources(self, text, slice_xrn): resources = [] urn = hrn_to_urn(slice_xrn, 'sliver') @@ -189,13 +195,14 @@ class AggregateManagerMax (AggregateManager): res['geni_status'] = 'configuring' resources.append(res) return resources - + def slice_status(self, api, slice_xrn, creds): urn = hrn_to_urn(slice_xrn, 'slice') result = {} top_level_status = 'unknown' slice_id = self.get_plc_slice_id(creds, urn) - (ret, output) = self.call_am_apiclient("QuerySliceNetworkClient", [slice_id,], 5) + (ret, output) = self.call_am_apiclient( + "QuerySliceNetworkClient", [slice_id, ], 5) # parse output into rspec XML if output.find("Unkown Rspec:") > 0: top_level_staus = 'failed' @@ -205,9 +212,9 @@ class AggregateManagerMax (AggregateManager): all_active = 0 if output.find("Status => FAILED") > 0: top_level_staus = 'failed' - elif ( output.find("Status => ACCEPTED") > 0 or output.find("Status => PENDING") > 0 - or output.find("Status => INSETUP") > 0 or output.find("Status => INCREATE") > 0 - ): + elif (output.find("Status => ACCEPTED") > 0 or output.find("Status => PENDING") > 0 + or output.find("Status => INSETUP") > 0 or output.find("Status => INCREATE") > 0 + ): top_level_status = 'configuring' else: top_level_status = 'ready' @@ -215,39 +222,46 @@ class AggregateManagerMax (AggregateManager): result['geni_urn'] = urn result['geni_status'] = top_level_status return result - + def create_slice(self, api, xrn, cred, rspec, users): indx1 = rspec.find("") if indx1 > -1 and indx2 > indx1: - rspec = rspec[indx1+len(""):indx2-1] + rspec = rspec[indx1 + len(""):indx2 - 1] rspec_path = self.save_rspec_to_file(rspec) self.prepare_slice(api, xrn, cred, users) slice_id = self.get_plc_slice_id(cred, xrn) - sys_cmd = "sed -i \"s/rspec id=\\\"[^\\\"]*/rspec id=\\\"" +slice_id+ "/g\" " + rspec_path + ";sed -i \"s/:rspec=[^:'<\\\" ]*/:rspec=" +slice_id+ "/g\" " + rspec_path + sys_cmd = "sed -i \"s/rspec id=\\\"[^\\\"]*/rspec id=\\\"" + slice_id + "/g\" " + \ + rspec_path + \ + ";sed -i \"s/:rspec=[^:'<\\\" ]*/:rspec=" + \ + slice_id + "/g\" " + rspec_path ret = self.shell_execute(sys_cmd, 1) - sys_cmd = "sed -i \"s/rspec id=\\\"[^\\\"]*/rspec id=\\\"" + rspec_path + "/g\"" + sys_cmd = "sed -i \"s/rspec id=\\\"[^\\\"]*/rspec id=\\\"" + \ + rspec_path + "/g\"" ret = self.shell_execute(sys_cmd, 1) - (ret, output) = self.call_am_apiclient("CreateSliceNetworkClient", [rspec_path,], 3) + (ret, output) = self.call_am_apiclient( + "CreateSliceNetworkClient", [rspec_path, ], 3) # parse output ? rspec = " Done! " return True - + def delete_slice(self, api, xrn, cred): slice_id = self.get_plc_slice_id(cred, xrn) - (ret, output) = self.call_am_apiclient("DeleteSliceNetworkClient", [slice_id,], 3) + (ret, output) = self.call_am_apiclient( + "DeleteSliceNetworkClient", [slice_id, ], 3) # parse output ? return 1 - - + def get_rspec(self, api, cred, slice_urn): logger.debug("#### called max-get_rspec") - #geni_slice_urn: urn:publicid:IDN+plc:maxpl+slice+xi_rspec_test1 + # geni_slice_urn: urn:publicid:IDN+plc:maxpl+slice+xi_rspec_test1 if slice_urn == None: - (ret, output) = self.call_am_apiclient("GetResourceTopology", ['all', '\"\"'], 5) + (ret, output) = self.call_am_apiclient( + "GetResourceTopology", ['all', '\"\"'], 5) else: slice_id = self.get_plc_slice_id(cred, slice_urn) - (ret, output) = self.call_am_apiclient("GetResourceTopology", ['all', slice_id,], 5) + (ret, output) = self.call_am_apiclient( + "GetResourceTopology", ['all', slice_id, ], 5) # parse output into rspec XML if output.find("No resouce found") > 0: rspec = " No resource found " @@ -256,62 +270,66 @@ class AggregateManagerMax (AggregateManager): logger.debug("#### computeResource %s" % comp_rspec) topo_rspec = self.get_xml_by_tag(output, 'topology') logger.debug("#### topology %s" % topo_rspec) - rspec = " " + rspec = " " if comp_rspec != None: rspec = rspec + self.get_xml_by_tag(output, 'computeResource') if topo_rspec != None: rspec = rspec + self.get_xml_by_tag(output, 'topology') rspec = rspec + " " return (rspec) - + def start_slice(self, api, xrn, cred): # service not supported return None - + def stop_slice(self, api, xrn, cred): # service not supported return None - + def reset_slices(self, api, xrn): # service not supported return None - - ### GENI AM API Methods - + + # GENI AM API Methods + def SliverStatus(self, api, slice_xrn, creds, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return {} + if Callids().already_handled(call_id): + return {} return self.slice_status(api, slice_xrn, creds) - + def CreateSliver(self, api, slice_xrn, creds, rspec_string, users, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" - #TODO: create real CreateSliver response rspec + if Callids().already_handled(call_id): + return "" + # TODO: create real CreateSliver response rspec ret = self.create_slice(api, slice_xrn, creds, rspec_string, users) if ret: return self.get_rspec(api, creds, slice_xrn) else: return " Error! " - + def DeleteSliver(self, api, xrn, creds, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + if Callids().already_handled(call_id): + return "" return self.delete_slice(api, xrn, creds) - + # no caching def ListResources(self, api, creds, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + if Callids().already_handled(call_id): + return "" # version_string = "rspec_%s" % (rspec_version.get_version_name()) slice_urn = options.get('geni_slice_urn') return self.get_rspec(api, creds, slice_urn) - + def fetch_context(self, slice_hrn, user_hrn, contexts): """ Returns the request context required by sfatables. At some point, this mechanism should be changed to refer to "contexts", which is the information that sfatables is requesting. But for now, we just return the basic information needed in a dict. """ - base_context = {'sfa':{'user':{'hrn':user_hrn}}} + base_context = {'sfa': {'user': {'hrn': user_hrn}}} return base_context - diff --git a/sfa/managers/component_manager_default.py b/sfa/managers/component_manager_default.py index b79c7cc8..1c6e0a82 100644 --- a/sfa/managers/component_manager_default.py +++ b/sfa/managers/component_manager_default.py @@ -2,21 +2,26 @@ def start_slice(api, slicename): return + def stop_slice(api, slicename): return + def DeleteSliver(api, slicename, call_id): return + def reset_slice(api, slicename): - return - + return + + def ListSlices(api): return [] + def reboot(): return -def redeem_ticket(api, ticket_string): - return +def redeem_ticket(api, ticket_string): + return diff --git a/sfa/managers/component_manager_pl.py b/sfa/managers/component_manager_pl.py index ba9b7ebf..219d573f 100644 --- a/sfa/managers/component_manager_pl.py +++ b/sfa/managers/component_manager_pl.py @@ -6,9 +6,11 @@ from sfa.trust.sfaticket import SfaTicket from sfa.planetlab.plxrn import PlXrn + def GetVersion(api, options): - return version_core({'interface':'component', - 'testbed':'myplc'}) + return version_core({'interface': 'component', + 'testbed': 'myplc'}) + def init_server(): from sfa.server import sfa_component_setup @@ -21,51 +23,57 @@ def init_server(): sfa_component_setup.GetCredential(force=True) sfa_component_setup.get_trusted_certs() + def SliverStatus(api, slice_xrn, creds): result = {} result['geni_urn'] = slice_xrn result['geni_status'] = 'unknown' result['geni_resources'] = {} return result - + + def start_slice(api, xrn, creds): slicename = PlXrn(xrn, type='slice').pl_slicename() api.driver.nodemanager.Start(slicename) + def stop_slice(api, xrn, creds): slicename = PlXrn(xrn, type='slice').pl_slicename() api.driver.nodemanager.Stop(slicename) + def DeleteSliver(api, xrn, creds, call_id): slicename = PlXrn(xrn, type='slice').pl_slicename() api.driver.nodemanager.Destroy(slicename) + def reset_slice(api, xrn): slicename = PlXrn(xrn, type='slice').pl_slicename() if not api.sliver_exists(slicename): raise SliverDoesNotExist(slicename) api.driver.nodemanager.ReCreate(slicename) - + # xxx outdated - this should accept a credential & call_id + + def ListSlices(api): - # this returns a tuple, the data we want is at index 1 + # this returns a tuple, the data we want is at index 1 xids = api.driver.nodemanager.GetXIDs() - # unfortunately the data we want is given to us as + # unfortunately the data we want is given to us as # a string but we really want it as a dict # lets eval it slices = eval(xids[1]) return slices.keys() + def redeem_ticket(api, ticket_string): ticket = SfaTicket(string=ticket_string) ticket.decode() hrn = ticket.attributes['slivers'][0]['hrn'] - slicename = PlXrn (hrn).pl_slicename() + slicename = PlXrn(hrn).pl_slicename() if not api.sliver_exists(slicename): raise SliverDoesNotExist(slicename) # convert ticket to format nm is used to nm_ticket = xmlrpc_client.dumps((ticket.attributes,), methodresponse=True) api.driver.nodemanager.AdminTicket(nm_ticket) - - diff --git a/sfa/managers/driver.py b/sfa/managers/driver.py index 1985e0ee..1d4832fb 100644 --- a/sfa/managers/driver.py +++ b/sfa/managers/driver.py @@ -1,17 +1,18 @@ -# -# an attempt to document what a driver class should provide, +# +# an attempt to document what a driver class should provide, # and implement reasonable defaults # + class Driver: - - def __init__ (self, api): + + def __init__(self, api): self.api = api # this is the hrn attached to the running server self.hrn = api.config.SFA_INTERFACE_HRN ######################################## - ########## registry oriented + # registry oriented ######################################## # NOTE: the is_enabled method is deprecated @@ -22,21 +23,21 @@ class Driver: # after looking up the sfa db, we wish to be able to display # testbed-specific info as well # based on the principle that SFA should not rely on the testbed database - # to perform such a core operation (i.e. getting rights right) - # this is no longer in use when performing other SFA operations - def augment_records_with_testbed_info (self, sfa_records): + # to perform such a core operation (i.e. getting rights right) + # this is no longer in use when performing other SFA operations + def augment_records_with_testbed_info(self, sfa_records): return sfa_records # incoming record, as provided by the client to the Register API call # expected retcod 'pointer' # 'pointer' is typically an int db id, that makes sense in the testbed environment - # -1 if this feature is not relevant - def register (self, sfa_record, hrn, pub_key) : + # -1 if this feature is not relevant + def register(self, sfa_record, hrn, pub_key): return -1 # incoming record is the existing sfa_record # expected retcod boolean, error message logged if result is False - def remove (self, sfa_record): + def remove(self, sfa_record): return True # incoming are the sfa_record: @@ -44,7 +45,7 @@ class Driver: # (*) new_sfa_record is what was passed in the Update call # expected retcod boolean, error message logged if result is False # NOTE 1. about keys - # this is confusing because a user may have several ssh keys in + # this is confusing because a user may have several ssh keys in # the planetlab database, but we need to pick one to generate its cert # so as much as in principle we should be able to use new_sfa_record['keys'] # the manager code actually picks one (the first one), and it seems safer @@ -53,33 +54,34 @@ class Driver: # NOTE 2. about keys # when changing the ssh key through this method the gid gets changed too # should anything be passed back to the caller in this case ? - def update (self, old_sfa_record, new_sfa_record, hrn, new_key): + def update(self, old_sfa_record, new_sfa_record, hrn, new_key): return True # callack for register/update # this allows to capture changes in the relations between objects # the ids below are the ones found in the 'pointer' field # this can get typically called with - # 'slice' 'user' 'researcher' slice_id user_ids - # 'authority' 'user' 'pi' authority_id user_ids - def update_relation (self, subject_type, target_type, relation_name, subject_id, link_ids): + # 'slice' 'user' 'researcher' slice_id user_ids + # 'authority' 'user' 'pi' authority_id user_ids + def update_relation(self, subject_type, target_type, relation_name, subject_id, link_ids): pass ######################################## - ########## aggregate oriented + # aggregate oriented ######################################## - + # a name for identifying the kind of testbed - def testbed_name (self): return "undefined" + def testbed_name(self): return "undefined" # a dictionary that gets appended to the generic answer to GetVersion # 'geni_request_rspec_versions' and 'geni_ad_rspec_versions' are mandatory - def aggregate_version (self): return {} + def aggregate_version(self): return {} # answer to ListResources # returns : advertisment rspec (xml string) - def list_resources (self, version=None, options=None): - if options is None: options={} + def list_resources(self, version=None, options=None): + if options is None: + options = {} return "dummy Driver.list_resources needs to be redefined" # the answer to Describe on a slice or a set of the slivers in a slice @@ -98,48 +100,60 @@ class Driver: # ... # ] #} - def describe (self, urns, version, options=None): - if options is None: options={} + def describe(self, urns, version, options=None): + if options is None: + options = {} return "dummy Driver.describe needs to be redefined" # the answer to Allocate on a given slicei or a set of the slivers in a slice # returns: same struct as for describe. - def allocate (self, urn, rspec_string, expiration, options=None): - if options is None: options={} + def allocate(self, urn, rspec_string, expiration, options=None): + if options is None: + options = {} return "dummy Driver.allocate needs to be redefined" # the answer to Provision on a given slice or a set of the slivers in a slice # returns: same struct as for describe. def provision(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} return "dummy Driver.provision needs to be redefined" # the answer to PerformOperationalAction on a given slice or a set of the slivers in a slice - # returns: struct containing "geni_slivers" list of the struct returned by describe. - def perform_operational_action (self, urns, action, options=None): - if options is None: options={} + # returns: struct containing "geni_slivers" list of the struct returned by + # describe. + def perform_operational_action(self, urns, action, options=None): + if options is None: + options = {} return "dummy Driver.perform_operational_action needs to be redefined" # the answer to Status on a given slice or a set of the slivers in a slice - # returns: struct containing "geni_urn" and "geni_slivers" list of the struct returned by describe. - def status (self, urns, options=None): - if options is None: options={} + # returns: struct containing "geni_urn" and "geni_slivers" list of the + # struct returned by describe. + def status(self, urns, options=None): + if options is None: + options = {} return "dummy Driver.status needs to be redefined" # the answer to Renew on a given slice or a set of the slivers in a slice - # returns: struct containing "geni_slivers" list of the struct returned by describe. - def renew (self, urns, expiration_time, options=None): - if options is None: options={} + # returns: struct containing "geni_slivers" list of the struct returned by + # describe. + def renew(self, urns, expiration_time, options=None): + if options is None: + options = {} return "dummy Driver.renew needs to be redefined" # the answer to Delete on a given slice - # returns: struct containing "geni_slivers" list of the struct returned by describe. + # returns: struct containing "geni_slivers" list of the struct returned by + # describe. def delete(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} return "dummy Driver.delete needs to be redefined" # the answer to Shutdown on a given slice # returns: boolean - def shutdown (self, xrn, options=None): - if options is None: options={} + def shutdown(self, xrn, options=None): + if options is None: + options = {} return False diff --git a/sfa/managers/eucalyptus/euca_rspec_validator.py b/sfa/managers/eucalyptus/euca_rspec_validator.py index 7ebaae5b..34fbf732 100755 --- a/sfa/managers/eucalyptus/euca_rspec_validator.py +++ b/sfa/managers/eucalyptus/euca_rspec_validator.py @@ -9,7 +9,8 @@ from lxml import etree as ET ## # The location of the RelaxNG schema. # -EUCALYPTUS_RSPEC_SCHEMA='eucalyptus.rng' +EUCALYPTUS_RSPEC_SCHEMA = 'eucalyptus.rng' + def main(): with open(sys.argv[1], 'r') as f: @@ -19,11 +20,10 @@ def main(): rspecXML = ET.XML(xml) if not rspecValidator(rspecXML): error = rspecValidator.error_log.last_error - message = '%s (line %s)' % (error.message, error.line) + message = '%s (line %s)' % (error.message, error.line) print(message) else: print('It is valid') if __name__ == "__main__": main() - diff --git a/sfa/managers/managerwrapper.py b/sfa/managers/managerwrapper.py index da8c98f6..77af9bf6 100644 --- a/sfa/managers/managerwrapper.py +++ b/sfa/managers/managerwrapper.py @@ -4,6 +4,8 @@ from sfa.util.faults import SfaNotImplemented, SfaAPIError from sfa.util.sfalogging import logger #################### + + class ManagerWrapper: """ This class acts as a wrapper around an SFA interface manager module, but @@ -15,23 +17,25 @@ class ManagerWrapper: is not implemented by a libarary and will generally be more helpful than the standard AttributeError """ + def __init__(self, manager, interface, config): - if isinstance (manager, ModuleType): + if isinstance(manager, ModuleType): # old-fashioned module implementation self.manager = manager - elif isinstance (manager, ClassType): - # create an instance; we don't pass the api in argument as it is passed + elif isinstance(manager, ClassType): + # create an instance; we don't pass the api in argument as it is passed # to the actual method calls anyway self.manager = manager(config) else: # that's what happens when there's something wrong with the db # or any bad stuff of that kind at startup time - logger.log_exc("Failed to create a manager, startup sequence is broken") - raise SfaAPIError("Argument to ManagerWrapper must be a module or class") + logger.log_exc( + "Failed to create a manager, startup sequence is broken") + raise SfaAPIError( + "Argument to ManagerWrapper must be a module or class") self.interface = interface - + def __getattr__(self, method): if not hasattr(self.manager, method): raise SfaNotImplemented(self.interface, method) return getattr(self.manager, method) - diff --git a/sfa/managers/registry_manager.py b/sfa/managers/registry_manager.py index 78d933d7..0646be7a 100644 --- a/sfa/managers/registry_manager.py +++ b/sfa/managers/registry_manager.py @@ -15,23 +15,23 @@ from sfa.util.sfalogging import logger from sfa.util.printable import printable -from sfa.trust.gid import GID +from sfa.trust.gid import GID from sfa.trust.credential import Credential from sfa.trust.certificate import Certificate, Keypair, convert_public_key from sfa.trust.gid import create_uuid from sfa.storage.model import make_record, RegRecord, RegAuthority, RegUser, RegSlice, RegKey, \ augment_with_sfa_builtins -### the types that we need to exclude from sqlobjects before being able to dump +# the types that we need to exclude from sqlobjects before being able to dump # them on the xmlrpc wire from sqlalchemy.orm.collections import InstrumentedList -### historical note -- april 2014 +# historical note -- april 2014 # the myslice chaps rightfully complained about the following discrepancy # they found that -# * read operations (resolve) expose stuff like e.g. +# * read operations (resolve) expose stuff like e.g. # 'reg-researchers', or 'reg-pis', but that -# * write operations (register, update) need e.g. +# * write operations (register, update) need e.g. # 'researcher' or 'pi' to be set - reg-* are just ignored # # the '_normalize_input' helper functions below aim at ironing this out @@ -43,63 +43,67 @@ from sqlalchemy.orm.collections import InstrumentedList # e.g. registry calls this 'reg-researchers' # while some drivers call this 'researcher' # we need to make sure that both keys appear and are the same + + def _normalize_input(record, reg_key, driver_key): # this looks right, use this for both keys if reg_key in record: # and issue a warning if they were both set and different # as we're overwriting some user data here if driver_key in record: - logger.warning ("normalize_input: incoming record has both values, using {}" - .format(reg_key)) + logger.warning("normalize_input: incoming record has both values, using {}" + .format(reg_key)) record[driver_key] = record[reg_key] # we only have one key set, duplicate for the other one elif driver_key in record: - logger.warning ("normalize_input: you should use '{}' instead of '{}'" - .format(reg_key, driver_key)) + logger.warning("normalize_input: you should use '{}' instead of '{}'" + .format(reg_key, driver_key)) record[reg_key] = record[driver_key] -def normalize_input_record (record): - _normalize_input (record, 'reg-researchers','researcher') - _normalize_input (record, 'reg-pis','pi') - _normalize_input (record, 'reg-keys','keys') + +def normalize_input_record(record): + _normalize_input(record, 'reg-researchers', 'researcher') + _normalize_input(record, 'reg-pis', 'pi') + _normalize_input(record, 'reg-keys', 'keys') # xxx the keys thing could use a little bit more attention: - # some parts of the code are using 'keys' while they should use 'reg-keys' + # some parts of the code are using 'keys' while they should use 'reg-keys' # but I run out of time for now if 'reg-keys' in record: record['keys'] = record['reg-keys'] return record + class RegistryManager: - def __init__ (self, config): + def __init__(self, config): logger.info("Creating RegistryManager[{}]".format(id(self))) # The GENI GetVersion call def GetVersion(self, api, options): - peers = dict ( [ (hrn,interface.get_url()) for (hrn,interface) in api.registries.iteritems() - if hrn != api.hrn]) - xrn=Xrn(api.hrn,type='authority') - return version_core({'interface':'registry', + peers = dict([(hrn, interface.get_url()) for (hrn, interface) in api.registries.iteritems() + if hrn != api.hrn]) + xrn = Xrn(api.hrn, type='authority') + return version_core({'interface': 'registry', 'sfa': 3, - 'hrn':xrn.get_hrn(), - 'urn':xrn.get_urn(), - 'peers':peers}) - + 'hrn': xrn.get_hrn(), + 'urn': xrn.get_urn(), + 'peers': peers}) + def GetCredential(self, api, xrn, input_type, caller_xrn=None): - # convert xrn to hrn + # convert xrn to hrn if input_type: hrn, _ = urn_to_hrn(xrn) type = input_type else: hrn, type = urn_to_hrn(xrn) - # Slivers don't have credentials but users should be able to + # Slivers don't have credentials but users should be able to # specify a sliver xrn and receive the slice's credential # However if input_type is specified - if type == 'sliver' or ( not input_type and '-' in Xrn(hrn).leaf): + if type == 'sliver' or (not input_type and '-' in Xrn(hrn).leaf): slice_xrn = api.driver.sliver_to_slice_xrn(hrn) - hrn = slice_xrn.hrn - + hrn = slice_xrn.hrn + # Is this a root or sub authority auth_hrn = api.auth.get_authority(hrn) if not auth_hrn or hrn == api.config.SFA_INTERFACE_HRN: @@ -108,7 +112,8 @@ class RegistryManager: # get record info dbsession = api.dbsession() - record = dbsession.query(RegRecord).filter_by(type=type, hrn=hrn).first() + record = dbsession.query(RegRecord).filter_by( + type=type, hrn=hrn).first() if not record: raise RecordNotFound("hrn={}, type={}".format(hrn, type)) @@ -121,15 +126,17 @@ class RegistryManager: else: caller_hrn, caller_type = urn_to_hrn(caller_xrn) if caller_type: - caller_record = dbsession.query(RegRecord).filter_by(hrn=caller_hrn,type=caller_type).first() + caller_record = dbsession.query(RegRecord).filter_by( + hrn=caller_hrn, type=caller_type).first() else: - caller_record = dbsession.query(RegRecord).filter_by(hrn=caller_hrn).first() + caller_record = dbsession.query( + RegRecord).filter_by(hrn=caller_hrn).first() if not caller_record: raise RecordNotFound( "Unable to associated caller (hrn={}, type={}) with credential for (hrn: {}, type: {})" .format(caller_hrn, caller_type, hrn, type)) caller_gid = GID(string=caller_record.gid) - + object_hrn = record.get_gid_object().get_hrn() # call the builtin authorization/credential generation engine rights = api.auth.determine_user_rights(caller_hrn, record) @@ -138,14 +145,15 @@ class RegistryManager: raise PermissionError("{} has no rights to {} ({})" .format(caller_hrn, object_hrn, xrn)) object_gid = GID(string=record.gid) - new_cred = Credential(subject = object_gid.get_subject()) + new_cred = Credential(subject=object_gid.get_subject()) new_cred.set_gid_caller(caller_gid) new_cred.set_gid_object(object_gid) - new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename()) - #new_cred.set_pubkey(object_gid.get_pubkey()) + new_cred.set_issuer_keys( + auth_info.get_privkey_filename(), auth_info.get_gid_filename()) + # new_cred.set_pubkey(object_gid.get_pubkey()) new_cred.set_privileges(rights) new_cred.get_privileges().delegate_all_privileges(True) - if hasattr(record,'expires'): + if hasattr(record, 'expires'): date = utcparse(record.expires) expires = datetime_to_epoch(date) new_cred.set_expiration(int(expires)) @@ -154,25 +162,25 @@ class RegistryManager: #new_cred.set_parent(api.auth.hierarchy.get_auth_cred(auth_hrn, kind=auth_kind)) new_cred.encode() new_cred.sign() - + return new_cred.save_to_string(save_parents=True) - - - # the default for full, which means 'dig into the testbed as well', should be false + + # the default for full, which means 'dig into the testbed as well', should + # be false def Resolve(self, api, xrns, type=None, details=False): - + dbsession = api.dbsession() if not isinstance(xrns, list): # try to infer type if not set and we get a single input if not type: type = Xrn(xrns).get_type() xrns = [xrns] - hrns = [urn_to_hrn(xrn)[0] for xrn in xrns] + hrns = [urn_to_hrn(xrn)[0] for xrn in xrns] # load all known registry names into a prefix tree and attempt to find # the longest matching prefix # create a dict where key is a registry hrn and its value is a list - # of hrns at that registry (determined by the known prefix tree). + # of hrns at that registry (determined by the known prefix tree). xrn_dict = {} registries = api.registries tree = prefixTree() @@ -183,14 +191,14 @@ class RegistryManager: if registry_hrn not in xrn_dict: xrn_dict[registry_hrn] = [] xrn_dict[registry_hrn].append(xrn) - - records = [] + + records = [] for registry_hrn in xrn_dict: # skip the hrn without a registry hrn - # XX should we let the user know the authority is unknown? + # XX should we let the user know the authority is unknown? if not registry_hrn: continue - + # if the best match (longest matching hrn) is not the local registry, # forward the request xrns = xrn_dict[registry_hrn] @@ -204,53 +212,60 @@ class RegistryManager: # pass foreign records as-is # previous code used to read # records.extend([SfaRecord(dict=record).as_dict() for record in peer_records]) - # not sure why the records coming through xmlrpc had to be processed at all + # not sure why the records coming through xmlrpc had to be + # processed at all records.extend(peer_records) - + # try resolving the remaining unfound records at the local registry - local_hrns = list ( set(hrns).difference([record['hrn'] for record in records]) ) - # - local_records = dbsession.query(RegRecord).filter(RegRecord.hrn.in_(local_hrns)) + local_hrns = list(set(hrns).difference( + [record['hrn'] for record in records])) + # + local_records = dbsession.query(RegRecord).filter( + RegRecord.hrn.in_(local_hrns)) if type: local_records = local_records.filter_by(type=type) local_records = local_records.all() - + for local_record in local_records: augment_with_sfa_builtins(local_record) logger.info("Resolve, (details={}, type={}) local_records={} " .format(details, type, local_records)) - local_dicts = [ record.__dict__ for record in local_records ] - + local_dicts = [record.__dict__ for record in local_records] + if details: - # in details mode we get as much info as we can, which involves contacting the + # in details mode we get as much info as we can, which involves contacting the # testbed for getting implementation details about the record api.driver.augment_records_with_testbed_info(local_dicts) # also we fill the 'url' field for known authorities # used to be in the driver code, sounds like a poorman thing though - def solve_neighbour_url (record): - if not record.type.startswith('authority'): return + + def solve_neighbour_url(record): + if not record.type.startswith('authority'): + return hrn = record.hrn - for neighbour_dict in [ api.aggregates, api.registries ]: + for neighbour_dict in [api.aggregates, api.registries]: if hrn in neighbour_dict: - record.url=neighbour_dict[hrn].get_url() - return + record.url = neighbour_dict[hrn].get_url() + return for record in local_records: - solve_neighbour_url (record) - + solve_neighbour_url(record) + # convert local record objects to dicts for xmlrpc # xxx somehow here calling dict(record) issues a weird error # however record.todict() seems to work fine # records.extend( [ dict(record) for record in local_records ] ) - records.extend( [ record.record_to_dict(exclude_types=(InstrumentedList,)) for record in local_records ] ) + records.extend([record.record_to_dict(exclude_types=( + InstrumentedList,)) for record in local_records]) if not records: raise RecordNotFound(str(hrns)) - + return records - - def List (self, api, xrn, origin_hrn=None, options=None): - if options is None: options={} + + def List(self, api, xrn, origin_hrn=None, options=None): + if options is None: + options = {} dbsession = api.dbsession() # load all know registry names into a prefix tree and attempt to find # the longest matching prefix @@ -260,13 +275,13 @@ class RegistryManager: tree = prefixTree() tree.load(registry_hrns) registry_hrn = tree.best_match(hrn) - - #if there was no match then this record belongs to an unknow registry + + # if there was no match then this record belongs to an unknow registry if not registry_hrn: raise MissingAuthority(xrn) # if the best match (longest matching hrn) is not the local registry, # forward the request - record_dicts = [] + record_dicts = [] if registry_hrn != api.hrn: credential = api.getCredential() interface = api.registries[registry_hrn] @@ -275,7 +290,7 @@ class RegistryManager: # same as above, no need to process what comes from through xmlrpc # pass foreign records as-is record_dicts = record_list - + # if we still have not found the record yet, try the local registry # logger.debug("before trying local records, {} foreign records".format(len(record_dicts))) if not record_dicts: @@ -289,10 +304,12 @@ class RegistryManager: if not api.auth.hierarchy.auth_exists(hrn): raise MissingAuthority(hrn) if recursive: - records = dbsession.query(RegRecord).filter(RegRecord.hrn.startswith(hrn)).all() + records = dbsession.query(RegRecord).filter( + RegRecord.hrn.startswith(hrn)).all() # logger.debug("recursive mode, found {} local records".format(len(records))) else: - records = dbsession.query(RegRecord).filter_by(authority=hrn).all() + records = dbsession.query( + RegRecord).filter_by(authority=hrn).all() # logger.debug("non recursive mode, found {} local records".format(len(records))) # so that sfi list can show more than plain names... for record in records: @@ -302,10 +319,11 @@ class RegistryManager: # was first observed with authorities' 'name' column # that would be missing from result as received by client augment_with_sfa_builtins(record) - record_dicts = [ record.record_to_dict(exclude_types=(InstrumentedList,)) for record in records ] - + record_dicts = [record.record_to_dict( + exclude_types=(InstrumentedList,)) for record in records] + return record_dicts - + def CreateGid(self, api, xrn, cert): # get the authority authority = Xrn(xrn=xrn).get_authority_hrn() @@ -314,7 +332,7 @@ class RegistryManager: pkey = Keypair(create=True) else: certificate = Certificate(string=cert) - pkey = certificate.get_pubkey() + pkey = certificate.get_pubkey() # Add the email of the user to SubjectAltName in the GID email = None @@ -322,56 +340,64 @@ class RegistryManager: dbsession = api.dbsession() record = dbsession.query(RegUser).filter_by(hrn=hrn).first() if record: - email = getattr(record,'email',None) - gid = api.auth.hierarchy.create_gid(xrn, create_uuid(), pkey, email=email) + email = getattr(record, 'email', None) + gid = api.auth.hierarchy.create_gid( + xrn, create_uuid(), pkey, email=email) return gid.save_to_string(save_parents=True) - + #################### - # utility for handling relationships among the SFA objects - + # utility for handling relationships among the SFA objects + # subject_record describes the subject of the relationships # ref_record contains the target values for the various relationships we need to manage # (to begin with, this is just the slice x person (researcher) and authority x person (pi) relationships) - def update_driver_relations (self, api, subject_obj, ref_obj): - type=subject_obj.type - #for (k,v) in subject_obj.__dict__.items(): print k,'=',v - if type=='slice' and hasattr(ref_obj,'researcher'): - self.update_driver_relation(api, subject_obj, ref_obj.researcher, 'user', 'researcher') - elif type=='authority' and hasattr(ref_obj,'pi'): - self.update_driver_relation(api, subject_obj,ref_obj.pi, 'user', 'pi') - + def update_driver_relations(self, api, subject_obj, ref_obj): + type = subject_obj.type + # for (k,v) in subject_obj.__dict__.items(): print k,'=',v + if type == 'slice' and hasattr(ref_obj, 'researcher'): + self.update_driver_relation( + api, subject_obj, ref_obj.researcher, 'user', 'researcher') + elif type == 'authority' and hasattr(ref_obj, 'pi'): + self.update_driver_relation( + api, subject_obj, ref_obj.pi, 'user', 'pi') + # field_key is the name of one field in the record, typically 'researcher' for a 'slice' record # hrns is the list of hrns that should be linked to the subject from now on # target_type would be e.g. 'user' in the 'slice' x 'researcher' example - def update_driver_relation (self, api, record_obj, hrns, target_type, relation_name): + def update_driver_relation(self, api, record_obj, hrns, target_type, relation_name): dbsession = api.dbsession() # locate the linked objects in our db - subject_type=record_obj.type - subject_id=record_obj.pointer + subject_type = record_obj.type + subject_id = record_obj.pointer # get the 'pointer' field of all matching records - link_id_tuples = dbsession.query(RegRecord.pointer).filter_by(type=target_type).filter(RegRecord.hrn.in_(hrns)).all() + link_id_tuples = dbsession.query(RegRecord.pointer).filter_by( + type=target_type).filter(RegRecord.hrn.in_(hrns)).all() # sqlalchemy returns named tuples for columns - link_ids = [ tuple.pointer for tuple in link_id_tuples ] - api.driver.update_relation (subject_type, target_type, relation_name, subject_id, link_ids) + link_ids = [tuple.pointer for tuple in link_id_tuples] + api.driver.update_relation( + subject_type, target_type, relation_name, subject_id, link_ids) def Register(self, api, record_dict): - - logger.debug("Register: entering with record_dict={}".format(printable(record_dict))) - normalize_input_record (record_dict) - logger.debug("Register: normalized record_dict={}".format(printable(record_dict))) + + logger.debug("Register: entering with record_dict={}".format( + printable(record_dict))) + normalize_input_record(record_dict) + logger.debug("Register: normalized record_dict={}".format( + printable(record_dict))) dbsession = api.dbsession() hrn, type = record_dict['hrn'], record_dict['type'] - urn = hrn_to_urn(hrn,type) + urn = hrn_to_urn(hrn, type) # validate the type if type not in ['authority', 'slice', 'node', 'user']: - raise UnknownSfaType(type) - + raise UnknownSfaType(type) + # check if record_dict already exists - existing_records = dbsession.query(RegRecord).filter_by(type=type, hrn=hrn).all() + existing_records = dbsession.query( + RegRecord).filter_by(type=type, hrn=hrn).all() if existing_records: raise ExistingRecord(hrn) - + assert ('type' in record_dict) # returns the right type of RegRecord according to type in record record = make_record(dict=record_dict) @@ -383,75 +409,83 @@ class RegistryManager: if not record.gid: uuid = create_uuid() pkey = Keypair(create=True) - pub_key=getattr(record,'reg-keys',None) + pub_key = getattr(record, 'reg-keys', None) if pub_key is not None: # use only first key in record - if pub_key and isinstance(pub_key, list): pub_key = pub_key[0] + if pub_key and isinstance(pub_key, list): + pub_key = pub_key[0] pkey = convert_public_key(pub_key) - - email = getattr(record,'email',None) - gid_object = api.auth.hierarchy.create_gid(urn, uuid, pkey, email = email) + + email = getattr(record, 'email', None) + gid_object = api.auth.hierarchy.create_gid( + urn, uuid, pkey, email=email) gid = gid_object.save_to_string(save_parents=True) record.gid = gid - - if isinstance (record, RegAuthority): + + if isinstance(record, RegAuthority): # update the tree if not api.auth.hierarchy.auth_exists(hrn): - api.auth.hierarchy.create_auth(hrn_to_urn(hrn,'authority')) - + api.auth.hierarchy.create_auth(hrn_to_urn(hrn, 'authority')) + # get the GID from the newly created authority auth_info = api.auth.get_auth_info(hrn) gid = auth_info.get_gid_object() - record.gid=gid.save_to_string(save_parents=True) + record.gid = gid.save_to_string(save_parents=True) # locate objects for relationships - pi_hrns = getattr(record,'reg-pis',None) - if pi_hrns is not None: record.update_pis (pi_hrns, dbsession) - - elif isinstance (record, RegSlice): - researcher_hrns = getattr(record,'reg-researchers',None) - if researcher_hrns is not None: record.update_researchers (researcher_hrns, dbsession) - - elif isinstance (record, RegUser): + pi_hrns = getattr(record, 'reg-pis', None) + if pi_hrns is not None: + record.update_pis(pi_hrns, dbsession) + + elif isinstance(record, RegSlice): + researcher_hrns = getattr(record, 'reg-researchers', None) + if researcher_hrns is not None: + record.update_researchers(researcher_hrns, dbsession) + + elif isinstance(record, RegUser): # create RegKey objects for incoming keys - if hasattr(record,'reg-keys'): + if hasattr(record, 'reg-keys'): keys = getattr(record, 'reg-keys') # some people send the key as a string instead of a list of strings # note for python2/3 : no need to consider unicode in a key if isinstance(keys, str): keys = [keys] - logger.debug("creating {} keys for user {}".format(len(keys), record.hrn)) - record.reg_keys = [ RegKey (key) for key in keys ] - + logger.debug("creating {} keys for user {}".format( + len(keys), record.hrn)) + record.reg_keys = [RegKey(key) for key in keys] + # update testbed-specific data if needed - pointer = api.driver.register (record.__dict__, hrn, pub_key) + pointer = api.driver.register(record.__dict__, hrn, pub_key) - record.pointer=pointer + record.pointer = pointer dbsession.add(record) dbsession.commit() - + # update membership for researchers, pis, owners, operators - self.update_driver_relations (api, record, record) - + self.update_driver_relations(api, record, record) + return record.get_gid_object().save_to_string(save_parents=True) - + def Update(self, api, record_dict): - logger.debug("Update: entering with record_dict={}".format(printable(record_dict))) - normalize_input_record (record_dict) - logger.debug("Update: normalized record_dict={}".format(printable(record_dict))) + logger.debug("Update: entering with record_dict={}".format( + printable(record_dict))) + normalize_input_record(record_dict) + logger.debug("Update: normalized record_dict={}".format( + printable(record_dict))) dbsession = api.dbsession() assert ('type' in record_dict) new_record = make_record(dict=record_dict) (type, hrn) = (new_record.type, new_record.hrn) - + # make sure the record exists - record = dbsession.query(RegRecord).filter_by(type=type, hrn=hrn).first() + record = dbsession.query(RegRecord).filter_by( + type=type, hrn=hrn).first() if not record: raise RecordNotFound("hrn={}, type={}".format(hrn, type)) record.just_updated() - + # Use the pointer from the existing record, not the one that the user # gave us. This prevents the user from inserting a forged pointer pointer = record.pointer @@ -461,7 +495,7 @@ class RegistryManager: if type == 'user': if getattr(new_record, 'keys', None): new_key = new_record.keys - if isinstance (new_key, list): + if isinstance(new_key, list): new_key = new_key[0] # take new_key into account @@ -469,23 +503,25 @@ class RegistryManager: # update the openssl key and gid pkey = convert_public_key(new_key) uuid = create_uuid() - urn = hrn_to_urn(hrn,type) + urn = hrn_to_urn(hrn, type) email = getattr(new_record, 'email', None) if email is None: email = getattr(record, 'email', None) - gid_object = api.auth.hierarchy.create_gid(urn, uuid, pkey, email = email) + gid_object = api.auth.hierarchy.create_gid( + urn, uuid, pkey, email=email) gid = gid_object.save_to_string(save_parents=True) - + # xxx should do side effects from new_record to record # not too sure how to do that - # not too big a deal with planetlab as the driver is authoritative, but... + # not too big a deal with planetlab as the driver is authoritative, + # but... # update native relations if isinstance(record, RegSlice): researcher_hrns = getattr(new_record, 'reg-researchers', None) if researcher_hrns is not None: - record.update_researchers (researcher_hrns, dbsession) + record.update_researchers(researcher_hrns, dbsession) elif isinstance(record, RegAuthority): pi_hrns = getattr(new_record, 'reg-pis', None) @@ -499,10 +535,10 @@ class RegistryManager: email = getattr(new_record, 'email', None) if email is not None: record.email = email - + # update the PLC information that was specified with the record # xxx mystery -- see also the bottom of model.py, - # oddly enough, without this useless statement, + # oddly enough, without this useless statement, # record.__dict__ as received by the driver seems to be off # anyway the driver should receive an object # (and then extract __dict__ itself if needed) @@ -513,19 +549,20 @@ class RegistryManager: # but that would need to be confirmed by more extensive tests new_key_pointer = -1 try: - (pointer, new_key_pointer) = api.driver.update (record.__dict__, new_record.__dict__, hrn, new_key) + (pointer, new_key_pointer) = api.driver.update( + record.__dict__, new_record.__dict__, hrn, new_key) except: - pass + pass if new_key and new_key_pointer: - record.reg_keys = [ RegKey(new_key, new_key_pointer) ] + record.reg_keys = [RegKey(new_key, new_key_pointer)] record.gid = gid dbsession.commit() # update membership for researchers, pis, owners, operators self.update_driver_relations(api, record, new_record) - - return 1 - + + return 1 + # expecting an Xrn instance def Remove(self, api, xrn, origin_hrn=None): dbsession = api.dbsession() @@ -534,43 +571,47 @@ class RegistryManager: request = dbsession.query(RegRecord).filter_by(hrn=hrn) if type and type not in ['all', '*']: request = request.filter_by(type=type) - + record = request.first() if not record: msg = "Could not find hrn {}".format(hrn) - if type: msg += " type={}".format(type) + if type: + msg += " type={}".format(type) raise RecordNotFound(msg) type = record.type - if type not in ['slice', 'user', 'node', 'authority'] : + if type not in ['slice', 'user', 'node', 'authority']: raise UnknownSfaType(type) credential = api.getCredential() registries = api.registries - + # Try to remove the object from the PLCDB of federated agg. - # This is attempted before removing the object from the local agg's PLCDB and sfa table + # This is attempted before removing the object from the local agg's + # PLCDB and sfa table if hrn.startswith(api.hrn) and type in ['user', 'slice', 'authority']: for registry in registries: if registry not in [api.hrn]: try: - result=registries[registry].remove_peer_object(credential, record, origin_hrn) + result = registries[registry].remove_peer_object( + credential, record, origin_hrn) except: pass # call testbed callback first - # IIUC this is done on the local testbed TOO because of the refreshpeer link + # IIUC this is done on the local testbed TOO because of the refreshpeer + # link if not api.driver.remove(record.__dict__): logger.warning("driver.remove failed") # delete from sfa db dbsession.delete(record) dbsession.commit() - + return 1 # This is a PLC-specific thing, won't work with other platforms - def get_key_from_incoming_ip (self, api): + def get_key_from_incoming_ip(self, api): dbsession = api.dbsession() # verify that the callers's ip address exist in the db and is an interface # for a node in the db @@ -578,16 +619,19 @@ class RegistryManager: interfaces = api.driver.shell.GetInterfaces({'ip': ip}, ['node_id']) if not interfaces: raise NonExistingRecord("no such ip {}".format(ip)) - nodes = api.driver.shell.GetNodes([interfaces[0]['node_id']], ['node_id', 'hostname']) + nodes = api.driver.shell.GetNodes( + [interfaces[0]['node_id']], ['node_id', 'hostname']) if not nodes: raise NonExistingRecord("no such node using ip {}".format(ip)) node = nodes[0] - + # look up the sfa record - record = dbsession.query(RegRecord).filter_by(type='node', pointer=node['node_id']).first() + record = dbsession.query(RegRecord).filter_by( + type='node', pointer=node['node_id']).first() if not record: - raise RecordNotFound("node with pointer {}".format(node['node_id'])) - + raise RecordNotFound( + "node with pointer {}".format(node['node_id'])) + # generate a new keypair and gid uuid = create_uuid() pkey = Keypair(create=True) @@ -600,21 +644,21 @@ class RegistryManager: # update the record dbsession.commit() - + # attempt the scp the key # and gid onto the node # this will only work for planetlab based components - (kfd, key_filename) = tempfile.mkstemp() - (gfd, gid_filename) = tempfile.mkstemp() + (kfd, key_filename) = tempfile.mkstemp() + (gfd, gid_filename) = tempfile.mkstemp() pkey.save_to_file(key_filename) gid_object.save_to_file(gid_filename, save_parents=True) host = node['hostname'] - key_dest="/etc/sfa/node.key" - gid_dest="/etc/sfa/node.gid" - scp = "/usr/bin/scp" + key_dest = "/etc/sfa/node.key" + gid_dest = "/etc/sfa/node.gid" + scp = "/usr/bin/scp" #identity = "/etc/planetlab/root_ssh_key.rsa" identity = "/etc/sfa/root_ssh_key" - scp_options = " -i {identity} ".format(**locals()) + scp_options = " -i {identity} ".format(**locals()) scp_options += "-o StrictHostKeyChecking=no " scp_key_command = "{scp} {scp_options} {key_filename} root@{host}:{key_dest}"\ .format(**locals()) @@ -622,7 +666,7 @@ class RegistryManager: .format(**locals()) all_commands = [scp_key_command, scp_gid_command] - + for command in all_commands: (status, output) = commands.getstatusoutput(command) if status: @@ -631,4 +675,4 @@ class RegistryManager: for filename in [key_filename, gid_filename]: os.unlink(filename) - return 1 + return 1 diff --git a/sfa/managers/slice_manager.py b/sfa/managers/slice_manager.py index 8252ec3f..011a22ee 100644 --- a/sfa/managers/slice_manager.py +++ b/sfa/managers/slice_manager.py @@ -17,56 +17,64 @@ from sfa.client.multiclient import MultiClient from sfa.rspecs.rspec_converter import RSpecConverter from sfa.rspecs.version_manager import VersionManager -from sfa.rspecs.rspec import RSpec +from sfa.rspecs.rspec import RSpec from sfa.client.client_helper import sfa_to_pg_users_arg from sfa.client.return_value import ReturnValue + class SliceManager: - # the cache instance is a class member so it survives across incoming requests + # the cache instance is a class member so it survives across incoming + # requests cache = None - def __init__ (self, config): - self.cache=None + def __init__(self, config): + self.cache = None if config.SFA_SM_CACHING: if SliceManager.cache is None: SliceManager.cache = Cache() self.cache = SliceManager.cache - + def GetVersion(self, api, options): # peers explicitly in aggregates.xml - peers =dict ([ (peername,interface.get_url()) for (peername,interface) in api.aggregates.iteritems() - if peername != api.hrn]) + peers = dict([(peername, interface.get_url()) for (peername, interface) in api.aggregates.iteritems() + if peername != api.hrn]) version_manager = VersionManager() ad_rspec_versions = [] request_rspec_versions = [] - cred_types = [{'geni_type': 'geni_sfa', 'geni_version': str(i)} for i in range(4)[-2:]] + cred_types = [{'geni_type': 'geni_sfa', + 'geni_version': str(i)} for i in range(4)[-2:]] for rspec_version in version_manager.versions: if rspec_version.content_type in ['*', 'ad']: ad_rspec_versions.append(rspec_version.to_dict()) if rspec_version.content_type in ['*', 'request']: request_rspec_versions.append(rspec_version.to_dict()) - xrn=Xrn(api.hrn, 'authority+sm') + xrn = Xrn(api.hrn, 'authority+sm') version_more = { - 'interface':'slicemgr', + 'interface': 'slicemgr', 'sfa': 2, 'geni_api': 3, 'geni_api_versions': {'3': 'http://%s:%s' % (api.config.SFA_SM_HOST, api.config.SFA_SM_PORT)}, - 'hrn' : xrn.get_hrn(), - 'urn' : xrn.get_urn(), + 'hrn': xrn.get_hrn(), + 'urn': xrn.get_urn(), 'peers': peers, - 'geni_single_allocation': 0, # Accept operations that act on as subset of slivers in a given state. - 'geni_allocate': 'geni_many',# Multiple slivers can exist and be incrementally added, including those which connect or overlap in some way. + # Accept operations that act on as subset of slivers in a given + # state. + 'geni_single_allocation': 0, + # Multiple slivers can exist and be incrementally added, including + # those which connect or overlap in some way. + 'geni_allocate': 'geni_many', 'geni_credential_types': cred_types, - } - sm_version=version_core(version_more) + } + sm_version = version_core(version_more) # local aggregate if present needs to have localhost resolved if api.hrn in api.aggregates: - local_am_url=api.aggregates[api.hrn].get_url() - sm_version['peers'][api.hrn]=local_am_url.replace('localhost',sm_version['hostname']) + local_am_url = api.aggregates[api.hrn].get_url() + sm_version['peers'][api.hrn] = local_am_url.replace( + 'localhost', sm_version['hostname']) return sm_version - + def drop_slicemgr_stats(self, rspec): try: stats_elements = rspec.xml.xpath('//statistics') @@ -74,20 +82,22 @@ class SliceManager: node.getparent().remove(node) except Exception as e: logger.warn("drop_slicemgr_stats failed: %s " % (str(e))) - + def add_slicemgr_stat(self, rspec, callname, aggname, elapsed, status, exc_info=None): try: stats_tags = rspec.xml.xpath('//statistics[@call="%s"]' % callname) if stats_tags: stats_tag = stats_tags[0] else: - stats_tag = rspec.xml.root.add_element("statistics", call=callname) + stats_tag = rspec.xml.root.add_element( + "statistics", call=callname) - stat_tag = stats_tag.add_element("aggregate", name=str(aggname), + stat_tag = stats_tag.add_element("aggregate", name=str(aggname), elapsed=str(elapsed), status=str(status)) if exc_info: - exc_tag = stat_tag.add_element("exc_info", name=str(exc_info[1])) + exc_tag = stat_tag.add_element( + "exc_info", name=str(exc_info[1])) # formats the traceback as one big text blob #exc_tag.text = "\n".join(traceback.format_exception(exc_info[0], exc_info[1], exc_info[2])) @@ -95,15 +105,17 @@ class SliceManager: # formats the traceback as a set of xml elements tb = traceback.extract_tb(exc_info[2]) for item in tb: - exc_frame = exc_tag.add_element("tb_frame", filename=str(item[0]), + exc_frame = exc_tag.add_element("tb_frame", filename=str(item[0]), line=str(item[1]), func=str(item[2]), code=str(item[3])) except Exception as e: - logger.warn("add_slicemgr_stat failed on %s: %s" %(aggname, str(e))) - + logger.warn("add_slicemgr_stat failed on %s: %s" % + (aggname, str(e))) + def ListResources(self, api, creds, options): - call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + call_id = options.get('call_id') + if Callids().already_handled(call_id): + return "" version_manager = VersionManager() @@ -113,35 +125,38 @@ class SliceManager: try: version = api.get_cached_server_version(server) # force ProtoGENI aggregates to give us a v2 RSpec - forward_options['geni_rspec_version'] = options.get('geni_rspec_version') + forward_options['geni_rspec_version'] = options.get( + 'geni_rspec_version') result = server.ListResources(credential, forward_options) - return {"aggregate": aggregate, "result": result, "elapsed": time.time()-tStart, "status": "success"} + return {"aggregate": aggregate, "result": result, "elapsed": time.time() - tStart, "status": "success"} except Exception as e: - api.logger.log_exc("ListResources failed at %s" %(server.url)) - return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception", "exc_info": sys.exc_info()} - + api.logger.log_exc("ListResources failed at %s" % (server.url)) + return {"aggregate": aggregate, "elapsed": time.time() - tStart, "status": "exception", "exc_info": sys.exc_info()} + # get slice's hrn from options xrn = options.get('geni_slice_urn', '') (hrn, type) = urn_to_hrn(xrn) if 'geni_compressed' in options: del(options['geni_compressed']) - + # get the rspec's return format from options - rspec_version = version_manager.get_version(options.get('geni_rspec_version')) + rspec_version = version_manager.get_version( + options.get('geni_rspec_version')) version_string = "rspec_%s" % (rspec_version) - + # look in cache first cached_requested = options.get('cached', True) if not xrn and self.cache and cached_requested: - rspec = self.cache.get(version_string) + rspec = self.cache.get(version_string) if rspec: - api.logger.debug("SliceManager.ListResources returns cached advertisement") + api.logger.debug( + "SliceManager.ListResources returns cached advertisement") return rspec - + # get the callers hrn valid_cred = api.auth.checkCredentials(creds, 'listnodes', hrn)[0] caller_hrn = Credential(cred=valid_cred).get_gid_caller().get_hrn() - + # attempt to use delegated credential first cred = api.getDelegatedCredential(creds) if not cred: @@ -152,61 +167,66 @@ class SliceManager: # unless the caller is the aggregate's SM if caller_hrn == aggregate and aggregate != api.hrn: continue - + # get the rspec from the aggregate interface = api.aggregates[aggregate] server = api.server_proxy(interface, cred) multiclient.run(_ListResources, aggregate, server, [cred], options) - - + results = multiclient.get_results() - rspec_version = version_manager.get_version(options.get('geni_rspec_version')) - if xrn: - result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'manifest') - else: - result_version = version_manager._get_version(rspec_version.type, rspec_version.version, 'ad') + rspec_version = version_manager.get_version( + options.get('geni_rspec_version')) + if xrn: + result_version = version_manager._get_version( + rspec_version.type, rspec_version.version, 'manifest') + else: + result_version = version_manager._get_version( + rspec_version.type, rspec_version.version, 'ad') rspec = RSpec(version=result_version) for result in results: - self.add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], - result["status"], result.get("exc_info",None)) - if result["status"]=="success": + self.add_slicemgr_stat(rspec, "ListResources", result["aggregate"], result["elapsed"], + result["status"], result.get("exc_info", None)) + if result["status"] == "success": res = result['result']['value'] try: rspec.version.merge(ReturnValue.get_value(res)) except: - api.logger.log_exc("SM.ListResources: Failed to merge aggregate rspec") - + api.logger.log_exc( + "SM.ListResources: Failed to merge aggregate rspec") + # cache the result if self.cache and not xrn: api.logger.debug("SliceManager.ListResources caches advertisement") self.cache.add(version_string, rspec.toxml()) - - return rspec.toxml() + return rspec.toxml() def Allocate(self, api, xrn, creds, rspec_str, expiration, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" - + if Callids().already_handled(call_id): + return "" + version_manager = VersionManager() + def _Allocate(aggregate, server, xrn, credential, rspec, options): tStart = time.time() try: # Need to call GetVersion at an aggregate to determine the supported # rspec type/format beofre calling CreateSliver at an Aggregate. #server_version = api.get_cached_server_version(server) - #if 'sfa' not in server_version and 'geni_api' in server_version: + # if 'sfa' not in server_version and 'geni_api' in server_version: # sfa aggregtes support both sfa and pg rspecs, no need to convert # if aggregate supports sfa rspecs. otherwise convert to pg rspec #rspec = RSpec(RSpecConverter.to_pg_rspec(rspec, 'request')) #filter = {'component_manager_id': server_version['urn']} - #rspec.filter(filter) + # rspec.filter(filter) #rspec = rspec.toxml() result = server.Allocate(xrn, credential, rspec, options) - return {"aggregate": aggregate, "result": result, "elapsed": time.time()-tStart, "status": "success"} + return {"aggregate": aggregate, "result": result, "elapsed": time.time() - tStart, "status": "success"} except: - logger.log_exc('Something wrong in _Allocate with URL %s'%server.url) - return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception", "exc_info": sys.exc_info()} + logger.log_exc( + 'Something wrong in _Allocate with URL %s' % server.url) + return {"aggregate": aggregate, "elapsed": time.time() - tStart, "status": "exception", "exc_info": sys.exc_info()} # Validate the RSpec against PlanetLab's schema --disabled for now # The schema used here needs to aggregate the PL and VINI schemas @@ -215,16 +235,16 @@ class SliceManager: # schema = None # if schema: # rspec.validate(schema) - + # if there is a section, the aggregates don't care about it, # so delete it. self.drop_slicemgr_stats(rspec) - + # attempt to use delegated credential first cred = api.getDelegatedCredential(creds) if not cred: cred = api.getCredential() - + # get the callers hrn hrn, type = urn_to_hrn(xrn) valid_cred = api.auth.checkCredentials(creds, 'createsliver', hrn)[0] @@ -232,54 +252,61 @@ class SliceManager: multiclient = MultiClient() for aggregate in api.aggregates: # prevent infinite loop. Dont send request back to caller - # unless the caller is the aggregate's SM + # unless the caller is the aggregate's SM if caller_hrn == aggregate and aggregate != api.hrn: continue interface = api.aggregates[aggregate] server = api.server_proxy(interface, cred) # Just send entire RSpec to each aggregate - multiclient.run(_Allocate, aggregate, server, xrn, [cred], rspec.toxml(), options) - + multiclient.run(_Allocate, aggregate, server, xrn, + [cred], rspec.toxml(), options) + results = multiclient.get_results() - manifest_version = version_manager._get_version(rspec.version.type, rspec.version.version, 'manifest') + manifest_version = version_manager._get_version( + rspec.version.type, rspec.version.version, 'manifest') result_rspec = RSpec(version=manifest_version) geni_urn = None geni_slivers = [] for result in results: - self.add_slicemgr_stat(result_rspec, "Allocate", result["aggregate"], result["elapsed"], - result["status"], result.get("exc_info",None)) - if result["status"]=="success": + self.add_slicemgr_stat(result_rspec, "Allocate", result["aggregate"], result["elapsed"], + result["status"], result.get("exc_info", None)) + if result["status"] == "success": try: res = result['result']['value'] geni_urn = res['geni_urn'] - result_rspec.version.merge(ReturnValue.get_value(res['geni_rspec'])) + result_rspec.version.merge( + ReturnValue.get_value(res['geni_rspec'])) geni_slivers.extend(res['geni_slivers']) except: - api.logger.log_exc("SM.Allocate: Failed to merge aggregate rspec") + api.logger.log_exc( + "SM.Allocate: Failed to merge aggregate rspec") return { 'geni_urn': geni_urn, 'geni_rspec': result_rspec.toxml(), 'geni_slivers': geni_slivers } - def Provision(self, api, xrn, creds, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + if Callids().already_handled(call_id): + return "" version_manager = VersionManager() + def _Provision(aggregate, server, xrn, credential, options): tStart = time.time() try: # Need to call GetVersion at an aggregate to determine the supported - # rspec type/format beofre calling CreateSliver at an Aggregate. + # rspec type/format beofre calling CreateSliver at an + # Aggregate. server_version = api.get_cached_server_version(server) result = server.Provision(xrn, credential, options) - return {"aggregate": aggregate, "result": result, "elapsed": time.time()-tStart, "status": "success"} + return {"aggregate": aggregate, "result": result, "elapsed": time.time() - tStart, "status": "success"} except: - logger.log_exc('Something wrong in _Allocate with URL %s'%server.url) - return {"aggregate": aggregate, "elapsed": time.time()-tStart, "status": "exception", "exc_info": sys.exc_info()} + logger.log_exc( + 'Something wrong in _Allocate with URL %s' % server.url) + return {"aggregate": aggregate, "elapsed": time.time() - tStart, "status": "exception", "exc_info": sys.exc_info()} # attempt to use delegated credential first cred = api.getDelegatedCredential(creds) @@ -298,45 +325,49 @@ class SliceManager: interface = api.aggregates[aggregate] server = api.server_proxy(interface, cred) # Just send entire RSpec to each aggregate - multiclient.run(_Provision, aggregate, server, xrn, [cred], options) + multiclient.run(_Provision, aggregate, + server, xrn, [cred], options) results = multiclient.get_results() - manifest_version = version_manager._get_version('GENI', '3', 'manifest') + manifest_version = version_manager._get_version( + 'GENI', '3', 'manifest') result_rspec = RSpec(version=manifest_version) geni_slivers = [] - geni_urn = None + geni_urn = None for result in results: self.add_slicemgr_stat(result_rspec, "Provision", result["aggregate"], result["elapsed"], - result["status"], result.get("exc_info",None)) - if result["status"]=="success": + result["status"], result.get("exc_info", None)) + if result["status"] == "success": try: res = result['result']['value'] geni_urn = res['geni_urn'] - result_rspec.version.merge(ReturnValue.get_value(res['geni_rspec'])) + result_rspec.version.merge( + ReturnValue.get_value(res['geni_rspec'])) geni_slivers.extend(res['geni_slivers']) except: - api.logger.log_exc("SM.Provision: Failed to merge aggregate rspec") + api.logger.log_exc( + "SM.Provision: Failed to merge aggregate rspec") return { 'geni_urn': geni_urn, 'geni_rspec': result_rspec.toxml(), 'geni_slivers': geni_slivers - } - + } - def Renew(self, api, xrn, creds, expiration_time, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return True + if Callids().already_handled(call_id): + return True def _Renew(aggregate, server, xrn, creds, expiration_time, options): try: - result=server.Renew(xrn, creds, expiration_time, options) - if type(result)!=dict: + result = server.Renew(xrn, creds, expiration_time, options) + if type(result) != dict: result = {'code': {'geni_code': 0}, 'value': result} result['aggregate'] = aggregate return result except: - logger.log_exc('Something wrong in _Renew with URL %s'%server.url) + logger.log_exc( + 'Something wrong in _Renew with URL %s' % server.url) return {'aggregate': aggregate, 'exc_info': traceback.format_exc(), 'code': {'geni_code': -1}, 'value': False, 'output': ""} @@ -348,7 +379,7 @@ class SliceManager: # attempt to use delegated credential first cred = api.getDelegatedCredential(creds) if not cred: - cred = api.getCredential(minimumExpiration=31*86400) + cred = api.getCredential(minimumExpiration=31 * 86400) multiclient = MultiClient() for aggregate in api.aggregates: # prevent infinite loop. Dont send request back to caller @@ -357,25 +388,29 @@ class SliceManager: continue interface = api.aggregates[aggregate] server = api.server_proxy(interface, cred) - multiclient.run(_Renew, aggregate, server, xrn, [cred], expiration_time, options) + multiclient.run(_Renew, aggregate, server, xrn, [ + cred], expiration_time, options) results = multiclient.get_results() geni_code = 0 - geni_output = ",".join([x.get('output',"") for x in results]) - geni_value = reduce (lambda x,y: x and y, [result.get('value',False) for result in results], True) + geni_output = ",".join([x.get('output', "") for x in results]) + geni_value = reduce(lambda x, y: x and y, [result.get( + 'value', False) for result in results], True) for agg_result in results: - agg_geni_code = agg_result['code'].get('geni_code',0) + agg_geni_code = agg_result['code'].get('geni_code', 0) if agg_geni_code: geni_code = agg_geni_code - results = {'aggregates': results, 'code': {'geni_code': geni_code}, 'value': geni_value, 'output': geni_output} + results = {'aggregates': results, 'code': { + 'geni_code': geni_code}, 'value': geni_value, 'output': geni_output} return results def Delete(self, api, xrn, creds, options): call_id = options.get('call_id') - if Callids().already_handled(call_id): return "" + if Callids().already_handled(call_id): + return "" def _Delete(server, xrn, creds, options): return server.Delete(xrn, creds, options) @@ -398,20 +433,20 @@ class SliceManager: interface = api.aggregates[aggregate] server = api.server_proxy(interface, cred) multiclient.run(_Delete, server, xrn, [cred], options) - + results = [] for result in multiclient.get_results(): results += ReturnValue.get_value(result) return results - - + # first draft at a merging SliverStatus def Status(self, api, slice_xrn, creds, options): def _Status(server, xrn, creds, options): return server.Status(xrn, creds, options) - call_id = options.get('call_id') - if Callids().already_handled(call_id): return {} + call_id = options.get('call_id') + if Callids().already_handled(call_id): + return {} # attempt to use delegated credential first cred = api.getDelegatedCredential(creds) if not cred: @@ -420,36 +455,41 @@ class SliceManager: for aggregate in api.aggregates: interface = api.aggregates[aggregate] server = api.server_proxy(interface, cred) - multiclient.run (_Status, server, slice_xrn, [cred], options) - results = [ReturnValue.get_value(result) for result in multiclient.get_results()] - - # get rid of any void result - e.g. when call_id was hit, where by convention we return {} - results = [ result for result in results if result and result['geni_slivers']] - + multiclient.run(_Status, server, slice_xrn, [cred], options) + results = [ReturnValue.get_value(result) + for result in multiclient.get_results()] + + # get rid of any void result - e.g. when call_id was hit, where by + # convention we return {} + results = [ + result for result in results if result and result['geni_slivers']] + # do not try to combine if there's no result - if not results : return {} - + if not results: + return {} + # otherwise let's merge stuff geni_slivers = [] - geni_urn = None + geni_urn = None for result in results: try: geni_urn = result['geni_urn'] geni_slivers.extend(result['geni_slivers']) except: - api.logger.log_exc("SM.Provision: Failed to merge aggregate rspec") + api.logger.log_exc( + "SM.Provision: Failed to merge aggregate rspec") return { 'geni_urn': geni_urn, 'geni_slivers': geni_slivers } - def Describe(self, api, creds, xrns, options): def _Describe(server, xrn, creds, options): return server.Describe(xrn, creds, options) call_id = options.get('call_id') - if Callids().already_handled(call_id): return {} + if Callids().already_handled(call_id): + return {} # attempt to use delegated credential first cred = api.getDelegatedCredential(creds) if not cred: @@ -458,39 +498,46 @@ class SliceManager: for aggregate in api.aggregates: interface = api.aggregates[aggregate] server = api.server_proxy(interface, cred) - multiclient.run (_Describe, server, xrns, [cred], options) - results = [ReturnValue.get_value(result) for result in multiclient.get_results()] + multiclient.run(_Describe, server, xrns, [cred], options) + results = [ReturnValue.get_value(result) + for result in multiclient.get_results()] - # get rid of any void result - e.g. when call_id was hit, where by convention we return {} - results = [ result for result in results if result and result.get('geni_urn')] + # get rid of any void result - e.g. when call_id was hit, where by + # convention we return {} + results = [ + result for result in results if result and result.get('geni_urn')] # do not try to combine if there's no result - if not results : return {} + if not results: + return {} # otherwise let's merge stuff version_manager = VersionManager() - manifest_version = version_manager._get_version('GENI', '3', 'manifest') + manifest_version = version_manager._get_version( + 'GENI', '3', 'manifest') result_rspec = RSpec(version=manifest_version) geni_slivers = [] - geni_urn = None + geni_urn = None for result in results: try: geni_urn = result['geni_urn'] - result_rspec.version.merge(ReturnValue.get_value(result['geni_rspec'])) + result_rspec.version.merge( + ReturnValue.get_value(result['geni_rspec'])) geni_slivers.extend(result['geni_slivers']) except: - api.logger.log_exc("SM.Provision: Failed to merge aggregate rspec") + api.logger.log_exc( + "SM.Provision: Failed to merge aggregate rspec") return { 'geni_urn': geni_urn, - 'geni_rspec': result_rspec.toxml(), + 'geni_rspec': result_rspec.toxml(), 'geni_slivers': geni_slivers - } - + } + def PerformOperationalAction(self, api, xrn, creds, action, options): # get the callers hrn valid_cred = api.auth.checkCredentials(creds, 'createsliver', xrn)[0] caller_hrn = Credential(cred=valid_cred).get_gid_caller().get_hrn() - + # attempt to use delegated credential first cred = api.getDelegatedCredential(creds) if not cred: @@ -502,18 +549,20 @@ class SliceManager: if caller_hrn == aggregate and aggregate != api.hrn: continue interface = api.aggregates[aggregate] - server = api.server_proxy(interface, cred) - multiclient.run(server.PerformOperationalAction, xrn, [cred], action, options) - multiclient.get_results() + server = api.server_proxy(interface, cred) + multiclient.run(server.PerformOperationalAction, + xrn, [cred], action, options) + multiclient.get_results() return 1 - + def Shutdown(self, api, xrn, creds, options=None): - if options is None: options={} - xrn = Xrn(xrn) + if options is None: + options = {} + xrn = Xrn(xrn) # get the callers hrn valid_cred = api.auth.checkCredentials(creds, 'stopslice', xrn.hrn)[0] caller_hrn = Credential(cred=valid_cred).get_gid_caller().get_hrn() - + # attempt to use delegated credential first cred = api.getDelegatedCredential(creds) if not cred: @@ -527,6 +576,5 @@ class SliceManager: interface = api.aggregates[aggregate] server = api.server_proxy(interface, cred) multiclient.run(server.Shutdown, xrn.urn, cred) - multiclient.get_results() + multiclient.get_results() return 1 - diff --git a/sfa/managers/v2_to_v3_adapter.py b/sfa/managers/v2_to_v3_adapter.py index 8b0ce61b..15c89e08 100644 --- a/sfa/managers/v2_to_v3_adapter.py +++ b/sfa/managers/v2_to_v3_adapter.py @@ -8,9 +8,10 @@ from sfa.util.cache import Cache from sfa.rspecs.rspec import RSpec from sfa.storage.model import SliverAllocation + class V2ToV3Adapter: - def __init__ (self, api): + def __init__(self, api): config = api.config flavour = config.SFA_GENERIC_FLAVOUR # to be cleaned @@ -21,23 +22,24 @@ class V2ToV3Adapter: from sfa.federica.fddriver import FdDriver self.driver = FdDriver(api) else: - logger.error("V2ToV3Adapter: Unknown Flavour !!!\n Supported Flavours: nitos, fd") - - # Caching + logger.error( + "V2ToV3Adapter: Unknown Flavour !!!\n Supported Flavours: nitos, fd") + + # Caching if config.SFA_AGGREGATE_CACHING: if self.driver.cache: self.cache = self.driver.cache else: self.cache = Cache() - def __getattr__(self, name): def func(*args, **kwds): if name == "list_resources": (version, options) = args slice_urn = slice_hrn = None creds = [] - rspec = getattr(self.driver, "list_resources")(slice_urn, slice_hrn, [], options) + rspec = getattr(self.driver, "list_resources")( + slice_urn, slice_hrn, [], options) result = rspec elif name == "describe": @@ -45,94 +47,103 @@ class V2ToV3Adapter: slice_urn = urns[0] slice_hrn, type = urn_to_hrn(slice_urn) creds = [] - rspec = getattr(self.driver, "list_resources")(slice_urn, slice_hrn, creds, options) - + rspec = getattr(self.driver, "list_resources")( + slice_urn, slice_hrn, creds, options) + # SliverAllocation if len(urns) == 1 and Xrn(xrn=urns[0]).type == 'slice': constraint = SliverAllocation.slice_urn.in_(urns) else: constraint = SliverAllocation.sliver_id.in_(urns) - - sliver_allocations = self.driver.api.dbsession().query (SliverAllocation).filter (constraint) - sliver_status = getattr(self.driver, "sliver_status")(slice_urn, slice_hrn) + + sliver_allocations = self.driver.api.dbsession().query( + SliverAllocation).filter(constraint) + sliver_status = getattr(self.driver, "sliver_status")( + slice_urn, slice_hrn) if 'geni_expires' in sliver_status.keys(): geni_expires = sliver_status['geni_expires'] - else: + else: geni_expires = '' - + geni_slivers = [] for sliver_allocation in sliver_allocations: geni_sliver = {} geni_sliver['geni_expires'] = geni_expires - geni_sliver['geni_allocation'] = sliver_allocation.allocation_state - geni_sliver['geni_sliver_urn'] = sliver_allocation.sliver_id + geni_sliver[ + 'geni_allocation'] = sliver_allocation.allocation_state + geni_sliver[ + 'geni_sliver_urn'] = sliver_allocation.sliver_id geni_sliver['geni_error'] = '' if geni_sliver['geni_allocation'] == 'geni_allocated': - geni_sliver['geni_operational_status'] = 'geni_pending_allocation' - else: + geni_sliver[ + 'geni_operational_status'] = 'geni_pending_allocation' + else: geni_sliver['geni_operational_status'] = 'geni_ready' geni_slivers.append(geni_sliver) - result = {'geni_urn': slice_urn, - 'geni_rspec': rspec, - 'geni_slivers': geni_slivers} + 'geni_rspec': rspec, + 'geni_slivers': geni_slivers} elif name == "allocate": (slice_urn, rspec_string, expiration, options) = args slice_hrn, type = urn_to_hrn(slice_urn) creds = [] users = options.get('sfa_users', []) - manifest_string = getattr(self.driver, "create_sliver")(slice_urn, slice_hrn, creds, rspec_string, users, options) - + manifest_string = getattr(self.driver, "create_sliver")( + slice_urn, slice_hrn, creds, rspec_string, users, options) + # slivers allocation rspec = RSpec(manifest_string) slivers = rspec.version.get_nodes_with_slivers() - - ##SliverAllocation + + # SliverAllocation for sliver in slivers: - client_id = sliver['client_id'] - component_id = sliver['component_id'] - component_name = sliver['component_name'] - slice_name = slice_hrn.replace('.','-') - component_short_name = component_name.split('.')[0] - # self.driver.hrn - sliver_hrn = '%s.%s-%s' % (self.driver.hrn, slice_name, component_short_name) - sliver_id = Xrn(sliver_hrn, type='sliver').urn - record = SliverAllocation(sliver_id=sliver_id, - client_id=client_id, - component_id=component_id, - slice_urn = slice_urn, - allocation_state='geni_allocated') - - record.sync(self.driver.api.dbsession()) - - + client_id = sliver['client_id'] + component_id = sliver['component_id'] + component_name = sliver['component_name'] + slice_name = slice_hrn.replace('.', '-') + component_short_name = component_name.split('.')[0] + # self.driver.hrn + sliver_hrn = '%s.%s-%s' % (self.driver.hrn, + slice_name, component_short_name) + sliver_id = Xrn(sliver_hrn, type='sliver').urn + record = SliverAllocation(sliver_id=sliver_id, + client_id=client_id, + component_id=component_id, + slice_urn=slice_urn, + allocation_state='geni_allocated') + + record.sync(self.driver.api.dbsession()) + # return manifest rspec_version = RSpec(rspec_string).version - rspec_version_str = "%s"%rspec_version - options['geni_rspec_version'] = {'version': rspec_version_str.split(' ')[1], 'type': rspec_version_str.lower().split(' ')[0]} + rspec_version_str = "%s" % rspec_version + options['geni_rspec_version'] = {'version': rspec_version_str.split( + ' ')[1], 'type': rspec_version_str.lower().split(' ')[0]} result = self.describe([slice_urn], rspec_version, options) - - elif name == "provision": + + elif name == "provision": (urns, options) = args if len(urns) == 1 and Xrn(xrn=urns[0]).type == 'slice': - constraint = SliverAllocation.slice_urn.in_(urns) + constraint = SliverAllocation.slice_urn.in_(urns) else: - constraint = SliverAllocation.sliver_id.in_(urns) - + constraint = SliverAllocation.sliver_id.in_(urns) + dbsession = self.driver.api.dbsession() - sliver_allocations = dbsession.query (SliverAllocation).filter(constraint) + sliver_allocations = dbsession.query( + SliverAllocation).filter(constraint) for sliver_allocation in sliver_allocations: - sliver_allocation.allocation_state = 'geni_provisioned' - + sliver_allocation.allocation_state = 'geni_provisioned' + dbsession.commit() result = self.describe(urns, '', options) elif name == "status": urns = args options = {} - options['geni_rspec_version'] = {'version': '3', 'type': 'GENI'} + options['geni_rspec_version'] = { + 'version': '3', 'type': 'GENI'} descr = self.describe(urns[0], '', options) result = {'geni_urn': descr['geni_urn'], 'geni_slivers': descr['geni_slivers']} @@ -142,48 +153,52 @@ class V2ToV3Adapter: slice_urn = urns[0] slice_hrn, type = urn_to_hrn(slice_urn) creds = [] - options['geni_rspec_version'] = {'version': '3', 'type': 'GENI'} + options['geni_rspec_version'] = { + 'version': '3', 'type': 'GENI'} descr = self.describe(urns, '', options) result = [] for sliver_allocation in descr['geni_slivers']: - geni_sliver = {'geni_sliver_urn': sliver_allocation['geni_sliver_urn'], - 'geni_allocation_status': 'geni_unallocated', - 'geni_expires': sliver_allocation['geni_expires'], - 'geni_error': sliver_allocation['geni_error']} - - result.append(geni_sliver) - - getattr(self.driver, "delete_sliver")(slice_urn, slice_hrn, creds, options) - - #SliverAllocation + geni_sliver = {'geni_sliver_urn': sliver_allocation['geni_sliver_urn'], + 'geni_allocation_status': 'geni_unallocated', + 'geni_expires': sliver_allocation['geni_expires'], + 'geni_error': sliver_allocation['geni_error']} + + result.append(geni_sliver) + + getattr(self.driver, "delete_sliver")( + slice_urn, slice_hrn, creds, options) + + # SliverAllocation constraints = SliverAllocation.slice_urn.in_(urns) dbsession = self.driver.api.dbsession() - sliver_allocations = dbsession.query(SliverAllocation).filter(constraints) - sliver_ids = [sliver_allocation.sliver_id for sliver_allocation in sliver_allocations] + sliver_allocations = dbsession.query( + SliverAllocation).filter(constraints) + sliver_ids = [ + sliver_allocation.sliver_id for sliver_allocation in sliver_allocations] SliverAllocation.delete_allocations(sliver_ids, dbsession) - elif name == "renew": (urns, expiration_time, options) = args - slice_urn = urns[0] + slice_urn = urns[0] slice_hrn, type = urn_to_hrn(slice_urn) creds = [] - getattr(self.driver, "renew_sliver")(slice_urn, slice_hrn, creds, expiration_time, options) + getattr(self.driver, "renew_sliver")( + slice_urn, slice_hrn, creds, expiration_time, options) - options['geni_rspec_version'] = {'version': '3', 'type': 'GENI'} + options['geni_rspec_version'] = { + 'version': '3', 'type': 'GENI'} descr = self.describe(urns, '', options) result = descr['geni_slivers'] - elif name == "perform_operational_action": (urns, action, options) = args - options['geni_rspec_version'] = {'version': '3', 'type': 'GENI'} + options['geni_rspec_version'] = { + 'version': '3', 'type': 'GENI'} result = self.describe(urns, '', options)['geni_slivers'] - - else: - # same as v2 ( registry methods) - result=getattr(self.driver, name)(*args, **kwds) + else: + # same as v2 ( registry methods) + result = getattr(self.driver, name)(*args, **kwds) return result return func diff --git a/sfa/methods/Allocate.py b/sfa/methods/Allocate.py index c8cbb0af..c2ecfc55 100644 --- a/sfa/methods/Allocate.py +++ b/sfa/methods/Allocate.py @@ -1,5 +1,5 @@ from sfa.util.faults import SfaInvalidArgument, InvalidRSpec, SfatablesRejected -from sfa.util.sfatime import datetime_to_string +from sfa.util.sfatime import datetime_to_string from sfa.util.xrn import Xrn, urn_to_hrn from sfa.util.method import Method from sfa.util.sfatablesRuntime import run_sfatables @@ -8,6 +8,7 @@ from sfa.storage.parameter import Parameter, Mixed from sfa.rspecs.rspec import RSpec from sfa.util.sfalogging import logger + class Allocate(Method): """ Allocate resources as described in a request RSpec argument @@ -50,37 +51,41 @@ class Allocate(Method): Parameter(type([dict]), "List of credentials"), Parameter(str, "RSpec"), Parameter(dict, "options"), - ] + ] returns = Parameter(str, "Allocated RSpec") def call(self, xrn, creds, rspec, options): xrn = Xrn(xrn, type='slice') # Find the valid credentials - valid_creds = self.api.auth.checkCredentialsSpeaksFor(creds, 'createsliver', xrn.get_hrn(), options=options) + valid_creds = self.api.auth.checkCredentialsSpeaksFor( + creds, 'createsliver', xrn.get_hrn(), options=options) the_credential = Credential(cred=valid_creds[0]) - # use the expiration from the first valid credential to determine when + # use the expiration from the first valid credential to determine when # the slivers should expire. expiration = datetime_to_string(the_credential.expiration) - - self.api.logger.debug("Allocate, received expiration from credential: %s"%expiration) + + self.api.logger.debug( + "Allocate, received expiration from credential: %s" % expiration) # turned off, as passing an empty rspec is indeed useful for cleaning up the slice # # make sure request is not empty # slivers = RSpec(rspec).version.get_nodes_with_slivers() # if not slivers: -# raise InvalidRSpec("Missing or element. Request rspec must explicitly allocate slivers") +# raise InvalidRSpec("Missing or element. Request rspec must explicitly allocate slivers") # flter rspec through sfatables if self.api.interface in ['aggregate']: chain_name = 'INCOMING' elif self.api.interface in ['slicemgr']: chain_name = 'FORWARD-INCOMING' - self.api.logger.debug("Allocate: sfatables on chain %s"%chain_name) + self.api.logger.debug("Allocate: sfatables on chain %s" % chain_name) actual_caller_hrn = the_credential.actual_caller_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, actual_caller_hrn, xrn.get_hrn(), self.name)) - rspec = run_sfatables(chain_name, xrn.get_hrn(), actual_caller_hrn, rspec) + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, actual_caller_hrn, xrn.get_hrn(), self.name)) + rspec = run_sfatables(chain_name, xrn.get_hrn(), + actual_caller_hrn, rspec) # turned off, as passing an empty rspec is indeed useful for cleaning up the slice # slivers = RSpec(rspec).version.get_nodes_with_slivers() # if not slivers: @@ -88,5 +93,6 @@ class Allocate(Method): # pass this to the driver code in case they need it options['actual_caller_hrn'] = actual_caller_hrn - result = self.api.manager.Allocate(self.api, xrn.get_urn(), creds, rspec, expiration, options) + result = self.api.manager.Allocate( + self.api, xrn.get_urn(), creds, rspec, expiration, options) return result diff --git a/sfa/methods/CreateGid.py b/sfa/methods/CreateGid.py index 0a2785f5..ab0c6bf1 100644 --- a/sfa/methods/CreateGid.py +++ b/sfa/methods/CreateGid.py @@ -5,40 +5,43 @@ from sfa.util.method import Method from sfa.storage.parameter import Parameter, Mixed from sfa.trust.credential import Credential + class CreateGid(Method): """ Create a signed credential for the object with the registry. In addition to being stored in the SFA database, the appropriate records will also be created in the PLC databases - + @param cred credential string @param xrn urn or hrn of certificate owner @param cert caller's certificate - + @return gid string representation """ interfaces = ['registry'] - + accepts = [ Mixed(Parameter(str, "Credential string"), Parameter(type([str]), "List of credentials")), Parameter(str, "URN or HRN of certificate owner"), Parameter(str, "Certificate string"), - ] + ] returns = Parameter(int, "String representation of gid object") - + def call(self, creds, xrn, cert=None): - # TODO: is there a better right to check for or is 'update good enough? + # TODO: is there a better right to check for or is 'update good enough? valid_creds = self.api.auth.checkCredentials(creds, 'update') # verify permissions hrn, type = urn_to_hrn(xrn) self.api.auth.verify_object_permission(hrn) - #log the call - origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, xrn, self.name)) + # log the call + origin_hrn = Credential( + string=valid_creds[0]).get_gid_caller().get_hrn() + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, xrn, self.name)) return self.api.manager.CreateGid(self.api, xrn, cert) diff --git a/sfa/methods/Delete.py b/sfa/methods/Delete.py index 593de28f..92042562 100644 --- a/sfa/methods/Delete.py +++ b/sfa/methods/Delete.py @@ -4,6 +4,7 @@ from sfa.storage.parameter import Parameter, Mixed from sfa.trust.auth import Auth from sfa.trust.credential import Credential + class Delete(Method): """ Remove the slice or slivers and free the allocated resources @@ -14,22 +15,24 @@ class Delete(Method): """ interfaces = ['aggregate', 'slicemgr', 'component'] - + accepts = [ - Parameter(type([str]), "Human readable name of slice to delete (hrn or urn)"), + Parameter( + type([str]), "Human readable name of slice to delete (hrn or urn)"), Parameter(type([dict]), "Credentials"), Parameter(dict, "options"), - ] + ] returns = Parameter(int, "1 if successful") - + def call(self, xrns, creds, options): valid_creds = self.api.auth.checkCredentialsSpeaksFor(creds, 'deletesliver', xrns, - check_sliver_callback = self.api.driver.check_sliver_credentials, + check_sliver_callback=self.api.driver.check_sliver_credentials, options=options) - #log the call + # log the call origin_hrn = Credential(cred=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, xrns, self.name)) + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, xrns, self.name)) return self.api.manager.Delete(self.api, xrns, creds, options) diff --git a/sfa/methods/Describe.py b/sfa/methods/Describe.py index 69d4fb7b..6930f1eb 100644 --- a/sfa/methods/Describe.py +++ b/sfa/methods/Describe.py @@ -8,6 +8,7 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class Describe(Method): """ Retrieve a manifest RSpec describing the resources contained by the @@ -21,43 +22,47 @@ class Describe(Method): interfaces = ['aggregate', 'slicemgr'] accepts = [ Parameter(type([str]), "List of URNs"), - Mixed(Parameter(str, "Credential string"), + Mixed(Parameter(str, "Credential string"), Parameter(type([str]), "List of credentials")), Parameter(dict, "Options") - ] + ] returns = Parameter(str, "List of resources") def call(self, urns, creds, options): - self.api.logger.info("interface: %s\tmethod-name: %s" % (self.api.interface, self.name)) - + self.api.logger.info("interface: %s\tmethod-name: %s" % + (self.api.interface, self.name)) + # client must specify a version if not options.get('geni_rspec_version'): if options.get('rspec_version'): options['geni_rspec_version'] = options['rspec_version'] else: - raise SfaInvalidArgument('Must specify an rspec version option. geni_rspec_version cannot be null') - valid_creds = self.api.auth.checkCredentialsSpeaksFor( - creds, 'listnodes', urns, - check_sliver_callback = self.api.driver.check_sliver_credentials, + raise SfaInvalidArgument( + 'Must specify an rspec version option. geni_rspec_version cannot be null') + valid_creds = self.api.auth.checkCredentialsSpeaksFor( + creds, 'listnodes', urns, + check_sliver_callback=self.api.driver.check_sliver_credentials, options=options) - # get hrn of the original caller + # get hrn of the original caller origin_hrn = options.get('origin_hrn', None) if not origin_hrn: - origin_hrn = Credential(cred=valid_creds[0]).get_gid_caller().get_hrn() + origin_hrn = Credential( + cred=valid_creds[0]).get_gid_caller().get_hrn() desc = self.api.manager.Describe(self.api, creds, urns, options) - # filter rspec through sfatables + # filter rspec through sfatables if self.api.interface in ['aggregate']: chain_name = 'OUTGOING' - elif self.api.interface in ['slicemgr']: + elif self.api.interface in ['slicemgr']: chain_name = 'FORWARD-OUTGOING' - self.api.logger.debug("ListResources: sfatables on chain %s"%chain_name) - desc['geni_rspec'] = run_sfatables(chain_name, '', origin_hrn, desc['geni_rspec']) - + self.api.logger.debug( + "ListResources: sfatables on chain %s" % chain_name) + desc['geni_rspec'] = run_sfatables( + chain_name, '', origin_hrn, desc['geni_rspec']) + if 'geni_compressed' in options and options['geni_compressed'] == True: - desc['geni_rspec'] = zlib.compress(desc['geni_rspec']).encode('base64') + desc['geni_rspec'] = zlib.compress( + desc['geni_rspec']).encode('base64') - return desc - - + return desc diff --git a/sfa/methods/GetCredential.py b/sfa/methods/GetCredential.py index 22d021cd..f7bc721a 100644 --- a/sfa/methods/GetCredential.py +++ b/sfa/methods/GetCredential.py @@ -5,6 +5,7 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class GetCredential(Method): """ Retrive a credential for an object @@ -18,19 +19,19 @@ class GetCredential(Method): """ interfaces = ['registry'] - + accepts = [ Mixed(Parameter(str, "Credential string"), - Parameter(type([str]), "List of credentials")), + Parameter(type([str]), "List of credentials")), Parameter(str, "Human readable name (hrn or urn)"), Mixed(Parameter(str, "Record type"), Parameter(None, "Type not specified")), - ] + ] returns = Parameter(str, "String representation of a credential object") def call(self, creds, xrn, type): - + if type: hrn = urn_to_hrn(xrn)[0] else: @@ -40,9 +41,10 @@ class GetCredential(Method): valid_creds = self.api.auth.checkCredentials(creds, 'getcredential') self.api.auth.verify_object_belongs_to_me(hrn) - #log the call - origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name)) + # log the call + origin_hrn = Credential( + string=valid_creds[0]).get_gid_caller().get_hrn() + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, hrn, self.name)) return self.api.manager.GetCredential(self.api, xrn, type, self.api.auth.client_gid.get_urn()) - diff --git a/sfa/methods/GetGids.py b/sfa/methods/GetGids.py index 623fbe92..197e5220 100644 --- a/sfa/methods/GetGids.py +++ b/sfa/methods/GetGids.py @@ -5,6 +5,7 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class GetGids(Method): """ Get a list of record information (hrn, gid and type) for @@ -16,30 +17,31 @@ class GetGids(Method): """ interfaces = ['registry'] - + accepts = [ - Mixed(Parameter(str, "Human readable name (hrn or xrn)"), + Mixed(Parameter(str, "Human readable name (hrn or xrn)"), Parameter(type([str]), "List of Human readable names (hrn or xrn)")), Mixed(Parameter(str, "Credential string"), - Parameter(type([str]), "List of credentials")), - ] + Parameter(type([str]), "List of credentials")), + ] returns = [Parameter(dict, "Dictionary of gids keyed on hrn")] - + def call(self, xrns, creds): # validate the credential valid_creds = self.api.auth.checkCredentials(creds, 'getgids') # xxxpylintxxx origin_hrn is unused.. - origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() - + origin_hrn = Credential( + string=valid_creds[0]).get_gid_caller().get_hrn() + # resolve the record - records = self.api.manager.Resolve(self.api, xrns, details = False) + records = self.api.manager.Resolve(self.api, xrns, details=False) if not records: raise RecordNotFound(xrns) - allowed_fields = ['hrn', 'type', 'gid'] + allowed_fields = ['hrn', 'type', 'gid'] for record in records: for key in record.keys(): if key not in allowed_fields: del(record[key]) - return records + return records diff --git a/sfa/methods/GetSelfCredential.py b/sfa/methods/GetSelfCredential.py index f3a96127..385d501f 100644 --- a/sfa/methods/GetSelfCredential.py +++ b/sfa/methods/GetSelfCredential.py @@ -6,6 +6,7 @@ from sfa.trust.certificate import Certificate from sfa.storage.parameter import Parameter, Mixed + class GetSelfCredential(Method): """ Retrive a credential for an object @@ -17,13 +18,13 @@ class GetSelfCredential(Method): """ interfaces = ['registry'] - + accepts = [ Parameter(str, "certificate"), Parameter(str, "Human readable name (hrn or urn)"), Mixed(Parameter(str, "Record type"), Parameter(None, "Type not specified")), - ] + ] returns = Parameter(str, "String representation of a credential object") @@ -46,28 +47,28 @@ class GetSelfCredential(Method): if type: hrn = urn_to_hrn(xrn)[0] else: - hrn, type = urn_to_hrn(xrn) + hrn, type = urn_to_hrn(xrn) self.api.auth.verify_object_belongs_to_me(hrn) origin_hrn = Certificate(string=cert).get_subject() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name)) - - - ### authenticate the gid + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, hrn, self.name)) + + # authenticate the gid # import here so we can load this module at build-time for sfa2wsdl #from sfa.storage.alchemy import dbsession from sfa.storage.model import RegRecord - # xxx-local - the current code runs Resolve, which would forward to + # xxx-local - the current code runs Resolve, which would forward to # another registry if needed - # I wonder if this is truly the intention, or shouldn't we instead + # I wonder if this is truly the intention, or shouldn't we instead # only look in the local db ? records = self.api.manager.Resolve(self.api, xrn, type, details=False) if not records: raise RecordNotFound(hrn) - record_obj = RegRecord (dict=records[0]) - # xxx-local the local-only version would read + record_obj = RegRecord(dict=records[0]) + # xxx-local the local-only version would read #record_obj = dbsession.query(RegRecord).filter_by(hrn=hrn).first() #if not record_obj: raise RecordNotFound(hrn) gid = record_obj.get_gid_object() @@ -76,11 +77,14 @@ class GetSelfCredential(Method): # authenticate the certificate against the gid in the db certificate = Certificate(string=cert) if not certificate.is_pubkey(gid.get_pubkey()): - for (obj,name) in [ (certificate,"CERT"), (gid,"GID"), ]: - self.api.logger.debug("ConnectionKeyGIDMismatch, %s pubkey: %s"%(name,obj.get_pubkey().get_pubkey_string())) - self.api.logger.debug("ConnectionKeyGIDMismatch, %s dump: %s"%(name,obj.dump_string())) - if hasattr (obj,'filename'): - self.api.logger.debug("ConnectionKeyGIDMismatch, %s filename: %s"%(name,obj.filename)) + for (obj, name) in [(certificate, "CERT"), (gid, "GID"), ]: + self.api.logger.debug("ConnectionKeyGIDMismatch, %s pubkey: %s" % ( + name, obj.get_pubkey().get_pubkey_string())) + self.api.logger.debug( + "ConnectionKeyGIDMismatch, %s dump: %s" % (name, obj.dump_string())) + if hasattr(obj, 'filename'): + self.api.logger.debug( + "ConnectionKeyGIDMismatch, %s filename: %s" % (name, obj.filename)) raise ConnectionKeyGIDMismatch(gid.get_subject()) - + return self.api.manager.GetCredential(self.api, xrn, type) diff --git a/sfa/methods/GetVersion.py b/sfa/methods/GetVersion.py index f043992e..4bbc86d8 100644 --- a/sfa/methods/GetVersion.py +++ b/sfa/methods/GetVersion.py @@ -8,14 +8,16 @@ class GetVersion(Method): Returns this GENI Aggregate Manager's Version Information @return version """ - interfaces = ['registry','aggregate', 'slicemgr', 'component'] + interfaces = ['registry', 'aggregate', 'slicemgr', 'component'] accepts = [ Parameter(dict, "Options") - ] + ] returns = Parameter(dict, "Version information") # API v2 specifies options is optional, so.. def call(self, options=None): - if options is None: options={} - self.api.logger.info("interface: %s\tmethod-name: %s" % (self.api.interface, self.name)) + if options is None: + options = {} + self.api.logger.info("interface: %s\tmethod-name: %s" % + (self.api.interface, self.name)) return self.api.manager.GetVersion(self.api, options) diff --git a/sfa/methods/List.py b/sfa/methods/List.py index 83d7a6e0..54b74082 100644 --- a/sfa/methods/List.py +++ b/sfa/methods/List.py @@ -6,6 +6,7 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class List(Method): """ List the records in an authority. @@ -15,23 +16,26 @@ class List(Method): @return list of record dictionaries """ interfaces = ['registry'] - + accepts = [ Parameter(str, "Human readable name (hrn or urn)"), Mixed(Parameter(str, "Credential string"), Parameter(type([str]), "List of credentials")), - ] + ] # xxx used to be [SfaRecord] returns = [Parameter(dict, "registry record")] - + def call(self, xrn, creds, options=None): - if options is None: options={} + if options is None: + options = {} hrn, type = urn_to_hrn(xrn) valid_creds = self.api.auth.checkCredentials(creds, 'list') - #log the call - origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name)) - - return self.api.manager.List(self.api, xrn, options=options) + # log the call + origin_hrn = Credential( + string=valid_creds[0]).get_gid_caller().get_hrn() + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, hrn, self.name)) + + return self.api.manager.List(self.api, xrn, options=options) diff --git a/sfa/methods/ListResources.py b/sfa/methods/ListResources.py index 795594be..e2ec988b 100644 --- a/sfa/methods/ListResources.py +++ b/sfa/methods/ListResources.py @@ -8,6 +8,7 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class ListResources(Method): """ Returns information about available resources @@ -17,42 +18,45 @@ class ListResources(Method): """ interfaces = ['aggregate', 'slicemgr'] accepts = [ - Mixed(Parameter(str, "Credential string"), + Mixed(Parameter(str, "Credential string"), Parameter(type([str]), "List of credentials")), Parameter(dict, "Options") - ] + ] returns = Parameter(str, "List of resources") def call(self, creds, options): - self.api.logger.info("interface: %s\tmethod-name: %s" % (self.api.interface, self.name)) - + self.api.logger.info("interface: %s\tmethod-name: %s" % + (self.api.interface, self.name)) + # client must specify a version if not options.get('geni_rspec_version'): if options.get('rspec_version'): options['geni_rspec_version'] = options['rspec_version'] else: - raise SfaInvalidArgument('Must specify an rspec version option. geni_rspec_version cannot be null') + raise SfaInvalidArgument( + 'Must specify an rspec version option. geni_rspec_version cannot be null') # Find the valid credentials - valid_creds = self.api.auth.checkCredentialsSpeaksFor(creds, 'listnodes', options=options) + valid_creds = self.api.auth.checkCredentialsSpeaksFor( + creds, 'listnodes', options=options) - # get hrn of the original caller + # get hrn of the original caller origin_hrn = options.get('origin_hrn', None) if not origin_hrn: - origin_hrn = Credential(cred=valid_creds[0]).get_gid_caller().get_hrn() + origin_hrn = Credential( + cred=valid_creds[0]).get_gid_caller().get_hrn() rspec = self.api.manager.ListResources(self.api, creds, options) - # filter rspec through sfatables + # filter rspec through sfatables if self.api.interface in ['aggregate']: chain_name = 'OUTGOING' - elif self.api.interface in ['slicemgr']: + elif self.api.interface in ['slicemgr']: chain_name = 'FORWARD-OUTGOING' - self.api.logger.debug("ListResources: sfatables on chain %s"%chain_name) - filtered_rspec = run_sfatables(chain_name, '', origin_hrn, rspec) - + self.api.logger.debug( + "ListResources: sfatables on chain %s" % chain_name) + filtered_rspec = run_sfatables(chain_name, '', origin_hrn, rspec) + if 'geni_compressed' in options and options['geni_compressed'] == True: filtered_rspec = zlib.compress(filtered_rspec).encode('base64') - return filtered_rspec - - + return filtered_rspec diff --git a/sfa/methods/PerformOperationalAction.py b/sfa/methods/PerformOperationalAction.py index 41bf58fd..1081b0e1 100644 --- a/sfa/methods/PerformOperationalAction.py +++ b/sfa/methods/PerformOperationalAction.py @@ -5,6 +5,7 @@ from sfa.util.sfatablesRuntime import run_sfatables from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class PerformOperationalAction(Method): """ Request that the named geni_allocated slivers be made @@ -17,7 +18,7 @@ class PerformOperationalAction(Method): @param slice urns ([string]) URNs of slivers to provision to @param credentials (dict) of credentials @param options (dict) options - + """ interfaces = ['aggregate', 'slicemgr'] accepts = [ @@ -25,19 +26,22 @@ class PerformOperationalAction(Method): Parameter(type([dict]), "Credentials"), Parameter(str, "Action"), Parameter(dict, "Options"), - ] + ] returns = Parameter(dict, "Provisioned Resources") def call(self, xrns, creds, action, options): - self.api.logger.info("interface: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, xrns, self.name)) + self.api.logger.info("interface: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, xrns, self.name)) (speaking_for, _) = urn_to_hrn(options.get('geni_speaking_for')) - + # Find the valid credentials valid_creds = self.api.auth.checkCredentialsSpeaksFor(creds, 'createsliver', xrns, - check_sliver_callback = self.api.driver.check_sliver_credentials, - options=options) + check_sliver_callback=self.api.driver.check_sliver_credentials, + options=options) origin_hrn = Credential(cred=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, xrns, self.name)) - result = self.api.manager.PerformOperationalAction(self.api, xrns, creds, action, options) + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, xrns, self.name)) + result = self.api.manager.PerformOperationalAction( + self.api, xrns, creds, action, options) return result diff --git a/sfa/methods/Provision.py b/sfa/methods/Provision.py index 7177854d..ad196692 100644 --- a/sfa/methods/Provision.py +++ b/sfa/methods/Provision.py @@ -6,6 +6,7 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed from sfa.rspecs.rspec import RSpec + class Provision(Method): """ Request that the named geni_allocated slivers be made @@ -18,24 +19,26 @@ class Provision(Method): @param slice urns ([string]) URNs of slivers to provision to @param credentials (dict) of credentials @param options (dict) options - + """ interfaces = ['aggregate', 'slicemgr'] accepts = [ Parameter(type([str]), "URNs"), Parameter(type([dict]), "Credentials"), Parameter(dict, "options"), - ] + ] returns = Parameter(dict, "Provisioned Resources") def call(self, xrns, creds, options): - self.api.logger.info("interface: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, xrns, self.name)) + self.api.logger.info("interface: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, xrns, self.name)) # Find the valid credentials valid_creds = self.api.auth.checkCredentialsSpeaksFor(creds, 'createsliver', xrns, - check_sliver_callback = self.api.driver.check_sliver_credentials, + check_sliver_callback=self.api.driver.check_sliver_credentials, options=options) origin_hrn = Credential(cred=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, xrns, self.name)) + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, xrns, self.name)) result = self.api.manager.Provision(self.api, xrns, creds, options) return result diff --git a/sfa/methods/Register.py b/sfa/methods/Register.py index 87f9a0f8..8f573546 100644 --- a/sfa/methods/Register.py +++ b/sfa/methods/Register.py @@ -4,38 +4,41 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class Register(Method): """ Register an object with the registry. In addition to being stored in the SFA database, the appropriate records will also be created in the PLC databases - + @param cred credential string @param record_dict dictionary containing record fields - + @return gid string representation """ interfaces = ['registry'] - + accepts = [ Parameter(dict, "Record dictionary containing record fields"), Mixed(Parameter(str, "Credential string"), Parameter(type([str]), "List of credentials")), - ] + ] returns = Parameter(int, "String representation of gid object") - + def call(self, record, creds): - # validate cred + # validate cred valid_creds = self.api.auth.checkCredentials(creds, 'register') - + # verify permissions hrn = record.get('hrn', '') self.api.auth.verify_object_permission(hrn) - #log the call - origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name)) - + # log the call + origin_hrn = Credential( + string=valid_creds[0]).get_gid_caller().get_hrn() + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, hrn, self.name)) + return self.api.manager.Register(self.api, record) diff --git a/sfa/methods/Remove.py b/sfa/methods/Remove.py index ee650f1c..a0614498 100644 --- a/sfa/methods/Remove.py +++ b/sfa/methods/Remove.py @@ -5,11 +5,12 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class Remove(Method): """ Remove an object from the registry. If the object represents a PLC object, then the PLC records will also be removed. - + @param cred credential string @param type record type @param xrn human readable name of record to remove (hrn or urn) @@ -18,27 +19,28 @@ class Remove(Method): """ interfaces = ['registry'] - + accepts = [ Parameter(str, "Human readable name of slice to instantiate (hrn or urn)"), Mixed(Parameter(str, "Credential string"), Parameter(type([str]), "List of credentials")), Mixed(Parameter(str, "Record type"), Parameter(None, "Type not specified")), - ] + ] returns = Parameter(int, "1 if successful") - + def call(self, xrn, creds, type): - xrn=Xrn(xrn,type=type) - + xrn = Xrn(xrn, type=type) + # validate the cred valid_creds = self.api.auth.checkCredentials(creds, "remove") self.api.auth.verify_object_permission(xrn.get_hrn()) - #log the call - origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tmethod-name: %s\tcaller-hrn: %s\ttarget-urn: %s"%( - self.api.interface, self.name, origin_hrn, xrn.get_urn())) + # log the call + origin_hrn = Credential( + string=valid_creds[0]).get_gid_caller().get_hrn() + self.api.logger.info("interface: %s\tmethod-name: %s\tcaller-hrn: %s\ttarget-urn: %s" % ( + self.api.interface, self.name, origin_hrn, xrn.get_urn())) - return self.api.manager.Remove(self.api, xrn) + return self.api.manager.Remove(self.api, xrn) diff --git a/sfa/methods/Renew.py b/sfa/methods/Renew.py index 0a7ac1af..e3ae1c30 100644 --- a/sfa/methods/Renew.py +++ b/sfa/methods/Renew.py @@ -9,11 +9,12 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter + class Renew(Method): """ Renews the resources in the specified slice or slivers by extending the lifetime. - + @param urns ([string]) List of URNs of to renew @param credentials ([string]) of credentials @param expiration_time (string) requested time of expiration @@ -25,44 +26,45 @@ class Renew(Method): Parameter(type([str]), "List of credentials"), Parameter(str, "Expiration time in RFC 3339 format"), Parameter(dict, "Options"), - ] + ] returns = Parameter(bool, "Success or Failure") def call(self, urns, creds, expiration_time, options): - # Find the valid credentials valid_creds = self.api.auth.checkCredentialsSpeaksFor(creds, 'renewsliver', urns, - check_sliver_callback = self.api.driver.check_sliver_credentials, + check_sliver_callback=self.api.driver.check_sliver_credentials, options=options) the_credential = Credential(cred=valid_creds[0]) actual_caller_hrn = the_credential.actual_caller_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-urns: %s\texpiration:%s\tmethod-name: %s"%\ - (self.api.interface, actual_caller_hrn, urns, expiration_time,self.name)) - + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-urns: %s\texpiration:%s\tmethod-name: %s" % + (self.api.interface, actual_caller_hrn, urns, expiration_time, self.name)) - # extend as long as possible : take the min of requested and now+SFA_MAX_SLICE_RENEW + # extend as long as possible : take the min of requested and + # now+SFA_MAX_SLICE_RENEW if options.get('geni_extend_alap'): # ignore requested time and set to max - expiration_time = add_datetime(datetime.datetime.utcnow(), days=int(self.api.config.SFA_MAX_SLICE_RENEW)) + expiration_time = add_datetime(datetime.datetime.utcnow( + ), days=int(self.api.config.SFA_MAX_SLICE_RENEW)) - # Validate that the time does not go beyond the credential's expiration time + # Validate that the time does not go beyond the credential's expiration + # time requested_expire = utcparse(expiration_time) - self.api.logger.info("requested_expire = %s"%requested_expire) + self.api.logger.info("requested_expire = %s" % requested_expire) credential_expire = the_credential.get_expiration() - self.api.logger.info("credential_expire = %s"%credential_expire) + self.api.logger.info("credential_expire = %s" % credential_expire) max_renew_days = int(self.api.config.SFA_MAX_SLICE_RENEW) - max_expire = datetime.datetime.utcnow() + datetime.timedelta (days=max_renew_days) + max_expire = datetime.datetime.utcnow() + datetime.timedelta(days=max_renew_days) if requested_expire > credential_expire: - # used to throw an InsufficientRights exception here, this was not right - self.api.logger.warning("Requested expiration %s, after credential expiration (%s) -> trimming to the latter/sooner"%\ + # used to throw an InsufficientRights exception here, this was not + # right + self.api.logger.warning("Requested expiration %s, after credential expiration (%s) -> trimming to the latter/sooner" % (requested_expire, credential_expire)) requested_expire = credential_expire if requested_expire > max_expire: # likewise - self.api.logger.warning("Requested expiration %s, after maximal expiration %s days (%s) -> trimming to the latter/sooner"%\ - (requested_expire, self.api.config.SFA_MAX_SLICE_RENEW,max_expire)) + self.api.logger.warning("Requested expiration %s, after maximal expiration %s days (%s) -> trimming to the latter/sooner" % + (requested_expire, self.api.config.SFA_MAX_SLICE_RENEW, max_expire)) requested_expire = max_expire return self.api.manager.Renew(self.api, urns, creds, requested_expire, options) - diff --git a/sfa/methods/Resolve.py b/sfa/methods/Resolve.py index fc12df1b..05bffcbb 100644 --- a/sfa/methods/Resolve.py +++ b/sfa/methods/Resolve.py @@ -5,6 +5,7 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class Resolve(Method): """ Resolve a record. @@ -15,7 +16,7 @@ class Resolve(Method): """ interfaces = ['registry'] - + # should we not accept an optional 'details' argument ? accepts = [ Mixed(Parameter(str, "Human readable name (hrn or urn)"), @@ -23,29 +24,33 @@ class Resolve(Method): Mixed(Parameter(str, "Credential string"), Parameter(list, "List of credentials)")), Parameter(dict, "options"), - ] + ] # xxx used to be [SfaRecord] returns = [Parameter(dict, "registry record")] - + def call(self, xrns, creds, options=None): - if options is None: options={} - # use details=False by default, only when explicitly specified do we want + if options is None: + options = {} + # use details=False by default, only when explicitly specified do we want # to mess with the testbed details - if 'details' in options: details=options['details'] - else: details=False + if 'details' in options: + details = options['details'] + else: + details = False type = None if not isinstance(xrns, list): type = Xrn(xrns).get_type() - xrns=[xrns] + xrns = [xrns] hrns = [urn_to_hrn(xrn)[0] for xrn in xrns] - #find valid credentials + # find valid credentials valid_creds = self.api.auth.checkCredentials(creds, 'resolve') - #log the call - origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrns, self.name)) - + # log the call + origin_hrn = Credential( + string=valid_creds[0]).get_gid_caller().get_hrn() + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, hrns, self.name)) + # send the call to the right manager return self.api.manager.Resolve(self.api, xrns, type, details=details) - diff --git a/sfa/methods/ResolveGENI.py b/sfa/methods/ResolveGENI.py index e32bf6de..5e6eac68 100644 --- a/sfa/methods/ResolveGENI.py +++ b/sfa/methods/ResolveGENI.py @@ -2,6 +2,7 @@ from sfa.util.method import Method from sfa.storage.parameter import Parameter + class ResolveGENI(Method): """ Lookup a URN and return information about the corresponding object. @@ -12,7 +13,7 @@ class ResolveGENI(Method): accepts = [ Parameter(str, "URN"), Parameter(type([str]), "List of credentials"), - ] + ] returns = Parameter(bool, "Success or Failure") def call(self, xrn): diff --git a/sfa/methods/Shutdown.py b/sfa/methods/Shutdown.py index f6f1841f..7b086af3 100644 --- a/sfa/methods/Shutdown.py +++ b/sfa/methods/Shutdown.py @@ -2,6 +2,7 @@ from sfa.storage.parameter import Parameter from sfa.trust.credential import Credential from sfa.util.method import Method + class Shutdown(Method): """ Perform an emergency shut down of a sliver. This operation is intended for administrative use. @@ -14,16 +15,16 @@ class Shutdown(Method): accepts = [ Parameter(str, "Slice URN"), Parameter(type([dict]), "Credentials"), - ] + ] returns = Parameter(bool, "Success or Failure") def call(self, xrn, creds): valid_creds = self.api.auth.checkCredentials(creds, 'stopslice', xrn, - check_sliver_callback = self.api.driver.check_sliver_credentials) - #log the call + check_sliver_callback=self.api.driver.check_sliver_credentials) + # log the call origin_hrn = Credential(cred=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, xrn, self.name)) + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, xrn, self.name)) return self.api.manager.Shutdown(self.api, xrn, creds) - diff --git a/sfa/methods/Status.py b/sfa/methods/Status.py index 68d928e6..dd15f5df 100644 --- a/sfa/methods/Status.py +++ b/sfa/methods/Status.py @@ -3,26 +3,27 @@ from sfa.util.method import Method from sfa.storage.parameter import Parameter, Mixed + class Status(Method): """ Get the status of a sliver - + @param slice_urn (string) URN of slice to allocate to - + """ interfaces = ['aggregate', 'slicemgr', 'component'] accepts = [ Parameter(type([str]), "Slice or sliver URNs"), Parameter(type([dict]), "credentials"), Parameter(dict, "Options") - ] + ] returns = Parameter(dict, "Status details") def call(self, xrns, creds, options): valid_creds = self.api.auth.checkCredentialsSpeaksFor(creds, 'sliverstatus', xrns, - check_sliver_callback = self.api.driver.check_sliver_credentials, + check_sliver_callback=self.api.driver.check_sliver_credentials, options=options) - self.api.logger.info("interface: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, xrns, self.name)) + self.api.logger.info("interface: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, xrns, self.name)) return self.api.manager.Status(self.api, xrns, creds, options) - diff --git a/sfa/methods/Update.py b/sfa/methods/Update.py index e0b1003e..507509cd 100644 --- a/sfa/methods/Update.py +++ b/sfa/methods/Update.py @@ -4,12 +4,13 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter + class Update(Method): """ Update an object in the registry. Currently, this only updates the PLC information associated with the record. The SFA fields (name, type, GID) are fixed. - + @param cred credential string specifying rights of the caller @param record a record dictionary to be updated @@ -17,25 +18,26 @@ class Update(Method): """ interfaces = ['registry'] - + accepts = [ Parameter(dict, "Record dictionary to be updated"), Parameter(str, "Credential string"), - ] + ] returns = Parameter(int, "1 if successful") - + def call(self, record_dict, creds): # validate the cred valid_creds = self.api.auth.checkCredentials(creds, "update") - + # verify permissions - hrn = record_dict.get('hrn', '') + hrn = record_dict.get('hrn', '') self.api.auth.verify_object_permission(hrn) - + # log - origin_hrn = Credential(string=valid_creds[0]).get_gid_caller().get_hrn() - self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s"%(self.api.interface, origin_hrn, hrn, self.name)) - - return self.api.manager.Update(self.api, record_dict) + origin_hrn = Credential( + string=valid_creds[0]).get_gid_caller().get_hrn() + self.api.logger.info("interface: %s\tcaller-hrn: %s\ttarget-hrn: %s\tmethod-name: %s" % + (self.api.interface, origin_hrn, hrn, self.name)) + return self.api.manager.Update(self.api, record_dict) diff --git a/sfa/methods/__init__.py b/sfa/methods/__init__.py index 5113465b..d407c4c1 100644 --- a/sfa/methods/__init__.py +++ b/sfa/methods/__init__.py @@ -1,4 +1,4 @@ -## Please use make index to update this file +# Please use make index to update this file all = """ Allocate CreateGid diff --git a/sfa/methods/get_key_from_incoming_ip.py b/sfa/methods/get_key_from_incoming_ip.py index 5887f998..860220bc 100644 --- a/sfa/methods/get_key_from_incoming_ip.py +++ b/sfa/methods/get_key_from_incoming_ip.py @@ -3,6 +3,7 @@ from sfa.util.sfalogging import logger from sfa.storage.parameter import Parameter + class get_key_from_incoming_ip(Method): """ Generate a new keypair and gid for requesting caller (component/node). @@ -11,14 +12,15 @@ class get_key_from_incoming_ip(Method): """ interfaces = ['registry'] - + accepts = [] returns = Parameter(int, "1 if successful, faults otherwise") - + def call(self): - if hasattr(self.api.manager,'get_key_from_incoming_ip'): - return self.api.manager.get_key_from_incoming_ip (api) + if hasattr(self.api.manager, 'get_key_from_incoming_ip'): + return self.api.manager.get_key_from_incoming_ip(api) else: - logger.warning("get_key_from_incoming_ip not supported by registry manager") + logger.warning( + "get_key_from_incoming_ip not supported by registry manager") return 0 diff --git a/sfa/methods/get_trusted_certs.py b/sfa/methods/get_trusted_certs.py index 7a4e1c54..53a687f3 100644 --- a/sfa/methods/get_trusted_certs.py +++ b/sfa/methods/get_trusted_certs.py @@ -5,6 +5,7 @@ from sfa.trust.credential import Credential from sfa.storage.parameter import Parameter, Mixed + class get_trusted_certs(Method): """ @param cred credential string specifying the rights of the caller @@ -12,29 +13,29 @@ class get_trusted_certs(Method): """ interfaces = ['registry', 'aggregate', 'slicemgr'] - + accepts = [ Mixed(Parameter(str, "Credential string"), Parameter(None, "Credential not specified")) - ] + ] returns = Parameter(type([str]), "List of GID strings") - - def call(self, cred = None): + + def call(self, cred=None): # If cred is not specified just return the gid for this interface. # This is true when when a peer is attempting to initiate federation - # with this interface - self.api.logger.debug("get_trusted_certs: %r"%cred) + # with this interface + self.api.logger.debug("get_trusted_certs: %r" % cred) if not cred: gid_strings = [] for gid in self.api.auth.trusted_cert_list: if gid.get_hrn() == self.api.config.SFA_INTERFACE_HRN: - gid_strings.append(gid.save_to_string(save_parents=True)) + gid_strings.append(gid.save_to_string(save_parents=True)) return gid_strings # authenticate the cred self.api.auth.check(cred, 'gettrustedcerts') - gid_strings = [gid.save_to_string(save_parents=True) for \ - gid in self.api.auth.trusted_cert_list] - - return gid_strings + gid_strings = [gid.save_to_string(save_parents=True) for + gid in self.api.auth.trusted_cert_list] + + return gid_strings diff --git a/sfa/nitos/nitosaggregate.py b/sfa/nitos/nitosaggregate.py index 832a2c7c..4201a897 100644 --- a/sfa/nitos/nitosaggregate.py +++ b/sfa/nitos/nitosaggregate.py @@ -24,12 +24,12 @@ from sfa.planetlab.topology import Topology import time + class NitosAggregate: def __init__(self, driver): self.driver = driver - def get_slice_and_slivers(self, slice_xrn): """ Returns a dict of slivers keyed on the sliver's node_id @@ -42,37 +42,38 @@ class NitosAggregate: slice_hrn, _ = urn_to_hrn(slice_xrn) slice_name = hrn_to_nitos_slicename(slice_hrn) slices = self.driver.shell.getSlices({'slice_name': slice_name}, []) - #filter results + # filter results for slc in slices: - if slc['slice_name'] == slice_name: - slice = slc - break + if slc['slice_name'] == slice_name: + slice = slc + break if not slice: return (slice, slivers) - - reserved_nodes = self.driver.shell.getReservedNodes({'slice_id': slice['slice_id']}, []) + + reserved_nodes = self.driver.shell.getReservedNodes( + {'slice_id': slice['slice_id']}, []) reserved_node_ids = [] # filter on the slice for node in reserved_nodes: - if node['slice_id'] == slice['slice_id']: - reserved_node_ids.append(node['node_id']) - #get all the nodes + if node['slice_id'] == slice['slice_id']: + reserved_node_ids.append(node['node_id']) + # get all the nodes all_nodes = self.driver.shell.getNodes({}, []) - + for node in all_nodes: - if node['node_id'] in reserved_node_ids: - slivers[node['node_id']] = node - - return (slice, slivers) - + if node['node_id'] in reserved_node_ids: + slivers[node['node_id']] = node + return (slice, slivers) - def get_nodes(self, slice_xrn, slice=None,slivers=None, options=None): - if slivers is None: slivers={} - if options is None: options={} - # if we are dealing with a slice that has no node just return - # and empty list + def get_nodes(self, slice_xrn, slice=None, slivers=None, options=None): + if slivers is None: + slivers = {} + if options is None: + options = {} + # if we are dealing with a slice that has no node just return + # and empty list if slice_xrn: if not slice or not slivers: return [] @@ -80,34 +81,38 @@ class NitosAggregate: nodes = [slivers[sliver] for sliver in slivers] else: nodes = self.driver.shell.getNodes({}, []) - + # get the granularity in second for the reservation system grain = self.driver.testbedInfo['grain'] #grain = 1800 - rspec_nodes = [] for node in nodes: rspec_node = NodeElement() site_name = self.driver.testbedInfo['name'] - rspec_node['component_id'] = hostname_to_urn(self.driver.hrn, site_name, node['hostname']) + rspec_node['component_id'] = hostname_to_urn( + self.driver.hrn, site_name, node['hostname']) rspec_node['component_name'] = node['hostname'] - rspec_node['component_manager_id'] = Xrn(self.driver.hrn, 'authority+cm').get_urn() - rspec_node['authority_id'] = hrn_to_urn(NitosXrn.site_hrn(self.driver.hrn, site_name), 'authority+sa') + rspec_node['component_manager_id'] = Xrn( + self.driver.hrn, 'authority+cm').get_urn() + rspec_node['authority_id'] = hrn_to_urn( + NitosXrn.site_hrn(self.driver.hrn, site_name), 'authority+sa') # do not include boot state ( element) in the manifest rspec - #if not slice: + # if not slice: # rspec_node['boot_state'] = node['boot_state'] rspec_node['exclusive'] = 'true' # site location longitude = self.driver.testbedInfo['longitude'] - latitude = self.driver.testbedInfo['latitude'] - if longitude and latitude: - location = Location({'longitude': longitude, 'latitude': latitude, 'country': 'unknown'}) + latitude = self.driver.testbedInfo['latitude'] + if longitude and latitude: + location = Location( + {'longitude': longitude, 'latitude': latitude, 'country': 'unknown'}) rspec_node['location'] = location # 3D position - position_3d = Position3D({'x': node['position']['X'], 'y': node['position']['Y'], 'z': node['position']['Z']}) + position_3d = Position3D({'x': node['position']['X'], 'y': node[ + 'position']['Y'], 'z': node['position']['Z']}) #position_3d = Position3D({'x': 1, 'y': 2, 'z': 3}) - rspec_node['position_3d'] = position_3d + rspec_node['position_3d'] = position_3d # Granularity granularity = Granularity({'grain': grain}) rspec_node['granularity'] = granularity @@ -115,8 +120,8 @@ class NitosAggregate: # HardwareType rspec_node['hardware_type'] = node['node_type'] #rspec_node['hardware_type'] = "orbit" - - #slivers + + # slivers if node['node_id'] in slivers: # add sliver info sliver = slivers[node['node_id']] @@ -124,13 +129,13 @@ class NitosAggregate: rspec_node['client_id'] = node['hostname'] rspec_node['slivers'] = [sliver] - rspec_nodes.append(rspec_node) - return rspec_nodes + return rspec_nodes def get_leases_and_channels(self, slice=None, slice_xrn=None, options=None): - if options is None: options={} + if options is None: + options = {} slices = self.driver.shell.getSlices({}, []) nodes = self.driver.shell.getNodes({}, []) leases = self.driver.shell.getReservedNodes({}, []) @@ -147,80 +152,85 @@ class NitosAggregate: all_reserved_channels = [] all_reserved_channels.extend(reserved_channels) for lease in all_leases: - if lease['slice_id'] != slice['slice_id']: - leases.remove(lease) + if lease['slice_id'] != slice['slice_id']: + leases.remove(lease) for channel in all_reserved_channels: - if channel['slice_id'] != slice['slice_id']: - reserved_channels.remove(channel) + if channel['slice_id'] != slice['slice_id']: + reserved_channels.remove(channel) rspec_channels = [] for channel in reserved_channels: - + rspec_channel = {} - #retrieve channel number + # retrieve channel number for chl in channels: - if chl['channel_id'] == channel['channel_id']: - channel_number = chl['channel'] - break + if chl['channel_id'] == channel['channel_id']: + channel_number = chl['channel'] + break rspec_channel['channel_num'] = channel_number rspec_channel['start_time'] = channel['start_time'] - rspec_channel['duration'] = (int(channel['end_time']) - int(channel['start_time'])) / int(grain) - rspec_channel['component_id'] = channel_to_urn(self.driver.hrn, self.driver.testbedInfo['name'], channel_number) - + rspec_channel['duration'] = ( + int(channel['end_time']) - int(channel['start_time'])) / int(grain) + rspec_channel['component_id'] = channel_to_urn( + self.driver.hrn, self.driver.testbedInfo['name'], channel_number) + # retreive slicename for slc in slices: - if slc['slice_id'] == channel['slice_id']: - slicename = slc['slice_name'] - break + if slc['slice_id'] == channel['slice_id']: + slicename = slc['slice_name'] + break if slice_xrn: slice_urn = slice_xrn slice_hrn = urn_to_hrn(slice_urn) else: - slice_hrn = slicename_to_hrn(self.driver.hrn, self.driver.testbedInfo['name'], slicename) + slice_hrn = slicename_to_hrn( + self.driver.hrn, self.driver.testbedInfo['name'], slicename) slice_urn = hrn_to_urn(slice_hrn, 'slice') rspec_channel['slice_id'] = slice_urn rspec_channels.append(rspec_channel) - rspec_leases = [] for lease in leases: rspec_lease = Lease() - + rspec_lease['lease_id'] = lease['reservation_id'] # retreive node name for node in nodes: - if node['node_id'] == lease['node_id']: - nodename = node['hostname'] - break - - rspec_lease['component_id'] = hostname_to_urn(self.driver.hrn, self.driver.testbedInfo['name'], nodename) + if node['node_id'] == lease['node_id']: + nodename = node['hostname'] + break + + rspec_lease['component_id'] = hostname_to_urn( + self.driver.hrn, self.driver.testbedInfo['name'], nodename) # retreive slicename for slc in slices: - if slc['slice_id'] == lease['slice_id']: - slicename = slc['slice_name'] - break - + if slc['slice_id'] == lease['slice_id']: + slicename = slc['slice_name'] + break + if slice_xrn: slice_urn = slice_xrn slice_hrn = urn_to_hrn(slice_urn) else: - slice_hrn = slicename_to_hrn(self.driver.hrn, self.driver.testbedInfo['name'], slicename) + slice_hrn = slicename_to_hrn( + self.driver.hrn, self.driver.testbedInfo['name'], slicename) slice_urn = hrn_to_urn(slice_hrn, 'slice') rspec_lease['slice_id'] = slice_urn rspec_lease['start_time'] = lease['start_time'] - rspec_lease['duration'] = (int(lease['end_time']) - int(lease['start_time'])) / int(grain) + rspec_lease['duration'] = ( + int(lease['end_time']) - int(lease['start_time'])) / int(grain) rspec_leases.append(rspec_lease) return (rspec_leases, rspec_channels) - def get_channels(self, slice=None, options=None): - if options is None: options={} + if options is None: + options = {} all_channels = self.driver.shell.getChannels({}, []) channels = [] @@ -228,12 +238,12 @@ class NitosAggregate: reserved_channels = self.driver.shell.getReservedChannels() reserved_channel_ids = [] for channel in reserved_channels: - if channel['slice_id'] == slice['slice_id']: - reserved_channel_ids.append(channel['channel_id']) + if channel['slice_id'] == slice['slice_id']: + reserved_channel_ids.append(channel['channel_id']) for channel in all_channels: - if channel['channel_id'] in reserved_channel_ids: - channels.append(channel) + if channel['channel_id'] in reserved_channel_ids: + channels.append(channel) else: channels = all_channels @@ -243,48 +253,50 @@ class NitosAggregate: rspec_channel['channel_num'] = channel['channel'] rspec_channel['frequency'] = channel['frequency'] rspec_channel['standard'] = channel['modulation'] - rspec_channel['component_id'] = channel_to_urn(self.driver.hrn, self.driver.testbedInfo['name'], channel['channel']) + rspec_channel['component_id'] = channel_to_urn( + self.driver.hrn, self.driver.testbedInfo['name'], channel['channel']) rspec_channels.append(rspec_channel) return rspec_channels - - - def get_rspec(self, slice_xrn=None, version = None, options=None): - if options is None: options={} + def get_rspec(self, slice_xrn=None, version=None, options=None): + if options is None: + options = {} version_manager = VersionManager() version = version_manager.get_version(version) if not slice_xrn: - rspec_version = version_manager._get_version(version.type, version.version, 'ad') + rspec_version = version_manager._get_version( + version.type, version.version, 'ad') else: - rspec_version = version_manager._get_version(version.type, version.version, 'manifest') + rspec_version = version_manager._get_version( + version.type, version.version, 'manifest') slice, slivers = self.get_slice_and_slivers(slice_xrn) rspec = RSpec(version=rspec_version, user_options=options) if slice and 'expires' in slice: - rspec.xml.set('expires', datetime_to_string(utcparse(slice['expires']))) + rspec.xml.set('expires', datetime_to_string( + utcparse(slice['expires']))) if not options.get('list_leases') or options.get('list_leases') and options['list_leases'] != 'leases': - nodes = self.get_nodes(slice_xrn, slice, slivers, options) - rspec.version.add_nodes(nodes) - # add sliver defaults - default_sliver = slivers.get(None, []) - if default_sliver: - default_sliver_attribs = default_sliver.get('tags', []) - for attrib in default_sliver_attribs: - logger.info(attrib) - rspec.version.add_default_sliver_attribute(attrib['tagname'], attrib['value']) - # add wifi channels - channels = self.get_channels(slice, options) - rspec.version.add_channels(channels) + nodes = self.get_nodes(slice_xrn, slice, slivers, options) + rspec.version.add_nodes(nodes) + # add sliver defaults + default_sliver = slivers.get(None, []) + if default_sliver: + default_sliver_attribs = default_sliver.get('tags', []) + for attrib in default_sliver_attribs: + logger.info(attrib) + rspec.version.add_default_sliver_attribute( + attrib['tagname'], attrib['value']) + # add wifi channels + channels = self.get_channels(slice, options) + rspec.version.add_channels(channels) if not options.get('list_leases') or options.get('list_leases') and options['list_leases'] != 'resources': - leases_channels = self.get_leases_and_channels(slice, slice_xrn) - rspec.version.add_leases(leases_channels) + leases_channels = self.get_leases_and_channels(slice, slice_xrn) + rspec.version.add_leases(leases_channels) return rspec.toxml() - - diff --git a/sfa/nitos/nitosdriver.py b/sfa/nitos/nitosdriver.py index 8e3da348..8bbf1153 100644 --- a/sfa/nitos/nitosdriver.py +++ b/sfa/nitos/nitosdriver.py @@ -28,79 +28,80 @@ from sfa.nitos.nitosslices import NitosSlices from sfa.nitos.nitosxrn import NitosXrn, slicename_to_hrn, hostname_to_hrn, hrn_to_nitos_slicename, xrn_to_hostname + def list_to_dict(recs, key): """ convert a list of dictionaries into a dictionary keyed on the specified dictionary key """ - return dict ( [ (rec[key],rec) for rec in recs ] ) + return dict([(rec[key], rec) for rec in recs]) # # NitosShell is just an xmlrpc serverproxy where methods # can be sent as-is; it takes care of authentication # from the global config -# +# + + class NitosDriver (Driver): - # the cache instance is a class member so it survives across incoming requests + # the cache instance is a class member so it survives across incoming + # requests cache = None - def __init__ (self, api): - Driver.__init__ (self, api) + def __init__(self, api): + Driver.__init__(self, api) config = api.config - self.shell = NitosShell (config) - self.cache=None + self.shell = NitosShell(config) + self.cache = None self.testbedInfo = self.shell.getTestbedInfo() # un-comment below lines to enable caching # if config.SFA_AGGREGATE_CACHING: # if NitosDriver.cache is None: # NitosDriver.cache = Cache() # self.cache = NitosDriver.cache - + ########################################### - ########## utility methods for NITOS driver + # utility methods for NITOS driver ########################################### - - def filter_nitos_results (self, listo, filters_dict): + def filter_nitos_results(self, listo, filters_dict): """ the Nitos scheduler API does not provide a get result filtring so we do it here """ mylist = [] mylist.extend(listo) for dicto in mylist: - for filter in filters_dict: - if filter not in dicto or dicto[filter] != filters_dict[filter]: - listo.remove(dicto) - break + for filter in filters_dict: + if filter not in dicto or dicto[filter] != filters_dict[filter]: + listo.remove(dicto) + break return listo - def convert_id (self, list_of_dict): + def convert_id(self, list_of_dict): """ convert object id retrived in string format to int format """ for dicto in list_of_dict: - for key in dicto: - if key in ['node_id', 'slice_id', 'user_id', 'channel_id', 'reservation_id'] and isinstance(dicto[key], str): - dicto[key] = int(dicto[key]) - elif key in ['user_ids']: - user_ids2 = [] - for user_id in dicto['user_ids']: - user_ids2.append(int(user_id)) - dicto['user_ids'] = user_ids2 + for key in dicto: + if key in ['node_id', 'slice_id', 'user_id', 'channel_id', 'reservation_id'] and isinstance(dicto[key], str): + dicto[key] = int(dicto[key]) + elif key in ['user_ids']: + user_ids2 = [] + for user_id in dicto['user_ids']: + user_ids2.append(int(user_id)) + dicto['user_ids'] = user_ids2 return list_of_dict - - ######################################## - ########## registry oriented + # registry oriented ######################################## - def augment_records_with_testbed_info (self, sfa_records): - return self.fill_record_info (sfa_records) + def augment_records_with_testbed_info(self, sfa_records): + return self.fill_record_info(sfa_records) - ########## - def register (self, sfa_record, hrn, pub_key): + ########## + def register(self, sfa_record, hrn, pub_key): type = sfa_record['type'] nitos_record = self.sfa_fields_to_nitos_fields(type, hrn, sfa_record) @@ -111,86 +112,90 @@ class NitosDriver (Driver): slices = self.shell.getSlices() # filter slices for slice in slices: - if slice['slice_name'] == nitos_record['name']: - slice_id = slice['slice_id'] - break - + if slice['slice_name'] == nitos_record['name']: + slice_id = slice['slice_id'] + break + if not slice_id: - pointer = self.shell.addSlice({'slice_name' : nitos_record['name']}) + pointer = self.shell.addSlice( + {'slice_name': nitos_record['name']}) else: - pointer = slice_id + pointer = slice_id elif type == 'user': users = self.shell.getUsers() # filter users for user in users: - if user['user_name'] == nitos_record['name']: - user_id = user['user_id'] - break + if user['user_name'] == nitos_record['name']: + user_id = user['user_id'] + break if not user_id: - pointer = self.shell.addUser({'username' : nitos_record['name'], 'email' : nitos_record['email']}) + pointer = self.shell.addUser( + {'username': nitos_record['name'], 'email': nitos_record['email']}) else: pointer = user_id - # Add the user's key if pub_key: - self.shell.addUserKey({'user_id' : pointer,'key' : pub_key}) + self.shell.addUserKey({'user_id': pointer, 'key': pub_key}) elif type == 'node': nodes = self.shell.GetNodes({}, []) # filter nodes for node in nodes: - if node['hostname'] == nitos_record['name']: - node_id = node['node_id'] - break + if node['hostname'] == nitos_record['name']: + node_id = node['node_id'] + break if not node_id: pointer = self.shell.addNode(nitos_record) else: pointer = node_id - + return pointer - + ########## - def update (self, old_sfa_record, new_sfa_record, hrn, new_key): - + def update(self, old_sfa_record, new_sfa_record, hrn, new_key): + pointer = old_sfa_record['pointer'] type = old_sfa_record['type'] - new_nitos_record = self.sfa_fields_to_nitos_fields(type, hrn, new_sfa_record) + new_nitos_record = self.sfa_fields_to_nitos_fields( + type, hrn, new_sfa_record) # new_key implemented for users only - if new_key and type not in [ 'user' ]: + if new_key and type not in ['user']: raise UnknownSfaType(type) if type == "slice": if 'name' in new_sfa_record: - self.shell.updateSlice({'slice_id': pointer, 'fields': {'slice_name': new_sfa_record['name']}}) - + self.shell.updateSlice({'slice_id': pointer, 'fields': { + 'slice_name': new_sfa_record['name']}}) + elif type == "user": update_fields = {} if 'name' in new_sfa_record: update_fields['username'] = new_sfa_record['name'] if 'email' in new_sfa_record: update_fields['email'] = new_sfa_record['email'] - - self.shell.updateUser({'user_id': pointer, 'fields': update_fields}) - + + self.shell.updateUser( + {'user_id': pointer, 'fields': update_fields}) + if new_key: - # needs to be improved - self.shell.addUserKey({'user_id': pointer, 'key': new_key}) - + # needs to be improved + self.shell.addUserKey({'user_id': pointer, 'key': new_key}) + elif type == "node": - self.shell.updateNode({'node_id': pointer, 'fields': new_sfa_record}) + self.shell.updateNode( + {'node_id': pointer, 'fields': new_sfa_record}) return True - ########## - def remove (self, sfa_record): + def remove(self, sfa_record): - type=sfa_record['type'] - pointer=sfa_record['pointer'] + type = sfa_record['type'] + pointer = sfa_record['pointer'] if type == 'user': self.shell.deleteUser({'user_id': pointer}) elif type == 'slice': @@ -199,10 +204,6 @@ class NitosDriver (Driver): self.shell.deleteNode({'node_id': pointer}) return True - - - - ## # Convert SFA fields to NITOS fields for use when registering or updating @@ -212,7 +213,7 @@ class NitosDriver (Driver): def sfa_fields_to_nitos_fields(self, type, hrn, sfa_record): nitos_record = {} - + if type == "slice": nitos_record["slice_name"] = hrn_to_nitos_slicename(hrn) elif type == "node": @@ -241,15 +242,15 @@ class NitosDriver (Driver): Fill in the nitos specific fields of a SFA record. This involves calling the appropriate NITOS API method to retrieve the database record for the object. - + @param record: record to fill in field (in/out param) """ - + # get ids by type - node_ids, slice_ids = [], [] + node_ids, slice_ids = [], [] user_ids, key_ids = [], [] type_map = {'node': node_ids, 'slice': slice_ids, 'user': user_ids} - + for record in records: for type in type_map: if type == record['type']: @@ -259,25 +260,27 @@ class NitosDriver (Driver): nodes, slices, users, keys = {}, {}, {}, {} if node_ids: all_nodes = self.convert_id(self.shell.getNodes({}, [])) - node_list = [node for node in all_nodes if node['node_id'] in node_ids] + node_list = [node for node in all_nodes if node[ + 'node_id'] in node_ids] nodes = list_to_dict(node_list, 'node_id') if slice_ids: all_slices = self.convert_id(self.shell.getSlices({}, [])) - slice_list = [slice for slice in all_slices if slice['slice_id'] in slice_ids] + slice_list = [slice for slice in all_slices if slice[ + 'slice_id'] in slice_ids] slices = list_to_dict(slice_list, 'slice_id') if user_ids: all_users = self.convert_id(self.shell.getUsers()) - user_list = [user for user in all_users if user['user_id'] in user_ids] + user_list = [user for user in all_users if user[ + 'user_id'] in user_ids] users = list_to_dict(user_list, 'user_id') nitos_records = {'node': nodes, 'slice': slices, 'user': users} - # fill record info for record in records: if record['pointer'] == -1: continue - + for type in nitos_records: if record['type'] == type: if record['pointer'] in nitos_records[type]: @@ -286,17 +289,16 @@ class NitosDriver (Driver): # fill in key info if record['type'] == 'user': if record['pointer'] in nitos_records['user']: - record['keys'] = nitos_records['user'][record['pointer']]['keys'] + record['keys'] = nitos_records[ + 'user'][record['pointer']]['keys'] return records - - + def fill_record_hrns(self, records): """ convert nitos ids to hrns """ - # get ids slice_ids, user_ids, node_ids = [], [], [] for record in records: @@ -311,18 +313,20 @@ class NitosDriver (Driver): slices, users, nodes = {}, {}, {} if node_ids: all_nodes = self.convert_id(self.shell.getNodes({}, [])) - node_list = [node for node in all_nodes if node['node_id'] in node_ids] + node_list = [node for node in all_nodes if node[ + 'node_id'] in node_ids] nodes = list_to_dict(node_list, 'node_id') if slice_ids: all_slices = self.convert_id(self.shell.getSlices({}, [])) - slice_list = [slice for slice in all_slices if slice['slice_id'] in slice_ids] + slice_list = [slice for slice in all_slices if slice[ + 'slice_id'] in slice_ids] slices = list_to_dict(slice_list, 'slice_id') if user_ids: all_users = self.convert_id(self.shell.getUsers()) - user_list = [user for user in all_users if user['user_id'] in user_ids] + user_list = [user for user in all_users if user[ + 'user_id'] in user_ids] users = list_to_dict(user_list, 'user_id') - # convert ids to hrns for record in records: # get all relevant data @@ -333,30 +337,33 @@ class NitosDriver (Driver): if pointer == -1: continue if 'user_ids' in record: - usernames = [users[user_id]['username'] for user_id in record['user_ids'] \ - if user_id in users] - user_hrns = [".".join([auth_hrn, testbed_name, username]) for username in usernames] - record['users'] = user_hrns + usernames = [users[user_id]['username'] for user_id in record['user_ids'] + if user_id in users] + user_hrns = [".".join([auth_hrn, testbed_name, username]) + for username in usernames] + record['users'] = user_hrns if 'slice_ids' in record: - slicenames = [slices[slice_id]['slice_name'] for slice_id in record['slice_ids'] \ + slicenames = [slices[slice_id]['slice_name'] for slice_id in record['slice_ids'] if slice_id in slices] - slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames] + slice_hrns = [slicename_to_hrn( + auth_hrn, slicename) for slicename in slicenames] record['slices'] = slice_hrns if 'node_ids' in record: - hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \ + hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] if node_id in nodes] - node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames] + node_hrns = [hostname_to_hrn( + auth_hrn, login_base, hostname) for hostname in hostnames] record['nodes'] = node_hrns if 'expires' in record: date = utcparse(record['expires']) datestring = datetime_to_string(date) - record['expires'] = datestring - - return records - + record['expires'] = datestring + + return records + def fill_record_sfa_info(self, records): - + def startswith(prefix, values): return [value for value in values if value.startswith(prefix)] @@ -364,10 +371,11 @@ class NitosDriver (Driver): user_ids = [] for record in records: user_ids.extend(record.get("user_ids", [])) - + # get the registry records user_list, users = [], {} - user_list = self.api.dbsession().query(RegRecord).filter(RegRecord.pointer.in_(user_ids)).all() + user_list = self.api.dbsession().query(RegRecord).filter( + RegRecord.pointer.in_(user_ids)).all() # create a hrns keyed on the sfa record's pointer. # Its possible for multiple records to have the same pointer so # the dict's value will be a list of hrns. @@ -378,65 +386,73 @@ class NitosDriver (Driver): # get the nitos records nitos_user_list, nitos_users = [], {} nitos_all_users = self.convert_id(self.shell.getUsers()) - nitos_user_list = [user for user in nitos_all_users if user['user_id'] in user_ids] + nitos_user_list = [ + user for user in nitos_all_users if user['user_id'] in user_ids] nitos_users = list_to_dict(nitos_user_list, 'user_id') - # fill sfa info for record in records: if record['pointer'] == -1: - continue + continue sfa_info = {} type = record['type'] - logger.info("fill_record_sfa_info - incoming record typed %s"%type) + logger.info( + "fill_record_sfa_info - incoming record typed %s" % type) if (type == "slice"): # all slice users are researchers record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice') record['researcher'] = [] for user_id in record.get('user_ids', []): hrns = [user.hrn for user in users[user_id]] - record['researcher'].extend(hrns) - + record['researcher'].extend(hrns) + elif (type == "node"): sfa_info['dns'] = record.get("hostname", "") # xxx TODO: URI, LatLong, IP, DNS - + elif (type == "user"): logger.info('setting user.email') sfa_info['email'] = record.get("email", "") sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user') - sfa_info['geni_certificate'] = record['gid'] + sfa_info['geni_certificate'] = record['gid'] # xxx TODO: PostalAddress, Phone record.update(sfa_info) #################### - def update_relation (self, subject_type, target_type, relation_name, subject_id, target_ids): - - if subject_type =='slice' and target_type == 'user' and relation_name == 'researcher': - subject=self.shell.getSlices ({'slice_id': subject_id}, [])[0] + def update_relation(self, subject_type, target_type, relation_name, subject_id, target_ids): + + if subject_type == 'slice' and target_type == 'user' and relation_name == 'researcher': + subject = self.shell.getSlices({'slice_id': subject_id}, [])[0] current_target_ids = subject['user_ids'] - add_target_ids = list ( set (target_ids).difference(current_target_ids)) - del_target_ids = list ( set (current_target_ids).difference(target_ids)) - logger.debug ("subject_id = %s (type=%s)"%(subject_id,type(subject_id))) + add_target_ids = list( + set(target_ids).difference(current_target_ids)) + del_target_ids = list( + set(current_target_ids).difference(target_ids)) + logger.debug("subject_id = %s (type=%s)" % + (subject_id, type(subject_id))) for target_id in add_target_ids: - self.shell.addUserToSlice ({'user_id': target_id, 'slice_id': subject_id}) - logger.debug ("add_target_id = %s (type=%s)"%(target_id,type(target_id))) + self.shell.addUserToSlice( + {'user_id': target_id, 'slice_id': subject_id}) + logger.debug("add_target_id = %s (type=%s)" % + (target_id, type(target_id))) for target_id in del_target_ids: - logger.debug ("del_target_id = %s (type=%s)"%(target_id,type(target_id))) - self.shell.deleteUserFromSlice ({'user_id': target_id, 'slice_id': subject_id}) + logger.debug("del_target_id = %s (type=%s)" % + (target_id, type(target_id))) + self.shell.deleteUserFromSlice( + {'user_id': target_id, 'slice_id': subject_id}) else: - logger.info('unexpected relation %s to maintain, %s -> %s'%(relation_name,subject_type,target_type)) - + logger.info('unexpected relation %s to maintain, %s -> %s' % + (relation_name, subject_type, target_type)) ######################################## - ########## aggregate oriented + # aggregate oriented ######################################## - def testbed_name (self): return "nitos" + def testbed_name(self): return "nitos" # 'geni_request_rspec_versions' and 'geni_ad_rspec_versions' are mandatory - def aggregate_version (self): + def aggregate_version(self): version_manager = VersionManager() ad_rspec_versions = [] request_rspec_versions = [] @@ -444,14 +460,14 @@ class NitosDriver (Driver): if rspec_version.content_type in ['*', 'ad']: ad_rspec_versions.append(rspec_version.to_dict()) if rspec_version.content_type in ['*', 'request']: - request_rspec_versions.append(rspec_version.to_dict()) + request_rspec_versions.append(rspec_version.to_dict()) return { - 'testbed':self.testbed_name(), + 'testbed': self.testbed_name(), 'geni_request_rspec_versions': request_rspec_versions, 'geni_ad_rspec_versions': ad_rspec_versions, - } + } - def list_slices (self, creds, options): + def list_slices(self, creds, options): # look in cache first if self.cache: slices = self.cache.get('slices') @@ -459,77 +475,86 @@ class NitosDriver (Driver): logger.debug("NitosDriver.list_slices returns from cache") return slices - # get data from db + # get data from db slices = self.shell.getSlices({}, []) testbed_name = self.testbedInfo['name'] - slice_hrns = [slicename_to_hrn(self.hrn, testbed_name, slice['slice_name']) for slice in slices] - slice_urns = [hrn_to_urn(slice_hrn, 'slice') for slice_hrn in slice_hrns] + slice_hrns = [slicename_to_hrn(self.hrn, testbed_name, slice[ + 'slice_name']) for slice in slices] + slice_urns = [hrn_to_urn(slice_hrn, 'slice') + for slice_hrn in slice_hrns] # cache the result if self.cache: - logger.debug ("NitosDriver.list_slices stores value in cache") - self.cache.add('slices', slice_urns) - + logger.debug("NitosDriver.list_slices stores value in cache") + self.cache.add('slices', slice_urns) + return slice_urns - + # first 2 args are None in case of resource discovery - def list_resources (self, slice_urn, slice_hrn, creds, options): - cached_requested = options.get('cached', True) + def list_resources(self, slice_urn, slice_hrn, creds, options): + cached_requested = options.get('cached', True) version_manager = VersionManager() # get the rspec's return format from options #rspec_version = version_manager.get_version(options.get('geni_rspec_version')) # rspec's return format for nitos aggregate is version NITOS 1 rspec_version = version_manager.get_version('NITOS 1') version_string = "rspec_%s" % (rspec_version) - - #panos adding the info option to the caching key (can be improved) + + # panos adding the info option to the caching key (can be improved) if options.get('info'): - version_string = version_string + "_"+options.get('info', 'default') + version_string = version_string + \ + "_" + options.get('info', 'default') # Adding the list_leases option to the caching key if options.get('list_leases'): - version_string = version_string + "_"+options.get('list_leases', 'default') + version_string = version_string + "_" + \ + options.get('list_leases', 'default') # Adding geni_available to caching key if options.get('geni_available'): - version_string = version_string + "_" + str(options.get('geni_available')) - + version_string = version_string + "_" + \ + str(options.get('geni_available')) + # look in cache first if cached_requested and self.cache and not slice_hrn: rspec = self.cache.get(version_string) if rspec: - logger.debug("NitosDriver.ListResources: returning cached advertisement") - return rspec - - #panos: passing user-defined options - #print "manager options = ",options + logger.debug( + "NitosDriver.ListResources: returning cached advertisement") + return rspec + + # panos: passing user-defined options + # print "manager options = ",options aggregate = NitosAggregate(self) - rspec = aggregate.get_rspec(slice_xrn=slice_urn, version=rspec_version, - options=options) - + rspec = aggregate.get_rspec(slice_xrn=slice_urn, version=rspec_version, + options=options) + # cache the result if self.cache and not slice_hrn: - logger.debug("NitosDriver.ListResources: stores advertisement in cache") + logger.debug( + "NitosDriver.ListResources: stores advertisement in cache") self.cache.add(version_string, rspec) - + return rspec - - def sliver_status (self, slice_urn, slice_hrn): + + def sliver_status(self, slice_urn, slice_hrn): # find out where this slice is currently running slicename = hrn_to_nitos_slicename(slice_hrn) - + slices = self.shell.getSlices({}, []) # filter slicename - if len(slices) == 0: - raise SliverDoesNotExist("%s (used %s as slicename internally)" % (slice_hrn, slicename)) - + if len(slices) == 0: + raise SliverDoesNotExist( + "%s (used %s as slicename internally)" % (slice_hrn, slicename)) + for slice in slices: - if slice['slice_name'] == slicename: - user_slice = slice - break + if slice['slice_name'] == slicename: + user_slice = slice + break if not user_slice: - raise SliverDoesNotExist("%s (used %s as slicename internally)" % (slice_hrn, slicename)) + raise SliverDoesNotExist( + "%s (used %s as slicename internally)" % (slice_hrn, slicename)) # report about the reserved nodes only reserved_nodes = self.shell.getReservedNodes({}, []) @@ -537,18 +562,15 @@ class NitosDriver (Driver): slice_reserved_nodes = [] for r_node in reserved_nodes: - if r_node['slice_id'] == slice['slice_id']: - for node in nodes: - if node['node_id'] == r_node['node_id']: - slice_reserved_nodes.append(node) - - - + if r_node['slice_id'] == slice['slice_id']: + for node in nodes: + if node['node_id'] == r_node['node_id']: + slice_reserved_nodes.append(node) if len(slice_reserved_nodes) == 0: - raise SliverDoesNotExist("You have not allocated any slivers here") + raise SliverDoesNotExist("You have not allocated any slivers here") -##### continue from here +# continue from here # get login info user = {} keys = [] @@ -556,9 +578,8 @@ class NitosDriver (Driver): users = self.shell.getUsers() # filter users on slice['user_ids'] for usr in users: - if usr['user_id'] in slice['user_ids']: - keys.extend(usr['keys']) - + if usr['user_id'] in slice['user_ids']: + keys.extend(usr['keys']) user.update({'urn': slice_urn, 'login': slice['slice_name'], @@ -566,7 +587,6 @@ class NitosDriver (Driver): 'port': ['22'], 'keys': keys}) - result = {} top_level_status = 'unknown' if slice_reserved_nodes: @@ -575,7 +595,7 @@ class NitosDriver (Driver): result['nitos_gateway_login'] = slice['slice_name'] #result['pl_expires'] = datetime_to_string(utcparse(slice['expires'])) #result['geni_expires'] = datetime_to_string(utcparse(slice['expires'])) - + resources = [] for node in slice_reserved_nodes: res = {} @@ -584,72 +604,83 @@ class NitosDriver (Driver): res['geni_urn'] = sliver_id res['geni_status'] = 'ready' res['geni_error'] = '' - res['users'] = [user] - + res['users'] = [user] + resources.append(res) - + result['geni_status'] = top_level_status result['geni_resources'] = resources - + return result - def create_sliver (self, slice_urn, slice_hrn, creds, rspec_string, users, options): + def create_sliver(self, slice_urn, slice_hrn, creds, rspec_string, users, options): aggregate = NitosAggregate(self) slices = NitosSlices(self) sfa_peer = slices.get_sfa_peer(slice_hrn) - slice_record=None + slice_record = None if users: slice_record = users[0].get('slice_record', {}) - + # parse rspec rspec = RSpec(rspec_string, version='NITOS 1') # ensure slice record exists - slice = slices.verify_slice(slice_hrn, slice_record, sfa_peer, options=options) + slice = slices.verify_slice( + slice_hrn, slice_record, sfa_peer, options=options) # ensure user records exists - users = slices.verify_users(slice_hrn, slice, users, sfa_peer, options=options) - + users = slices.verify_users( + slice_hrn, slice, users, sfa_peer, options=options) + # add/remove leases (nodes and channels) - # a lease in Nitos RSpec case is a reservation of nodes and channels grouped by (slice,timeslot) + # a lease in Nitos RSpec case is a reservation of nodes and channels + # grouped by (slice,timeslot) rspec_requested_leases = rspec.version.get_leases() rspec_requested_nodes = [] rspec_requested_channels = [] for lease in rspec_requested_leases: - if lease['type'] == 'node': - lease.pop('type', None) - rspec_requested_nodes.append(lease) - else: - lease.pop('type', None) - rspec_requested_channels.append(lease) - + if lease['type'] == 'node': + lease.pop('type', None) + rspec_requested_nodes.append(lease) + else: + lease.pop('type', None) + rspec_requested_channels.append(lease) + nodes = slices.verify_slice_leases_nodes(slice, rspec_requested_nodes) - channels = slices.verify_slice_leases_channels(slice, rspec_requested_channels) + channels = slices.verify_slice_leases_channels( + slice, rspec_requested_channels) return aggregate.get_rspec(slice_xrn=slice_urn, version=rspec.version) - def delete_sliver (self, slice_urn, slice_hrn, creds, options): + def delete_sliver(self, slice_urn, slice_hrn, creds, options): slicename = hrn_to_nitos_slicename(slice_hrn) - slices = self.filter_nitos_results(self.shell.getSlices({}, []), {'slice_name': slicename}) + slices = self.filter_nitos_results( + self.shell.getSlices({}, []), {'slice_name': slicename}) if not slices: return 1 slice = slices[0] - slice_reserved_nodes = self.filter_nitos_results(self.shell.getReservedNodes({}, []), {'slice_id': slice['slice_id'] }) - slice_reserved_channels = self.filter_nitos_results(self.shell.getReservedChannels(), {'slice_id': slice['slice_id'] }) + slice_reserved_nodes = self.filter_nitos_results( + self.shell.getReservedNodes({}, []), {'slice_id': slice['slice_id']}) + slice_reserved_channels = self.filter_nitos_results( + self.shell.getReservedChannels(), {'slice_id': slice['slice_id']}) - slice_reserved_nodes_ids = [node['reservation_id'] for node in slice_reserved_nodes] - slice_reserved_channels_ids = [channel['reservation_id'] for channel in slice_reserved_channels] + slice_reserved_nodes_ids = [node['reservation_id'] + for node in slice_reserved_nodes] + slice_reserved_channels_ids = [ + channel['reservation_id'] for channel in slice_reserved_channels] # release all reserved nodes and channels for that slice try: - released_nodes = self.shell.releaseNodes({'reservation_ids': slice_reserved_nodes_ids}) - released_channels = self.shell.releaseChannels({'reservation_ids': slice_reserved_channels_ids}) + released_nodes = self.shell.releaseNodes( + {'reservation_ids': slice_reserved_nodes_ids}) + released_channels = self.shell.releaseChannels( + {'reservation_ids': slice_reserved_channels_ids}) except: pass return 1 - def renew_sliver (self, slice_urn, slice_hrn, creds, expiration_time, options): + def renew_sliver(self, slice_urn, slice_hrn, creds, expiration_time, options): slicename = hrn_to_nitos_slicename(slice_hrn) slices = self.shell.GetSlices({'slicename': slicename}, ['slice_id']) if not slices: @@ -664,22 +695,21 @@ class NitosDriver (Driver): except: return False - # xxx this code is quite old and has not run for ages # it is obviously totally broken and needs a rewrite - def get_ticket (self, slice_urn, slice_hrn, creds, rspec_string, options): + def get_ticket(self, slice_urn, slice_hrn, creds, rspec_string, options): raise SfaNotImplemented("NitosDriver.get_ticket needs a rewrite") # please keep this code for future reference # slices = PlSlices(self) # peer = slices.get_peer(slice_hrn) # sfa_peer = slices.get_sfa_peer(slice_hrn) -# +# # # get the slice record # credential = api.getCredential() # interface = api.registries[api.hrn] # registry = api.server_proxy(interface, credential) # records = registry.Resolve(xrn, credential) -# +# # # make sure we get a local slice record # record = None # for tmp_record in records: @@ -689,13 +719,13 @@ class NitosDriver (Driver): # slice_record = SliceRecord(dict=tmp_record) # if not record: # raise RecordNotFound(slice_hrn) -# +# # # similar to CreateSliver, we must verify that the required records exist # # at this aggregate before we can issue a ticket # # parse rspec # rspec = RSpec(rspec_string) # requested_attributes = rspec.version.get_slice_attributes() -# +# # # ensure site record exists # site = slices.verify_site(slice_hrn, slice_record, peer, sfa_peer) # # ensure slice record exists @@ -705,13 +735,13 @@ class NitosDriver (Driver): # persons = slices.verify_persons(slice_hrn, slice, users, peer, sfa_peer) # # ensure slice attributes exists # slices.verify_slice_attributes(slice, requested_attributes) -# +# # # get sliver info # slivers = slices.get_slivers(slice_hrn) -# +# # if not slivers: # raise SliverDoesNotExist(slice_hrn) -# +# # # get initscripts # initscripts = [] # data = { @@ -719,7 +749,7 @@ class NitosDriver (Driver): # 'initscripts': initscripts, # 'slivers': slivers # } -# +# # # create the ticket # object_gid = record.get_gid_object() # new_ticket = SfaTicket(subject = object_gid.get_subject()) @@ -732,5 +762,5 @@ class NitosDriver (Driver): # #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn)) # new_ticket.encode() # new_ticket.sign() -# +# # return new_ticket.save_to_string(save_parents=True) diff --git a/sfa/nitos/nitosshell.py b/sfa/nitos/nitosshell.py index cf543f29..8142b1e4 100644 --- a/sfa/nitos/nitosshell.py +++ b/sfa/nitos/nitosshell.py @@ -5,38 +5,42 @@ from urlparse import urlparse from sfa.util.sfalogging import logger from sfa.util.py23 import xmlrpc_client + class NitosShell: """ A simple xmlrpc shell to a NITOS Scheduler instance This class can receive all NITOS API calls to the underlying testbed For safety this is limited to a set of hard-coded calls """ - - direct_calls = ['getNodes','getChannels','getSlices','getUsers','getReservedNodes', - 'getReservedChannels','getTestbedInfo', - 'reserveNodes','reserveChannels','addSlice','addUser','addUserToSlice', - 'addUserKey','addNode', 'addChannel', - 'updateReservedNodes','updateReservedChannels','updateSlice','updateUser', + + direct_calls = ['getNodes', 'getChannels', 'getSlices', 'getUsers', 'getReservedNodes', + 'getReservedChannels', 'getTestbedInfo', + 'reserveNodes', 'reserveChannels', 'addSlice', 'addUser', 'addUserToSlice', + 'addUserKey', 'addNode', 'addChannel', + 'updateReservedNodes', 'updateReservedChannels', 'updateSlice', 'updateUser', 'updateNode', 'updateChannel', - 'deleteNode','deleteChannel','deleteSlice','deleteUser', 'deleteUserFromSLice', + 'deleteNode', 'deleteChannel', 'deleteSlice', 'deleteUser', 'deleteUserFromSLice', 'deleteKey', 'releaseNodes', 'releaseChannels' ] - - # use the 'capability' auth mechanism for higher performance when the PLC db is local - def __init__ ( self, config ) : + # use the 'capability' auth mechanism for higher performance when the PLC + # db is local + def __init__(self, config): url = config.SFA_NITOS_URL - self.proxy = xmlrpc_client.ServerProxy(url, verbose = False, allow_none = True) + self.proxy = xmlrpc_client.ServerProxy( + url, verbose=False, allow_none=True) def __getattr__(self, name): def func(*args, **kwds): - actual_name=None - if name in NitosShell.direct_calls: actual_name=name + actual_name = None + if name in NitosShell.direct_calls: + actual_name = name if not actual_name: - raise Exception("Illegal method call %s for NITOS driver"%(name)) + raise Exception( + "Illegal method call %s for NITOS driver" % (name)) actual_name = "scheduler.server." + actual_name - result=getattr(self.proxy, actual_name)(*args, **kwds) - logger.debug('NitosShell %s (%s) returned ... '%(name,actual_name)) + result = getattr(self.proxy, actual_name)(*args, **kwds) + logger.debug('NitosShell %s (%s) returned ... ' % + (name, actual_name)) return result return func - diff --git a/sfa/nitos/nitosslices.py b/sfa/nitos/nitosslices.py index ffdb6e96..46005cce 100644 --- a/sfa/nitos/nitosslices.py +++ b/sfa/nitos/nitosslices.py @@ -8,14 +8,14 @@ from sfa.rspecs.rspec import RSpec from sfa.nitos.nitosxrn import NitosXrn, hrn_to_nitos_slicename, xrn_to_hostname, xrn_to_channel -MAXINT = 2L**31-1 +MAXINT = 2L**31 - 1 + class NitosSlices: def __init__(self, driver): self.driver = driver - def get_sfa_peer(self, xrn): hrn, type = urn_to_hrn(xrn) @@ -31,115 +31,122 @@ class NitosSlices: def verify_slice_leases_nodes(self, slice, rspec_requested_nodes): nodes = self.driver.shell.getNodes({}, []) - + requested_nodes = [] for node in rspec_requested_nodes: - requested_node = {} - nitos_nodes = [] - nitos_nodes.extend(nodes) - slice_name = hrn_to_nitos_slicename(node['slice_id']) - if slice_name != slice['slice_name']: - continue - hostname = xrn_to_hostname(node['component_id']) - nitos_node = self.driver.filter_nitos_results(nitos_nodes, {'hostname': hostname}) - if not nitos_node: - continue - nitos_node = nitos_node[0] - # fill the requested node with nitos ids - requested_node['slice_id'] = slice['slice_id'] - requested_node['node_id'] = nitos_node['node_id'] - requested_node['start_time'] = node['start_time'] - requested_node['end_time'] = str(int(node['duration']) * int(self.driver.testbedInfo['grain']) + int(node['start_time'])) - requested_nodes.append(requested_node) + requested_node = {} + nitos_nodes = [] + nitos_nodes.extend(nodes) + slice_name = hrn_to_nitos_slicename(node['slice_id']) + if slice_name != slice['slice_name']: + continue + hostname = xrn_to_hostname(node['component_id']) + nitos_node = self.driver.filter_nitos_results( + nitos_nodes, {'hostname': hostname}) + if not nitos_node: + continue + nitos_node = nitos_node[0] + # fill the requested node with nitos ids + requested_node['slice_id'] = slice['slice_id'] + requested_node['node_id'] = nitos_node['node_id'] + requested_node['start_time'] = node['start_time'] + requested_node['end_time'] = str(int( + node['duration']) * int(self.driver.testbedInfo['grain']) + int(node['start_time'])) + requested_nodes.append(requested_node) # get actual nodes reservation data for the slice - reserved_nodes = self.driver.filter_nitos_results(self.driver.shell.getReservedNodes({}, []), {'slice_id': slice['slice_id']}) - + reserved_nodes = self.driver.filter_nitos_results( + self.driver.shell.getReservedNodes({}, []), {'slice_id': slice['slice_id']}) + reserved_nodes_by_id = {} for node in reserved_nodes: - reserved_nodes_by_id[node['reservation_id']] = {'slice_id': node['slice_id'], \ - 'node_id': node['node_id'], 'start_time': node['start_time'], \ - 'end_time': node['end_time']} + reserved_nodes_by_id[node['reservation_id']] = {'slice_id': node['slice_id'], + 'node_id': node['node_id'], 'start_time': node['start_time'], + 'end_time': node['end_time']} added_nodes = [] kept_nodes_id = [] deleted_nodes_id = [] for reservation_id in reserved_nodes_by_id: - if reserved_nodes_by_id[reservation_id] not in requested_nodes: - deleted_nodes_id.append(reservation_id) - else: - kept_nodes_id.append(reservation_id) - requested_nodes.remove(reserved_nodes_by_id[reservation_id]) + if reserved_nodes_by_id[reservation_id] not in requested_nodes: + deleted_nodes_id.append(reservation_id) + else: + kept_nodes_id.append(reservation_id) + requested_nodes.remove(reserved_nodes_by_id[reservation_id]) added_nodes = requested_nodes - try: - deleted=self.driver.shell.releaseNodes({'reservation_ids': deleted_nodes_id}) + deleted = self.driver.shell.releaseNodes( + {'reservation_ids': deleted_nodes_id}) for node in added_nodes: - added=self.driver.shell.reserveNodes({'slice_id': slice['slice_id'], 'start_time': node['start_time'], 'end_time': node['end_time'], 'nodes': [node['node_id']]}) + added = self.driver.shell.reserveNodes({'slice_id': slice['slice_id'], 'start_time': node[ + 'start_time'], 'end_time': node['end_time'], 'nodes': [node['node_id']]}) except: logger.log_exc('Failed to add/remove slice leases nodes') return added_nodes - def verify_slice_leases_channels(self, slice, rspec_requested_channels): channels = self.driver.shell.getChannels({}, []) requested_channels = [] for channel in rspec_requested_channels: - requested_channel = {} - nitos_channels = [] - nitos_channels.extend(channels) - slice_name = hrn_to_nitos_slicename(channel['slice_id']) - if slice_name != slice['slice_name']: - continue - channel_num = xrn_to_channel(channel['component_id']) - nitos_channel = self.driver.filter_nitos_results(nitos_channels, {'channel': channel_num})[0] - # fill the requested channel with nitos ids - requested_channel['slice_id'] = slice['slice_id'] - requested_channel['channel_id'] = nitos_channel['channel_id'] - requested_channel['start_time'] = channel['start_time'] - requested_channel['end_time'] = str(int(channel['duration']) * int(self.driver.testbedInfo['grain']) + int(channel['start_time'])) - requested_channels.append(requested_channel) + requested_channel = {} + nitos_channels = [] + nitos_channels.extend(channels) + slice_name = hrn_to_nitos_slicename(channel['slice_id']) + if slice_name != slice['slice_name']: + continue + channel_num = xrn_to_channel(channel['component_id']) + nitos_channel = self.driver.filter_nitos_results( + nitos_channels, {'channel': channel_num})[0] + # fill the requested channel with nitos ids + requested_channel['slice_id'] = slice['slice_id'] + requested_channel['channel_id'] = nitos_channel['channel_id'] + requested_channel['start_time'] = channel['start_time'] + requested_channel['end_time'] = str(int( + channel['duration']) * int(self.driver.testbedInfo['grain']) + int(channel['start_time'])) + requested_channels.append(requested_channel) # get actual channel reservation data for the slice - reserved_channels = self.driver.filter_nitos_results(self.driver.shell.getReservedChannels(), {'slice_id': slice['slice_id']}) - + reserved_channels = self.driver.filter_nitos_results( + self.driver.shell.getReservedChannels(), {'slice_id': slice['slice_id']}) + reserved_channels_by_id = {} for channel in reserved_channels: - reserved_channels_by_id[channel['reservation_id']] = {'slice_id': channel['slice_id'], \ - 'channel_id': channel['channel_id'], 'start_time': channel['start_time'], \ - 'end_time': channel['end_time']} + reserved_channels_by_id[channel['reservation_id']] = {'slice_id': channel['slice_id'], + 'channel_id': channel['channel_id'], 'start_time': channel['start_time'], + 'end_time': channel['end_time']} added_channels = [] kept_channels_id = [] deleted_channels_id = [] for reservation_id in reserved_channels_by_id: - if reserved_channels_by_id[reservation_id] not in requested_channels: - deleted_channels_id.append(reservation_id) - else: - kept_channels_id.append(reservation_id) - requested_channels.remove(reserved_channels_by_id[reservation_id]) + if reserved_channels_by_id[reservation_id] not in requested_channels: + deleted_channels_id.append(reservation_id) + else: + kept_channels_id.append(reservation_id) + requested_channels.remove( + reserved_channels_by_id[reservation_id]) added_channels = requested_channels - try: - deleted=self.driver.shell.releaseChannels({'reservation_ids': deleted_channels_id}) + deleted = self.driver.shell.releaseChannels( + {'reservation_ids': deleted_channels_id}) for channel in added_channels: - added=self.driver.shell.reserveChannels({'slice_id': slice['slice_id'], 'start_time': channel['start_time'], 'end_time': channel['end_time'], 'channels': [channel['channel_id']]}) + added = self.driver.shell.reserveChannels({'slice_id': slice['slice_id'], 'start_time': channel[ + 'start_time'], 'end_time': channel['end_time'], 'channels': [channel['channel_id']]}) except: logger.log_exc('Failed to add/remove slice leases channels') - - return added_channels + return added_channels def free_egre_key(self): used = set() for tag in self.driver.shell.GetSliceTags({'tagname': 'egre_key'}): - used.add(int(tag['value'])) + used.add(int(tag['value'])) for i in range(1, 256): if i not in used: @@ -150,76 +157,84 @@ class NitosSlices: return str(key) - - def verify_slice(self, slice_hrn, slice_record, sfa_peer, options=None): - if options is None: options={} + if options is None: + options = {} slicename = hrn_to_nitos_slicename(slice_hrn) - slices = self.driver.shell.getSlices({}, []) - slices = self.driver.filter_nitos_results(slices, {'slice_name': slicename}) + slices = self.driver.shell.getSlices({}, []) + slices = self.driver.filter_nitos_results( + slices, {'slice_name': slicename}) if not slices: slice = {'slice_name': slicename} - # add the slice + # add the slice slice['slice_id'] = self.driver.shell.addSlice(slice) slice['node_ids'] = [] slice['user_ids'] = [] else: slice = slices[0] - + return slice def verify_users(self, slice_hrn, slice_record, users, sfa_peer, options=None): - if options is None: options={} + if options is None: + options = {} # get slice info slicename = hrn_to_nitos_slicename(slice_hrn) slices = self.driver.shell.getSlices({}, []) - slice = self.driver.filter_nitos_results(slices, {'slice_name': slicename})[0] + slice = self.driver.filter_nitos_results( + slices, {'slice_name': slicename})[0] added_users = [] - #get users info + # get users info users_info = [] for user in users: - user_urn = user['urn'] - user_hrn, type = urn_to_hrn(user_urn) - username = str(user_hrn).split('.')[-1] - email = user['email'] - # look for the user according to his username, email... - nitos_users = self.driver.filter_nitos_results(self.driver.shell.getUsers(), {'username': username}) - if not nitos_users: - nitos_users = self.driver.filter_nitos_results(self.driver.shell.getUsers(), {'email': email}) - - if not nitos_users: - # create the user - user_id = self.driver.shell.addUser({'username': email.split('@')[0], 'email': email}) - added_users.append(user_id) - # add user keys - for key in user['keys']: - self.driver.shell.addUserKey({'user_id': user_id, 'key': key, 'slice_id': slice['slice_id']}) - # add the user to the slice - self.driver.shell.addUserToSlice({'slice_id': slice['slice_id'], 'user_id': user_id}) - else: - # check if the users are in the slice - for user in nitos_users: - if not user['user_id'] in slice['user_ids']: - self.driver.shell.addUserToSlice({'slice_id': slice['slice_id'], 'user_id': user['user_id']}) + user_urn = user['urn'] + user_hrn, type = urn_to_hrn(user_urn) + username = str(user_hrn).split('.')[-1] + email = user['email'] + # look for the user according to his username, email... + nitos_users = self.driver.filter_nitos_results( + self.driver.shell.getUsers(), {'username': username}) + if not nitos_users: + nitos_users = self.driver.filter_nitos_results( + self.driver.shell.getUsers(), {'email': email}) + + if not nitos_users: + # create the user + user_id = self.driver.shell.addUser( + {'username': email.split('@')[0], 'email': email}) + added_users.append(user_id) + # add user keys + for key in user['keys']: + self.driver.shell.addUserKey( + {'user_id': user_id, 'key': key, 'slice_id': slice['slice_id']}) + # add the user to the slice + self.driver.shell.addUserToSlice( + {'slice_id': slice['slice_id'], 'user_id': user_id}) + else: + # check if the users are in the slice + for user in nitos_users: + if not user['user_id'] in slice['user_ids']: + self.driver.shell.addUserToSlice( + {'slice_id': slice['slice_id'], 'user_id': user['user_id']}) return added_users - def verify_keys(self, persons, users, options=None): - if options is None: options={} - # existing keys + if options is None: + options = {} + # existing keys key_ids = [] for person in persons: key_ids.extend(person['key_ids']) keylist = self.driver.shell.GetKeys(key_ids, ['key_id', 'key']) keydict = {} for key in keylist: - keydict[key['key']] = key['key_id'] + keydict[key['key']] = key['key_id'] existing_keys = keydict.keys() persondict = {} for person in persons: - persondict[person['email']] = person - + persondict[person['email']] = person + # add new keys requested_keys = [] updated_persons = [] @@ -233,28 +248,31 @@ class NitosSlices: try: if peer: person = persondict[user['email']] - self.driver.shell.UnBindObjectFromPeer('person', person['person_id'], peer['shortname']) - key['key_id'] = self.driver.shell.AddPersonKey(user['email'], key) + self.driver.shell.UnBindObjectFromPeer( + 'person', person['person_id'], peer['shortname']) + key['key_id'] = self.driver.shell.AddPersonKey( + user['email'], key) if peer: key_index = user_keys.index(key['key']) remote_key_id = user['key_ids'][key_index] - self.driver.shell.BindObjectToPeer('key', key['key_id'], peer['shortname'], remote_key_id) - + self.driver.shell.BindObjectToPeer('key', key['key_id'], peer[ + 'shortname'], remote_key_id) + finally: if peer: - self.driver.shell.BindObjectToPeer('person', person['person_id'], peer['shortname'], user['person_id']) - + self.driver.shell.BindObjectToPeer('person', person['person_id'], peer[ + 'shortname'], user['person_id']) + # remove old keys (only if we are not appending) append = options.get('append', True) - if append == False: + if append == False: removed_keys = set(existing_keys).difference(requested_keys) for existing_key_id in keydict: if keydict[existing_key_id] in removed_keys: try: if peer: - self.driver.shell.UnBindObjectFromPeer('key', existing_key_id, peer['shortname']) + self.driver.shell.UnBindObjectFromPeer( + 'key', existing_key_id, peer['shortname']) self.driver.shell.DeleteKey(existing_key_id) except: - pass - - + pass diff --git a/sfa/nitos/nitosxrn.py b/sfa/nitos/nitosxrn.py index 9722ffed..375a75fe 100644 --- a/sfa/nitos/nitosxrn.py +++ b/sfa/nitos/nitosxrn.py @@ -5,82 +5,106 @@ import re from sfa.util.xrn import Xrn # temporary helper functions to use this module instead of namespace -def hostname_to_hrn (auth, login_base, hostname): - return NitosXrn(auth=auth+'.'+login_base,hostname=hostname).get_hrn() + + +def hostname_to_hrn(auth, login_base, hostname): + return NitosXrn(auth=auth + '.' + login_base, hostname=hostname).get_hrn() + + def hostname_to_urn(auth, login_base, hostname): - return NitosXrn(auth=auth+'.'+login_base,hostname=hostname).get_urn() -def slicename_to_hrn (auth_hrn,site_name,slicename): - return NitosXrn(auth=auth_hrn+'.'+site_name,slicename=slicename).get_hrn() + return NitosXrn(auth=auth + '.' + login_base, hostname=hostname).get_urn() + + +def slicename_to_hrn(auth_hrn, site_name, slicename): + return NitosXrn(auth=auth_hrn + '.' + site_name, slicename=slicename).get_hrn() # hack to convert nitos user name to hrn -def username_to_hrn (auth_hrn,site_name,username): - return NitosXrn(auth=auth_hrn+'.'+site_name,slicename=username).get_hrn() -def email_to_hrn (auth_hrn, email): + + +def username_to_hrn(auth_hrn, site_name, username): + return NitosXrn(auth=auth_hrn + '.' + site_name, slicename=username).get_hrn() + + +def email_to_hrn(auth_hrn, email): return NitosXrn(auth=auth_hrn, email=email).get_hrn() -def hrn_to_nitos_slicename (hrn): - return NitosXrn(xrn=hrn,type='slice').nitos_slicename() + + +def hrn_to_nitos_slicename(hrn): + return NitosXrn(xrn=hrn, type='slice').nitos_slicename() # removed-dangerous - was used for non-slice objects -#def hrn_to_nitos_login_base (hrn): +# def hrn_to_nitos_login_base (hrn): # return NitosXrn(xrn=hrn,type='slice').nitos_login_base() -def hrn_to_nitos_authname (hrn): - return NitosXrn(xrn=hrn,type='any').nitos_authname() + + +def hrn_to_nitos_authname(hrn): + return NitosXrn(xrn=hrn, type='any').nitos_authname() + + def xrn_to_hostname(hrn): return Xrn.unescape(NitosXrn(xrn=hrn, type='node').get_leaf()) -def channel_to_hrn (auth, login_base, channel): - return NitosXrn(auth=auth+'.'+login_base, channel=channel).get_hrn() -def channel_to_urn (auth, login_base, channel): - return NitosXrn(auth=auth+'.'+login_base, channel=channel).get_urn() + + +def channel_to_hrn(auth, login_base, channel): + return NitosXrn(auth=auth + '.' + login_base, channel=channel).get_hrn() + + +def channel_to_urn(auth, login_base, channel): + return NitosXrn(auth=auth + '.' + login_base, channel=channel).get_urn() + + def xrn_to_channel(hrn): return Xrn.unescape(NitosXrn(xrn=hrn, type='channel').get_leaf()) + class NitosXrn (Xrn): - @staticmethod - def site_hrn (auth, login_base): - return '.'.join([auth,login_base]) + @staticmethod + def site_hrn(auth, login_base): + return '.'.join([auth, login_base]) - def __init__ (self, auth=None, hostname=None, slicename=None, email=None, interface=None, channel=None, **kwargs): - #def hostname_to_hrn(auth_hrn, login_base, hostname): + def __init__(self, auth=None, hostname=None, slicename=None, email=None, interface=None, channel=None, **kwargs): + # def hostname_to_hrn(auth_hrn, login_base, hostname): if hostname is not None: - self.type='node' + self.type = 'node' # keep only the first part of the DNS name #self.hrn='.'.join( [auth,hostname.split(".")[0] ] ) # escape the '.' in the hostname - self.hrn='.'.join( [auth,Xrn.escape(hostname)] ) + self.hrn = '.'.join([auth, Xrn.escape(hostname)]) self.hrn_to_urn() - #def slicename_to_hrn(auth_hrn, slicename): + # def slicename_to_hrn(auth_hrn, slicename): elif slicename is not None: - self.type='slice' + self.type = 'slice' self.hrn = ".".join([auth] + [slicename.replace(".", "_")]) self.hrn_to_urn() - #def email_to_hrn(auth_hrn, email): + # def email_to_hrn(auth_hrn, email): elif email is not None: - self.type='person' + self.type = 'person' # keep only the part before '@' and replace special chars into _ - self.hrn='.'.join([auth,email.split('@')[0].replace(".", "_").replace("+", "_")]) + self.hrn = '.'.join( + [auth, email.split('@')[0].replace(".", "_").replace("+", "_")]) self.hrn_to_urn() elif interface is not None: self.type = 'interface' self.hrn = auth + '.' + interface self.hrn_to_urn() elif channel is not None: - self.type='channel' - self.hrn=".".join([auth] + [channel]) + self.type = 'channel' + self.hrn = ".".join([auth] + [channel]) self.hrn_to_urn() else: - Xrn.__init__ (self,**kwargs) + Xrn.__init__(self, **kwargs) - #def hrn_to_pl_slicename(hrn): - def nitos_slicename (self): + # def hrn_to_pl_slicename(hrn): + def nitos_slicename(self): self._normalize() leaf = self.leaf sliver_id_parts = leaf.split(':') name = sliver_id_parts[0] name = re.sub('[^a-zA-Z0-9_]', '', name) - #return self.nitos_login_base() + '_' + name + # return self.nitos_login_base() + '_' + name return name - #def hrn_to_pl_authname(hrn): - def nitos_authname (self): + # def hrn_to_pl_authname(hrn): + def nitos_authname(self): self._normalize() return self.authority[-1] @@ -88,20 +112,20 @@ class NitosXrn (Xrn): self._normalize() return self.leaf - def nitos_login_base (self): + def nitos_login_base(self): self._normalize() if self.type and self.type.startswith('authority'): - base = self.leaf + base = self.leaf else: base = self.authority[-1] - + # Fix up names of GENI Federates base = base.lower() base = re.sub('\\\[^a-zA-Z0-9]', '', base) if len(base) > 20: - base = base[len(base)-20:] - + base = base[len(base) - 20:] + return base @@ -110,6 +134,6 @@ if __name__ == '__main__': #nitosxrn = NitosXrn(auth="omf.nitos",slicename="aminesl") #slice_hrn = nitosxrn.get_hrn() #slice_name = NitosXrn(xrn="omf.nitos.aminesl",type='slice').nitos_slicename() - slicename = "giorgos_n" - hrn = slicename_to_hrn("pla", "nitos", slicename) - print(hrn) + slicename = "giorgos_n" + hrn = slicename_to_hrn("pla", "nitos", slicename) + print(hrn) diff --git a/sfa/openstack/client.py b/sfa/openstack/client.py index 215d3330..40a3f85d 100644 --- a/sfa/openstack/client.py +++ b/sfa/openstack/client.py @@ -14,7 +14,7 @@ def parse_novarc(filename): parts = line.split('=') if len(parts) > 1: value = parts[1].replace("\'", "") - value = value.replace('\"', '') + value = value.replace('\"', '') opts[parts[0]] = value except: pass @@ -23,6 +23,7 @@ def parse_novarc(filename): class KeystoneClient: + def __init__(self, username=None, password=None, tenant=None, url=None, config=None): if not config: config = Config() @@ -35,33 +36,41 @@ class KeystoneClient: opts['OS_TENANT_NAME'] = tenant if url: opts['OS_AUTH_URL'] = url - self.opts = opts + self.opts = opts self.client = keystone_client.Client(username=opts.get('OS_USERNAME'), password=opts.get('OS_PASSWORD'), - tenant_name=opts.get('OS_TENANT_NAME'), + tenant_name=opts.get( + 'OS_TENANT_NAME'), auth_url=opts.get('OS_AUTH_URL')) def connect(self, *args, **kwds): self.__init__(*args, **kwds) - + def __getattr__(self, name): - return getattr(self.client, name) + return getattr(self.client, name) class GlanceClient: + def __init__(self, config=None): if not config: config = Config() opts = parse_novarc(config.SFA_NOVA_NOVARC) self.client = glance_client.get_client(host='0.0.0.0', - username=opts.get('OS_USERNAME'), - password=opts.get('OS_PASSWORD'), - tenant=opts.get('OS_TENANT_NAME'), + username=opts.get( + 'OS_USERNAME'), + password=opts.get( + 'OS_PASSWORD'), + tenant=opts.get( + 'OS_TENANT_NAME'), auth_url=opts.get('OS_AUTH_URL')) + def __getattr__(self, name): return getattr(self.client, name) + class NovaClient: + def __init__(self, username=None, password=None, tenant=None, url=None, config=None): if not config: config = Config() @@ -82,11 +91,11 @@ class NovaClient: region_name='', extensions=[], service_type='compute', - service_name='', + service_name='', ) def connect(self, *args, **kwds): self.__init__(*args, **kwds) - + def __getattr__(self, name): - return getattr(self.client, name) + return getattr(self.client, name) diff --git a/sfa/openstack/euca_shell.py b/sfa/openstack/euca_shell.py index e2bdf7f8..90c22225 100644 --- a/sfa/openstack/euca_shell.py +++ b/sfa/openstack/euca_shell.py @@ -2,18 +2,19 @@ try: import boto from boto.ec2.regioninfo import RegionInfo from boto.exception import EC2ResponseError - has_boto=True + has_boto = True except: - has_boto=False + has_boto = False from sfa.util.sfalogging import logger from sfa.openstack.nova_shell import NovaShell from sfa.util.config import Config + class EucaShell: """ A xmlrpc connection to the euca api. - """ + """ def __init__(self, config): self.config = config @@ -22,16 +23,18 @@ class EucaShell: self.secret_key = None def init_context(self, project_name=None): - + # use the context of the specified project's project - # manager. + # manager. if project_name: project = self.nova_shell.auth_manager.get_project(project_name) - self.access_key = "%s:%s" % (project.project_manager.name, project_name) + self.access_key = "%s:%s" % ( + project.project_manager.name, project_name) self.secret_key = project.project_manager.secret else: # use admin user's context - admin_user = self.nova_shell.auth_manager.get_user(self.config.SFA_NOVA_USER) + admin_user = self.nova_shell.auth_manager.get_user( + self.config.SFA_NOVA_USER) #access_key = admin_user.access self.access_key = '%s' % admin_user.name self.secret_key = admin_user.secret @@ -43,18 +46,18 @@ class EucaShell: if not self.access_key or not self.secret_key: self.init_context(project_name) - + url = self.config.SFA_NOVA_API_URL host = None - port = None + port = None path = "/" use_ssl = False - # Split the url into parts + # Split the url into parts if url.find('https://') >= 0: - use_ssl = True + use_ssl = True url = url.replace('https://', '') elif url.find('http://') >= 0: - use_ssl = False + use_ssl = False url = url.replace('http://', '') parts = url.split(':') host = parts[0] @@ -62,16 +65,15 @@ class EucaShell: parts = parts[1].split('/') port = int(parts[0]) parts = parts[1:] - path = '/'+'/'.join(parts) + path = '/' + '/'.join(parts) return boto.connect_ec2(aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, is_secure=use_ssl, region=RegionInfo(None, 'eucalyptus', host), host=host, port=port, - path=path) + path=path) def __getattr__(self, name): def func(*args, **kwds): conn = self.get_euca_connection() - diff --git a/sfa/openstack/image.py b/sfa/openstack/image.py index 555b1b9b..4e511df1 100644 --- a/sfa/openstack/image.py +++ b/sfa/openstack/image.py @@ -3,9 +3,10 @@ from sfa.rspecs.elements.disk_image import DiskImage class Image: - + def __init__(self, image=None): - if image is None: image={} + if image is None: + image = {} self.id = None self.container_format = None self.kernel_id = None @@ -21,22 +22,23 @@ class Image: def parse_image(self, image): if isinstance(image, dict): - self.id = image['id'] + self.id = image['id'] self.name = image['name'] self.container_format = image['container_format'] - self.properties = image['properties'] + self.properties = image['properties'] if 'kernel_id' in self.properties: self.kernel_id = self.properties['kernel_id'] if 'ramdisk_id' in self.properties: self.ramdisk_id = self.properties['ramdisk_id'] - + def to_rspec_object(self): img = DiskImage() img['name'] = self.name img['description'] = self.name img['os'] = self.name img['version'] = self.name - return img + return img + class ImageManager: @@ -68,7 +70,5 @@ class ImageManager: elif name: image = self.driver.shell.nova_manager.images.find(name=name) except ImageNotFound: - pass + pass return Image(image) - - diff --git a/sfa/openstack/nova_driver.py b/sfa/openstack/nova_driver.py index 39ea2f91..31754803 100644 --- a/sfa/openstack/nova_driver.py +++ b/sfa/openstack/nova_driver.py @@ -7,7 +7,7 @@ from sfa.util.faults import MissingSfaInfo, UnknownSfaType, \ from sfa.util.sfalogging import logger from sfa.util.defaultdict import defaultdict from sfa.util.sfatime import utcparse, datetime_to_string, datetime_to_epoch -from sfa.util.xrn import Xrn, hrn_to_urn, get_leaf +from sfa.util.xrn import Xrn, hrn_to_urn, get_leaf from sfa.openstack.osxrn import OSXrn, hrn_to_os_slicename, hrn_to_os_tenant_name from sfa.util.cache import Cache from sfa.trust.credential import Credential @@ -23,28 +23,32 @@ from sfa.openstack.shell import Shell from sfa.openstack.osaggregate import OSAggregate from sfa.planetlab.plslices import PlSlices + def list_to_dict(recs, key): """ convert a list of dictionaries into a dictionary keyed on the specified dictionary key """ - return dict ( [ (rec[key],rec) for rec in recs ] ) + return dict([(rec[key], rec) for rec in recs]) # # PlShell is just an xmlrpc serverproxy where methods # can be sent as-is; it takes care of authentication # from the global config -# +# + + class NovaDriver(Driver): - # the cache instance is a class member so it survives across incoming requests + # the cache instance is a class member so it survives across incoming + # requests cache = None - def __init__ (self, api): + def __init__(self, api): Driver.__init__(self, api) config = api.config self.shell = Shell(config=config) - self.cache=None + self.cache = None if config.SFA_AGGREGATE_CACHING: if NovaDriver.cache is None: NovaDriver.cache = Cache() @@ -54,7 +58,8 @@ class NovaDriver(Driver): sliver_id_parts = Xrn(xrn).get_sliver_id_parts() slice = self.shell.auth_manager.tenants.find(id=sliver_id_parts[0]) if not slice: - raise Forbidden("Unable to locate slice record for sliver: %s" % xrn) + raise Forbidden( + "Unable to locate slice record for sliver: %s" % xrn) slice_xrn = OSXrn(name=slice.name, type='slice') return slice_xrn @@ -72,11 +77,11 @@ class NovaDriver(Driver): slice_ids.append(sliver_id_parts[0]) if not slice_ids: - raise Forbidden("sliver urn not provided") + raise Forbidden("sliver urn not provided") sliver_names = [] for slice_id in slice_ids: - slice = self.shell.auth_manager.tenants.find(slice_id) + slice = self.shell.auth_manager.tenants.find(slice_id) sliver_names.append(slice['name']) # make sure we have a credential for every specified sliver ierd @@ -84,32 +89,32 @@ class NovaDriver(Driver): if sliver_name not in slice_cred_names: msg = "Valid credential not found for target: %s" % sliver_name raise Forbidden(msg) - + ######################################## - ########## registry oriented + # registry oriented ######################################## - ########## disabled users - def is_enabled (self, record): + # disabled users + def is_enabled(self, record): # all records are enabled return True - def augment_records_with_testbed_info (self, sfa_records): - return self.fill_record_info (sfa_records) + def augment_records_with_testbed_info(self, sfa_records): + return self.fill_record_info(sfa_records) + + ########## + def register(self, sfa_record, hrn, pub_key): - ########## - def register (self, sfa_record, hrn, pub_key): - if sfa_record['type'] == 'slice': - record = self.register_slice(sfa_record, hrn) + record = self.register_slice(sfa_record, hrn) elif sfa_record['type'] == 'user': record = self.register_user(sfa_record, hrn, pub_key) - elif sfa_record['type'].startswith('authority'): + elif sfa_record['type'].startswith('authority'): record = self.register_authority(sfa_record, hrn) # We should be returning the records id as a pointer but - # this is a string and the records table expects this to be an + # this is a string and the records table expects this to be an # int. - #return record.id + # return record.id return -1 def register_slice(self, sfa_record, hrn): @@ -119,40 +124,46 @@ class NovaDriver(Driver): self.shell.auth_manager.tenants.create(name, description) tenant = self.shell.auth_manager.tenants.find(name=name) auth_hrn = OSXrn(xrn=hrn, type='slice').get_authority_hrn() - parent_tenant_name = OSXrn(xrn=auth_hrn, type='slice').get_tenant_name() - parent_tenant = self.shell.auth_manager.tenants.find(name=parent_tenant_name) + parent_tenant_name = OSXrn( + xrn=auth_hrn, type='slice').get_tenant_name() + parent_tenant = self.shell.auth_manager.tenants.find( + name=parent_tenant_name) researchers = sfa_record.get('researchers', []) for researcher in researchers: name = Xrn(researcher).get_leaf() user = self.shell.auth_manager.users.find(name=name) self.shell.auth_manager.roles.add_user_role(user, 'Member', tenant) self.shell.auth_manager.roles.add_user_role(user, 'user', tenant) - pis = sfa_record.get('pis', []) for pi in pis: name = Xrn(pi).get_leaf() user = self.shell.auth_manager.users.find(name=name) self.shell.auth_manager.roles.add_user_role(user, 'pi', tenant) - self.shell.auth_manager.roles.add_user_role(user, 'pi', parent_tenant) + self.shell.auth_manager.roles.add_user_role( + user, 'pi', parent_tenant) return tenant - + def register_user(self, sfa_record, hrn, pub_key): # add person roles, projects and keys email = sfa_record.get('email', None) xrn = Xrn(hrn) name = xrn.get_leaf() auth_hrn = xrn.get_authority_hrn() - tenant_name = OSXrn(xrn=auth_hrn, type='authority').get_tenant_name() - tenant = self.shell.auth_manager.tenants.find(name=tenant_name) - self.shell.auth_manager.users.create(name, email=email, tenant_id=tenant.id) + tenant_name = OSXrn(xrn=auth_hrn, type='authority').get_tenant_name() + tenant = self.shell.auth_manager.tenants.find(name=tenant_name) + self.shell.auth_manager.users.create( + name, email=email, tenant_id=tenant.id) user = self.shell.auth_manager.users.find(name=name) slices = sfa_records.get('slices', []) for slice in projects: - slice_tenant_name = OSXrn(xrn=slice, type='slice').get_tenant_name() - slice_tenant = self.shell.auth_manager.tenants.find(name=slice_tenant_name) - self.shell.auth_manager.roles.add_user_role(user, slice_tenant, 'user') + slice_tenant_name = OSXrn( + xrn=slice, type='slice').get_tenant_name() + slice_tenant = self.shell.auth_manager.tenants.find( + name=slice_tenant_name) + self.shell.auth_manager.roles.add_user_role( + user, slice_tenant, 'user') keys = sfa_records.get('keys', []) for key in keys: keyname = OSXrn(xrn=hrn, type='user').get_slicename() @@ -161,18 +172,19 @@ class NovaDriver(Driver): def register_authority(self, sfa_record, hrn): name = OSXrn(xrn=hrn, type='authority').get_tenant_name() - self.shell.auth_manager.tenants.create(name, sfa_record.get('description', '')) + self.shell.auth_manager.tenants.create( + name, sfa_record.get('description', '')) tenant = self.shell.auth_manager.tenants.find(name=name) return tenant - - + ########## - # xxx actually old_sfa_record comes filled with plc stuff as well in the original code - def update (self, old_sfa_record, new_sfa_record, hrn, new_key): - type = new_sfa_record['type'] - + # xxx actually old_sfa_record comes filled with plc stuff as well in the + # original code + def update(self, old_sfa_record, new_sfa_record, hrn, new_key): + type = new_sfa_record['type'] + # new_key implemented for users only - if new_key and type not in [ 'user' ]: + if new_key and type not in ['user']: raise UnknownSfaType(type) elif type == "slice": @@ -186,29 +198,28 @@ class NovaDriver(Driver): project_manager = Xrn(pis[0], 'user').get_leaf() elif researchers: project_manager = Xrn(researchers[0], 'user').get_leaf() - self.shell.auth_manager.modify_project(name, project_manager, description) + self.shell.auth_manager.modify_project( + name, project_manager, description) elif type == "user": # can techinally update access_key and secret_key, - # but that is not in our scope, so we do nothing. + # but that is not in our scope, so we do nothing. pass return True - ########## - def remove (self, sfa_record): - type=sfa_record['type'] + def remove(self, sfa_record): + type = sfa_record['type'] if type == 'user': - name = Xrn(sfa_record['hrn']).get_leaf() + name = Xrn(sfa_record['hrn']).get_leaf() if self.shell.auth_manager.get_user(name): self.shell.auth_manager.delete_user(name) elif type == 'slice': - name = hrn_to_os_slicename(sfa_record['hrn']) + name = hrn_to_os_slicename(sfa_record['hrn']) if self.shell.auth_manager.get_project(name): self.shell.auth_manager.delete_project(name) return True - #################### def fill_record_info(self, records): """ @@ -228,12 +239,12 @@ class NovaDriver(Driver): else: continue record['geni_urn'] = hrn_to_urn(record['hrn'], record['type']) - record['geni_certificate'] = record['gid'] - #if os_record.created_at is not None: + record['geni_certificate'] = record['gid'] + # if os_record.created_at is not None: # record['date_created'] = datetime_to_string(utcparse(os_record.created_at)) - #if os_record.updated_at is not None: + # if os_record.updated_at is not None: # record['last_updated'] = datetime_to_string(utcparse(os_record.updated_at)) - + return records def fill_user_record_info(self, record): @@ -246,14 +257,14 @@ class NovaDriver(Driver): slices = [] all_tenants = self.shell.auth_manager.tenants.list() for tmp_tenant in all_tenants: - if tmp_tenant.name.startswith(tenant.name +"."): + if tmp_tenant.name.startswith(tenant.name + "."): for tmp_user in tmp_tenant.list_users(): if tmp_user.name == user.name: - slice_hrn = ".".join([self.hrn, tmp_tenant.name]) - slices.append(slice_hrn) + slice_hrn = ".".join([self.hrn, tmp_tenant.name]) + slices.append(slice_hrn) record['slices'] = slices roles = self.shell.auth_manager.roles.roles_for_user(user, tenant) - record['roles'] = [role.name for role in roles] + record['roles'] = [role.name for role in roles] keys = self.shell.nova_manager.keypairs.findall(name=record['hrn']) record['keys'] = [key.public_key for key in keys] return record @@ -262,7 +273,8 @@ class NovaDriver(Driver): tenant_name = hrn_to_os_tenant_name(record['hrn']) tenant = self.shell.auth_manager.tenants.find(name=tenant_name) parent_tenant_name = OSXrn(xrn=tenant_name).get_authority_hrn() - parent_tenant = self.shell.auth_manager.tenants.find(name=parent_tenant_name) + parent_tenant = self.shell.auth_manager.tenants.find( + name=parent_tenant_name) researchers = [] pis = [] @@ -270,11 +282,13 @@ class NovaDriver(Driver): for user in tenant.list_users(): for role in self.shell.auth_manager.roles.roles_for_user(user, tenant): if role.name.lower() == 'pi': - user_tenant = self.shell.auth_manager.tenants.find(id=user.tenantId) + user_tenant = self.shell.auth_manager.tenants.find( + id=user.tenantId) hrn = ".".join([self.hrn, user_tenant.name, user.name]) pis.append(hrn) elif role.name.lower() in ['user', 'member']: - user_tenant = self.shell.auth_manager.tenants.find(id=user.tenantId) + user_tenant = self.shell.auth_manager.tenants.find( + id=user.tenantId) hrn = ".".join([self.hrn, user_tenant.name, user.name]) researchers.append(hrn) @@ -282,7 +296,8 @@ class NovaDriver(Driver): for user in parent_tenant.list_users(): for role in self.shell.auth_manager.roles.roles_for_user(user, parent_tenant): if role.name.lower() == 'pi': - user_tenant = self.shell.auth_manager.tenants.find(id=user.tenantId) + user_tenant = self.shell.auth_manager.tenants.find( + id=user.tenantId) hrn = ".".join([self.hrn, user_tenant.name, user.name]) pis.append(hrn) record['name'] = tenant_name @@ -312,10 +327,10 @@ class NovaDriver(Driver): # look for slices slices = [] - all_tenants = self.shell.auth_manager.tenants.list() + all_tenants = self.shell.auth_manager.tenants.list() for tmp_tenant in all_tenants: - if tmp_tenant.name.startswith(tenant.name+"."): - slices.append(".".join([self.hrn, tmp_tenant.name])) + if tmp_tenant.name.startswith(tenant.name + "."): + slices.append(".".join([self.hrn, tmp_tenant.name])) record['name'] = tenant_name record['description'] = tenant.description @@ -327,56 +342,65 @@ class NovaDriver(Driver): #################### # plcapi works by changes, compute what needs to be added/deleted - def update_relation (self, subject_type, target_type, subject_id, target_ids): + def update_relation(self, subject_type, target_type, subject_id, target_ids): # hard-wire the code for slice/user for now, could be smarter if needed - if subject_type =='slice' and target_type == 'user': - subject=self.shell.project_get(subject_id)[0] + if subject_type == 'slice' and target_type == 'user': + subject = self.shell.project_get(subject_id)[0] current_target_ids = [user.name for user in subject.members] - add_target_ids = list ( set (target_ids).difference(current_target_ids)) - del_target_ids = list ( set (current_target_ids).difference(target_ids)) - logger.debug ("subject_id = %s (type=%s)"%(subject_id,type(subject_id))) + add_target_ids = list( + set(target_ids).difference(current_target_ids)) + del_target_ids = list( + set(current_target_ids).difference(target_ids)) + logger.debug("subject_id = %s (type=%s)" % + (subject_id, type(subject_id))) for target_id in add_target_ids: - self.shell.project_add_member(target_id,subject_id) - logger.debug ("add_target_id = %s (type=%s)"%(target_id,type(target_id))) + self.shell.project_add_member(target_id, subject_id) + logger.debug("add_target_id = %s (type=%s)" % + (target_id, type(target_id))) for target_id in del_target_ids: - logger.debug ("del_target_id = %s (type=%s)"%(target_id,type(target_id))) + logger.debug("del_target_id = %s (type=%s)" % + (target_id, type(target_id))) self.shell.project_remove_member(target_id, subject_id) else: - logger.info('unexpected relation to maintain, %s -> %s'%(subject_type,target_type)) + logger.info('unexpected relation to maintain, %s -> %s' % + (subject_type, target_type)) - ######################################## - ########## aggregate oriented + # aggregate oriented ######################################## - def testbed_name (self): return "openstack" + def testbed_name(self): return "openstack" - def aggregate_version (self): + def aggregate_version(self): return {} # first 2 args are None in case of resource discovery - def list_resources (self, version=None, options=None): - if options is None: options={} + def list_resources(self, version=None, options=None): + if options is None: + options = {} aggregate = OSAggregate(self) - rspec = aggregate.list_resources(version=version, options=options) + rspec = aggregate.list_resources(version=version, options=options) return rspec def describe(self, urns, version=None, options=None): - if options is None: options={} + if options is None: + options = {} aggregate = OSAggregate(self) return aggregate.describe(urns, version=version, options=options) - - def status (self, urns, options=None): - if options is None: options={} + + def status(self, urns, options=None): + if options is None: + options = {} aggregate = OSAggregate(self) - desc = aggregate.describe(urns) + desc = aggregate.describe(urns) status = {'geni_urn': desc['geni_urn'], 'geni_slivers': desc['geni_slivers']} return status - def allocate (self, urn, rspec_string, expiration, options=None): - if options is None: options={} - xrn = Xrn(urn) + def allocate(self, urn, rspec_string, expiration, options=None): + if options is None: + options = {} + xrn = Xrn(urn) aggregate = OSAggregate(self) # assume first user is the caller and use their context @@ -391,22 +415,24 @@ class NovaDriver(Driver): pubkeys = [] for user in users: pubkeys.extend(user['keys']) - + rspec = RSpec(rspec_string) instance_name = hrn_to_os_slicename(slice_hrn) tenant_name = OSXrn(xrn=slice_hrn, type='slice').get_tenant_name() - slivers = aggregate.run_instances(instance_name, tenant_name, \ + slivers = aggregate.run_instances(instance_name, tenant_name, rspec_string, key_name, pubkeys) - - # update all sliver allocation states setting then to geni_allocated + + # update all sliver allocation states setting then to geni_allocated sliver_ids = [sliver.id for sliver in slivers] - dbsession=self.api.dbsession() - SliverAllocation.set_allocations(sliver_ids, 'geni_provisioned',dbsession) - + dbsession = self.api.dbsession() + SliverAllocation.set_allocations( + sliver_ids, 'geni_provisioned', dbsession) + return aggregate.describe(urns=[urn], version=rspec.version) def provision(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} # update sliver allocation states and set them to geni_provisioned aggregate = OSAggregate(self) instances = aggregate.get_instances(urns) @@ -414,14 +440,17 @@ class NovaDriver(Driver): for instance in instances: sliver_hrn = "%s.%s" % (self.driver.hrn, instance.id) sliver_ids.append(Xrn(sliver_hrn, type='sliver').urn) - dbsession=self.api.dbsession() - SliverAllocation.set_allocations(sliver_ids, 'geni_provisioned',dbsession) + dbsession = self.api.dbsession() + SliverAllocation.set_allocations( + sliver_ids, 'geni_provisioned', dbsession) version_manager = VersionManager() - rspec_version = version_manager.get_version(options['geni_rspec_version']) - return self.describe(urns, rspec_version, options=options) + rspec_version = version_manager.get_version( + options['geni_rspec_version']) + return self.describe(urns, rspec_version, options=options) - def delete (self, urns, options=None): - if options is None: options={} + def delete(self, urns, options=None): + if options is None: + options = {} # collect sliver ids so we can update sliver allocation states after # we remove the slivers. aggregate = OSAggregate(self) @@ -430,12 +459,12 @@ class NovaDriver(Driver): for instance in instances: sliver_hrn = "%s.%s" % (self.driver.hrn, instance.id) sliver_ids.append(Xrn(sliver_hrn, type='sliver').urn) - + # delete the instance aggregate.delete_instance(instance) - + # delete sliver allocation states - dbsession=self.api.dbsession() + dbsession = self.api.dbsession() SliverAllocation.delete_allocations(sliver_ids, dbsession) # return geni_slivers @@ -444,18 +473,20 @@ class NovaDriver(Driver): geni_slivers.append( {'geni_sliver_urn': sliver['sliver_id'], 'geni_allocation_status': 'geni_unallocated', - 'geni_expires': None}) + 'geni_expires': None}) return geni_slivers - def renew (self, urns, expiration_time, options=None): - if options is None: options={} + def renew(self, urns, expiration_time, options=None): + if options is None: + options = {} description = self.describe(urns, None, options) return description['geni_slivers'] - def perform_operational_action (self, urns, action, options=None): - if options is None: options={} + def perform_operational_action(self, urns, action, options=None): + if options is None: + options = {} aggregate = OSAggregate(self) - action = action.lower() + action = action.lower() if action == 'geni_start': action_method = aggregate.start_instances elif action == 'geni_stop': @@ -465,16 +496,18 @@ class NovaDriver(Driver): else: raise UnsupportedOperation(action) - # fault if sliver is not full allocated (operational status is geni_pending_allocation) + # fault if sliver is not full allocated (operational status is + # geni_pending_allocation) description = self.describe(urns, None, options) for sliver in description['geni_slivers']: if sliver['geni_operational_status'] == 'geni_pending_allocation': - raise UnsupportedOperation(action, "Sliver must be fully allocated (operational status is not geni_pending_allocation)") + raise UnsupportedOperation( + action, "Sliver must be fully allocated (operational status is not geni_pending_allocation)") # # Perform Operational Action Here # - instances = aggregate.get_instances(urns) + instances = aggregate.get_instances(urns) for instance in instances: tenant_name = self.driver.shell.auth_manager.client.tenant_name action_method(tenant_name, instance.name, instance.id) @@ -483,7 +516,8 @@ class NovaDriver(Driver): return geni_slivers def shutdown(self, xrn, options=None): - if options is None: options={} + if options is None: + options = {} xrn = OSXrn(xrn=xrn, type='slice') tenant_name = xrn.get_tenant_name() name = xrn.get_slicename() diff --git a/sfa/openstack/osaggregate.py b/sfa/openstack/osaggregate.py index 29681a05..0a87704e 100644 --- a/sfa/openstack/osaggregate.py +++ b/sfa/openstack/osaggregate.py @@ -4,7 +4,7 @@ import socket import base64 import string import random -import time +import time from collections import defaultdict from nova.exception import ImageNotFound from nova.api.ec2.cloud import CloudController @@ -20,13 +20,14 @@ from sfa.rspecs.elements.services import Services from sfa.rspecs.elements.interface import Interface from sfa.rspecs.elements.fw_rule import FWRule from sfa.util.xrn import Xrn -from sfa.planetlab.plxrn import PlXrn +from sfa.planetlab.plxrn import PlXrn from sfa.openstack.osxrn import OSXrn, hrn_to_os_slicename from sfa.rspecs.version_manager import VersionManager from sfa.openstack.security_group import SecurityGroup from sfa.client.multiclient import MultiClient from sfa.util.sfalogging import logger + def pubkeys_to_user_data(pubkeys): user_data = "#!/bin/bash\n\n" for pubkey in pubkeys: @@ -37,14 +38,16 @@ def pubkeys_to_user_data(pubkeys): user_data += "\n" return user_data + def image_to_rspec_disk_image(image): img = DiskImage() img['name'] = image['name'] img['description'] = image['name'] img['os'] = image['name'] - img['version'] = image['name'] + img['version'] = image['name'] return img - + + class OSAggregate: def __init__(self, driver): @@ -59,17 +62,20 @@ class OSAggregate: return zones def list_resources(self, version=None, options=None): - if options is None: options={} + if options is None: + options = {} version_manager = VersionManager() version = version_manager.get_version(version) - rspec_version = version_manager._get_version(version.type, version.version, 'ad') + rspec_version = version_manager._get_version( + version.type, version.version, 'ad') rspec = RSpec(version=version, user_options=options) nodes = self.get_aggregate_nodes() rspec.version.add_nodes(nodes) return rspec.toxml() def describe(self, urns, version=None, options=None): - if options is None: options={} + if options is None: + options = {} # update nova connection tenant_name = OSXrn(xrn=urns[0], type='slice').get_tenant_name() self.driver.shell.nova_manager.connect(tenant=tenant_name) @@ -77,27 +83,31 @@ class OSAggregate: # lookup the sliver allocations sliver_ids = [sliver['sliver_id'] for sliver in slivers] constraint = SliverAllocation.sliver_id.in_(sliver_ids) - sliver_allocations = self.driver.api.dbsession().query(SliverAllocation).filter(constraint) + sliver_allocations = self.driver.api.dbsession().query( + SliverAllocation).filter(constraint) sliver_allocation_dict = {} for sliver_allocation in sliver_allocations: - sliver_allocation_dict[sliver_allocation.sliver_id] = sliver_allocation + sliver_allocation_dict[ + sliver_allocation.sliver_id] = sliver_allocation geni_slivers = [] rspec_nodes = [] for instance in instances: rspec_nodes.append(self.instance_to_rspec_node(instance)) - geni_sliver = self.instance_to_geni_sliver(instance, sliver_sllocation_dict) + geni_sliver = self.instance_to_geni_sliver( + instance, sliver_sllocation_dict) geni_slivers.append(geni_sliver) version_manager = VersionManager() version = version_manager.get_version(version) - rspec_version = version_manager._get_version(version.type, version.version, 'manifest') + rspec_version = version_manager._get_version( + version.type, version.version, 'manifest') rspec = RSpec(version=rspec_version, user_options=options) rspec.xml.set('expires', datetime_to_string(utcparse(time.time()))) rspec.version.add_nodes(rspec_nodes) result = {'geni_urn': Xrn(urns[0]).get_urn(), - 'geni_rspec': rspec.toxml(), + 'geni_rspec': rspec.toxml(), 'geni_slivers': geni_slivers} - + return result def get_instances(self, urns): @@ -117,7 +127,7 @@ class OSAggregate: if names: filter['name'] = names if ids: - filter['id'] = ids + filter['id'] = ids servers = self.driver.shell.nova_manager.servers.findall(**filter) instances.extend(servers) @@ -134,72 +144,78 @@ class OSAggregate: rspec_node = Node() rspec_node['component_id'] = node_xrn.urn rspec_node['component_name'] = node_xrn.name - rspec_node['component_manager_id'] = Xrn(self.driver.hrn, 'authority+cm').get_urn() - rspec_node['sliver_id'] = OSXrn(name=instance.name, type='slice', id=instance.id).get_urn() + rspec_node['component_manager_id'] = Xrn( + self.driver.hrn, 'authority+cm').get_urn() + rspec_node['sliver_id'] = OSXrn( + name=instance.name, type='slice', id=instance.id).get_urn() if instance.metadata.get('client_id'): rspec_node['client_id'] = instance.metadata.get('client_id') # get sliver details - flavor = self.driver.shell.nova_manager.flavors.find(id=instance.flavor['id']) + flavor = self.driver.shell.nova_manager.flavors.find( + id=instance.flavor['id']) sliver = self.instance_to_sliver(flavor) # get firewall rules fw_rules = [] group_name = instance.metadata.get('security_groups') if group_name: - group = self.driver.shell.nova_manager.security_groups.find(name=group_name) + group = self.driver.shell.nova_manager.security_groups.find( + name=group_name) for rule in group.rules: - port_range ="%s:%s" % (rule['from_port'], rule['to_port']) + port_range = "%s:%s" % (rule['from_port'], rule['to_port']) fw_rule = FWRule({'protocol': rule['ip_protocol'], 'port_range': port_range, 'cidr_ip': rule['ip_range']['cidr']}) fw_rules.append(fw_rule) - sliver['fw_rules'] = fw_rules + sliver['fw_rules'] = fw_rules rspec_node['slivers'] = [sliver] # get disk image - image = self.driver.shell.image_manager.get_images(id=instance.image['id']) + image = self.driver.shell.image_manager.get_images( + id=instance.image['id']) if isinstance(image, list) and len(image) > 0: image = image[0] disk_image = image_to_rspec_disk_image(image) sliver['disk_image'] = [disk_image] - # get interfaces + # get interfaces rspec_node['services'] = [] rspec_node['interfaces'] = [] addresses = instance.addresses - # HACK: public ips are stored in the list of private, but - # this seems wrong. Assume pub ip is the last in the list of - # private ips until openstack bug is fixed. + # HACK: public ips are stored in the list of private, but + # this seems wrong. Assume pub ip is the last in the list of + # private ips until openstack bug is fixed. if addresses.get('private'): login = Login({'authentication': 'ssh-keys', 'hostname': addresses.get('private')[-1]['addr'], - 'port':'22', 'username': 'root'}) + 'port': '22', 'username': 'root'}) service = Services({'login': login}) - rspec_node['services'].append(service) - + rspec_node['services'].append(service) + for private_ip in addresses.get('private', []): - if_xrn = PlXrn(auth=self.driver.hrn, - interface='node%s' % (instance.hostId)) - if_client_id = Xrn(if_xrn.urn, type='interface', id="eth%s" %if_index).urn - if_sliver_id = Xrn(rspec_node['sliver_id'], type='slice', id="eth%s" %if_index).urn + if_xrn = PlXrn(auth=self.driver.hrn, + interface='node%s' % (instance.hostId)) + if_client_id = Xrn(if_xrn.urn, type='interface', + id="eth%s" % if_index).urn + if_sliver_id = Xrn( + rspec_node['sliver_id'], type='slice', id="eth%s" % if_index).urn interface = Interface({'component_id': if_xrn.urn, 'client_id': if_client_id, 'sliver_id': if_sliver_id}) - interface['ips'] = [{'address': private_ip['addr'], + interface['ips'] = [{'address': private_ip['addr'], #'netmask': private_ip['network'], 'type': private_ip['version']}] - rspec_node['interfaces'].append(interface) - + rspec_node['interfaces'].append(interface) + # slivers always provide the ssh service for public_ip in addresses.get('public', []): - login = Login({'authentication': 'ssh-keys', - 'hostname': public_ip['addr'], - 'port':'22', 'username': 'root'}) + login = Login({'authentication': 'ssh-keys', + 'hostname': public_ip['addr'], + 'port': '22', 'username': 'root'}) service = Services({'login': login}) rspec_node['services'].append(service) return rspec_node - def instance_to_sliver(self, instance, xrn=None): if xrn: sliver_hrn = '%s.%s' % (self.driver.hrn, instance.id) @@ -211,19 +227,20 @@ class OSAggregate: 'cpus': str(instance.vcpus), 'memory': str(instance.ram), 'storage': str(instance.disk)}) - return sliver + return sliver def instance_to_geni_sliver(self, instance, sliver_allocations=None): - if sliver_allocations is None: sliver_allocations={} + if sliver_allocations is None: + sliver_allocations = {} sliver_hrn = '%s.%s' % (self.driver.hrn, instance.id) sliver_id = Xrn(sliver_hrn, type='sliver').urn - + # set sliver allocation and operational status sliver_allocation = sliver_allocations[sliver_id] if sliver_allocation: allocation_status = sliver_allocation.allocation_state if allocation_status == 'geni_allocated': - op_status = 'geni_pending_allocation' + op_status = 'geni_pending_allocation' elif allocation_status == 'geni_provisioned': state = instance.state.lower() if state == 'active': @@ -231,23 +248,23 @@ class OSAggregate: elif state == 'building': op_status = 'geni_notready' elif state == 'failed': - op_status =' geni_failed' + op_status = ' geni_failed' else: op_status = 'geni_unknown' else: - allocation_status = 'geni_unallocated' + allocation_status = 'geni_unallocated' # required fields - geni_sliver = {'geni_sliver_urn': sliver_id, + geni_sliver = {'geni_sliver_urn': sliver_id, 'geni_expires': None, 'geni_allocation_status': allocation_status, 'geni_operational_status': op_status, 'geni_error': None, 'plos_created_at': datetime_to_string(utcparse(instance.created)), 'plos_sliver_type': self.shell.nova_manager.flavors.find(id=instance.flavor['id']).name, - } + } return geni_sliver - + def get_aggregate_nodes(self): zones = self.get_availability_zones() # available sliver/instance/vm types @@ -256,17 +273,19 @@ class OSAggregate: instances = instances.values() # available images images = self.driver.shell.image_manager.get_images_detailed() - disk_images = [image_to_rspec_disk_image(img) for img in images if img['container_format'] in ['ami', 'ovf']] + disk_images = [image_to_rspec_disk_image(img) for img in images if img[ + 'container_format'] in ['ami', 'ovf']] rspec_nodes = [] for zone in zones: rspec_node = Node() xrn = OSXrn(zone, type='node') rspec_node['component_id'] = xrn.urn rspec_node['component_name'] = xrn.name - rspec_node['component_manager_id'] = Xrn(self.driver.hrn, 'authority+cm').get_urn() + rspec_node['component_manager_id'] = Xrn( + self.driver.hrn, 'authority+cm').get_urn() rspec_node['exclusive'] = 'false' rspec_node['hardware_types'] = [HardwareType({'name': 'plos-pc'}), - HardwareType({'name': 'pc'})] + HardwareType({'name': 'pc'})] slivers = [] for instance in instances: sliver = self.instance_to_sliver(instance) @@ -274,26 +293,30 @@ class OSAggregate: slivers.append(sliver) rspec_node['available'] = 'true' rspec_node['slivers'] = slivers - rspec_nodes.append(rspec_node) + rspec_nodes.append(rspec_node) - return rspec_nodes + return rspec_nodes def create_tenant(self, tenant_name): - tenants = self.driver.shell.auth_manager.tenants.findall(name=tenant_name) + tenants = self.driver.shell.auth_manager.tenants.findall( + name=tenant_name) if not tenants: - self.driver.shell.auth_manager.tenants.create(tenant_name, tenant_name) - tenant = self.driver.shell.auth_manager.tenants.find(name=tenant_name) + self.driver.shell.auth_manager.tenants.create( + tenant_name, tenant_name) + tenant = self.driver.shell.auth_manager.tenants.find( + name=tenant_name) else: tenant = tenants[0] return tenant - + def create_instance_key(self, slice_hrn, user): slice_name = Xrn(slice_hrn).leaf user_name = Xrn(user['urn']).leaf key_name = "%s_%s" % (slice_name, user_name) pubkey = user['keys'][0] key_found = False - existing_keys = self.driver.shell.nova_manager.keypairs.findall(name=key_name) + existing_keys = self.driver.shell.nova_manager.keypairs.findall( + name=key_name) for existing_key in existing_keys: if existing_key.public_key != pubkey: self.driver.shell.nova_manager.keypairs.delete(existing_key) @@ -302,67 +325,69 @@ class OSAggregate: if not key_found: self.driver.shell.nova_manager.keypairs.create(key_name, pubkey) - return key_name - + return key_name def create_security_group(self, slicename, fw_rules=None): - if fw_rules is None: fw_rules=[] + if fw_rules is None: + fw_rules = [] # use default group by default - group_name = 'default' + group_name = 'default' if isinstance(fw_rules, list) and fw_rules: # Each sliver get's its own security group. # Keep security group names unique by appending some random # characters on end. - random_name = "".join([random.choice(string.letters+string.digits) - for i in xrange(6)]) - group_name = slicename + random_name + random_name = "".join([random.choice(string.letters + string.digits) + for i in xrange(6)]) + group_name = slicename + random_name security_group = SecurityGroup(self.driver) security_group.create_security_group(group_name) for rule in fw_rules: - security_group.add_rule_to_group(group_name, - protocol = rule.get('protocol'), - cidr_ip = rule.get('cidr_ip'), - port_range = rule.get('port_range'), - icmp_type_code = rule.get('icmp_type_code')) + security_group.add_rule_to_group(group_name, + protocol=rule.get('protocol'), + cidr_ip=rule.get('cidr_ip'), + port_range=rule.get( + 'port_range'), + icmp_type_code=rule.get('icmp_type_code')) # Open ICMP by default security_group.add_rule_to_group(group_name, - protocol = "icmp", - cidr_ip = "0.0.0.0/0", - icmp_type_code = "-1:-1") + protocol="icmp", + cidr_ip="0.0.0.0/0", + icmp_type_code="-1:-1") return group_name def add_rule_to_security_group(self, group_name, **kwds): security_group = SecurityGroup(self.driver) - security_group.add_rule_to_group(group_name=group_name, - protocol=kwds.get('protocol'), - cidr_ip =kwds.get('cidr_ip'), - icmp_type_code = kwds.get('icmp_type_code')) - - + security_group.add_rule_to_group(group_name=group_name, + protocol=kwds.get('protocol'), + cidr_ip=kwds.get('cidr_ip'), + icmp_type_code=kwds.get('icmp_type_code')) def run_instances(self, instance_name, tenant_name, rspec, key_name, pubkeys): - #logger.debug('Reserving an instance: image: %s, flavor: ' \ + # logger.debug('Reserving an instance: image: %s, flavor: ' \ # '%s, key: %s, name: %s' % \ # (image_id, flavor_id, key_name, slicename)) # make sure a tenant exists for this slice - tenant = self.create_tenant(tenant_name) + tenant = self.create_tenant(tenant_name) # add the sfa admin user to this tenant and update our nova client connection # to use these credentials for the rest of this session. This emsures that the instances # we create will be assigned to the correct tenant. - sfa_admin_user = self.driver.shell.auth_manager.users.find(name=self.driver.shell.auth_manager.opts['OS_USERNAME']) + sfa_admin_user = self.driver.shell.auth_manager.users.find( + name=self.driver.shell.auth_manager.opts['OS_USERNAME']) user_role = self.driver.shell.auth_manager.roles.find(name='user') admin_role = self.driver.shell.auth_manager.roles.find(name='admin') - self.driver.shell.auth_manager.roles.add_user_role(sfa_admin_user, admin_role, tenant) - self.driver.shell.auth_manager.roles.add_user_role(sfa_admin_user, user_role, tenant) - self.driver.shell.nova_manager.connect(tenant=tenant.name) + self.driver.shell.auth_manager.roles.add_user_role( + sfa_admin_user, admin_role, tenant) + self.driver.shell.auth_manager.roles.add_user_role( + sfa_admin_user, user_role, tenant) + self.driver.shell.nova_manager.connect(tenant=tenant.name) authorized_keys = "\n".join(pubkeys) files = {'/root/.ssh/authorized_keys': authorized_keys} rspec = RSpec(rspec) requested_instances = defaultdict(list) - + # iterate over clouds/zones/nodes slivers = [] for node in rspec.version.get_nodes_with_slivers(): @@ -370,54 +395,60 @@ class OSAggregate: if not instances: continue for instance in instances: - try: + try: metadata = {} - flavor_id = self.driver.shell.nova_manager.flavors.find(name=instance['name']) + flavor_id = self.driver.shell.nova_manager.flavors.find(name=instance[ + 'name']) image = instance.get('disk_image') if image and isinstance(image, list): image = image[0] else: - raise InvalidRSpec("Must specify a disk_image for each VM") - image_id = self.driver.shell.nova_manager.images.find(name=image['name']) + raise InvalidRSpec( + "Must specify a disk_image for each VM") + image_id = self.driver.shell.nova_manager.images.find(name=image[ + 'name']) fw_rules = instance.get('fw_rules', []) - group_name = self.create_security_group(instance_name, fw_rules) + group_name = self.create_security_group( + instance_name, fw_rules) metadata['security_groups'] = group_name if node.get('component_id'): metadata['component_id'] = node['component_id'] if node.get('client_id'): - metadata['client_id'] = node['client_id'] + metadata['client_id'] = node['client_id'] server = self.driver.shell.nova_manager.servers.create( - flavor=flavor_id, - image=image_id, - key_name = key_name, - security_groups = [group_name], - files=files, - meta=metadata, - name=instance_name) + flavor=flavor_id, + image=image_id, + key_name=key_name, + security_groups=[group_name], + files=files, + meta=metadata, + name=instance_name) slivers.append(server) - except Exception as err: - logger.log_exc(err) - - return slivers + except Exception as err: + logger.log_exc(err) + + return slivers def delete_instance(self, instance): - + def _delete_security_group(inst): security_group = inst.metadata.get('security_groups', '') if security_group: manager = SecurityGroup(self.driver) - timeout = 10.0 # wait a maximum of 10 seconds before forcing the security group delete + timeout = 10.0 # wait a maximum of 10 seconds before forcing the security group delete start_time = time.time() instance_deleted = False while instance_deleted == False and (time.time() - start_time) < timeout: - tmp_inst = self.driver.shell.nova_manager.servers.findall(id=inst.id) + tmp_inst = self.driver.shell.nova_manager.servers.findall( + id=inst.id) if not tmp_inst: instance_deleted = True time.sleep(.5) manager.delete_security_group(security_group) multiclient = MultiClient() - tenant = self.driver.shell.auth_manager.tenants.find(id=instance.tenant_id) + tenant = self.driver.shell.auth_manager.tenants.find( + id=instance.tenant_id) self.driver.shell.nova_manager.connect(tenant=tenant.name) args = {'name': instance.name, 'id': instance.id} @@ -453,7 +484,7 @@ class OSAggregate: def restart_instances(self, instacne_name, tenant_name, id=None): self.stop_instances(instance_name, tenant_name, id) self.start_instances(instance_name, tenant_name, id) - return 1 + return 1 def update_instances(self, project_name): pass diff --git a/sfa/openstack/osxrn.py b/sfa/openstack/osxrn.py index 6a3944c9..8584b397 100644 --- a/sfa/openstack/osxrn.py +++ b/sfa/openstack/osxrn.py @@ -2,6 +2,7 @@ import re from sfa.util.xrn import Xrn from sfa.util.config import Config + def hrn_to_os_slicename(hrn): return OSXrn(xrn=hrn, type='slice').get_slicename() @@ -9,13 +10,15 @@ def hrn_to_os_slicename(hrn): def hrn_to_os_tenant_name(hrn): return OSXrn(xrn=hrn, type='slice').get_tenant_name() + def cleanup_name(name): - return name.replace(".", "_").replace("+", "_") + return name.replace(".", "_").replace("+", "_") + class OSXrn(Xrn): def __init__(self, name=None, auth=None, **kwds): - + config = Config() self.id = id if name is not None: @@ -23,15 +26,15 @@ class OSXrn(Xrn): if 'type' in kwds: self.type = kwds['type'] if auth is not None: - self.hrn='.'.join([auth, cleanup_name(name)]) + self.hrn = '.'.join([auth, cleanup_name(name)]) else: self.hrn = name.replace('_', '.') self.hrn_to_urn() else: - Xrn.__init__(self, **kwds) - - self.name = self.get_name() - + Xrn.__init__(self, **kwds) + + self.name = self.get_name() + def get_name(self): self._normalize() leaf = self.leaf @@ -40,7 +43,6 @@ class OSXrn(Xrn): name = re.sub('[^a-zA-Z0-9_]', '', name) return name - def get_slicename(self): self._normalize() slicename = self.hrn @@ -52,4 +54,3 @@ class OSXrn(Xrn): self._normalize() tenant_name = self.hrn.replace('\.', '') return tenant_name - diff --git a/sfa/openstack/security_group.py b/sfa/openstack/security_group.py index 6aced8c6..ca0e2661 100644 --- a/sfa/openstack/security_group.py +++ b/sfa/openstack/security_group.py @@ -1,11 +1,11 @@ from sfa.util.sfalogging import logger + class SecurityGroup: def __init__(self, driver): self.client = driver.shell.nova_manager - def create_security_group(self, name): try: self.client.security_groups.create(name=name, description=name) @@ -20,7 +20,6 @@ class SecurityGroup: except Exception as ex: logger.log_exc("Failed to delete security group") - def _validate_port_range(self, port_range): from_port = to_port = None if isinstance(port_range, str): @@ -44,7 +43,6 @@ class SecurityGroup: logger.error('port must be an integer.') return (from_port, to_port) - def add_rule_to_group(self, group_name=None, protocol='tcp', cidr_ip='0.0.0.0/0', port_range=None, icmp_type_code=None, source_group_name=None, source_group_owner_id=None): @@ -56,15 +54,14 @@ class SecurityGroup: from_port, to_port = icmp_type[0], icmp_type[1] group = self.client.security_groups.find(name=group_name) - self.client.security_group_rules.create(group.id, \ - protocol, from_port, to_port,cidr_ip) + self.client.security_group_rules.create(group.id, + protocol, from_port, to_port, cidr_ip) except Exception as ex: logger.log_exc("Failed to add rule to group %s" % group_name) - def remove_rule_from_group(self, group_name=None, protocol='tcp', cidr_ip='0.0.0.0/0', - port_range=None, icmp_type_code=None, - source_group_name=None, source_group_owner_id=None): + port_range=None, icmp_type_code=None, + source_group_name=None, source_group_owner_id=None): try: from_port, to_port = self._validate_port_range(port_range) icmp_type = self._validate_icmp_type_code(icmp_type_code) @@ -72,15 +69,14 @@ class SecurityGroup: from_port, to_port = icmp_type[0], icmp_type[1] group = self.client.security_groups.find(name=group_name) filter = { - 'id': group.id, + 'id': group.id, 'from_port': from_port, 'to_port': to_port, 'cidr_ip': ip, - 'ip_protocol':protocol, + 'ip_protocol': protocol, } rule = self.client.security_group_rules.find(**filter) if rule: self.client.security_group_rules.delete(rule) except Exception as ex: - logger.log_exc("Failed to remove rule from group %s" % group_name) - + logger.log_exc("Failed to remove rule from group %s" % group_name) diff --git a/sfa/openstack/shell.py b/sfa/openstack/shell.py index e31be9dc..fb6320a7 100644 --- a/sfa/openstack/shell.py +++ b/sfa/openstack/shell.py @@ -12,24 +12,23 @@ except: has_nova = False - class Shell: """ A simple native shell to a nova backend. This class can receive all nova calls to the underlying testbed """ - - # dont care about limiting calls yet + + # dont care about limiting calls yet direct_calls = [] alias_calls = {} - - # use the 'capability' auth mechanism for higher performance when the PLC db is local - def __init__ ( self, config=None) : + # use the 'capability' auth mechanism for higher performance when the PLC + # db is local + def __init__(self, config=None): if not config: config = Config() if has_nova: - # instantiate managers + # instantiate managers self.auth_manager = KeystoneClient(config=config) self.image_manager = GlanceClient(config=config) self.nova_manager = NovaClient(config=config) diff --git a/sfa/planetlab/nodemanager.py b/sfa/planetlab/nodemanager.py index 12e7f22e..eb00fdcf 100644 --- a/sfa/planetlab/nodemanager.py +++ b/sfa/planetlab/nodemanager.py @@ -2,6 +2,7 @@ import tempfile import commands import os + class NodeManager: method = None @@ -12,16 +13,16 @@ class NodeManager: def __getattr__(self, method): self.method = method return self.__call__ - + def __call__(self, *args): method = self.method - sfa_slice_prefix = self.config.SFA_CM_SLICE_PREFIX + sfa_slice_prefix = self.config.SFA_CM_SLICE_PREFIX sfa_slice = sfa_slice_prefix + "_sfacm" python = "/usr/bin/python" vserver_path = "/vservers/%s" % (sfa_slice) script_path = "/tmp/" path = "%(vserver_path)s/%(script_path)s" % locals() - (fd, filename) = tempfile.mkstemp(dir=path) + (fd, filename) = tempfile.mkstemp(dir=path) scriptname = script_path + os.sep + filename.split(os.sep)[-1:][0] # define the script to execute # when providing support for python3 wrt xmlrpclib @@ -32,7 +33,7 @@ import xmlrpclib s = xmlrpclib.ServerProxy('http://127.0.0.1:812') print s.%(method)s%(args)s""" % locals() - try: + try: # write the script to a temporary file f = open(filename, 'w') f.write(script % locals()) @@ -41,8 +42,9 @@ print s.%(method)s%(args)s""" % locals() chmod_cmd = "/bin/chmod 775 %(filename)s" % locals() (status, output) = commands.getstatusoutput(chmod_cmd) - # execute the commad as a slice with root NM privs + # execute the commad as a slice with root NM privs cmd = 'su - %(sfa_slice)s -c "%(python)s %(scriptname)s"' % locals() (status, output) = commands.getstatusoutput(cmd) - return (status, output) - finally: os.unlink(filename) + return (status, output) + finally: + os.unlink(filename) diff --git a/sfa/planetlab/peers.py b/sfa/planetlab/peers.py index 7c6e1b7c..8c676f3b 100644 --- a/sfa/planetlab/peers.py +++ b/sfa/planetlab/peers.py @@ -2,6 +2,7 @@ from sfa.util.xrn import get_authority from sfa.util.py23 import StringType + def get_peer(pldriver, hrn): # Because of myplc native federation, we first need to determine if this # slice belongs to out local plc or a myplc peer. We will assume it @@ -14,16 +15,18 @@ def get_peer(pldriver, hrn): # get this site's authority (sfa root authority or sub authority) site_authority = get_authority(slice_authority).lower() # check if we are already peered with this site_authority, if so - peers = pldriver.shell.GetPeers( {}, ['peer_id', 'peername', 'shortname', 'hrn_root']) + peers = pldriver.shell.GetPeers( + {}, ['peer_id', 'peername', 'shortname', 'hrn_root']) for peer_record in peers: - names = [name.lower() for name in peer_record.values() if isinstance(name, StringType)] + names = [name.lower() for name in peer_record.values() + if isinstance(name, StringType)] if site_authority in names: peer = peer_record['shortname'] return peer -#def get_sfa_peer(pldriver, hrn): +# def get_sfa_peer(pldriver, hrn): # # return the authority for this hrn or None if we are the authority # sfa_peer = None # slice_authority = get_authority(hrn) @@ -33,4 +36,3 @@ def get_peer(pldriver, hrn): # sfa_peer = site_authority # # return sfa_peer - diff --git a/sfa/planetlab/plaggregate.py b/sfa/planetlab/plaggregate.py index 32d6d933..8095291d 100644 --- a/sfa/planetlab/plaggregate.py +++ b/sfa/planetlab/plaggregate.py @@ -26,41 +26,45 @@ from sfa.storage.model import SliverAllocation import time + class PlAggregate: def __init__(self, driver): self.driver = driver def get_nodes(self, options=None): - if options is None: options={} + if options is None: + options = {} filter = {'peer_id': None} - geni_available = options.get('geni_available') + geni_available = options.get('geni_available') if geni_available == True: filter['boot_state'] = 'boot' nodes = self.driver.shell.GetNodes(filter) - - return nodes - + + return nodes + def get_sites(self, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} sites = {} for site in self.driver.shell.GetSites(filter): sites[site['site_id']] = site return sites def get_interfaces(self, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} interfaces = {} for interface in self.driver.shell.GetInterfaces(filter): iface = Interface() if interface['bwlimit']: - interface['bwlimit'] = str(int(interface['bwlimit'])/1000) + interface['bwlimit'] = str(int(interface['bwlimit']) / 1000) interfaces[interface['interface_id']] = interface return interfaces def get_links(self, sites, nodes, interfaces): - - topology = Topology() + + topology = Topology() links = [] for (site_id1, site_id2) in topology: site_id1 = int(site_id1) @@ -82,34 +86,44 @@ class PlAggregate: node2 = nodes[s2_node_id] # set interfaces # just get first interface of the first node - if1_xrn = PlXrn(auth=self.driver.hrn, interface='node%s:eth0' % (node1['node_id'])) + if1_xrn = PlXrn(auth=self.driver.hrn, + interface='node%s:eth0' % (node1['node_id'])) if1_ipv4 = interfaces[node1['interface_ids'][0]]['ip'] - if2_xrn = PlXrn(auth=self.driver.hrn, interface='node%s:eth0' % (node2['node_id'])) + if2_xrn = PlXrn(auth=self.driver.hrn, + interface='node%s:eth0' % (node2['node_id'])) if2_ipv4 = interfaces[node2['interface_ids'][0]]['ip'] - if1 = Interface({'component_id': if1_xrn.urn, 'ipv4': if1_ipv4} ) - if2 = Interface({'component_id': if2_xrn.urn, 'ipv4': if2_ipv4} ) + if1 = Interface( + {'component_id': if1_xrn.urn, 'ipv4': if1_ipv4}) + if2 = Interface( + {'component_id': if2_xrn.urn, 'ipv4': if2_ipv4}) # set link - link = Link({'capacity': '1000000', 'latency': '0', 'packet_loss': '0', 'type': 'ipv4'}) + link = Link({'capacity': '1000000', 'latency': '0', + 'packet_loss': '0', 'type': 'ipv4'}) link['interface1'] = if1 link['interface2'] = if2 - link['component_name'] = "%s:%s" % (site1['login_base'], site2['login_base']) - link['component_id'] = PlXrn(auth=self.driver.hrn, interface=link['component_name']).get_urn() - link['component_manager_id'] = hrn_to_urn(self.driver.hrn, 'authority+am') + link['component_name'] = "%s:%s" % ( + site1['login_base'], site2['login_base']) + link['component_id'] = PlXrn(auth=self.driver.hrn, interface=link[ + 'component_name']).get_urn() + link['component_manager_id'] = hrn_to_urn( + self.driver.hrn, 'authority+am') links.append(link) return links def get_node_tags(self, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} node_tags = {} - for node_tag in self.driver.shell.GetNodeTags(filter, ['tagname', 'value', 'node_id', 'node_tag_id'] ): + for node_tag in self.driver.shell.GetNodeTags(filter, ['tagname', 'value', 'node_id', 'node_tag_id']): node_tags[node_tag['node_tag_id']] = node_tag return node_tags def get_pl_initscripts(self, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} pl_initscripts = {} filter.update({'enabled': True}) for initscript in self.driver.shell.GetInitScripts(filter): @@ -117,7 +131,8 @@ class PlAggregate: return pl_initscripts def get_slivers(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} names = set() slice_ids = set() node_ids = [] @@ -128,13 +143,13 @@ class PlAggregate: # id: slice_id-node_id try: sliver_id_parts = xrn.get_sliver_id_parts() - slice_id = int(sliver_id_parts[0]) + slice_id = int(sliver_id_parts[0]) node_id = int(sliver_id_parts[1]) - slice_ids.add(slice_id) + slice_ids.add(slice_id) node_ids.append(node_id) except ValueError: - pass - else: + pass + else: slice_hrn = xrn.get_hrn() filter = {} @@ -142,21 +157,25 @@ class PlAggregate: if slice_ids: filter['slice_id'] = list(slice_ids) # get all slices - fields = ['slice_id', 'name', 'hrn', 'person_ids', 'node_ids', 'slice_tag_ids', 'expires'] + fields = ['slice_id', 'name', 'hrn', 'person_ids', + 'node_ids', 'slice_tag_ids', 'expires'] all_slices = self.driver.shell.GetSlices(filter, fields) if slice_hrn: - slices = [slice for slice in all_slices if slice['hrn'] == slice_hrn] + slices = [slice for slice in all_slices if slice[ + 'hrn'] == slice_hrn] else: slices = all_slices - + if not slices: if slice_hrn: - logger.error("PlAggregate.get_slivers : no slice found with hrn {}".format(slice_hrn)) + logger.error( + "PlAggregate.get_slivers : no slice found with hrn {}".format(slice_hrn)) else: - logger.error("PlAggregate.get_slivers : no sliver found with urns {}".format(urns)) + logger.error( + "PlAggregate.get_slivers : no sliver found with urns {}".format(urns)) return [] - slice = slices[0] - slice['hrn'] = slice_hrn + slice = slices[0] + slice['hrn'] = slice_hrn # get sliver users persons = [] @@ -165,37 +184,39 @@ class PlAggregate: person_ids.extend(slice['person_ids']) if person_ids: persons = self.driver.shell.GetPersons(person_ids) - + # get user keys keys = {} key_ids = [] for person in persons: key_ids.extend(person['key_ids']) - + if key_ids: key_list = self.driver.shell.GetKeys(key_ids) for key in key_list: - keys[key['key_id']] = key + keys[key['key_id']] = key # construct user key info users = [] for person in persons: - person_urn = hrn_to_urn(self.driver.shell.GetPersonHrn(int(person['person_id'])), 'user') + person_urn = hrn_to_urn(self.driver.shell.GetPersonHrn( + int(person['person_id'])), 'user') user = { - 'login': slice['name'], + 'login': slice['name'], 'user_urn': person_urn, 'keys': [keys[k_id]['key'] for k_id in person['key_ids'] if k_id in keys] } users.append(user) if node_ids: - node_ids = [node_id for node_id in node_ids if node_id in slice['node_ids']] + node_ids = [ + node_id for node_id in node_ids if node_id in slice['node_ids']] slice['node_ids'] = node_ids pltags_dict = self.get_pltags_by_node_id(slice) nodes_dict = self.get_slice_nodes(slice, options) slivers = [] for node in nodes_dict.values(): - node.update(slice) + node.update(slice) # slice-global tags node['slice-tags'] = pltags_dict['slice-global'] # xxx @@ -204,33 +225,40 @@ class PlAggregate: # xxx # sliver tags node['slice-tags'] += pltags_dict[node['node_id']] - sliver_hrn = '%s.%s-%s' % (self.driver.hrn, slice['slice_id'], node['node_id']) + sliver_hrn = '%s.%s-%s' % (self.driver.hrn, + slice['slice_id'], node['node_id']) node['sliver_id'] = Xrn(sliver_hrn, type='sliver').urn - node['urn'] = node['sliver_id'] + node['urn'] = node['sliver_id'] node['services_user'] = users slivers.append(node) if not slivers: - logger.warning("PlAggregate.get_slivers : slice(s) found but with no sliver {}".format(urns)) + logger.warning( + "PlAggregate.get_slivers : slice(s) found but with no sliver {}".format(urns)) return slivers def node_to_rspec_node(self, node, sites, interfaces, node_tags, pl_initscripts=None, grain=None, options=None): - if pl_initscripts is None: pl_initscripts=[] - if options is None: options={} + if pl_initscripts is None: + pl_initscripts = [] + if options is None: + options = {} rspec_node = NodeElement() # xxx how to retrieve site['login_base'] - site=sites[node['site_id']] - rspec_node['component_id'] = hostname_to_urn(self.driver.hrn, site['login_base'], node['hostname']) + site = sites[node['site_id']] + rspec_node['component_id'] = hostname_to_urn( + self.driver.hrn, site['login_base'], node['hostname']) rspec_node['component_name'] = node['hostname'] - rspec_node['component_manager_id'] = Xrn(self.driver.hrn, 'authority+cm').get_urn() - rspec_node['authority_id'] = hrn_to_urn(PlXrn.site_hrn(self.driver.hrn, site['login_base']), 'authority+sa') + rspec_node['component_manager_id'] = Xrn( + self.driver.hrn, 'authority+cm').get_urn() + rspec_node['authority_id'] = hrn_to_urn(PlXrn.site_hrn( + self.driver.hrn, site['login_base']), 'authority+sa') # do not include boot state ( element) in the manifest rspec rspec_node['boot_state'] = node['boot_state'] - if node['boot_state'] == 'boot': + if node['boot_state'] == 'boot': rspec_node['available'] = 'true' else: rspec_node['available'] = 'false' - #distinguish between Shared and Reservable nodes + # distinguish between Shared and Reservable nodes if node['node_type'] == 'reservable': rspec_node['exclusive'] = 'true' else: @@ -244,13 +272,14 @@ class PlAggregate: # add site/interface info to nodes. # assumes that sites, interfaces and tags have already been prepared. if site['longitude'] and site['latitude']: - location = Location({'longitude': site['longitude'], 'latitude': site['latitude'], 'country': 'unknown'}) + location = Location({'longitude': site['longitude'], 'latitude': site[ + 'latitude'], 'country': 'unknown'}) rspec_node['location'] = location # Granularity granularity = Granularity({'grain': grain}) rspec_node['granularity'] = granularity rspec_node['interfaces'] = [] - if_count=0 + if_count = 0 for if_id in node['interface_ids']: interface = Interface(interfaces[if_id]) interface['ipv4'] = interface['ip'] @@ -260,17 +289,19 @@ class PlAggregate: if slice: interface['client_id'] = "%s:%s" % (node['node_id'], if_id) rspec_node['interfaces'].append(interface) - if_count+=1 + if_count += 1 # this is what describes a particular node - node_level_tags = [PLTag(node_tags[tag_id]) for tag_id in node['node_tag_ids'] if tag_id in node_tags] + node_level_tags = [PLTag(node_tags[tag_id]) for tag_id in node[ + 'node_tag_ids'] if tag_id in node_tags] rspec_node['tags'] = node_level_tags return rspec_node - def sliver_to_rspec_node(self, sliver, sites, interfaces, node_tags, sliver_pltags, \ + def sliver_to_rspec_node(self, sliver, sites, interfaces, node_tags, sliver_pltags, pl_initscripts, sliver_allocations): # get the granularity in second for the reservation system grain = self.driver.shell.GetLeaseGranularity() - rspec_node = self.node_to_rspec_node(sliver, sites, interfaces, node_tags, pl_initscripts, grain) + rspec_node = self.node_to_rspec_node( + sliver, sites, interfaces, node_tags, pl_initscripts, grain) for pltag in sliver_pltags: logger.debug("Need to expose {}".format(pltag)) # xxx how to retrieve site['login_base'] @@ -282,21 +313,23 @@ class PlAggregate: 'name': sliver['name'], 'type': 'plab-vserver', 'tags': sliver_pltags, - }) + }) rspec_node['sliver_id'] = rspec_sliver['sliver_id'] if sliver['urn'] in sliver_allocations: - rspec_node['client_id'] = sliver_allocations[sliver['urn']].client_id + rspec_node['client_id'] = sliver_allocations[ + sliver['urn']].client_id if sliver_allocations[sliver['urn']].component_id: - rspec_node['component_id'] = sliver_allocations[sliver['urn']].component_id + rspec_node['component_id'] = sliver_allocations[ + sliver['urn']].component_id rspec_node['slivers'] = [rspec_sliver] # slivers always provide the ssh service login = Login({'authentication': 'ssh-keys', 'hostname': sliver['hostname'], - 'port':'22', + 'port': '22', 'username': sliver['name'], 'login': sliver['name'] - }) + }) service = ServicesElement({'login': login, 'services_user': sliver['services_user']}) rspec_node['services'] = [service] @@ -316,7 +349,8 @@ class PlAggregate: pltags_dict[tag['node_id']].append(PLTag(tag)) # restricted to a nodegroup # for now such tags are not exposed to describe - # xxx we should also expose the nodegroup name in this case to be complete.. + # xxx we should also expose the nodegroup name in this case to be + # complete.. elif tag['nodegroup_id']: tag['scope'] = 'nodegroup' pltags_dict['nodegroup'].append(PLTag(tag)) @@ -327,7 +361,8 @@ class PlAggregate: return pltags_dict def get_slice_nodes(self, slice, options=None): - if options is None: options={} + if options is None: + options = {} nodes_dict = {} filter = {'peer_id': None} tags_filter = {} @@ -336,7 +371,7 @@ class PlAggregate: else: # there are no nodes to look up return nodes_dict - tags_filter=filter.copy() + tags_filter = filter.copy() geni_available = options.get('geni_available') if geni_available == True: filter['boot_state'] = 'boot' @@ -346,14 +381,15 @@ class PlAggregate: return nodes_dict def rspec_node_to_geni_sliver(self, rspec_node, sliver_allocations=None): - if sliver_allocations is None: sliver_allocations={} + if sliver_allocations is None: + sliver_allocations = {} if rspec_node['sliver_id'] in sliver_allocations: # set sliver allocation and operational status sliver_allocation = sliver_allocations[rspec_node['sliver_id']] if sliver_allocation: allocation_status = sliver_allocation.allocation_state if allocation_status == 'geni_allocated': - op_status = 'geni_pending_allocation' + op_status = 'geni_pending_allocation' elif allocation_status == 'geni_provisioned': if rspec_node['boot_state'] == 'boot': op_status = 'geni_ready' @@ -362,28 +398,30 @@ class PlAggregate: else: op_status = 'geni_unknown' else: - allocation_status = 'geni_unallocated' + allocation_status = 'geni_unallocated' else: allocation_status = 'geni_unallocated' op_status = 'geni_failed' # required fields geni_sliver = {'geni_sliver_urn': rspec_node['sliver_id'], 'geni_expires': rspec_node['expires'], - 'geni_allocation_status' : allocation_status, + 'geni_allocation_status': allocation_status, 'geni_operational_status': op_status, 'geni_error': '', } - return geni_sliver + return geni_sliver def get_leases(self, slice=None, options=None): - if options is None: options={} - + if options is None: + options = {} + now = int(time.time()) - filter={} - filter.update({'clip':now}) + filter = {} + filter.update({'clip': now}) if slice: - filter.update({'name':slice['name']}) - return_fields = ['lease_id', 'hostname', 'site_id', 'name', 't_from', 't_until'] + filter.update({'name': slice['name']}) + return_fields = ['lease_id', 'hostname', + 'site_id', 'name', 't_from', 't_until'] leases = self.driver.shell.GetLeases(filter) grain = self.driver.shell.GetLeaseGranularity() @@ -392,38 +430,41 @@ class PlAggregate: site_ids.append(lease['site_id']) # get sites - sites_dict = self.get_sites({'site_id': site_ids}) - + sites_dict = self.get_sites({'site_id': site_ids}) + rspec_leases = [] for lease in leases: rspec_lease = Lease() - + # xxx how to retrieve site['login_base'] - site_id=lease['site_id'] - site=sites_dict[site_id] + site_id = lease['site_id'] + site = sites_dict[site_id] - rspec_lease['component_id'] = hrn_to_urn(self.driver.shell.GetNodeHrn(lease['hostname']), 'node') + rspec_lease['component_id'] = hrn_to_urn( + self.driver.shell.GetNodeHrn(lease['hostname']), 'node') slice_hrn = self.driver.shell.GetSliceHrn(lease['slice_id']) slice_urn = hrn_to_urn(slice_hrn, 'slice') rspec_lease['slice_id'] = slice_urn rspec_lease['start_time'] = lease['t_from'] - rspec_lease['duration'] = (lease['t_until'] - lease['t_from']) / grain + rspec_lease['duration'] = ( + lease['t_until'] - lease['t_from']) / grain rspec_leases.append(rspec_lease) return rspec_leases - - def list_resources(self, version = None, options=None): - if options is None: options={} + def list_resources(self, version=None, options=None): + if options is None: + options = {} version_manager = VersionManager() version = version_manager.get_version(version) - rspec_version = version_manager._get_version(version.type, version.version, 'ad') + rspec_version = version_manager._get_version( + version.type, version.version, 'ad') rspec = RSpec(version=rspec_version, user_options=options) - + if not options.get('list_leases') or options['list_leases'] != 'leases': # get nodes - nodes = self.get_nodes(options) + nodes = self.get_nodes(options) site_ids = [] interface_ids = [] tag_ids = [] @@ -434,31 +475,34 @@ class PlAggregate: tag_ids.extend(node['node_tag_ids']) nodes_dict[node['node_id']] = node sites = self.get_sites({'site_id': site_ids}) - interfaces = self.get_interfaces({'interface_id':interface_ids}) + interfaces = self.get_interfaces({'interface_id': interface_ids}) node_tags = self.get_node_tags({'node_tag_id': tag_ids}) pl_initscripts = self.get_pl_initscripts() # convert nodes to rspec nodes rspec_nodes = [] for node in nodes: - rspec_node = self.node_to_rspec_node(node, sites, interfaces, node_tags, pl_initscripts) + rspec_node = self.node_to_rspec_node( + node, sites, interfaces, node_tags, pl_initscripts) rspec_nodes.append(rspec_node) rspec.version.add_nodes(rspec_nodes) # add links - links = self.get_links(sites, nodes_dict, interfaces) + links = self.get_links(sites, nodes_dict, interfaces) rspec.version.add_links(links) if not options.get('list_leases') or options.get('list_leases') and options['list_leases'] != 'resources': - leases = self.get_leases() - rspec.version.add_leases(leases) + leases = self.get_leases() + rspec.version.add_leases(leases) return rspec.toxml() def describe(self, urns, version=None, options=None): - if options is None: options={} + if options is None: + options = {} version_manager = VersionManager() version = version_manager.get_version(version) - rspec_version = version_manager._get_version(version.type, version.version, 'manifest') + rspec_version = version_manager._get_version( + version.type, version.version, 'manifest') rspec = RSpec(version=rspec_version, user_options=options) # get slivers @@ -467,19 +511,21 @@ class PlAggregate: if slivers: rspec_expires = datetime_to_string(utcparse(slivers[0]['expires'])) else: - rspec_expires = datetime_to_string(utcparse(time.time())) + rspec_expires = datetime_to_string(utcparse(time.time())) rspec.xml.set('expires', rspec_expires) # lookup the sliver allocations geni_urn = urns[0] sliver_ids = [sliver['sliver_id'] for sliver in slivers] constraint = SliverAllocation.sliver_id.in_(sliver_ids) - sliver_allocations = self.driver.api.dbsession().query(SliverAllocation).filter(constraint) + sliver_allocations = self.driver.api.dbsession().query( + SliverAllocation).filter(constraint) sliver_allocation_dict = {} for sliver_allocation in sliver_allocations: geni_urn = sliver_allocation.slice_urn - sliver_allocation_dict[sliver_allocation.sliver_id] = sliver_allocation - + sliver_allocation_dict[ + sliver_allocation.sliver_id] = sliver_allocation + if not options.get('list_leases') or options['list_leases'] != 'leases': # add slivers site_ids = [] @@ -492,7 +538,7 @@ class PlAggregate: tag_ids.extend(sliver['node_tag_ids']) nodes_dict[sliver['node_id']] = sliver sites = self.get_sites({'site_id': site_ids}) - interfaces = self.get_interfaces({'interface_id':interface_ids}) + interfaces = self.get_interfaces({'interface_id': interface_ids}) node_tags = self.get_node_tags({'node_tag_id': tag_ids}) pl_initscripts = self.get_pl_initscripts() rspec_nodes = [] @@ -500,25 +546,27 @@ class PlAggregate: if sliver['slice_ids_whitelist'] and sliver['slice_id'] not in sliver['slice_ids_whitelist']: continue sliver_pltags = sliver['slice-tags'] - rspec_node = self.sliver_to_rspec_node(sliver, sites, interfaces, node_tags, sliver_pltags, + rspec_node = self.sliver_to_rspec_node(sliver, sites, interfaces, node_tags, sliver_pltags, pl_initscripts, sliver_allocation_dict) - logger.debug('rspec of type {}'.format(rspec_node.__class__.__name__)) + logger.debug('rspec of type {}'.format( + rspec_node.__class__.__name__)) # manifest node element shouldn't contain available attribute rspec_node.pop('available') - rspec_nodes.append(rspec_node) - geni_sliver = self.rspec_node_to_geni_sliver(rspec_node, sliver_allocation_dict) + rspec_nodes.append(rspec_node) + geni_sliver = self.rspec_node_to_geni_sliver( + rspec_node, sliver_allocation_dict) geni_slivers.append(geni_sliver) rspec.version.add_nodes(rspec_nodes) # add sliver defaults #default_sliver = slivers.get(None, []) - #if default_sliver: + # if default_sliver: # default_sliver_attribs = default_sliver.get('tags', []) # for attrib in default_sliver_attribs: # rspec.version.add_default_sliver_attribute(attrib['tagname'], attrib['value']) - # add links - links = self.get_links(sites, nodes_dict, interfaces) + # add links + links = self.get_links(sites, nodes_dict, interfaces) rspec.version.add_links(links) if not options.get('list_leases') or options['list_leases'] != 'resources': @@ -526,7 +574,6 @@ class PlAggregate: leases = self.get_leases(slivers[0]) rspec.version.add_leases(leases) - - return {'geni_urn': geni_urn, + return {'geni_urn': geni_urn, 'geni_rspec': rspec.toxml(), 'geni_slivers': geni_slivers} diff --git a/sfa/planetlab/plcomponentdriver.py b/sfa/planetlab/plcomponentdriver.py index 25f9b7af..39788bb5 100644 --- a/sfa/planetlab/plcomponentdriver.py +++ b/sfa/planetlab/plcomponentdriver.py @@ -9,6 +9,8 @@ from sfa.trust.certificate import Certificate, Keypair from sfa.trust.gid import GID #################### + + class PlComponentDriver: """ This class is the type for the toplevel 'api' object @@ -18,12 +20,12 @@ class PlComponentDriver: some tweaks as compared with a service running in the infrastructure. """ - def __init__ (self, config): + def __init__(self, config): self.nodemanager = NodeManager(config) def sliver_exists(self): sliver_dict = self.nodemanager.GetXIDs() - ### xxx slicename is undefined + # xxx slicename is undefined if slicename in sliver_dict.keys(): return True else: @@ -32,14 +34,14 @@ class PlComponentDriver: def get_registry(self): addr, port = self.config.SFA_REGISTRY_HOST, self.config.SFA_REGISTRY_PORT url = "http://%(addr)s:%(port)s" % locals() - ### xxx this would require access to the api... + # xxx this would require access to the api... server = SfaServerProxy(url, self.key_file, self.cert_file) return server def get_node_key(self): # this call requires no authentication, # so we can generate a random keypair here - subject="component" + subject = "component" (kfd, keyfile) = tempfile.mkstemp() (cfd, certfile) = tempfile.mkstemp() key = Keypair(create=True) @@ -51,7 +53,7 @@ class PlComponentDriver: cert.save_to_file(certfile) registry = self.get_registry() # the registry will scp the key onto the node - registry.get_key_from_incoming_ip() + registry.get_key_from_incoming_ip() # override the method in SfaApi def getCredential(self): @@ -62,7 +64,7 @@ class PlComponentDriver: config_dir = self.config.config_path cred_filename = path + os.sep + 'node.cred' try: - credential = Credential(filename = cred_filename) + credential = Credential(filename=cred_filename) return credential.save_to_string(save_parents=True) except IOError: node_pkey_file = config_dir + os.sep + "node.key" @@ -76,11 +78,12 @@ class PlComponentDriver: gid = GID(filename=node_gid_file) hrn = gid.get_hrn() # get credential from registry - cert_str = Certificate(filename=cert_filename).save_to_string(save_parents=True) + cert_str = Certificate( + filename=cert_filename).save_to_string(save_parents=True) registry = self.get_registry() cred = registry.GetSelfCredential(cert_str, hrn, 'node') # xxx credfile is undefined - Credential(string=cred).save_to_file(credfile, save_parents=True) + Credential(string=cred).save_to_file(credfile, save_parents=True) return cred @@ -90,7 +93,8 @@ class PlComponentDriver: """ files = ["server.key", "server.cert", "node.cred"] for f in files: - # xxx KEYDIR is undefined, could be meant to be "/var/lib/sfa/" from sfa_component_setup.py + # xxx KEYDIR is undefined, could be meant to be "/var/lib/sfa/" + # from sfa_component_setup.py filepath = KEYDIR + os.sep + f if os.path.isfile(filepath): os.unlink(f) diff --git a/sfa/planetlab/pldriver.py b/sfa/planetlab/pldriver.py index 36d8bd04..6a917a46 100644 --- a/sfa/planetlab/pldriver.py +++ b/sfa/planetlab/pldriver.py @@ -2,7 +2,7 @@ import datetime # from sfa.util.faults import MissingSfaInfo, UnknownSfaType, \ RecordNotFound, SfaNotImplemented, SliverDoesNotExist, SearchFailed, \ - UnsupportedOperation, Forbidden + UnsupportedOperation, Forbidden from sfa.util.sfalogging import logger from sfa.util.defaultdict import defaultdict from sfa.util.sfatime import utcparse, datetime_to_string, datetime_to_epoch @@ -31,22 +31,25 @@ def list_to_dict(recs, key): convert a list of dictionaries into a dictionary keyed on the specified dictionary key """ - return { rec[key] : rec for rec in recs } + return {rec[key]: rec for rec in recs} # # PlShell is just an xmlrpc serverproxy where methods # can be sent as-is; it takes care of authentication # from the global config -# +# + + class PlDriver (Driver): - # the cache instance is a class member so it survives across incoming requests + # the cache instance is a class member so it survives across incoming + # requests cache = None - def __init__ (self, api): - Driver.__init__ (self, api) + def __init__(self, api): + Driver.__init__(self, api) config = api.config - self.shell = PlShell (config) + self.shell = PlShell(config) self.cache = None if config.SFA_AGGREGATE_CACHING: if PlDriver.cache is None: @@ -59,19 +62,20 @@ class PlDriver (Driver): try: filter['slice_id'] = int(sliver_id_parts[0]) except ValueError: - filter['name'] = sliver_id_parts[0] + filter['name'] = sliver_id_parts[0] slices = self.shell.GetSlices(filter, ['hrn']) if not slices: - raise Forbidden("Unable to locate slice record for sliver: {}".format(xrn)) + raise Forbidden( + "Unable to locate slice record for sliver: {}".format(xrn)) slice = slices[0] slice_xrn = slice['hrn'] - return slice_xrn - + return slice_xrn + def check_sliver_credentials(self, creds, urns): # build list of cred object hrns slice_cred_names = [] for cred in creds: - slice_cred_hrn = Credential(cred=cred).get_gid_object().get_hrn() + slice_cred_hrn = Credential(cred=cred).get_gid_object().get_hrn() top_auth_hrn = top_auth(slice_cred_hrn) site_hrn = '.'.join(slice_cred_hrn.split('.')[:-1]) slice_part = slice_cred_hrn.split('.')[-1] @@ -80,8 +84,8 @@ class PlDriver (Driver): else: login_base = hash_loginbase(site_hrn) - slicename = '_'.join([login_base, slice_part]) - slice_cred_names.append(slicename) + slicename = '_'.join([login_base, slice_part]) + slice_cred_names.append(slicename) # look up slice name of slivers listed in urns arg slice_ids = [] @@ -89,11 +93,11 @@ class PlDriver (Driver): sliver_id_parts = Xrn(xrn=urn).get_sliver_id_parts() try: slice_ids.append(int(sliver_id_parts[0])) - except ValueError: + except ValueError: pass if not slice_ids: - raise Forbidden("sliver urn not provided") + raise Forbidden("sliver urn not provided") slices = self.shell.GetSlices(slice_ids) sliver_names = [slice['name'] for slice in slices] @@ -101,25 +105,28 @@ class PlDriver (Driver): # make sure we have a credential for every specified sliver ierd for sliver_name in sliver_names: if sliver_name not in slice_cred_names: - msg = "Valid credential not found for target: {}".format(sliver_name) + msg = "Valid credential not found for target: {}".format( + sliver_name) raise Forbidden(msg) ######################################## - ########## registry oriented + # registry oriented ######################################## - def augment_records_with_testbed_info (self, sfa_records): - return self.fill_record_info (sfa_records) + def augment_records_with_testbed_info(self, sfa_records): + return self.fill_record_info(sfa_records) - ########## - def register (self, sfa_record, hrn, pub_key): + ########## + def register(self, sfa_record, hrn, pub_key): type = sfa_record['type'] pl_record = self.sfa_fields_to_pl_fields(type, hrn, sfa_record) if type == 'authority': - sites = self.shell.GetSites({'peer_id': None, 'login_base': pl_record['login_base']}) + sites = self.shell.GetSites( + {'peer_id': None, 'login_base': pl_record['login_base']}) if not sites: - # xxx when a site gets registered through SFA we need to set its max_slices + # xxx when a site gets registered through SFA we need to set + # its max_slices if 'max_slices' not in pl_record: pl_record['max_slices'] = 2 pointer = self.shell.AddSite(pl_record) @@ -132,42 +139,47 @@ class PlDriver (Driver): for key in pl_record.keys(): if key not in acceptable_fields: pl_record.pop(key) - slices = self.shell.GetSlices({'peer_id': None, 'name': pl_record['name']}) + slices = self.shell.GetSlices( + {'peer_id': None, 'name': pl_record['name']}) if not slices: - if not pl_record.get('url', None) or not pl_record.get('description', None): - pl_record['url'] = hrn - pl_record['description'] = hrn + if not pl_record.get('url', None) or not pl_record.get('description', None): + pl_record['url'] = hrn + pl_record['description'] = hrn - pointer = self.shell.AddSlice(pl_record) - self.shell.SetSliceHrn(int(pointer), hrn) + pointer = self.shell.AddSlice(pl_record) + self.shell.SetSliceHrn(int(pointer), hrn) else: - pointer = slices[0]['slice_id'] + pointer = slices[0]['slice_id'] elif type == 'user': - persons = self.shell.GetPersons({'peer_id': None, 'email': sfa_record['email']}) + persons = self.shell.GetPersons( + {'peer_id': None, 'email': sfa_record['email']}) if not persons: - for key in ['first_name','last_name']: + for key in ['first_name', 'last_name']: if key not in sfa_record: sfa_record[key] = '*from*sfa*' # AddPerson does not allow everything to be set - can_add = ['first_name', 'last_name', 'title','email', 'password', 'phone', 'url', 'bio'] - add_person_dict = { k : sfa_record[k] for k in sfa_record if k in can_add } + can_add = ['first_name', 'last_name', 'title', + 'email', 'password', 'phone', 'url', 'bio'] + add_person_dict = {k: sfa_record[k] + for k in sfa_record if k in can_add} pointer = self.shell.AddPerson(add_person_dict) self.shell.SetPersonHrn(int(pointer), hrn) else: pointer = persons[0]['person_id'] - + # enable the person's account self.shell.UpdatePerson(pointer, {'enabled': True}) # add this person to the site login_base = get_leaf(sfa_record['authority']) self.shell.AddPersonToSite(pointer, login_base) - + # What roles should this user have? roles = [] - if 'roles' in sfa_record: + if 'roles' in sfa_record: # if specified in xml, but only low-level roles - roles = [ role for role in sfa_record['roles'] if role in ['user','tech'] ] + roles = [role for role in sfa_record[ + 'roles'] if role in ['user', 'tech']] # at least user if no other cluse could be found if not roles: roles = ['user'] @@ -175,42 +187,47 @@ class PlDriver (Driver): self.shell.AddRoleToPerson(role, pointer) # Add the user's key if pub_key: - self.shell.AddPersonKey(pointer, {'key_type' : 'ssh', 'key' : pub_key}) + self.shell.AddPersonKey( + pointer, {'key_type': 'ssh', 'key': pub_key}) elif type == 'node': - login_base = PlXrn(xrn=sfa_record['authority'], type='authority').pl_login_base() - nodes = self.shell.GetNodes({'peer_id': None, 'hostname': pl_record['hostname']}) + login_base = PlXrn( + xrn=sfa_record['authority'], type='authority').pl_login_base() + nodes = self.shell.GetNodes( + {'peer_id': None, 'hostname': pl_record['hostname']}) if not nodes: pointer = self.shell.AddNode(login_base, pl_record) self.shell.SetNodeHrn(int(pointer), hrn) else: pointer = nodes[0]['node_id'] - + return pointer - + ########## - # xxx actually old_sfa_record comes filled with plc stuff as well in the original code - def update (self, old_sfa_record, new_sfa_record, hrn, new_key): + # xxx actually old_sfa_record comes filled with plc stuff as well in the + # original code + def update(self, old_sfa_record, new_sfa_record, hrn, new_key): pointer = old_sfa_record['pointer'] type = old_sfa_record['type'] new_key_pointer = None # new_key implemented for users only - if new_key and type not in [ 'user' ]: + if new_key and type not in ['user']: raise UnknownSfaType(type) if (type == "authority"): - logger.debug("pldriver.update: calling UpdateSite with {}".format(new_sfa_record)) + logger.debug( + "pldriver.update: calling UpdateSite with {}".format(new_sfa_record)) self.shell.UpdateSite(pointer, new_sfa_record) self.shell.SetSiteHrn(pointer, hrn) - + elif type == "slice": pl_record = self.sfa_fields_to_pl_fields(type, hrn, new_sfa_record) if 'name' in pl_record: pl_record.pop('name') self.shell.UpdateSlice(pointer, pl_record) self.shell.SetSliceHrn(pointer, hrn) - + elif type == "user": # SMBAKER: UpdatePerson only allows a limited set of fields to be # updated. Ideally we should have a more generic way of doing @@ -223,19 +240,21 @@ class PlDriver (Driver): 'enabled']: update_fields[key] = all_fields[key] # when updating a user, we always get a 'email' field at this point - # this is because 'email' is a native field in the RegUser object... + # this is because 'email' is a native field in the RegUser + # object... if 'email' in update_fields and not update_fields['email']: del update_fields['email'] self.shell.UpdatePerson(pointer, update_fields) self.shell.SetPersonHrn(pointer, hrn) - + if new_key: # must check this key against the previous one if it exists - persons = self.shell.GetPersons({'peer_id': None, 'person_id': pointer}, ['key_ids']) + persons = self.shell.GetPersons( + {'peer_id': None, 'person_id': pointer}, ['key_ids']) person = persons[0] keys = person['key_ids'] keys = self.shell.GetKeys(person['key_ids']) - + key_exists = False for key in keys: if new_key == key['key']: @@ -243,20 +262,21 @@ class PlDriver (Driver): new_key_pointer = key['key_id'] break if not key_exists: - new_key_pointer = self.shell.AddPersonKey(pointer, {'key_type': 'ssh', 'key': new_key}) - + new_key_pointer = self.shell.AddPersonKey( + pointer, {'key_type': 'ssh', 'key': new_key}) + elif type == "node": self.shell.UpdateNode(pointer, new_sfa_record) return (pointer, new_key_pointer) - ########## - def remove (self, sfa_record): + def remove(self, sfa_record): type = sfa_record['type'] pointer = sfa_record['pointer'] if type == 'user': - persons = self.shell.GetPersons({'peer_id': None, 'person_id': pointer}) + persons = self.shell.GetPersons( + {'peer_id': None, 'person_id': pointer}) # only delete this person if he has site ids. if he doesnt, it probably means # he was just removed from a site, not actually deleted if persons and persons[0]['site_ids']: @@ -273,7 +293,6 @@ class PlDriver (Driver): return True - ## # Convert SFA fields to PLC fields for use when registering or updating # registry record in the PLC database @@ -282,17 +301,17 @@ class PlDriver (Driver): def sfa_fields_to_pl_fields(self, type, hrn, sfa_record): pl_record = {} - + if type == "slice": pl_record["name"] = hrn_to_pl_slicename(hrn) if "instantiation" in sfa_record: pl_record['instantiation'] = sfa_record['instantiation'] else: pl_record["instantiation"] = "plc-instantiated" - if "url" in sfa_record: - pl_record["url"] = sfa_record["url"] - if "description" in sfa_record: - pl_record["description"] = sfa_record["description"] + if "url" in sfa_record: + pl_record["url"] = sfa_record["url"] + if "description" in sfa_record: + pl_record["description"] = sfa_record["description"] if "expires" in sfa_record: date = utcparse(sfa_record['expires']) expires = datetime_to_epoch(date) @@ -304,13 +323,14 @@ class PlDriver (Driver): if "hostname" not in sfa_record: raise MissingSfaInfo("hostname") pl_record["hostname"] = sfa_record["hostname"] - if "model" in sfa_record: + if "model" in sfa_record: pl_record["model"] = sfa_record["model"] else: pl_record["model"] = "geni" elif type == "authority": - pl_record["login_base"] = PlXrn(xrn=hrn,type='authority').pl_login_base() + pl_record["login_base"] = PlXrn( + xrn=hrn, type='authority').pl_login_base() if "name" not in sfa_record or not sfa_record['name']: pl_record["name"] = hrn if "abbreviated_name" not in sfa_record: @@ -341,15 +361,15 @@ class PlDriver (Driver): Fill in the planetlab specific fields of a SFA record. This involves calling the appropriate PLC method to retrieve the database record for the object. - + @param record: record to fill in field (in/out param) """ # get ids by type - node_ids, site_ids, slice_ids = [], [], [] + node_ids, site_ids, slice_ids = [], [], [] person_ids, key_ids = [], [] type_map = {'node': node_ids, 'authority': site_ids, 'slice': slice_ids, 'user': person_ids} - + for record in records: for type in type_map: if type == record['type']: @@ -358,16 +378,20 @@ class PlDriver (Driver): # get pl records nodes, sites, slices, persons, keys = {}, {}, {}, {}, {} if node_ids: - node_list = self.shell.GetNodes({'peer_id': None, 'node_id': node_ids}) + node_list = self.shell.GetNodes( + {'peer_id': None, 'node_id': node_ids}) nodes = list_to_dict(node_list, 'node_id') if site_ids: - site_list = self.shell.GetSites({'peer_id': None, 'site_id': site_ids}) + site_list = self.shell.GetSites( + {'peer_id': None, 'site_id': site_ids}) sites = list_to_dict(site_list, 'site_id') if slice_ids: - slice_list = self.shell.GetSlices({'peer_id': None, 'slice_id': slice_ids}) + slice_list = self.shell.GetSlices( + {'peer_id': None, 'slice_id': slice_ids}) slices = list_to_dict(slice_list, 'slice_id') if person_ids: - person_list = self.shell.GetPersons({'peer_id': None, 'person_id': person_ids}) + person_list = self.shell.GetPersons( + {'peer_id': None, 'person_id': person_ids}) persons = list_to_dict(person_list, 'person_id') for person in persons: key_ids.extend(persons[person]['key_ids']) @@ -386,7 +410,7 @@ class PlDriver (Driver): # authorities, but not PL "sites" if record['pointer'] == -1: continue - + for type in pl_records: if record['type'] == type: if record['pointer'] in pl_records[type]: @@ -395,9 +419,11 @@ class PlDriver (Driver): # fill in key info if record['type'] == 'user': if 'key_ids' not in record: - logger.info("user record has no 'key_ids' - need to import from myplc ?") + logger.info( + "user record has no 'key_ids' - need to import from myplc ?") else: - pubkeys = [keys[key_id]['key'] for key_id in record['key_ids'] if key_id in keys] + pubkeys = [keys[key_id]['key'] + for key_id in record['key_ids'] if key_id in keys] record['keys'] = pubkeys return records @@ -424,18 +450,22 @@ class PlDriver (Driver): # get pl records slices, persons, sites, nodes = {}, {}, {}, {} if site_ids: - site_list = self.shell.GetSites({'peer_id': None, 'site_id': site_ids}, ['site_id', 'login_base']) + site_list = self.shell.GetSites({'peer_id': None, 'site_id': site_ids}, [ + 'site_id', 'login_base']) sites = list_to_dict(site_list, 'site_id') if person_ids: - person_list = self.shell.GetPersons({'peer_id': None, 'person_id': person_ids}, ['person_id', 'email']) + person_list = self.shell.GetPersons( + {'peer_id': None, 'person_id': person_ids}, ['person_id', 'email']) persons = list_to_dict(person_list, 'person_id') if slice_ids: - slice_list = self.shell.GetSlices({'peer_id': None, 'slice_id': slice_ids}, ['slice_id', 'name']) - slices = list_to_dict(slice_list, 'slice_id') + slice_list = self.shell.GetSlices( + {'peer_id': None, 'slice_id': slice_ids}, ['slice_id', 'name']) + slices = list_to_dict(slice_list, 'slice_id') if node_ids: - node_list = self.shell.GetNodes({'peer_id': None, 'node_id': node_ids}, ['node_id', 'hostname']) + node_list = self.shell.GetNodes( + {'peer_id': None, 'node_id': node_ids}, ['node_id', 'hostname']) nodes = list_to_dict(node_list, 'node_id') - + # convert ids to hrns for record in records: # get all relevant data @@ -451,33 +481,37 @@ class PlDriver (Driver): login_base = site['login_base'] record['site'] = ".".join([auth_hrn, login_base]) if 'person_ids' in record: - emails = [persons[person_id]['email'] for person_id in record['person_ids'] \ - if person_id in persons] + emails = [persons[person_id]['email'] for person_id in record['person_ids'] + if person_id in persons] usernames = [email.split('@')[0] for email in emails] - person_hrns = [".".join([auth_hrn, login_base, username]) for username in usernames] - record['persons'] = person_hrns + person_hrns = [".".join([auth_hrn, login_base, username]) + for username in usernames] + record['persons'] = person_hrns if 'slice_ids' in record: - slicenames = [slices[slice_id]['name'] for slice_id in record['slice_ids'] \ + slicenames = [slices[slice_id]['name'] for slice_id in record['slice_ids'] if slice_id in slices] - slice_hrns = [slicename_to_hrn(auth_hrn, slicename) for slicename in slicenames] + slice_hrns = [slicename_to_hrn( + auth_hrn, slicename) for slicename in slicenames] record['slices'] = slice_hrns if 'node_ids' in record: - hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] \ + hostnames = [nodes[node_id]['hostname'] for node_id in record['node_ids'] if node_id in nodes] - node_hrns = [hostname_to_hrn(auth_hrn, login_base, hostname) for hostname in hostnames] + node_hrns = [hostname_to_hrn( + auth_hrn, login_base, hostname) for hostname in hostnames] record['nodes'] = node_hrns if 'site_ids' in record: - login_bases = [sites[site_id]['login_base'] for site_id in record['site_ids'] \ + login_bases = [sites[site_id]['login_base'] for site_id in record['site_ids'] if site_id in sites] - site_hrns = [".".join([auth_hrn, lbase]) for lbase in login_bases] + site_hrns = [".".join([auth_hrn, lbase]) + for lbase in login_bases] record['sites'] = site_hrns if 'expires' in record: date = utcparse(record['expires']) datestring = datetime_to_string(date) - record['expires'] = datestring - - return records + record['expires'] = datestring + + return records def fill_record_sfa_info(self, records): @@ -489,20 +523,22 @@ class PlDriver (Driver): site_ids = [] for record in records: person_ids.extend(record.get("person_ids", [])) - site_ids.extend(record.get("site_ids", [])) + site_ids.extend(record.get("site_ids", [])) if 'site_id' in record: - site_ids.append(record['site_id']) - + site_ids.append(record['site_id']) + # get all pis from the sites we've encountered - # and store them in a dictionary keyed on site_id + # and store them in a dictionary keyed on site_id site_pis = {} if site_ids: - pi_filter = {'peer_id': None, '|roles': ['pi'], '|site_ids': site_ids} - pi_list = self.shell.GetPersons(pi_filter, ['person_id', 'site_ids']) + pi_filter = {'peer_id': None, '|roles': [ + 'pi'], '|site_ids': site_ids} + pi_list = self.shell.GetPersons( + pi_filter, ['person_id', 'site_ids']) for pi in pi_list: # we will need the pi's hrns also person_ids.append(pi['person_id']) - + # we also need to keep track of the sites these pis # belong to for site_id in pi['site_ids']: @@ -510,14 +546,15 @@ class PlDriver (Driver): site_pis[site_id].append(pi) else: site_pis[site_id] = [pi] - - # get sfa records for all records associated with these records. + + # get sfa records for all records associated with these records. # we'll replace pl ids (person_ids) with hrns from the sfa records # we obtain - + # get the registry records person_list, persons = [], {} - person_list = self.api.dbsession().query (RegRecord).filter(RegRecord.pointer.in_(person_ids)) + person_list = self.api.dbsession().query( + RegRecord).filter(RegRecord.pointer.in_(person_ids)) # create a hrns keyed on the sfa record's pointer. # Its possible for multiple records to have the same pointer so # the dict's value will be a list of hrns. @@ -527,17 +564,19 @@ class PlDriver (Driver): # get the pl records pl_person_list, pl_persons = [], {} - pl_person_list = self.shell.GetPersons(person_ids, ['person_id', 'roles']) + pl_person_list = self.shell.GetPersons( + person_ids, ['person_id', 'roles']) pl_persons = list_to_dict(pl_person_list, 'person_id') # fill sfa info for record in records: # skip records with no pl info (top level authorities) - #if record['pointer'] == -1: - # continue + # if record['pointer'] == -1: + # continue sfa_info = {} type = record['type'] - logger.info("fill_record_sfa_info - incoming record typed {}".format(type)) + logger.info( + "fill_record_sfa_info - incoming record typed {}".format(type)) if (type == "slice"): # all slice users are researchers record['geni_urn'] = hrn_to_urn(record['hrn'], 'slice') @@ -545,7 +584,7 @@ class PlDriver (Driver): record['researcher'] = [] for person_id in record.get('person_ids', []): hrns = [person.hrn for person in persons[person_id]] - record['researcher'].extend(hrns) + record['researcher'].extend(hrns) # pis at the slice's site if 'site_id' in record and record['site_id'] in site_pis: @@ -554,8 +593,8 @@ class PlDriver (Driver): for person_id in pi_ids: hrns = [person.hrn for person in persons[person_id]] record['PI'].extend(hrns) - record['geni_creator'] = record['PI'] - + record['geni_creator'] = record['PI'] + elif (type.startswith("authority")): record['url'] = None logger.info("fill_record_sfa_info - authority xherex") @@ -565,10 +604,11 @@ class PlDriver (Driver): record['owner'] = [] for pointer in record.get('person_ids', []): if pointer not in persons or pointer not in pl_persons: - # this means there is not sfa or pl record for this user - continue - hrns = [person.hrn for person in persons[pointer]] - roles = pl_persons[pointer]['roles'] + # this means there is not sfa or pl record for this + # user + continue + hrns = [person.hrn for person in persons[pointer]] + roles = pl_persons[pointer]['roles'] if 'pi' in roles: record['PI'].extend(hrns) if 'tech' in roles: @@ -579,74 +619,82 @@ class PlDriver (Driver): elif (type == "node"): sfa_info['dns'] = record.get("hostname", "") # xxx TODO: URI, LatLong, IP, DNS - + elif (type == "user"): logger.info('setting user.email') sfa_info['email'] = record.get("email", "") sfa_info['geni_urn'] = hrn_to_urn(record['hrn'], 'user') - sfa_info['geni_certificate'] = record['gid'] + sfa_info['geni_certificate'] = record['gid'] # xxx TODO: PostalAddress, Phone record.update(sfa_info) - #################### # plcapi works by changes, compute what needs to be added/deleted - def update_relation (self, subject_type, target_type, relation_name, subject_id, target_ids): + def update_relation(self, subject_type, target_type, relation_name, subject_id, target_ids): # hard-wire the code for slice/user for now, could be smarter if needed if subject_type == 'slice' and target_type == 'user' and relation_name == 'researcher': - subject = self.shell.GetSlices (subject_id)[0] + subject = self.shell.GetSlices(subject_id)[0] current_target_ids = subject['person_ids'] - add_target_ids = list ( set (target_ids).difference(current_target_ids)) - del_target_ids = list ( set (current_target_ids).difference(target_ids)) - logger.debug ("subject_id = {} (type={})".format(subject_id, type(subject_id))) + add_target_ids = list( + set(target_ids).difference(current_target_ids)) + del_target_ids = list( + set(current_target_ids).difference(target_ids)) + logger.debug("subject_id = {} (type={})".format( + subject_id, type(subject_id))) for target_id in add_target_ids: - self.shell.AddPersonToSlice (target_id,subject_id) - logger.debug ("add_target_id = {} (type={})".format(target_id, type(target_id))) + self.shell.AddPersonToSlice(target_id, subject_id) + logger.debug("add_target_id = {} (type={})".format( + target_id, type(target_id))) for target_id in del_target_ids: - logger.debug ("del_target_id = {} (type={})".format(target_id, type(target_id))) - self.shell.DeletePersonFromSlice (target_id, subject_id) + logger.debug("del_target_id = {} (type={})".format( + target_id, type(target_id))) + self.shell.DeletePersonFromSlice(target_id, subject_id) elif subject_type == 'authority' and target_type == 'user' and relation_name == 'pi': # due to the plcapi limitations this means essentially adding pi role to all people in the list - # it's tricky to remove any pi role here, although it might be desirable - persons = self.shell.GetPersons ({'peer_id': None, 'person_id': target_ids}) - for person in persons: + # it's tricky to remove any pi role here, although it might be + # desirable + persons = self.shell.GetPersons( + {'peer_id': None, 'person_id': target_ids}) + for person in persons: if 'pi' not in person['roles']: - self.shell.AddRoleToPerson('pi',person['person_id']) + self.shell.AddRoleToPerson('pi', person['person_id']) else: - logger.info('unexpected relation {} to maintain, {} -> {}'\ + logger.info('unexpected relation {} to maintain, {} -> {}' .format(relation_name, subject_type, target_type)) - ######################################## - ########## aggregate oriented + # aggregate oriented ######################################## - def testbed_name (self): return "myplc" + def testbed_name(self): return "myplc" - def aggregate_version (self): + def aggregate_version(self): return {} # first 2 args are None in case of resource discovery - def list_resources (self, version=None, options=None): - if options is None: options={} + def list_resources(self, version=None, options=None): + if options is None: + options = {} aggregate = PlAggregate(self) - rspec = aggregate.list_resources(version=version, options=options) + rspec = aggregate.list_resources(version=version, options=options) return rspec def describe(self, urns, version, options=None): - if options is None: options={} + if options is None: + options = {} aggregate = PlAggregate(self) return aggregate.describe(urns, version=version, options=options) - - def status (self, urns, options=None): - if options is None: options={} + + def status(self, urns, options=None): + if options is None: + options = {} aggregate = PlAggregate(self) - desc = aggregate.describe(urns, version='GENI 3') + desc = aggregate.describe(urns, version='GENI 3') status = {'geni_urn': desc['geni_urn'], 'geni_slivers': desc['geni_slivers']} return status - def allocate (self, urn, rspec_string, expiration, options=None): + def allocate(self, urn, rspec_string, expiration, options=None): """ Allocate a PL slice @@ -658,36 +706,41 @@ class PlDriver (Driver): the exact set of tags at the end of the call, meaning pre-existing tags are deleted if not repeated in the incoming request """ - if options is None: options={} + if options is None: + options = {} xrn = Xrn(urn) aggregate = PlAggregate(self) slices = PlSlices(self) sfa_peer = slices.get_sfa_peer(xrn.get_hrn()) - slice_record = None + slice_record = None users = options.get('geni_users', []) if users: slice_record = users[0].get('slice_record', {}) - + # parse rspec rspec = RSpec(rspec_string) requested_attributes = rspec.version.get_slice_attributes() - + # ensure site record exists - site = slices.verify_site(xrn.hrn, slice_record, sfa_peer, options=options) + site = slices.verify_site( + xrn.hrn, slice_record, sfa_peer, options=options) # ensure slice record exists - slice = slices.verify_slice(xrn.hrn, slice_record, sfa_peer, expiration=expiration, options=options) + slice = slices.verify_slice( + xrn.hrn, slice_record, sfa_peer, expiration=expiration, options=options) # ensure person records exists - persons = slices.verify_persons(xrn.hrn, slice, users, sfa_peer, options=options) + persons = slices.verify_persons( + xrn.hrn, slice, users, sfa_peer, options=options) # ensure slice attributes exists slices.verify_slice_tags(slice, requested_attributes, options=options) - + # add/remove slice from nodes request_nodes = rspec.version.get_nodes_with_slivers() nodes = slices.verify_slice_nodes(urn, slice, request_nodes) - - # add/remove links links - slices.verify_slice_links(slice, rspec.version.get_link_requests(), nodes) + + # add/remove links links + slices.verify_slice_links( + slice, rspec.version.get_link_requests(), nodes) # add/remove leases rspec_requested_leases = rspec.version.get_leases() @@ -696,68 +749,76 @@ class PlDriver (Driver): return aggregate.describe([xrn.get_urn()], version=rspec.version) def provision(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} # update users slices = PlSlices(self) aggregate = PlAggregate(self) slivers = aggregate.get_slivers(urns) if not slivers: sliver_id_parts = Xrn(urns[0]).get_sliver_id_parts() - # allow to be called with an empty rspec, meaning flush reservations + # allow to be called with an empty rspec, meaning flush + # reservations if sliver_id_parts: filter = {} try: filter['slice_id'] = int(sliver_id_parts[0]) except ValueError: filter['name'] = sliver_id_parts[0] - slices = self.shell.GetSlices(filter,['hrn']) + slices = self.shell.GetSlices(filter, ['hrn']) if not slices: - raise Forbidden("Unable to locate slice record for sliver: {}".format(xrn)) + raise Forbidden( + "Unable to locate slice record for sliver: {}".format(xrn)) slice = slices[0] slice_urn = hrn_to_urn(slice['hrn'], type='slice') - urns = [slice_urn] - else: + urns = [slice_urn] + else: slice_id = slivers[0]['slice_id'] slice_hrn = self.shell.GetSliceHrn(slice_id) slice = self.shell.GetSlices({'slice_id': slice_id})[0] slice['hrn'] = slice_hrn sfa_peer = slices.get_sfa_peer(slice['hrn']) users = options.get('geni_users', []) - persons = slices.verify_persons(slice['hrn'], slice, users, sfa_peer, options=options) + persons = slices.verify_persons( + slice['hrn'], slice, users, sfa_peer, options=options) # update sliver allocation states and set them to geni_provisioned sliver_ids = [sliver['sliver_id'] for sliver in slivers] dbsession = self.api.dbsession() - SliverAllocation.set_allocations(sliver_ids, 'geni_provisioned',dbsession) + SliverAllocation.set_allocations( + sliver_ids, 'geni_provisioned', dbsession) version_manager = VersionManager() - rspec_version = version_manager.get_version(options['geni_rspec_version']) + rspec_version = version_manager.get_version( + options['geni_rspec_version']) return self.describe(urns, rspec_version, options=options) def delete(self, urns, options=None): - if options is None: options={} + if options is None: + options = {} # collect sliver ids so we can update sliver allocation states after # we remove the slivers. aggregate = PlAggregate(self) slivers = aggregate.get_slivers(urns) if slivers: - slice_id = slivers[0]['slice_id'] + slice_id = slivers[0]['slice_id'] slice_name = slivers[0]['name'] node_ids = [] sliver_ids = [] for sliver in slivers: node_ids.append(sliver['node_id']) - sliver_ids.append(sliver['sliver_id']) + sliver_ids.append(sliver['sliver_id']) # leases - leases = self.shell.GetLeases({'name': slice_name, 'node_id': node_ids}) - leases_ids = [lease['lease_id'] for lease in leases ] + leases = self.shell.GetLeases( + {'name': slice_name, 'node_id': node_ids}) + leases_ids = [lease['lease_id'] for lease in leases] slice_hrn = self.shell.GetSliceHrn(int(slice_id)) try: self.shell.DeleteSliceFromNodes(slice_id, node_ids) if len(leases_ids) > 0: self.shell.DeleteLeases(leases_ids) - + # delete sliver allocation states dbsession = self.api.dbsession() SliverAllocation.delete_allocations(sliver_ids, dbsession) @@ -770,11 +831,12 @@ class PlDriver (Driver): geni_slivers.append( {'geni_sliver_urn': sliver['sliver_id'], 'geni_allocation_status': 'geni_unallocated', - 'geni_expires': datetime_to_string(utcparse(sliver['expires']))}) + 'geni_expires': datetime_to_string(utcparse(sliver['expires']))}) return geni_slivers - def renew (self, urns, expiration_time, options=None): - if options is None: options={} + def renew(self, urns, expiration_time, options=None): + if options is None: + options = {} aggregate = PlAggregate(self) slivers = aggregate.get_slivers(urns) if not slivers: @@ -785,17 +847,18 @@ class PlDriver (Driver): self.shell.UpdateSlice(slice['slice_id'], record) description = self.describe(urns, 'GENI 3', options) return description['geni_slivers'] - - def perform_operational_action (self, urns, action, options=None): - if options is None: options={} + def perform_operational_action(self, urns, action, options=None): + if options is None: + options = {} # MyPLC doesn't support operational actions. Lets pretend like it # supports start, but reject everything else. action = action.lower() if action not in ['geni_start']: raise UnsupportedOperation(action) - # fault if sliver is not full allocated (operational status is geni_pending_allocation) + # fault if sliver is not full allocated (operational status is + # geni_pending_allocation) description = self.describe(urns, 'GENI 3', options) for sliver in description['geni_slivers']: if sliver['geni_operational_status'] == 'geni_pending_allocation': @@ -809,8 +872,9 @@ class PlDriver (Driver): return geni_slivers # set the 'enabled' tag to 0 - def shutdown (self, xrn, options=None): - if options is None: options={} + def shutdown(self, xrn, options=None): + if options is None: + options = {} hrn, _ = urn_to_hrn(xrn) top_auth_hrn = top_auth(hrn) site_hrn = '.'.join(hrn.split('.')[:-1]) @@ -822,11 +886,13 @@ class PlDriver (Driver): slicename = '_'.join([login_base, slice_part]) - slices = self.shell.GetSlices({'peer_id': None, 'name': slicename}, ['slice_id']) + slices = self.shell.GetSlices( + {'peer_id': None, 'name': slicename}, ['slice_id']) if not slices: raise RecordNotFound(slice_hrn) slice_id = slices[0]['slice_id'] - slice_tags = self.shell.GetSliceTags({'slice_id': slice_id, 'tagname': 'enabled'}) + slice_tags = self.shell.GetSliceTags( + {'slice_id': slice_id, 'tagname': 'enabled'}) if not slice_tags: self.shell.AddSliceTag(slice_id, 'enabled', '0') elif slice_tags[0]['value'] != "0": diff --git a/sfa/planetlab/plshell.py b/sfa/planetlab/plshell.py index 15c58b69..8877f410 100644 --- a/sfa/planetlab/plshell.py +++ b/sfa/planetlab/plshell.py @@ -5,13 +5,14 @@ from urlparse import urlparse from sfa.util.sfalogging import logger + class PlShell: """ A simple xmlrpc shell to a myplc instance This class can receive all PLCAPI calls to the underlying testbed For safety this is limited to a set of hard-coded calls """ - + direct_calls = ['AddNode', 'AddPerson', 'AddPersonKey', 'AddPersonToSite', 'AddPersonToSlice', 'AddRoleToPerson', 'AddSite', 'AddSiteTag', 'AddSlice', 'AddSliceTag', 'AddSliceToNodes', 'BindObjectToPeer', 'DeleteKey', @@ -22,9 +23,9 @@ class PlShell: 'UnBindObjectFromPeer', 'UpdateNode', 'UpdatePerson', 'UpdateSite', 'UpdateSlice', 'UpdateSliceTag', # also used as-is in importer - 'GetSites','GetNodes', + 'GetSites', 'GetNodes', # Lease management methods - 'GetLeases', 'GetLeaseGranularity', 'DeleteLeases','UpdateLeases', + 'GetLeases', 'GetLeaseGranularity', 'DeleteLeases', 'UpdateLeases', 'AddLeases', # HRN management methods 'SetPersonHrn', 'GetPersonHrn', 'SetSliceHrn', 'GetSliceHrn', @@ -32,65 +33,72 @@ class PlShell: # Tag slice/person/site created by SFA 'SetPersonSfaCreated', 'GetPersonSfaCreated', 'SetSliceSfaCreated', 'GetSliceSfaCreated', 'SetNodeSfaCreated', 'GetNodeSfaCreated', - 'GetSiteSfaCreated', 'SetSiteSfaCreated', + 'GetSiteSfaCreated', 'SetSiteSfaCreated', ] # support for other names - this is experimental - alias_calls = { 'get_authorities':'GetSites', - 'get_nodes':'GetNodes', - } - + alias_calls = {'get_authorities': 'GetSites', + 'get_nodes': 'GetNodes', + } - # use the 'capability' auth mechanism for higher performance when the PLC db is local - def __init__ ( self, config ) : + # use the 'capability' auth mechanism for higher performance when the PLC + # db is local + def __init__(self, config): url = config.SFA_PLC_URL # try to figure if the url is local - hostname=urlparse(url).hostname - is_local=False - if hostname == 'localhost': is_local=True - # otherwise compare IP addresses; + hostname = urlparse(url).hostname + is_local = False + if hostname == 'localhost': + is_local = True + # otherwise compare IP addresses; # this might fail for any number of reasons, so let's harden that try: # xxx todo this seems to result in a DNS request for each incoming request to the AM # should be cached or improved - url_ip=socket.gethostbyname(hostname) - local_ip=socket.gethostbyname(socket.gethostname()) - if url_ip==local_ip: is_local=True + url_ip = socket.gethostbyname(hostname) + local_ip = socket.gethostbyname(socket.gethostname()) + if url_ip == local_ip: + is_local = True except: pass if is_local: try: # too bad this is not installed properly - plcapi_path="/usr/share/plc_api" - if plcapi_path not in sys.path: sys.path.append(plcapi_path) + plcapi_path = "/usr/share/plc_api" + if plcapi_path not in sys.path: + sys.path.append(plcapi_path) import PLC.Shell - plc_direct_access=True + plc_direct_access = True except: - plc_direct_access=False + plc_direct_access = False if is_local and plc_direct_access: logger.info('plshell access - capability') - self.plauth = { 'AuthMethod': 'capability', - 'Username': str(config.SFA_PLC_USER), - 'AuthString': str(config.SFA_PLC_PASSWORD), - } - self.proxy = PLC.Shell.Shell () + self.plauth = {'AuthMethod': 'capability', + 'Username': str(config.SFA_PLC_USER), + 'AuthString': str(config.SFA_PLC_PASSWORD), + } + self.proxy = PLC.Shell.Shell() else: logger.info('plshell access - xmlrpc') - self.plauth = { 'AuthMethod': 'password', - 'Username': str(config.SFA_PLC_USER), - 'AuthString': str(config.SFA_PLC_PASSWORD), - } - self.proxy = xmlrpclib.Server(url, verbose = False, allow_none = True) + self.plauth = {'AuthMethod': 'password', + 'Username': str(config.SFA_PLC_USER), + 'AuthString': str(config.SFA_PLC_PASSWORD), + } + self.proxy = xmlrpclib.Server(url, verbose=False, allow_none=True) def __getattr__(self, name): def func(*args, **kwds): - actual_name=None - if name in PlShell.direct_calls: actual_name=name - if name in PlShell.alias_calls: actual_name=PlShell.alias_calls[name] + actual_name = None + if name in PlShell.direct_calls: + actual_name = name + if name in PlShell.alias_calls: + actual_name = PlShell.alias_calls[name] if not actual_name: - raise Exception("Illegal method call %s for PL driver"%(name)) - result=getattr(self.proxy, actual_name)(self.plauth, *args, **kwds) - logger.debug('PlShell %s (%s) returned ... '%(name,actual_name)) + raise Exception( + "Illegal method call %s for PL driver" % (name)) + result = getattr(self.proxy, actual_name)( + self.plauth, *args, **kwds) + logger.debug('PlShell %s (%s) returned ... ' % (name, actual_name)) return result return func diff --git a/sfa/planetlab/plslices.py b/sfa/planetlab/plslices.py index 2b59a017..635001ce 100644 --- a/sfa/planetlab/plslices.py +++ b/sfa/planetlab/plslices.py @@ -10,25 +10,27 @@ from sfa.planetlab.topology import Topology from sfa.planetlab.plxrn import PlXrn, hrn_to_pl_slicename, xrn_to_hostname, top_auth, hash_loginbase from sfa.storage.model import SliverAllocation -MAXINT = 2L**31-1 +MAXINT = 2L**31 - 1 + class PlSlices: - rspec_to_slice_tag = {'max_rate' : 'net_max_rate'} + rspec_to_slice_tag = {'max_rate': 'net_max_rate'} def __init__(self, driver): self.driver = driver def get_slivers(self, xrn, node=None): hrn, type = urn_to_hrn(xrn) - + slice_name = hrn_to_pl_slicename(hrn) # XX Should we just call PLCAPI.GetSliceTicket(slice_name) instead # of doing all of this? - #return self.driver.shell.GetSliceTicket(self.auth, slice_name) - + # return self.driver.shell.GetSliceTicket(self.auth, slice_name) + # from PLCAPI.GetSlivers.get_slivers() - slice_fields = ['slice_id', 'name', 'instantiation', 'expires', 'person_ids', 'slice_tag_ids'] + slice_fields = ['slice_id', 'name', 'instantiation', + 'expires', 'person_ids', 'slice_tag_ids'] slices = self.driver.shell.GetSlices(slice_name, slice_fields) # Build up list of users and slice attributes person_ids = set() @@ -39,11 +41,11 @@ class PlSlices: person_ids = list(person_ids) all_slice_tag_ids = list(all_slice_tag_ids) # Get user information - all_persons_list = self.driver.shell.GetPersons({'person_id':person_ids,'enabled':True}, + all_persons_list = self.driver.shell.GetPersons({'person_id': person_ids, 'enabled': True}, ['person_id', 'enabled', 'key_ids']) all_persons = {} for person in all_persons_list: - all_persons[person['person_id']] = person + all_persons[person['person_id']] = person # Build up list of keys key_ids = set() @@ -51,7 +53,8 @@ class PlSlices: key_ids.update(person['key_ids']) key_ids = list(key_ids) # Get user account keys - all_keys_list = self.driver.shell.GetKeys(key_ids, ['key_id', 'key', 'key_type']) + all_keys_list = self.driver.shell.GetKeys( + key_ids, ['key_id', 'key', 'key_type']) all_keys = {} for key in all_keys_list: all_keys[key['key_id']] = key @@ -60,7 +63,7 @@ class PlSlices: all_slice_tags = {} for slice_tag in all_slice_tags_list: all_slice_tags[slice_tag['slice_tag_id']] = slice_tag - + slivers = [] for slice in slices: keys = [] @@ -73,13 +76,13 @@ class PlSlices: if key_id in all_keys: key = all_keys[key_id] keys += [{'key_type': key['key_type'], - 'key': key['key']}] + 'key': key['key']}] attributes = [] # All (per-node and global) attributes for this slice slice_tags = [] for slice_tag_id in slice['slice_tag_ids']: if slice_tag_id in all_slice_tags: - slice_tags.append(all_slice_tags[slice_tag_id]) + slice_tags.append(all_slice_tags[slice_tag_id]) # Per-node sliver attributes take precedence over global # slice attributes, so set them first. # Then comes nodegroup slice attributes @@ -90,7 +93,7 @@ class PlSlices: for sliver_attribute in filter(lambda a: a['node_id'] == node['node_id'], slice_tags): sliver_attributes.append(sliver_attribute['tagname']) attributes.append({'tagname': sliver_attribute['tagname'], - 'value': sliver_attribute['value']}) + 'value': sliver_attribute['value']}) # set nodegroup slice attributes for slice_tag in filter(lambda a: a['nodegroup_id'] in node['nodegroup_ids'], slice_tags): @@ -99,7 +102,7 @@ class PlSlices: # already set. if slice_tag not in slice_tags: attributes.append({'tagname': slice_tag['tagname'], - 'value': slice_tag['value']}) + 'value': slice_tag['value']}) for slice_tag in filter(lambda a: a['node_id'] is None, slice_tags): # Do not set any global slice attributes for @@ -107,13 +110,13 @@ class PlSlices: # already set. if slice_tag['tagname'] not in sliver_attributes: attributes.append({'tagname': slice_tag['tagname'], - 'value': slice_tag['value']}) + 'value': slice_tag['value']}) # XXX Sanity check; though technically this should be a system invariant # checked with an assertion if slice['expires'] > MAXINT: slice['expires'] = MAXINT - + slivers.append({ 'hrn': hrn, 'name': slice['name'], @@ -125,7 +128,6 @@ class PlSlices: }) return slivers - def get_sfa_peer(self, xrn): hrn, type = urn_to_hrn(xrn) @@ -142,119 +144,121 @@ class PlSlices: def verify_slice_leases(self, slice, rspec_requested_leases): - leases = self.driver.shell.GetLeases({'name':slice['name'], 'clip':int(time.time())}, - ['lease_id','name', 'hostname', 't_from', 't_until']) + leases = self.driver.shell.GetLeases({'name': slice['name'], 'clip': int(time.time())}, + ['lease_id', 'name', 'hostname', 't_from', 't_until']) grain = self.driver.shell.GetLeaseGranularity() requested_leases = [] for lease in rspec_requested_leases: - requested_lease = {} - slice_hrn, _ = urn_to_hrn(lease['slice_id']) - - top_auth_hrn = top_auth(slice_hrn) - site_hrn = '.'.join(slice_hrn.split('.')[:-1]) - slice_part = slice_hrn.split('.')[-1] - if top_auth_hrn == self.driver.hrn: - login_base = slice_hrn.split('.')[-2][:12] - else: - login_base = hash_loginbase(site_hrn) - - slice_name = '_'.join([login_base, slice_part]) - - if slice_name != slice['name']: - continue - elif Xrn(lease['component_id']).get_authority_urn().split(':')[0] != self.driver.hrn: - continue + requested_lease = {} + slice_hrn, _ = urn_to_hrn(lease['slice_id']) + + top_auth_hrn = top_auth(slice_hrn) + site_hrn = '.'.join(slice_hrn.split('.')[:-1]) + slice_part = slice_hrn.split('.')[-1] + if top_auth_hrn == self.driver.hrn: + login_base = slice_hrn.split('.')[-2][:12] + else: + login_base = hash_loginbase(site_hrn) - hostname = xrn_to_hostname(lease['component_id']) - # fill the requested node with nitos ids - requested_lease['name'] = slice['name'] - requested_lease['hostname'] = hostname - requested_lease['t_from'] = int(lease['start_time']) - requested_lease['t_until'] = int(lease['duration']) * grain + int(lease['start_time']) - requested_leases.append(requested_lease) + slice_name = '_'.join([login_base, slice_part]) + if slice_name != slice['name']: + continue + elif Xrn(lease['component_id']).get_authority_urn().split(':')[0] != self.driver.hrn: + continue + hostname = xrn_to_hostname(lease['component_id']) + # fill the requested node with nitos ids + requested_lease['name'] = slice['name'] + requested_lease['hostname'] = hostname + requested_lease['t_from'] = int(lease['start_time']) + requested_lease['t_until'] = int( + lease['duration']) * grain + int(lease['start_time']) + requested_leases.append(requested_lease) - # prepare actual slice leases by lease_id + # prepare actual slice leases by lease_id leases_by_id = {} for lease in leases: - leases_by_id[lease['lease_id']] = {'name': lease['name'], 'hostname': lease['hostname'], \ - 't_from': lease['t_from'], 't_until': lease['t_until']} - + leases_by_id[lease['lease_id']] = {'name': lease['name'], 'hostname': lease['hostname'], + 't_from': lease['t_from'], 't_until': lease['t_until']} + added_leases = [] kept_leases_id = [] deleted_leases_id = [] for lease_id in leases_by_id: - if leases_by_id[lease_id] not in requested_leases: - deleted_leases_id.append(lease_id) - else: - kept_leases_id.append(lease_id) - requested_leases.remove(leases_by_id[lease_id]) + if leases_by_id[lease_id] not in requested_leases: + deleted_leases_id.append(lease_id) + else: + kept_leases_id.append(lease_id) + requested_leases.remove(leases_by_id[lease_id]) added_leases = requested_leases - try: self.driver.shell.DeleteLeases(deleted_leases_id) for lease in added_leases: - self.driver.shell.AddLeases(lease['hostname'], slice['name'], lease['t_from'], lease['t_until']) + self.driver.shell.AddLeases(lease['hostname'], slice['name'], lease[ + 't_from'], lease['t_until']) - except: + except: logger.log_exc('Failed to add/remove slice leases') return leases - def verify_slice_nodes(self, slice_urn, slice, rspec_nodes): - + slivers = {} for node in rspec_nodes: hostname = node.get('component_name') client_id = node.get('client_id') - component_id = node.get('component_id').strip() + component_id = node.get('component_id').strip() if hostname: hostname = hostname.strip() elif component_id: hostname = xrn_to_hostname(component_id) if hostname: - slivers[hostname] = {'client_id': client_id, 'component_id': component_id} - - nodes = self.driver.shell.GetNodes(slice['node_ids'], ['node_id', 'hostname', 'interface_ids']) + slivers[hostname] = { + 'client_id': client_id, 'component_id': component_id} + + nodes = self.driver.shell.GetNodes( + slice['node_ids'], ['node_id', 'hostname', 'interface_ids']) current_slivers = [node['hostname'] for node in nodes] # remove nodes not in rspec deleted_nodes = list(set(current_slivers).difference(slivers.keys())) # add nodes from rspec - added_nodes = list(set(slivers.keys()).difference(current_slivers)) + added_nodes = list(set(slivers.keys()).difference(current_slivers)) try: self.driver.shell.AddSliceToNodes(slice['name'], added_nodes) - self.driver.shell.DeleteSliceFromNodes(slice['name'], deleted_nodes) - - except: + self.driver.shell.DeleteSliceFromNodes( + slice['name'], deleted_nodes) + + except: logger.log_exc('Failed to add/remove slice from nodes') - slices = self.driver.shell.GetSlices(slice['name'], ['node_ids']) + slices = self.driver.shell.GetSlices(slice['name'], ['node_ids']) resulting_nodes = self.driver.shell.GetNodes(slices[0]['node_ids']) # update sliver allocations for node in resulting_nodes: client_id = slivers[node['hostname']]['client_id'] component_id = slivers[node['hostname']]['component_id'] - sliver_hrn = '{}.{}-{}'.format(self.driver.hrn, slice['slice_id'], node['node_id']) + sliver_hrn = '{}.{}-{}'.format(self.driver.hrn, + slice['slice_id'], node['node_id']) sliver_id = Xrn(sliver_hrn, type='sliver').urn - record = SliverAllocation(sliver_id=sliver_id, client_id=client_id, + record = SliverAllocation(sliver_id=sliver_id, client_id=client_id, component_id=component_id, - slice_urn = slice_urn, - allocation_state='geni_allocated') + slice_urn=slice_urn, + allocation_state='geni_allocated') record.sync(self.driver.api.dbsession()) return resulting_nodes def free_egre_key(self): used = set() for tag in self.driver.shell.GetSliceTags({'tagname': 'egre_key'}): - used.add(int(tag['value'])) + used.add(int(tag['value'])) for i in range(1, 256): if i not in used: @@ -266,16 +270,16 @@ class PlSlices: return str(key) def verify_slice_links(self, slice, requested_links, nodes): - + if not requested_links: return # exit if links are not supported here topology = Topology() if not topology: - return + return - # build dict of nodes + # build dict of nodes nodes_dict = {} interface_ids = [] for node in nodes: @@ -285,20 +289,20 @@ class PlSlices: interfaces = self.driver.shell.GetInterfaces(interface_ids) interfaces_dict = {} for interface in interfaces: - interfaces_dict[interface['interface_id']] = interface + interfaces_dict[interface['interface_id']] = interface slice_tags = [] - + # set egre key slice_tags.append({'name': 'egre_key', 'value': self.free_egre_key()}) - + # set netns slice_tags.append({'name': 'netns', 'value': '1'}) - # set cap_net_admin + # set cap_net_admin # need to update the attribute string? - slice_tags.append({'name': 'capabilities', 'value': 'CAP_NET_ADMIN'}) - + slice_tags.append({'name': 'capabilities', 'value': 'CAP_NET_ADMIN'}) + for link in requested_links: # get the ip address of the first node in the link ifname1 = Xrn(link['interface1']['component_id']).get_leaf() @@ -308,24 +312,28 @@ class PlSlices: node_raw = ifname_parts[0] device = None if len(ifname_parts) > 1: - device = ifname_parts[1] + device = ifname_parts[1] node_id = int(node_raw.replace('node', '')) node = nodes_dict[node_id] if1 = interfaces_dict[node['interface_ids'][0]] ipaddr = if1['ip'] topo_rspec = VLink.get_topo_rspec(link, ipaddr) # set topo_rspec tag - slice_tags.append({'name': 'topo_rspec', 'value': str([topo_rspec]), 'node_id': node_id}) + slice_tags.append({'name': 'topo_rspec', 'value': str( + [topo_rspec]), 'node_id': node_id}) # set vini_topo tag - slice_tags.append({'name': 'vini_topo', 'value': 'manual', 'node_id': node_id}) - #self.driver.shell.AddSliceTag(slice['name'], 'topo_rspec', str([topo_rspec]), node_id) + slice_tags.append( + {'name': 'vini_topo', 'value': 'manual', 'node_id': node_id}) + #self.driver.shell.AddSliceTag(slice['name'], 'topo_rspec', str([topo_rspec]), node_id) - self.verify_slice_tags(slice, slice_tags, {'pltags':'append'}, admin=True) - + self.verify_slice_tags(slice, slice_tags, { + 'pltags': 'append'}, admin=True) def verify_site(self, slice_xrn, slice_record=None, sfa_peer=None, options=None): - if slice_record is None: slice_record={} - if options is None: options={} + if slice_record is None: + slice_record = {} + if options is None: + options = {} (slice_hrn, type) = urn_to_hrn(slice_xrn) top_auth_hrn = top_auth(slice_hrn) site_hrn = '.'.join(slice_hrn.split('.')[:-1]) @@ -335,8 +343,8 @@ class PlSlices: login_base = hash_loginbase(site_hrn) # filter sites by hrn - sites = self.driver.shell.GetSites({'peer_id': None, 'hrn':site_hrn}, - ['site_id','name','abbreviated_name','login_base','hrn']) + sites = self.driver.shell.GetSites({'peer_id': None, 'hrn': site_hrn}, + ['site_id', 'name', 'abbreviated_name', 'login_base', 'hrn']) # alredy exists if sites: @@ -350,21 +358,22 @@ class PlSlices: 'max_slivers': 1000, 'enabled': True, 'peer_site_id': None, - 'hrn':site_hrn, + 'hrn': site_hrn, 'sfa_created': 'True', - } + } site_id = self.driver.shell.AddSite(site) # plcapi tends to mess with the incoming hrn so let's make sure - self.driver.shell.SetSiteHrn (site_id, site_hrn) + self.driver.shell.SetSiteHrn(site_id, site_hrn) site['site_id'] = site_id # exempt federated sites from monitor policies - self.driver.shell.AddSiteTag(site_id, 'exempt_site_until', "20200101") + self.driver.shell.AddSiteTag( + site_id, 'exempt_site_until', "20200101") return site - def verify_slice(self, slice_hrn, slice_record, sfa_peer, expiration, options=None): - if options is None: options={} + if options is None: + options = {} top_auth_hrn = top_auth(slice_hrn) site_hrn = '.'.join(slice_hrn.split('.')[:-1]) slice_part = slice_hrn.split('.')[-1] @@ -376,15 +385,15 @@ class PlSlices: expires = int(datetime_to_epoch(utcparse(expiration))) # Filter slices by HRN - slices = self.driver.shell.GetSlices({'peer_id': None, 'hrn':slice_hrn}, - ['slice_id','name','hrn','expires']) - + slices = self.driver.shell.GetSlices({'peer_id': None, 'hrn': slice_hrn}, + ['slice_id', 'name', 'hrn', 'expires']) + if slices: slice = slices[0] slice_id = slice['slice_id'] - #Update expiration if necessary + # Update expiration if necessary if slice.get('expires', None) != expires: - self.driver.shell.UpdateSlice( slice_id, {'expires' : expires}) + self.driver.shell.UpdateSlice(slice_id, {'expires': expires}) else: if slice_record: url = slice_record.get('url', slice_hrn) @@ -398,37 +407,36 @@ class PlSlices: 'hrn': slice_hrn, 'sfa_created': 'True', #'expires': expires, - } + } # add the slice slice_id = self.driver.shell.AddSlice(slice) # plcapi tends to mess with the incoming hrn so let's make sure - self.driver.shell.SetSliceHrn (slice_id, slice_hrn) + self.driver.shell.SetSliceHrn(slice_id, slice_hrn) # cannot be set with AddSlice # set the expiration self.driver.shell.UpdateSlice(slice_id, {'expires': expires}) return self.driver.shell.GetSlices(slice_id)[0] - # in the following code, we use # 'person' to denote a PLCAPI-like record with typically 'person_id' and 'email' # 'user' to denote an incoming record with typically 'urn' and 'email' - we add 'hrn' in there - # 'slice_record': it seems like the first of these 'users' also contains a 'slice_record' + # 'slice_record': it seems like the first of these 'users' also contains a 'slice_record' # key that holds stuff like 'hrn', 'slice_id', 'authority',... - # - def create_person_from_user (self, user, site_id): + # + def create_person_from_user(self, user, site_id): user_hrn = user['hrn'] # the value to use if 'user' has no 'email' attached - or if the attached email already exists - # typically - ( auth_hrn, _ , leaf ) = user_hrn.rpartition('.') + # typically + (auth_hrn, _, leaf) = user_hrn.rpartition('.') # somehow this has backslashes, get rid of them - auth_hrn = auth_hrn.replace('\\','') + auth_hrn = auth_hrn.replace('\\', '') default_email = "{}@{}.stub".format(leaf, auth_hrn) - person_record = { + person_record = { # required - 'first_name': user.get('first_name',user_hrn), - 'last_name': user.get('last_name',user_hrn), + 'first_name': user.get('first_name', user_hrn), + 'last_name': user.get('last_name', user_hrn), 'email': user.get('email', default_email), # our additions 'enabled': True, @@ -436,38 +444,43 @@ class PlSlices: 'hrn': user_hrn, } - logger.debug ("about to attempt to AddPerson with {}".format(person_record)) + logger.debug( + "about to attempt to AddPerson with {}".format(person_record)) try: # the thing is, the PLE db has a limitation on re-using the same e-mail - # in the case where people have an account on ple.upmc and then then come + # in the case where people have an account on ple.upmc and then then come # again from onelab.upmc, they will most likely have the same e-mail, and so kaboom.. # so we first try with the accurate email - person_id = int (self.driver.shell.AddPerson(person_record)) + person_id = int(self.driver.shell.AddPerson(person_record)) except: logger.log_exc("caught during first attempt at AddPerson") - # and if that fails we start again with the email based on the hrn, which this time is unique.. + # and if that fails we start again with the email based on the hrn, + # which this time is unique.. person_record['email'] = default_email - logger.debug ("second chance with email={}".format(person_record['email'])) - person_id = int (self.driver.shell.AddPerson(person_record)) + logger.debug("second chance with email={}".format( + person_record['email'])) + person_id = int(self.driver.shell.AddPerson(person_record)) self.driver.shell.AddRoleToPerson('user', person_id) self.driver.shell.AddPersonToSite(person_id, site_id) # plcapi tends to mess with the incoming hrn so let's make sure - self.driver.shell.SetPersonHrn (person_id, user_hrn) + self.driver.shell.SetPersonHrn(person_id, user_hrn) # also 'enabled':True does not seem to pass through with AddPerson - self.driver.shell.UpdatePerson (person_id, {'enabled': True}) + self.driver.shell.UpdatePerson(person_id, {'enabled': True}) return person_id def verify_persons(self, slice_hrn, slice_record, users, sfa_peer, options=None): - if options is None: options={} + if options is None: + options = {} # first we annotate the incoming users arg with a 'hrn' key for user in users: - user['hrn'], _ = urn_to_hrn(user['urn']) + user['hrn'], _ = urn_to_hrn(user['urn']) # this is for retrieving users from a hrn - users_by_hrn = { user['hrn'] : user for user in users } + users_by_hrn = {user['hrn']: user for user in users} - for user in users: logger.debug("incoming user {}".format(user)) + for user in users: + logger.debug("incoming user {}".format(user)) # compute the hrn's for the authority and site top_auth_hrn = top_auth(slice_hrn) @@ -481,53 +494,62 @@ class PlSlices: slice_name = '_'.join([login_base, slice_part]) # locate the site object - # due to a limitation in PLCAPI, we have to specify 'hrn' as part of the return fields - site = self.driver.shell.GetSites ({'peer_id':None, 'hrn':site_hrn}, ['site_id','hrn'])[0] + # due to a limitation in PLCAPI, we have to specify 'hrn' as part of + # the return fields + site = self.driver.shell.GetSites( + {'peer_id': None, 'hrn': site_hrn}, ['site_id', 'hrn'])[0] site_id = site['site_id'] # locate the slice object - slice = self.driver.shell.GetSlices ({'peer_id':None, 'hrn':slice_hrn}, ['slice_id','hrn','person_ids'])[0] + slice = self.driver.shell.GetSlices({'peer_id': None, 'hrn': slice_hrn}, [ + 'slice_id', 'hrn', 'person_ids'])[0] slice_id = slice['slice_id'] slice_person_ids = slice['person_ids'] # the common set of attributes for our calls to GetPersons - person_fields = ['person_id','email','hrn'] + person_fields = ['person_id', 'email', 'hrn'] # for the intended set of hrns, locate existing persons - target_hrns = [ user['hrn'] for user in users ] - target_existing_persons = self.driver.shell.GetPersons ({'peer_id':None, 'hrn': target_hrns}, person_fields) - target_existing_person_ids = [ person ['person_id'] for person in target_existing_persons ] + target_hrns = [user['hrn'] for user in users] + target_existing_persons = self.driver.shell.GetPersons( + {'peer_id': None, 'hrn': target_hrns}, person_fields) + target_existing_person_ids = [person['person_id'] + for person in target_existing_persons] # find out the hrns that *do not* have a corresponding person - existing_hrns = [ person['hrn'] for person in target_existing_persons ] - tocreate_hrns = set (target_hrns) - set (existing_hrns) + existing_hrns = [person['hrn'] for person in target_existing_persons] + tocreate_hrns = set(target_hrns) - set(existing_hrns) # create these - target_created_person_ids = [ self.create_person_from_user (users_by_hrn[hrn], site_id) for hrn in tocreate_hrns ] + target_created_person_ids = [self.create_person_from_user( + users_by_hrn[hrn], site_id) for hrn in tocreate_hrns] # we can partition the persons of interest into one of these 3 classes - add_person_ids = set(target_created_person_ids) | set(target_existing_person_ids) - set(slice_person_ids) - keep_person_ids = set(target_existing_person_ids) & set(slice_person_ids) - del_person_ids = set(slice_person_ids) - set(target_existing_person_ids) - - # delete + add_person_ids = set(target_created_person_ids) | set( + target_existing_person_ids) - set(slice_person_ids) + keep_person_ids = set( + target_existing_person_ids) & set(slice_person_ids) + del_person_ids = set(slice_person_ids) - \ + set(target_existing_person_ids) + + # delete for person_id in del_person_ids: - self.driver.shell.DeletePersonFromSlice (person_id, slice_id) + self.driver.shell.DeletePersonFromSlice(person_id, slice_id) # about the last 2 sets, for managing keys, we need to trace back person_id -> user # and for this we need all the Person objects; we already have the target_existing ones # also we avoid issuing a call if possible target_created_persons = [] if not target_created_person_ids \ - else self.driver.shell.GetPersons \ - ({'peer_id':None, 'person_id':target_created_person_ids}, person_fields) - persons_by_person_id = { person['person_id'] : person \ - for person in target_existing_persons + target_created_persons } - - def user_by_person_id (person_id): - person = persons_by_person_id [person_id] - hrn = person ['hrn'] - return users_by_hrn [hrn] - + else self.driver.shell.GetPersons \ + ({'peer_id': None, 'person_id': target_created_person_ids}, person_fields) + persons_by_person_id = {person['person_id']: person + for person in target_existing_persons + target_created_persons} + + def user_by_person_id(person_id): + person = persons_by_person_id[person_id] + hrn = person['hrn'] + return users_by_hrn[hrn] + persons_to_verify_keys = {} - # add + # add for person_id in add_person_ids: self.driver.shell.AddPersonToSlice(person_id, slice_id) persons_to_verify_keys[person_id] = user_by_person_id(person_id) @@ -538,22 +560,23 @@ class PlSlices: # return hrns of the newly added persons - return [ persons_by_person_id[person_id]['hrn'] for person_id in add_person_ids ] + return [persons_by_person_id[person_id]['hrn'] for person_id in add_person_ids] def verify_keys(self, persons_to_verify_keys, options=None): - if options is None: options={} + if options is None: + options = {} # we only add keys that comes from sfa to persons in PL for person_id in persons_to_verify_keys: - person_sfa_keys = persons_to_verify_keys[person_id].get('keys', []) - person_pl_keys = self.driver.shell.GetKeys({'person_id': int(person_id)}) - person_pl_keys_list = [key['key'] for key in person_pl_keys] - - keys_to_add = set(person_sfa_keys).difference(person_pl_keys_list) + person_sfa_keys = persons_to_verify_keys[person_id].get('keys', []) + person_pl_keys = self.driver.shell.GetKeys( + {'person_id': int(person_id)}) + person_pl_keys_list = [key['key'] for key in person_pl_keys] - for key_string in keys_to_add: - key = {'key': key_string, 'key_type': 'ssh'} - self.driver.shell.AddPersonKey(int(person_id), key) + keys_to_add = set(person_sfa_keys).difference(person_pl_keys_list) + for key_string in keys_to_add: + key = {'key': key_string, 'key_type': 'ssh'} + self.driver.shell.AddPersonKey(int(person_id), key) def verify_slice_tags(self, slice, requested_slice_attributes, options=None, admin=False): """ @@ -564,18 +587,20 @@ class PlSlices: (*) 'sync' - tries to do the plain wholesale thing, i.e. to leave the db in sync with incoming tags """ - if options is None: options={} + if options is None: + options = {} # lookup 'pltags' in options to find out which mode is requested here - pltags = options.get('pltags', 'ignore') - # make sure the default is 'ignore' + pltags = options.get('pltags', 'ignore') + # make sure the default is 'ignore' if pltags not in ('ignore', 'append', 'sync'): pltags = 'ignore' if pltags == 'ignore': - logger.info('verify_slice_tags in ignore mode - leaving slice tags as-is') + logger.info( + 'verify_slice_tags in ignore mode - leaving slice tags as-is') return - + # incoming data (attributes) have a (name, value) pair # while PLC data (tags) have a (tagname, value) pair # we must be careful not to mix these up @@ -585,15 +610,17 @@ class PlSlices: if not admin: filter['|roles'] = ['user'] valid_tag_types = self.driver.shell.GetTagTypes(filter) - valid_tag_names = [ tag_type['tagname'] for tag_type in valid_tag_types ] - logger.debug("verify_slice_attributes: valid names={}".format(valid_tag_names)) + valid_tag_names = [tag_type['tagname'] for tag_type in valid_tag_types] + logger.debug( + "verify_slice_attributes: valid names={}".format(valid_tag_names)) # get slice tags slice_attributes_to_add = [] slice_tags_to_remove = [] # we need to keep the slice hrn anyway ignored_slice_tag_names = ['hrn'] - existing_slice_tags = self.driver.shell.GetSliceTags({'slice_id': slice['slice_id']}) + existing_slice_tags = self.driver.shell.GetSliceTags( + {'slice_id': slice['slice_id']}) # get tags that should be removed for slice_tag in existing_slice_tags: @@ -628,26 +655,29 @@ class PlSlices: if not tag_found: slice_attributes_to_add.append(requested_attribute) - def friendly_message (tag_or_att): - name = tag_or_att['tagname'] if 'tagname' in tag_or_att else tag_or_att['name'] + def friendly_message(tag_or_att): + name = tag_or_att[ + 'tagname'] if 'tagname' in tag_or_att else tag_or_att['name'] return "SliceTag slice={}, tagname={} value={}, node_id={}"\ .format(slice['name'], tag_or_att['name'], tag_or_att['value'], tag_or_att.get('node_id')) - + # remove stale tags for tag in slice_tags_to_remove: try: - logger.info("Removing Slice Tag {}".format(friendly_message(tag))) + logger.info("Removing Slice Tag {}".format( + friendly_message(tag))) self.driver.shell.DeleteSliceTag(tag['slice_tag_id']) except Exception as e: - logger.warn("Failed to remove slice tag {}\nCause:{}"\ + logger.warn("Failed to remove slice tag {}\nCause:{}" .format(friendly_message(tag), e)) # add requested_tags for attribute in slice_attributes_to_add: try: - logger.info("Adding Slice Tag {}".format(friendly_message(attribute))) - self.driver.shell.AddSliceTag(slice['name'], attribute['name'], + logger.info("Adding Slice Tag {}".format( + friendly_message(attribute))) + self.driver.shell.AddSliceTag(slice['name'], attribute['name'], attribute['value'], attribute.get('node_id', None)) except Exception as e: - logger.warn("Failed to add slice tag {}\nCause:{}"\ + logger.warn("Failed to add slice tag {}\nCause:{}" .format(friendly_message(attribute), e)) diff --git a/sfa/planetlab/plxrn.py b/sfa/planetlab/plxrn.py index 70ff5e02..f9cf63c3 100644 --- a/sfa/planetlab/plxrn.py +++ b/sfa/planetlab/plxrn.py @@ -5,78 +5,97 @@ import re from sfa.util.xrn import Xrn, get_authority # temporary helper functions to use this module instead of namespace -def hostname_to_hrn (auth, login_base, hostname): - return PlXrn(auth=auth+'.'+login_base,hostname=hostname).get_hrn() + + +def hostname_to_hrn(auth, login_base, hostname): + return PlXrn(auth=auth + '.' + login_base, hostname=hostname).get_hrn() + + def hostname_to_urn(auth, login_base, hostname): - return PlXrn(auth=auth+'.'+login_base,hostname=hostname).get_urn() -def slicename_to_hrn (auth_hrn, slicename): - return PlXrn(auth=auth_hrn,slicename=slicename).get_hrn() -def email_to_hrn (auth_hrn, email): + return PlXrn(auth=auth + '.' + login_base, hostname=hostname).get_urn() + + +def slicename_to_hrn(auth_hrn, slicename): + return PlXrn(auth=auth_hrn, slicename=slicename).get_hrn() + + +def email_to_hrn(auth_hrn, email): return PlXrn(auth=auth_hrn, email=email).get_hrn() -def hrn_to_pl_slicename (hrn): - return PlXrn(xrn=hrn,type='slice').pl_slicename() + + +def hrn_to_pl_slicename(hrn): + return PlXrn(xrn=hrn, type='slice').pl_slicename() # removed-dangerous - was used for non-slice objects -#def hrn_to_pl_login_base (hrn): +# def hrn_to_pl_login_base (hrn): # return PlXrn(xrn=hrn,type='slice').pl_login_base() -def hrn_to_pl_authname (hrn): - return PlXrn(xrn=hrn,type='any').pl_authname() + + +def hrn_to_pl_authname(hrn): + return PlXrn(xrn=hrn, type='any').pl_authname() + + def xrn_to_hostname(hrn): return Xrn.unescape(PlXrn(xrn=hrn, type='node').get_leaf()) -# helpers to handle external objects created via fedaration -def top_auth (hrn): +# helpers to handle external objects created via fedaration + + +def top_auth(hrn): return hrn.split('.')[0] + def hash_loginbase(site_hrn): if len(site_hrn) <= 12: - return site_hrn.replace('.','8').replace('_', '8') + return site_hrn.replace('.', '8').replace('_', '8') ratio = float(12) / len(site_hrn) auths_tab = site_hrn.split('.') auths_tab2 = [] for auth in auths_tab: auth = auth.replace('_', '8') - auth2 = auth[:int(len(auth)*ratio)] + auth2 = auth[:int(len(auth) * ratio)] auths_tab2.append(auth2) return '8'.join(auths_tab2) + class PlXrn (Xrn): - @staticmethod - def site_hrn (auth, login_base): - return '.'.join([auth,login_base]) + @staticmethod + def site_hrn(auth, login_base): + return '.'.join([auth, login_base]) - def __init__ (self, auth=None, hostname=None, slicename=None, email=None, interface=None, **kwargs): - #def hostname_to_hrn(auth_hrn, login_base, hostname): + def __init__(self, auth=None, hostname=None, slicename=None, email=None, interface=None, **kwargs): + # def hostname_to_hrn(auth_hrn, login_base, hostname): if hostname is not None: - self.type='node' + self.type = 'node' # keep only the first part of the DNS name #self.hrn='.'.join( [auth,hostname.split(".")[0] ] ) # escape the '.' in the hostname - self.hrn='.'.join( [auth,Xrn.escape(hostname)] ) + self.hrn = '.'.join([auth, Xrn.escape(hostname)]) self.hrn_to_urn() - #def slicename_to_hrn(auth_hrn, slicename): + # def slicename_to_hrn(auth_hrn, slicename): elif slicename is not None: - self.type='slice' + self.type = 'slice' # split at the first _ - parts = slicename.split("_",1) - self.hrn = ".".join([auth] + parts ) + parts = slicename.split("_", 1) + self.hrn = ".".join([auth] + parts) self.hrn_to_urn() - #def email_to_hrn(auth_hrn, email): + # def email_to_hrn(auth_hrn, email): elif email is not None: - self.type='person' + self.type = 'person' # keep only the part before '@' and replace special chars into _ - self.hrn='.'.join([auth,email.split('@')[0].replace(".", "_").replace("+", "_")]) + self.hrn = '.'.join( + [auth, email.split('@')[0].replace(".", "_").replace("+", "_")]) self.hrn_to_urn() elif interface is not None: self.type = 'interface' self.hrn = auth + '.' + interface self.hrn_to_urn() else: - Xrn.__init__ (self,**kwargs) + Xrn.__init__(self, **kwargs) - #def hrn_to_pl_slicename(hrn): - def pl_slicename (self): + # def hrn_to_pl_slicename(hrn): + def pl_slicename(self): self._normalize() leaf = self.leaf sliver_id_parts = leaf.split(':') @@ -84,8 +103,8 @@ class PlXrn (Xrn): name = re.sub('[^a-zA-Z0-9_]', '', name) return self.pl_login_base() + '_' + name - #def hrn_to_pl_authname(hrn): - def pl_authname (self): + # def hrn_to_pl_authname(hrn): + def pl_authname(self): self._normalize() return self.authority[-1] @@ -93,13 +112,13 @@ class PlXrn (Xrn): self._normalize() return self.leaf - def pl_login_base (self): + def pl_login_base(self): self._normalize() if self.type and self.type.startswith('authority'): - base = self.leaf + base = self.leaf else: base = self.authority[-1] - + # Fix up names of GENI Federates base = base.lower() base = re.sub('[\\\\]*[^a-zA-Z0-9]', '', base) diff --git a/sfa/planetlab/topology.py b/sfa/planetlab/topology.py index c05b198b..8add242a 100644 --- a/sfa/planetlab/topology.py +++ b/sfa/planetlab/topology.py @@ -1,20 +1,21 @@ ## # SFA Topology Info # -# This module holds topology configuration for SFA. It is implemnted as a +# This module holds topology configuration for SFA. It is implemnted as a # list of site_id tuples import os.path import traceback from sfa.util.sfalogging import logger + class Topology(set): """ Parse the topology configuration file. """ - def __init__(self, config_file = "/etc/sfa/topology"): - set.__init__(self) + def __init__(self, config_file="/etc/sfa/topology"): + set.__init__(self) try: # load the links f = open(config_file, 'r') @@ -24,7 +25,8 @@ class Topology(set): line = line[0:ignore] tup = line.split() if len(tup) > 1: - self.add((tup[0], tup[1])) + self.add((tup[0], tup[1])) except Exception as e: - logger.log_exc("Could not find or load the configuration file: %s" % config_file) + logger.log_exc( + "Could not find or load the configuration file: %s" % config_file) raise diff --git a/sfa/planetlab/vlink.py b/sfa/planetlab/vlink.py index b0a83004..75c19ce9 100644 --- a/sfa/planetlab/vlink.py +++ b/sfa/planetlab/vlink.py @@ -12,20 +12,20 @@ suffixes = { "bit": 1, "kibit": 1024, "kbit": 1000, - "mibit": 1024*1024, + "mibit": 1024 * 1024, "mbit": 1000000, - "gibit": 1024*1024*1024, + "gibit": 1024 * 1024 * 1024, "gbit": 1000000000, - "tibit": 1024*1024*1024*1024, + "tibit": 1024 * 1024 * 1024 * 1024, "tbit": 1000000000000, "bps": 8, - "kibps": 8*1024, + "kibps": 8 * 1024, "kbps": 8000, - "mibps": 8*1024*1024, + "mibps": 8 * 1024 * 1024, "mbps": 8000000, - "gibps": 8*1024*1024*1024, + "gibps": 8 * 1024 * 1024 * 1024, "gbps": 8000000000, - "tibps": 8*1024*1024*1024*1024, + "tibps": 8 * 1024 * 1024 * 1024 * 1024, "tbps": 8000000000000 } @@ -46,6 +46,7 @@ def get_tc_rate(s): else: return -1 + def format_tc_rate(rate): """ Formats a bits/second rate into a tc rate string @@ -60,13 +61,15 @@ def format_tc_rate(rate): else: return "%.0fbit" % rate + class VLink: + @staticmethod def get_link_id(if1, if2): if if1['id'] < if2['id']: - link = (if1['id']<<7) + if2['id'] + link = (if1['id'] << 7) + if2['id'] else: - link = (if2['id']<<7) + if1['id'] + link = (if2['id'] << 7) + if1['id'] return link @staticmethod @@ -82,14 +85,14 @@ class VLink: link_id = VLink.get_link_id(if1, if2) iface_id = VLink.get_iface_id(if1, if2) first = link_id >> 6 - second = ((link_id & 0x3f)<<2) + iface_id + second = ((link_id & 0x3f) << 2) + iface_id return "192.168.%d.%s" % (first, second) @staticmethod def get_virt_net(link): link_id = VLink.get_link_id(link['interface1'], link['interface2']) first = link_id >> 6 - second = (link_id & 0x3f)<<2 + second = (link_id & 0x3f) << 2 return "192.168.%d.%d/30" % (first, second) @staticmethod @@ -99,7 +102,6 @@ class VLink: node_id = int(node.replace("node", "")) return node_id - @staticmethod def get_topo_rspec(link, ipaddr): link['interface1']['id'] = VLink.get_interface_id(link['interface1']) @@ -110,6 +112,6 @@ class VLink: bw = format_tc_rate(long(link['capacity'])) return (link['interface2']['id'], ipaddr, bw, my_ip, remote_ip, net) - @staticmethod + @staticmethod def topo_rspec_to_link(topo_rspec): - pass + pass diff --git a/sfa/rspecs/elements/attribute.py b/sfa/rspecs/elements/attribute.py index 740083b7..19bdabe1 100644 --- a/sfa/rspecs/elements/attribute.py +++ b/sfa/rspecs/elements/attribute.py @@ -1,9 +1,9 @@ from sfa.rspecs.elements.element import Element + class Attribute(Element): fields = [ 'name', 'value', ] - diff --git a/sfa/rspecs/elements/bwlimit.py b/sfa/rspecs/elements/bwlimit.py index 6f75161c..083d54c4 100644 --- a/sfa/rspecs/elements/bwlimit.py +++ b/sfa/rspecs/elements/bwlimit.py @@ -1,8 +1,8 @@ from sfa.rspecs.elements.element import Element + class BWlimit(Element): - fields = [ + fields = [ 'units', 'value', ] - diff --git a/sfa/rspecs/elements/channel.py b/sfa/rspecs/elements/channel.py index 96680544..cf4a7cb8 100644 --- a/sfa/rspecs/elements/channel.py +++ b/sfa/rspecs/elements/channel.py @@ -1,7 +1,8 @@ from sfa.rspecs.elements.element import Element - + + class Channel(Element): - + fields = [ 'reservation_id', 'channel_num', diff --git a/sfa/rspecs/elements/datapath.py b/sfa/rspecs/elements/datapath.py index 5b51e1bb..089cec77 100644 --- a/sfa/rspecs/elements/datapath.py +++ b/sfa/rspecs/elements/datapath.py @@ -1,12 +1,11 @@ from sfa.rspecs.elements.element import Element - + + class Datapath(Element): - + fields = [ 'component_id', 'component_manager_id', 'dp_id', 'ports', ] - - diff --git a/sfa/rspecs/elements/disk_image.py b/sfa/rspecs/elements/disk_image.py index 1f530f67..70f8ba13 100644 --- a/sfa/rspecs/elements/disk_image.py +++ b/sfa/rspecs/elements/disk_image.py @@ -1,9 +1,10 @@ from sfa.rspecs.elements.element import Element + class DiskImage(Element): fields = [ 'name', 'os', 'version', 'description', - ] + ] diff --git a/sfa/rspecs/elements/element.py b/sfa/rspecs/elements/element.py index df46c891..21f0949c 100644 --- a/sfa/rspecs/elements/element.py +++ b/sfa/rspecs/elements/element.py @@ -3,15 +3,15 @@ class Element(dict): fields = {} def __init__(self, fields=None, element=None, keys=None): - if fields is None: fields={} + if fields is None: + fields = {} self.element = element dict.__init__(self, dict.fromkeys(self.fields)) if not keys: keys = fields.keys() for key in keys: if key in fields: - self[key] = fields[key] - + self[key] = fields[key] def __getattr__(self, name): if hasattr(self.__dict__, name): @@ -20,4 +20,4 @@ class Element(dict): return getattr(self.element, name) else: raise AttributeError("class Element of type {} has no attribute {}" - .format(self.__class__.__name__, name)) + .format(self.__class__.__name__, name)) diff --git a/sfa/rspecs/elements/execute.py b/sfa/rspecs/elements/execute.py index e7ee7067..6bbf50da 100644 --- a/sfa/rspecs/elements/execute.py +++ b/sfa/rspecs/elements/execute.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class Execute(Element): fields = [ 'shell', diff --git a/sfa/rspecs/elements/fw_rule.py b/sfa/rspecs/elements/fw_rule.py index 9ae66ab9..11592294 100644 --- a/sfa/rspecs/elements/fw_rule.py +++ b/sfa/rspecs/elements/fw_rule.py @@ -1,10 +1,10 @@ from sfa.rspecs.elements.element import Element + class FWRule(Element): - fields = [ + fields = [ 'protocol', 'cidr_ip', 'port_range', 'icmp_type_code', ] - diff --git a/sfa/rspecs/elements/granularity.py b/sfa/rspecs/elements/granularity.py index 16d30a01..e8252447 100644 --- a/sfa/rspecs/elements/granularity.py +++ b/sfa/rspecs/elements/granularity.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class Granularity(Element): fields = [ diff --git a/sfa/rspecs/elements/hardware_type.py b/sfa/rspecs/elements/hardware_type.py index 5f20c9bb..ecfbd4e3 100644 --- a/sfa/rspecs/elements/hardware_type.py +++ b/sfa/rspecs/elements/hardware_type.py @@ -1,7 +1,8 @@ from sfa.rspecs.elements.element import Element + class HardwareType(Element): - + fields = [ 'name' - ] + ] diff --git a/sfa/rspecs/elements/install.py b/sfa/rspecs/elements/install.py index 227a7972..bf15623e 100644 --- a/sfa/rspecs/elements/install.py +++ b/sfa/rspecs/elements/install.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element - + + class Install(Element): fields = [ 'file_type', diff --git a/sfa/rspecs/elements/interface.py b/sfa/rspecs/elements/interface.py index 11045df8..8fa4ec97 100644 --- a/sfa/rspecs/elements/interface.py +++ b/sfa/rspecs/elements/interface.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class Interface(Element): fields = ['component_id', 'role', @@ -8,5 +9,5 @@ class Interface(Element): 'bwlimit', 'node_id', 'interface_id', - 'mac_address', - ] + 'mac_address', + ] diff --git a/sfa/rspecs/elements/lease.py b/sfa/rspecs/elements/lease.py index dc3fe587..3fbf55c6 100644 --- a/sfa/rspecs/elements/lease.py +++ b/sfa/rspecs/elements/lease.py @@ -1,11 +1,12 @@ from sfa.rspecs.elements.element import Element - + + class Lease(Element): - + fields = [ 'lease_id', 'component_id', 'slice_id', 'start_time', - 'duration', + 'duration', ] diff --git a/sfa/rspecs/elements/link.py b/sfa/rspecs/elements/link.py index 3bbfe2bb..f9ae02e3 100644 --- a/sfa/rspecs/elements/link.py +++ b/sfa/rspecs/elements/link.py @@ -1,8 +1,9 @@ -from sfa.rspecs.elements.element import Element +from sfa.rspecs.elements.element import Element + class Link(Element): fields = [ - 'client_id', + 'client_id', 'component_id', 'component_name', 'component_manager', diff --git a/sfa/rspecs/elements/location.py b/sfa/rspecs/elements/location.py index 57bfe0c1..f99c5432 100644 --- a/sfa/rspecs/elements/location.py +++ b/sfa/rspecs/elements/location.py @@ -1,7 +1,8 @@ from sfa.rspecs.elements.element import Element + class Location(Element): - + fields = [ 'country', 'longitude', diff --git a/sfa/rspecs/elements/login.py b/sfa/rspecs/elements/login.py index 51741a9b..7badea6b 100644 --- a/sfa/rspecs/elements/login.py +++ b/sfa/rspecs/elements/login.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class Login(Element): fields = [ 'authentication', diff --git a/sfa/rspecs/elements/node.py b/sfa/rspecs/elements/node.py index e0f65e40..4f167295 100644 --- a/sfa/rspecs/elements/node.py +++ b/sfa/rspecs/elements/node.py @@ -1,7 +1,8 @@ from sfa.rspecs.elements.element import Element - + + class NodeElement(Element): - + fields = [ 'client_id', 'component_id', @@ -9,12 +10,12 @@ class NodeElement(Element): 'component_manager_id', 'client_id', 'sliver_id', - 'authority_id', + 'authority_id', 'exclusive', 'location', 'bw_unallocated', 'bw_limit', - 'boot_state', + 'boot_state', 'slivers', 'hardware_types', 'disk_images', @@ -23,5 +24,3 @@ class NodeElement(Element): 'tags', 'pl_initscripts', ] - - diff --git a/sfa/rspecs/elements/pltag.py b/sfa/rspecs/elements/pltag.py index 8234f461..33026c5e 100644 --- a/sfa/rspecs/elements/pltag.py +++ b/sfa/rspecs/elements/pltag.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class PLTag(Element): fields = [ @@ -7,4 +8,3 @@ class PLTag(Element): 'value', 'scope', ] - diff --git a/sfa/rspecs/elements/port.py b/sfa/rspecs/elements/port.py index 2817b6b3..15dd1347 100644 --- a/sfa/rspecs/elements/port.py +++ b/sfa/rspecs/elements/port.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class Port(Element): fields = [ 'num', diff --git a/sfa/rspecs/elements/position_3d.py b/sfa/rspecs/elements/position_3d.py index d08a79c7..aaf9bcdb 100644 --- a/sfa/rspecs/elements/position_3d.py +++ b/sfa/rspecs/elements/position_3d.py @@ -1,7 +1,8 @@ from sfa.rspecs.elements.element import Element + class Position3D(Element): - + fields = [ 'x', 'y', diff --git a/sfa/rspecs/elements/property.py b/sfa/rspecs/elements/property.py index 472dedeb..8f1a37b1 100644 --- a/sfa/rspecs/elements/property.py +++ b/sfa/rspecs/elements/property.py @@ -1,7 +1,8 @@ from sfa.rspecs.elements.element import Element + class Property(Element): - + fields = [ 'source_id', 'dest_id', @@ -9,4 +10,3 @@ class Property(Element): 'latency', 'packet_loss', ] - diff --git a/sfa/rspecs/elements/services.py b/sfa/rspecs/elements/services.py index e159d70c..af737953 100644 --- a/sfa/rspecs/elements/services.py +++ b/sfa/rspecs/elements/services.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class ServicesElement(Element): fields = [ @@ -8,4 +9,3 @@ class ServicesElement(Element): 'login', 'services_user', ] - diff --git a/sfa/rspecs/elements/sliver.py b/sfa/rspecs/elements/sliver.py index 5186d1f8..b6a0d264 100644 --- a/sfa/rspecs/elements/sliver.py +++ b/sfa/rspecs/elements/sliver.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class Sliver(Element): fields = [ 'sliver_id', diff --git a/sfa/rspecs/elements/spectrum.py b/sfa/rspecs/elements/spectrum.py index 46eb3fa1..91e63448 100644 --- a/sfa/rspecs/elements/spectrum.py +++ b/sfa/rspecs/elements/spectrum.py @@ -1,5 +1,6 @@ from sfa.rspecs.elements.element import Element + class Spectrum(Element): fields = [] diff --git a/sfa/rspecs/elements/versions/iotlabv1Lease.py b/sfa/rspecs/elements/versions/iotlabv1Lease.py index bfc503aa..ea796b55 100644 --- a/sfa/rspecs/elements/versions/iotlabv1Lease.py +++ b/sfa/rspecs/elements/versions/iotlabv1Lease.py @@ -5,7 +5,6 @@ from sfa.util.xrn import Xrn from sfa.rspecs.elements.lease import Lease - class Iotlabv1Lease: @staticmethod @@ -15,23 +14,27 @@ class Iotlabv1Lease: if len(network_elems) > 0: network_elem = network_elems[0] elif len(leases) > 0: - network_urn = Xrn(leases[0]['component_id']).get_authority_urn().split(':')[0] - network_elem = xml.add_element('network', name = network_urn) + network_urn = Xrn(leases[0]['component_id'] + ).get_authority_urn().split(':')[0] + network_elem = xml.add_element('network', name=network_urn) else: network_elem = xml lease_elems = [] for lease in leases: - lease_fields = ['lease_id', 'component_id', 'slice_id', 'start_time', 'duration'] - lease_elem = network_elem.add_instance('lease', lease, lease_fields) + lease_fields = ['lease_id', 'component_id', + 'slice_id', 'start_time', 'duration'] + lease_elem = network_elem.add_instance( + 'lease', lease, lease_fields) lease_elems.append(lease_elem) - @staticmethod def get_leases(xml, filter=None): - if filter is None: filter={} - xpath = '//lease%s | //default:lease%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//lease%s | //default:lease%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) lease_elems = xml.xpath(xpath) return Iotlabv1Lease.get_lease_objs(lease_elems) @@ -39,14 +42,14 @@ class Iotlabv1Lease: def get_lease_objs(lease_elems): leases = [] for lease_elem in lease_elems: - #get nodes + # get nodes node_elems = lease_elem.xpath('./default:node | ./node') for node_elem in node_elems: - lease = Lease(lease_elem.attrib, lease_elem) - lease['slice_id'] = lease_elem.attrib['slice_id'] - lease['start_time'] = lease_elem.attrib['start_time'] - lease['duration'] = lease_elem.attrib['duration'] - lease['component_id'] = node_elem.attrib['component_id'] - leases.append(lease) + lease = Lease(lease_elem.attrib, lease_elem) + lease['slice_id'] = lease_elem.attrib['slice_id'] + lease['start_time'] = lease_elem.attrib['start_time'] + lease['duration'] = lease_elem.attrib['duration'] + lease['component_id'] = node_elem.attrib['component_id'] + leases.append(lease) return leases diff --git a/sfa/rspecs/elements/versions/iotlabv1Node.py b/sfa/rspecs/elements/versions/iotlabv1Node.py index 553f1c09..99b5dea2 100644 --- a/sfa/rspecs/elements/versions/iotlabv1Node.py +++ b/sfa/rspecs/elements/versions/iotlabv1Node.py @@ -10,19 +10,21 @@ from sfa.rspecs.elements.interface import Interface from sfa.rspecs.elements.versions.iotlabv1Sliver import Iotlabv1Sliver from sfa.util.sfalogging import logger + class IotlabNode(NodeElement): - #First get the fields already defined in the class Node + # First get the fields already defined in the class Node fields = list(NodeElement.fields) - #Extend it with iotlab's specific fields - fields.extend (['archi', 'radio', 'mobile','position']) + # Extend it with iotlab's specific fields + fields.extend(['archi', 'radio', 'mobile', 'position']) class IotlabPosition(Element): - fields = ['x', 'y','z'] + fields = ['x', 'y', 'z'] + class IotlabLocation(Location): fields = list(Location.fields) - fields.extend (['site']) + fields.extend(['site']) class IotlabMobility(Element): @@ -31,7 +33,6 @@ class IotlabMobility(Element): fields = ['mobile', 'mobility_type'] - class Iotlabv1Node: @staticmethod @@ -42,7 +43,7 @@ class Iotlabv1Node: """ logger.debug(" add_connection_information ") - #Get network item in the xml + # Get network item in the xml network_elems = xml.xpath('//network') if len(network_elems) > 0: network_elem = network_elems[0] @@ -51,14 +52,14 @@ class Iotlabv1Node: iotlab_network_dict['login'] = ldap_username iotlab_network_dict['ssh'] = \ - ['ssh ' + ldap_username + '@'+site+'.iotlab.info' + ['ssh ' + ldap_username + '@' + site + '.iotlab.info' for site in sites_set] network_elem.set('ssh', unicode(iotlab_network_dict['ssh'])) network_elem.set('login', unicode(iotlab_network_dict['login'])) @staticmethod - def add_nodes(xml, nodes,rspec_content_type=None): + def add_nodes(xml, nodes, rspec_content_type=None): """Adds the nodes to the xml. Adds the nodes as well as dedicated iotlab fields to the node xml @@ -72,7 +73,7 @@ class Iotlabv1Node: :rtype: list """ - #Add network item in the xml + # Add network item in the xml network_elems = xml.xpath('//network') if len(network_elems) > 0: network_elem = network_elems[0] @@ -84,23 +85,23 @@ class Iotlabv1Node: network_elem = xml node_elems = [] - #Then add nodes items to the network item in the xml + # Then add nodes items to the network item in the xml for node in nodes: - #Attach this node to the network element + # Attach this node to the network element node_fields = ['component_manager_id', 'component_id', 'exclusive', 'boot_state', 'mobile'] node_elem = network_elem.add_instance('node', node, node_fields) node_elems.append(node_elem) - #Set the attibutes of this node element + # Set the attibutes of this node element for attribute in node: - # set component name + # set component name if attribute is 'component_name': component_name = node['component_name'] node_elem.set('component_name', component_name) # set hardware types, extend fields to add Iotlab's architecture - #and radio type + # and radio type if attribute is 'hardware_types': for hardware_type in node.get('hardware_types', []): @@ -115,10 +116,10 @@ class Iotlabv1Node: # set location if attribute is 'location': node_elem.add_instance('location', node['location'], - IotlabLocation.fields) + IotlabLocation.fields) # add granularity of the reservation system - #TODO put the granularity in network instead SA 18/07/12 + # TODO put the granularity in network instead SA 18/07/12 if attribute is 'granularity': granularity = node['granularity'] if granularity: @@ -134,24 +135,24 @@ class Iotlabv1Node: available_elem = node_elem.add_element('available', now='false') - #set position + # set position if attribute is 'position': node_elem.add_instance('position', node['position'], IotlabPosition.fields) - ## add services + # add services #PGv2Services.add_services(node_elem, node.get('services', [])) # add slivers if attribute is 'slivers': slivers = node.get('slivers', []) if not slivers: - # we must still advertise the available sliver types + # we must still advertise the available sliver types slivers = Sliver({'type': 'iotlab-node'}) # we must also advertise the available initscripts #slivers['tags'] = [] - #if node.get('pl_initscripts'): - #for initscript in node.get('pl_initscripts', []): - #slivers['tags'].append({'name': 'initscript', \ - #'value': initscript['name']}) + # if node.get('pl_initscripts'): + # for initscript in node.get('pl_initscripts', []): + # slivers['tags'].append({'name': 'initscript', \ + #'value': initscript['name']}) Iotlabv1Sliver.add_slivers(node_elem, slivers) # add sliver tag in Request Rspec @@ -161,21 +162,23 @@ class Iotlabv1Node: @staticmethod def get_nodes(xml, filter=None): - if filter is None: filter={} - xpath = '//node%s | //default:node%s' % (XpathFilter.xpath(filter), \ - XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//node%s | //default:node%s' % (XpathFilter.xpath(filter), + XpathFilter.xpath(filter)) node_elems = xml.xpath(xpath) return Iotlabv1Node.get_node_objs(node_elems) @staticmethod def get_nodes_with_slivers(xml, sliver_filter=None): - if sliver_filter is None: sliver_filter={} + if sliver_filter is None: + sliver_filter = {} xpath = '//node[count(sliver)>0] | \ //default:node[count(default:sliver) > 0]' node_elems = xml.xpath(xpath) logger.debug("SLABV1NODE \tget_nodes_with_slivers \ - node_elems %s"%(node_elems)) + node_elems %s" % (node_elems)) return Iotlabv1Node.get_node_objs(node_elems) @staticmethod @@ -191,21 +194,20 @@ class Iotlabv1Node: # get hardware types hardware_type_elems = node_elem.xpath('./default:hardware_type | \ ./hardware_type') - node['hardware_types'] = [hw_type.get_instance(HardwareType) \ - for hw_type in hardware_type_elems] + node['hardware_types'] = [hw_type.get_instance(HardwareType) + for hw_type in hardware_type_elems] # get location location_elems = node_elem.xpath('./default:location | ./location') - locations = [location_elem.get_instance(Location) \ - for location_elem in location_elems] + locations = [location_elem.get_instance(Location) + for location_elem in location_elems] if len(locations) > 0: node['location'] = locations[0] - # get interfaces iface_elems = node_elem.xpath('./default:interface | ./interface') - node['interfaces'] = [iface_elem.get_instance(Interface) \ - for iface_elem in iface_elems] + node['interfaces'] = [iface_elem.get_instance(Interface) + for iface_elem in iface_elems] # get position position_elems = node_elem.xpath('./default:position | ./position') @@ -227,10 +229,9 @@ class Iotlabv1Node: node['boot_state'] = 'disabled' logger.debug("SLABV1NODE \tget_nodes_objs \ - #nodes %s"%(nodes)) + #nodes %s" % (nodes)) return nodes - @staticmethod def add_slivers(xml, slivers): logger.debug("Iotlabv1NODE \tadd_slivers ") @@ -253,12 +254,9 @@ class Iotlabv1Node: @staticmethod def remove_slivers(xml, hostnames): for hostname in hostnames: - nodes = Iotlabv1Node.get_nodes(xml, \ - {'component_id': '*%s*' % hostname}) + nodes = Iotlabv1Node.get_nodes(xml, + {'component_id': '*%s*' % hostname}) for node in nodes: slivers = Iotlabv1Sliver.get_slivers(node.element) for sliver in slivers: node.element.remove(sliver.element) - - - diff --git a/sfa/rspecs/elements/versions/iotlabv1Sliver.py b/sfa/rspecs/elements/versions/iotlabv1Sliver.py index 269de566..110b9322 100644 --- a/sfa/rspecs/elements/versions/iotlabv1Sliver.py +++ b/sfa/rspecs/elements/versions/iotlabv1Sliver.py @@ -5,6 +5,8 @@ from sfa.rspecs.elements.sliver import Sliver #from sfa.rspecs.elements.versions.pgv2DiskImage import PGv2DiskImage import sys + + class Iotlabv1Sliver: @staticmethod @@ -21,30 +23,35 @@ class Iotlabv1Sliver: if sliver.get('client_id'): sliver_elem.set('client_id', sliver['client_id']) #images = sliver.get('disk_images') - #if images and isinstance(images, list): + # if images and isinstance(images, list): #Iotlabv1DiskImage.add_images(sliver_elem, images) - Iotlabv1Sliver.add_sliver_attributes(sliver_elem, sliver.get('tags', [])) + Iotlabv1Sliver.add_sliver_attributes( + sliver_elem, sliver.get('tags', [])) @staticmethod def add_sliver_attributes(xml, attributes): if attributes: for attribute in attributes: if attribute['name'] == 'initscript': - xml.add_element('{%s}initscript' % xml.namespaces['planetlab'], name=attribute['value']) + xml.add_element('{%s}initscript' % xml.namespaces[ + 'planetlab'], name=attribute['value']) elif tag['tagname'] == 'flack_info': - attrib_elem = xml.add_element('{%s}info' % self.namespaces['flack']) + attrib_elem = xml.add_element( + '{%s}info' % self.namespaces['flack']) attrib_dict = eval(tag['value']) for (key, value) in attrib_dict.items(): attrib_elem.set(key, value) + @staticmethod def get_slivers(xml, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} xpath = './default:sliver | ./sliver' sliver_elems = xml.xpath(xpath) slivers = [] for sliver_elem in sliver_elems: - sliver = Sliver(sliver_elem.attrib,sliver_elem) + sliver = Sliver(sliver_elem.attrib, sliver_elem) if 'component_id' in xml.attrib: sliver['component_id'] = xml.attrib['component_id'] @@ -52,11 +59,13 @@ class Iotlabv1Sliver: sliver['type'] = sliver_elem.attrib['name'] #sliver['images'] = Iotlabv1DiskImage.get_images(sliver_elem) - print("\r\n \r\n SLABV1SLIVER.PY \t\t\t get_slivers sliver %s " %( sliver), file=sys.stderr) + print("\r\n \r\n SLABV1SLIVER.PY \t\t\t get_slivers sliver %s " % + (sliver), file=sys.stderr) slivers.append(sliver) return slivers @staticmethod def get_sliver_attributes(xml, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} return [] diff --git a/sfa/rspecs/elements/versions/nitosv1Channel.py b/sfa/rspecs/elements/versions/nitosv1Channel.py index cf4a5f61..04f4aa6e 100644 --- a/sfa/rspecs/elements/versions/nitosv1Channel.py +++ b/sfa/rspecs/elements/versions/nitosv1Channel.py @@ -23,18 +23,18 @@ class NITOSv1Channel: @staticmethod def add_channels(xml, channels): - + network_elems = xml.xpath('//network') if len(network_elems) > 0: network_elem = network_elems[0] elif len(channels) > 0: - # dirty hack that handles no resource manifest rspec + # dirty hack that handles no resource manifest rspec network_urn = "omf" - network_elem = xml.add_element('network', name = network_urn) + network_elem = xml.add_element('network', name=network_urn) else: network_elem = xml -# spectrum_elems = xml.xpath('//spectrum') +# spectrum_elems = xml.xpath('//spectrum') # spectrum_elem = xml.add_element('spectrum') # if len(spectrum_elems) > 0: @@ -44,25 +44,28 @@ class NITOSv1Channel: # else: # spectrum_elem = xml - spectrum_elem = network_elem.add_instance('spectrum', []) - - channel_elems = [] + spectrum_elem = network_elem.add_instance('spectrum', []) + + channel_elems = [] for channel in channels: - channel_fields = ['channel_num', 'frequency', 'standard', 'component_id'] - channel_elem = spectrum_elem.add_instance('channel', channel, channel_fields) + channel_fields = ['channel_num', + 'frequency', 'standard', 'component_id'] + channel_elem = spectrum_elem.add_instance( + 'channel', channel, channel_fields) channel_elems.append(channel_elem) - @staticmethod def get_channels(xml, filter=None): - if filter is None: filter={} - xpath = '//channel%s | //default:channel%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//channel%s | //default:channel%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) channel_elems = xml.xpath(xpath) return NITOSv1Channel.get_channel_objs(channel_elems) @staticmethod def get_channel_objs(channel_elems): - channels = [] + channels = [] for channel_elem in channel_elems: channel = Channel(channel_elem.attrib, channel_elem) channel['channel_num'] = channel_elem.attrib['channel_num'] @@ -71,5 +74,4 @@ class NITOSv1Channel: channel['component_id'] = channel_elem.attrib['component_id'] channels.append(channel) - return channels - + return channels diff --git a/sfa/rspecs/elements/versions/nitosv1Lease.py b/sfa/rspecs/elements/versions/nitosv1Lease.py index dd3041c5..b5cb319c 100644 --- a/sfa/rspecs/elements/versions/nitosv1Lease.py +++ b/sfa/rspecs/elements/versions/nitosv1Lease.py @@ -19,93 +19,95 @@ from sfa.rspecs.elements.lease import Lease from sfa.rspecs.elements.channel import Channel - class NITOSv1Lease: @staticmethod def add_leases(xml, leases, channels): - + network_elems = xml.xpath('//network') if len(network_elems) > 0: network_elem = network_elems[0] elif len(leases) > 0: - network_urn = Xrn(leases[0]['component_id']).get_authority_urn().split(':')[0] - network_elem = xml.add_element('network', name = network_urn) + network_urn = Xrn(leases[0]['component_id'] + ).get_authority_urn().split(':')[0] + network_elem = xml.add_element('network', name=network_urn) else: network_elem = xml - + # group the leases by slice and timeslots grouped_leases = [] while leases: - slice_id = leases[0]['slice_id'] - start_time = leases[0]['start_time'] - duration = leases[0]['duration'] - group = [] - - for lease in leases: - if slice_id == lease['slice_id'] and start_time == lease['start_time'] and duration == lease['duration']: - group.append(lease) - - grouped_leases.append(group) - - for lease1 in group: - leases.remove(lease1) - - lease_elems = [] + slice_id = leases[0]['slice_id'] + start_time = leases[0]['start_time'] + duration = leases[0]['duration'] + group = [] + + for lease in leases: + if slice_id == lease['slice_id'] and start_time == lease['start_time'] and duration == lease['duration']: + group.append(lease) + + grouped_leases.append(group) + + for lease1 in group: + leases.remove(lease1) + + lease_elems = [] for lease in grouped_leases: #lease[0]['start_time'] = datetime_to_string(utcparse(lease[0]['start_time'])) lease_fields = ['slice_id', 'start_time', 'duration'] - lease_elem = network_elem.add_instance('lease', lease[0], lease_fields) + lease_elem = network_elem.add_instance( + 'lease', lease[0], lease_fields) lease_elems.append(lease_elem) # add nodes of this lease for node in lease: - lease_elem.add_instance('node', node, ['component_id']) + lease_elem.add_instance('node', node, ['component_id']) # add reserved channels of this lease #channels = [{'channel_id': 1}, {'channel_id': 2}] for channel in channels: - #channel['start_time'] = datetime_to_string(utcparse(channel['start_time'])) - if channel['slice_id'] == lease[0]['slice_id'] and channel['start_time'] == lease[0]['start_time'] and channel['duration'] == lease[0]['duration']: - lease_elem.add_instance('channel', channel, ['component_id']) - + #channel['start_time'] = datetime_to_string(utcparse(channel['start_time'])) + if channel['slice_id'] == lease[0]['slice_id'] and channel['start_time'] == lease[0]['start_time'] and channel['duration'] == lease[0]['duration']: + lease_elem.add_instance( + 'channel', channel, ['component_id']) @staticmethod def get_leases(xml, filter=None): - if filter is None: filter={} - xpath = '//lease%s | //default:lease%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//lease%s | //default:lease%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) lease_elems = xml.xpath(xpath) return NITOSv1Lease.get_lease_objs(lease_elems) @staticmethod def get_lease_objs(lease_elems): - leases = [] + leases = [] channels = [] for lease_elem in lease_elems: - #get nodes + # get nodes node_elems = lease_elem.xpath('./default:node | ./node') for node_elem in node_elems: - lease = Lease(lease_elem.attrib, lease_elem) - lease['slice_id'] = lease_elem.attrib['slice_id'] - #lease['start_time'] = datetime_to_epoch(utcparse(lease_elem.attrib['start_time'])) - lease['start_time'] = lease_elem.attrib['start_time'] - lease['duration'] = lease_elem.attrib['duration'] - lease['component_id'] = node_elem.attrib['component_id'] - lease['type'] = 'node' - leases.append(lease) - #get channels + lease = Lease(lease_elem.attrib, lease_elem) + lease['slice_id'] = lease_elem.attrib['slice_id'] + #lease['start_time'] = datetime_to_epoch(utcparse(lease_elem.attrib['start_time'])) + lease['start_time'] = lease_elem.attrib['start_time'] + lease['duration'] = lease_elem.attrib['duration'] + lease['component_id'] = node_elem.attrib['component_id'] + lease['type'] = 'node' + leases.append(lease) + # get channels channel_elems = lease_elem.xpath('./default:channel | ./channel') for channel_elem in channel_elems: - channel = Channel(channel_elem.attrib, channel_elem) - channel['slice_id'] = lease_elem.attrib['slice_id'] - #channel['start_time'] = datetime_to_epoch(utcparse(lease_elem.attrib['start_time'])) - channel['start_time'] = lease_elem.attrib['start_time'] - channel['duration'] = lease_elem.attrib['duration'] - channel['component_id'] = channel_elem.attrib['component_id'] - channel['type'] = 'channel' - channels.append(channel) - - return leases + channels + channel = Channel(channel_elem.attrib, channel_elem) + channel['slice_id'] = lease_elem.attrib['slice_id'] + #channel['start_time'] = datetime_to_epoch(utcparse(lease_elem.attrib['start_time'])) + channel['start_time'] = lease_elem.attrib['start_time'] + channel['duration'] = lease_elem.attrib['duration'] + channel['component_id'] = channel_elem.attrib['component_id'] + channel['type'] = 'channel' + channels.append(channel) + return leases + channels diff --git a/sfa/rspecs/elements/versions/nitosv1Node.py b/sfa/rspecs/elements/versions/nitosv1Node.py index ea59b3d5..c4349398 100644 --- a/sfa/rspecs/elements/versions/nitosv1Node.py +++ b/sfa/rspecs/elements/versions/nitosv1Node.py @@ -26,21 +26,23 @@ class NITOSv1Node: network_elem = network_elems[0] elif len(nodes) > 0 and nodes[0].get('component_manager_id'): network_urn = nodes[0]['component_manager_id'] - network_elem = xml.add_element('network', name = Xrn(network_urn).get_hrn()) + network_elem = xml.add_element( + 'network', name=Xrn(network_urn).get_hrn()) else: network_elem = xml # needs to be improuved to retreive the gateway addr dynamically. gateway_addr = 'nitlab.inf.uth.gr' - node_elems = [] + node_elems = [] for node in nodes: - node_fields = ['component_manager_id', 'component_id', 'boot_state'] + node_fields = ['component_manager_id', + 'component_id', 'boot_state'] node_elem = network_elem.add_instance('node', node, node_fields) node_elems.append(node_elem) # determine network hrn - network_hrn = None + network_hrn = None if 'component_manager_id' in node and node['component_manager_id']: network_hrn = Xrn(node['component_manager_id']).get_hrn() @@ -63,12 +65,13 @@ class NITOSv1Node: # add 3D Position of the node position_3d = node.get('position_3d') if position_3d: - node_elem.add_instance('position_3d', position_3d, Position3D.fields) + node_elem.add_instance( + 'position_3d', position_3d, Position3D.fields) # all nitos nodes are exculsive exclusive_elem = node_elem.add_element('exclusive') exclusive_elem.set_text('TRUE') - + # In order to access nitos nodes, one need to pass through the nitos gateway # here we advertise Nitos access gateway address gateway_elem = node_elem.add_element('gateway') @@ -82,17 +85,17 @@ class NITOSv1Node: granularity_elem.set_text(str(granularity)) # add hardware type #hardware_type = node.get('hardware_type') - #if hardware_type: + # if hardware_type: # node_elem.add_instance('hardware_type', hardware_type) hardware_type_elem = node_elem.add_element('hardware_type') hardware_type_elem.set_text(node.get('hardware_type')) - if isinstance(node.get('interfaces'), list): for interface in node.get('interfaces', []): - node_elem.add_instance('interface', interface, ['component_id', 'client_id', 'ipv4']) - - #if 'bw_unallocated' in node and node['bw_unallocated']: + node_elem.add_instance('interface', interface, [ + 'component_id', 'client_id', 'ipv4']) + + # if 'bw_unallocated' in node and node['bw_unallocated']: # bw_unallocated = etree.SubElement(node_elem, 'bw_unallocated', units='kbps').text = str(int(node['bw_unallocated'])/1000) PGv2Services.add_services(node_elem, node.get('services', [])) @@ -102,12 +105,12 @@ class NITOSv1Node: tag_elem = node_elem.add_element(tag['tagname']) tag_elem.set_text(tag['value']) NITOSv1Sliver.add_slivers(node_elem, node.get('slivers', [])) - + # add sliver tag in Request Rspec if rspec_content_type == "request": node_elem.add_instance('sliver', '', []) - @staticmethod + @staticmethod def add_slivers(xml, slivers): component_ids = [] for sliver in slivers: @@ -118,7 +121,7 @@ class NITOSv1Node: elif 'component_id' in sliver and sliver['component_id']: filter['component_id'] = '*%s*' % sliver['component_id'] if not filter: - continue + continue nodes = NITOSv1Node.get_nodes(xml, filter) if not nodes: continue @@ -128,56 +131,64 @@ class NITOSv1Node: @staticmethod def remove_slivers(xml, hostnames): for hostname in hostnames: - nodes = NITOSv1Node.get_nodes(xml, {'component_id': '*%s*' % hostname}) + nodes = NITOSv1Node.get_nodes( + xml, {'component_id': '*%s*' % hostname}) for node in nodes: slivers = NITOSv1Sliver.get_slivers(node.element) for sliver in slivers: node.element.remove(sliver.element) - + @staticmethod def get_nodes(xml, filter=None): - if filter is None: filter={} - xpath = '//node%s | //default:node%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//node%s | //default:node%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) node_elems = xml.xpath(xpath) return NITOSv1Node.get_node_objs(node_elems) @staticmethod def get_nodes_with_slivers(xml): - xpath = '//node[count(sliver)>0] | //default:node[count(default:sliver)>0]' + xpath = '//node[count(sliver)>0] | //default:node[count(default:sliver)>0]' node_elems = xml.xpath(xpath) return NITOSv1Node.get_node_objs(node_elems) - @staticmethod def get_node_objs(node_elems): - nodes = [] + nodes = [] for node_elem in node_elems: node = NodeElement(node_elem.attrib, node_elem) if 'site_id' in node_elem.attrib: node['authority_id'] = node_elem.attrib['site_id'] # get location location_elems = node_elem.xpath('./default:location | ./location') - locations = [loc_elem.get_instance(Location) for loc_elem in location_elems] + locations = [loc_elem.get_instance( + Location) for loc_elem in location_elems] if len(locations) > 0: node['location'] = locations[0] # get bwlimit bwlimit_elems = node_elem.xpath('./default:bw_limit | ./bw_limit') - bwlimits = [bwlimit_elem.get_instance(BWlimit) for bwlimit_elem in bwlimit_elems] + bwlimits = [bwlimit_elem.get_instance( + BWlimit) for bwlimit_elem in bwlimit_elems] if len(bwlimits) > 0: node['bwlimit'] = bwlimits[0] # get interfaces iface_elems = node_elem.xpath('./default:interface | ./interface') - ifaces = [iface_elem.get_instance(Interface) for iface_elem in iface_elems] + ifaces = [iface_elem.get_instance( + Interface) for iface_elem in iface_elems] node['interfaces'] = ifaces # get services - node['services'] = PGv2Services.get_services(node_elem) + node['services'] = PGv2Services.get_services(node_elem) # get slivers node['slivers'] = NITOSv1Sliver.get_slivers(node_elem) # get tags - node['tags'] = NITOSv1PLTag.get_pl_tags(node_elem, ignore=NodeElement.fields+["hardware_type"]) + node['tags'] = NITOSv1PLTag.get_pl_tags( + node_elem, ignore=NodeElement.fields + ["hardware_type"]) # get hardware types - hardware_type_elems = node_elem.xpath('./default:hardware_type | ./hardware_type') - node['hardware_types'] = [hw_type.get_instance(HardwareType) for hw_type in hardware_type_elems] + hardware_type_elems = node_elem.xpath( + './default:hardware_type | ./hardware_type') + node['hardware_types'] = [hw_type.get_instance( + HardwareType) for hw_type in hardware_type_elems] # temporary... play nice with old slice manager rspec if not node['component_name']: @@ -186,5 +197,4 @@ class NITOSv1Node: node['component_name'] = hostname_elem.text nodes.append(node) - return nodes - + return nodes diff --git a/sfa/rspecs/elements/versions/nitosv1PLTag.py b/sfa/rspecs/elements/versions/nitosv1PLTag.py index ea34ff4e..cca22af3 100644 --- a/sfa/rspecs/elements/versions/nitosv1PLTag.py +++ b/sfa/rspecs/elements/versions/nitosv1PLTag.py @@ -1,20 +1,22 @@ -from sfa.rspecs.elements.element import Element +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.pltag import PLTag + class NITOSv1PLTag: + @staticmethod def add_pl_tag(xml, name, value): for pl_tag in pl_tags: pl_tag_elem = xml.add_element(name) pl_tag_elem.set_text(value) - + @staticmethod def get_pl_tags(xml, ignore=None): - if ignore is None: ignore=[] + if ignore is None: + ignore = [] pl_tags = [] for elem in xml.iterchildren(): if elem.tag not in ignore: pl_tag = PLTag({'tagname': elem.tag, 'value': elem.text}) - pl_tags.append(pl_tag) + pl_tags.append(pl_tag) return pl_tags - diff --git a/sfa/rspecs/elements/versions/nitosv1Sliver.py b/sfa/rspecs/elements/versions/nitosv1Sliver.py index feac8879..6ceb87fc 100644 --- a/sfa/rspecs/elements/versions/nitosv1Sliver.py +++ b/sfa/rspecs/elements/versions/nitosv1Sliver.py @@ -7,6 +7,7 @@ from sfa.rspecs.elements.versions.nitosv1PLTag import NITOSv1PLTag #from sfa.planetlab.plxrn import PlXrn + class NITOSv1Sliver: @staticmethod @@ -20,16 +21,18 @@ class NITOSv1Sliver: tags = sliver.get('tags', []) if tags: for tag in tags: - NITOSv1Sliver.add_sliver_attribute(sliver_elem, tag['tagname'], tag['value']) + NITOSv1Sliver.add_sliver_attribute( + sliver_elem, tag['tagname'], tag['value']) if sliver.get('sliver_id'): - name = Xrn(xrn=sliver.get('sliver_id')).get_hrn().split('.')[-1] + name = Xrn(xrn=sliver.get('sliver_id') + ).get_hrn().split('.')[-1] sliver_elem.set('name', name) @staticmethod def add_sliver_attribute(xml, name, value): elem = xml.add_element(name) elem.set_text(value) - + @staticmethod def get_sliver_attributes(xml): attribs = [] @@ -40,19 +43,19 @@ class NITOSv1Sliver: instance['name'] = elem.tag instance['value'] = elem.text attribs.append(instance) - return attribs - + return attribs + @staticmethod def get_slivers(xml, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} xpath = './default:sliver | ./sliver' sliver_elems = xml.xpath(xpath) slivers = [] for sliver_elem in sliver_elems: - sliver = Sliver(sliver_elem.attrib,sliver_elem) - if 'component_id' in xml.attrib: + sliver = Sliver(sliver_elem.attrib, sliver_elem) + if 'component_id' in xml.attrib: sliver['component_id'] = xml.attrib['component_id'] sliver['tags'] = NITOSv1Sliver.get_sliver_attributes(sliver_elem) slivers.append(sliver) - return slivers - + return slivers diff --git a/sfa/rspecs/elements/versions/ofeliav1Port.py b/sfa/rspecs/elements/versions/ofeliav1Port.py index 009d551c..36003743 100644 --- a/sfa/rspecs/elements/versions/ofeliav1Port.py +++ b/sfa/rspecs/elements/versions/ofeliav1Port.py @@ -1,8 +1,9 @@ from sfa.util.xrn import Xrn from sfa.util.xml import XmlElement -from sfa.rspecs.elements.element import Element -from sfa.rspecs.elements.port import Port +from sfa.rspecs.elements.element import Element +from sfa.rspecs.elements.port import Port + class Ofeliav1Port: @@ -18,14 +19,15 @@ class Ofeliav1Port: tags = port.get('tags', []) if tags: for tag in tags: - Ofeliav1Port.add_port_attribute(port_elem, tag['tagname'], tag['value']) + Ofeliav1Port.add_port_attribute( + port_elem, tag['tagname'], tag['value']) @staticmethod def add_port_attribute(xml, name, value): raise Exception("not implemented yet") elem = xml.add_element(name) elem.set_text(value) - + @staticmethod def get_port_attributes(xml): attribs = [] @@ -36,19 +38,19 @@ class Ofeliav1Port: instance['name'] = elem.tag instance['value'] = elem.text attribs.append(instance) - return attribs - + return attribs + @staticmethod def get_ports(xml, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} xpath = './openflow:port | ./port' port_elems = xml.xpath(xpath) ports = [] for port_elem in port_elems: - port = Port(port_elem.attrib,port_elem) - #if 'component_id' in xml.attrib: + port = Port(port_elem.attrib, port_elem) + # if 'component_id' in xml.attrib: # port['component_id'] = xml.attrib['component_id'] #port['tags'] = Ofeliav1Port.get_port_attributes(port_elem) ports.append(port) - return ports - + return ports diff --git a/sfa/rspecs/elements/versions/ofeliav1datapath.py b/sfa/rspecs/elements/versions/ofeliav1datapath.py index 01390c0b..1cf77a4a 100644 --- a/sfa/rspecs/elements/versions/ofeliav1datapath.py +++ b/sfa/rspecs/elements/versions/ofeliav1datapath.py @@ -22,21 +22,23 @@ class Ofeliav1Datapath: @staticmethod def get_datapaths(xml, filter=None): - if filter is None: filter = {} + if filter is None: + filter = {} #xpath = '//datapath%s | //default:datapath%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) - xpath = '//datapath%s | //openflow:datapath%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + xpath = '//datapath%s | //openflow:datapath%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) datapath_elems = xml.xpath(xpath) return Ofeliav1Datapath.get_datapath_objs(datapath_elems) @staticmethod def get_datapath_objs(datapath_elems): - datapaths = [] + datapaths = [] for datapath_elem in datapath_elems: datapath = Datapath(datapath_elem.attrib, datapath_elem) # get ports - datapath['ports'] = Ofeliav1Port.get_ports(datapath_elem) + datapath['ports'] = Ofeliav1Port.get_ports(datapath_elem) datapaths.append(datapath) - return datapaths + return datapaths # @staticmethod # def add_nodes(xml, nodes, rspec_content_type=None): @@ -49,14 +51,14 @@ class Ofeliav1Datapath: # else: # network_elem = xml # -# node_elems = [] +# node_elems = [] # for node in nodes: # node_fields = ['component_manager_id', 'component_id', 'boot_state'] # node_elem = network_elem.add_instance('node', node, node_fields) # node_elems.append(node_elem) # # # determine network hrn -# network_hrn = None +# network_hrn = None # if 'component_manager_id' in node and node['component_manager_id']: # network_hrn = Xrn(node['component_manager_id']).get_hrn() # @@ -90,8 +92,8 @@ class Ofeliav1Datapath: # # if isinstance(node.get('interfaces'), list): # for interface in node.get('interfaces', []): -# node_elem.add_instance('interface', interface, ['component_id', 'client_id', 'ipv4']) -# +# node_elem.add_instance('interface', interface, ['component_id', 'client_id', 'ipv4']) +# # #if 'bw_unallocated' in node and node['bw_unallocated']: # # bw_unallocated = etree.SubElement(node_elem, 'bw_unallocated', units='kbps').text = str(int(node['bw_unallocated'])/1000) # @@ -115,9 +117,9 @@ class Ofeliav1Datapath: # # # add sliver tag in Request Rspec # if rspec_content_type == "request": -# node_elem.add_instance('sliver', '', []) +# node_elem.add_instance('sliver', '', []) # -# @staticmethod +# @staticmethod # def add_slivers(xml, slivers): # component_ids = [] # for sliver in slivers: @@ -128,7 +130,7 @@ class Ofeliav1Datapath: # elif 'component_id' in sliver and sliver['component_id']: # filter['component_id'] = '*%s*' % sliver['component_id'] # if not filter: -# continue +# continue # nodes = SFAv1Node.get_nodes(xml, filter) # if not nodes: # continue @@ -143,7 +145,7 @@ class Ofeliav1Datapath: # slivers = SFAv1Sliver.get_slivers(node.element) # for sliver in slivers: # node.element.remove(sliver.element) -# +# # @staticmethod # def get_nodes(xml, filter={}): # xpath = '//node%s | //default:node%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) @@ -152,7 +154,7 @@ class Ofeliav1Datapath: # # @staticmethod # def get_nodes_with_slivers(xml): -# xpath = '//node[count(sliver)>0] | //default:node[count(default:sliver)>0]' +# xpath = '//node[count(sliver)>0] | //default:node[count(default:sliver)>0]' # node_elems = xml.xpath(xpath) # return SFAv1Node.get_node_objs(node_elems) # diff --git a/sfa/rspecs/elements/versions/ofeliav1link.py b/sfa/rspecs/elements/versions/ofeliav1link.py index 3a32f217..da299c5b 100644 --- a/sfa/rspecs/elements/versions/ofeliav1link.py +++ b/sfa/rspecs/elements/versions/ofeliav1link.py @@ -5,18 +5,21 @@ from sfa.util.xrn import Xrn, get_leaf from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.link import Link + class Ofeliav1Link: @staticmethod def get_links(xml, filter=None): - if filter is None: filter = {} - xpath = '//link%s | //openflow:link%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//link%s | //openflow:link%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) link_elems = xml.xpath(xpath) return Ofeliav1Link.get_link_objs(link_elems) @staticmethod def get_link_objs(link_elems): - links = [] + links = [] for link_elem in link_elems: link = Link(link_elem.attrib, link_elem) links.append(link) diff --git a/sfa/rspecs/elements/versions/pgv2DiskImage.py b/sfa/rspecs/elements/versions/pgv2DiskImage.py index 4a6df82e..cd6801ea 100644 --- a/sfa/rspecs/elements/versions/pgv2DiskImage.py +++ b/sfa/rspecs/elements/versions/pgv2DiskImage.py @@ -1,20 +1,22 @@ from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.disk_image import DiskImage + class PGv2DiskImage: @staticmethod def add_images(xml, images): if not images: - return + return if not isinstance(images, list): images = [images] - for image in images: + for image in images: xml.add_instance('disk_image', image, DiskImage.fields) - + @staticmethod def get_images(xml, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} xpath = './default:disk_image | ./disk_image' image_elems = xml.xpath(xpath) images = [] @@ -22,4 +24,3 @@ class PGv2DiskImage: image = DiskImage(image_elem.attrib, image_elem) images.append(image) return images - diff --git a/sfa/rspecs/elements/versions/pgv2Interface.py b/sfa/rspecs/elements/versions/pgv2Interface.py index 7144fa9b..5a433826 100644 --- a/sfa/rspecs/elements/versions/pgv2Interface.py +++ b/sfa/rspecs/elements/versions/pgv2Interface.py @@ -2,19 +2,21 @@ from sfa.util.xrn import Xrn from sfa.util.xml import XpathFilter from sfa.rspecs.elements.interface import Interface + class PGv2Interface: @staticmethod def add_interfaces(xml, interfaces): if isinstance(interfaces, list): for interface in interfaces: - if_elem = xml.add_instance('interface', interface, ['component_id', 'client_id', 'sliver_id']) + if_elem = xml.add_instance('interface', interface, [ + 'component_id', 'client_id', 'sliver_id']) ips = interface.get('ips', []) for ip in ips: if_elem.add_instance('ip', {'address': ip.get('address'), 'netmask': ip.get('netmask'), - 'type': ip.get('type')}) - + 'type': ip.get('type')}) + @staticmethod def get_interfaces(xml): pass diff --git a/sfa/rspecs/elements/versions/pgv2Lease.py b/sfa/rspecs/elements/versions/pgv2Lease.py index b04f7dc2..19ee9392 100644 --- a/sfa/rspecs/elements/versions/pgv2Lease.py +++ b/sfa/rspecs/elements/versions/pgv2Lease.py @@ -10,32 +10,33 @@ from sfa.rspecs.elements.disk_image import DiskImage from sfa.rspecs.elements.interface import Interface from sfa.rspecs.elements.bwlimit import BWlimit from sfa.rspecs.elements.pltag import PLTag -from sfa.rspecs.elements.versions.pgv2Services import PGv2Services -from sfa.rspecs.elements.versions.pgv2SliverType import PGv2SliverType -from sfa.rspecs.elements.versions.pgv2Interface import PGv2Interface +from sfa.rspecs.elements.versions.pgv2Services import PGv2Services +from sfa.rspecs.elements.versions.pgv2SliverType import PGv2SliverType +from sfa.rspecs.elements.versions.pgv2Interface import PGv2Interface from sfa.rspecs.elements.lease import Lease class PGv2Lease: + @staticmethod def add_leases(xml, leases): # group the leases by slice and timeslots grouped_leases = [] while leases: - slice_id = leases[0]['slice_id'] - start_time = leases[0]['start_time'] - duration = leases[0]['duration'] - group = [] + slice_id = leases[0]['slice_id'] + start_time = leases[0]['start_time'] + duration = leases[0]['duration'] + group = [] - for lease in leases: - if slice_id == lease['slice_id'] and start_time == lease['start_time'] and duration == lease['duration']: - group.append(lease) + for lease in leases: + if slice_id == lease['slice_id'] and start_time == lease['start_time'] and duration == lease['duration']: + group.append(lease) - grouped_leases.append(group) + grouped_leases.append(group) - for lease1 in group: - leases.remove(lease1) + for lease1 in group: + leases.remove(lease1) lease_elems = [] for lease in grouped_leases: @@ -47,30 +48,30 @@ class PGv2Lease: # add nodes of this lease for node in lease: - lease_elem.add_instance('node', node, ['component_id']) - + lease_elem.add_instance('node', node, ['component_id']) @staticmethod def get_leases(xml, filter=None): - if filter is None: filter={} - xpath = '//lease%s | //default:lease%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//lease%s | //default:lease%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) lease_elems = xml.xpath(xpath) return PGv2Lease.get_lease_objs(lease_elems) - @staticmethod def get_lease_objs(lease_elems): leases = [] for lease_elem in lease_elems: - #get nodes + # get nodes node_elems = lease_elem.xpath('./default:node | ./node') for node_elem in node_elems: - lease = Lease(lease_elem.attrib, lease_elem) - lease['slice_id'] = lease_elem.attrib['slice_id'] - #lease['start_time'] = datetime_to_epoch(utcparse(lease_elem.attrib['start_time'])) - lease['start_time'] = lease_elem.attrib['start_time'] - lease['duration'] = lease_elem.attrib['duration'] - lease['component_id'] = node_elem.attrib['component_id'] - leases.append(lease) + lease = Lease(lease_elem.attrib, lease_elem) + lease['slice_id'] = lease_elem.attrib['slice_id'] + #lease['start_time'] = datetime_to_epoch(utcparse(lease_elem.attrib['start_time'])) + lease['start_time'] = lease_elem.attrib['start_time'] + lease['duration'] = lease_elem.attrib['duration'] + lease['component_id'] = node_elem.attrib['component_id'] + leases.append(lease) return leases diff --git a/sfa/rspecs/elements/versions/pgv2Link.py b/sfa/rspecs/elements/versions/pgv2Link.py index 1c0d3ce1..95c708e6 100644 --- a/sfa/rspecs/elements/versions/pgv2Link.py +++ b/sfa/rspecs/elements/versions/pgv2Link.py @@ -2,32 +2,40 @@ from sfa.util.xrn import Xrn from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.link import Link from sfa.rspecs.elements.interface import Interface -from sfa.rspecs.elements.property import Property +from sfa.rspecs.elements.property import Property + class PGv2Link: + @staticmethod def add_links(xml, links): for link in links: - - link_elem = xml.add_instance('link', link, ['component_name', 'component_id', 'client_id']) - # set component manager element + + link_elem = xml.add_instance( + 'link', link, ['component_name', 'component_id', 'client_id']) + # set component manager element if 'component_manager' in link and link['component_manager']: - cm_element = link_elem.add_element('component_manager', name=link['component_manager']) + cm_element = link_elem.add_element( + 'component_manager', name=link['component_manager']) # set interface_ref elements if link.get('interface1') and link.get('interface2'): for if_ref in [link['interface1'], link['interface2']]: - link_elem.add_instance('interface_ref', if_ref, Interface.fields) + link_elem.add_instance( + 'interface_ref', if_ref, Interface.fields) # set property elements - prop1 = link_elem.add_element('property', source_id = link['interface1']['component_id'], - dest_id = link['interface2']['component_id'], capacity=link['capacity'], - latency=link['latency'], packet_loss=link['packet_loss']) - prop2 = link_elem.add_element('property', source_id = link['interface2']['component_id'], - dest_id = link['interface1']['component_id'], capacity=link['capacity'], - latency=link['latency'], packet_loss=link['packet_loss']) + prop1 = link_elem.add_element('property', source_id=link['interface1']['component_id'], + dest_id=link['interface2'][ + 'component_id'], capacity=link['capacity'], + latency=link['latency'], packet_loss=link['packet_loss']) + prop2 = link_elem.add_element('property', source_id=link['interface2']['component_id'], + dest_id=link['interface1'][ + 'component_id'], capacity=link['capacity'], + latency=link['latency'], packet_loss=link['packet_loss']) if link.get('type'): - type_elem = link_elem.add_element('link_type', name=link['type']) - - @staticmethod + type_elem = link_elem.add_element( + 'link_type', name=link['type']) + + @staticmethod def get_links(xml): links = [] link_elems = xml.xpath('//default:link | //link') @@ -36,7 +44,8 @@ class PGv2Link: link = Link(link_elem.attrib, link_elem) # set component manager - component_managers = link_elem.xpath('./default:component_manager | ./component_manager') + component_managers = link_elem.xpath( + './default:component_manager | ./component_manager') if len(component_managers) > 0 and 'name' in component_managers[0].attrib: link['component_manager'] = component_managers[0].attrib['name'] @@ -44,8 +53,8 @@ class PGv2Link: link_types = link_elem.xpath('./default:link_type | ./link_type') if len(link_types) > 0 and 'name' in link_types[0].attrib: link['type'] = link_types[0].attrib['name'] - - # get capacity, latency and packet_loss from first property + + # get capacity, latency and packet_loss from first property property_fields = ['capacity', 'latency', 'packet_loss'] property_elems = link_elem.xpath('./default:property | ./property') if len(property_elems) > 0: @@ -53,15 +62,17 @@ class PGv2Link: for attrib in ['capacity', 'latency', 'packet_loss']: if attrib in prop.attrib: link[attrib] = prop.attrib[attrib] - + # get interfaces - iface_elems = link_elem.xpath('./default:interface_ref | ./interface_ref') - interfaces = [iface_elem.get_instance(Interface) for iface_elem in iface_elems] + iface_elems = link_elem.xpath( + './default:interface_ref | ./interface_ref') + interfaces = [iface_elem.get_instance( + Interface) for iface_elem in iface_elems] if len(interfaces) > 1: link['interface1'] = interfaces[0] - link['interface2'] = interfaces[1] + link['interface2'] = interfaces[1] links.append(link) - return links + return links @staticmethod def add_link_requests(xml, link_tuples, append=False): @@ -71,34 +82,36 @@ class PGv2Link: available_links = PGv2Link.get_links(xml) recently_added = [] for link in available_links: - if_name1 = Xrn(link['interface1']['component_id']).get_leaf() - if_name2 = Xrn(link['interface2']['component_id']).get_leaf() - + if_name1 = Xrn(link['interface1']['component_id']).get_leaf() + if_name2 = Xrn(link['interface2']['component_id']).get_leaf() + requested_link = None l_tup_1 = (if_name1, if_name2) l_tup_2 = (if_name2, if_name1) if link_tuples.issuperset([(if_name1, if_name2)]): - requested_link = (if_name1, if_name2) + requested_link = (if_name1, if_name2) elif link_tuples.issuperset([(if_name2, if_name2)]): requested_link = (if_name2, if_name1) if requested_link: - # add client id to link ane interface elements + # add client id to link ane interface elements link.element.set('client_id', link['component_name']) - link['interface1'].element.set('client_id', Xrn(link['interface1']['component_id']).get_leaf()) - link['interface2'].element.set('client_id', Xrn(link['interface2']['component_id']).get_leaf()) + link['interface1'].element.set('client_id', Xrn( + link['interface1']['component_id']).get_leaf()) + link['interface2'].element.set('client_id', Xrn( + link['interface2']['component_id']).get_leaf()) recently_added.append(link['component_name']) if not append: - # remove all links that don't have a client id + # remove all links that don't have a client id for link in PGv2Link.get_links(xml): if not link['client_id'] or link['component_name'] not in recently_added: parent = link.element.getparent() - parent.remove(link.element) - + parent.remove(link.element) + @staticmethod def get_link_requests(xml): link_requests = [] for link in PGv2Link.get_links(xml): if link['client_id'] != None: link_requests.append(link) - return link_requests + return link_requests diff --git a/sfa/rspecs/elements/versions/pgv2Node.py b/sfa/rspecs/elements/versions/pgv2Node.py index a61749c1..3409dc0c 100644 --- a/sfa/rspecs/elements/versions/pgv2Node.py +++ b/sfa/rspecs/elements/versions/pgv2Node.py @@ -9,48 +9,55 @@ from sfa.rspecs.elements.disk_image import DiskImage from sfa.rspecs.elements.interface import Interface from sfa.rspecs.elements.bwlimit import BWlimit from sfa.rspecs.elements.pltag import PLTag -from sfa.rspecs.elements.versions.pgv2Services import PGv2Services -from sfa.rspecs.elements.versions.pgv2SliverType import PGv2SliverType -from sfa.rspecs.elements.versions.pgv2Interface import PGv2Interface +from sfa.rspecs.elements.versions.pgv2Services import PGv2Services +from sfa.rspecs.elements.versions.pgv2SliverType import PGv2SliverType +from sfa.rspecs.elements.versions.pgv2Interface import PGv2Interface from sfa.rspecs.elements.versions.sfav1PLTag import SFAv1PLTag from sfa.rspecs.elements.granularity import Granularity from sfa.rspecs.elements.attribute import Attribute class PGv2Node: + @staticmethod def add_nodes(xml, nodes, rspec_content_type=None): node_elems = [] for node in nodes: - node_fields = ['component_manager_id', 'component_id', 'client_id', 'sliver_id', 'exclusive'] + node_fields = ['component_manager_id', 'component_id', + 'client_id', 'sliver_id', 'exclusive'] node_elem = xml.add_instance('node', node, node_fields) node_elems.append(node_elem) # set component name if node.get('component_id'): - component_name = Xrn.unescape(get_leaf(Xrn(node['component_id']).get_hrn())) + component_name = Xrn.unescape( + get_leaf(Xrn(node['component_id']).get_hrn())) node_elem.set('component_name', component_name) # set hardware types if node.get('hardware_types'): - for hardware_type in node.get('hardware_types', []): - node_elem.add_instance('hardware_type', hardware_type, HardwareType.fields) + for hardware_type in node.get('hardware_types', []): + node_elem.add_instance( + 'hardware_type', hardware_type, HardwareType.fields) # set location if node.get('location'): - node_elem.add_instance('location', node['location'], Location.fields) + node_elem.add_instance( + 'location', node['location'], Location.fields) # set granularity if node.get('exclusive') == "true": granularity = node.get('granularity') - node_elem.add_instance('granularity', granularity, granularity.fields) + node_elem.add_instance( + 'granularity', granularity, granularity.fields) # set interfaces PGv2Interface.add_interfaces(node_elem, node.get('interfaces')) - #if node.get('interfaces'): + # if node.get('interfaces'): # for interface in node.get('interfaces', []): # node_elem.add_instance('interface', interface, ['component_id', 'client_id']) # set available element if node.get('available'): - available_elem = node_elem.add_element('available', now=node['available']) + available_elem = node_elem.add_element( + 'available', now=node['available']) # add services - PGv2Services.add_services(node_elem, node.get('services', [])) + PGv2Services.add_services(node_elem, node.get('services', [])) # add slivers slivers = node.get('slivers', []) if not slivers: @@ -58,41 +65,45 @@ class PGv2Node: if node.get('sliver_type'): slivers = Sliver({'type': node['sliver_type']}) else: - # Planet lab + # Planet lab slivers = Sliver({'type': 'plab-vserver'}) # we must also advertise the available initscripts slivers['tags'] = [] - if node.get('pl_initscripts'): + if node.get('pl_initscripts'): for initscript in node.get('pl_initscripts', []): - slivers['tags'].append({'name': 'initscript', 'value': initscript['name']}) + slivers['tags'].append( + {'name': 'initscript', 'value': initscript['name']}) PGv2SliverType.add_slivers(node_elem, slivers) # advertise the node tags tags = node.get('tags', []) if tags: - for tag in tags: + for tag in tags: tag['name'] = tag.pop('tagname') - node_elem.add_instance('{%s}attribute' % xml.namespaces['planetlab'], tag, ['name', 'value']) + node_elem.add_instance('{%s}attribute' % xml.namespaces[ + 'planetlab'], tag, ['name', 'value']) # add sliver tag in Request Rspec - #if rspec_content_type == "request": + # if rspec_content_type == "request": # node_elem.add_instance('sliver', '', []) return node_elems - @staticmethod def get_nodes(xml, filter=None): - if filter is None: filter={} - xpath = '//node%s | //default:node%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//node%s | //default:node%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) node_elems = xml.xpath(xpath) return PGv2Node.get_node_objs(node_elems) @staticmethod def get_nodes_with_slivers(xml, filter=None): - if filter is None: filter={} - xpath = '//node[count(sliver_type)>0] | //default:node[count(default:sliver_type) > 0]' - node_elems = xml.xpath(xpath) + if filter is None: + filter = {} + xpath = '//node[count(sliver_type)>0] | //default:node[count(default:sliver_type) > 0]' + node_elems = xml.xpath(xpath) return PGv2Node.get_node_objs(node_elems) @staticmethod @@ -100,68 +111,78 @@ class PGv2Node: nodes = [] for node_elem in node_elems: node = NodeElement(node_elem.attrib, node_elem) - nodes.append(node) + nodes.append(node) if 'component_id' in node_elem.attrib: - node['authority_id'] = Xrn(node_elem.attrib['component_id']).get_authority_urn() - + node['authority_id'] = Xrn( + node_elem.attrib['component_id']).get_authority_urn() + # get hardware types - hardware_type_elems = node_elem.xpath('./default:hardware_type | ./hardware_type') - node['hardware_types'] = [dict(hw_type.get_instance(HardwareType)) for hw_type in hardware_type_elems] - + hardware_type_elems = node_elem.xpath( + './default:hardware_type | ./hardware_type') + node['hardware_types'] = [dict(hw_type.get_instance( + HardwareType)) for hw_type in hardware_type_elems] + # get location location_elems = node_elem.xpath('./default:location | ./location') - locations = [dict(location_elem.get_instance(Location)) for location_elem in location_elems] + locations = [dict(location_elem.get_instance(Location)) + for location_elem in location_elems] if len(locations) > 0: node['location'] = locations[0] # get granularity - granularity_elems = node_elem.xpath('./default:granularity | ./granularity') + granularity_elems = node_elem.xpath( + './default:granularity | ./granularity') if len(granularity_elems) > 0: - node['granularity'] = granularity_elems[0].get_instance(Granularity) + node['granularity'] = granularity_elems[ + 0].get_instance(Granularity) # get interfaces iface_elems = node_elem.xpath('./default:interface | ./interface') - node['interfaces'] = [dict(iface_elem.get_instance(Interface)) for iface_elem in iface_elems] + node['interfaces'] = [dict(iface_elem.get_instance( + Interface)) for iface_elem in iface_elems] # get services node['services'] = PGv2Services.get_services(node_elem) - + # get slivers - node['slivers'] = PGv2SliverType.get_slivers(node_elem) - + node['slivers'] = PGv2SliverType.get_slivers(node_elem) + # get boot state - available_elems = node_elem.xpath('./default:available | ./available') + available_elems = node_elem.xpath( + './default:available | ./available') if len(available_elems) > 0 and 'now' in available_elems[0].attrib: - if available_elems[0].attrib.get('now', '').lower() == 'true': + if available_elems[0].attrib.get('now', '').lower() == 'true': node['boot_state'] = 'boot' - else: - node['boot_state'] = 'disabled' + else: + node['boot_state'] = 'disabled' # get initscripts try: - node['pl_initscripts'] = [] - initscript_elems = node_elem.xpath('./default:sliver_type/planetlab:initscript | ./sliver_type/initscript') - if len(initscript_elems) > 0: - for initscript_elem in initscript_elems: + node['pl_initscripts'] = [] + initscript_elems = node_elem.xpath( + './default:sliver_type/planetlab:initscript | ./sliver_type/initscript') + if len(initscript_elems) > 0: + for initscript_elem in initscript_elems: if 'name' in initscript_elem.attrib: - node['pl_initscripts'].append(dict(initscript_elem.attrib)) + node['pl_initscripts'].append( + dict(initscript_elem.attrib)) except: - pass + pass # get node tags try: - tag_elems = node_elem.xpath('./planetlab:attribute | ./attribute') - node['tags'] = [] - if len(tag_elems) > 0: - for tag_elem in tag_elems: + tag_elems = node_elem.xpath( + './planetlab:attribute | ./attribute') + node['tags'] = [] + if len(tag_elems) > 0: + for tag_elem in tag_elems: tag = dict(tag_elem.get_instance(Attribute)) tag['tagname'] = tag.pop('name') node['tags'].append(tag) except: - pass - - return nodes + pass + return nodes @staticmethod def add_slivers(xml, slivers): @@ -173,7 +194,7 @@ class PGv2Node: sliver = {} elif 'component_id' in sliver and sliver['component_id']: filter['component_id'] = '*%s*' % sliver['component_id'] - if not filter: + if not filter: continue nodes = PGv2Node.get_nodes(xml, filter) if not nodes: @@ -184,18 +205,17 @@ class PGv2Node: @staticmethod def remove_slivers(xml, hostnames): for hostname in hostnames: - nodes = PGv2Node.get_nodes(xml, {'component_id': '*%s*' % hostname}) + nodes = PGv2Node.get_nodes( + xml, {'component_id': '*%s*' % hostname}) for node in nodes: slivers = PGv2SliverType.get_slivers(node.element) for sliver in slivers: - node.element.remove(sliver.element) + node.element.remove(sliver.element) if __name__ == '__main__': from sfa.rspecs.rspec import RSpec import pdb r = RSpec('/tmp/emulab.rspec') - r2 = RSpec(version = 'ProtoGENI') + r2 = RSpec(version='ProtoGENI') nodes = PGv2Node.get_nodes(r.xml) PGv2Node.add_nodes(r2.xml.root, nodes) - #pdb.set_trace() - - + # pdb.set_trace() diff --git a/sfa/rspecs/elements/versions/pgv2Services.py b/sfa/rspecs/elements/versions/pgv2Services.py index e0acc714..61e73aea 100644 --- a/sfa/rspecs/elements/versions/pgv2Services.py +++ b/sfa/rspecs/elements/versions/pgv2Services.py @@ -1,14 +1,16 @@ -from sfa.rspecs.elements.element import Element -from sfa.rspecs.elements.execute import Execute -from sfa.rspecs.elements.install import Install -from sfa.rspecs.elements.services import ServicesElement +from sfa.rspecs.elements.element import Element +from sfa.rspecs.elements.execute import Execute +from sfa.rspecs.elements.install import Install +from sfa.rspecs.elements.services import ServicesElement from sfa.rspecs.elements.login import Login + class PGv2Services: + @staticmethod def add_services(xml, services): if not services: - return + return for service in services: service_elem = xml.add_element('services') child_elements = {'install': Install.fields, @@ -16,7 +18,7 @@ class PGv2Services: 'login': Login.fields} for (name, fields) in child_elements.items(): child = service.get(name) - if not child: + if not child: continue if isinstance(child, dict): service_elem.add_instance(name, child, fields) @@ -33,21 +35,26 @@ class PGv2Services: # for key in ssh_user['keys']: # pkey_elem = ssh_user_elem.add_element('{%s}public_key' % xml.namespaces['ssh-user']) # pkey_elem.element.text=key - + @staticmethod def get_services(xml): services = [] for services_elem in xml.xpath('./default:services | ./services'): service = ServicesElement(services_elem.attrib, services_elem) - # get install - install_elems = services_elem.xpath('./default:install | ./install') - service['install'] = [install_elem.get_instance(Install) for install_elem in install_elems] + # get install + install_elems = services_elem.xpath( + './default:install | ./install') + service['install'] = [install_elem.get_instance( + Install) for install_elem in install_elems] # get execute - execute_elems = services_elem.xpath('./default:execute | ./execute') - service['execute'] = [execute_elem.get_instance(Execute) for execute_elem in execute_elems] + execute_elems = services_elem.xpath( + './default:execute | ./execute') + service['execute'] = [execute_elem.get_instance( + Execute) for execute_elem in execute_elems] # get login login_elems = services_elem.xpath('./default:login | ./login') - service['login'] = [login_elem.get_instance(Login) for login_elem in login_elems] + service['login'] = [login_elem.get_instance( + Login) for login_elem in login_elems] # ssh_user_elems = services_elem.xpath('./ssh-user:service_user | ./service_user') # services_user = [] @@ -55,6 +62,5 @@ class PGv2Services: # services_user = ssh_user_elem.get_instance(None, fields=['login', 'user_urn']) # service['services_user'] = services_user - services.append(service) + services.append(service) return services - diff --git a/sfa/rspecs/elements/versions/pgv2SliverType.py b/sfa/rspecs/elements/versions/pgv2SliverType.py index b8ca2d11..9a5e1d71 100644 --- a/sfa/rspecs/elements/versions/pgv2SliverType.py +++ b/sfa/rspecs/elements/versions/pgv2SliverType.py @@ -5,15 +5,16 @@ from sfa.rspecs.elements.versions.plosv1FWRule import PLOSv1FWRule from sfa.util.sfalogging import logger + class PGv2SliverType: @staticmethod def add_slivers(xml, slivers): if not slivers: - return + return if not isinstance(slivers, list): slivers = [slivers] - for sliver in slivers: + for sliver in slivers: sliver_elem = xml.add_element('sliver_type') if sliver.get('type'): sliver_elem.set('name', sliver['type']) @@ -21,15 +22,16 @@ class PGv2SliverType: for attr in attrs: if sliver.get(attr): sliver_elem.set(attr, sliver[attr]) - + images = sliver.get('disk_image') if images and isinstance(images, list): - PGv2DiskImage.add_images(sliver_elem, images) + PGv2DiskImage.add_images(sliver_elem, images) fw_rules = sliver.get('fw_rules') if fw_rules and isinstance(fw_rules, list): PLOSv1FWRule.add_rules(sliver_elem, fw_rules) - PGv2SliverType.add_sliver_attributes(sliver_elem, sliver.get('tags', [])) - + PGv2SliverType.add_sliver_attributes( + sliver_elem, sliver.get('tags', [])) + @staticmethod def add_sliver_attributes(xml, tags): if tags is None: @@ -37,32 +39,35 @@ class PGv2SliverType: for tag in tags: tagname = tag['tagname'] if 'tagname' in tag else tag['name'] if tagname == 'flack_info': - attrib_elem = xml.add_element('{%s}info' % self.namespaces['flack']) + attrib_elem = xml.add_element( + '{%s}info' % self.namespaces['flack']) try: attrib_dict = eval(tag['value']) for (key, value) in attrib_dict.items(): attrib_elem.set(key, value) except Exception as e: - logger.warning("Could not parse dictionary in flack tag -- {}".format(e)) + logger.warning( + "Could not parse dictionary in flack tag -- {}".format(e)) elif tagname == 'initscript': xml.add_element('{%s}initscript' % xml.namespaces['planetlab'], name=tag['value']) else: xml.add_element('{%s}attribute' % (xml.namespaces['planetlab']), - name = tagname, - value = tag['value'], - scope = tag.get('scope', 'unknown'), - ) - + name=tagname, + value=tag['value'], + scope=tag.get('scope', 'unknown'), + ) + @staticmethod def get_slivers(xml, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} xpath = './default:sliver_type | ./sliver_type' sliver_elems = xml.xpath(xpath) slivers = [] for sliver_elem in sliver_elems: - sliver = Sliver(sliver_elem.attrib,sliver_elem) - if 'component_id' in xml.attrib: + sliver = Sliver(sliver_elem.attrib, sliver_elem) + if 'component_id' in xml.attrib: sliver['component_id'] = xml.attrib['component_id'] if 'name' in sliver_elem.attrib: sliver['type'] = sliver_elem.attrib['name'] @@ -73,5 +78,6 @@ class PGv2SliverType: @staticmethod def get_sliver_attributes(xml, filter=None): - if filter is None: filter={} - return [] + if filter is None: + filter = {} + return [] diff --git a/sfa/rspecs/elements/versions/plosv1FWRule.py b/sfa/rspecs/elements/versions/plosv1FWRule.py index 744a36f3..e62a34f2 100644 --- a/sfa/rspecs/elements/versions/plosv1FWRule.py +++ b/sfa/rspecs/elements/versions/plosv1FWRule.py @@ -1,11 +1,13 @@ -from sfa.rspecs.elements.element import Element +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.fw_rule import FWRule + class PLOSv1FWRule: + @staticmethod def add_rules(xml, rules): if not rules: - return + return for rule in rules: rule_elem = xml.add_element('{%s}fw_rule' % xml.namespaces['plos']) rule_elem.set('protocol', rule.get('protocol')) @@ -13,13 +15,12 @@ class PLOSv1FWRule: rule_elem.set('cidr_ip', rule.get('cidr_ip')) if rule.get('icmp_type_code'): rule_elem.set('icmp_type_code', rule.get('icmp_type_code')) - + @staticmethod def get_rules(xml): rules = [] - if 'plos' in xml.namespaces: + if 'plos' in xml.namespaces: for rule_elem in xml.xpath('./plos:fw_rule | ./fw_rule'): rule = FWRule(rule_elem.attrib, rule_elem) - rules.append(rule) + rules.append(rule) return rules - diff --git a/sfa/rspecs/elements/versions/sfav1Lease.py b/sfa/rspecs/elements/versions/sfav1Lease.py index 0c7cb26b..2a14b47b 100644 --- a/sfa/rspecs/elements/versions/sfav1Lease.py +++ b/sfa/rspecs/elements/versions/sfav1Lease.py @@ -22,59 +22,61 @@ class SFAv1Lease: @staticmethod def add_leases(xml, leases): - + network_elems = xml.xpath('//network') if len(network_elems) > 0: network_elem = network_elems[0] elif len(leases) > 0: - network_urn = Xrn(leases[0]['component_id']).get_authority_urn().split(':')[0] - network_elem = xml.add_element('network', name = network_urn) + network_urn = Xrn(leases[0]['component_id'] + ).get_authority_urn().split(':')[0] + network_elem = xml.add_element('network', name=network_urn) else: network_elem = xml - + # group the leases by slice and timeslots grouped_leases = [] while leases: - slice_id = leases[0]['slice_id'] - start_time = leases[0]['start_time'] - duration = leases[0]['duration'] - group = [] + slice_id = leases[0]['slice_id'] + start_time = leases[0]['start_time'] + duration = leases[0]['duration'] + group = [] - for lease in leases: - if slice_id == lease['slice_id'] and start_time == lease['start_time'] and duration == lease['duration']: - group.append(lease) + for lease in leases: + if slice_id == lease['slice_id'] and start_time == lease['start_time'] and duration == lease['duration']: + group.append(lease) - grouped_leases.append(group) + grouped_leases.append(group) - for lease1 in group: - leases.remove(lease1) + for lease1 in group: + leases.remove(lease1) lease_elems = [] for lease in grouped_leases: #lease[0]['start_time'] = datetime_to_string(utcparse(lease[0]['start_time'])) lease_fields = ['slice_id', 'start_time', 'duration'] - lease_elem = network_elem.add_instance('lease', lease[0], lease_fields) + lease_elem = network_elem.add_instance( + 'lease', lease[0], lease_fields) lease_elems.append(lease_elem) # add nodes of this lease for node in lease: - lease_elem.add_instance('node', node, ['component_id']) + lease_elem.add_instance('node', node, ['component_id']) - -# lease_elems = [] +# lease_elems = [] # for lease in leases: # lease_fields = ['lease_id', 'component_id', 'slice_id', 'start_time', 'duration'] # lease_elem = network_elem.add_instance('lease', lease, lease_fields) # lease_elems.append(lease_elem) - @staticmethod def get_leases(xml, filter=None): - if filter is None: filter={} - xpath = '//lease%s | //default:lease%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//lease%s | //default:lease%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) lease_elems = xml.xpath(xpath) return SFAv1Lease.get_lease_objs(lease_elems) @@ -82,24 +84,21 @@ class SFAv1Lease: def get_lease_objs(lease_elems): leases = [] for lease_elem in lease_elems: - #get nodes + # get nodes node_elems = lease_elem.xpath('./default:node | ./node') for node_elem in node_elems: - lease = Lease(lease_elem.attrib, lease_elem) - lease['slice_id'] = lease_elem.attrib['slice_id'] - #lease['start_time'] = datetime_to_epoch(utcparse(lease_elem.attrib['start_time'])) - lease['start_time'] = lease_elem.attrib['start_time'] - lease['duration'] = lease_elem.attrib['duration'] - lease['component_id'] = node_elem.attrib['component_id'] - leases.append(lease) + lease = Lease(lease_elem.attrib, lease_elem) + lease['slice_id'] = lease_elem.attrib['slice_id'] + #lease['start_time'] = datetime_to_epoch(utcparse(lease_elem.attrib['start_time'])) + lease['start_time'] = lease_elem.attrib['start_time'] + lease['duration'] = lease_elem.attrib['duration'] + lease['component_id'] = node_elem.attrib['component_id'] + leases.append(lease) return leases - - - -# leases = [] +# leases = [] # for lease_elem in lease_elems: # lease = Lease(lease_elem.attrib, lease_elem) # if lease.get('lease_id'): @@ -110,5 +109,4 @@ class SFAv1Lease: # lease['duration'] = lease_elem.attrib['duration'] # leases.append(lease) -# return leases - +# return leases diff --git a/sfa/rspecs/elements/versions/sfav1Node.py b/sfa/rspecs/elements/versions/sfav1Node.py index 1931a584..23defad2 100644 --- a/sfa/rspecs/elements/versions/sfav1Node.py +++ b/sfa/rspecs/elements/versions/sfav1Node.py @@ -25,24 +25,27 @@ class SFAv1Node: network_elem = network_elems[0] elif len(nodes) > 0 and nodes[0].get('component_manager_id'): network_urn = nodes[0]['component_manager_id'] - network_elem = xml.add_element('network', name = Xrn(network_urn).get_hrn()) + network_elem = xml.add_element( + 'network', name=Xrn(network_urn).get_hrn()) else: network_elem = xml - node_elems = [] + node_elems = [] for node in nodes: - node_fields = ['component_manager_id', 'component_id', 'boot_state'] + node_fields = ['component_manager_id', + 'component_id', 'boot_state'] node_elem = network_elem.add_instance('node', node, node_fields) node_elems.append(node_elem) # determine network hrn - network_hrn = None + network_hrn = None if 'component_manager_id' in node and node['component_manager_id']: network_hrn = Xrn(node['component_manager_id']).get_hrn() # set component_name attribute and hostname element if 'component_id' in node and node['component_id']: - component_name = Xrn.unescape(get_leaf(Xrn(node['component_id']).get_hrn())) + component_name = Xrn.unescape( + get_leaf(Xrn(node['component_id']).get_hrn())) node_elem.set('component_name', component_name) hostname_elem = node_elem.add_element('hostname') hostname_elem.set_text(component_name) @@ -56,23 +59,25 @@ class SFAv1Node: if location: node_elem.add_instance('location', location, Location.fields) - # add exclusive tag to distinguish between Reservable and Shared nodes + # add exclusive tag to distinguish between Reservable and Shared + # nodes exclusive_elem = node_elem.add_element('exclusive') if node.get('exclusive') and node.get('exclusive') == 'true': exclusive_elem.set_text('TRUE') # add granularity of the reservation system granularity = node.get('granularity') if granularity: - node_elem.add_instance('granularity', granularity, granularity.fields) + node_elem.add_instance( + 'granularity', granularity, granularity.fields) else: exclusive_elem.set_text('FALSE') - if isinstance(node.get('interfaces'), list): for interface in node.get('interfaces', []): - node_elem.add_instance('interface', interface, ['component_id', 'client_id', 'ipv4']) - - #if 'bw_unallocated' in node and node['bw_unallocated']: + node_elem.add_instance('interface', interface, [ + 'component_id', 'client_id', 'ipv4']) + + # if 'bw_unallocated' in node and node['bw_unallocated']: # bw_unallocated = etree.SubElement(node_elem, 'bw_unallocated', units='kbps').text = str(int(node['bw_unallocated'])/1000) PGv2Services.add_services(node_elem, node.get('services', [])) @@ -81,12 +86,13 @@ class SFAv1Node: for tag in tags: # backdoor for FITeagle # Alexander Willner - if tag['tagname']=="fiteagle_settings": + if tag['tagname'] == "fiteagle_settings": tag_elem = node_elem.add_element(tag['tagname']) for subtag in tag['value']: subtag_elem = tag_elem.add_element('setting') subtag_elem.set('name', str(subtag['tagname'])) - subtag_elem.set('description', str(subtag['description'])) + subtag_elem.set('description', str( + subtag['description'])) subtag_elem.set_text(subtag['value']) else: tag_elem = node_elem.add_element(tag['tagname']) @@ -95,9 +101,9 @@ class SFAv1Node: # add sliver tag in Request Rspec if rspec_content_type == "request": - node_elem.add_instance('sliver', '', []) + node_elem.add_instance('sliver', '', []) - @staticmethod + @staticmethod def add_slivers(xml, slivers): component_ids = [] for sliver in slivers: @@ -108,7 +114,7 @@ class SFAv1Node: elif 'component_id' in sliver and sliver['component_id']: filter['component_id'] = '*%s*' % sliver['component_id'] if not filter: - continue + continue nodes = SFAv1Node.get_nodes(xml, filter) if not nodes: continue @@ -118,56 +124,64 @@ class SFAv1Node: @staticmethod def remove_slivers(xml, hostnames): for hostname in hostnames: - nodes = SFAv1Node.get_nodes(xml, {'component_id': '*%s*' % hostname}) + nodes = SFAv1Node.get_nodes( + xml, {'component_id': '*%s*' % hostname}) for node in nodes: slivers = SFAv1Sliver.get_slivers(node.element) for sliver in slivers: node.element.remove(sliver.element) - + @staticmethod def get_nodes(xml, filter=None): - if filter is None: filter={} - xpath = '//node%s | //default:node%s' % (XpathFilter.xpath(filter), XpathFilter.xpath(filter)) + if filter is None: + filter = {} + xpath = '//node%s | //default:node%s' % ( + XpathFilter.xpath(filter), XpathFilter.xpath(filter)) node_elems = xml.xpath(xpath) return SFAv1Node.get_node_objs(node_elems) @staticmethod def get_nodes_with_slivers(xml): - xpath = '//node[count(sliver)>0] | //default:node[count(default:sliver)>0]' + xpath = '//node[count(sliver)>0] | //default:node[count(default:sliver)>0]' node_elems = xml.xpath(xpath) return SFAv1Node.get_node_objs(node_elems) - @staticmethod def get_node_objs(node_elems): - nodes = [] + nodes = [] for node_elem in node_elems: node = NodeElement(node_elem.attrib, node_elem) if 'site_id' in node_elem.attrib: node['authority_id'] = node_elem.attrib['site_id'] # get location location_elems = node_elem.xpath('./default:location | ./location') - locations = [dict(loc_elem.get_instance(Location)) for loc_elem in location_elems] + locations = [dict(loc_elem.get_instance(Location)) + for loc_elem in location_elems] if len(locations) > 0: node['location'] = locations[0] # get bwlimit bwlimit_elems = node_elem.xpath('./default:bw_limit | ./bw_limit') - bwlimits = [bwlimit_elem.get_instance(BWlimit) for bwlimit_elem in bwlimit_elems] + bwlimits = [bwlimit_elem.get_instance( + BWlimit) for bwlimit_elem in bwlimit_elems] if len(bwlimits) > 0: node['bwlimit'] = bwlimits[0] # get interfaces iface_elems = node_elem.xpath('./default:interface | ./interface') - ifaces = [dict(iface_elem.get_instance(Interface)) for iface_elem in iface_elems] + ifaces = [dict(iface_elem.get_instance(Interface)) + for iface_elem in iface_elems] node['interfaces'] = ifaces # get services - node['services'] = PGv2Services.get_services(node_elem) + node['services'] = PGv2Services.get_services(node_elem) # get slivers node['slivers'] = SFAv1Sliver.get_slivers(node_elem) # get tags - node['tags'] = SFAv1PLTag.get_pl_tags(node_elem, ignore=NodeElement.fields+["hardware_type"]) + node['tags'] = SFAv1PLTag.get_pl_tags( + node_elem, ignore=NodeElement.fields + ["hardware_type"]) # get hardware types - hardware_type_elems = node_elem.xpath('./default:hardware_type | ./hardware_type') - node['hardware_types'] = [dict(hw_type.get_instance(HardwareType)) for hw_type in hardware_type_elems] + hardware_type_elems = node_elem.xpath( + './default:hardware_type | ./hardware_type') + node['hardware_types'] = [dict(hw_type.get_instance( + HardwareType)) for hw_type in hardware_type_elems] # temporary... play nice with old slice manager rspec if not node['component_name']: @@ -176,5 +190,4 @@ class SFAv1Node: node['component_name'] = hostname_elem.text nodes.append(node) - return nodes - + return nodes diff --git a/sfa/rspecs/elements/versions/sfav1PLTag.py b/sfa/rspecs/elements/versions/sfav1PLTag.py index 907c962a..2666e9c1 100644 --- a/sfa/rspecs/elements/versions/sfav1PLTag.py +++ b/sfa/rspecs/elements/versions/sfav1PLTag.py @@ -1,20 +1,22 @@ -from sfa.rspecs.elements.element import Element +from sfa.rspecs.elements.element import Element from sfa.rspecs.elements.pltag import PLTag + class SFAv1PLTag: + @staticmethod def add_pl_tag(xml, name, value): for pl_tag in pl_tags: pl_tag_elem = xml.add_element(name) pl_tag_elem.set_text(value) - + @staticmethod def get_pl_tags(xml, ignore=None): - if ignore is None: ignore=[] + if ignore is None: + ignore = [] pl_tags = [] for elem in xml.iterchildren(): if elem.tag not in ignore: pl_tag = PLTag({'tagname': elem.tag, 'value': elem.text}) - pl_tags.append(dict(pl_tag)) + pl_tags.append(dict(pl_tag)) return pl_tags - diff --git a/sfa/rspecs/elements/versions/sfav1Sliver.py b/sfa/rspecs/elements/versions/sfav1Sliver.py index 7e9282f2..a9d17e2c 100644 --- a/sfa/rspecs/elements/versions/sfav1Sliver.py +++ b/sfa/rspecs/elements/versions/sfav1Sliver.py @@ -19,13 +19,14 @@ class SFAv1Sliver: tags = sliver.get('tags', []) if tags: for tag in tags: - SFAv1Sliver.add_sliver_attribute(sliver_elem, tag['tagname'], tag['value']) + SFAv1Sliver.add_sliver_attribute( + sliver_elem, tag['tagname'], tag['value']) @staticmethod def add_sliver_attribute(xml, name, value): elem = xml.add_element(name) elem.set_text(value) - + @staticmethod def get_sliver_attributes(xml): attribs = [] @@ -36,19 +37,19 @@ class SFAv1Sliver: instance['name'] = elem.tag instance['value'] = elem.text attribs.append(instance) - return attribs - + return attribs + @staticmethod def get_slivers(xml, filter=None): - if filter is None: filter={} + if filter is None: + filter = {} xpath = './default:sliver | ./sliver' sliver_elems = xml.xpath(xpath) slivers = [] for sliver_elem in sliver_elems: - sliver = Sliver(sliver_elem.attrib,sliver_elem) - if 'component_id' in xml.attrib: + sliver = Sliver(sliver_elem.attrib, sliver_elem) + if 'component_id' in xml.attrib: sliver['component_id'] = xml.attrib['component_id'] sliver['tags'] = SFAv1Sliver.get_sliver_attributes(sliver_elem) slivers.append(sliver) - return slivers - + return slivers diff --git a/sfa/rspecs/pg_rspec_converter.py b/sfa/rspecs/pg_rspec_converter.py index ef021c00..7de30efd 100755 --- a/sfa/rspecs/pg_rspec_converter.py +++ b/sfa/rspecs/pg_rspec_converter.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python from __future__ import print_function from lxml import etree @@ -8,7 +8,7 @@ from sfa.rspecs.version_manager import VersionManager from sfa.util.py23 import StringIO -xslt=''' +xslt = ''' @@ -31,47 +31,49 @@ xslt=''' 1: - print(PGRSpecConverter.to_sfa_rspec(sys.argv[1])) + if len(sys.argv) > 1: + print(PGRSpecConverter.to_sfa_rspec(sys.argv[1])) diff --git a/sfa/rspecs/rspec.py b/sfa/rspecs/rspec.py index 372a8352..5a3bc58b 100755 --- a/sfa/rspecs/rspec.py +++ b/sfa/rspecs/rspec.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python from __future__ import print_function @@ -8,13 +8,15 @@ from sfa.util.xml import XML, XpathFilter from sfa.util.faults import InvalidRSpecElement, InvalidRSpec from sfa.util.sfatime import SFATIME_FORMAT -from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements +from sfa.rspecs.rspec_elements import RSpecElement, RSpecElements from sfa.rspecs.version_manager import VersionManager + class RSpec: - + def __init__(self, rspec="", version=None, user_options=None, ttl=None, expires=None): - if user_options is None: user_options={} + if user_options is None: + user_options = {} self.header = '\n' self.template = """""" self.version = None @@ -33,7 +35,8 @@ class RSpec: elif version: self.create(version, ttl, expires) else: - raise InvalidRSpec("No RSpec or version specified. Must specify a valid rspec string or a valid version") + raise InvalidRSpec( + "No RSpec or version specified. Must specify a valid rspec string or a valid version") def create(self, version=None, ttl=None, expires=None): """ @@ -42,15 +45,16 @@ class RSpec: """ self.version = self.version_manager.get_version(version) self.namespaces = self.version.namespaces - self.parse_xml(self.version.template, self.version) + self.parse_xml(self.version.template, self.version) now = datetime.utcnow() generated_ts = now.strftime(SFATIME_FORMAT) if ttl is None: ttl = 60 if expires is None: - expires_ts = (now + timedelta(minutes=ttl)).strftime(SFATIME_FORMAT) + expires_ts = (now + timedelta(minutes=ttl) + ).strftime(SFATIME_FORMAT) else: - if isinstance(expires,int): + if isinstance(expires, int): expires_date = datetime.fromtimestamp(expires) else: expires_date = expires @@ -62,15 +66,17 @@ class RSpec: self.xml.parse_xml(xml) if not version: if self.xml.schema: - self.version = self.version_manager.get_version_by_schema(self.xml.schema) + self.version = self.version_manager.get_version_by_schema( + self.xml.schema) else: #raise InvalidRSpec('unknown rspec schema: {}'.format(schema)) # TODO: Should start raising an exception once SFA defines a schema. - # for now we just default to sfa - self.version = self.version_manager.get_version({'type':'sfa','version': '1'}) - self.version.xml = self.xml + # for now we just default to sfa + self.version = self.version_manager.get_version( + {'type': 'sfa', 'version': '1'}) + self.version.xml = self.xml self.namespaces = self.xml.namespaces - + def load_rspec_elements(self, rspec_elements): self.elements = {} for rspec_element in rspec_elements: @@ -81,25 +87,30 @@ class RSpec: if element_type not in RSpecElements: raise InvalidRSpecElement(element_type, extra="no such element type: {}. Must specify a valid RSpecElement".format(element_type)) - self.elements[element_type] = RSpecElement(element_type, element_name, element_path) + self.elements[element_type] = RSpecElement( + element_type, element_name, element_path) def get_rspec_element(self, element_type): if element_type not in self.elements: - msg = "ElementType {} not registered for this rspec".format(element_type) + msg = "ElementType {} not registered for this rspec".format( + element_type) raise InvalidRSpecElement(element_type, extra=msg) return self.elements[element_type] def get(self, element_type, filter=None, depth=0): - if filter is None: filter={} + if filter is None: + filter = {} elements = self.get_elements(element_type, filter) - elements = [self.xml.get_element_attributes(elem, depth=depth) for elem in elements] + elements = [self.xml.get_element_attributes( + elem, depth=depth) for elem in elements] return elements def get_elements(self, element_type, filter=None): """ search for a registered element """ - if filter is None: filter={} + if filter is None: + filter = {} if element_type not in self.elements: msg = "Unable to search for element {} in rspec, expath expression not found."\ .format(element_type) @@ -112,25 +123,23 @@ class RSpec: self.version.merge(in_rspec) def filter(self, filter): - if 'component_manager_id' in filter: + if 'component_manager_id' in filter: nodes = self.version.get_nodes() for node in nodes: if 'component_manager_id' not in node.attrib or \ - node.attrib['component_manager_id'] != filter['component_manager_id']: + node.attrib['component_manager_id'] != filter['component_manager_id']: parent = node.getparent() - parent.remove(node.element) - + parent.remove(node.element) def toxml(self, header=True): if header: return self.header + self.xml.toxml() else: return self.xml.toxml() - def save(self, filename): return self.xml.save(filename) - + if __name__ == '__main__': import sys input = sys.argv[1] @@ -141,4 +150,3 @@ if __name__ == '__main__': # rspec.register_rspec_element(RSpecElements.NODE, 'node', '//node') # print rspec.get(RSpecElements.NODE)[0] # print rspec.get(RSpecElements.NODE, depth=1)[0] - diff --git a/sfa/rspecs/rspec_converter.py b/sfa/rspecs/rspec_converter.py index 5d7bdd61..a24b2def 100755 --- a/sfa/rspecs/rspec_converter.py +++ b/sfa/rspecs/rspec_converter.py @@ -7,6 +7,7 @@ from sfa.rspecs.sfa_rspec_converter import SfaRSpecConverter from sfa.rspecs.rspec import RSpec from sfa.rspecs.version_manager import VersionManager + class RSpecConverter: @staticmethod @@ -15,34 +16,34 @@ class RSpecConverter: version_manager = VersionManager() sfa_version = version_manager._get_version('sfa', '1') pg_version = version_manager._get_version('protogeni', '2') - if rspec.version.type.lower() == sfa_version.type.lower(): - return in_rspec - elif rspec.version.type.lower() == pg_version.type.lower(): + if rspec.version.type.lower() == sfa_version.type.lower(): + return in_rspec + elif rspec.version.type.lower() == pg_version.type.lower(): return PGRSpecConverter.to_sfa_rspec(in_rspec, content_type) else: - return in_rspec + return in_rspec - @staticmethod + @staticmethod def to_pg_rspec(in_rspec, content_type=None): rspec = RSpec(in_rspec) version_manager = VersionManager() sfa_version = version_manager._get_version('sfa', '1') pg_version = version_manager._get_version('protogeni', '2') - if rspec.version.type.lower() == pg_version.type.lower(): + if rspec.version.type.lower() == pg_version.type.lower(): return in_rspec - elif rspec.version.type.lower() == sfa_version.type.lower(): + elif rspec.version.type.lower() == sfa_version.type.lower(): return SfaRSpecConverter.to_pg_rspec(in_rspec, content_type) else: - return in_rspec + return in_rspec if __name__ == '__main__': pg_rspec = 'test/protogeni.rspec' - sfa_rspec = 'test/nodes.rspec' + sfa_rspec = 'test/nodes.rspec' print("converting pg rspec to sfa rspec") print(RSpecConverter.to_sfa_rspec(pg_rspec)) - + print("converting sfa rspec to pg rspec") - print(RSpecConverter.to_pg_rspec(sfa_rspec)) + print(RSpecConverter.to_pg_rspec(sfa_rspec)) diff --git a/sfa/rspecs/rspec_elements.py b/sfa/rspecs/rspec_elements.py index 7f79f68a..1b69f5a4 100644 --- a/sfa/rspecs/rspec_elements.py +++ b/sfa/rspecs/rspec_elements.py @@ -5,29 +5,31 @@ RSpecElements = Enum( AVAILABLE='AVAILABLE', BWLIMIT='BWLIMIT', EXECUTE='EXECUTE', - NETWORK='NETWORK', + NETWORK='NETWORK', COMPONENT_MANAGER='COMPONENT_MANAGER', - HARDWARE_TYPE='HARDWARE_TYPE', - INSTALL='INSTALL', - INTERFACE='INTERFACE', + HARDWARE_TYPE='HARDWARE_TYPE', + INSTALL='INSTALL', + INTERFACE='INTERFACE', INTERFACE_REF='INTERFACE_REF', - LOCATION='LOCATION', - LOGIN='LOGIN', - LINK='LINK', - LINK_TYPE='LINK_TYPE', - NODE='NODE', + LOCATION='LOCATION', + LOGIN='LOGIN', + LINK='LINK', + LINK_TYPE='LINK_TYPE', + NODE='NODE', PROPERTY='PROPERTY', SERVICES='SERVICES', - SLIVER='SLIVER', - SLIVER_TYPE='SLIVER_TYPE', + SLIVER='SLIVER', + SLIVER_TYPE='SLIVER_TYPE', LEASE='LEASE', GRANULARITY='GRANULARITY', SPECTRUM='SPECTRUM', CHANNEL='CHANNEL', - POSITION_3D ='POSITION_3D', + POSITION_3D='POSITION_3D', ) + class RSpecElement: + def __init__(self, element_type, path): if not element_type in RSpecElements: raise InvalidRSpecElement(element_type) diff --git a/sfa/rspecs/sfa_rspec_converter.py b/sfa/rspecs/sfa_rspec_converter.py index 577b788d..62760048 100755 --- a/sfa/rspecs/sfa_rspec_converter.py +++ b/sfa/rspecs/sfa_rspec_converter.py @@ -6,83 +6,96 @@ from sfa.util.xrn import hrn_to_urn from sfa.rspecs.rspec import RSpec from sfa.rspecs.version_manager import VersionManager + class SfaRSpecConverter: @staticmethod - def to_pg_rspec(rspec, content_type = None): + def to_pg_rspec(rspec, content_type=None): if not isinstance(rspec, RSpec): sfa_rspec = RSpec(rspec) else: sfa_rspec = rspec - + if not content_type or content_type not in \ - ['ad', 'request', 'manifest']: + ['ad', 'request', 'manifest']: content_type = sfa_rspec.version.content_type - - + version_manager = VersionManager() pg_version = version_manager._get_version('protogeni', '2', 'request') pg_rspec = RSpec(version=pg_version) - + # get networks networks = sfa_rspec.version.get_networks() - + for network in networks: # get nodes - sfa_node_elements = sfa_rspec.version.get_node_elements(network=network) + sfa_node_elements = sfa_rspec.version.get_node_elements( + network=network) for sfa_node_element in sfa_node_elements: # create node element node_attrs = {} node_attrs['exclusive'] = 'false' if 'component_manager_id' in sfa_node_element.attrib: - node_attrs['component_manager_id'] = sfa_node_element.attrib['component_manager_id'] + node_attrs['component_manager_id'] = sfa_node_element.attrib[ + 'component_manager_id'] else: - node_attrs['component_manager_id'] = hrn_to_urn(network, 'authority+cm') + node_attrs['component_manager_id'] = hrn_to_urn( + network, 'authority+cm') if 'component_id' in sfa_node_element.attrib: - node_attrs['compoenent_id'] = sfa_node_element.attrib['component_id'] + node_attrs['compoenent_id'] = sfa_node_element.attrib[ + 'component_id'] if sfa_node_element.find('hostname') != None: hostname = sfa_node_element.find('hostname').text node_attrs['component_name'] = hostname node_attrs['client_id'] = hostname - node_element = pg_rspec.xml.add_element('node', node_attrs) - + node_element = pg_rspec.xml.add_element('node', node_attrs) + if content_type == 'request': sliver_element = sfa_node_element.find('sliver') - sliver_type_elements = sfa_node_element.xpath('./sliver_type', namespaces=sfa_rspec.namespaces) - available_sliver_types = [element.attrib['name'] for element in sliver_type_elements] + sliver_type_elements = sfa_node_element.xpath( + './sliver_type', namespaces=sfa_rspec.namespaces) + available_sliver_types = [element.attrib[ + 'name'] for element in sliver_type_elements] valid_sliver_types = ['emulab-openvz', 'raw-pc'] - - # determine sliver type + + # determine sliver type requested_sliver_type = 'emulab-openvz' for available_sliver_type in available_sliver_types: if available_sliver_type in valid_sliver_types: requested_sliver_type = available_sliver_type - + if sliver_element != None: - pg_rspec.xml.add_element('sliver_type', {'name': requested_sliver_type}, parent=node_element) + pg_rspec.xml.add_element( + 'sliver_type', {'name': requested_sliver_type}, parent=node_element) else: # create node_type element for hw_type in ['plab-pc', 'pc']: - hdware_type_element = pg_rspec.xml.add_element('hardware_type', {'name': hw_type}, parent=node_element) + hdware_type_element = pg_rspec.xml.add_element( + 'hardware_type', {'name': hw_type}, parent=node_element) # create available element - pg_rspec.xml.add_element('available', {'now': 'true'}, parent=node_element) + pg_rspec.xml.add_element( + 'available', {'now': 'true'}, parent=node_element) # create locaiton element - # We don't actually associate nodes with a country. + # We don't actually associate nodes with a country. # Set country to "unknown" until we figure out how to make # sure this value is always accurate. location = sfa_node_element.find('location') if location != None: - location_attrs = {} - location_attrs['country'] = location.get('country', 'unknown') - location_attrs['latitude'] = location.get('latitude', 'None') - location_attrs['longitude'] = location.get('longitude', 'None') - pg_rspec.xml.add_element('location', location_attrs, parent=node_element) + location_attrs = {} + location_attrs['country'] = location.get( + 'country', 'unknown') + location_attrs['latitude'] = location.get( + 'latitude', 'None') + location_attrs['longitude'] = location.get( + 'longitude', 'None') + pg_rspec.xml.add_element( + 'location', location_attrs, parent=node_element) return pg_rspec.toxml() if __name__ == '__main__': import sys - if len(sys.argv) > 1: + if len(sys.argv) > 1: print(SfaRSpecConverter.to_pg_rspec(sys.argv[1])) diff --git a/sfa/rspecs/version.py b/sfa/rspecs/version.py index 91c96e09..7faf2abe 100644 --- a/sfa/rspecs/version.py +++ b/sfa/rspecs/version.py @@ -1,6 +1,7 @@ #!/usr/bin/python from sfa.util.sfalogging import logger + class RSpecVersion: type = None content_type = None @@ -26,5 +27,3 @@ class RSpecVersion: def __str__(self): return "%s %s" % (self.type, self.version) - - diff --git a/sfa/rspecs/version_manager.py b/sfa/rspecs/version_manager.py index e0a604e4..e71c7ca6 100644 --- a/sfa/rspecs/version_manager.py +++ b/sfa/rspecs/version_manager.py @@ -2,11 +2,13 @@ from __future__ import print_function import os from sfa.util.faults import InvalidRSpec, UnsupportedRSpecVersion -from sfa.rspecs.version import RSpecVersion -from sfa.util.sfalogging import logger +from sfa.rspecs.version import RSpecVersion +from sfa.util.sfalogging import logger from sfa.util.py23 import StringType + class VersionManager: + def __init__(self): self.versions = [] self.load_versions() @@ -14,18 +16,18 @@ class VersionManager: def __repr__(self): return ""\ .format(len(self.versions), - ", ".join( [ str(x) for x in self.versions ])) - + ", ".join([str(x) for x in self.versions])) + def load_versions(self): - path = os.path.dirname(os.path.abspath( __file__ )) + path = os.path.dirname(os.path.abspath(__file__)) versions_path = path + os.sep + 'versions' versions_module_path = 'sfa.rspecs.versions' valid_module = lambda x: os.path.isfile(os.sep.join([versions_path, x])) \ - and x.endswith('.py') and x != '__init__.py' + and x.endswith('.py') and x != '__init__.py' files = [f for f in os.listdir(versions_path) if valid_module(f)] for filename in files: basename = filename.split('.')[0] - module_path = versions_module_path +'.'+basename + module_path = versions_module_path + '.' + basename module = __import__(module_path, fromlist=module_path) for attr_name in dir(module): attr = getattr(module, attr_name) @@ -38,20 +40,23 @@ class VersionManager: if type is None or type.lower() == version.type.lower(): if version_num is None or str(float(version_num)) == str(float(version.version)): if content_type is None or content_type.lower() == version.content_type.lower() \ - or version.content_type == '*': + or version.content_type == '*': retval = version - ### sounds like we should be glad with the first match, not the last one + # sounds like we should be glad with the first match, + # not the last one break if not retval: - raise UnsupportedRSpecVersion("[%s %s %s] is not suported here"% (type, version_num, content_type)) + raise UnsupportedRSpecVersion( + "[%s %s %s] is not suported here" % (type, version_num, content_type)) return retval def get_version(self, version=None): retval = None if isinstance(version, dict): - retval = self._get_version(version.get('type'), version.get('version'), version.get('content_type')) + retval = self._get_version(version.get('type'), version.get( + 'version'), version.get('content_type')) elif isinstance(version, StringType): - version_parts = version.split(' ') + version_parts = version.split(' ') num_parts = len(version_parts) type = version_parts[0] version_num = None @@ -60,14 +65,15 @@ class VersionManager: version_num = version_parts[1] if num_parts > 2: content_type = version_parts[2] - retval = self._get_version(type, version_num, content_type) + retval = self._get_version(type, version_num, content_type) elif isinstance(version, RSpecVersion): retval = version elif not version: retval = self.versions[0] else: - raise UnsupportedRSpecVersion("No such version: %s "% str(version)) - + raise UnsupportedRSpecVersion( + "No such version: %s " % str(version)) + return retval def get_version_by_schema(self, schema): @@ -94,9 +100,8 @@ class VersionManager: if __name__ == '__main__': manager = VersionManager() print(manager) - manager.show_by_string('sfa 1') - manager.show_by_string('protogeni 2') - manager.show_by_string('protogeni 2 advertisement') - manager.show_by_schema('http://www.protogeni.net/resources/rspec/2/ad.xsd') + manager.show_by_string('sfa 1') + manager.show_by_string('protogeni 2') + manager.show_by_string('protogeni 2 advertisement') + manager.show_by_schema('http://www.protogeni.net/resources/rspec/2/ad.xsd') manager.show_by_schema('http://sorch.netmode.ntua.gr/ws/RSpec/ad.xsd') - diff --git a/sfa/rspecs/versions/federica.py b/sfa/rspecs/versions/federica.py index 8ff9f5e8..4e245872 100644 --- a/sfa/rspecs/versions/federica.py +++ b/sfa/rspecs/versions/federica.py @@ -1,5 +1,6 @@ from sfa.rspecs.versions.pgv2 import PGv2 + class FedericaAd (PGv2): enabled = True type = 'Fedrica' @@ -7,6 +8,7 @@ class FedericaAd (PGv2): schema = 'http://sorch.netmode.ntua.gr/ws/RSpec/ad.xsd' namespace = 'http://sorch.netmode.ntua.gr/ws/RSpec' + class FedericaRequest (PGv2): enabled = True type = 'Fedrica' @@ -14,10 +16,10 @@ class FedericaRequest (PGv2): schema = 'http://sorch.netmode.ntua.gr/ws/RSpec/request.xsd' namespace = 'http://sorch.netmode.ntua.gr/ws/RSpec' + class FedericaManifest (PGv2): enabled = True type = 'Fedrica' content_type = 'manifest' schema = 'http://sorch.netmode.ntua.gr/ws/RSpec/manifest.xsd' namespace = 'http://sorch.netmode.ntua.gr/ws/RSpec' - diff --git a/sfa/rspecs/versions/iotlabv1.py b/sfa/rspecs/versions/iotlabv1.py index 0d393237..5d4383c8 100644 --- a/sfa/rspecs/versions/iotlabv1.py +++ b/sfa/rspecs/versions/iotlabv1.py @@ -37,11 +37,11 @@ class Iotlabv1(RSpecVersion): # Network def get_networks(self): - #WARNING Added //default:network to the xpath - #otherwise network element not detected 16/07/12 SA + # WARNING Added //default:network to the xpath + # otherwise network element not detected 16/07/12 SA network_elems = self.xml.xpath('//network | //default:network') - networks = [network_elem.get_instance(fields=['name', 'slice']) for \ + networks = [network_elem.get_instance(fields=['name', 'slice']) for network_elem in network_elems] return networks @@ -60,33 +60,35 @@ class Iotlabv1(RSpecVersion): def get_nodes_with_slivers(self): return Iotlabv1Node.get_nodes_with_slivers(self.xml) - def get_slice_timeslot(self ): + def get_slice_timeslot(self): return Iotlabv1Timeslot.get_slice_timeslot(self.xml) def add_connection_information(self, ldap_username, sites_set): - return Iotlabv1Node.add_connection_information(self.xml,ldap_username, sites_set) + return Iotlabv1Node.add_connection_information(self.xml, ldap_username, sites_set) def add_nodes(self, nodes, check_for_dupes=False, rspec_content_type=None): - return Iotlabv1Node.add_nodes(self.xml,nodes, rspec_content_type) + return Iotlabv1Node.add_nodes(self.xml, nodes, rspec_content_type) - def merge_node(self, source_node_tag, network, no_dupes = False): + def merge_node(self, source_node_tag, network, no_dupes=False): logger.debug("SLABV1 merge_node") - #if no_dupes and self.get_node_element(node['hostname']): - ## node already exists - #return + # if no_dupes and self.get_node_element(node['hostname']): + # node already exists + # return network_tag = self.add_network(network) network_tag.append(deepcopy(source_node_tag)) # Slivers def get_sliver_attributes(self, hostname, node, network=None): - print("\r\n \r\n \r\n \t\t SLABV1.PY get_sliver_attributes hostname %s " %(hostname), file=sys.stderr) - nodes = self.get_nodes({'component_id': '*%s*' %hostname}) + print("\r\n \r\n \r\n \t\t SLABV1.PY get_sliver_attributes hostname %s " % ( + hostname), file=sys.stderr) + nodes = self.get_nodes({'component_id': '*%s*' % hostname}) attribs = [] - print("\r\n \r\n \r\n \t\t SLABV1.PY get_sliver_attributes-----------------nodes %s " %(nodes), file=sys.stderr) + print("\r\n \r\n \r\n \t\t SLABV1.PY get_sliver_attributes-----------------nodes %s " % + (nodes), file=sys.stderr) if nodes is not None and isinstance(nodes, list) and len(nodes) > 0: node = nodes[0] - #if node : + # if node : #sliver = node.xpath('./default:sliver | ./sliver') #sliver = node.xpath('./default:sliver', namespaces=self.namespaces) sliver = node['slivers'] @@ -95,7 +97,8 @@ class Iotlabv1(RSpecVersion): sliver = sliver[0] attribs = sliver #attribs = self.attributes_list(sliver) - print("\r\n \r\n \r\n \t\t SLABV1.PY get_sliver_attributes----------NN------- sliver %s self.namespaces %s attribs %s " %(sliver, self.namespaces,attribs), file=sys.stderr) + print("\r\n \r\n \r\n \t\t SLABV1.PY get_sliver_attributes----------NN------- sliver %s self.namespaces %s attribs %s " % + (sliver, self.namespaces, attribs), file=sys.stderr) return attribs def get_slice_attributes(self, network=None): @@ -107,25 +110,26 @@ class Iotlabv1(RSpecVersion): # TODO: default sliver attributes in the PG rspec? default_ns_prefix = self.namespaces['default'] for node in nodes_with_slivers: - sliver_attributes = self.get_sliver_attributes(node['component_id'],node, network) + sliver_attributes = self.get_sliver_attributes( + node['component_id'], node, network) for sliver_attribute in sliver_attributes: name = str(sliver_attribute[0]) text = str(sliver_attribute[1]) attribs = sliver_attribute[2] # we currently only suppor the and attributes - #if 'info' in name: - #attribute = {'name': 'flack_info', 'value': str(attribs), 'node_id': node} - #slice_attributes.append(attribute) - #elif 'initscript' in name: + # if 'info' in name: + #attribute = {'name': 'flack_info', 'value': str(attribs), 'node_id': node} + # slice_attributes.append(attribute) + # elif 'initscript' in name: if 'initscript' in name: if attribs is not None and 'name' in attribs: value = attribs['name'] else: value = text - attribute = {'name': 'initscript', 'value': value, 'node_id': node} + attribute = {'name': 'initscript', + 'value': value, 'node_id': node} slice_attributes.append(attribute) - return slice_attributes def attributes_list(self, elem): @@ -142,7 +146,8 @@ class Iotlabv1(RSpecVersion): pass def add_slivers(self, hostnames, attributes=None, sliver_urn=None, append=False): - if attributes is None: attributes=[] + if attributes is None: + attributes = [] # all nodes hould already be present in the rspec. Remove all # nodes that done have slivers print("\r\n \r\n \r\n \t\t\t Iotlabv1.PY add_slivers ----->get_node ", file=sys.stderr) @@ -153,8 +158,10 @@ class Iotlabv1(RSpecVersion): node_elem = node_elems[0] # determine sliver types for this node - #TODO : add_slivers valid type of sliver needs to be changed 13/07/12 SA - valid_sliver_types = ['iotlab-node', 'emulab-openvz', 'raw-pc', 'plab-vserver', 'plab-vnode'] + # TODO : add_slivers valid type of sliver needs to be changed + # 13/07/12 SA + valid_sliver_types = [ + 'iotlab-node', 'emulab-openvz', 'raw-pc', 'plab-vserver', 'plab-vnode'] #valid_sliver_types = ['emulab-openvz', 'raw-pc', 'plab-vserver', 'plab-vnode'] requested_sliver_type = None for sliver_type in node_elem.get('slivers', []): @@ -164,8 +171,9 @@ class Iotlabv1(RSpecVersion): if not requested_sliver_type: continue sliver = {'type': requested_sliver_type, - 'pl_tags': attributes} - print("\r\n \r\n \r\n \t\t\t Iotlabv1.PY add_slivers node_elem %s sliver_type %s \r\n \r\n " %(node_elem, sliver_type), file=sys.stderr) + 'pl_tags': attributes} + print("\r\n \r\n \r\n \t\t\t Iotlabv1.PY add_slivers node_elem %s sliver_type %s \r\n \r\n " % ( + node_elem, sliver_type), file=sys.stderr) # remove available element for available_elem in node_elem.xpath('./default:available | ./available'): node_elem.remove(available_elem) @@ -203,7 +211,6 @@ class Iotlabv1(RSpecVersion): def remove_slivers(self, slivers, network=None, no_dupes=False): Iotlabv1Node.remove_slivers(self.xml, slivers) - # Utility def merge(self, in_rspec): @@ -225,7 +232,7 @@ class Iotlabv1(RSpecVersion): in_rspec = RSpecConverter.to_sfa_rspec(rspec.toxml()) rspec = RSpec(in_rspec) # just copy over all networks - #Attention special get_networks using //default:network xpath + # Attention special get_networks using //default:network xpath current_networks = self.get_networks() networks = rspec.version.get_networks() for network in networks: @@ -234,17 +241,13 @@ class Iotlabv1(RSpecVersion): self.xml.append(network.element) current_networks.append(current_network) - - - - # Leases def get_leases(self, lease_filter=None): return SFAv1Lease.get_leases(self.xml, lease_filter) - #return Iotlabv1Lease.get_leases(self.xml, lease_filter) + # return Iotlabv1Lease.get_leases(self.xml, lease_filter) - def add_leases(self, leases, network = None, no_dupes=False): + def add_leases(self, leases, network=None, no_dupes=False): SFAv1Lease.add_leases(self.xml, leases) #Iotlabv1Lease.add_leases(self.xml, leases) @@ -253,7 +256,7 @@ class Iotlabv1(RSpecVersion): def get_channels(self, filter=None): return [] - def add_channels(self, channels, network = None, no_dupes=False): + def add_channels(self, channels, network=None, no_dupes=False): pass # Links @@ -266,10 +269,10 @@ class Iotlabv1(RSpecVersion): def add_links(self, links): pass + def add_link_requests(self, links): pass - def cleanup(self): # remove unncecessary elements, attributes if self.type in ['request', 'manifest']: @@ -281,21 +284,23 @@ class Iotlabv1Ad(Iotlabv1): enabled = True content_type = 'ad' schema = 'http://senslab.info/resources/rspec/1/ad.xsd' - #http://www.geni.net/resources/rspec/3/ad.xsd' + # http://www.geni.net/resources/rspec/3/ad.xsd' template = '' + class Iotlabv1Request(Iotlabv1): enabled = True content_type = 'request' schema = 'http://senslab.info/resources/rspec/1/request.xsd' - #http://www.geni.net/resources/rspec/3/request.xsd + # http://www.geni.net/resources/rspec/3/request.xsd template = '' + class Iotlabv1Manifest(Iotlabv1): enabled = True content_type = 'manifest' schema = 'http://senslab.info/resources/rspec/1/manifest.xsd' - #http://www.geni.net/resources/rspec/3/manifest.xsd + # http://www.geni.net/resources/rspec/3/manifest.xsd template = '' diff --git a/sfa/rspecs/versions/nitosv1.py b/sfa/rspecs/versions/nitosv1.py index 60caf1d1..99485c6f 100644 --- a/sfa/rspecs/versions/nitosv1.py +++ b/sfa/rspecs/versions/nitosv1.py @@ -13,6 +13,7 @@ from sfa.rspecs.elements.versions.nitosv1Sliver import NITOSv1Sliver from sfa.rspecs.elements.versions.nitosv1Lease import NITOSv1Lease from sfa.rspecs.elements.versions.nitosv1Channel import NITOSv1Channel + class NITOSv1(RSpecVersion): enabled = True type = 'NITOS' @@ -24,13 +25,12 @@ class NITOSv1(RSpecVersion): namespaces = None template = '' % type - # Network + # Network def get_networks(self): network_elems = self.xml.xpath('//network') - networks = [network_elem.get_instance(fields=['name', 'slice']) for \ + networks = [network_elem.get_instance(fields=['name', 'slice']) for network_elem in network_elems] - return networks - + return networks def add_network(self, network): network_tags = self.xml.xpath('//network[@name="%s"]' % network) @@ -40,16 +40,15 @@ class NITOSv1(RSpecVersion): network_tag = network_tags[0] return network_tag - # Nodes - + def get_nodes(self, filter=None): return NITOSv1Node.get_nodes(self.xml, filter) def get_nodes_with_slivers(self): return NITOSv1Node.get_nodes_with_slivers(self.xml) - def add_nodes(self, nodes, network = None, no_dupes=False, rspec_content_type=None): + def add_nodes(self, nodes, network=None, no_dupes=False, rspec_content_type=None): NITOSv1Node.add_nodes(self.xml, nodes, rspec_content_type) def merge_node(self, source_node_tag, network, no_dupes=False): @@ -61,9 +60,10 @@ class NITOSv1(RSpecVersion): network_tag.append(deepcopy(source_node_tag)) # Slivers - + def add_slivers(self, hostnames, attributes=None, sliver_urn=None, append=False): - if attributes is None: attributes=[] + if attributes is None: + attributes = [] # add slice name to network tag network_tags = self.xml.xpath('//network') if network_tags: @@ -71,7 +71,7 @@ class NITOSv1(RSpecVersion): network_tag.set('slice', urn_to_hrn(sliver_urn)[0]) # add slivers - sliver = {'name':sliver_urn, + sliver = {'name': sliver_urn, 'pl_tags': attributes} for hostname in hostnames: if sliver_urn: @@ -89,10 +89,9 @@ class NITOSv1(RSpecVersion): parent = node_elem.element.getparent() parent.remove(node_elem.element) - def remove_slivers(self, slivers, network=None, no_dupes=False): NITOSv1Node.remove_slivers(self.xml, slivers) - + def get_slice_attributes(self, network=None): attributes = [] nodes_with_slivers = self.get_nodes_with_slivers() @@ -101,14 +100,13 @@ class NITOSv1(RSpecVersion): attribute['node_id'] = None attributes.append(attribute) for node in nodes_with_slivers: - nodename=node['component_name'] + nodename = node['component_name'] sliver_attributes = self.get_sliver_attributes(nodename, network) for sliver_attribute in sliver_attributes: sliver_attribute['node_id'] = nodename attributes.append(sliver_attribute) return attributes - def add_sliver_attribute(self, component_id, name, value, network=None): nodes = self.get_nodes({'component_id': '*%s*' % component_id}) if nodes is not None and isinstance(nodes, list) and len(nodes) > 0: @@ -119,7 +117,8 @@ class NITOSv1(RSpecVersion): NITOSv1Sliver.add_sliver_attribute(sliver, name, value) else: # should this be an assert / raise an exception? - logger.error("WARNING: failed to find component_id %s" % component_id) + logger.error("WARNING: failed to find component_id %s" % + component_id) def get_sliver_attributes(self, component_id, network=None): nodes = self.get_nodes({'component_id': '*%s*' % component_id}) @@ -136,20 +135,21 @@ class NITOSv1(RSpecVersion): attribs = self.get_sliver_attributes(component_id) for attrib in attribs: if attrib['name'] == name and attrib['value'] == value: - #attrib.element.delete() + # attrib.element.delete() parent = attrib.element.getparent() parent.remove(attrib.element) def add_default_sliver_attribute(self, name, value, network=None): if network: - defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network) + defaults = self.xml.xpath( + "//network[@name='%s']/sliver_defaults" % network) else: defaults = self.xml.xpath("//sliver_defaults") if not defaults: if network: network_tag = self.xml.xpath("//network[@name='%s']" % network) else: - network_tag = self.xml.xpath("//network") + network_tag = self.xml.xpath("//network") if isinstance(network_tag, list): network_tag = network_tag[0] defaults = network_tag.add_element('sliver_defaults') @@ -159,17 +159,19 @@ class NITOSv1(RSpecVersion): def get_default_sliver_attributes(self, network=None): if network: - defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network) + defaults = self.xml.xpath( + "//network[@name='%s']/sliver_defaults" % network) else: defaults = self.xml.xpath("//sliver_defaults") - if not defaults: return [] + if not defaults: + return [] return NITOSv1Sliver.get_sliver_attributes(defaults[0]) - + def remove_default_sliver_attribute(self, name, value, network=None): attribs = self.get_default_sliver_attributes(network) for attrib in attribs: if attrib['name'] == name and attrib['value'] == value: - #attrib.element.delete() + # attrib.element.delete() parent = attrib.element.getparent() parent.remove(attrib.element) @@ -183,6 +185,7 @@ class NITOSv1(RSpecVersion): def add_links(self, links): pass + def add_link_requests(self, links): pass @@ -220,7 +223,7 @@ class NITOSv1(RSpecVersion): def get_leases(self, filter=None): return NITOSv1Lease.get_leases(self.xml, filter) - def add_leases(self, leases_channels, network = None, no_dupes=False): + def add_leases(self, leases_channels, network=None, no_dupes=False): leases, channels = leases_channels NITOSv1Lease.add_leases(self.xml, leases, channels) @@ -229,11 +232,10 @@ class NITOSv1(RSpecVersion): def get_channels(self, filter=None): return NITOSv1Channel.get_channels(self.xml, filter) - def add_channels(self, channels, network = None, no_dupes=False): + def add_channels(self, channels, network=None, no_dupes=False): NITOSv1Channel.add_channels(self.xml, channels) - if __name__ == '__main__': from sfa.rspecs.rspec import RSpec from sfa.rspecs.rspec_elements import * diff --git a/sfa/rspecs/versions/ofeliav1.py b/sfa/rspecs/versions/ofeliav1.py index 0a001cbd..c38dee30 100755 --- a/sfa/rspecs/versions/ofeliav1.py +++ b/sfa/rspecs/versions/ofeliav1.py @@ -16,6 +16,7 @@ from sfa.rspecs.elements.versions.sfav1Lease import SFAv1Lease from sfa.rspecs.elements.versions.ofeliav1datapath import Ofeliav1Datapath from sfa.rspecs.elements.versions.ofeliav1link import Ofeliav1Link + class Ofelia(RSpecVersion): enabled = True type = 'OFELIA' @@ -28,14 +29,13 @@ class Ofelia(RSpecVersion): #template = '' % type template = '' - # Network + # Network def get_networks(self): raise Exception("Not implemented") network_elems = self.xml.xpath('//network') - networks = [network_elem.get_instance(fields=['name', 'slice']) for \ + networks = [network_elem.get_instance(fields=['name', 'slice']) for network_elem in network_elems] - return networks - + return networks def add_network(self, network): raise Exception("Not implemented") @@ -46,19 +46,20 @@ class Ofelia(RSpecVersion): network_tag = network_tags[0] return network_tag -# These are all resources -# get_resources function can return all resources or a specific type of resource +# These are all resources +# get_resources function can return all resources or a specific type of +# resource def get_resources(self, filter=None, type=None): resources = list() - if not type or type=='datapath': + if not type or type == 'datapath': datapaths = self.get_datapaths(filter) for datapath in datapaths: - datapath['type']='datapath' + datapath['type'] = 'datapath' resources.extend(datapaths) - if not type or type=='link': + if not type or type == 'link': links = self.get_links(filter) for link in links: - link['type']='link' + link['type'] = 'link' resources.extend(links) return resources @@ -71,7 +72,7 @@ class Ofelia(RSpecVersion): return Ofeliav1Link.get_links(self.xml, filter) # def get_link_requests(self): -# return PGv2Link.get_link_requests(self.xml) +# return PGv2Link.get_link_requests(self.xml) # # def add_links(self, links): # networks = self.get_networks() @@ -84,12 +85,11 @@ class Ofelia(RSpecVersion): # def add_link_requests(self, links): # PGv2Link.add_link_requests(self.xml, links) - - # Slivers - + def add_slivers(self, hostnames, attributes=None, sliver_urn=None, append=False): - if attributes is None: attributes=[] + if attributes is None: + attributes = [] # add slice name to network tag network_tags = self.xml.xpath('//network') if network_tags: @@ -97,7 +97,7 @@ class Ofelia(RSpecVersion): network_tag.set('slice', urn_to_hrn(sliver_urn)[0]) # add slivers - sliver = {'name':sliver_urn, + sliver = {'name': sliver_urn, 'pl_tags': attributes} for hostname in hostnames: if sliver_urn: @@ -115,10 +115,9 @@ class Ofelia(RSpecVersion): parent = node_elem.element.getparent() parent.remove(node_elem.element) - def remove_slivers(self, slivers, network=None, no_dupes=False): SFAv1Node.remove_slivers(self.xml, slivers) - + def get_slice_attributes(self, network=None): attributes = [] nodes_with_slivers = self.get_nodes_with_slivers() @@ -127,14 +126,13 @@ class Ofelia(RSpecVersion): attribute['node_id'] = None attributes.append(attribute) for node in nodes_with_slivers: - nodename=node['component_name'] + nodename = node['component_name'] sliver_attributes = self.get_sliver_attributes(nodename, network) for sliver_attribute in sliver_attributes: sliver_attribute['node_id'] = nodename attributes.append(sliver_attribute) return attributes - def add_sliver_attribute(self, component_id, name, value, network=None): nodes = self.get_nodes({'component_id': '*%s*' % component_id}) if nodes is not None and isinstance(nodes, list) and len(nodes) > 0: @@ -145,7 +143,8 @@ class Ofelia(RSpecVersion): SFAv1Sliver.add_sliver_attribute(sliver, name, value) else: # should this be an assert / raise an exception? - logger.error("WARNING: failed to find component_id %s" % component_id) + logger.error("WARNING: failed to find component_id %s" % + component_id) def get_sliver_attributes(self, component_id, network=None): nodes = self.get_nodes({'component_id': '*%s*' % component_id}) @@ -162,20 +161,21 @@ class Ofelia(RSpecVersion): attribs = self.get_sliver_attributes(component_id) for attrib in attribs: if attrib['name'] == name and attrib['value'] == value: - #attrib.element.delete() + # attrib.element.delete() parent = attrib.element.getparent() parent.remove(attrib.element) def add_default_sliver_attribute(self, name, value, network=None): if network: - defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network) + defaults = self.xml.xpath( + "//network[@name='%s']/sliver_defaults" % network) else: defaults = self.xml.xpath("//sliver_defaults") if not defaults: if network: network_tag = self.xml.xpath("//network[@name='%s']" % network) else: - network_tag = self.xml.xpath("//network") + network_tag = self.xml.xpath("//network") if isinstance(network_tag, list): network_tag = network_tag[0] defaults = network_tag.add_element('sliver_defaults') @@ -185,17 +185,19 @@ class Ofelia(RSpecVersion): def get_default_sliver_attributes(self, network=None): if network: - defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network) + defaults = self.xml.xpath( + "//network[@name='%s']/sliver_defaults" % network) else: defaults = self.xml.xpath("//sliver_defaults") - if not defaults: return [] + if not defaults: + return [] return SFAv1Sliver.get_sliver_attributes(defaults[0]) - + def remove_default_sliver_attribute(self, name, value, network=None): attribs = self.get_default_sliver_attributes(network) for attrib in attribs: if attrib['name'] == name and attrib['value'] == value: - #attrib.element.delete() + # attrib.element.delete() parent = attrib.element.getparent() parent.remove(attrib.element) @@ -234,13 +236,13 @@ if __name__ == '__main__': from sfa.rspecs.rspec import RSpec from sfa.rspecs.rspec_elements import * print("main ofeliav1") - if len(sys.argv)!=2: + if len(sys.argv) != 2: r = RSpec('/tmp/resources.rspec') else: - r = RSpec(sys.argv[1], version = 'OFELIA 1') - #print r.version.get_datapaths() + r = RSpec(sys.argv[1], version='OFELIA 1') + # print r.version.get_datapaths() resources = r.version.get_resources() pprint.pprint(resources) - #r.load_rspec_elements(SFAv1.elements) - #print r.get(RSpecElements.NODE) + # r.load_rspec_elements(SFAv1.elements) + # print r.get(RSpecElements.NODE) diff --git a/sfa/rspecs/versions/pgv2.py b/sfa/rspecs/versions/pgv2.py index 3e25a376..d9ca6830 100644 --- a/sfa/rspecs/versions/pgv2.py +++ b/sfa/rspecs/versions/pgv2.py @@ -10,6 +10,7 @@ from sfa.rspecs.elements.versions.pgv2Lease import PGv2Lease from sfa.util.sfalogging import logger from sfa.util.py23 import StringType + class PGv2(RSpecVersion): type = 'ProtoGENI' content_type = 'ad' @@ -51,18 +52,19 @@ class PGv2(RSpecVersion): def add_nodes(self, nodes, check_for_dupes=False, rspec_content_type=None): return PGv2Node.add_nodes(self.xml, nodes, rspec_content_type) - + def merge_node(self, source_node_tag): # this is untested self.xml.root.append(deepcopy(source_node_tag)) # Slivers - + def get_sliver_attributes(self, component_id, network=None): - nodes = self.get_nodes({'component_id': '*%s*' %component_id}) + nodes = self.get_nodes({'component_id': '*%s*' % component_id}) try: node = nodes[0] - sliver = node.xpath('./default:sliver_type', namespaces=self.namespaces) + sliver = node.xpath('./default:sliver_type', + namespaces=self.namespaces) if sliver is not None and isinstance(sliver, list) and len(sliver) > 0: sliver = sliver[0] return self.attributes_list(sliver) @@ -77,21 +79,25 @@ class PGv2(RSpecVersion): # TODO: default sliver attributes in the PG rspec? default_ns_prefix = self.namespaces['default'] for node in nodes_with_slivers: - sliver_attributes = self.get_sliver_attributes(node['component_id'], network) + sliver_attributes = self.get_sliver_attributes( + node['component_id'], network) for sliver_attribute in sliver_attributes: - name=str(sliver_attribute[0]) - text =str(sliver_attribute[1]) + name = str(sliver_attribute[0]) + text = str(sliver_attribute[1]) attribs = sliver_attribute[2] - # we currently only suppor the and attributes - if 'info' in name: - attribute = {'name': 'flack_info', 'value': str(attribs), 'node_id': node} + # we currently only suppor the and + # attributes + if 'info' in name: + attribute = {'name': 'flack_info', + 'value': str(attribs), 'node_id': node} slice_attributes.append(attribute) elif 'initscript' in name: if attribs is not None and 'name' in attribs: value = attribs['name'] else: value = text - attribute = {'name': 'initscript', 'value': value, 'node_id': node} + attribute = {'name': 'initscript', + 'value': value, 'node_id': node} slice_attributes.append(attribute) return slice_attributes @@ -110,7 +116,8 @@ class PGv2(RSpecVersion): pass def add_slivers(self, hostnames, attributes=None, sliver_urn=None, append=False): - if attributes is None: attributes=[] + if attributes is None: + attributes = [] # all nodes hould already be present in the rspec. Remove all # nodes that done have slivers for hostname in hostnames: @@ -118,27 +125,28 @@ class PGv2(RSpecVersion): if not node_elems: continue node_elem = node_elems[0] - + # determine sliver types for this node - valid_sliver_types = ['emulab-openvz', 'raw-pc', 'plab-vserver', 'plab-vnode'] + valid_sliver_types = ['emulab-openvz', + 'raw-pc', 'plab-vserver', 'plab-vnode'] requested_sliver_type = None for sliver_type in node_elem.get('slivers', []): if sliver_type.get('type') in valid_sliver_types: requested_sliver_type = sliver_type['type'] - + if not requested_sliver_type: continue sliver = {'type': requested_sliver_type, - 'pl_tags': attributes} + 'pl_tags': attributes} # remove available element for available_elem in node_elem.xpath('./default:available | ./available'): node_elem.remove(available_elem) - + # remove interface elements for interface_elem in node_elem.xpath('./default:interface | ./interface'): node_elem.remove(interface_elem) - + # remove existing sliver_type elements for sliver_type in node_elem.get('slivers', []): node_elem.element.remove(sliver_type.element) @@ -154,8 +162,8 @@ class PGv2(RSpecVersion): #sliver_id = Xrn(xrn=sliver_urn, type='slice', id=str(node_id)).get_urn() #node_elem.set('sliver_id', sliver_id) - # add the sliver type elemnt - PGv2SliverType.add_slivers(node_elem.element, sliver) + # add the sliver type elemnt + PGv2SliverType.add_slivers(node_elem.element, sliver) # remove all nodes without slivers if not append: @@ -165,7 +173,7 @@ class PGv2(RSpecVersion): parent.remove(node_elem.element) def remove_slivers(self, slivers, network=None, no_dupes=False): - PGv2Node.remove_slivers(self.xml, slivers) + PGv2Node.remove_slivers(self.xml, slivers) # Links @@ -173,7 +181,7 @@ class PGv2(RSpecVersion): return PGv2Link.get_links(self.xml) def get_link_requests(self): - return PGv2Link.get_link_requests(self.xml) + return PGv2Link.get_link_requests(self.xml) def add_links(self, links): PGv2Link.add_links(self.xml.root, links) @@ -186,7 +194,7 @@ class PGv2(RSpecVersion): def get_leases(self, filter=None): return PGv2Lease.get_leases(self.xml, filter) - def add_leases(self, leases, network = None, no_dupes=False): + def add_leases(self, leases, network=None, no_dupes=False): PGv2Lease.add_leases(self.xml, leases) # Spectrum @@ -194,7 +202,7 @@ class PGv2(RSpecVersion): def get_channels(self, filter=None): return [] - def add_channels(self, channels, network = None, no_dupes=False): + def add_channels(self, channels, network=None, no_dupes=False): pass # Utility @@ -220,16 +228,14 @@ class PGv2(RSpecVersion): main_nodes.append(node) self.add_nodes(main_nodes) self.add_links(in_rspec.version.get_links()) - + # Leases leases = in_rspec.version.get_leases() self.add_leases(leases) # #rspec = RSpec(in_rspec) - #for child in rspec.xml.iterchildren(): + # for child in rspec.xml.iterchildren(): # self.xml.root.append(child) - - def cleanup(self): # remove unncecessary elements, attributes @@ -237,26 +243,27 @@ class PGv2(RSpecVersion): # remove 'available' element from remaining node elements self.xml.remove_element('//default:available | //available') + class PGv2Ad(PGv2): enabled = True content_type = 'ad' schema = 'http://www.protogeni.net/resources/rspec/2/ad.xsd' template = '' + class PGv2Request(PGv2): enabled = True content_type = 'request' schema = 'http://www.protogeni.net/resources/rspec/2/request.xsd' template = '' + class PGv2Manifest(PGv2): enabled = True content_type = 'manifest' schema = 'http://www.protogeni.net/resources/rspec/2/manifest.xsd' template = '' - - if __name__ == '__main__': from sfa.rspecs.rspec import RSpec diff --git a/sfa/rspecs/versions/pgv3.py b/sfa/rspecs/versions/pgv3.py index eb15b82f..149d8777 100644 --- a/sfa/rspecs/versions/pgv3.py +++ b/sfa/rspecs/versions/pgv3.py @@ -1,5 +1,6 @@ from sfa.rspecs.versions.pgv2 import PGv2 + class GENIv3(PGv2): type = 'GENI' content_type = 'ad' @@ -41,12 +42,14 @@ class GENIv3Ad(GENIv3): """ + class GENIv3Request(GENIv3): enabled = True content_type = 'request' schema = 'http://www.geni.net/resources/rspec/3/request.xsd' template = '' + class GENIv2Manifest(GENIv3): enabled = True content_type = 'manifest' diff --git a/sfa/rspecs/versions/sfav1.py b/sfa/rspecs/versions/sfav1.py index 645a6604..588e3c05 100644 --- a/sfa/rspecs/versions/sfav1.py +++ b/sfa/rspecs/versions/sfav1.py @@ -12,6 +12,7 @@ from sfa.rspecs.elements.versions.sfav1Node import SFAv1Node from sfa.rspecs.elements.versions.sfav1Sliver import SFAv1Sliver from sfa.rspecs.elements.versions.sfav1Lease import SFAv1Lease + class SFAv1(RSpecVersion): enabled = True type = 'SFA' @@ -23,13 +24,12 @@ class SFAv1(RSpecVersion): namespaces = None template = '' % type - # Network + # Network def get_networks(self): network_elems = self.xml.xpath('//network') - networks = [network_elem.get_instance(fields=['name', 'slice']) for \ + networks = [network_elem.get_instance(fields=['name', 'slice']) for network_elem in network_elems] - return networks - + return networks def add_network(self, network): network_tags = self.xml.xpath('//network[@name="%s"]' % network) @@ -39,16 +39,15 @@ class SFAv1(RSpecVersion): network_tag = network_tags[0] return network_tag - # Nodes - + def get_nodes(self, filter=None): return SFAv1Node.get_nodes(self.xml, filter) def get_nodes_with_slivers(self): return SFAv1Node.get_nodes_with_slivers(self.xml) - def add_nodes(self, nodes, network = None, no_dupes=False, rspec_content_type=None): + def add_nodes(self, nodes, network=None, no_dupes=False, rspec_content_type=None): SFAv1Node.add_nodes(self.xml, nodes, rspec_content_type) def merge_node(self, source_node_tag, network, no_dupes=False): @@ -60,9 +59,10 @@ class SFAv1(RSpecVersion): network_tag.append(deepcopy(source_node_tag)) # Slivers - + def add_slivers(self, hostnames, attributes=None, sliver_urn=None, append=False): - if attributes is None: attributes=[] + if attributes is None: + attributes = [] # add slice name to network tag network_tags = self.xml.xpath('//network') if network_tags: @@ -70,7 +70,7 @@ class SFAv1(RSpecVersion): network_tag.set('slice', urn_to_hrn(sliver_urn)[0]) # add slivers - sliver = {'name':sliver_urn, + sliver = {'name': sliver_urn, 'pl_tags': attributes} for hostname in hostnames: if sliver_urn: @@ -88,10 +88,9 @@ class SFAv1(RSpecVersion): parent = node_elem.element.getparent() parent.remove(node_elem.element) - def remove_slivers(self, slivers, network=None, no_dupes=False): SFAv1Node.remove_slivers(self.xml, slivers) - + def get_slice_attributes(self, network=None): attributes = [] nodes_with_slivers = self.get_nodes_with_slivers() @@ -100,14 +99,13 @@ class SFAv1(RSpecVersion): attribute['node_id'] = None attributes.append(attribute) for node in nodes_with_slivers: - nodename=node['component_name'] + nodename = node['component_name'] sliver_attributes = self.get_sliver_attributes(nodename, network) for sliver_attribute in sliver_attributes: sliver_attribute['node_id'] = nodename attributes.append(sliver_attribute) return attributes - def add_sliver_attribute(self, component_id, name, value, network=None): nodes = self.get_nodes({'component_id': '*%s*' % component_id}) if nodes is not None and isinstance(nodes, list) and len(nodes) > 0: @@ -118,7 +116,8 @@ class SFAv1(RSpecVersion): SFAv1Sliver.add_sliver_attribute(sliver, name, value) else: # should this be an assert / raise an exception? - logger.error("WARNING: failed to find component_id %s" % component_id) + logger.error("WARNING: failed to find component_id %s" % + component_id) def get_sliver_attributes(self, component_id, network=None): nodes = self.get_nodes({'component_id': '*%s*' % component_id}) @@ -135,20 +134,21 @@ class SFAv1(RSpecVersion): attribs = self.get_sliver_attributes(component_id) for attrib in attribs: if attrib['name'] == name and attrib['value'] == value: - #attrib.element.delete() + # attrib.element.delete() parent = attrib.element.getparent() parent.remove(attrib.element) def add_default_sliver_attribute(self, name, value, network=None): if network: - defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network) + defaults = self.xml.xpath( + "//network[@name='%s']/sliver_defaults" % network) else: defaults = self.xml.xpath("//sliver_defaults") if not defaults: if network: network_tag = self.xml.xpath("//network[@name='%s']" % network) else: - network_tag = self.xml.xpath("//network") + network_tag = self.xml.xpath("//network") if isinstance(network_tag, list): network_tag = network_tag[0] defaults = network_tag.add_element('sliver_defaults') @@ -158,17 +158,19 @@ class SFAv1(RSpecVersion): def get_default_sliver_attributes(self, network=None): if network: - defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network) + defaults = self.xml.xpath( + "//network[@name='%s']/sliver_defaults" % network) else: defaults = self.xml.xpath("//sliver_defaults") - if not defaults: return [] + if not defaults: + return [] return SFAv1Sliver.get_sliver_attributes(defaults[0]) - + def remove_default_sliver_attribute(self, name, value, network=None): attribs = self.get_default_sliver_attributes(network) for attrib in attribs: if attrib['name'] == name and attrib['value'] == value: - #attrib.element.delete() + # attrib.element.delete() parent = attrib.element.getparent() parent.remove(attrib.element) @@ -178,7 +180,7 @@ class SFAv1(RSpecVersion): return PGv2Link.get_links(self.xml) def get_link_requests(self): - return PGv2Link.get_link_requests(self.xml) + return PGv2Link.get_link_requests(self.xml) def add_links(self, links): networks = self.get_networks() @@ -225,7 +227,7 @@ class SFAv1(RSpecVersion): def get_leases(self, filter=None): return SFAv1Lease.get_leases(self.xml, filter) - def add_leases(self, leases, network = None, no_dupes=False): + def add_leases(self, leases, network=None, no_dupes=False): SFAv1Lease.add_leases(self.xml, leases) # Spectrum @@ -233,7 +235,7 @@ class SFAv1(RSpecVersion): def get_channels(self, filter=None): return [] - def add_channels(self, channels, network = None, no_dupes=False): + def add_channels(self, channels, network=None, no_dupes=False): pass if __name__ == '__main__': diff --git a/sfa/server/aggregate.py b/sfa/server/aggregate.py index 90fcaf49..a8f0a0a4 100644 --- a/sfa/server/aggregate.py +++ b/sfa/server/aggregate.py @@ -1,9 +1,11 @@ from sfa.server.sfaserver import SfaServer from sfa.util.xrn import hrn_to_urn from sfa.server.interface import Interfaces, Interface -from sfa.util.config import Config +from sfa.util.config import Config # this truly is a server-side object + + class Aggregate(SfaServer): ## @@ -12,21 +14,23 @@ class Aggregate(SfaServer): # @param ip the ip address to listen on # @param port the port to listen on # @param key_file private key filename of registry - # @param cert_file certificate filename containing public key (could be a GID file) + # @param cert_file certificate filename containing public key (could be a GID file) def __init__(self, ip, port, key_file, cert_file): - SfaServer.__init__(self, ip, port, key_file, cert_file,'aggregate') + SfaServer.__init__(self, ip, port, key_file, cert_file, 'aggregate') # # Aggregates is a dictionary of aggregate connections keyed on the aggregate hrn # as such it's more of a client-side thing for aggregate servers to reach their peers # + + class Aggregates(Interfaces): default_dict = {'aggregates': {'aggregate': [Interfaces.default_fields]}} - - def __init__(self, conf_file = "/etc/sfa/aggregates.xml"): + + def __init__(self, conf_file="/etc/sfa/aggregates.xml"): Interfaces.__init__(self, conf_file) - sfa_config = Config() + sfa_config = Config() # set up a connection to the local aggregate if sfa_config.SFA_AGGREGATE_ENABLED: addr = sfa_config.SFA_AGGREGATE_HOST diff --git a/sfa/server/api_versions.py b/sfa/server/api_versions.py index 9138a748..6e584922 100644 --- a/sfa/server/api_versions.py +++ b/sfa/server/api_versions.py @@ -2,13 +2,14 @@ import os from sfa.util.xml import XML from sfa.util.config import Config + class ApiVersions: required_fields = ['version', 'url'] - + template = """ -""" +""" def __init__(self, string=None, filename=None, create=False): self.xml = None @@ -22,9 +23,10 @@ class ApiVersions: else: # load the default file c = Config() - api_versions_file = os.path.sep.join([c.config_path, 'api_versions.xml']) + api_versions_file = os.path.sep.join( + [c.config_path, 'api_versions.xml']) self.load(api_versions_file) - + def create(self): self.xml = XML(string=ApiVersions.template) @@ -42,6 +44,4 @@ class ApiVersions: set(ApiVersions.required_fields).issubset(item.keys()) and \ item['version'] != '' and item['url'] != '': versions[str(item['version'])] = item['url'] - return versions - - + return versions diff --git a/sfa/server/component.py b/sfa/server/component.py index 9baa6c22..df91917b 100644 --- a/sfa/server/component.py +++ b/sfa/server/component.py @@ -7,16 +7,17 @@ import time import sys from sfa.server.sfaserver import SfaServer - + # GeniLight client support is optional try: from egeni.geniLight_client import * except ImportError: - GeniClientLight = None + GeniClientLight = None ## # Component is a SfaServer that serves component operations. + class Component(SfaServer): ## # Create a new registry object. @@ -27,4 +28,5 @@ class Component(SfaServer): # @param cert_file certificate filename containing public key (could be a GID file) def __init__(self, ip, port, key_file, cert_file): - SfaServer.__init__(self, ip, port, key_file, cert_file, interface='component') + SfaServer.__init__(self, ip, port, key_file, + cert_file, interface='component') diff --git a/sfa/server/interface.py b/sfa/server/interface.py index 2f461461..093dcb80 100644 --- a/sfa/server/interface.py +++ b/sfa/server/interface.py @@ -5,38 +5,43 @@ from sfa.util.xml import XML try: from egeni.geniLight_client import * except ImportError: - GeniClientLight = None + GeniClientLight = None + class Interface: """ Interface to another SFA service, typically a peer, or the local aggregate can retrieve a xmlrpclib.ServerProxy object for issuing calls there """ + def __init__(self, hrn, addr, port, client_type='sfa'): self.hrn = hrn self.addr = addr self.port = port self.client_type = client_type - + def get_url(self): address_parts = self.addr.split('/') address_parts[0] = address_parts[0] + ":" + str(self.port) - url = "http://%s" % "/".join(address_parts) + url = "http://%s" % "/".join(address_parts) return url def server_proxy(self, key_file, cert_file, timeout=30): - server = None - if self.client_type == 'geniclientlight' and GeniClientLight: + server = None + if self.client_type == 'geniclientlight' and GeniClientLight: # xxx url and self.api are undefined - server = GeniClientLight(url, self.api.key_file, self.api.cert_file) + server = GeniClientLight( + url, self.api.key_file, self.api.cert_file) else: - server = SfaServerProxy(self.get_url(), key_file, cert_file, timeout) - - return server + server = SfaServerProxy( + self.get_url(), key_file, cert_file, timeout) + + return server ## # In is a dictionary of registry connections keyed on the registry # hrn + class Interfaces(dict): """ Interfaces is a base class for managing information on the @@ -46,11 +51,11 @@ class Interfaces(dict): # fields that must be specified in the config file default_fields = { 'hrn': '', - 'addr': '', - 'port': '', + 'addr': '', + 'port': '', } - # defined by the class + # defined by the class default_dict = {} def __init__(self, conf_file): @@ -62,14 +67,15 @@ class Interfaces(dict): if isinstance(value, list): for record in value: if isinstance(record, dict) and \ - required_fields.issubset(record.keys()): - hrn, address, port = record['hrn'], record['addr'], record['port'] + required_fields.issubset(record.keys()): + hrn, address, port = record[ + 'hrn'], record['addr'], record['port'] # sometime this is called at a very early stage with no config loaded # avoid to remember this instance in such a case if not address or not port: - continue + continue interface = Interface(hrn, address, port) - self[hrn] = interface + self[hrn] = interface def server_proxy(self, hrn, key_file, cert_file, timeout=30): return self[hrn].server_proxy(key_file, cert_file, timeout) diff --git a/sfa/server/modpython/SfaAggregateModPython.py b/sfa/server/modpython/SfaAggregateModPython.py index 5d38a553..3a87017e 100755 --- a/sfa/server/modpython/SfaAggregateModPython.py +++ b/sfa/server/modpython/SfaAggregateModPython.py @@ -16,6 +16,7 @@ from sfa.planetlab.server import SfaApi api = SfaApi(interface='aggregate') + def handler(req): try: if req.method != "POST": @@ -52,5 +53,5 @@ def handler(req): except Exception as err: # Log error in /var/log/httpd/(ssl_)?error_log - logger.log_exc('%r'%err) + logger.log_exc('%r' % err) return apache.HTTP_INTERNAL_SERVER_ERROR diff --git a/sfa/server/modpython/SfaRegistryModPython.py b/sfa/server/modpython/SfaRegistryModPython.py index 6d17cd2c..49018b78 100755 --- a/sfa/server/modpython/SfaRegistryModPython.py +++ b/sfa/server/modpython/SfaRegistryModPython.py @@ -16,6 +16,7 @@ from sfa.planetlab.server import SfaApi api = SfaApi(interface='registry') + def handler(req): try: if req.method != "POST": @@ -52,5 +53,5 @@ def handler(req): except Exception as err: # Log error in /var/log/httpd/(ssl_)?error_log - logger.log_exc('%r'%err) + logger.log_exc('%r' % err) return apache.HTTP_INTERNAL_SERVER_ERROR diff --git a/sfa/server/modpython/SfaSliceMgrModPython.py b/sfa/server/modpython/SfaSliceMgrModPython.py index dcb85626..b0595229 100755 --- a/sfa/server/modpython/SfaSliceMgrModPython.py +++ b/sfa/server/modpython/SfaSliceMgrModPython.py @@ -16,6 +16,7 @@ from sfa.planetlab.server import SfaApi api = SfaApi(interface='slicemgr') + def handler(req): try: if req.method != "POST": @@ -52,5 +53,5 @@ def handler(req): except Exception as err: # Log error in /var/log/httpd/(ssl_)?error_log - logger.log_exc('%r'%err) + logger.log_exc('%r' % err) return apache.HTTP_INTERNAL_SERVER_ERROR diff --git a/sfa/server/registry.py b/sfa/server/registry.py index 13a75fc7..edf07117 100644 --- a/sfa/server/registry.py +++ b/sfa/server/registry.py @@ -3,12 +3,14 @@ # from sfa.server.sfaserver import SfaServer from sfa.server.interface import Interfaces, Interface -from sfa.util.config import Config +from sfa.util.config import Config # # Registry is a SfaServer that serves registry and slice operations at PLC. # this truly is a server-side object # + + class Registry(SfaServer): ## # Create a new registry object. @@ -17,11 +19,11 @@ class Registry(SfaServer): # @param port the port to listen on # @param key_file private key filename of registry # @param cert_file certificate filename containing public key (could be a GID file) - + def __init__(self, ip, port, key_file, cert_file): - SfaServer.__init__(self, ip, port, key_file, cert_file,'registry') - sfa_config=Config() - if Config().SFA_REGISTRY_ENABLED: + SfaServer.__init__(self, ip, port, key_file, cert_file, 'registry') + sfa_config = Config() + if Config().SFA_REGISTRY_ENABLED: from sfa.storage.alchemy import engine from sfa.storage.dbschema import DBSchema DBSchema().init_or_upgrade() @@ -30,13 +32,15 @@ class Registry(SfaServer): # Registries is a dictionary of registry connections keyed on the registry hrn # as such it's more of a client-side thing for registry servers to reach their peers # + + class Registries(Interfaces): - + default_dict = {'registries': {'registry': [Interfaces.default_fields]}} - def __init__(self, conf_file = "/etc/sfa/registries.xml"): - Interfaces.__init__(self, conf_file) - sfa_config = Config() + def __init__(self, conf_file="/etc/sfa/registries.xml"): + Interfaces.__init__(self, conf_file) + sfa_config = Config() if sfa_config.SFA_REGISTRY_ENABLED: addr = sfa_config.SFA_REGISTRY_HOST port = sfa_config.SFA_REGISTRY_PORT diff --git a/sfa/server/sfa-start.py b/sfa/server/sfa-start.py index 0dd53b6e..bca06ee8 100755 --- a/sfa/server/sfa-start.py +++ b/sfa/server/sfa-start.py @@ -25,9 +25,10 @@ # TODO: Can all three servers use the same "registry" certificate? ## -### xxx todo not in the config yet -component_port=12346 -import os, os.path +# xxx todo not in the config yet +component_port = 12346 +import os +import os.path import traceback import sys from optparse import OptionParser @@ -46,19 +47,25 @@ from sfa.server.aggregate import Aggregates from sfa.client.return_value import ReturnValue # after http://www.erlenstar.demon.co.uk/unix/faq_2.html + + def daemon(): """Daemonize the current process.""" - if os.fork() != 0: os._exit(0) + if os.fork() != 0: + os._exit(0) os.setsid() - if os.fork() != 0: os._exit(0) + if os.fork() != 0: + os._exit(0) os.umask(0) devnull = os.open(os.devnull, os.O_RDWR) os.dup2(devnull, 0) - # xxx fixme - this is just to make sure that nothing gets stupidly lost - should use devnull - logdir='/var/log/httpd' + # xxx fixme - this is just to make sure that nothing gets stupidly lost - + # should use devnull + logdir = '/var/log/httpd' # when installed in standalone we might not have httpd installed - if not os.path.isdir(logdir): os.mkdir('/var/log/httpd') - crashlog = os.open('%s/sfa_access_log'%logdir, os.O_RDWR | os.O_APPEND | os.O_CREAT, 0644) + if not os.path.isdir(logdir): + os.mkdir('/var/log/httpd') + crashlog = os.open('%s/sfa_access_log' % logdir, os.O_RDWR | os.O_APPEND | os.O_CREAT, 0644) os.dup2(crashlog, 1) os.dup2(crashlog, 2) @@ -72,7 +79,7 @@ def install_peer_certs(server_key_file, server_cert_file): # There should be a gid file in /etc/sfa/trusted_roots for every # peer registry found in in the registries.xml config file. If there # are any missing gids, request a new one from the peer registry. - api = SfaApi(key_file = server_key_file, cert_file = server_cert_file) + api = SfaApi(key_file=server_key_file, cert_file=server_cert_file) registries = Registries() aggregates = Aggregates() interfaces = dict(registries.items() + aggregates.items()) @@ -83,23 +90,27 @@ def install_peer_certs(server_key_file, server_cert_file): #gids = self.get_peer_gids(new_hrns) + gids_current peer_gids = [] if not new_hrns: - return + return trusted_certs_dir = api.config.get_trustedroots_dir() for new_hrn in new_hrns: - if not new_hrn: continue + if not new_hrn: + continue # the gid for this interface should already be installed - if new_hrn == api.config.SFA_INTERFACE_HRN: continue + if new_hrn == api.config.SFA_INTERFACE_HRN: + continue try: # get gid from the registry url = interfaces[new_hrn].get_url() - interface = interfaces[new_hrn].server_proxy(server_key_file, server_cert_file, timeout=30) + interface = interfaces[new_hrn].server_proxy( + server_key_file, server_cert_file, timeout=30) # skip non sfa aggregates server_version = api.get_cached_server_version(interface) if 'sfa' not in server_version: - logger.info("get_trusted_certs: skipping non sfa aggregate: %s" % new_hrn) + logger.info( + "get_trusted_certs: skipping non sfa aggregate: %s" % new_hrn) continue - + trusted_gids = ReturnValue.get_value(interface.get_trusted_certs()) if trusted_gids: # the gid we want should be the first one in the list, @@ -112,18 +123,20 @@ def install_peer_certs(server_key_file, server_cert_file): gid = GID(string=trusted_gid) peer_gids.append(gid) if gid.get_hrn() == new_hrn: - gid_filename = os.path.join(trusted_certs_dir, '%s.gid' % new_hrn) + gid_filename = os.path.join( + trusted_certs_dir, '%s.gid' % new_hrn) gid.save_to_file(gid_filename, save_parents=True) message = "installed trusted cert for %s" % new_hrn # log the message api.logger.info(message) except: message = "interface: %s\tunable to install trusted gid for %s" % \ - (api.interface, new_hrn) + (api.interface, new_hrn) api.logger.log_exc(message) # doesnt matter witch one update_cert_records(peer_gids) + def update_cert_records(gids): """ Make sure there is a record in the registry for the specified gids. @@ -144,42 +157,43 @@ def update_cert_records(gids): # remove old records for record in records_found: if record.hrn not in hrns_expected and \ - record.hrn != self.api.config.SFA_INTERFACE_HRN: + record.hrn != self.api.config.SFA_INTERFACE_HRN: dbsession.delete(record) - # TODO: store urn in the db so we do this in 1 query + # TODO: store urn in the db so we do this in 1 query for gid in gids: hrn, type = gid.get_hrn(), gid.get_type() - record = dbsession.query(RegRecord).filter_by(hrn=hrn, type=type,pointer=-1).first() + record = dbsession.query(RegRecord).filter_by( + hrn=hrn, type=type, pointer=-1).first() if not record: - record = RegRecord (dict= {'type':type, - 'hrn': hrn, - 'authority': get_authority(hrn), - 'gid': gid.save_to_string(save_parents=True), - }) + record = RegRecord(dict={'type': type, + 'hrn': hrn, + 'authority': get_authority(hrn), + 'gid': gid.save_to_string(save_parents=True), + }) dbsession.add(record) dbsession.commit() - + + def main(): # Generate command line parser parser = OptionParser(usage="sfa-start.py [options]") parser.add_option("-r", "--registry", dest="registry", action="store_true", - help="run registry server", default=False) + help="run registry server", default=False) parser.add_option("-s", "--slicemgr", dest="sm", action="store_true", - help="run slice manager", default=False) + help="run slice manager", default=False) parser.add_option("-a", "--aggregate", dest="am", action="store_true", - help="run aggregate manager", default=False) + help="run aggregate manager", default=False) parser.add_option("-c", "--component", dest="cm", action="store_true", - help="run component server", default=False) + help="run component server", default=False) parser.add_option("-t", "--trusted-certs", dest="trusted_certs", action="store_true", - help="refresh trusted certs", default=False) + help="refresh trusted certs", default=False) parser.add_option("-d", "--daemon", dest="daemon", action="store_true", - help="Run as daemon.", default=False) + help="Run as daemon.", default=False) (options, args) = parser.parse_args() - + config = Config() logger.setLevelFromOptVerbose(config.SFA_API_LOGLEVEL) - # ge the server's key and cert hierarchy = Hierarchy() @@ -190,20 +204,23 @@ def main(): # ensure interface cert is present in trusted roots dir trusted_roots = TrustedRoots(config.get_trustedroots_dir()) trusted_roots.add_gid(GID(filename=server_cert_file)) - if (options.daemon): daemon() - + if (options.daemon): + daemon() + if options.trusted_certs: - install_peer_certs(server_key_file, server_cert_file) - + install_peer_certs(server_key_file, server_cert_file) + # start registry server if (options.registry): from sfa.server.registry import Registry - r = Registry("", config.SFA_REGISTRY_PORT, server_key_file, server_cert_file) + r = Registry("", config.SFA_REGISTRY_PORT, + server_key_file, server_cert_file) r.start() if (options.am): from sfa.server.aggregate import Aggregate - a = Aggregate("", config.SFA_AGGREGATE_PORT, server_key_file, server_cert_file) + a = Aggregate("", config.SFA_AGGREGATE_PORT, + server_key_file, server_cert_file) a.start() # start slice manager @@ -214,7 +231,8 @@ def main(): if (options.cm): from sfa.server.component import Component - c = Component("", config.component_port, server_key_file, server_cert_file) + c = Component("", config.component_port, + server_key_file, server_cert_file) # c = Component("", config.SFA_COMPONENT_PORT, server_key_file, server_cert_file) c.start() diff --git a/sfa/server/sfa_component_setup.py b/sfa/server/sfa_component_setup.py index e35a40a6..236d253a 100755 --- a/sfa/server/sfa_component_setup.py +++ b/sfa/server/sfa_component_setup.py @@ -21,9 +21,11 @@ from sfa.planetlab.plxrn import hrn_to_pl_slicename, slicename_to_hrn KEYDIR = "/var/lib/sfa/" CONFDIR = "/etc/sfa/" + def handle_gid_mismatch_exception(f): def wrapper(*args, **kwds): - try: return f(*args, **kwds) + try: + return f(*args, **kwds) except ConnectionKeyGIDMismatch: # clean regen server keypair and try again print("cleaning keys and trying again") @@ -32,14 +34,15 @@ def handle_gid_mismatch_exception(f): return wrapper -def server_proxy(url=None, port=None, keyfile=None, certfile=None,verbose=False): + +def server_proxy(url=None, port=None, keyfile=None, certfile=None, verbose=False): """ returns an xmlrpc connection to the service a the specified address """ if url: url_parts = url.split(":") - if len(url_parts) >1: + if len(url_parts) > 1: pass else: url = "http://%(url)s:%(port)s" % locals() @@ -53,8 +56,8 @@ def server_proxy(url=None, port=None, keyfile=None, certfile=None,verbose=False) print("Contacting registry at: %(url)s" % locals()) server = SfaServerProxy(url, keyfile, certfile) - return server - + return server + def create_default_dirs(): config = Config() @@ -67,9 +70,11 @@ def create_default_dirs(): if not os.path.exists(dir): os.makedirs(dir) + def has_node_key(): key_file = KEYDIR + os.sep + 'server.key' - return os.path.exists(key_file) + return os.path.exists(key_file) + def clean_key_cred(): """ @@ -80,17 +85,17 @@ def clean_key_cred(): filepath = KEYDIR + os.sep + f if os.path.isfile(filepath): os.unlink(f) - + # install the new key pair # GetCredential will take care of generating the new keypair - # and credential + # and credential GetCredential() - - + + def get_node_key(registry=None, verbose=False): - # this call requires no authentication, + # this call requires no authentication, # so we can generate a random keypair here - subject="component" + subject = "component" (kfd, keyfile) = tempfile.mkstemp() (cfd, certfile) = tempfile.mkstemp() key = Keypair(create=True) @@ -100,10 +105,11 @@ def get_node_key(registry=None, verbose=False): cert.set_pubkey(key) cert.sign() cert.save_to_file(certfile) - - registry = server_proxy(url = registry, keyfile=keyfile, certfile=certfile) + + registry = server_proxy(url=registry, keyfile=keyfile, certfile=certfile) registry.get_key_from_incoming_ip() + def create_server_keypair(keyfile=None, certfile=None, hrn="component", verbose=False): """ create the server key/cert pair in the right place @@ -114,49 +120,52 @@ def create_server_keypair(keyfile=None, certfile=None, hrn="component", verbose= cert.set_issuer(key=key, subject=hrn) cert.set_pubkey(key) cert.sign() - cert.save_to_file(certfile, save_parents=True) + cert.save_to_file(certfile, save_parents=True) + @handle_gid_mismatch_exception def GetCredential(registry=None, force=False, verbose=False): config = Config() hierarchy = Hierarchy() - key_dir= hierarchy.basedir + key_dir = hierarchy.basedir data_dir = config.data_path config_dir = config.config_path credfile = data_dir + os.sep + 'node.cred' # check for existing credential if not force and os.path.exists(credfile): if verbose: - print("Loading Credential from %(credfile)s " % locals()) + print("Loading Credential from %(credfile)s " % locals()) cred = Credential(filename=credfile).save_to_string(save_parents=True) else: if verbose: - print("Getting credential from registry") + print("Getting credential from registry") # make sure node private key exists node_pkey_file = config_dir + os.sep + "node.key" node_gid_file = config_dir + os.sep + "node.gid" if not os.path.exists(node_pkey_file) or \ not os.path.exists(node_gid_file): get_node_key(registry=registry, verbose=verbose) - + gid = GID(filename=node_gid_file) hrn = gid.get_hrn() # create server key and certificate - keyfile =data_dir + os.sep + "server.key" + keyfile = data_dir + os.sep + "server.key" certfile = data_dir + os.sep + "server.cert" key = Keypair(filename=node_pkey_file) key.save_to_file(keyfile) create_server_keypair(keyfile, certfile, hrn, verbose) - # get credential from registry - registry = server_proxy(url=registry, keyfile=keyfile, certfile=certfile) + # get credential from registry + registry = server_proxy( + url=registry, keyfile=keyfile, certfile=certfile) cert = Certificate(filename=certfile) cert_str = cert.save_to_string(save_parents=True) cred = registry.GetSelfCredential(cert_str, 'node', hrn) Credential(string=cred).save_to_file(credfile, save_parents=True) - + return cred + @handle_gid_mismatch_exception def get_trusted_certs(registry=None, verbose=False): """ @@ -175,13 +184,14 @@ def get_trusted_certs(registry=None, verbose=False): # get credential cred = GetCredential(registry=registry, verbose=verbose) # make sure server key cert pair exists - create_server_keypair(keyfile=keyfile, certfile=certfile, hrn=hrn, verbose=verbose) + create_server_keypair( + keyfile=keyfile, certfile=certfile, hrn=hrn, verbose=verbose) registry = server_proxy(url=registry, keyfile=keyfile, certfile=certfile) # get the trusted certs and save them in the right place if verbose: print("Getting trusted certs from registry") trusted_certs = registry.get_trusted_certs(cred) - trusted_gid_names = [] + trusted_gid_names = [] for gid_str in trusted_certs: gid = GID(string=gid_str) gid.decode() @@ -189,7 +199,7 @@ def get_trusted_certs(registry=None, verbose=False): trusted_gid_names.append(relative_filename) gid_filename = trusted_certs_dir + os.sep + relative_filename if verbose: - print("Writing GID for %s as %s" % (gid.get_hrn(), gid_filename)) + print("Writing GID for %s as %s" % (gid.get_hrn(), gid_filename)) gid.save_to_file(gid_filename, save_parents=True) # remove old certs @@ -198,7 +208,8 @@ def get_trusted_certs(registry=None, verbose=False): if gid_name not in trusted_gid_names: if verbose: print("Removing old gid ", gid_name) - os.unlink(trusted_certs_dir + os.sep + gid_name) + os.unlink(trusted_certs_dir + os.sep + gid_name) + @handle_gid_mismatch_exception def get_gids(registry=None, verbose=False): @@ -220,14 +231,15 @@ def get_gids(registry=None, verbose=False): # get credential cred = GetCredential(registry=registry, verbose=verbose) # make sure server key cert pair exists - create_server_keypair(keyfile=keyfile, certfile=certfile, hrn=hrn, verbose=verbose) + create_server_keypair( + keyfile=keyfile, certfile=certfile, hrn=hrn, verbose=verbose) registry = server_proxy(url=registry, keyfile=keyfile, certfile=certfile) - + if verbose: print("Getting current slices on this node") # get a list of slices on this node from sfa.generic import Generic - generic=Generic.the_flavour() + generic = Generic.the_flavour() api = generic.make_api(interface='component') xids_tuple = api.driver.nodemanager.GetXIDs() slices = eval(xids_tuple[1]) @@ -237,19 +249,19 @@ def get_gids(registry=None, verbose=False): slices_without_gids = [] for slicename in slicenames: if not os.path.isfile("/vservers/%s/etc/slice.gid" % slicename) \ - or not os.path.isfile("/vservers/%s/etc/node.gid" % slicename): - slices_without_gids.append(slicename) - + or not os.path.isfile("/vservers/%s/etc/node.gid" % slicename): + slices_without_gids.append(slicename) + # convert slicenames to hrns - hrns = [slicename_to_hrn(interface_hrn, slicename) \ + hrns = [slicename_to_hrn(interface_hrn, slicename) for slicename in slices_without_gids] - + # exit if there are no gids to install if not hrns: return - + if verbose: - print("Getting gids for slices on this node from registry") + print("Getting gids for slices on this node from registry") # get the gids # and save them in the right palce records = registry.GetGids(hrns, cred) @@ -261,7 +273,7 @@ def get_gids(registry=None, verbose=False): # if this slice isnt really instatiated skip it if not os.path.exists("/vservers/%(slicename)s" % locals()): continue - + # save the slice gid in /etc/sfa/ in the vservers filesystem vserver_path = "/vservers/%(slicename)s" % locals() gid = record['gid'] @@ -273,8 +285,8 @@ def get_gids(registry=None, verbose=False): node_gid_filename = os.sep.join([vserver_path, "etc", "node.gid"]) if verbose: print("Saving node GID for %(slicename)s as %(node_gid_filename)s" % locals()) - node_gid.save_to_file(node_gid_filename, save_parents=True) - + node_gid.save_to_file(node_gid_filename, save_parents=True) + def dispatch(options, args): @@ -287,26 +299,27 @@ def dispatch(options, args): if options.verbose: print("Getting the component's trusted certs") get_trusted_certs(verbose=options.verbose) - if options.gids: + if options.gids: if options.verbose: print("Geting the component's GIDs") get_gids(verbose=options.verbose) + def main(): args = sys.argv prog_name = args[0] parser = OptionParser(usage="%(prog_name)s [options]" % locals()) parser.add_option("-v", "--verbose", dest="verbose", action="store_true", - default=False, help="Be verbose") + default=False, help="Be verbose") parser.add_option("-r", "--registry", dest="registry", default=None, - help="Url of registry to contact") - parser.add_option("-k", "--key", dest="key", action="store_true", - default=False, - help="Get the node's pkey from the registry") + help="Url of registry to contact") + parser.add_option("-k", "--key", dest="key", action="store_true", + default=False, + help="Get the node's pkey from the registry") parser.add_option("-c", "--certs", dest="certs", action="store_true", default=False, help="Get the trusted certs from the registry") - parser.add_option("-g", "--gids", dest="gids", action="store_true", + parser.add_option("-g", "--gids", dest="gids", action="store_true", default=False, help="Get gids for all the slices on the component") @@ -315,4 +328,4 @@ def main(): dispatch(options, args) if __name__ == '__main__': - main() + main() diff --git a/sfa/server/sfaapi.py b/sfa/server/sfaapi.py index 499e3bdb..a2950762 100644 --- a/sfa/server/sfaapi.py +++ b/sfa/server/sfaapi.py @@ -1,4 +1,5 @@ -import os, os.path +import os +import os.path import datetime from sfa.util.faults import SfaFault, SfaAPIError, RecordNotFound @@ -17,7 +18,9 @@ from sfa.client.return_value import ReturnValue from sfa.storage.alchemy import alchemy #################### -class SfaApi (XmlrpcApi): + + +class SfaApi (XmlrpcApi): """ An SfaApi instance is a basic xmlrpc service augmented with the local cryptographic material and hrn @@ -35,13 +38,13 @@ class SfaApi (XmlrpcApi): (*) an instance of a testbed driver """ - def __init__ (self, encoding="utf-8", methods='sfa.methods', - config = "/etc/sfa/sfa_config", - peer_cert = None, interface = None, - key_file = None, cert_file = None, cache = None): - - XmlrpcApi.__init__ (self, encoding) - + def __init__(self, encoding="utf-8", methods='sfa.methods', + config="/etc/sfa/sfa_config", + peer_cert=None, interface=None, + key_file=None, cert_file=None, cache=None): + + XmlrpcApi.__init__(self, encoding) + # we may be just be documenting the API if config is None: return @@ -61,22 +64,22 @@ class SfaApi (XmlrpcApi): # load registries from sfa.server.registry import Registries - self.registries = Registries() + self.registries = Registries() # load aggregates from sfa.server.aggregate import Aggregates self.aggregates = Aggregates() - + # filled later on by generic/Generic - self.manager=None - self._dbsession=None + self.manager = None + self._dbsession = None def server_proxy(self, interface, cred, timeout=30): """ Returns a connection to the specified interface. Use the specified credential to determine the caller and look for the caller's key/cert in the registry hierarchy cache. - """ + """ from sfa.trust.hierarchy import Hierarchy if not isinstance(cred, Credential): cred_obj = Credential(string=cred) @@ -89,16 +92,17 @@ class SfaApi (XmlrpcApi): cert_file = auth_info.get_gid_filename() server = interface.server_proxy(key_file, cert_file, timeout) return server - + def dbsession(self): if self._dbsession is None: - self._dbsession=alchemy.session() + self._dbsession = alchemy.session() return self._dbsession def close_dbsession(self): - if self._dbsession is None: return + if self._dbsession is None: + return alchemy.close_session(self._dbsession) - self._dbsession=None + self._dbsession = None def getCredential(self, minimumExpiration=0): """ @@ -107,10 +111,10 @@ class SfaApi (XmlrpcApi): type = 'authority' path = self.config.SFA_DATA_DIR filename = ".".join([self.interface, self.hrn, type, "cred"]) - cred_filename = os.path.join(path,filename) + cred_filename = os.path.join(path, filename) cred = None if os.path.isfile(cred_filename): - cred = Credential(filename = cred_filename) + cred = Credential(filename=cred_filename) # make sure cred isnt expired if not cred.get_expiration or \ datetime.datetime.utcnow() + datetime.timedelta(seconds=minimumExpiration) < cred.get_expiration(): @@ -118,41 +122,42 @@ class SfaApi (XmlrpcApi): # get a new credential if self.interface in ['registry']: - cred = self._getCredentialRaw() + cred = self._getCredentialRaw() else: - cred = self._getCredential() + cred = self._getCredential() cred.save_to_file(cred_filename, save_parents=True) return cred.save_to_string(save_parents=True) - def getDelegatedCredential(self, creds): """ Attempt to find a credential delegated to us in the specified list of creds. """ from sfa.trust.hierarchy import Hierarchy - if creds and not isinstance(creds, list): + if creds and not isinstance(creds, list): creds = [creds] hierarchy = Hierarchy() - + delegated_cred = None for cred in creds: if hierarchy.auth_exists(Credential(cred=cred).get_gid_caller().get_hrn()): delegated_cred = cred break return delegated_cred - + def _getCredential(self): """ Get our credential from a remote registry """ from sfa.server.registry import Registries registries = Registries() - registry = registries.server_proxy(self.hrn, self.key_file, self.cert_file) - cert_string=self.cert.save_to_string(save_parents=True) + registry = registries.server_proxy( + self.hrn, self.key_file, self.cert_file) + cert_string = self.cert.save_to_string(save_parents=True) # get self credential - self_cred = registry.GetSelfCredential(cert_string, self.hrn, 'authority') + self_cred = registry.GetSelfCredential( + cert_string, self.hrn, 'authority') # get credential cred = registry.GetCredential(self_cred, self.hrn, 'authority') return Credential(string=cred) @@ -164,32 +169,34 @@ class SfaApi (XmlrpcApi): hrn = self.hrn auth_hrn = self.auth.get_authority(hrn) - + # is this a root or sub authority if not auth_hrn or hrn == self.config.SFA_INTERFACE_HRN: auth_hrn = hrn auth_info = self.auth.get_auth_info(auth_hrn) # xxx although unlikely we might want to check for a potential leak - dbsession=self.dbsession() + dbsession = self.dbsession() from sfa.storage.model import RegRecord - record = dbsession.query(RegRecord).filter_by(type='authority+sa', hrn=hrn).first() + record = dbsession.query(RegRecord).filter_by( + type='authority+sa', hrn=hrn).first() if not record: raise RecordNotFound(hrn) type = record.type object_gid = record.get_gid_object() - new_cred = Credential(subject = object_gid.get_subject()) + new_cred = Credential(subject=object_gid.get_subject()) new_cred.set_gid_caller(object_gid) new_cred.set_gid_object(object_gid) - new_cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename()) - + new_cred.set_issuer_keys( + auth_info.get_privkey_filename(), auth_info.get_gid_filename()) + r1 = determine_rights(type, hrn) new_cred.set_privileges(r1) new_cred.encode() new_cred.sign() return new_cred - - def loadCredential (self): + + def loadCredential(self): """ Attempt to load credential from file if it exists. If it doesnt get credential from registry. @@ -199,9 +206,9 @@ class SfaApi (XmlrpcApi): # XX This is really the aggregate's credential. Using this is easier than getting # the registry's credential from iteslf (ssl errors). filename = self.interface + self.hrn + ".ma.cred" - ma_cred_path = os.path.join(self.config.SFA_DATA_DIR,filename) + ma_cred_path = os.path.join(self.config.SFA_DATA_DIR, filename) try: - self.credential = Credential(filename = ma_cred_path) + self.credential = Credential(filename=ma_cred_path) except IOError: self.credential = self.getCredentialFromRegistry() @@ -214,20 +221,19 @@ class SfaApi (XmlrpcApi): result = server.GetVersion() server_version = ReturnValue.get_value(result) # cache version for 24 hours - self.cache.add(cache_key, server_version, ttl= 60*60*24) + self.cache.add(cache_key, server_version, ttl=60 * 60 * 24) return server_version - def get_geni_code(self, result): code = { - 'geni_code': GENICODE.SUCCESS, + 'geni_code': GENICODE.SUCCESS, 'am_type': 'sfa', 'am_code': None, } if isinstance(result, SfaFault): code['geni_code'] = result.faultCode - code['am_code'] = result.faultCode - + code['am_code'] = result.faultCode + return code def get_geni_value(self, result): @@ -239,26 +245,25 @@ class SfaApi (XmlrpcApi): def get_geni_output(self, result): output = "" if isinstance(result, SfaFault): - output = result.faultString + output = result.faultString return output def prepare_response_am(self, result): - version = version_core() + version = version_core() response = { - 'geni_api': 3, + 'geni_api': 3, 'code': self.get_geni_code(result), 'value': self.get_geni_value(result), 'output': self.get_geni_output(result), } return response - + def prepare_response(self, result, method=""): """ Converts the specified result into a standard GENI compliant response """ # as of dec 13 2011 we only support API v2 - if self.interface.lower() in ['aggregate', 'slicemgr']: + if self.interface.lower() in ['aggregate', 'slicemgr']: result = self.prepare_response_am(result) return XmlrpcApi.prepare_response(self, result, method) - diff --git a/sfa/server/sfaserver.py b/sfa/server/sfaserver.py index 8e7c5f82..223ae774 100644 --- a/sfa/server/sfaserver.py +++ b/sfa/server/sfaserver.py @@ -20,6 +20,7 @@ from sfa.trust.certificate import Keypair, Certificate # the credential, and verify that the user is using the key that matches the # GID supplied in the credential. + class SfaServer(threading.Thread): ## @@ -28,19 +29,20 @@ class SfaServer(threading.Thread): # @param ip the ip address to listen on # @param port the port to listen on # @param key_file private key filename of registry - # @param cert_file certificate filename containing public key + # @param cert_file certificate filename containing public key # (could be a GID file) def __init__(self, ip, port, key_file, cert_file, interface): threading.Thread.__init__(self) - self.key = Keypair(filename = key_file) - self.cert = Certificate(filename = cert_file) + self.key = Keypair(filename=key_file) + self.cert = Certificate(filename=cert_file) #self.server = SecureXMLRPCServer((ip, port), SecureXMLRpcRequestHandler, key_file, cert_file) - self.server = ThreadedServer((ip, int(port)), SecureXMLRpcRequestHandler, key_file, cert_file) - self.server.interface=interface + self.server = ThreadedServer( + (ip, int(port)), SecureXMLRpcRequestHandler, key_file, cert_file) + self.server.interface = interface self.trusted_cert_list = None self.register_functions() - logger.info("Starting SfaServer, interface=%s"%interface) + logger.info("Starting SfaServer, interface=%s" % interface) ## # Register functions that will be served by the XMLRPC server. This @@ -57,9 +59,7 @@ class SfaServer(threading.Thread): return anything ## - # Execute the server, serving requests forever. + # Execute the server, serving requests forever. def run(self): self.server.serve_forever() - - diff --git a/sfa/server/slicemgr.py b/sfa/server/slicemgr.py index 9a7fa4a0..280c2a06 100644 --- a/sfa/server/slicemgr.py +++ b/sfa/server/slicemgr.py @@ -4,16 +4,16 @@ import datetime import time from sfa.server.sfaserver import SfaServer + class SliceMgr(SfaServer): - ## # Create a new slice manager object. # # @param ip the ip address to listen on # @param port the port to listen on # @param key_file private key filename of registry - # @param cert_file certificate filename containing public key (could be a GID file) + # @param cert_file certificate filename containing public key (could be a GID file) - def __init__(self, ip, port, key_file, cert_file, config = "/etc/sfa/sfa_config"): - SfaServer.__init__(self, ip, port, key_file, cert_file,'slicemgr') + def __init__(self, ip, port, key_file, cert_file, config="/etc/sfa/sfa_config"): + SfaServer.__init__(self, ip, port, key_file, cert_file, 'slicemgr') diff --git a/sfa/server/threadedserver.py b/sfa/server/threadedserver.py index 7a26ad27..5d07e69d 100644 --- a/sfa/server/threadedserver.py +++ b/sfa/server/threadedserver.py @@ -18,7 +18,7 @@ from OpenSSL import SSL from sfa.util.sfalogging import logger from sfa.util.config import Config -from sfa.util.cache import Cache +from sfa.util.cache import Cache from sfa.trust.certificate import Certificate from sfa.trust.trustedroots import TrustedRoots from sfa.util.py23 import xmlrpc_client @@ -31,12 +31,12 @@ from sfa.generic import Generic # we have our own authentication spec. Thus we disable several of the normal # prohibitions that OpenSSL places on certificates + def verify_callback(conn, x509, err, depth, preverify): # if the cert has been preverified, then it is ok if preverify: - #print " preverified" - return 1 - + # print " preverified" + return 1 # the certificate verification done by openssl checks a number of things # that we aren't interested in, so we look out for those error messages @@ -46,48 +46,52 @@ def verify_callback(conn, x509, err, depth, preverify): # xxx thierry: this most likely means the cert has a validity range in the future # by newer pl nodes. if err == 9: - #print " X509_V_ERR_CERT_NOT_YET_VALID" - return 1 + # print " X509_V_ERR_CERT_NOT_YET_VALID" + return 1 # allow self-signed certificates if err == 18: - #print " X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT" - return 1 + # print " X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT" + return 1 # allow certs that don't have an issuer if err == 20: - #print " X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY" - return 1 + # print " X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY" + return 1 # allow chained certs with self-signed roots if err == 19: return 1 - + # allow certs that are untrusted if err == 21: - #print " X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE" - return 1 + # print " X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE" + return 1 # allow certs that are untrusted if err == 27: - #print " X509_V_ERR_CERT_UNTRUSTED" - return 1 + # print " X509_V_ERR_CERT_UNTRUSTED" + return 1 # ignore X509_V_ERR_CERT_SIGNATURE_FAILURE if err == 7: - return 1 + return 1 - logger.debug(" error %s in verify_callback"%err) + logger.debug(" error %s in verify_callback" % err) return 0 ## -# taken from the web (XXX find reference). Implements HTTPS xmlrpc request handler +# taken from the web (XXX find reference). Implements HTTPS xmlrpc request +# handler + + class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): """Secure XML-RPC request handler class. It it very similar to SimpleXMLRPCRequestHandler but it uses HTTPS for transporting XML data. """ + def setup(self): self.connection = self.request self.rfile = socket._fileobject(self.request, "rb", self.rbufsize) @@ -101,33 +105,36 @@ class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): """ try: peer_cert = Certificate() - peer_cert.load_from_pyopenssl_x509(self.connection.get_peer_certificate()) - generic=Generic.the_flavour() - self.api = generic.make_api (peer_cert = peer_cert, - interface = self.server.interface, - key_file = self.server.key_file, - cert_file = self.server.cert_file, - cache = self.cache) - #logger.info("SecureXMLRpcRequestHandler.do_POST:") - #logger.info("interface=%s"%self.server.interface) - #logger.info("key_file=%s"%self.server.key_file) - #logger.info("api=%s"%self.api) - #logger.info("server=%s"%self.server) - #logger.info("handler=%s"%self) + peer_cert.load_from_pyopenssl_x509( + self.connection.get_peer_certificate()) + generic = Generic.the_flavour() + self.api = generic.make_api(peer_cert=peer_cert, + interface=self.server.interface, + key_file=self.server.key_file, + cert_file=self.server.cert_file, + cache=self.cache) + # logger.info("SecureXMLRpcRequestHandler.do_POST:") + # logger.info("interface=%s"%self.server.interface) + # logger.info("key_file=%s"%self.server.key_file) + # logger.info("api=%s"%self.api) + # logger.info("server=%s"%self.server) + # logger.info("handler=%s"%self) # get arguments request = self.rfile.read(int(self.headers["content-length"])) - remote_addr = (remote_ip, remote_port) = self.connection.getpeername() - self.api.remote_addr = remote_addr - response = self.api.handle(remote_addr, request, self.server.method_map) + remote_addr = ( + remote_ip, remote_port) = self.connection.getpeername() + self.api.remote_addr = remote_addr + response = self.api.handle( + remote_addr, request, self.server.method_map) except Exception as fault: # This should only happen if the module is buggy # internal error, report as HTTP server error logger.log_exc("server.do_POST") response = self.api.prepare_response(fault) - #self.send_response(500) - #self.end_headers() - - # avoid session/connection leaks : do this no matter what + # self.send_response(500) + # self.end_headers() + + # avoid session/connection leaks : do this no matter what finally: self.send_response(200) self.send_header("Content-type", "text/xml") @@ -138,11 +145,13 @@ class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): # close db connection self.api.close_dbsession() # shut down the connection - self.connection.shutdown() # Modified here! + self.connection.shutdown() # Modified here! ## # Taken from the web (XXX find reference). Implements an HTTPS xmlrpc server -class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLRPCDispatcher): + + +class SecureXMLRPCServer(BaseHTTPServer.HTTPServer, SimpleXMLRPCServer.SimpleXMLRPCDispatcher): def __init__(self, server_address, HandlerClass, key_file, cert_file, logRequests=True): """ @@ -150,8 +159,8 @@ class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLR It it very similar to SimpleXMLRPCServer but it uses HTTPS for transporting XML data. """ - logger.debug("SecureXMLRPCServer.__init__, server_address=%s, " - "cert_file=%s, key_file=%s"%(server_address,cert_file,key_file)) + logger.debug("SecureXMLRPCServer.__init__, server_address=%s, " + "cert_file=%s, key_file=%s" % (server_address, cert_file, key_file)) self.logRequests = logRequests self.interface = None self.key_file = key_file @@ -159,22 +168,25 @@ class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLR self.method_map = {} # add cache to the request handler HandlerClass.cache = Cache() - #for compatibility with python 2.4 (centos53) + # for compatibility with python 2.4 (centos53) if sys.version_info < (2, 5): SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self) else: - SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, True, None) + SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__( + self, True, None) SocketServer.BaseServer.__init__(self, server_address, HandlerClass) ctx = SSL.Context(SSL.SSLv23_METHOD) - ctx.use_privatekey_file(key_file) + ctx.use_privatekey_file(key_file) ctx.use_certificate_file(cert_file) # If you wanted to verify certs against known CAs.. this is how you would do it - #ctx.load_verify_locations('/etc/sfa/trusted_roots/plc.gpo.gid') + # ctx.load_verify_locations('/etc/sfa/trusted_roots/plc.gpo.gid') config = Config() - trusted_cert_files = TrustedRoots(config.get_trustedroots_dir()).get_file_list() + trusted_cert_files = TrustedRoots( + config.get_trustedroots_dir()).get_file_list() for cert_file in trusted_cert_files: ctx.load_verify_locations(cert_file) - ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback) + ctx.set_verify(SSL.VERIFY_PEER | + SSL.VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback) ctx.set_verify_depth(5) ctx.set_app_data(self) self.socket = SSL.Connection(ctx, socket.socket(self.address_family, @@ -188,20 +200,21 @@ class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLR # the client. def _dispatch(self, method, params): - logger.debug("SecureXMLRPCServer._dispatch, method=%s"%method) + logger.debug("SecureXMLRPCServer._dispatch, method=%s" % method) try: return SimpleXMLRPCServer.SimpleXMLRPCDispatcher._dispatch(self, method, params) except: # can't use format_exc() as it is not available in jython yet # (even in trunk). type, value, tb = sys.exc_info() - raise xmlrpc_client.Fault(1,''.join(traceback.format_exception(type, value, tb))) + raise xmlrpc_client.Fault(1, ''.join( + traceback.format_exception(type, value, tb))) # override this one from the python 2.7 code # originally defined in class TCPServer def shutdown_request(self, request): """Called to shutdown and close an individual request.""" - # ---------- + # ---------- # the std python 2.7 code just attempts a request.shutdown(socket.SHUT_WR) # this works fine with regular sockets # However we are dealing with an instance of OpenSSL.SSL.Connection instead @@ -209,24 +222,28 @@ class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLR # always perform as expected # ---------- std python 2.7 code try: - #explicitly shutdown. socket.close() merely releases - #the socket and waits for GC to perform the actual close. + # explicitly shutdown. socket.close() merely releases + # the socket and waits for GC to perform the actual close. request.shutdown(socket.SHUT_WR) except socket.error: - pass #some platforms may raise ENOTCONN here + pass # some platforms may raise ENOTCONN here # ---------- except TypeError: - # we are dealing with an OpenSSL.Connection object, + # we are dealing with an OpenSSL.Connection object, # try to shut it down but never mind if that fails - try: request.shutdown() - except: pass + try: + request.shutdown() + except: + pass # ---------- self.close_request(request) -## From Active State code: http://code.activestate.com/recipes/574454/ -# This is intended as a drop-in replacement for the ThreadingMixIn class in -# module SocketServer of the standard lib. Instead of spawning a new thread +# From Active State code: http://code.activestate.com/recipes/574454/ +# This is intended as a drop-in replacement for the ThreadingMixIn class in +# module SocketServer of the standard lib. Instead of spawning a new thread # for each request, requests are processed by of pool of reusable threads. + + class ThreadPoolMixIn(SocketServer.ThreadingMixIn): """ use a thread pool instead of a new thread on every request @@ -245,25 +262,24 @@ class ThreadPoolMixIn(SocketServer.ThreadingMixIn): self.requests = Queue() for x in range(self.numThreads): - t = threading.Thread(target = self.process_request_thread) + t = threading.Thread(target=self.process_request_thread) t.setDaemon(1) t.start() # server main loop while True: self.handle_request() - + self.server_close() - def process_request_thread(self): """ obtain request from queue instead of directly from server socket """ while True: - SocketServer.ThreadingMixIn.process_request_thread(self, *self.requests.get()) + SocketServer.ThreadingMixIn.process_request_thread( + self, *self.requests.get()) - def handle_request(self): """ simply collect requests and put them on the queue for the workers. @@ -275,5 +291,6 @@ class ThreadPoolMixIn(SocketServer.ThreadingMixIn): if self.verify_request(request, client_address): self.requests.put((request, client_address)) + class ThreadedServer(ThreadPoolMixIn, SecureXMLRPCServer): pass diff --git a/sfa/server/xmlrpcapi.py b/sfa/server/xmlrpcapi.py index 74e00266..b4fa748f 100644 --- a/sfa/server/xmlrpcapi.py +++ b/sfa/server/xmlrpcapi.py @@ -28,9 +28,11 @@ from sfa.util.py23 import xmlrpc_client # [#x7F-#x84], [#x86-#x9F], [#xFDD0-#xFDDF] invalid_xml_ascii = map(chr, range(0x0, 0x8) + [0xB, 0xC] + range(0xE, 0x1F)) -xml_escape_table = string.maketrans("".join(invalid_xml_ascii), "?" * len(invalid_xml_ascii)) +xml_escape_table = string.maketrans( + "".join(invalid_xml_ascii), "?" * len(invalid_xml_ascii)) -def xmlrpclib_escape(s, replace = string.replace): + +def xmlrpclib_escape(s, replace=string.replace): """ xmlrpclib does not handle invalid 7-bit control characters. This function augments xmlrpclib.escape, which by default only replaces @@ -45,6 +47,7 @@ def xmlrpclib_escape(s, replace = string.replace): # Replace invalid 7-bit control characters with '?' return s.translate(xml_escape_table) + def xmlrpclib_dump(self, value, write): """ xmlrpclib cannot marshal instances of subclasses of built-in @@ -80,24 +83,26 @@ def xmlrpclib_dump(self, value, write): # the expected behaviour under python3 xmlrpc_client.Marshaller._Marshaller__dump = xmlrpclib_dump + class XmlrpcApi: """ The XmlrpcApi class implements a basic xmlrpc (or soap) service """ protocol = None - - def __init__ (self, encoding="utf-8", methods='sfa.methods'): + + def __init__(self, encoding="utf-8", methods='sfa.methods'): self.encoding = encoding - self.source = None - + self.source = None + # flat list of method names - self.methods_module = methods_module = __import__(methods, fromlist=[methods]) + self.methods_module = methods_module = __import__( + methods, fromlist=[methods]) self.methods = methods_module.all self.logger = logger - + def callable(self, method): """ Return a new instance of the specified method. @@ -105,11 +110,12 @@ class XmlrpcApi: # Look up method if method not in self.methods: raise SfaInvalidAPIMethod(method) - + # Get new instance of method try: classname = method.split(".")[-1] - module = __import__(self.methods_module.__name__ + "." + method, globals(), locals(), [classname]) + module = __import__(self.methods_module.__name__ + + "." + method, globals(), locals(), [classname]) callablemethod = getattr(module, classname)(self) return getattr(module, classname)(self) except (ImportError, AttributeError): @@ -126,7 +132,6 @@ class XmlrpcApi: self.source = source return function(*args) - def handle(self, source, data, method_map): """ Handle an XML-RPC or SOAP request from the specified source. @@ -139,12 +144,13 @@ class XmlrpcApi: if method in method_map: method = method_map[method] methodresponse = True - + except Exception as e: if SOAPpy is not None: self.protocol = 'soap' interface = SOAPpy - (r, header, body, attrs) = parseSOAPRPC(data, header = 1, body = 1, attrs = 1) + (r, header, body, attrs) = parseSOAPRPC( + data, header=1, body=1, attrs=1) method = r._name args = r._aslist() # XXX Support named arguments @@ -155,34 +161,36 @@ class XmlrpcApi: result = self.call(source, method, *args) except SfaFault as fault: result = fault - self.logger.log_exc("XmlrpcApi.handle has caught Exception") + self.logger.log_exc("XmlrpcApi.handle has caught Exception") except Exception as fault: self.logger.log_exc("XmlrpcApi.handle has caught Exception") result = SfaAPIError(fault) - # Return result response = self.prepare_response(result, method) return response - + def prepare_response(self, result, method=""): """ convert result to a valid xmlrpc or soap response - """ - + """ + if self.protocol == 'xmlrpc': if not isinstance(result, SfaFault): result = (result,) - response = xmlrpc_client.dumps(result, methodresponse = True, encoding = self.encoding, allow_none = 1) + response = xmlrpc_client.dumps( + result, methodresponse=True, encoding=self.encoding, allow_none=1) elif self.protocol == 'soap': if isinstance(result, Exception): - result = faultParameter(NS.ENV_T + ":Server", "Method Failed", method) - result._setDetail("Fault %d: %s" % (result.faultCode, result.faultString)) + result = faultParameter( + NS.ENV_T + ":Server", "Method Failed", method) + result._setDetail("Fault %d: %s" % + (result.faultCode, result.faultString)) else: - response = buildSOAP(kw = {'%sResponse' % method: {'Result': result}}, encoding = self.encoding) + response = buildSOAP( + kw={'%sResponse' % method: {'Result': result}}, encoding=self.encoding) else: if isinstance(result, Exception): - raise result - - return response + raise result + return response diff --git a/sfa/storage/alchemy.py b/sfa/storage/alchemy.py index 64c39cf8..de2e55d9 100644 --- a/sfa/storage/alchemy.py +++ b/sfa/storage/alchemy.py @@ -9,72 +9,75 @@ from sfa.util.sfalogging import logger # this module is designed to be loaded when the configured db server is reachable # OTOH model can be loaded from anywhere including the client-side + class Alchemy: - def __init__ (self, config): + def __init__(self, config): dbname = "sfa" # will be created lazily on-demand self._session = None # the former PostgreSQL.py used the psycopg2 directly and was doing - #self.connection.set_client_encoding("UNICODE") + # self.connection.set_client_encoding("UNICODE") # it's unclear how to achieve this in sqlalchemy, nor if it's needed at all # http://www.sqlalchemy.org/docs/dialects/postgresql.html#unicode # we indeed have /var/lib/pgsql/data/postgresql.conf where # this setting is unset, it might be an angle to tweak that if need be # try a unix socket first - omitting the hostname does the trick - unix_url = "postgresql+psycopg2://%s:%s@:%s/%s"%\ - (config.SFA_DB_USER,config.SFA_DB_PASSWORD,config.SFA_DB_PORT,dbname) + unix_url = "postgresql+psycopg2://%s:%s@:%s/%s" %\ + (config.SFA_DB_USER, config.SFA_DB_PASSWORD, config.SFA_DB_PORT, dbname) # the TCP fallback method - tcp_url = "postgresql+psycopg2://%s:%s@%s:%s/%s"%\ - (config.SFA_DB_USER,config.SFA_DB_PASSWORD,config.SFA_DB_HOST,config.SFA_DB_PORT,dbname) - for url in [ unix_url, tcp_url ] : + tcp_url = "postgresql+psycopg2://%s:%s@%s:%s/%s" %\ + (config.SFA_DB_USER, config.SFA_DB_PASSWORD, + config.SFA_DB_HOST, config.SFA_DB_PORT, dbname) + for url in [unix_url, tcp_url]: try: - logger.debug("Trying db URL %s"%url) - self.engine = create_engine (url) + logger.debug("Trying db URL %s" % url) + self.engine = create_engine(url) self.check() - self.url=url + self.url = url return except: pass - self.engine=None - raise Exception("Could not connect to database %s as %s with psycopg2"%(dbname,config.SFA_DB_USER)) - + self.engine = None + raise Exception("Could not connect to database %s as %s with psycopg2" % ( + dbname, config.SFA_DB_USER)) # expects boolean True: debug is ON or False: debug is OFF - def debug (self, echo): - self.engine.echo=echo + def debug(self, echo): + self.engine.echo = echo - def check (self): - self.engine.execute ("select 1").scalar() + def check(self): + self.engine.execute("select 1").scalar() - def global_session (self): + def global_session(self): if self._session is None: - Session=sessionmaker () - self._session=Session(bind=self.engine) - logger.debug('alchemy.global_session created session %s'%self._session) + Session = sessionmaker() + self._session = Session(bind=self.engine) + logger.debug('alchemy.global_session created session %s' % + self._session) return self._session - def close_global_session (self): - if self._session is None: return - logger.debug('alchemy.close_global_session %s'%self._session) + def close_global_session(self): + if self._session is None: + return + logger.debug('alchemy.close_global_session %s' % self._session) self._session.close() - self._session=None + self._session = None # create a dbsession to be managed separately - def session (self): - Session=sessionmaker() - session=Session (bind=self.engine) - logger.debug('alchemy.session created session %s'%session) + def session(self): + Session = sessionmaker() + session = Session(bind=self.engine) + logger.debug('alchemy.session created session %s' % session) return session - def close_session (self, session): - logger.debug('alchemy.close_session closed session %s'%session) + def close_session(self, session): + logger.debug('alchemy.close_session closed session %s' % session) session.close() #################### from sfa.util.config import Config -alchemy=Alchemy (Config()) -engine=alchemy.engine -global_dbsession=alchemy.global_session() - +alchemy = Alchemy(Config()) +engine = alchemy.engine +global_dbsession = alchemy.global_session() diff --git a/sfa/storage/dbschema.py b/sfa/storage/dbschema.py index 5166b93c..1bed5e04 100644 --- a/sfa/storage/dbschema.py +++ b/sfa/storage/dbschema.py @@ -11,11 +11,11 @@ import migrate.versioning.api as migrate from sfa.util.sfalogging import logger import sfa.storage.model as model -########## this class takes care of database upgrades -### upgrade from a pre-2.1 db -# * 1.0 and up to 1.1-4: ('very old') +# this class takes care of database upgrades +# upgrade from a pre-2.1 db +# * 1.0 and up to 1.1-4: ('very old') # was piggybacking the planetlab5 database -# this is kind of out of our scope here, we don't have the credentials +# this is kind of out of our scope here, we don't have the credentials # to connect to planetlab5, but this is documented in # https://svn.planet-lab.org/wiki/SFATutorialConfigureSFA#Upgradingnotes # and essentially this is seamless to users @@ -27,106 +27,115 @@ import sfa.storage.model as model # we have an 'records' table, plus 'users' and the like # and once migrate has kicked in there is a table named (see migrate.cfg) # migrate_db_version (repository_id, repository_path, version) -### after 2.1 +# after 2.1 # Starting with 2.1, we use sqlalchemy-migrate scripts in a standard way -# Note that the model defined in sfa.storage.model needs to be maintained -# as the 'current/latest' version, and newly installed deployments will +# Note that the model defined in sfa.storage.model needs to be maintained +# as the 'current/latest' version, and newly installed deployments will # then 'jump' to the latest version number without going through the migrations ### -# An initial attempt to run this as a 001_*.py migrate script +# An initial attempt to run this as a 001_*.py migrate script # did not quite work out (essentially we need to set the current version # number out of the migrations logic) # also this approach has less stuff in the initscript, which seems just right + class DBSchema: - header="Upgrading to 2.1 or higher" + header = "Upgrading to 2.1 or higher" - def __init__ (self): + def __init__(self): from sfa.storage.alchemy import alchemy - self.url=alchemy.url - self.engine=alchemy.engine - self.repository="/usr/share/sfa/migrations" + self.url = alchemy.url + self.engine = alchemy.engine + self.repository = "/usr/share/sfa/migrations" - def current_version (self): + def current_version(self): try: - return migrate.db_version (self.url, self.repository) + return migrate.db_version(self.url, self.repository) except: return None - def table_exists (self, tablename): + def table_exists(self, tablename): try: - metadata = MetaData (bind=self.engine) - table=Table (tablename, metadata, autoload=True) + metadata = MetaData(bind=self.engine) + table = Table(tablename, metadata, autoload=True) return True except NoSuchTableError: return False - def drop_table (self, tablename): - if self.table_exists (tablename): - print("%s: Dropping table %s"%(DBSchema.header,tablename), file=sys.stderr) - self.engine.execute ("drop table %s cascade"%tablename) + def drop_table(self, tablename): + if self.table_exists(tablename): + print("%s: Dropping table %s" % + (DBSchema.header, tablename), file=sys.stderr) + self.engine.execute("drop table %s cascade" % tablename) else: - print("%s: no need to drop table %s"%(DBSchema.header,tablename), file=sys.stderr) - - def handle_old_releases (self): + print("%s: no need to drop table %s" % + (DBSchema.header, tablename), file=sys.stderr) + + def handle_old_releases(self): try: # try to find out which old version this can be - if not self.table_exists ('records'): - # this likely means + if not self.table_exists('records'): + # this likely means # (.) we've just created the db, so it's either a fresh install, or # (.) we come from a 'very old' depl. - # in either case, an import is required but there's nothing to clean up - print("%s: make sure to run import"%(DBSchema.header,), file=sys.stderr) - elif self.table_exists ('sfa_db_version'): + # in either case, an import is required but there's nothing to + # clean up + print("%s: make sure to run import" % + (DBSchema.header,), file=sys.stderr) + elif self.table_exists('sfa_db_version'): # we come from an 'old' version - self.drop_table ('records') - self.drop_table ('record_types') - self.drop_table ('sfa_db_version') + self.drop_table('records') + self.drop_table('record_types') + self.drop_table('sfa_db_version') else: # we should be good here pass except: - print("%s: unknown exception"%(DBSchema.header,), file=sys.stderr) - traceback.print_exc () + print("%s: unknown exception" % + (DBSchema.header,), file=sys.stderr) + traceback.print_exc() - # after this call the db schema and the version as known by migrate should + # after this call the db schema and the version as known by migrate should # reflect the current data model and the latest known version - def init_or_upgrade (self): + def init_or_upgrade(self): # check if under version control, and initialize it otherwise if self.current_version() is None: - before="Unknown" + before = "Unknown" # can be either a very old version, or a fresh install # for very old versions: self.handle_old_releases() - # in any case, initialize db from current code and reflect in migrate + # in any case, initialize db from current code and reflect in + # migrate model.init_tables(self.engine) - code_version = migrate.version (self.repository) - migrate.version_control (self.url, self.repository, code_version) - after="%s"%self.current_version() - logger.info("DBSchema : jumped to version %s"%(after)) + code_version = migrate.version(self.repository) + migrate.version_control(self.url, self.repository, code_version) + after = "%s" % self.current_version() + logger.info("DBSchema : jumped to version %s" % (after)) else: # use migrate in the usual way - before="%s"%self.current_version() - migrate.upgrade (self.url, self.repository) - after="%s"%self.current_version() + before = "%s" % self.current_version() + migrate.upgrade(self.url, self.repository) + after = "%s" % self.current_version() if before != after: - logger.info("DBSchema : upgraded version from %s to %s"%(before,after)) + logger.info("DBSchema : upgraded version from %s to %s" % + (before, after)) else: - logger.debug("DBSchema : no change needed in db schema (%s==%s)"%(before,after)) - + logger.debug( + "DBSchema : no change needed in db schema (%s==%s)" % (before, after)) + # this trashes the db altogether, from the current model in sfa.storage.model # I hope this won't collide with ongoing migrations and all - # actually, now that sfa uses its own db, this is essentially equivalent to + # actually, now that sfa uses its own db, this is essentially equivalent to # dropping the db entirely, modulo a 'service sfa start' - def nuke (self): + def nuke(self): model.drop_tables(self.engine) # so in this case it's like we haven't initialized the db at all try: - migrate.drop_version_control (self.url, self.repository) + migrate.drop_version_control(self.url, self.repository) except migrate.exceptions.DatabaseNotControlledError: logger.log_exc("Failed to drop version control") - + if __name__ == '__main__': DBSchema().init_or_upgrade() diff --git a/sfa/storage/migrations/versions/001_slice_researchers.py b/sfa/storage/migrations/versions/001_slice_researchers.py index 2106fc63..05583df4 100644 --- a/sfa/storage/migrations/versions/001_slice_researchers.py +++ b/sfa/storage/migrations/versions/001_slice_researchers.py @@ -1,4 +1,4 @@ -# this move is about adding a slice x users many to many relation ship for modelling +# this move is about adding a slice x users many to many relation ship for modelling # regular "membership" of users in a slice from sqlalchemy import Table, MetaData, Column, ForeignKey @@ -8,21 +8,25 @@ metadata = MetaData() # this is needed by migrate so it can locate 'records.record_id' records = \ - Table ( 'records', metadata, - Column ('record_id', Integer, primary_key=True), - ) + Table('records', metadata, + Column('record_id', Integer, primary_key=True), + ) # slice x user (researchers) association slice_researcher_table = \ - Table ( 'slice_researcher', metadata, - Column ('slice_id', Integer, ForeignKey ('records.record_id'), primary_key=True), - Column ('researcher_id', Integer, ForeignKey ('records.record_id'), primary_key=True), - ) + Table('slice_researcher', metadata, + Column('slice_id', Integer, ForeignKey( + 'records.record_id'), primary_key=True), + Column('researcher_id', Integer, ForeignKey( + 'records.record_id'), primary_key=True), + ) + def upgrade(migrate_engine): metadata.bind = migrate_engine slice_researcher_table.create() + def downgrade(migrate_engine): metadata.bind = migrate_engine slice_researcher_table.drop() diff --git a/sfa/storage/migrations/versions/002_authority_pis.py b/sfa/storage/migrations/versions/002_authority_pis.py index e19a6bcd..8a16e071 100644 --- a/sfa/storage/migrations/versions/002_authority_pis.py +++ b/sfa/storage/migrations/versions/002_authority_pis.py @@ -1,5 +1,6 @@ # this move is about adding a authority x user many to many relation ship for modelling PIs -# that is to say users who can vouch for other users in the authority, and can create slices +# that is to say users who can vouch for other users in the authority, and +# can create slices from sqlalchemy import Table, MetaData, Column, ForeignKey from sqlalchemy import Integer, String @@ -8,21 +9,25 @@ metadata = MetaData() # this is needed by migrate so it can locate 'records.record_id' records = \ - Table ( 'records', metadata, - Column ('record_id', Integer, primary_key=True), - ) + Table('records', metadata, + Column('record_id', Integer, primary_key=True), + ) # authority x user (PIs) association authority_pi_table = \ - Table ( 'authority_pi', metadata, - Column ('authority_id', Integer, ForeignKey ('records.record_id'), primary_key=True), - Column ('pi_id', Integer, ForeignKey ('records.record_id'), primary_key=True), - ) + Table('authority_pi', metadata, + Column('authority_id', Integer, ForeignKey( + 'records.record_id'), primary_key=True), + Column('pi_id', Integer, ForeignKey( + 'records.record_id'), primary_key=True), + ) + def upgrade(migrate_engine): metadata.bind = migrate_engine authority_pi_table.create() + def downgrade(migrate_engine): metadata.bind = migrate_engine authority_pi_table.drop() diff --git a/sfa/storage/migrations/versions/003_sliver_allocations.py b/sfa/storage/migrations/versions/003_sliver_allocations.py index 9dba9e8b..5e8b8754 100644 --- a/sfa/storage/migrations/versions/003_sliver_allocations.py +++ b/sfa/storage/migrations/versions/003_sliver_allocations.py @@ -4,18 +4,20 @@ from sqlalchemy import Integer, String metadata = MetaData() sliver_allocation_table = \ - Table ( 'sliver_allocation', metadata, - Column('sliver_id', String, primary_key=True), - Column('client_id', String), - Column('component_id', String), - Column('slice_urn', String), - Column('allocation_state', String), + Table('sliver_allocation', metadata, + Column('sliver_id', String, primary_key=True), + Column('client_id', String), + Column('component_id', String), + Column('slice_urn', String), + Column('allocation_state', String), ) + def upgrade(migrate_engine): metadata.bind = migrate_engine sliver_allocation_table.create() + def downgrade(migrate_engine): metadata.bind = migrate_engine sliver_allocation_table.drop() diff --git a/sfa/storage/migrations/versions/004_authority_name.py b/sfa/storage/migrations/versions/004_authority_name.py index 55ed1ef9..5c09e530 100644 --- a/sfa/storage/migrations/versions/004_authority_name.py +++ b/sfa/storage/migrations/versions/004_authority_name.py @@ -5,13 +5,15 @@ from sqlalchemy import MetaData, Table, Column, String from migrate.changeset.schema import create_column, drop_column + def upgrade(migrate_engine): - metadata = MetaData(bind = migrate_engine) + metadata = MetaData(bind=migrate_engine) authorities = Table('authorities', metadata, autoload=True) name_column = Column('name', String) name_column.create(authorities) + def downgrade(migrate_engine): - metadata = MetaData(bind = migrate_engine) + metadata = MetaData(bind=migrate_engine) authorities = Table('authorities', metadata, autoload=True) authorities.c.name.drop() diff --git a/sfa/storage/model.py b/sfa/storage/model.py index 923576e9..6cb5f1d7 100644 --- a/sfa/storage/model.py +++ b/sfa/storage/model.py @@ -1,6 +1,6 @@ from datetime import datetime -from sqlalchemy import or_, and_ +from sqlalchemy import or_, and_ from sqlalchemy import Column, Integer, String, DateTime from sqlalchemy import Table, Column, MetaData, join, ForeignKey from sqlalchemy.orm import relationship, backref @@ -12,7 +12,7 @@ from sqlalchemy.ext.declarative import declarative_base from sfa.storage.record import Record from sfa.util.sfalogging import logger from sfa.util.sfatime import utcparse, datetime_to_string -from sfa.util.xml import XML +from sfa.util.xml import XML from sfa.util.py23 import StringType from sfa.trust.gid import GID @@ -27,13 +27,13 @@ Base = declarative_base() # sqlalchemy however offers an object interface, meaning that you write obj.id instead of obj['id'] # which is admittedly much nicer # however we still need to deal with dictionaries if only for the xmlrpc layer -# -# here are a few utilities for this -# +# +# here are a few utilities for this +# # (*) first off, when an old pieve of code needs to be used as-is, if only temporarily, the simplest trick # is to use obj.__dict__ # this behaves exactly like required, i.e. obj.__dict__['field']='new value' does change obj.field -# however this depends on sqlalchemy's implementation so it should be avoided +# however this depends on sqlalchemy's implementation so it should be avoided # # (*) second, when an object needs to be exposed to the xmlrpc layer, we need to convert it into a dict # remember though that writing the resulting dictionary won't change the object @@ -48,15 +48,18 @@ Base = declarative_base() # (*) finally for converting a dictionary into an sqlalchemy object, we provide # obj.load_from_dict(dict) + class AlchemyObj(Record): - def __iter__(self): + + def __iter__(self): self._i = iter(object_mapper(self).columns) - return self - def next(self): + return self + + def next(self): n = self._i.next().name return n, getattr(self, n) -# # only intended for debugging +# # only intended for debugging # def inspect (self, logger, message=""): # logger.info("%s -- Inspecting AlchemyObj -- attrs"%message) # for k in dir(self): @@ -72,43 +75,52 @@ class AlchemyObj(Record): # various kinds of records are implemented as an inheritance hierarchy # RegRecord is the base class for all actual variants # a first draft was using 'type' as the discriminator for the inheritance -# but we had to define another more internal column (classtype) so we +# but we had to define another more internal column (classtype) so we # accomodate variants in types like authority+am and the like class RegRecord(Base, AlchemyObj): - __tablename__ = 'records' - record_id = Column (Integer, primary_key=True) + __tablename__ = 'records' + record_id = Column(Integer, primary_key=True) # this is the discriminator that tells which class to use - classtype = Column (String) + classtype = Column(String) # in a first version type was the discriminator # but that could not accomodate for 'authority+sa' and the like - type = Column (String) - hrn = Column (String) - gid = Column (String) - authority = Column (String) - peer_authority = Column (String) - pointer = Column (Integer, default=-1) - date_created = Column (DateTime) - last_updated = Column (DateTime) + type = Column(String) + hrn = Column(String) + gid = Column(String) + authority = Column(String) + peer_authority = Column(String) + pointer = Column(Integer, default=-1) + date_created = Column(DateTime) + last_updated = Column(DateTime) # use the 'type' column to decide which subclass the object is of - __mapper_args__ = { 'polymorphic_on' : classtype } - - fields = [ 'type', 'hrn', 'gid', 'authority', 'peer_authority' ] - def __init__ (self, type=None, hrn=None, gid=None, authority=None, peer_authority=None, - pointer=None, dict=None): - if type: self.type=type - if hrn: self.hrn=hrn - if gid: - if isinstance(gid, StringType): self.gid=gid - else: self.gid=gid.save_to_string(save_parents=True) - if authority: self.authority=authority - if peer_authority: self.peer_authority=peer_authority - if pointer: self.pointer=pointer - if dict: self.load_from_dict (dict) + __mapper_args__ = {'polymorphic_on': classtype} + + fields = ['type', 'hrn', 'gid', 'authority', 'peer_authority'] + + def __init__(self, type=None, hrn=None, gid=None, authority=None, peer_authority=None, + pointer=None, dict=None): + if type: + self.type = type + if hrn: + self.hrn = hrn + if gid: + if isinstance(gid, StringType): + self.gid = gid + else: + self.gid = gid.save_to_string(save_parents=True) + if authority: + self.authority = authority + if peer_authority: + self.peer_authority = peer_authority + if pointer: + self.pointer = pointer + if dict: + self.load_from_dict(dict) def __repr__(self): - result="", " name={}>".format(self.name)) return result - def update_pis (self, pi_hrns, dbsession): + def update_pis(self, pi_hrns, dbsession): # strip that in case we have words - pi_hrns = [ x.strip() for x in pi_hrns ] + pi_hrns = [x.strip() for x in pi_hrns] request = dbsession.query(RegUser).filter(RegUser.hrn.in_(pi_hrns)) - logger.info("RegAuthority.update_pis: %d incoming pis, %d matches found"\ + logger.info("RegAuthority.update_pis: %d incoming pis, %d matches found" % (len(pi_hrns), request.count())) pis = dbsession.query(RegUser).filter(RegUser.hrn.in_(pi_hrns)).all() self.reg_pis = pis #################### + + class RegSlice(RegRecord): - __tablename__ = 'slices' - __mapper_args__ = { 'polymorphic_identity' : 'slice' } - record_id = Column (Integer, ForeignKey ("records.record_id"), primary_key=True) - #### extensions come here + __tablename__ = 'slices' + __mapper_args__ = {'polymorphic_identity': 'slice'} + record_id = Column(Integer, ForeignKey( + "records.record_id"), primary_key=True) + # extensions come here reg_researchers = relationship \ - ('RegUser', + ('RegUser', secondary=slice_researcher_table, - primaryjoin=RegRecord.record_id==slice_researcher_table.c.slice_id, - secondaryjoin=RegRecord.record_id==slice_researcher_table.c.researcher_id, + primaryjoin=RegRecord.record_id == slice_researcher_table.c.slice_id, + secondaryjoin=RegRecord.record_id == slice_researcher_table.c.researcher_id, backref='reg_slices_as_researcher', - ) + ) - def __init__ (self, **kwds): + def __init__(self, **kwds): if 'type' not in kwds: - kwds['type']='slice' + kwds['type'] = 'slice' RegRecord.__init__(self, **kwds) - def __repr__ (self): + def __repr__(self): return RegRecord.__repr__(self).replace("Record", "Slice") - def update_researchers (self, researcher_hrns, dbsession): + def update_researchers(self, researcher_hrns, dbsession): # strip that in case we have words - researcher_hrns = [ x.strip() for x in researcher_hrns ] - request = dbsession.query (RegUser).filter(RegUser.hrn.in_(researcher_hrns)) - logger.info ("RegSlice.update_researchers: %d incoming researchers, %d matches found"\ - % (len(researcher_hrns), request.count())) - researchers = dbsession.query (RegUser).filter(RegUser.hrn.in_(researcher_hrns)).all() + researcher_hrns = [x.strip() for x in researcher_hrns] + request = dbsession.query(RegUser).filter( + RegUser.hrn.in_(researcher_hrns)) + logger.info("RegSlice.update_researchers: %d incoming researchers, %d matches found" + % (len(researcher_hrns), request.count())) + researchers = dbsession.query(RegUser).filter( + RegUser.hrn.in_(researcher_hrns)).all() self.reg_researchers = researchers # when dealing with credentials, we need to retrieve the PIs attached to a slice # WARNING: with the move to passing dbsessions around, we face a glitch here because this # helper function is called from the trust/ area that - def get_pis (self): + def get_pis(self): from sqlalchemy.orm import sessionmaker Session = sessionmaker() dbsession = Session.object_session(self) from sfa.util.xrn import get_authority authority_hrn = get_authority(self.hrn) - auth_record = dbsession.query(RegAuthority).filter_by(hrn=authority_hrn).first() + auth_record = dbsession.query( + RegAuthority).filter_by(hrn=authority_hrn).first() return auth_record.reg_pis - - @validates ('expires') - def validate_expires (self, key, incoming): - return self.validate_datetime (key, incoming) + + @validates('expires') + def validate_expires(self, key, incoming): + return self.validate_datetime(key, incoming) #################### + + class RegNode(RegRecord): - __tablename__ = 'nodes' - __mapper_args__ = { 'polymorphic_identity' : 'node' } - record_id = Column (Integer, ForeignKey ("records.record_id"), primary_key=True) - + __tablename__ = 'nodes' + __mapper_args__ = {'polymorphic_identity': 'node'} + record_id = Column(Integer, ForeignKey( + "records.record_id"), primary_key=True) + def __init__(self, **kwds): if 'type' not in kwds: - kwds['type']='node' + kwds['type'] = 'node' RegRecord.__init__(self, **kwds) - def __repr__ (self): + def __repr__(self): return RegRecord.__repr__(self).replace("Record", "Node") #################### + + class RegUser(RegRecord): - __tablename__ = 'users' + __tablename__ = 'users' # these objects will have type='user' in the records table - __mapper_args__ = { 'polymorphic_identity' : 'user' } - record_id = Column (Integer, ForeignKey ("records.record_id"), primary_key=True) - #### extensions come here - email = Column ('email', String) + __mapper_args__ = {'polymorphic_identity': 'user'} + record_id = Column(Integer, ForeignKey( + "records.record_id"), primary_key=True) + # extensions come here + email = Column('email', String) # can't use name 'keys' here because when loading from xml we're getting - # a 'keys' tag, and assigning a list of strings in a reference column like this crashes + # a 'keys' tag, and assigning a list of strings in a reference column like + # this crashes reg_keys = relationship \ ('RegKey', backref='reg_user', - cascade = "all, delete, delete-orphan", - ) - + cascade="all, delete, delete-orphan", + ) + # so we can use RegUser (email=.., hrn=..) and the like - def __init__ (self, **kwds): + def __init__(self, **kwds): # handle local settings if 'email' in kwds: self.email = kwds.pop('email') @@ -309,12 +347,12 @@ class RegUser(RegRecord): RegRecord.__init__(self, **kwds) # append stuff at the end of the record __repr__ - def __repr__ (self): + def __repr__(self): result = RegRecord.__repr__(self).replace("Record", "User") result.replace(">", " email={}>".format(self.email)) return result - @validates('email') + @validates('email') def validate_email(self, key, address): assert '@' in address return address @@ -322,34 +360,39 @@ class RegUser(RegRecord): #################### # xxx tocheck : not sure about eager loading of this one # meaning, when querying the whole records, we expect there should -# be a single query to fetch all the keys -# or, is it enough that we issue a single query to retrieve all the keys +# be a single query to fetch all the keys +# or, is it enough that we issue a single query to retrieve all the keys + + class RegKey(Base): - __tablename__ = 'keys' - key_id = Column (Integer, primary_key=True) - record_id = Column (Integer, ForeignKey ("records.record_id")) - key = Column (String) - pointer = Column (Integer, default = -1) - - def __init__ (self, key, pointer=None): + __tablename__ = 'keys' + key_id = Column(Integer, primary_key=True) + record_id = Column(Integer, ForeignKey("records.record_id")) + key = Column(String) + pointer = Column(Integer, default=-1) + + def __init__(self, key, pointer=None): self.key = key if pointer: self.pointer = pointer - def __repr__ (self): + def __repr__(self): result = ":} -# so after that, an 'authority' record will e.g. have a 'reg-pis' field with the hrns of its pi-users -augment_map = {'authority': {'reg-pis' : 'reg_pis',}, - 'slice': {'reg-researchers' : 'reg_researchers',}, - 'user': {'reg-pi-authorities' : 'reg_authorities_as_pi', - 'reg-slices' : 'reg_slices_as_researcher',}, - } +# so after that, an 'authority' record will e.g. have a 'reg-pis' field +# with the hrns of its pi-users +augment_map = {'authority': {'reg-pis': 'reg_pis', }, + 'slice': {'reg-researchers': 'reg_researchers', }, + 'user': {'reg-pi-authorities': 'reg_authorities_as_pi', + 'reg-slices': 'reg_slices_as_researcher', }, + } # xxx mystery @@ -499,16 +557,18 @@ augment_map = {'authority': {'reg-pis' : 'reg_pis',}, # is what gets exposed to the drivers (this is historical and dates back before sqlalchemy) # so it is recommended to always run this function that will make sure # that such built-in fields are properly set in __dict__ too -# +# def augment_with_sfa_builtins(local_record): # don't ruin the import of that file in a client world from sfa.util.xrn import Xrn # add a 'urn' field - setattr(local_record, 'reg-urn', Xrn(xrn=local_record.hrn, type=local_record.type).urn) - # users have keys and this is needed to synthesize 'users' sent over to CreateSliver + setattr(local_record, 'reg-urn', + Xrn(xrn=local_record.hrn, type=local_record.type).urn) + # users have keys and this is needed to synthesize 'users' sent over to + # CreateSliver fields_to_check = [] if local_record.type == 'user': - user_keys = [ key.key for key in local_record.reg_keys ] + user_keys = [key.key for key in local_record.reg_keys] setattr(local_record, 'reg-keys', user_keys) fields_to_check = ['email'] elif local_record.type == 'authority': @@ -524,7 +584,5 @@ def augment_with_sfa_builtins(local_record): for (field_name, attribute) in type_map.items(): # get related objects related_records = getattr(local_record, attribute, []) - hrns = [ r.hrn for r in related_records ] - setattr (local_record, field_name, hrns) - - + hrns = [r.hrn for r in related_records] + setattr(local_record, field_name, hrns) diff --git a/sfa/storage/parameter.py b/sfa/storage/parameter.py index dc9d5b5c..1545cac4 100644 --- a/sfa/storage/parameter.py +++ b/sfa/storage/parameter.py @@ -10,6 +10,7 @@ from sfa.util.faults import SfaAPIError from sfa.util.py23 import StringType + class Parameter: """ Typed value wrapper. Use in accepts and returns to document method @@ -17,11 +18,11 @@ class Parameter: sub-parameters (i.e., dict fields). """ - def __init__(self, type, doc = "", - min = None, max = None, - optional = None, - ro = False, - nullok = False): + def __init__(self, type, doc="", + min=None, max=None, + optional=None, + ro=False, + nullok=False): # Basic type of the parameter. Must be a builtin type # that can be marshalled by XML-RPC. self.type = type @@ -51,6 +52,7 @@ class Parameter: def __repr__(self): return repr(self.type) + class Mixed(tuple): """ A list (technically, a tuple) of types. Use in accepts and returns @@ -75,6 +77,7 @@ def python_type(arg): else: return type(arg) + def xmlrpc_type(arg): """ Returns the XML-RPC type of the specified argument, which may be a diff --git a/sfa/storage/record.py b/sfa/storage/record.py index a03ce30f..e8697040 100644 --- a/sfa/storage/record.py +++ b/sfa/storage/record.py @@ -8,6 +8,7 @@ from sfa.trust.gid import GID from sfa.util.sfalogging import logger from sfa.util.py23 import StringType + class Record: def __init__(self, dict=None, xml_str=None): @@ -16,7 +17,7 @@ class Record: elif xml_str: xml = XML(xml_str) xml_dict = xml.todict() - self.load_from_dict(xml_dict) + self.load_from_dict(xml_dict) def get_field(self, field): return self.__dict__.get(field, None) @@ -25,36 +26,37 @@ class Record: # turns out the date_created field is received by the client as a 'created' int # (and 'last_updated' does not make it at all) # let's be flexible - def date_repr (self,fields): - if not isinstance(fields,list): + def date_repr(self, fields): + if not isinstance(fields, list): fields = [fields] for field in fields: - value = getattr(self,field,None) - if isinstance (value,datetime): - return datetime_to_string (value) - elif isinstance (value,(int,float)): + value = getattr(self, field, None) + if isinstance(value, datetime): + return datetime_to_string(value) + elif isinstance(value, (int, float)): return datetime_to_string(utcparse(value)) # fallback return "** undef_datetime **" - + # # need to filter out results, esp. wrt relationships # exclude_types must be a tuple so we can use isinstance - # - def record_to_dict (self, exclude_types=None): + # + def record_to_dict(self, exclude_types=None): if exclude_types is None: exclude_types = () d = self.__dict__ - def exclude (k, v): - return k.startswith('_') or isinstance (v, exclude_types) - keys = [ k for k, v in d.items() if not exclude(k, v) ] - return { k : d[k] for k in keys } - + + def exclude(k, v): + return k.startswith('_') or isinstance(v, exclude_types) + keys = [k for k, v in d.items() if not exclude(k, v)] + return {k: d[k] for k in keys} + def toxml(self): return self.save_as_xml() - def load_from_dict (self, d): - for (k,v) in d.iteritems(): + def load_from_dict(self, d): + for (k, v) in d.iteritems(): # experimental if isinstance(v, StringType) and v.lower() in ['true']: v = True @@ -65,13 +67,14 @@ class Record: # in addition we provide convenience for converting to and from xml records # for this purpose only, we need the subclasses to define 'fields' as either # a list or a dictionary - def fields (self): + def fields(self): fields = self.__dict__.keys() return fields - def save_as_xml (self): + def save_as_xml(self): # xxx not sure about the scope here - input_dict = dict( [ (key, getattr(self,key)) for key in self.fields() if getattr(self,key,None) ] ) + input_dict = dict([(key, getattr(self, key)) + for key in self.fields() if getattr(self, key, None)]) xml_record = XML("") xml_record.parse_dict(input_dict) return xml_record.toxml() @@ -82,7 +85,7 @@ class Record: else: format = format.lower() if format == 'text': - self.dump_text(dump_parents,sort=sort) + self.dump_text(dump_parents, sort=sort) elif format == 'xml': print(self.save_as_xml()) elif format == 'simple': @@ -91,20 +94,23 @@ class Record: raise Exception("Invalid format %s" % format) def dump_text(self, dump_parents=False, sort=False): - print(40*'=') + print(40 * '=') print("RECORD") # print remaining fields fields = self.fields() - if sort: fields.sort() + if sort: + fields.sort() for attrib_name in fields: attrib = getattr(self, attrib_name) # skip internals - if attrib_name.startswith('_'): continue + if attrib_name.startswith('_'): + continue # skip callables - if callable (attrib): continue - # handle gid + if callable(attrib): + continue + # handle gid if attrib_name == 'gid': - print(" gid:") + print(" gid:") print(GID(string=attrib).dump_string(8, dump_parents)) elif attrib_name in ['date created', 'last updated']: print(" %s: %s" % (attrib_name, self.date_repr(attrib_name))) @@ -112,4 +118,4 @@ class Record: print(" %s: %s" % (attrib_name, attrib)) def dump_simple(self): - return "%s"%self + return "%s" % self diff --git a/sfa/trust/abac_credential.py b/sfa/trust/abac_credential.py index cb6e6867..edf21c99 100644 --- a/sfa/trust/abac_credential.py +++ b/sfa/trust/abac_credential.py @@ -44,13 +44,13 @@ except: # or between a subject and a class of targets (all those satisfying a role). # # An ABAC credential is like a normal SFA credential in that it has -# a validated signature block and is checked for expiration. +# a validated signature block and is checked for expiration. # It does not, however, have 'privileges'. Rather it contains a 'head' and # list of 'tails' of elements, each of which represents a principal and # role. # A special case of an ABAC credential is a speaks_for credential. Such -# a credential is simply an ABAC credential in form, but has a single +# a credential is simply an ABAC credential in form, but has a single # tail and fixed role 'speaks_for'. In ABAC notation, it asserts # AGENT.speaks_for(AGENT)<-CLIENT, or "AGENT asserts that CLIENT may speak # for AGENT". The AGENT in this case is the head and the CLIENT is the @@ -66,16 +66,20 @@ except: # An ABAC element contains a principal (keyid and optional mnemonic) # and optional role and linking_role element class ABACElement: - def __init__(self, principal_keyid, principal_mnemonic=None, \ - role=None, linking_role=None): + + def __init__(self, principal_keyid, principal_mnemonic=None, + role=None, linking_role=None): self._principal_keyid = principal_keyid self._principal_mnemonic = principal_mnemonic self._role = role self._linking_role = linking_role def get_principal_keyid(self): return self._principal_keyid + def get_principal_mnemonic(self): return self._principal_mnemonic + def get_role(self): return self._role + def get_linking_role(self): return self._linking_role def __str__(self): @@ -91,26 +95,28 @@ class ABACElement: # Subclass of Credential for handling ABAC credentials # They have a different cred_type (geni_abac vs. geni_sfa) # and they have a head and tail and role (as opposed to privileges) + + class ABACCredential(Credential): ABAC_CREDENTIAL_TYPE = 'geni_abac' - def __init__(self, create=False, subject=None, + def __init__(self, create=False, subject=None, string=None, filename=None): - self.head = None # An ABACElemenet - self.tails = [] # List of ABACElements - super(ABACCredential, self).__init__(create=create, - subject=subject, - string=string, + self.head = None # An ABACElemenet + self.tails = [] # List of ABACElements + super(ABACCredential, self).__init__(create=create, + subject=subject, + string=string, filename=filename) self.cred_type = ABACCredential.ABAC_CREDENTIAL_TYPE - def get_head(self) : - if not self.head: + def get_head(self): + if not self.head: self.decode() return self.head - def get_tails(self) : + def get_tails(self): if len(self.tails) == 0: self.decode() return self.tails @@ -125,7 +131,8 @@ class ABACCredential(Credential): rt0_root = rt0s[0] heads = self._get_abac_elements(rt0_root, 'head') if len(heads) != 1: - raise CredentialNotVerifiable("ABAC credential should have exactly 1 head element, had %d" % len(heads)) + raise CredentialNotVerifiable( + "ABAC credential should have exactly 1 head element, had %d" % len(heads)) self.head = heads[0] self.tails = self._get_abac_elements(rt0_root, 'tail') @@ -136,7 +143,8 @@ class ABACCredential(Credential): for elt in elements: keyids = elt.getElementsByTagName('keyid') if len(keyids) != 1: - raise CredentialNotVerifiable("ABAC credential element '%s' should have exactly 1 keyid, had %d." % (label, len(keyids))) + raise CredentialNotVerifiable( + "ABAC credential element '%s' should have exactly 1 keyid, had %d." % (label, len(keyids))) keyid_elt = keyids[0] keyid = keyid_elt.childNodes[0].nodeValue.strip() @@ -153,7 +161,8 @@ class ABACCredential(Credential): linking_role = None linking_role_elts = elt.getElementsByTagName('linking_role') if len(linking_role_elts) > 0: - linking_role = linking_role_elts[0].childNodes[0].nodeValue.strip() + linking_role = linking_role_elts[ + 0].childNodes[0].nodeValue.strip() abac_element = ABACElement(keyid, mnemonic, role, linking_role) abac_elements.append(abac_element) @@ -162,12 +171,14 @@ class ABACCredential(Credential): def dump_string(self, dump_parents=False, show_xml=False): result = "ABAC Credential\n" - filename=self.get_filename() - if filename: result += "Filename %s\n"%filename + filename = self.get_filename() + if filename: + result += "Filename %s\n" % filename if self.expiration: - result += "\texpiration: %s \n" % self.expiration.strftime(SFATIME_FORMAT) + result += "\texpiration: %s \n" % self.expiration.strftime( + SFATIME_FORMAT) - result += "\tHead: %s\n" % self.get_head() + result += "\tHead: %s\n" % self.get_head() for tail in self.get_tails(): result += "\tTail: %s\n" % tail if self.get_signature(): @@ -187,7 +198,8 @@ class ABACCredential(Credential): return result # sounds like this should be __repr__ instead ?? - # Produce the ABAC assertion. Something like [ABAC cred: Me.role<-You] or similar + # Produce the ABAC assertion. Something like [ABAC cred: Me.role<-You] or + # similar def pretty_cred(self): result = "[ABAC cred: " + str(self.get_head()) for tail in self.get_tails(): @@ -197,9 +209,9 @@ class ABACCredential(Credential): def createABACElement(self, doc, tagName, abacObj): kid = abacObj.get_principal_keyid() - mnem = abacObj.get_principal_mnemonic() # may be None - role = abacObj.get_role() # may be None - link = abacObj.get_linking_role() # may be None + mnem = abacObj.get_principal_mnemonic() # may be None + role = abacObj.get_role() # may be None + link = abacObj.get_linking_role() # may be None ele = doc.createElement(tagName) prin = doc.createElement('ABACprincipal') ele.appendChild(prin) @@ -218,7 +230,8 @@ class ABACCredential(Credential): # WARNING: # In general, a signed credential obtained externally should # not be changed else the signature is no longer valid. So, once - # you have loaded an existing signed credential, do not call encode() or sign() on it. + # you have loaded an existing signed credential, do not call encode() or + # sign() on it. def encode(self): # Create the XML document @@ -231,9 +244,12 @@ class ABACCredential(Credential): # Note that delegation of credentials between the 2 only really works # cause those schemas are identical. # Also note these PG schemas talk about PG tickets and CM policies. - signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") - signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.geni.net/resources/credential/2/credential.xsd") - signed_cred.setAttribute("xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd") + signed_cred.setAttribute( + "xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") + signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", + "http://www.geni.net/resources/credential/2/credential.xsd") + signed_cred.setAttribute( + "xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd") # PG says for those last 2: # signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd") @@ -256,12 +272,14 @@ class ABACCredential(Credential): append_sub(doc, cred, "uuid", "") if not self.expiration: - self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME)) + self.set_expiration(datetime.datetime.utcnow( + ) + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME)) self.expiration = self.expiration.replace(microsecond=0) if self.expiration.tzinfo is not None and self.expiration.tzinfo.utcoffset(self.expiration) is not None: # TZ aware. Make sure it is UTC self.expiration = self.expiration.astimezone(tz.tzutc()) - append_sub(doc, cred, "expires", self.expiration.strftime(SFATIME_FORMAT)) # RFC3339 + append_sub(doc, cred, "expires", self.expiration.strftime( + SFATIME_FORMAT)) # RFC3339 abac = doc.createElement("abac") rt0 = doc.createElement("rt0") diff --git a/sfa/trust/auth.py b/sfa/trust/auth.py index 512c58bd..16eb8a69 100644 --- a/sfa/trust/auth.py +++ b/sfa/trust/auth.py @@ -1,5 +1,5 @@ # -# SfaAPI authentication +# SfaAPI authentication # import sys @@ -27,7 +27,7 @@ class Auth: Credential based authentication """ - def __init__(self, peer_cert = None, config = None ): + def __init__(self, peer_cert=None, config=None): self.peer_cert = peer_cert self.hierarchy = Hierarchy() if not config: @@ -42,34 +42,41 @@ class Auth: # this convenience methods extracts speaking_for_xrn # from the passed options using 'geni_speaking_for' - def checkCredentialsSpeaksFor (self, *args, **kwds): + def checkCredentialsSpeaksFor(self, *args, **kwds): if 'options' not in kwds: - logger.error ("checkCredentialsSpeaksFor was not passed options=options") + logger.error( + "checkCredentialsSpeaksFor was not passed options=options") return # remove the options arg - options = kwds['options']; del kwds['options'] + options = kwds['options'] + del kwds['options'] # compute the speaking_for_xrn arg and pass it to checkCredentials - if options is None: speaking_for_xrn = None - else: speaking_for_xrn = options.get('geni_speaking_for', None) + if options is None: + speaking_for_xrn = None + else: + speaking_for_xrn = options.get('geni_speaking_for', None) kwds['speaking_for_xrn'] = speaking_for_xrn return self.checkCredentials(*args, **kwds) - # do not use mutable as default argument + # do not use mutable as default argument # http://docs.python-guide.org/en/latest/writing/gotchas/#mutable-default-arguments - def checkCredentials(self, creds, operation, xrns=None, - check_sliver_callback=None, + def checkCredentials(self, creds, operation, xrns=None, + check_sliver_callback=None, speaking_for_xrn=None): - if xrns is None: xrns = [] + if xrns is None: + xrns = [] error = (None, None) + def log_invalid_cred(cred): - if not isinstance (cred, StringType): - logger.info("cannot validate credential %s - expecting a string"%cred) + if not isinstance(cred, StringType): + logger.info( + "cannot validate credential %s - expecting a string" % cred) error = ('TypeMismatch', "checkCredentials: expected a string, received {} -- {}" .format(type(cred), cred)) else: cred_obj = Credential(string=cred) - logger.info("failed to validate credential - dump=%s"%\ + logger.info("failed to validate credential - dump=%s" % cred_obj.dump_string(dump_parents=True)) error = sys.exc_info()[:2] return error @@ -80,23 +87,27 @@ class Auth: if not xrn: raise BadArgs("Invalid urn or hrn") - if not isinstance(xrns, list): xrns = [xrns] - slice_xrns = Xrn.filter_type(xrns, 'slice') + slice_xrns = Xrn.filter_type(xrns, 'slice') sliver_xrns = Xrn.filter_type(xrns, 'sliver') - # we are not able to validate slivers in the traditional way so - # we make sure not to include sliver urns/hrns in the core validation loop - hrns = [Xrn(xrn).hrn for xrn in xrns if xrn not in sliver_xrns] + # we are not able to validate slivers in the traditional way so + # we make sure not to include sliver urns/hrns in the core validation + # loop + hrns = [Xrn(xrn).hrn for xrn in xrns if xrn not in sliver_xrns] valid = [] if not isinstance(creds, list): creds = [creds] - logger.debug("Auth.checkCredentials with %d creds on hrns=%s"%(len(creds),hrns)) - # won't work if either creds or hrns is empty - let's make it more explicit - if not creds: raise Forbidden("no credential provided") - if not hrns: hrns = [None] + logger.debug("Auth.checkCredentials with %d creds on hrns=%s" % + (len(creds), hrns)) + # won't work if either creds or hrns is empty - let's make it more + # explicit + if not creds: + raise Forbidden("no credential provided") + if not hrns: + hrns = [None] speaks_for_gid = determine_speaks_for(logger, creds, self.peer_cert, speaking_for_xrn, self.trusted_cert_list) @@ -112,21 +123,22 @@ class Auth: valid.append(cred) except: error = log_invalid_cred(cred) - + # make sure all sliver xrns are validated against the valid credentials if sliver_xrns: if not check_sliver_callback: - msg = "sliver verification callback method not found." + msg = "sliver verification callback method not found." msg += " Unable to validate sliver xrns: %s" % sliver_xrns raise Forbidden(msg) check_sliver_callback(valid, sliver_xrns) - + if not len(valid): - raise Forbidden("Invalid credential %s -- %s"%(error[0],error[1])) - + raise Forbidden("Invalid credential %s -- %s" % + (error[0], error[1])) + return valid - - def check(self, credential, operation, hrn = None): + + def check(self, credential, operation, hrn=None): """ Check the credential against the peer cert (callerGID) included in the credential matches the caller that is connected to the @@ -134,23 +146,24 @@ class Auth: trusted cert and check if the credential is allowed to perform the specified operation. """ - cred = Credential(cred=credential) + cred = Credential(cred=credential) self.client_cred = cred - logger.debug("Auth.check: handling hrn=%s and credential=%s"%\ - (hrn,cred.pretty_cred())) + logger.debug("Auth.check: handling hrn=%s and credential=%s" % + (hrn, cred.pretty_cred())) if cred.type not in ['geni_sfa']: - raise CredentialNotVerifiable(cred.type, "%s not supported" % cred.type) + raise CredentialNotVerifiable( + cred.type, "%s not supported" % cred.type) self.client_gid = self.client_cred.get_gid_caller() self.object_gid = self.client_cred.get_gid_object() - + # make sure the client_gid is not blank if not self.client_gid: raise MissingCallerGID(self.client_cred.pretty_subject()) - + # validate the client cert if it exists if self.peer_cert: - self.verifyPeerCert(self.peer_cert, self.client_gid) + self.verifyPeerCert(self.peer_cert, self.client_gid) # make sure the client is allowed to perform the operation if operation: @@ -161,16 +174,16 @@ class Auth: self.client_cred.verify(self.trusted_cert_file_list, self.config.SFA_CREDENTIAL_SCHEMA) else: - raise MissingTrustedRoots(self.config.get_trustedroots_dir()) - - # Make sure the credential's target matches the specified hrn. - # This check does not apply to trusted peers + raise MissingTrustedRoots(self.config.get_trustedroots_dir()) + + # Make sure the credential's target matches the specified hrn. + # This check does not apply to trusted peers trusted_peers = [gid.get_hrn() for gid in self.trusted_cert_list] if hrn and self.client_gid.get_hrn() not in trusted_peers: target_hrn = self.object_gid.get_hrn() if not hrn == target_hrn: - raise PermissionError("Target hrn: %s doesn't match specified hrn: %s " % \ - (target_hrn, hrn) ) + raise PermissionError("Target hrn: %s doesn't match specified hrn: %s " % + (target_hrn, hrn)) return True def check_ticket(self, ticket): @@ -181,14 +194,15 @@ class Auth: client_ticket = SfaTicket(string=ticket) client_ticket.verify_chain(self.trusted_cert_list) else: - raise MissingTrustedRoots(self.config.get_trustedroots_dir()) + raise MissingTrustedRoots(self.config.get_trustedroots_dir()) - return True + return True def verifyPeerCert(self, cert, gid): # make sure the client_gid matches client's certificate if not cert.is_pubkey(gid.get_pubkey()): - raise ConnectionKeyGIDMismatch(gid.get_subject()+":"+cert.get_subject()) + raise ConnectionKeyGIDMismatch( + gid.get_subject() + ":" + cert.get_subject()) def verifyGidRequestHash(self, gid, hash, arglist): key = gid.get_pubkey() @@ -208,7 +222,7 @@ class Auth: cred.verify(self.trusted_cert_file_list) def authenticateGid(self, gidStr, argList, requestHash=None): - gid = GID(string = gidStr) + gid = GID(string=gidStr) self.validateGid(gid) # request_hash is optional if requestHash: @@ -216,7 +230,7 @@ class Auth: return gid def authenticateCred(self, credStr, argList, requestHash=None): - cred = Credential(string = credStr) + cred = Credential(string=credStr) self.validateCred(cred) # request hash is optional if requestHash: @@ -226,7 +240,7 @@ class Auth: def authenticateCert(self, certStr, requestHash): cert = Certificate(string=certStr) # xxx should be validateCred ?? - self.validateCred(cert) + self.validateCred(cert) def gidNoop(self, gidStr, value, requestHash): self.authenticateGid(gidStr, [gidStr, value], requestHash) @@ -237,26 +251,25 @@ class Auth: return value def verify_cred_is_me(self, credential): - is_me = False + is_me = False cred = Credential(string=credential) caller_gid = cred.get_gid_caller() caller_hrn = caller_gid.get_hrn() if caller_hrn != self.config.SFA_INTERFACE_HRN: raise SfaPermissionDenied(self.config.SFA_INTEFACE_HRN) - return - + return + def get_auth_info(self, auth_hrn): """ Given an authority name, return the information for that authority. This is basically a stub that calls the hierarchy module. - + @param auth_hrn human readable name of authority """ return self.hierarchy.get_auth_info(auth_hrn) - def veriry_auth_belongs_to_me(self, name): """ Verify that an authority belongs to our hierarchy. @@ -270,26 +283,24 @@ class Auth: # get auth info will throw an exception if the authority doesnt exist self.get_auth_info(name) - def verify_object_belongs_to_me(self, name): """ Verify that an object belongs to our hierarchy. By extension, this implies that the authority that owns the object belongs to our hierarchy. If it does not an exception is thrown. - + @param name human readable name of object """ auth_name = self.get_authority(name) if not auth_name: - auth_name = name + auth_name = name if name == self.config.SFA_INTERFACE_HRN: return - self.verify_auth_belongs_to_me(auth_name) - + self.verify_auth_belongs_to_me(auth_name) + def verify_auth_belongs_to_me(self, name): # get auth info will throw an exception if the authority doesnt exist - self.get_auth_info(name) - + self.get_auth_info(name) def verify_object_permission(self, name): """ @@ -297,7 +308,7 @@ class Auth: allows permission to the object 'name'. This is done by a simple prefix test. For example, an object_gid for plc.arizona would match the objects plc.arizona.slice1 and plc.arizona. - + @param name human readable name to test """ object_hrn = self.object_gid.get_hrn() @@ -305,16 +316,16 @@ class Auth: return if name.startswith(object_hrn + "."): return - #if name.startswith(get_authority(name)): - #return - + # if name.startswith(get_authority(name)): + # return + raise PermissionError(name) def determine_user_rights(self, caller_hrn, reg_record): """ Given a user credential and a record, determine what set of rights the user should have to that record. - + This is intended to replace determine_user_rights() and verify_cancreate_credential() """ @@ -322,15 +333,15 @@ class Auth: rl = Rights() type = reg_record.type - logger.debug("entering determine_user_rights with record %s and caller_hrn %s"%\ + logger.debug("entering determine_user_rights with record %s and caller_hrn %s" % (reg_record, caller_hrn)) if type == 'slice': # researchers in the slice are in the DB as-is - researcher_hrns = [ user.hrn for user in reg_record.reg_researchers ] + researcher_hrns = [user.hrn for user in reg_record.reg_researchers] # locating PIs attached to that slice slice_pis = reg_record.get_pis() - pi_hrns = [ user.hrn for user in slice_pis ] + pi_hrns = [user.hrn for user in slice_pis] if (caller_hrn in researcher_hrns + pi_hrns): rl.add('refresh') rl.add('embed') @@ -339,7 +350,7 @@ class Auth: rl.add('info') elif type == 'authority': - pi_hrns = [ user.hrn for user in reg_record.reg_pis ] + pi_hrns = [user.hrn for user in reg_record.reg_pis] if (caller_hrn == self.config.SFA_INTERFACE_HRN): rl.add('authority') rl.add('sa') @@ -347,10 +358,10 @@ class Auth: if (caller_hrn in pi_hrns): rl.add('authority') rl.add('sa') - # NOTE: for the PL implementation, this 'operators' list - # amounted to users with 'tech' role in that site + # NOTE: for the PL implementation, this 'operators' list + # amounted to users with 'tech' role in that site # it seems like this is not needed any longer, so for now I just drop that - # operator_hrns = reg_record.get('operator',[]) + # operator_hrns = reg_record.get('operator', []) # if (caller_hrn in operator_hrns): # rl.add('authority') # rl.add('ma') @@ -383,6 +394,6 @@ class Auth: tmp_cred = Credential(string=cred) if tmp_cred.get_gid_caller().get_hrn() in [caller_hrn_list]: creds.append(cred) - except: pass + except: + pass return creds - diff --git a/sfa/trust/certificate.py b/sfa/trust/certificate.py index a30d73aa..55d46d3f 100644 --- a/sfa/trust/certificate.py +++ b/sfa/trust/certificate.py @@ -11,13 +11,13 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- @@ -67,6 +67,7 @@ glo_passphrase_callback = None # # The callback should return a string containing the passphrase. + def set_passphrase_callback(callback_func): global glo_passphrase_callback @@ -75,24 +76,29 @@ def set_passphrase_callback(callback_func): ## # Sets a fixed passphrase. + def set_passphrase(passphrase): - set_passphrase_callback( lambda k,s,x: passphrase ) + set_passphrase_callback(lambda k, s, x: passphrase) ## # Check to see if a passphrase works for a particular private key string. # Intended to be used by passphrase callbacks for input validation. + def test_passphrase(string, passphrase): try: - OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, string, (lambda x: passphrase)) + OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, string, (lambda x: passphrase)) return True except: return False + def convert_public_key(key): keyconvert_path = "/usr/bin/keyconvert.py" if not os.path.isfile(keyconvert_path): - raise IOError("Could not find keyconvert in {}".format(keyconvert_path)) + raise IOError( + "Could not find keyconvert in {}".format(keyconvert_path)) # we can only convert rsa keys if "ssh-dss" in key: @@ -110,7 +116,8 @@ def convert_public_key(key): # that it can be expected to see why it failed. # TODO: for production, cleanup the temporary files if not os.path.exists(ssl_fn): - raise Exception("keyconvert: generated certificate not found. keyconvert may have failed.") + raise Exception( + "keyconvert: generated certificate not found. keyconvert may have failed.") k = Keypair() try: @@ -131,6 +138,7 @@ def convert_public_key(key): # A Keypair object may represent both a public and private key pair, or it # may represent only a public key (this usage is consistent with OpenSSL). + class Keypair: key = None # public/private keypair m2key = None # public key (m2crypto format) @@ -151,7 +159,8 @@ class Keypair: self.load_from_file(filename) ## - # Create a RSA public/private key pair and store it inside the keypair object + # Create a RSA public/private key pair and store it inside the keypair + # object def create(self): self.key = OpenSSL.crypto.PKey() @@ -166,7 +175,8 @@ class Keypair: self.filename = filename ## - # Load the private key from a file. Implicity the private key includes the public key. + # Load the private key from a file. Implicity the private key includes the + # public key. def load_from_file(self, filename): self.filename = filename @@ -174,7 +184,8 @@ class Keypair: self.load_from_string(buffer) ## - # Load the private key from a string. Implicitly the private key includes the public key. + # Load the private key from a string. Implicitly the private key includes + # the public key. def load_from_string(self, string): import M2Crypto @@ -184,7 +195,8 @@ class Keypair: self.m2key = M2Crypto.EVP.load_key_string( string, functools.partial(glo_passphrase_callback, self, string)) else: - self.key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, string) + self.key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, string) self.m2key = M2Crypto.EVP.load_key_string(string) ## @@ -199,7 +211,8 @@ class Keypair: # create an m2 x509 cert m2name = M2Crypto.X509.X509_Name() - m2name.add_entry_by_txt(field="CN", type=0x1001, entry="junk", len=-1, loc=-1, set=0) + m2name.add_entry_by_txt(field="CN", type=0x1001, + entry="junk", len=-1, loc=-1, set=0) m2x509 = M2Crypto.X509.X509() m2x509.set_pubkey(self.m2key) m2x509.set_serial_number(0) @@ -217,7 +230,8 @@ class Keypair: # convert the m2 x509 cert to a pyopenssl x509 m2pem = m2x509.as_pem() - pyx509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, m2pem) + pyx509 = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, m2pem) # get the pyopenssl pkey from the pyopenssl x509 self.key = pyx509.get_pubkey() @@ -285,16 +299,17 @@ class Keypair: # only informative def get_filename(self): - return getattr(self,'filename',None) + return getattr(self, 'filename', None) def dump(self, *args, **kwargs): print(self.dump_string(*args, **kwargs)) def dump_string(self): - result = "" + result = "" result += "KEYPAIR: pubkey={:>40}...".format(self.get_pubkey_string()) filename = self.get_filename() - if filename: result += "Filename {}\n".format(filename) + if filename: + result += "Filename {}\n".format(filename) return result ## @@ -309,6 +324,7 @@ class Keypair: # When saving a certificate to a file or a string, the caller can choose # whether to save the parent certificates as well. + class Certificate: digest = "sha256" @@ -316,7 +332,7 @@ class Certificate: # issuerKey = None # issuerSubject = None # parent = None - isCA = None # will be a boolean once set + isCA = None # will be a boolean once set separator = "-----parent-----" @@ -358,10 +374,10 @@ class Certificate: self.x509 = OpenSSL.crypto.X509() # FIXME: Use different serial #s self.x509.set_serial_number(3) - self.x509.gmtime_adj_notBefore(0) # 0 means now - self.x509.gmtime_adj_notAfter(lifeDays*60*60*24) # five years is default - self.x509.set_version(2) # x509v3 so it can have extensions - + self.x509.gmtime_adj_notBefore(0) # 0 means now + self.x509.gmtime_adj_notAfter( + lifeDays * 60 * 60 * 24) # five years is default + self.x509.set_version(2) # x509v3 so it can have extensions ## # Given a pyOpenSSL X509 object, store that object inside of this @@ -375,14 +391,15 @@ class Certificate: def load_from_string(self, string): # if it is a chain of multiple certs, then split off the first one and - # load it (support for the ---parent--- tag as well as normal chained certs) + # load it (support for the ---parent--- tag as well as normal chained + # certs) if string is None or string.strip() == "": logger.warn("Empty string in load_from_string") return string = string.strip() - + # If it's not in proper PEM format, wrap it if string.count('-----BEGIN CERTIFICATE') == 0: string = '-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----'\ @@ -392,22 +409,24 @@ class Certificate: # such as the text of the certificate, skip the text beg = string.find('-----BEGIN CERTIFICATE') if beg > 0: - # skipping over non cert beginning + # skipping over non cert beginning string = string[beg:] parts = [] if string.count('-----BEGIN CERTIFICATE-----') > 1 and \ - string.count(Certificate.separator) == 0: - parts = string.split('-----END CERTIFICATE-----',1) + string.count(Certificate.separator) == 0: + parts = string.split('-----END CERTIFICATE-----', 1) parts[0] += '-----END CERTIFICATE-----' else: parts = string.split(Certificate.separator, 1) - self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, parts[0]) + self.x509 = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, parts[0]) if self.x509 is None: - logger.warn("Loaded from string but cert is None: {}".format(string)) + logger.warn( + "Loaded from string but cert is None: {}".format(string)) # if there are more certs, then create a parent and let the parent load # itself from the remainder of the string @@ -433,7 +452,8 @@ class Certificate: if self.x509 is None: logger.warn("None cert in certificate.save_to_string") return "" - string = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, self.x509) + string = OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, self.x509) if PY3 and isinstance(string, bytes): string = string.decode() if save_parents and self.parent: @@ -525,24 +545,27 @@ class Certificate: # let's try to make this a little more usable as is makes logs hairy # FIXME: Consider adding 'urn:publicid' and 'uuid' back for GENI? pretty_fields = ['email'] + def filter_chunk(self, chunk): for field in self.pretty_fields: if field in chunk: - return " "+chunk + return " " + chunk def pretty_cert(self): message = "[Cert." x = self.x509.get_subject() ou = getattr(x, "OU") - if ou: message += " OU: {}".format(ou) + if ou: + message += " OU: {}".format(ou) cn = getattr(x, "CN") - if cn: message += " CN: {}".format(cn) + if cn: + message += " CN: {}".format(cn) data = self.get_data(field='subjectAltName') if data: message += " SubjectAltName:" counter = 0 filtered = [self.filter_chunk(chunk) for chunk in data.split()] - message += " ".join( [f for f in filtered if f]) + message += " ".join([f for f in filtered if f]) omitted = len([f for f in filtered if not f]) if omitted: message += "..+{} omitted".format(omitted) @@ -591,8 +614,6 @@ class Certificate: else: self.add_extension('basicConstraints', 1, 'CA:FALSE') - - ## # Add an X509 extension to the certificate. Add_extension can only be called # once for a particular extension name, due to limitations in the underlying @@ -677,7 +698,8 @@ class Certificate: return self.data[field] ## - # Sign the certificate using the issuer private key and issuer subject previous set with set_issuer(). + # Sign the certificate using the issuer private key and issuer subject + # previous set with set_issuer(). def sign(self): logger.debug('certificate.sign') @@ -762,7 +784,7 @@ class Certificate: # @param Trusted_certs is a list of certificates that are trusted. # - def verify_chain(self, trusted_certs = None): + def verify_chain(self, trusted_certs=None): # Verify a chain of certificates. Each certificate must be signed by # the public key contained in it's parent. The chain is recursed # until a certificate is found that is signed by a trusted root. @@ -787,7 +809,7 @@ class Certificate: if debug_verify_chain: logger.debug("verify_chain: NO. Cert {} is signed by trusted_cert {}, " "but that signer is expired..." - .format(self.pretty_cert(),trusted_cert.pretty_cert())) + .format(self.pretty_cert(), trusted_cert.pretty_cert())) raise CertExpired("{} signer trusted_cert {}" .format(self.pretty_cert(), trusted_cert.pretty_cert())) @@ -828,12 +850,12 @@ class Certificate: # if the parent isn't verified... if debug_verify_chain: logger.debug("verify_chain: .. {}, -> verifying parent {}" - .format(self.pretty_cert(),self.parent.pretty_cert())) + .format(self.pretty_cert(), self.parent.pretty_cert())) self.parent.verify_chain(trusted_certs) return - ### more introspection + # more introspection def get_extensions(self): import M2Crypto # pyOpenSSL does not have a way to get extensions @@ -843,7 +865,8 @@ class Certificate: logger.debug("X509 had {} extensions".format(nb_extensions)) for i in range(nb_extensions): ext = m2x509.get_ext_at(i) - triples.append( (ext.get_name(), ext.get_value(), ext.get_critical(),) ) + triples.append( + (ext.get_name(), ext.get_value(), ext.get_critical(),)) return triples def get_data_names(self): @@ -852,12 +875,12 @@ class Certificate: def get_all_datas(self): triples = self.get_extensions() for name in self.get_data_names(): - triples.append( (name,self.get_data(name),'data',) ) + triples.append((name, self.get_data(name), 'data',)) return triples # only informative def get_filename(self): - return getattr(self,'filename',None) + return getattr(self, 'filename', None) def dump(self, *args, **kwargs): print(self.dump_string(*args, **kwargs)) diff --git a/sfa/trust/credential.py b/sfa/trust/credential.py index e4d5e999..54fe3fc1 100644 --- a/sfa/trust/credential.py +++ b/sfa/trust/credential.py @@ -11,13 +11,13 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- ## @@ -28,7 +28,8 @@ from __future__ import print_function -import os, os.path +import os +import os.path import subprocess import datetime from tempfile import mkstemp @@ -52,7 +53,7 @@ from sfa.trust.rights import Right, Rights, determine_rights from sfa.trust.gid import GID from sfa.util.xrn import urn_to_hrn, hrn_authfor_hrn -# 31 days, in seconds +# 31 days, in seconds DEFAULT_CREDENTIAL_LIFETIME = 86400 * 31 @@ -62,7 +63,7 @@ DEFAULT_CREDENTIAL_LIFETIME = 86400 * 31 # . add namespaces to signed-credential element? signature_format = \ -''' + ''' @@ -90,8 +91,10 @@ signature_format = \ ## # Convert a string into a bool # used to convert an xsd:boolean to a Python boolean + + def str2bool(str): - if str.lower() in ['true','1']: + if str.lower() in ('true', '1'): return True return False @@ -101,16 +104,17 @@ def str2bool(str): def getTextNode(element, subele): sub = element.getElementsByTagName(subele)[0] - if len(sub.childNodes) > 0: + if len(sub.childNodes) > 0: return sub.childNodes[0].nodeValue else: return None - + ## # Utility function to set the text of an XML element # It creates the element, adds the text to it, # and then appends it to the parent. + def append_sub(doc, parent, element, text): ele = doc.createElement(element) ele.appendChild(doc.createTextNode(text)) @@ -121,8 +125,9 @@ def append_sub(doc, parent, element, text): # for a signed-credential # + class Signature(object): - + def __init__(self, string=None): self.refid = None self.issuer_gid = None @@ -131,7 +136,6 @@ class Signature(object): self.xml = string self.decode() - def get_refid(self): if not self.refid: self.decode() @@ -148,13 +152,14 @@ class Signature(object): def get_issuer_gid(self): if not self.gid: self.decode() - return self.gid + return self.gid def set_issuer_gid(self, gid): self.gid = gid def decode(self): - # Helper function to pull characters off the front of a string if present + # Helper function to pull characters off the front of a string if + # present def remove_prefix(text, prefix): if text and prefix and text.startswith(prefix): return text[len(prefix):] @@ -166,15 +171,17 @@ class Signature(object): logger.log_exc("Failed to parse credential, {}".format(self.xml)) raise sig = doc.getElementsByTagName("Signature")[0] - ## This code until the end of function rewritten by Aaron Helsinger + # This code until the end of function rewritten by Aaron Helsinger ref_id = remove_prefix(sig.getAttribute("xml:id").strip(), "Sig_") - # The xml:id tag is optional, and could be in a + # The xml:id tag is optional, and could be in a # Reference xml:id or Reference UID sub element instead if not ref_id or ref_id == '': reference = sig.getElementsByTagName('Reference')[0] - ref_id = remove_prefix(reference.getAttribute('xml:id').strip(), "Sig_") + ref_id = remove_prefix( + reference.getAttribute('xml:id').strip(), "Sig_") if not ref_id or ref_id == '': - ref_id = remove_prefix(reference.getAttribute('URI').strip(), "#") + ref_id = remove_prefix( + reference.getAttribute('URI').strip(), "#") self.set_refid(ref_id) keyinfos = sig.getElementsByTagName("X509Data") gids = None @@ -184,15 +191,17 @@ class Signature(object): if len(cert.childNodes) > 0: szgid = cert.childNodes[0].nodeValue szgid = szgid.strip() - szgid = "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----".format(szgid) + szgid = "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----".format( + szgid) if gids is None: gids = szgid else: gids += "\n" + szgid if gids is None: - raise CredentialNotVerifiable("Malformed XML: No certificate found in signature") + raise CredentialNotVerifiable( + "Malformed XML: No certificate found in signature") self.set_issuer_gid(GID(string=gids)) - + def encode(self): self.xml = signature_format.format(refid=self.get_refid()) @@ -200,33 +209,38 @@ class Signature(object): # A credential provides a caller gid with privileges to an object gid. # A signed credential is signed by the object's authority. # -# Credentials are encoded in one of two ways. +# Credentials are encoded in one of two ways. # The legacy style (now unsupported) places it in the subjectAltName of an X509 certificate. # The new credentials are placed in signed XML. # # WARNING: # In general, a signed credential obtained externally should # not be changed else the signature is no longer valid. So, once -# you have loaded an existing signed credential, do not call encode() or sign() on it. +# you have loaded an existing signed credential, do not call encode() or +# sign() on it. + def filter_creds_by_caller(creds, caller_hrn_list): - """ - Returns a list of creds who's gid caller matches the - specified caller hrn - """ - if not isinstance(creds, list): creds = [creds] - if not isinstance(caller_hrn_list, list): - caller_hrn_list = [caller_hrn_list] - caller_creds = [] - for cred in creds: - try: - tmp_cred = Credential(string=cred) - if tmp_cred.type != Credential.SFA_CREDENTIAL_TYPE: - continue - if tmp_cred.get_gid_caller().get_hrn() in caller_hrn_list: - caller_creds.append(cred) - except: pass - return caller_creds + """ + Returns a list of creds who's gid caller matches the + specified caller hrn + """ + if not isinstance(creds, list): + creds = [creds] + if not isinstance(caller_hrn_list, list): + caller_hrn_list = [caller_hrn_list] + caller_creds = [] + for cred in creds: + try: + tmp_cred = Credential(string=cred) + if tmp_cred.type != Credential.SFA_CREDENTIAL_TYPE: + continue + if tmp_cred.get_gid_caller().get_hrn() in caller_hrn_list: + caller_creds.append(cred) + except: + pass + return caller_creds + class Credential(object): @@ -266,15 +280,16 @@ class Credential(object): self.version = cred['geni_version'] if string or filename: - if string: + if string: str = string elif filename: with open(filename) as infile: str = infile.read() - + # if this is a legacy credential, write error and bail out if isinstance(str, StringType) and str.strip().startswith("-----"): - logger.error("Legacy credentials not supported any more - giving up with {}...".format(str[:10])) + logger.error( + "Legacy credentials not supported any more - giving up with {}...".format(str[:10])) return else: self.xml = str @@ -287,16 +302,20 @@ class Credential(object): if not getattr(Credential, 'xmlsec1_path', None): # Find a xmlsec1 binary path Credential.xmlsec1_path = '' - paths = ['/usr/bin', '/usr/local/bin', '/bin', '/opt/bin', '/opt/local/bin'] - try: paths += os.getenv('PATH').split(':') - except: pass + paths = ['/usr/bin', '/usr/local/bin', + '/bin', '/opt/bin', '/opt/local/bin'] + try: + paths += os.getenv('PATH').split(':') + except: + pass for path in paths: xmlsec1 = os.path.join(path, 'xmlsec1') if os.path.isfile(xmlsec1): Credential.xmlsec1_path = xmlsec1 break if not Credential.xmlsec1_path: - logger.error("Could not locate required binary 'xmlsec1' - SFA will be unable to sign stuff !!") + logger.error( + "Could not locate required binary 'xmlsec1' - SFA will be unable to sign stuff !!") return Credential.xmlsec1_path def get_subject(self): @@ -330,7 +349,6 @@ class Credential(object): def set_signature(self, sig): self.signature = sig - ## # Need the issuer's private key and name # @param key Keypair object containing the private key of the issuer @@ -340,7 +358,6 @@ class Credential(object): self.issuer_privkey = privkey self.issuer_gid = gid - ## # Set this credential's parent def set_parent(self, cred): @@ -380,16 +397,17 @@ class Credential(object): if not self.gidObject: self.decode() return self.gidObject - + ## # Expiration: an absolute UTC time of expiration (as either an int or string or datetime) - # + # def set_expiration(self, expiration): expiration_datetime = utcparse(expiration) if expiration_datetime is not None: self.expiration = expiration_datetime else: - logger.error("unexpected input {} in Credential.set_expiration".format(expiration)) + logger.error( + "unexpected input {} in Credential.set_expiration".format(expiration)) ## # get the lifetime of the credential (always in datetime format) @@ -397,7 +415,8 @@ class Credential(object): def get_expiration(self): if not self.expiration: self.decode() - # at this point self.expiration is normalized as a datetime - DON'T call utcparse again + # at this point self.expiration is normalized as a datetime - DON'T + # call utcparse again return self.expiration ## @@ -407,9 +426,9 @@ class Credential(object): def set_privileges(self, privs): if isinstance(privs, str): - self.privileges = Rights(string = privs) + self.privileges = Rights(string=privs) else: - self.privileges = privs + self.privileges = privs ## # return the privileges as a Rights object @@ -427,20 +446,20 @@ class Credential(object): def can_perform(self, op_name): rights = self.get_privileges() - + if not rights: return False return rights.can_perform(op_name) - ## - # Encode the attributes of the credential into an XML string - # This should be done immediately before signing the credential. + # Encode the attributes of the credential into an XML string + # This should be done immediately before signing the credential. # WARNING: # In general, a signed credential obtained externally should # not be changed else the signature is no longer valid. So, once - # you have loaded an existing signed credential, do not call encode() or sign() on it. + # you have loaded an existing signed credential, do not call encode() or + # sign() on it. def encode(self): # Create the XML document @@ -453,18 +472,22 @@ class Credential(object): # Note that delegation of credentials between the 2 only really works # cause those schemas are identical. # Also note these PG schemas talk about PG tickets and CM policies. - signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") - # FIXME: See v2 schema at www.geni.net/resources/credential/2/credential.xsd - signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.planet-lab.org/resources/sfa/credential.xsd") - signed_cred.setAttribute("xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd") + signed_cred.setAttribute( + "xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance") + # FIXME: See v2 schema at + # www.geni.net/resources/credential/2/credential.xsd + signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", + "http://www.planet-lab.org/resources/sfa/credential.xsd") + signed_cred.setAttribute( + "xsi:schemaLocation", "http://www.planet-lab.org/resources/sfa/ext/policy/1 http://www.planet-lab.org/resources/sfa/ext/policy/1/policy.xsd") # PG says for those last 2: # signed_cred.setAttribute("xsi:noNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd") # signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd") - doc.appendChild(signed_cred) - - # Fill in the bit + doc.appendChild(signed_cred) + + # Fill in the bit cred = doc.createElement("credential") cred.setAttribute("xml:id", self.get_refid()) signed_cred.appendChild(cred) @@ -476,13 +499,16 @@ class Credential(object): append_sub(doc, cred, "target_urn", self.gidObject.get_urn()) append_sub(doc, cred, "uuid", "") if not self.expiration: - logger.debug("Creating credential valid for {} s".format(DEFAULT_CREDENTIAL_LIFETIME)) - self.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME)) + logger.debug("Creating credential valid for {} s".format( + DEFAULT_CREDENTIAL_LIFETIME)) + self.set_expiration(datetime.datetime.utcnow( + ) + datetime.timedelta(seconds=DEFAULT_CREDENTIAL_LIFETIME)) self.expiration = self.expiration.replace(microsecond=0) if self.expiration.tzinfo is not None and self.expiration.tzinfo.utcoffset(self.expiration) is not None: # TZ aware. Make sure it is UTC - by Aaron Helsinger self.expiration = self.expiration.astimezone(tz.tzutc()) - append_sub(doc, cred, "expires", self.expiration.strftime(SFATIME_FORMAT)) + append_sub(doc, cred, "expires", + self.expiration.strftime(SFATIME_FORMAT)) privileges = doc.createElement("privileges") cred.appendChild(privileges) @@ -491,7 +517,8 @@ class Credential(object): for right in rights.rights: priv = doc.createElement("privilege") append_sub(doc, priv, "name", right.kind) - append_sub(doc, priv, "can_delegate", str(right.delegate).lower()) + append_sub(doc, priv, "can_delegate", + str(right.delegate).lower()) privileges.appendChild(priv) # Add the parent credential if it exists @@ -534,8 +561,10 @@ class Credential(object): attr = parentRoot.attributes.item(attrIx) # returns the old attribute of same name that was # on the credential - # Below throws InUse exception if we forgot to clone the attribute first - oldAttr = signed_cred.setAttributeNode(attr.cloneNode(True)) + # Below throws InUse exception if we forgot to clone the + # attribute first + oldAttr = signed_cred.setAttributeNode( + attr.cloneNode(True)) if oldAttr and oldAttr.value != attr.value: msg = "Delegating cred from owner {} to {} over {}:\n" "- Replaced attribute {} value '{}' with '{}'"\ @@ -544,7 +573,8 @@ class Credential(object): logger.warn(msg) #raise CredentialNotVerifiable("Can't encode new valid delegated credential: {}".format(msg)) - p_cred = doc.importNode(sdoc.getElementsByTagName("credential")[0], True) + p_cred = doc.importNode( + sdoc.getElementsByTagName("credential")[0], True) p = doc.createElement("parent") p.appendChild(p_cred) cred.appendChild(p) @@ -558,24 +588,24 @@ class Credential(object): if self.parent: for cur_cred in self.get_credential_list()[1:]: sdoc = parseString(cur_cred.get_signature().get_xml()) - ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True) + ele = doc.importNode( + sdoc.getElementsByTagName("Signature")[0], True) signatures.appendChild(ele) - + # Get the finished product self.xml = doc.toxml("utf-8") - - def save_to_random_tmp_file(self): + def save_to_random_tmp_file(self): fp, filename = mkstemp(suffix='cred', text=True) fp = os.fdopen(fp, "w") self.save_to_file(filename, save_parents=True, filep=fp) return filename - + def save_to_file(self, filename, save_parents=True, filep=None): if not self.xml: self.encode() if filep: - f = filep + f = filep else: f = open(filename, "w") if PY3 and isinstance(self.xml, bytes): @@ -602,12 +632,12 @@ class Credential(object): # Figure out what refids exist, and update this credential's id # so that it doesn't clobber the others. Returns the refids of # the parents. - + def updateRefID(self): if not self.parent: self.set_refid('ref0') return [] - + refs = [] next_cred = self.parent @@ -618,7 +648,6 @@ class Credential(object): else: next_cred = None - # Find a unique refid for this credential rid = self.get_refid() while rid in refs: @@ -642,7 +671,8 @@ class Credential(object): # WARNING: # In general, a signed credential obtained externally should # not be changed else the signature is no longer valid. So, once - # you have loaded an existing signed credential, do not call encode() or sign() on it. + # you have loaded an existing signed credential, do not call encode() or + # sign() on it. def sign(self): if not self.issuer_privkey: @@ -657,13 +687,13 @@ class Credential(object): # Create the signature template to be signed signature = Signature() signature.set_refid(self.get_refid()) - sdoc = parseString(signature.get_xml()) - sig_ele = doc.importNode(sdoc.getElementsByTagName("Signature")[0], True) + sdoc = parseString(signature.get_xml()) + sig_ele = doc.importNode( + sdoc.getElementsByTagName("Signature")[0], True) sigs.appendChild(sig_ele) self.xml = doc.toxml("utf-8") - # Split the issuer GID into multiple certificates if it's a chain chain = GID(filename=self.issuer_gid) gid_files = [] @@ -674,7 +704,6 @@ class Credential(object): else: chain = None - # Call out to xmlsec1 to sign it ref = 'Sig_{}'.format(self.get_refid()) filename = self.save_to_random_tmp_file() @@ -692,8 +721,7 @@ class Credential(object): self.xml = signed # Update signatures - self.decode() - + self.decode() ## # Retrieve the attributes of the credential from the XML. @@ -721,10 +749,11 @@ class Credential(object): sigs = signatures[0].getElementsByTagName("Signature") else: creds = doc.getElementsByTagName("credential") - + if creds is None or len(creds) == 0: # malformed cred file - raise CredentialNotVerifiable("Malformed XML: No credential tag found") + raise CredentialNotVerifiable( + "Malformed XML: No credential tag found") # Just take the first cred if there are more than one cred = creds[0] @@ -734,8 +763,7 @@ class Credential(object): self.gidCaller = GID(string=getTextNode(cred, "owner_gid")) self.gidObject = GID(string=getTextNode(cred, "target_gid")) - - ## This code until the end of function rewritten by Aaron Helsinger + # This code until the end of function rewritten by Aaron Helsinger # Process privileges rlist = Rights() priv_nodes = cred.getElementsByTagName("privileges") @@ -747,7 +775,7 @@ class Credential(object): if kind == '*': # Convert * into the default privileges for the credential's type # Each inherits the delegatability from the * above - _ , type = urn_to_hrn(self.gidObject.get_urn()) + _, type = urn_to_hrn(self.gidObject.get_urn()) rl = determine_rights(type, self.gidObject.get_urn()) for r in rl.rights: r.delegate = deleg @@ -756,14 +784,14 @@ class Credential(object): rlist.add(Right(kind.strip(), deleg)) self.set_privileges(rlist) - # Is there a parent? parent = cred.getElementsByTagName("parent") if len(parent) > 0: parent_doc = parent[0].getElementsByTagName("credential")[0] parent_xml = parent_doc.toxml("utf-8") if parent_xml is None or parent_xml.strip() == "": - raise CredentialNotVerifiable("Malformed XML: Had parent tag but it is empty") + raise CredentialNotVerifiable( + "Malformed XML: Had parent tag but it is empty") self.parent = Credential(string=parent_xml) self.updateRefID() @@ -774,17 +802,16 @@ class Credential(object): for cur_cred in self.get_credential_list(): if cur_cred.get_refid() == Sig.get_refid(): cur_cred.set_signature(Sig) - - + ## # Verify - # trusted_certs: A list of trusted GID filenames (not GID objects!) + # trusted_certs: A list of trusted GID filenames (not GID objects!) # Chaining is not supported within the GIDs by xmlsec1. # # trusted_certs_required: Should usually be true. Set False means an # empty list of trusted_certs would still let this method pass. # It just skips xmlsec1 verification et al. Only used by some utils - # + # # Verify that: # . All of the signatures are valid and that the issuers trace back # to trusted roots (performed by xmlsec1) @@ -839,12 +866,13 @@ class Credential(object): trusted_cert_objects.append(GID(filename=f)) ok_trusted_certs.append(f) except Exception as exc: - logger.error("Failed to load trusted cert from {}: {}".format(f, exc)) + logger.error( + "Failed to load trusted cert from {}: {}".format(f, exc)) trusted_certs = ok_trusted_certs # make sure it is not expired if self.get_expiration() < datetime.datetime.utcnow(): - raise CredentialNotVerifiable("Credential {} expired at {}" \ + raise CredentialNotVerifiable("Credential {} expired at {}" .format(self.pretty_cred(), self.expiration.strftime(SFATIME_FORMAT))) @@ -877,21 +905,23 @@ class Credential(object): # turns out, with fedora21, there is extra input before this 'OK' thing # looks like we're better off just using the exit code - that's what it is made for #cert_args = " ".join(['--trusted-pem {}'.format(x) for x in trusted_certs]) - #command = '{} --verify --node-id "{}" {} {} 2>&1'.\ + # command = '{} --verify --node-id "{}" {} {} 2>&1'.\ # format(self.xmlsec_path, ref, cert_args, filename) xmlsec1 = self.get_xmlsec1_path() if not xmlsec1: raise Exception("Could not locate required 'xmlsec1' program") - command = [ xmlsec1, '--verify', '--node-id', ref ] + command = [xmlsec1, '--verify', '--node-id', ref] for trusted in trusted_certs: - command += ["--trusted-pem", trusted ] - command += [ filename ] + command += ["--trusted-pem", trusted] + command += [filename] logger.debug("Running " + " ".join(command)) try: - verified = subprocess.check_output(command, stderr=subprocess.STDOUT) + verified = subprocess.check_output( + command, stderr=subprocess.STDOUT) logger.debug("xmlsec command returned {}".format(verified)) if "OK\n" not in verified: - logger.warning("WARNING: xmlsec1 seemed to return fine but without a OK in its output") + logger.warning( + "WARNING: xmlsec1 seemed to return fine but without a OK in its output") except subprocess.CalledProcessError as e: verified = e.output # xmlsec errors have a msg= which is the interesting bit. @@ -901,8 +931,9 @@ class Credential(object): mstart = mstart + 4 mend = verified.find('\\', mstart) msg = verified[mstart:mend] - logger.warning("Credential.verify - failed - xmlsec1 returned {}".format(verified.strip())) - raise CredentialNotVerifiable("xmlsec1 error verifying cred {} using Signature ID {}: {}"\ + logger.warning( + "Credential.verify - failed - xmlsec1 returned {}".format(verified.strip())) + raise CredentialNotVerifiable("xmlsec1 error verifying cred {} using Signature ID {}: {}" .format(self.pretty_cred(), ref, msg)) os.remove(filename) @@ -916,9 +947,9 @@ class Credential(object): return True ## - # Creates a list of the credential and its parents, with the root + # Creates a list of the credential and its parents, with the root # (original delegated credential) as the last item in the list - def get_credential_list(self): + def get_credential_list(self): cur_cred = self list = [] while cur_cred: @@ -928,7 +959,7 @@ class Credential(object): else: cur_cred = None return list - + ## # Make sure the credential's target gid (a) was signed by or (b) # is the same as the entity that signed the original credential, @@ -941,7 +972,7 @@ class Credential(object): if root_cred.get_signature() is None: # malformed raise CredentialNotVerifiable("Could not verify credential owned by {} for object {}. " - "Cred has no signature" \ + "Cred has no signature" .format(self.gidCaller.get_urn(), self.gidObject.get_urn())) root_cred_signer = root_cred.get_signature().get_issuer_gid() @@ -965,7 +996,7 @@ class Credential(object): # If not, remove this. #root_target_gid_str = root_target_gid.save_to_string() #root_cred_signer_str = root_cred_signer.save_to_string() - #if root_target_gid_str == root_cred_signer_str: + # if root_target_gid_str == root_cred_signer_str: # # cred signer is target, return success # return @@ -1020,7 +1051,7 @@ class Credential(object): # . The privileges must have "can_delegate" set for each delegated privilege # . The target gid must be the same between child and parents # . The expiry time on the child must be no later than the parent - # . The signer of the child must be the owner of the parent + # . The signer of the child must be the owner of the parent def verify_parent(self, parent_cred): # make sure the rights given to the child are a subset of the # parents rights (and check delegate bits) @@ -1028,13 +1059,15 @@ class Credential(object): message = ( "Parent cred {} (ref {}) rights {} " " not superset of delegated cred {} (ref {}) rights {}" - .format(parent_cred.pretty_cred(),parent_cred.get_refid(), + .format(parent_cred.pretty_cred(), parent_cred.get_refid(), parent_cred.get_privileges().pretty_rights(), self.pretty_cred(), self.get_refid(), self.get_privileges().pretty_rights())) logger.error(message) - logger.error("parent details {}".format(parent_cred.get_privileges().save_to_string())) - logger.error("self details {}".format(self.get_privileges().save_to_string())) + logger.error("parent details {}".format( + parent_cred.get_privileges().save_to_string())) + logger.error("self details {}".format( + self.get_privileges().save_to_string())) raise ChildRightsNotSubsetOfParent(message) # make sure my target gid is the same as the parent's @@ -1044,7 +1077,8 @@ class Credential(object): "Delegated cred {}: Target gid not equal between parent and child. Parent {}" .format(self.pretty_cred(), parent_cred.pretty_cred())) logger.error(message) - logger.error("parent details {}".format(parent_cred.save_to_string())) + logger.error("parent details {}".format( + parent_cred.save_to_string())) logger.error("self details {}".format(self.save_to_string())) raise CredentialNotVerifiable(message) @@ -1060,17 +1094,20 @@ class Credential(object): message = "Delegated credential {} not signed by parent {}'s caller"\ .format(self.pretty_cred(), parent_cred.pretty_cred()) logger.error(message) - logger.error("compare1 parent {}".format(parent_cred.get_gid_caller().pretty_cert())) - logger.error("compare1 parent details {}".format(parent_cred.get_gid_caller().save_to_string())) - logger.error("compare2 self {}".format(self.get_signature().get_issuer_gid().pretty_crert())) - logger.error("compare2 self details {}".format(self.get_signature().get_issuer_gid().save_to_string())) + logger.error("compare1 parent {}".format( + parent_cred.get_gid_caller().pretty_cert())) + logger.error("compare1 parent details {}".format( + parent_cred.get_gid_caller().save_to_string())) + logger.error("compare2 self {}".format( + self.get_signature().get_issuer_gid().pretty_crert())) + logger.error("compare2 self details {}".format( + self.get_signature().get_issuer_gid().save_to_string())) raise CredentialNotVerifiable(message) - + # Recurse if parent_cred.parent: parent_cred.verify_parent(parent_cred.parent) - def delegate(self, delegee_gidfile, caller_keyfile, caller_gidfile): """ Return a delegated copy of this credential, delegated to the @@ -1078,12 +1115,12 @@ class Credential(object): """ # get the gid of the object we are delegating object_gid = self.get_gid_object() - object_hrn = object_gid.get_hrn() - + object_hrn = object_gid.get_hrn() + # the hrn of the user who will be delegated to delegee_gid = GID(filename=delegee_gidfile) delegee_hrn = delegee_gid.get_hrn() - + #user_key = Keypair(filename=keyfile) #user_hrn = self.get_gid_caller().get_hrn() subject_string = "{} delegated to {}".format(object_hrn, delegee_hrn) @@ -1103,7 +1140,7 @@ class Credential(object): # only informative def get_filename(self): - return getattr(self,'filename',None) + return getattr(self, 'filename', None) def actual_caller_hrn(self): """ @@ -1121,7 +1158,8 @@ class Credential(object): """ caller_hrn, caller_type = urn_to_hrn(self.get_gid_caller().get_urn()) - issuer_hrn, issuer_type = urn_to_hrn(self.get_signature().get_issuer_gid().get_urn()) + issuer_hrn, issuer_type = urn_to_hrn( + self.get_signature().get_issuer_gid().get_urn()) subject_hrn = self.get_gid_object().get_hrn() # if the caller is a user and the issuer is not # it's probably the former @@ -1131,7 +1169,8 @@ class Credential(object): # this seems to be a 'regular' credential elif caller_hrn.startswith(issuer_hrn): actual_caller_hrn = caller_hrn - # else this looks like a delegated credential, and the real caller is the issuer + # else this looks like a delegated credential, and the real caller is + # the issuer else: actual_caller_hrn = issuer_hrn logger.info("actual_caller_hrn: caller_hrn={}, issuer_hrn={}, returning {}" @@ -1147,10 +1186,11 @@ class Credential(object): # SFA code ignores show_xml and disables printing the cred xml def dump_string(self, dump_parents=False, show_xml=False): - result="" + result = "" result += "CREDENTIAL {}\n".format(self.pretty_subject()) - filename=self.get_filename() - if filename: result += "Filename {}\n".format(filename) + filename = self.get_filename() + if filename: + result += "Filename {}\n".format(filename) privileges = self.get_privileges() if privileges: result += " privs: {}\n".format(privileges.save_to_string()) @@ -1166,7 +1206,8 @@ class Credential(object): result += self.get_signature().get_issuer_gid().dump_string(8, dump_parents) if self.expiration: - result += " expiration: " + self.expiration.strftime(SFATIME_FORMAT) + "\n" + result += " expiration: " + \ + self.expiration.strftime(SFATIME_FORMAT) + "\n" gidObject = self.get_gid_object() if gidObject: diff --git a/sfa/trust/credential_factory.py b/sfa/trust/credential_factory.py index cf5a8fbf..fb8206a3 100644 --- a/sfa/trust/credential_factory.py +++ b/sfa/trust/credential_factory.py @@ -34,6 +34,7 @@ import re # Specifically, this factory can create standard SFA credentials # and ABAC credentials from XML strings based on their identifying content + class CredentialFactory: UNKNOWN_CREDENTIAL_TYPE = 'geni_unknown' @@ -58,12 +59,14 @@ class CredentialFactory: @staticmethod def createCred(credString=None, credFile=None): if not credString and not credFile: - raise Exception("CredentialFactory.createCred called with no argument") + raise Exception( + "CredentialFactory.createCred called with no argument") if credFile: try: credString = open(credFile).read() except Exception as e: - logger.info("Error opening credential file %s: %s" % credFile, e) + logger.info("Error opening credential file %s: %s" % + credFile, e) return None # Try to treat the file as JSON, getting the cred_type from the struct @@ -73,7 +76,8 @@ class CredentialFactory: cred_type = credO['geni_type'] credString = credO['geni_value'] except Exception as e: - # It wasn't a struct. So the credString is XML. Pull the type directly from the string + # It wasn't a struct. So the credString is XML. Pull the type + # directly from the string logger.debug("Credential string not JSON: %s" % e) cred_type = CredentialFactory.getType(credString) @@ -84,9 +88,11 @@ class CredentialFactory: except Exception as e: if credFile: msg = "credString started: %s" % credString[:50] - raise Exception("%s not a parsable SFA credential: %s. " % (credFile, e) + msg) + raise Exception( + "%s not a parsable SFA credential: %s. " % (credFile, e) + msg) else: - raise Exception("SFA Credential not parsable: %s. Cred start: %s..." % (e, credString[:50])) + raise Exception( + "SFA Credential not parsable: %s. Cred start: %s..." % (e, credString[:50])) elif cred_type == ABACCredential.ABAC_CREDENTIAL_TYPE: try: @@ -94,9 +100,11 @@ class CredentialFactory: return cred except Exception as e: if credFile: - raise Exception("%s not a parsable ABAC credential: %s" % (credFile, e)) + raise Exception( + "%s not a parsable ABAC credential: %s" % (credFile, e)) else: - raise Exception("ABAC Credential not parsable: %s. Cred start: %s..." % (e, credString[:50])) + raise Exception( + "ABAC Credential not parsable: %s. Cred start: %s..." % (e, credString[:50])) else: raise Exception("Unknown credential type '%s'" % cred_type) diff --git a/sfa/trust/gid.py b/sfa/trust/gid.py index 3f903d98..b4900603 100644 --- a/sfa/trust/gid.py +++ b/sfa/trust/gid.py @@ -11,13 +11,13 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- ## @@ -39,6 +39,7 @@ from sfa.util.py23 import xmlrpc_client ## # Create a new uuid. Returns the UUID as a string. + def create_uuid(): return str(uuid.uuid4().int) @@ -54,7 +55,7 @@ def create_uuid(): # # URN is a human readable identifier of form: # "urn:publicid:IDN+toplevelauthority[:sub-auth.]*[\res. type]\ +object name" -# For example, urn:publicid:IDN+planetlab:us:arizona+user+bakers +# For example, urn:publicid:IDN+planetlab:us:arizona+user+bakers # # PUBLIC_KEY is the public key of the principal identified by the UUID/HRN. # It is a Keypair object as defined in the cert.py module. @@ -83,7 +84,7 @@ class GID(Certificate): self.uuid = None self.hrn = None self.urn = None - self.email = None # for adding to the SubjectAltName + self.email = None # for adding to the SubjectAltName Certificate.__init__(self, lifeDays, create, subject, string, filename) if subject: @@ -123,11 +124,11 @@ class GID(Certificate): def set_urn(self, urn): self.urn = urn self.hrn, type = urn_to_hrn(urn) - + def get_urn(self): if not self.urn: self.decode() - return self.urn + return self.urn # Will be stuffed into subjectAltName def set_email(self, email): @@ -143,7 +144,7 @@ class GID(Certificate): self.decode() _, t = urn_to_hrn(self.urn) return t - + ## # Encode the GID fields and package them into the subject-alt-name field # of the X509 certificate. This must be called prior to signing the @@ -154,18 +155,17 @@ class GID(Certificate): urn = self.urn else: urn = hrn_to_urn(self.hrn, None) - + str = "URI:" + urn if self.uuid: str += ", " + "URI:" + uuid.UUID(int=self.uuid).urn - + if self.email: str += ", " + "email:" + self.email self.set_data(str, 'subjectAltName') - ## # Decode the subject-alt-name field of the X509 certificate into the # fields of the GID. This is automatically called by the various get_*() @@ -188,7 +188,7 @@ class GID(Certificate): # FIXME: Ensure there isn't cruft in that address... # EG look for email:copy,.... dict['email'] = val[6:] - + self.uuid = dict.get("uuid", None) self.urn = dict.get("urn", None) self.hrn = dict.get("hrn", None) @@ -203,21 +203,22 @@ class GID(Certificate): # @param dump_parents If true, also dump the parents of the GID def dump(self, *args, **kwargs): - print(self.dump_string(*args,**kwargs)) + print(self.dump_string(*args, **kwargs)) def dump_string(self, indent=0, dump_parents=False): - result=" "*(indent-2) + "GID\n" - result += " "*indent + "hrn:" + str(self.get_hrn()) +"\n" - result += " "*indent + "urn:" + str(self.get_urn()) +"\n" - result += " "*indent + "uuid:" + str(self.get_uuid()) + "\n" + result = " " * (indent - 2) + "GID\n" + result += " " * indent + "hrn:" + str(self.get_hrn()) + "\n" + result += " " * indent + "urn:" + str(self.get_urn()) + "\n" + result += " " * indent + "uuid:" + str(self.get_uuid()) + "\n" if self.get_email() is not None: - result += " "*indent + "email:" + str(self.get_email()) + "\n" - filename=self.get_filename() - if filename: result += "Filename %s\n"%filename + result += " " * indent + "email:" + str(self.get_email()) + "\n" + filename = self.get_filename() + if filename: + result += "Filename %s\n" % filename if self.parent and dump_parents: - result += " "*indent + "parent:\n" - result += self.parent.dump_string(indent+4, dump_parents) + result += " " * indent + "parent:\n" + result += self.parent.dump_string(indent + 4, dump_parents) return result ## @@ -230,10 +231,10 @@ class GID(Certificate): # for a principal that is not a member of that authority. For example, # planetlab.us.arizona cannot sign a GID for planetlab.us.princeton.foo. - def verify_chain(self, trusted_certs = None): + def verify_chain(self, trusted_certs=None): # do the normal certificate verification stuff - trusted_root = Certificate.verify_chain(self, trusted_certs) - + trusted_root = Certificate.verify_chain(self, trusted_certs) + if self.parent: # make sure the parent's hrn is a prefix of the child's hrn if not hrn_authfor_hrn(self.parent.get_hrn(), self.get_hrn()): @@ -256,7 +257,7 @@ class GID(Certificate): trusted_gid = GID(string=trusted_root.save_to_string()) trusted_type = trusted_gid.get_type() trusted_hrn = trusted_gid.get_hrn() - #if trusted_type == 'authority': + # if trusted_type == 'authority': # trusted_hrn = trusted_hrn[:trusted_hrn.rindex('.')] cur_hrn = self.get_hrn() if not hrn_authfor_hrn(trusted_hrn, cur_hrn): diff --git a/sfa/trust/hierarchy.py b/sfa/trust/hierarchy.py index 5e76dbf3..8e126a73 100644 --- a/sfa/trust/hierarchy.py +++ b/sfa/trust/hierarchy.py @@ -26,6 +26,7 @@ from sfa.trust.sfaticket import SfaTicket # The AuthInfo class contains the information for an authority. This information # includes the GID, private key, and database connection information. + class AuthInfo: hrn = None gid_object = None @@ -64,14 +65,14 @@ class AuthInfo: def get_gid_object(self): if not self.gid_object: - self.gid_object = GID(filename = self.gid_filename) + self.gid_object = GID(filename=self.gid_filename) return self.gid_object ## # Get the private key in the form of a Keypair object def get_pkey_object(self): - return Keypair(filename = self.privkey_filename) + return Keypair(filename=self.privkey_filename) ## # Replace the GID with a new one. The file specified by gid_filename is @@ -92,13 +93,14 @@ class AuthInfo: # contains the GID and pkey files for that authority (as well as # subdirectories for each sub-authority) + class Hierarchy: ## # Create the hierarchy object. # # @param basedir the base directory to store the hierarchy in - def __init__(self, basedir = None): + def __init__(self, basedir=None): self.config = Config() if not basedir: basedir = os.path.join(self.config.SFA_DATA_DIR, "authorities") @@ -119,8 +121,8 @@ class Hierarchy: parent_hrn = get_authority(hrn) directory = os.path.join(self.basedir, hrn.replace(".", "/")) - gid_filename = os.path.join(directory, leaf+".gid") - privkey_filename = os.path.join(directory, leaf+".pkey") + gid_filename = os.path.join(directory, leaf + ".gid") + privkey_filename = os.path.join(directory, leaf + ".pkey") return (directory, gid_filename, privkey_filename) @@ -131,30 +133,30 @@ class Hierarchy: # @param the human readable name of the authority to check def auth_exists(self, xrn): - hrn, type = urn_to_hrn(xrn) + hrn, type = urn_to_hrn(xrn) (directory, gid_filename, privkey_filename) = \ self.get_auth_filenames(hrn) - - return os.path.exists(gid_filename) and os.path.exists(privkey_filename) + + return os.path.exists(gid_filename) and os.path.exists(privkey_filename) ## # Create an authority. A private key for the authority and the associated # GID are created and signed by the parent authority. # - # @param xrn the human readable name of the authority to create (urn will be converted to hrn) + # @param xrn the human readable name of the authority to create (urn will be converted to hrn) # @param create_parents if true, also create the parents if they do not exist def create_auth(self, xrn, create_parents=False): hrn, type = urn_to_hrn(str(xrn)) - logger.debug("Hierarchy: creating authority: %s"% hrn) + logger.debug("Hierarchy: creating authority: %s" % hrn) # create the parent authority if necessary parent_hrn = get_authority(hrn) parent_urn = hrn_to_urn(parent_hrn, 'authority') if (parent_hrn) and (not self.auth_exists(parent_urn)) and (create_parents): self.create_auth(parent_urn, create_parents) - (directory, gid_filename, privkey_filename,) = \ - self.get_auth_filenames(hrn) + directory, gid_filename, privkey_filename = self.get_auth_filenames( + hrn) # create the directory to hold the files try: @@ -166,10 +168,11 @@ class Hierarchy: pass if os.path.exists(privkey_filename): - logger.debug("using existing key %r for authority %r"%(privkey_filename,hrn)) - pkey = Keypair(filename = privkey_filename) + logger.debug("using existing key %r for authority %r" + % (privkey_filename, hrn)) + pkey = Keypair(filename=privkey_filename) else: - pkey = Keypair(create = True) + pkey = Keypair(create=True) pkey.save_to_file(privkey_filename) gid = self.create_gid(xrn, create_uuid(), pkey) @@ -182,13 +185,12 @@ class Hierarchy: # create the authority if it doesnt alrady exist if not self.auth_exists(hrn): self.create_auth(hrn, create_parents=True) - - + def get_interface_auth_info(self, create=True): hrn = self.config.SFA_INTERFACE_HRN if not self.auth_exists(hrn): - if create==True: - self.create_top_level_auth(hrn) + if create == True: + self.create_top_level_auth(hrn) else: raise MissingAuthority(hrn) return self.get_auth_info(hrn) @@ -202,7 +204,8 @@ class Hierarchy: def get_auth_info(self, xrn): hrn, type = urn_to_hrn(xrn) if not self.auth_exists(hrn): - logger.warning("Hierarchy: missing authority - xrn=%s, hrn=%s"%(xrn,hrn)) + logger.warning( + "Hierarchy: missing authority - xrn=%s, hrn=%s" % (xrn, hrn)) raise MissingAuthority(hrn) (directory, gid_filename, privkey_filename, ) = \ @@ -234,12 +237,12 @@ class Hierarchy: parent_hrn = force_parent if force_parent else get_authority(hrn) # Using hrn_to_urn() here to make sure the urn is in the right format # If xrn was a hrn instead of a urn, then the gid's urn will be - # of type None + # of type None urn = hrn_to_urn(hrn, type) gid = GID(subject=hrn, uuid=uuid, hrn=hrn, urn=urn, email=email) # is this a CA cert if hrn == self.config.SFA_INTERFACE_HRN or not parent_hrn: - # root or sub authority + # root or sub authority gid.set_intermediate_ca(True) elif type and 'authority' in type: # authority type @@ -257,7 +260,8 @@ class Hierarchy: else: # we need the parent's private key in order to sign this GID parent_auth_info = self.get_auth_info(parent_hrn) - gid.set_issuer(parent_auth_info.get_pkey_object(), parent_auth_info.hrn) + gid.set_issuer(parent_auth_info.get_pkey_object(), + parent_auth_info.hrn) gid.set_parent(parent_auth_info.get_gid_object()) gid.set_pubkey(pkey) @@ -282,7 +286,7 @@ class Hierarchy: # update the gid if we need to if gid_is_expired or xrn or uuid or pubkey: - + if not xrn: xrn = gid.get_urn() if not uuid: @@ -303,7 +307,7 @@ class Hierarchy: # @param authority type of credential to return (authority | sa | ma) def get_auth_cred(self, xrn, kind="authority"): - hrn, type = urn_to_hrn(xrn) + hrn, type = urn_to_hrn(xrn) auth_info = self.get_auth_info(hrn) gid = auth_info.get_gid_object() @@ -312,19 +316,20 @@ class Hierarchy: cred.set_gid_object(gid) cred.set_privileges(kind) cred.get_privileges().delegate_all_privileges(True) - #cred.set_pubkey(auth_info.get_gid_object().get_pubkey()) + # cred.set_pubkey(auth_info.get_gid_object().get_pubkey()) parent_hrn = get_authority(hrn) if not parent_hrn or hrn == self.config.SFA_INTERFACE_HRN: # if there is no parent hrn, then it must be self-signed. this # is where we terminate the recursion - cred.set_issuer_keys(auth_info.get_privkey_filename(), auth_info.get_gid_filename()) + cred.set_issuer_keys( + auth_info.get_privkey_filename(), auth_info.get_gid_filename()) else: # we need the parent's private key in order to sign this GID parent_auth_info = self.get_auth_info(parent_hrn) - cred.set_issuer_keys(parent_auth_info.get_privkey_filename(), parent_auth_info.get_gid_filename()) + cred.set_issuer_keys(parent_auth_info.get_privkey_filename( + ), parent_auth_info.get_gid_filename()) - cred.set_parent(self.get_auth_cred(parent_hrn, kind)) cred.encode() @@ -361,11 +366,11 @@ class Hierarchy: else: # we need the parent's private key in order to sign this GID parent_auth_info = self.get_auth_info(parent_hrn) - ticket.set_issuer(parent_auth_info.get_pkey_object(), parent_auth_info.hrn) + ticket.set_issuer( + parent_auth_info.get_pkey_object(), parent_auth_info.hrn) ticket.set_parent(self.get_auth_cred(parent_hrn)) ticket.encode() ticket.sign() return ticket - diff --git a/sfa/trust/rights.py b/sfa/trust/rights.py index 28e11b35..0d329e57 100644 --- a/sfa/trust/rights.py +++ b/sfa/trust/rights.py @@ -33,7 +33,6 @@ ## - ## # privilege_table is a list of priviliges and what operations are allowed # per privilege. @@ -46,16 +45,15 @@ privilege_table = {"authority": ["register", "remove", "update", "resolve", "lis "sa": ["getticket", "redeemslice", "redeemticket", "createslice", "createsliver", "deleteslice", "deletesliver", "updateslice", "getsliceresources", "getticket", "loanresources", "stopslice", "startslice", "renewsliver", "deleteslice", "deletesliver", "resetslice", "listslices", "listnodes", "getpolicy", "sliverstatus"], - "embed": ["getticket", "redeemslice", "redeemticket", "createslice", "createsliver", "renewsliver", "deleteslice", + "embed": ["getticket", "redeemslice", "redeemticket", "createslice", "createsliver", "renewsliver", "deleteslice", "deletesliver", "updateslice", "sliverstatus", "getsliceresources", "shutdown"], "bind": ["getticket", "loanresources", "redeemticket"], - "control": ["updateslice", "createslice", "createsliver", "renewsliver", "sliverstatus", "stopslice", "startslice", + "control": ["updateslice", "createslice", "createsliver", "renewsliver", "sliverstatus", "stopslice", "startslice", "deleteslice", "deletesliver", "resetslice", "getsliceresources", "getgids"], "info": ["listslices", "listnodes", "getpolicy"], "ma": ["setbootstate", "getbootstate", "reboot", "getgids", "gettrustedcerts"], - "operator": ["gettrustedcerts", "getgids"], - "*": ["createsliver", "deletesliver", "sliverstatus", "renewsliver", "shutdown"]} - + "operator": ["gettrustedcerts", "getgids"], + "*": ["createsliver", "deletesliver", "sliverstatus", "renewsliver", "shutdown"]} ## @@ -105,7 +103,6 @@ def determine_rights(type, name): # The Right class represents a single privilege. - class Right: ## # Create a new right. @@ -116,7 +113,7 @@ class Right: self.kind = kind self.delegate = delegate - def __repr__ (self): return ""%self.kind + def __repr__(self): return "" % self.kind ## # Test to see if this right object is allowed to perform an operation. @@ -161,6 +158,7 @@ class Right: ## # A Rights object represents a list of privileges. + class Rights: ## # Create a new rightlist object, containing no rights. @@ -172,7 +170,8 @@ class Rights: if string: self.load_from_string(string) - def __repr__ (self): return "[" + " ".join( ["%s"%r for r in self.rights]) + "]" + def __repr__(self): return "[" + \ + " ".join(["%s" % r for r in self.rights]) + "]" def is_empty(self): return self.rights == [] @@ -227,7 +226,7 @@ class Rights: # @param op_name is an operation to check, for example "listslices" def can_perform(self, op_name): - + for right in self.rights: if right.can_perform(op_name): return True @@ -252,7 +251,6 @@ class Rights: return False return True - ## # set the delegate bit to 'delegate' on # all privileges diff --git a/sfa/trust/sfaticket.py b/sfa/trust/sfaticket.py index 6d4a009a..2c1e38ba 100644 --- a/sfa/trust/sfaticket.py +++ b/sfa/trust/sfaticket.py @@ -39,6 +39,7 @@ from sfa.util.py23 import xmlrpc_client # attributes = slice attributes (keys, vref, instantiation, etc) # rspec = resources + class SfaTicket(Certificate): gidCaller = None gidObject = None @@ -96,9 +97,11 @@ class SfaTicket(Certificate): "rspec": self.rspec, "delegate": self.delegate} if self.gidCaller: - dict["gidCaller"] = self.gidCaller.save_to_string(save_parents=True) + dict["gidCaller"] = self.gidCaller.save_to_string( + save_parents=True) if self.gidObject: - dict["gidObject"] = self.gidObject.save_to_string(save_parents=True) + dict["gidObject"] = self.gidObject.save_to_string( + save_parents=True) str = "URI:" + xmlrpc_client.dumps((dict,), allow_none=True) self.set_data(str) @@ -146,5 +149,5 @@ class SfaTicket(Certificate): print(" ", self.get_rspec()) if self.parent and dump_parents: - print("PARENT", end=' ') - self.parent.dump(dump_parents) + print("PARENT", end=' ') + self.parent.dump(dump_parents) diff --git a/sfa/trust/speaksfor_util.py b/sfa/trust/speaksfor_util.py index 640d5125..12ee8d04 100644 --- a/sfa/trust/speaksfor_util.py +++ b/sfa/trust/speaksfor_util.py @@ -42,7 +42,7 @@ from sfa.trust.gid import GID from sfa.util.sfalogging import logger from sfa.util.py23 import StringIO -# Routine to validate that a speaks-for credential +# Routine to validate that a speaks-for credential # says what it claims to say: # It is a signed credential wherein the signer S is attesting to the # ABAC statement: @@ -54,6 +54,8 @@ from sfa.util.py23 import StringIO # Simple XML helper functions # Find the text associated with first child text node + + def findTextChildValue(root): child = findChildNamed(root, '#text') if child: @@ -61,6 +63,8 @@ def findTextChildValue(root): return None # Find first child with given name + + def findChildNamed(root, name): for child in root.childNodes: if child.nodeName == name: @@ -68,6 +72,8 @@ def findChildNamed(root, name): return None # Write a string to a tempfile, returning name of tempfile + + def write_to_tempfile(str): str_fd, str_file = tempfile.mkstemp() if str: @@ -76,6 +82,8 @@ def write_to_tempfile(str): return str_file # Run a subprocess and return output + + def run_subprocess(cmd, stdout, stderr): try: proc = subprocess.Popen(cmd, stdout=stdout, stderr=stderr) @@ -86,7 +94,9 @@ def run_subprocess(cmd, stdout, stderr): output = proc.returncode return output except Exception as e: - raise Exception("Failed call to subprocess '{}': {}".format(" ".join(cmd), e)) + raise Exception( + "Failed call to subprocess '{}': {}".format(" ".join(cmd), e)) + def get_cert_keyid(gid): """Extract the subject key identifier from the given certificate. @@ -102,6 +112,8 @@ def get_cert_keyid(gid): return keyid # Pull the cert out of a list of certs in a PEM formatted cert string + + def grab_toplevel_cert(cert): start_label = '-----BEGIN CERTIFICATE-----' if cert.find(start_label) > -1: @@ -118,9 +130,9 @@ def grab_toplevel_cert(cert): # Validate that the given speaks-for credential represents the # statement User.speaks_for(User)<-Tool for the given user and tool certs # and was signed by the user -# Return: -# Boolean indicating whether the given credential -# is not expired +# Return: +# Boolean indicating whether the given credential +# is not expired # is an ABAC credential # was signed by the user associated with the speaking_for_urn # is verified by xmlsec1 @@ -130,6 +142,8 @@ def grab_toplevel_cert(cert): # String user certificate of speaking_for user if the above tests succeed # (None otherwise) # Error message indicating why the speaks_for call failed ("" otherwise) + + def verify_speaks_for(cred, tool_gid, speaking_for_urn, trusted_roots, schema=None, logger=None): @@ -154,7 +168,7 @@ def verify_speaks_for(cred, tool_gid, speaking_for_urn, .format(user_urn, speaking_for_urn, cred.pretty_cred()) tails = cred.get_tails() - if len(tails) != 1: + if len(tails) != 1: return False, None, "Invalid ABAC-SF credential: Need exactly 1 tail element, got {} ({})"\ .format(len(tails), cred.pretty_cred()) @@ -172,11 +186,12 @@ def verify_speaks_for(cred, tool_gid, speaking_for_urn, if trusted_roots: for x in trusted_roots: cert_args += ['--trusted-pem', x.filename] - # FIXME: Why do we not need to specify the --node-id option as credential.py does? + # FIXME: Why do we not need to specify the --node-id option as + # credential.py does? xmlsec1 = cred.get_xmlsec1_path() if not xmlsec1: raise Exception("Could not locate required 'xmlsec1' program") - xmlsec1_args = [xmlsec1, '--verify'] + cert_args + [ cred_file] + xmlsec1_args = [xmlsec1, '--verify'] + cert_args + [cred_file] output = run_subprocess(xmlsec1_args, stdout=None, stderr=subprocess.PIPE) os.unlink(cred_file) if output != 0: @@ -208,7 +223,8 @@ def verify_speaks_for(cred, tool_gid, speaking_for_urn, xmlschema = etree.XMLSchema(schema_doc) if not xmlschema.validate(tree): error = xmlschema.error_log.last_error - message = "{}: {} (line {})".format(cred.pretty_cred(), error.message, error.line) + message = "{}: {} (line {})".format( + cred.pretty_cred(), error.message, error.line) return False, None, ("XML Credential schema invalid: {}".format(message)) if trusted_roots: @@ -233,20 +249,23 @@ def verify_speaks_for(cred, tool_gid, speaking_for_urn, # # credentials is a list of GENI-style credentials: # Either a cred string xml string, or Credential object of a tuple -# [{'geni_type' : geni_type, 'geni_value : cred_value, +# [{'geni_type' : geni_type, 'geni_value : cred_value, # 'geni_version' : version}] # caller_gid is the raw X509 cert gid # options is the dictionary of API-provided options # trusted_roots is a list of Certificate objects from the system # trusted_root directory # Optionally, provide an XML schema against which to validate the credential + + def determine_speaks_for(logger, credentials, caller_gid, speaking_for_xrn, trusted_roots, schema=None): if speaking_for_xrn: - speaking_for_urn = Xrn (speaking_for_xrn.strip()).get_urn() + speaking_for_urn = Xrn(speaking_for_xrn.strip()).get_urn() for cred in credentials: # Skip things that aren't ABAC credentials if type(cred) == dict: - if cred['geni_type'] != ABACCredential.ABAC_CREDENTIAL_TYPE: continue + if cred['geni_type'] != ABACCredential.ABAC_CREDENTIAL_TYPE: + continue cred_value = cred['geni_value'] elif isinstance(cred, Credential): if not isinstance(cred, ABACCredential): @@ -254,7 +273,8 @@ def determine_speaks_for(logger, credentials, caller_gid, speaking_for_xrn, trus else: cred_value = cred else: - if CredentialFactory.getType(cred) != ABACCredential.ABAC_CREDENTIAL_TYPE: continue + if CredentialFactory.getType(cred) != ABACCredential.ABAC_CREDENTIAL_TYPE: + continue cred_value = cred # If the cred_value is xml, create the object @@ -271,13 +291,16 @@ def determine_speaks_for(logger, credentials, caller_gid, speaking_for_xrn, trus trusted_roots, schema, logger=logger) logger.info(msg) if is_valid_speaks_for: - return user_gid # speaks-for + return user_gid # speaks-for else: logger.info("Got speaks-for option but not a valid speaks_for with this credential: {}" .format(msg)) - return caller_gid # Not speaks-for + return caller_gid # Not speaks-for + +# Create an ABAC Speaks For credential using the ABACCredential object and +# it's encode&sign methods + -# Create an ABAC Speaks For credential using the ABACCredential object and it's encode&sign methods def create_sign_abaccred(tool_gid, user_gid, ma_gid, user_key_file, cred_filename, dur_days=365): logger.info("Creating ABAC SpeaksFor using ABACCredential...\n") # Write out the user cert @@ -298,9 +321,11 @@ def create_sign_abaccred(tool_gid, user_gid, ma_gid, user_key_file, cred_filenam user_urn = user_gid.get_urn() user_keyid = get_cert_keyid(user_gid) tool_keyid = get_cert_keyid(tool_gid) - cred.head = ABACElement(user_keyid, user_urn, "speaks_for_{}".format(user_keyid)) + cred.head = ABACElement(user_keyid, user_urn, + "speaks_for_{}".format(user_keyid)) cred.tails.append(ABACElement(tool_keyid, tool_urn)) - cred.set_expiration(datetime.datetime.utcnow() + datetime.timedelta(days=dur_days)) + cred.set_expiration(datetime.datetime.utcnow() + + datetime.timedelta(days=dur_days)) cred.expiration = cred.expiration.replace(microsecond=0) # Produce the cred XML @@ -314,7 +339,9 @@ def create_sign_abaccred(tool_gid, user_gid, ma_gid, user_key_file, cred_filenam .format(cred.pretty_cred(), cred_filename)) # FIXME: Assumes signer is itself signed by an 'ma_gid' that can be trusted -def create_speaks_for(tool_gid, user_gid, ma_gid, + + +def create_speaks_for(tool_gid, user_gid, ma_gid, user_key_file, cred_filename, dur_days=365): tool_urn = tool_gid.get_urn() user_urn = user_gid.get_urn() @@ -363,14 +390,14 @@ def create_speaks_for(tool_gid, user_gid, ma_gid, unsigned_cred_filename = write_to_tempfile(unsigned_cred) # Now sign the file with xmlsec1 - # xmlsec1 --sign --privkey-pem privkey.pem,cert.pem + # xmlsec1 --sign --privkey-pem privkey.pem,cert.pem # --output signed.xml tosign.xml pems = "{},{},{}".format(user_key_file, user_gid.get_filename(), ma_gid.get_filename()) xmlsec1 = Credential.get_xmlsec1_path() if not xmlsec1: raise Exception("Could not locate required 'xmlsec1' program") - cmd = [ xmlsec1, '--sign', '--privkey-pem', pems, + cmd = [xmlsec1, '--sign', '--privkey-pem', pems, '--output', cred_filename, unsigned_cred_filename] # print(" ".join(cmd)) @@ -387,19 +414,19 @@ def create_speaks_for(tool_gid, user_gid, ma_gid, if __name__ == "__main__": parser = optparse.OptionParser() - parser.add_option('--cred_file', + parser.add_option('--cred_file', help='Name of credential file') - parser.add_option('--tool_cert_file', + parser.add_option('--tool_cert_file', help='Name of file containing tool certificate') - parser.add_option('--user_urn', + parser.add_option('--user_urn', help='URN of speaks-for user') - parser.add_option('--user_cert_file', + parser.add_option('--user_cert_file', help="filename of x509 certificate of signing user") - parser.add_option('--ma_cert_file', + parser.add_option('--ma_cert_file', help="filename of x509 cert of MA that signed user cert") - parser.add_option('--user_key_file', + parser.add_option('--user_key_file', help="filename of private key of signing user") - parser.add_option('--trusted_roots_directory', + parser.add_option('--trusted_roots_directory', help='Directory of trusted root certs') parser.add_option('--create', help="name of file of ABAC speaksfor cred to create") @@ -412,7 +439,7 @@ if __name__ == "__main__": if options.create: if options.user_cert_file and options.user_key_file \ - and options.ma_cert_file: + and options.ma_cert_file: user_gid = GID(filename=options.user_cert_file) ma_gid = GID(filename=options.ma_cert_file) if options.useObject: @@ -424,8 +451,8 @@ if __name__ == "__main__": options.user_key_file, options.create) else: - print("Usage: --create cred_file " + - "--user_cert_file user_cert_file" + + print("Usage: --create cred_file " + + "--user_cert_file user_cert_file" + " --user_key_file user_key_file --ma_cert_file ma_cert_file") sys.exit() @@ -437,18 +464,17 @@ if __name__ == "__main__": trusted_roots_directory = options.trusted_roots_directory trusted_roots = \ - [Certificate(filename=os.path.join(trusted_roots_directory, file)) - for file in os.listdir(trusted_roots_directory) - if file.endswith('.pem') and file != 'CATedCACerts.pem'] + [Certificate(filename=os.path.join(trusted_roots_directory, file)) + for file in os.listdir(trusted_roots_directory) + if file.endswith('.pem') and file != 'CATedCACerts.pem'] cred = open(options.cred_file).read() - creds = [{'geni_type' : ABACCredential.ABAC_CREDENTIAL_TYPE, 'geni_value' : cred, - 'geni_version' : '1'}] + creds = [{'geni_type': ABACCredential.ABAC_CREDENTIAL_TYPE, 'geni_value': cred, + 'geni_version': '1'}] gid = determine_speaks_for(None, creds, tool_gid, - {'geni_speaking_for' : user_urn}, + {'geni_speaking_for': user_urn}, trusted_roots) - print('SPEAKS_FOR = {}'.format(gid != tool_gid)) print("CERT URN = {}".format(gid.get_urn())) diff --git a/sfa/trust/trustedroots.py b/sfa/trust/trustedroots.py index fb6b6425..5f8dcd3c 100644 --- a/sfa/trust/trustedroots.py +++ b/sfa/trust/trustedroots.py @@ -4,17 +4,18 @@ import glob from sfa.trust.gid import GID from sfa.util.sfalogging import logger + class TrustedRoots: - + # we want to avoid reading all files in the directory # this is because it's common to have backups of all kinds # e.g. *~, *.hide, *-00, *.bak and the like - supported_extensions= [ 'gid', 'cert', 'pem' ] + supported_extensions = ['gid', 'cert', 'pem'] def __init__(self, dir): self.basedir = dir # create the directory to hold the files, if not existing - if not os.path.isdir (self.basedir): + if not os.path.isdir(self.basedir): os.makedirs(self.basedir) def add_gid(self, gid): @@ -22,22 +23,23 @@ class TrustedRoots: gid.save_to_file(fn) def get_list(self): - gid_list = [GID(filename=cert_file) for cert_file in self.get_file_list()] + gid_list = [GID(filename=cert_file) + for cert_file in self.get_file_list()] return gid_list def get_file_list(self): - file_list = [] - pattern = os.path.join(self.basedir,"*") + file_list = [] + pattern = os.path.join(self.basedir, "*") for cert_file in glob.glob(pattern): if os.path.isfile(cert_file): if self.has_supported_extension(cert_file): - file_list.append(cert_file) + file_list.append(cert_file) else: logger.warning("File {} ignored - supported extensions are {}" .format(cert_file, TrustedRoots.supported_extensions)) return file_list - def has_supported_extension (self,path): - (_,ext)=os.path.splitext(path) - ext=ext.replace('.','').lower() + def has_supported_extension(self, path): + _, ext = os.path.splitext(path) + ext = ext.replace('.', '').lower() return ext in TrustedRoots.supported_extensions diff --git a/sfa/util/cache.py b/sfa/util/cache.py index 75b22e49..175968b5 100644 --- a/sfa/util/cache.py +++ b/sfa/util/cache.py @@ -7,9 +7,10 @@ import threading import pickle from datetime import datetime -# maximum lifetime of cached data (in seconds) +# maximum lifetime of cached data (in seconds) DEFAULT_CACHE_TTL = 60 * 60 + class CacheData: data = None @@ -17,7 +18,7 @@ class CacheData: expires = None lock = None - def __init__(self, data, ttl = DEFAULT_CACHE_TTL): + def __init__(self, data, ttl=DEFAULT_CACHE_TTL): self.lock = threading.RLock() self.data = data self.renew(ttl) @@ -31,26 +32,25 @@ class CacheData: def get_expires_date(self): return str(datetime.fromtimestamp(self.expires)) - def renew(self, ttl = DEFAULT_CACHE_TTL): + def renew(self, ttl=DEFAULT_CACHE_TTL): self.created = time.time() - self.expires = self.created + ttl - - def set_data(self, data, renew=True, ttl = DEFAULT_CACHE_TTL): - with self.lock: + self.expires = self.created + ttl + + def set_data(self, data, renew=True, ttl=DEFAULT_CACHE_TTL): + with self.lock: self.data = data if renew: self.renew(ttl) - + def get_data(self): return self.data - def dump(self): return self.__dict__ def __str__(self): - return str(self.dump()) - + return str(self.dump()) + def tostring(self): return self.__str__() @@ -62,38 +62,38 @@ class CacheData: def __setstate__(self, d): self.__dict__.update(d) self.lock = threading.RLock() - + class Cache: - cache = {} + cache = {} lock = threading.RLock() def __init__(self, filename=None): if filename: self.load_from_file(filename) - - def add(self, key, value, ttl = DEFAULT_CACHE_TTL): + + def add(self, key, value, ttl=DEFAULT_CACHE_TTL): with self.lock: if key in self.cache: self.cache[key].set_data(value, ttl=ttl) else: self.cache[key] = CacheData(value, ttl) - + def get(self, key): data = self.cache.get(key) - if not data: + if not data: data = None elif data.is_expired(): self.pop(key) - data = None + data = None else: data = data.get_data() return data def pop(self, key): if key in self.cache: - self.cache.pop(key) + self.cache.pop(key) def dump(self): result = {} @@ -102,10 +102,10 @@ class Cache: return result def __str__(self): - return str(self.dump()) - + return str(self.dump()) + def tostring(self): - return self.__str() + return self.__str() def save_to_file(self, filename): f = open(filename, 'wb') diff --git a/sfa/util/callids.py b/sfa/util/callids.py index 67485467..35c67d4d 100644 --- a/sfa/util/callids.py +++ b/sfa/util/callids.py @@ -11,62 +11,74 @@ memory-only for now - thread-safe implemented as a (singleton) hash 'callid'->timestamp """ -debug=False +debug = False + class _call_ids_impl (dict): _instance = None # 5 minutes sounds amply enough - purge_timeout=5*60 + purge_timeout = 5 * 60 # when trying to get a lock - retries=10 + retries = 10 # in ms - wait_ms=100 + wait_ms = 100 - def __init__(self): - self._lock=threading.Lock() + def __init__(self): + self._lock = threading.Lock() # the only primitive # return True if the callid is unknown, False otherwise - def already_handled (self,call_id): + def already_handled(self, call_id): # if not provided in the call... - if not call_id: return False - has_lock=False + if not call_id: + return False + has_lock = False for attempt in range(_call_ids_impl.retries): - if debug: logger.debug("Waiting for lock (%d)"%attempt) - if self._lock.acquire(False): - has_lock=True - if debug: logger.debug("got lock (%d)"%attempt) + if debug: + logger.debug("Waiting for lock (%d)" % attempt) + if self._lock.acquire(False): + has_lock = True + if debug: + logger.debug("got lock (%d)" % attempt) break - time.sleep(float(_call_ids_impl.wait_ms)/1000) + time.sleep(float(_call_ids_impl.wait_ms) / 1000) # in the unlikely event where we can't get the lock if not has_lock: - logger.warning("_call_ids_impl.should_handle_call_id: could not acquire lock") + logger.warning( + "_call_ids_impl.should_handle_call_id: could not acquire lock") return False # we're good to go if call_id in self: self._purge() self._lock.release() return True - self[call_id]=time.time() + self[call_id] = time.time() self._purge() self._lock.release() - if debug: logger.debug("released lock") + if debug: + logger.debug("released lock") return False - + def _purge(self): - now=time.time() - o_keys=[] - for (k,v) in self.iteritems(): - if (now-v) >= _call_ids_impl.purge_timeout: o_keys.append(k) - for k in o_keys: - if debug: logger.debug("Purging call_id %r (%s)"%(k,time.strftime("%H:%M:%S",time.localtime(self[k])))) + now = time.time() + o_keys = [] + for (k, v) in self.iteritems(): + if (now - v) >= _call_ids_impl.purge_timeout: + o_keys.append(k) + for k in o_keys: + if debug: + logger.debug("Purging call_id %r (%s)" % ( + k, time.strftime("%H:%M:%S", time.localtime(self[k])))) del self[k] if debug: logger.debug("AFTER PURGE") - for (k,v) in self.iteritems(): logger.debug("%s -> %s"%(k,time.strftime("%H:%M:%S",time.localtime(v)))) - -def Callids (): + for (k, v) in self.iteritems(): + logger.debug("%s -> %s" % + (k, time.strftime("%H:%M:%S", time.localtime(v)))) + + +def Callids(): if not _call_ids_impl._instance: _call_ids_impl._instance = _call_ids_impl() return _call_ids_impl._instance diff --git a/sfa/util/config.py b/sfa/util/config.py index e5435a7b..89b63fa4 100644 --- a/sfa/util/config.py +++ b/sfa/util/config.py @@ -9,26 +9,28 @@ from sfa.util.py23 import StringIO from sfa.util.py23 import ConfigParser default_config = \ + """ """ -""" + def isbool(v): return v.lower() in ("true", "false") + def str2bool(v): - return v.lower() in ("true", "1") + return v.lower() in ("true", "1") + class Config: - + def __init__(self, config_file='/etc/sfa/sfa_config'): self._files = [] self.config_path = os.path.dirname(config_file) - self.config = ConfigParser.ConfigParser() + self.config = ConfigParser.ConfigParser() self.filename = config_file if not os.path.isfile(self.filename): self.create(self.filename) self.load(self.filename) - def _header(self): header = """ @@ -47,7 +49,6 @@ DO NOT EDIT. This file was automatically generated at configfile = open(filename, 'w') configfile.write(default_config) configfile.close() - def load(self, filename): if filename: @@ -75,7 +76,7 @@ DO NOT EDIT. This file was automatically generated at if not value: value = "" self.config.set(section_name, option_name, value) - + def load_shell(self, filename): f = open(filename, 'r') for line in f: @@ -86,13 +87,13 @@ DO NOT EDIT. This file was automatically generated at if len(parts) < 2: continue option = parts[0] - value = parts[1].replace('"', '').replace("'","") + value = parts[1].replace('"', '').replace("'", "") section, var = self.locate_varname(option, strict=False) if section and var: self.set(section, var, value) except: pass - f.close() + f.close() def locate_varname(self, varname, strict=True): varname = varname.lower() @@ -105,7 +106,7 @@ DO NOT EDIT. This file was automatically generated at var_name = varname.replace(section_name, "")[1:] if strict and not self.config.has_option(section_name, var_name): raise ConfigParser.NoOptionError(var_name, section_name) - return (section_name, var_name) + return (section_name, var_name) def set_attributes(self): sections = self.config.sections() @@ -116,7 +117,7 @@ DO NOT EDIT. This file was automatically generated at if isbool(value): value = str2bool(value) elif value.isdigit(): - value = int(value) + value = int(value) setattr(self, name, value) setattr(self, name.upper(), value) @@ -150,7 +151,7 @@ DO NOT EDIT. This file was automatically generated at } variable_list = {} for item in self.config.items(section): - var_name = item[0] + var_name = item[0] name = "%s_%s" % (section, var_name) value = item[1] if isbool(value): @@ -168,7 +169,7 @@ DO NOT EDIT. This file was automatically generated at } variable_list[name] = variable variables[section] = (category, variable_list) - return variables + return variables def verify(self, config1, config2, validate_method): return True @@ -180,7 +181,7 @@ DO NOT EDIT. This file was automatically generated at def is_xml(config_file): try: x = Xml(config_file) - return True + return True except: return False @@ -193,23 +194,23 @@ DO NOT EDIT. This file was automatically generated at except ConfigParser.MissingSectionHeaderError: return False - def dump(self, sections=None): - if sections is None: sections=[] + if sections is None: + sections = [] sys.stdout.write(output_python()) - def output_python(self, encoding = "utf-8"): + def output_python(self, encoding="utf-8"): buf = codecs.lookup(encoding)[3](StringIO()) - buf.writelines(["# " + line + os.linesep for line in self._header()]) - + buf.writelines(["# " + line + os.linesep for line in self._header()]) + for section in self.sections(): buf.write("[%s]%s" % (section, os.linesep)) - for (name,value) in self.items(section): - buf.write("%s=%s%s" % (name,value,os.linesep)) + for (name, value) in self.items(section): + buf.write("%s=%s%s" % (name, value, os.linesep)) buf.write(os.linesep) return buf.getvalue() - - def output_shell(self, show_comments = True, encoding = "utf-8"): + + def output_shell(self, show_comments=True, encoding="utf-8"): """ Return variables as a shell script. """ @@ -218,18 +219,18 @@ DO NOT EDIT. This file was automatically generated at buf.writelines(["# " + line + os.linesep for line in self._header()]) for section in self.sections(): - for (name,value) in self.items(section): + for (name, value) in self.items(section): # bash does not have the concept of NULL if value: option = "%s_%s" % (section.upper(), name.upper()) if isbool(value): value = str(str2bool(value)) elif not value.isdigit(): - value = '"%s"' % value + value = '"%s"' % value buf.write(option + "=" + value + os.linesep) - return buf.getvalue() + return buf.getvalue() - def output_php(selfi, encoding = "utf-8"): + def output_php(selfi, encoding="utf-8"): """ Return variables as a PHP script. """ @@ -239,7 +240,7 @@ DO NOT EDIT. This file was automatically generated at buf.writelines(["// " + line + os.linesep for line in self._header()]) for section in self.sections(): - for (name,value) in self.items(section): + for (name, value) in self.items(section): option = "%s_%s" % (section, name) buf.write(os.linesep) buf.write("// " + option + os.linesep) @@ -249,9 +250,9 @@ DO NOT EDIT. This file was automatically generated at buf.write("?>" + os.linesep) - return buf.getvalue() + return buf.getvalue() - def output_xml(self, encoding = "utf-8"): + def output_xml(self, encoding="utf-8"): pass def output_variables(self, encoding="utf-8"): @@ -261,39 +262,38 @@ DO NOT EDIT. This file was automatically generated at buf = codecs.lookup(encoding)[3](StringIO()) for section in self.sections(): - for (name,value) in self.items(section): - option = "%s_%s" % (section,name) + for (name, value) in self.items(section): + option = "%s_%s" % (section, name) buf.write(option + os.linesep) return buf.getvalue() - pass - + pass + def write(self, filename=None): if not filename: filename = self.filename - configfile = open(filename, 'w') + configfile = open(filename, 'w') self.config.write(configfile) - + def save(self, filename=None): self.write(filename) - def get_trustedroots_dir(self): return self.config_path + os.sep + 'trusted_roots' def get_openflow_aggrMgr_info(self): aggr_mgr_ip = 'localhost' - if (hasattr(self,'openflow_aggregate_manager_ip')): + if (hasattr(self, 'openflow_aggregate_manager_ip')): aggr_mgr_ip = self.OPENFLOW_AGGREGATE_MANAGER_IP aggr_mgr_port = 2603 - if (hasattr(self,'openflow_aggregate_manager_port')): + if (hasattr(self, 'openflow_aggregate_manager_port')): aggr_mgr_port = self.OPENFLOW_AGGREGATE_MANAGER_PORT - return (aggr_mgr_ip,aggr_mgr_port) + return (aggr_mgr_ip, aggr_mgr_port) def get_interface_hrn(self): - if (hasattr(self,'sfa_interface_hrn')): + if (hasattr(self, 'sfa_interface_hrn')): return self.SFA_INTERFACE_HRN else: return "plc" @@ -306,7 +306,6 @@ if __name__ == '__main__': if len(sys.argv) > 1: filename = sys.argv[1] config = Config(filename) - else: + else: config = Config() config.dump() - diff --git a/sfa/util/defaultdict.py b/sfa/util/defaultdict.py index e0dd1450..513eb81f 100644 --- a/sfa/util/defaultdict.py +++ b/sfa/util/defaultdict.py @@ -3,36 +3,44 @@ try: from collections import defaultdict except: class defaultdict(dict): + def __init__(self, default_factory=None, *a, **kw): if (default_factory is not None and - not hasattr(default_factory, '__call__')): + not hasattr(default_factory, '__call__')): raise TypeError('first argument must be callable') dict.__init__(self, *a, **kw) self.default_factory = default_factory + def __getitem__(self, key): try: return dict.__getitem__(self, key) except KeyError: return self.__missing__(key) + def __missing__(self, key): if self.default_factory is None: raise KeyError(key) self[key] = value = self.default_factory() return value + def __reduce__(self): if self.default_factory is None: args = tuple() else: args = self.default_factory, return type(self), args, None, None, self.items() + def copy(self): return self.__copy__() + def __copy__(self): return type(self)(self.default_factory, self) + def __deepcopy__(self, memo): import copy return type(self)(self.default_factory, copy.deepcopy(self.items())) + def __repr__(self): return 'defaultdict(%s, %s)' % (self.default_factory, dict.__repr__(self)) diff --git a/sfa/util/enumeration.py b/sfa/util/enumeration.py index b65508f8..8ddfcc16 100644 --- a/sfa/util/enumeration.py +++ b/sfa/util/enumeration.py @@ -11,17 +11,19 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- + class Enum(set): + def __init__(self, *args, **kwds): set.__init__(self) enums = dict(zip(args, [object() for i in range(len(args))]), **kwds) @@ -30,6 +32,6 @@ class Enum(set): self.add(eval('self.%s' % key)) -#def Enum2(*args, **kwds): +# def Enum2(*args, **kwds): # enums = dict(zip(sequential, range(len(sequential))), **named) # return type('Enum', (), enums) diff --git a/sfa/util/faults.py b/sfa/util/faults.py index 702b685d..ea62bad2 100644 --- a/sfa/util/faults.py +++ b/sfa/util/faults.py @@ -11,13 +11,13 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- # @@ -27,37 +27,49 @@ from sfa.util.genicode import GENICODE from sfa.util.py23 import xmlrpc_client + class SfaFault(xmlrpc_client.Fault): - def __init__(self, faultCode, faultString, extra = None): + + def __init__(self, faultCode, faultString, extra=None): if extra: faultString += ": " + str(extra) xmlrpc_client.Fault.__init__(self, faultCode, faultString) + class Forbidden(SfaFault): - def __init__(self, extra = None): - faultString = "FORBIDDEN" - SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + + def __init__(self, extra=None): + faultString = "FORBIDDEN" + SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + class BadArgs(SfaFault): - def __init__(self, extra = None): + + def __init__(self, extra=None): faultString = "BADARGS" SfaFault.__init__(self, GENICODE.BADARGS, faultString, extra) class CredentialMismatch(SfaFault): - def __init__(self, extra = None): + + def __init__(self, extra=None): faultString = "Credential mismatch" - SfaFault.__init__(self, GENICODE.CREDENTIAL_MISMATCH, faultString, extra) + SfaFault.__init__(self, GENICODE.CREDENTIAL_MISMATCH, + faultString, extra) + class SfaInvalidAPIMethod(SfaFault): - def __init__(self, method, interface = None, extra = None): + + def __init__(self, method, interface=None, extra=None): faultString = "Invalid method " + method if interface: faultString += " for interface " + interface SfaFault.__init__(self, GENICODE.UNSUPPORTED, faultString, extra) + class SfaInvalidArgumentCount(SfaFault): - def __init__(self, got, min, max = min, extra = None): + + def __init__(self, got, min, max=min, extra=None): if min != max: expected = "%d-%d" % (min, max) else: @@ -66,306 +78,416 @@ class SfaInvalidArgumentCount(SfaFault): (expected, got) SfaFault.__init__(self, GENICODE.BADARGS, faultString, extra) + class SfaInvalidArgument(SfaFault): - def __init__(self, extra = None, name = None): + + def __init__(self, extra=None, name=None): if name is not None: faultString = "Invalid %s value" % name else: faultString = "Invalid argument" SfaFault.__init__(self, GENICODE.BADARGS, faultString, extra) + class SfaAuthenticationFailure(SfaFault): - def __init__(self, extra = None): + + def __init__(self, extra=None): faultString = "Failed to authenticate call" SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + class SfaDBError(SfaFault): - def __init__(self, extra = None): + + def __init__(self, extra=None): faultString = "Database error" SfaFault.__init__(self, GENICODE.DBERROR, faultString, extra) + class SfaPermissionDenied(SfaFault): - def __init__(self, extra = None): + + def __init__(self, extra=None): faultString = "Permission denied" SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + class SfaNotImplemented(SfaFault): - def __init__(self, interface=None, extra = None): + + def __init__(self, interface=None, extra=None): faultString = "Not implemented" if interface: - faultString += " at interface " + interface + faultString += " at interface " + interface SfaFault.__init__(self, GENICODE.UNSUPPORTED, faultString, extra) + class SfaAPIError(SfaFault): - def __init__(self, extra = None): + + def __init__(self, extra=None): faultString = "Internal SFA API error" SfaFault.__init__(self, GENICODE.SERVERERROR, faultString, extra) + class MalformedHrnException(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Malformed HRN: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, extra) + def __str__(self): return repr(self.value) + class TreeException(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Tree Exception: %(value)s, " % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class SearchFailed(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "%s does not exist here " % self.value SfaFault.__init__(self, GENICODE.SEARCHFAILED, faultString, extra) + def __str__(self): return repr(self.value) + class NonExistingRecord(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Non exsiting record %(value)s, " % locals() SfaFault.__init__(self, GENICODE.SEARCHFAILED, faultString, extra) + def __str__(self): return repr(self.value) + class ExistingRecord(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Existing record: %(value)s, " % locals() SfaFault.__init__(self, GENICODE.REFUSED, faultString, extra) + def __str__(self): return repr(self.value) - + class InvalidRPCParams(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Invalid RPC Params: %(value)s, " % locals() SfaFault.__init__(self, GENICODE.RPCERROR, faultString, extra) + def __str__(self): return repr(self.value) # SMBAKER exceptions follow + class ConnectionKeyGIDMismatch(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Connection Key GID mismatch: %(value)s" % locals() - SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class MissingCallerGID(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Missing Caller GID: %(value)s" % locals() - SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class RecordNotFound(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Record not found: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class UnknownSfaType(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Unknown SFA Type: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class MissingAuthority(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Missing authority: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class PlanetLabRecordDoesNotExist(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "PlanetLab record does not exist : %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class PermissionError(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Permission error: %(value)s" % locals() SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + def __str__(self): return repr(self.value) + class InsufficientRights(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Insufficient rights: %(value)s" % locals() SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + def __str__(self): return repr(self.value) + class MissingDelegateBit(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Missing delegate bit: %(value)s" % locals() SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + def __str__(self): return repr(self.value) + class ChildRightsNotSubsetOfParent(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Child rights not subset of parent: %(value)s" % locals() SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + def __str__(self): return repr(self.value) + class CertMissingParent(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Cert missing parent: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class CertNotSignedByParent(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Cert not signed by parent: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) - + + class GidParentHrn(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value - faultString = "Cert URN is not an extension of its parent: %(value)s" % locals() + faultString = "Cert URN is not an extension of its parent: %(value)s" % locals( + ) SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) - + + class GidInvalidParentHrn(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "GID invalid parent hrn: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class SliverDoesNotExist(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Sliver does not exist : %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class BadRequestHash(xmlrpc_client.Fault): - def __init__(self, hash = None, extra = None): + + def __init__(self, hash=None, extra=None): faultString = "bad request hash: " + str(hash) xmlrpc_client.Fault.__init__(self, GENICODE.ERROR, faultString) + class MissingTrustedRoots(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value - faultString = "Trusted root directory does not exist: %(value)s" % locals() - SfaFault.__init__(self, GENICODE.SERVERERROR, faultString, extra) + faultString = "Trusted root directory does not exist: %(value)s" % locals( + ) + SfaFault.__init__(self, GENICODE.SERVERERROR, faultString, extra) + def __str__(self): return repr(self.value) + class MissingSfaInfo(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Missing information: %(value)s" % locals() - SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class InvalidRSpec(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Invalid RSpec: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class InvalidRSpecVersion(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Invalid RSpec version: %(value)s" % locals() SfaFault.__init__(self, GENICODE.BADVERSION, faultString, extra) + def __str__(self): return repr(self.value) + class UnsupportedRSpecVersion(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Unsupported RSpec version: %(value)s" % locals() SfaFault.__init__(self, GENICODE.UNSUPPORTED, faultString, extra) + def __str__(self): return repr(self.value) + class InvalidRSpecElement(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Invalid RSpec Element: %(value)s" % locals() SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class InvalidXML(SfaFault): - def __init__(self, value, extra = None): + + def __init__(self, value, extra=None): self.value = value faultString = "Invalid XML Document: %(value)s" % locals() SfaFault.__init__(self, GENICODE.BADARGS, faultString, extra) + def __str__(self): return repr(self.value) + class AccountNotEnabled(SfaFault): - def __init__(self, extra = None): + + def __init__(self, extra=None): faultString = "Account Disabled" SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) + def __str__(self): return repr(self.value) + class CredentialNotVerifiable(SfaFault): - def __init__(self, value=None, extra = None): + + def __init__(self, value=None, extra=None): self.value = value - faultString = "Unable to verify credential" %locals() + faultString = "Unable to verify credential" % locals() if value: faultString += ": %s" % value - faultString += ", " + faultString += ", " SfaFault.__init__(self, GENICODE.BADARGS, faultString, extra) + def __str__(self): return repr(self.value) + class CertExpired(SfaFault): + def __init__(self, value, extra=None): self.value = value faultString = "%s cert is expired" % value SfaFault.__init__(self, GENICODE.ERROR, faultString, extra) - + + class SfatablesRejected(SfaFault): + def __init__(self, value, extra=None): - self.value =value + self.value = value faultString = "%s rejected by sfatables" - SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + SfaFault.__init__(self, GENICODE.FORBIDDEN, faultString, extra) + class UnsupportedOperation(SfaFault): + def __init__(self, value, extra=None): self.value = value faultString = "Unsupported operation: %s" % value - SfaFault.__init__(self, GENICODE.UNSUPPORTED, faultString, extra) - + SfaFault.__init__(self, GENICODE.UNSUPPORTED, faultString, extra) diff --git a/sfa/util/genicode.py b/sfa/util/genicode.py index 2ebac476..9b30eb7a 100644 --- a/sfa/util/genicode.py +++ b/sfa/util/genicode.py @@ -11,13 +11,13 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- @@ -42,5 +42,5 @@ GENICODE = Enum( EXPIRED=15, INPORGRESS=16, ALREADYEXISTS=17, - CREDENTIAL_MISMATCH=22 -) + CREDENTIAL_MISMATCH=22 +) diff --git a/sfa/util/method.py b/sfa/util/method.py index 009220c9..44c368a0 100644 --- a/sfa/util/method.py +++ b/sfa/util/method.py @@ -13,6 +13,7 @@ from sfa.util.faults import SfaFault, SfaInvalidAPIMethod, SfaInvalidArgumentCou from sfa.storage.parameter import Parameter, Mixed, python_type, xmlrpc_type + class Method: """ Base class for all SfaAPI functions. At a minimum, all SfaAPI @@ -54,7 +55,7 @@ class Method: # API may set this to a (addr, port) tuple if known self.source = None - + def __call__(self, *args, **kwds): """ Main entry point for all SFA API functions. Type checks @@ -65,23 +66,25 @@ class Method: start = time.time() methodname = self.name if not self.api.interface or self.api.interface not in self.interfaces: - raise SfaInvalidAPIMethod(methodname, self.api.interface) + raise SfaInvalidAPIMethod(methodname, self.api.interface) (min_args, max_args, defaults) = self.args() - + # Check that the right number of arguments were passed in if len(args) < len(min_args) or len(args) > len(max_args): - raise SfaInvalidArgumentCount(len(args), len(min_args), len(max_args)) + raise SfaInvalidArgumentCount( + len(args), len(min_args), len(max_args)) for name, value, expected in zip(max_args, args, self.accepts): self.type_check(name, value, expected, args) - logger.debug("method.__call__ [%s] : BEG %s"%(self.api.interface,methodname)) + logger.debug("method.__call__ [%s] : BEG %s" % ( + self.api.interface, methodname)) result = self.call(*args, **kwds) runtime = time.time() - start - logger.debug("method.__call__ [%s] : END %s in %02f s (%s)"%\ - (self.api.interface,methodname,runtime,getattr(self,'message',"[no-msg]"))) + logger.debug("method.__call__ [%s] : END %s in %02f s (%s)" % + (self.api.interface, methodname, runtime, getattr(self, 'message', "[no-msg]"))) return result @@ -90,23 +93,24 @@ class Method: caller = "" # Prepend caller and method name to expected faults - fault.faultString = caller + ": " + self.name + ": " + fault.faultString + fault.faultString = caller + ": " + self.name + ": " + fault.faultString runtime = time.time() - start - logger.log_exc("Method %s raised an exception"%self.name) + logger.log_exc("Method %s raised an exception" % self.name) raise fault - - def help(self, indent = " "): + def help(self, indent=" "): """ Text documentation for the method. """ (min_args, max_args, defaults) = self.args() - text = "%s(%s) -> %s\n\n" % (self.name, ", ".join(max_args), xmlrpc_type(self.returns)) + text = "%s(%s) -> %s\n\n" % (self.name, + ", ".join(max_args), xmlrpc_type(self.returns)) text += "Description:\n\n" - lines = [indent + line.strip() for line in self.__doc__.strip().split("\n")] + lines = [indent + line.strip() + for line in self.__doc__.strip().split("\n")] text += "\n".join(lines) + "\n\n" def param_text(name, param, indent, step): @@ -129,9 +133,9 @@ class Method: # Print parameter documentation right below type if isinstance(param, Parameter): - wrapper = textwrap.TextWrapper(width = 70, - initial_indent = " " * param_offset, - subsequent_indent = " " * param_offset) + wrapper = textwrap.TextWrapper(width=70, + initial_indent=" " * param_offset, + subsequent_indent=" " * param_offset) text += "\n".join(wrapper.wrap(param.doc)) + "\n" param = param.type @@ -170,16 +174,17 @@ class Method: That represents the minimum and maximum sets of arguments that this function accepts and the defaults for the optional arguments. """ - + # Inspect call. Remove self from the argument list. - max_args = self.call.func_code.co_varnames[1:self.call.func_code.co_argcount] + max_args = self.call.func_code.co_varnames[ + 1:self.call.func_code.co_argcount] defaults = self.call.func_defaults if defaults is None: defaults = () min_args = max_args[0:len(max_args) - len(defaults)] defaults = tuple([None for arg in min_args]) + defaults - + return (min_args, max_args, defaults) def type_check(self, name, value, expected, args): @@ -188,7 +193,7 @@ class Method: which may be a Python type, a typed value, a Parameter, a Mixed type, or a list or dictionary of possibly mixed types, values, Parameters, or Mixed types. - + Extraneous members of lists must be of the same type as the last specified type. For example, if the expected argument type is [int, bool], then [1, False] and [14, True, False, @@ -210,9 +215,9 @@ class Method: # If an authentication structure is expected, save it and # authenticate after basic type checking is done. - #if isinstance(expected, Auth): + # if isinstance(expected, Auth): # auth = expected - #else: + # else: # auth = None # Get actual expected type from within the Parameter structure @@ -243,23 +248,28 @@ class Method: pass elif not isinstance(value, expected_type): - raise SfaInvalidArgument("expected %s, got %s" % \ - (xmlrpc_type(expected_type), xmlrpc_type(type(value))), + raise SfaInvalidArgument("expected %s, got %s" % + (xmlrpc_type(expected_type), + xmlrpc_type(type(value))), name) # If a minimum or maximum (length, value) has been specified if issubclass(expected_type, StringType): if min is not None and \ len(value.encode(self.api.encoding)) < min: - raise SfaInvalidArgument("%s must be at least %d bytes long" % (name, min)) + raise SfaInvalidArgument( + "%s must be at least %d bytes long" % (name, min)) if max is not None and \ len(value.encode(self.api.encoding)) > max: - raise SfaInvalidArgument("%s must be at most %d bytes long" % (name, max)) + raise SfaInvalidArgument( + "%s must be at most %d bytes long" % (name, max)) elif expected_type in (list, tuple, set): if min is not None and len(value) < min: - raise SfaInvalidArgument("%s must contain at least %d items" % (name, min)) + raise SfaInvalidArgument( + "%s must contain at least %d items" % (name, min)) if max is not None and len(value) > max: - raise SfaInvalidArgument("%s must contain at most %d items" % (name, max)) + raise SfaInvalidArgument( + "%s must contain at most %d items" % (name, max)) else: if min is not None and value < min: raise SfaInvalidArgument("%s must be > %s" % (name, str(min))) @@ -280,12 +290,13 @@ class Method: elif isinstance(expected, dict): for key in value.keys(): if key in expected: - self.type_check(name + "['%s']" % key, value[key], expected[key], args) + self.type_check(name + "['%s']" % + key, value[key], expected[key], args) for key, subparam in expected.iteritems(): if isinstance(subparam, Parameter) and \ subparam.optional is not None and \ not subparam.optional and key not in value.keys(): raise SfaInvalidArgument("'%s' not specified" % key, name) - #if auth is not None: + # if auth is not None: # auth.check(self, *args) diff --git a/sfa/util/policy.py b/sfa/util/policy.py index 5e43be55..4340f00c 100644 --- a/sfa/util/policy.py +++ b/sfa/util/policy.py @@ -2,18 +2,18 @@ import os from sfa.util.storage import SimpleStorage + class Policy(SimpleStorage): def __init__(self, api): self.api = api path = self.api.config.SFA_CONFIG_DIR - filename = ".".join([self.api.interface, self.api.hrn, "policy"]) + filename = ".".join([self.api.interface, self.api.hrn, "policy"]) filepath = path + os.sep + filename self.policy_file = filepath default_policy = {'slice_whitelist': [], 'slice_blacklist': [], 'node_whitelist': [], - 'node_blacklist': []} + 'node_blacklist': []} SimpleStorage.__init__(self, self.policy_file, default_policy) - self.load() - + self.load() diff --git a/sfa/util/prefixTree.py b/sfa/util/prefixTree.py index 0d7a557e..2f5b5f0d 100755 --- a/sfa/util/prefixTree.py +++ b/sfa/util/prefixTree.py @@ -1,18 +1,19 @@ from __future__ import print_function + class prefixNode: def __init__(self, prefix): self.prefix = prefix self.children = [] - + class prefixTree: - + def __init__(self): self.root = prefixNode("") - def insert(self, prefix, node = None): + def insert(self, prefix, node=None): """ insert a prefix into the tree """ @@ -27,20 +28,20 @@ class prefixTree: name = ".".join(parts[:i]) if not self.exists(name) and not name == prefix: self.insert(name) - + if prefix.startswith(node.prefix): if prefix == node.prefix: pass elif not node.children: node.children.append(prefixNode(prefix)) else: - inserted = False + inserted = False for child in node.children: if prefix.startswith(child.prefix): self.insert(prefix, child) inserted = True if not inserted: - node.children.append(prefixNode(prefix)) + node.children.append(prefixNode(prefix)) def load(self, prefix_list): """ @@ -49,7 +50,7 @@ class prefixTree: for prefix in prefix_list: self.insert(prefix) - def exists(self, prefix, node = None): + def exists(self, prefix, node=None): """ returns true if the specified prefix exists anywhere in the tree, false if it doesnt. @@ -68,14 +69,14 @@ class prefixTree: if prefix.startswith(child.prefix): return self.exists(prefix, child) - def best_match(self, prefix, node = None): + def best_match(self, prefix, node=None): """ searches the tree and returns the prefix that best matches the specified prefix """ if not node: node = self.root - + if prefix.startswith(node.prefix): if not node.children: return node.prefix @@ -83,8 +84,8 @@ class prefixTree: if prefix.startswith(child.prefix): return self.best_match(prefix, child) return node.prefix - - def dump(self, node = None): + + def dump(self, node=None): """ print the tree """ @@ -93,7 +94,7 @@ class prefixTree: print(node.prefix) for child in node.children: - print(child.prefix, end=' ') - + print(child.prefix, end=' ') + for child in node.children: self.dump(child) diff --git a/sfa/util/printable.py b/sfa/util/printable.py index f18c274b..c5a2da6c 100644 --- a/sfa/util/printable.py +++ b/sfa/util/printable.py @@ -1,15 +1,17 @@ # yet another way to display records... -def beginning (foo,size=15): - full="%s"%foo - if len(full)<=size: return full - return full[:size-3]+'...' +def beginning(foo, size=15): + full = "%s" % foo + if len(full) <= size: + return full + return full[:size - 3] + '...' -def printable (record_s): + +def printable(record_s): # a list of records : - if isinstance (record_s,list): - return "[" + "\n".join( [ printable(r) for r in record_s ]) + "]" - if isinstance (record_s, dict): - return "{" + " , ".join( [ "%s:%s"%(k,beginning(v)) for k,v in record_s.iteritems() ] ) + "}" - if isinstance (record_s, str): + if isinstance(record_s, list): + return "[" + "\n".join([printable(r) for r in record_s]) + "]" + if isinstance(record_s, dict): + return "{" + " , ".join(["%s:%s" % (k, beginning(v)) for k, v in record_s.iteritems()]) + "}" + if isinstance(record_s, str): return record_s - return "unprintable [[%s]]"%record_s + return "unprintable [[%s]]" % record_s diff --git a/sfa/util/py23.py b/sfa/util/py23.py index d26ebd10..7cd55cd2 100644 --- a/sfa/util/py23.py +++ b/sfa/util/py23.py @@ -21,7 +21,7 @@ try: import httplib as http_client except: from http import client as http_client - + try: import ConfigParser except: diff --git a/sfa/util/sfalogging.py b/sfa/util/sfalogging.py index 2b73b48d..2b7d7823 100644 --- a/sfa/util/sfalogging.py +++ b/sfa/util/sfalogging.py @@ -13,46 +13,54 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- from __future__ import print_function -import os, sys +import os +import sys import traceback -import logging, logging.handlers +import logging +import logging.handlers + +CRITICAL = logging.CRITICAL +ERROR = logging.ERROR +WARNING = logging.WARNING +INFO = logging.INFO +DEBUG = logging.DEBUG + +# a logger that can handle tracebacks -CRITICAL=logging.CRITICAL -ERROR=logging.ERROR -WARNING=logging.WARNING -INFO=logging.INFO -DEBUG=logging.DEBUG -# a logger that can handle tracebacks class _SfaLogger: - def __init__ (self,logfile=None,loggername=None,level=logging.INFO): + + def __init__(self, logfile=None, loggername=None, level=logging.INFO): # default is to locate loggername from the logfile if avail. if not logfile: - #loggername='console' - #handler=logging.StreamHandler() + # loggername='console' + # handler=logging.StreamHandler() #handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) logfile = "/var/log/sfa.log" if not loggername: - loggername=os.path.basename(logfile) + loggername = os.path.basename(logfile) try: - handler=logging.handlers.RotatingFileHandler(logfile,maxBytes=1000000, backupCount=5) + handler = logging.handlers.RotatingFileHandler( + logfile, maxBytes=1000000, backupCount=5) except IOError: # This is usually a permissions error because the file is # owned by root, but httpd is trying to access it. - tmplogfile=os.path.join(os.getenv("TMPDIR", os.getenv("TMP", os.path.normpath("/tmp"))), os.path.basename(logfile)) + tmplogfile = os.path.join(os.getenv("TMPDIR", + os.getenv("TMP", os.path.normpath("/tmp"))), + os.path.basename(logfile)) tmplogfile = os.path.normpath(tmplogfile) tmpdir = os.path.dirname(tmplogfile) @@ -64,51 +72,55 @@ class _SfaLogger: # We could (a) rename the tmplogfile, or (b) # just log to the console in that case. # Here we default to the console. - if os.path.exists(tmplogfile) and not os.access(tmplogfile,os.W_OK): + if os.path.exists(tmplogfile) and not os.access(tmplogfile, os.W_OK): loggername = loggername + "-console" handler = logging.StreamHandler() else: - handler=logging.handlers.RotatingFileHandler(tmplogfile,maxBytes=1000000, backupCount=5) - handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) - self.logger=logging.getLogger(loggername) + handler = logging.handlers.RotatingFileHandler( + tmplogfile, maxBytes=1000000, backupCount=5) + handler.setFormatter(logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s")) + self.logger = logging.getLogger(loggername) self.logger.setLevel(level) # check if logger already has the handler we're about to add handler_exists = False for l_handler in self.logger.handlers: if l_handler.baseFilename == handler.baseFilename and \ l_handler.level == handler.level: - handler_exists = True + handler_exists = True if not handler_exists: self.logger.addHandler(handler) - self.loggername=loggername + self.loggername = loggername - def setLevel(self,level): + def setLevel(self, level): self.logger.setLevel(level) # shorthand to avoid having to import logging all over the place def setLevelDebug(self): self.logger.setLevel(logging.DEBUG) - def debugEnabled (self): + def debugEnabled(self): return self.logger.getEffectiveLevel() == logging.DEBUG # define a verbose option with s/t like # parser.add_option("-v", "--verbose", action="count", dest="verbose", default=0) # and pass the coresponding options.verbose to this method to adjust level - def setLevelFromOptVerbose(self,verbose): - if verbose==0: + def setLevelFromOptVerbose(self, verbose): + if verbose == 0: self.logger.setLevel(logging.WARNING) - elif verbose==1: + elif verbose == 1: self.logger.setLevel(logging.INFO) - elif verbose>=2: + elif verbose >= 2: self.logger.setLevel(logging.DEBUG) # in case some other code needs a boolean - def getBoolVerboseFromOpt(self,verbose): - return verbose>=1 - def getBoolDebugFromOpt(self,verbose): - return verbose>=2 + + def getBoolVerboseFromOpt(self, verbose): + return verbose >= 1 + + def getBoolDebugFromOpt(self, verbose): + return verbose >= 2 #################### def info(self, msg): @@ -116,34 +128,36 @@ class _SfaLogger: def debug(self, msg): self.logger.debug(msg) - + def warn(self, msg): self.logger.warn(msg) # some code is using logger.warn(), some is using logger.warning() def warning(self, msg): self.logger.warning(msg) - + def error(self, msg): - self.logger.error(msg) - + self.logger.error(msg) + def critical(self, msg): self.logger.critical(msg) # logs an exception - use in an except statement - def log_exc(self,message): - self.error("%s BEG TRACEBACK"%message+"\n"+traceback.format_exc().strip("\n")) - self.error("%s END TRACEBACK"%message) - - def log_exc_critical(self,message): - self.critical("%s BEG TRACEBACK"%message+"\n"+traceback.format_exc().strip("\n")) - self.critical("%s END TRACEBACK"%message) - + def log_exc(self, message): + self.error("%s BEG TRACEBACK" % message + "\n" + + traceback.format_exc().strip("\n")) + self.error("%s END TRACEBACK" % message) + + def log_exc_critical(self, message): + self.critical("%s BEG TRACEBACK" % message + "\n" + + traceback.format_exc().strip("\n")) + self.critical("%s END TRACEBACK" % message) + # for investigation purposes, can be placed anywhere - def log_stack(self,message): - to_log="".join(traceback.format_stack()) - self.info("%s BEG STACK"%message+"\n"+to_log) - self.info("%s END STACK"%message) + def log_stack(self, message): + to_log = "".join(traceback.format_stack()) + self.info("%s BEG STACK" % message + "\n" + to_log) + self.info("%s END STACK" % message) def enable_console(self, stream=sys.stdout): formatter = logging.Formatter("%(message)s") @@ -158,14 +172,16 @@ warn_logger = _SfaLogger(loggername='warning', level=logging.WARNING) error_logger = _SfaLogger(loggername='error', level=logging.ERROR) critical_logger = _SfaLogger(loggername='critical', level=logging.CRITICAL) logger = info_logger -sfi_logger = _SfaLogger(logfile=os.path.expanduser("~/.sfi/")+'sfi.log',loggername='sfilog', level=logging.DEBUG) +sfi_logger = _SfaLogger(logfile=os.path.expanduser("~/.sfi/") + 'sfi.log', + loggername='sfilog', level=logging.DEBUG) ######################################## import time + def profile(logger): """ Prints the runtime of the specified callable. Use as a decorator, e.g., - + @profile(logger) def foo(...): ... @@ -176,24 +192,26 @@ def profile(logger): result = callable(*args, **kwds) end = time.time() args = map(str, args) - args += ["%s = %s" % (name, str(value)) for (name, value) in kwds.iteritems()] + args += ["%s = %s" % (name, str(value)) + for (name, value) in kwds.iteritems()] # should probably use debug, but then debug is not always enabled - logger.info("PROFILED %s (%s): %.02f s" % (callable.__name__, ", ".join(args), end - start)) + logger.info("PROFILED %s (%s): %.02f s" % + (callable.__name__, ", ".join(args), end - start)) return result return wrapper return logger_profile -if __name__ == '__main__': +if __name__ == '__main__': print('testing sfalogging into logger.log') - logger1=_SfaLogger('logger.log', loggername='std(info)') - logger2=_SfaLogger('logger.log', loggername='error', level=logging.ERROR) - logger3=_SfaLogger('logger.log', loggername='debug', level=logging.DEBUG) - - for (logger,msg) in [ (logger1,"std(info)"),(logger2,"error"),(logger3,"debug")]: - - print("====================",msg, logger.logger.handlers) - + logger1 = _SfaLogger('logger.log', loggername='std(info)') + logger2 = _SfaLogger('logger.log', loggername='error', level=logging.ERROR) + logger3 = _SfaLogger('logger.log', loggername='debug', level=logging.DEBUG) + + for logger, msg in ((logger1, "std(info)"), (logger2, "error"), (logger3, "debug")): + + print("====================", msg, logger.logger.handlers) + logger.enable_console() logger.critical("logger.critical") logger.error("logger.error") @@ -202,13 +220,12 @@ if __name__ == '__main__': logger.debug("logger.debug") logger.setLevel(logging.DEBUG) logger.debug("logger.debug again") - + @profile(logger) - def sleep(seconds = 1): + def sleep(seconds=1): time.sleep(seconds) logger.info('console.info') sleep(0.5) logger.setLevel(logging.DEBUG) sleep(0.25) - diff --git a/sfa/util/sfatablesRuntime.py b/sfa/util/sfatablesRuntime.py index 0bc88f6c..2266b311 100644 --- a/sfa/util/sfatablesRuntime.py +++ b/sfa/util/sfatablesRuntime.py @@ -11,10 +11,11 @@ try: information that sfatables is requesting. But for now, we just return the basic information needed in a dict. """ - base_context = {'sfa':{'user':{'hrn':user_hrn}, 'slice':{'hrn':slice_hrn}}} + base_context = { + 'sfa': {'user': {'hrn': user_hrn}, 'slice': {'hrn': slice_hrn}}} return base_context - def run_sfatables(chain, hrn, origin_hrn, rspec, context_callback = None ): + def run_sfatables(chain, hrn, origin_hrn, rspec, context_callback=None): """ Run the rspec through sfatables @param chain Name of rule chain @@ -22,7 +23,7 @@ try: @param origin_hrn Original caller's hrn @param rspec Incoming rspec @param context_callback Callback used to generate the request context - + @return rspec """ if not context_callback: @@ -40,8 +41,10 @@ try: return newrspec except: - + from sfa.util.sfalogging import logger - def run_sfatables (_,__,___, rspec, ____=None): - logger.warning("Cannot import sfatables.runtime, please install package sfa-sfatables") + + def run_sfatables(_, __, ___, rspec, ____=None): + logger.warning( + "Cannot import sfatables.runtime, please install package sfa-sfatables") return rspec diff --git a/sfa/util/sfatime.py b/sfa/util/sfatime.py index 74356673..9752d325 100644 --- a/sfa/util/sfatime.py +++ b/sfa/util/sfatime.py @@ -33,6 +33,7 @@ from sfa.util.py23 import StringType SFATIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" + def utcparse(input): """ Translate a string into a time using dateutil.parser.parse but make sure it's in UTC time and strip the timezone, so that it's compatible with normal datetime.datetime objects. @@ -40,18 +41,21 @@ the timezone, so that it's compatible with normal datetime.datetime objects. For safety this can also handle inputs that are either timestamps, or datetimes """ - def handle_shorthands (input): + def handle_shorthands(input): """recognize string like +5d or +3w or +2m as 2 days, 3 weeks or 2 months from now""" if input.startswith('+'): - match=re.match (r"([0-9]+)([dwm])",input[1:]) + match = re.match(r"([0-9]+)([dwm])", input[1:]) if match: - how_many=int(match.group(1)) - what=match.group(2) - if what == 'd': d=datetime.timedelta(days=how_many) - elif what == 'w': d=datetime.timedelta(weeks=how_many) - elif what == 'm': d=datetime.timedelta(weeks=4*how_many) - return datetime.datetime.utcnow()+d + how_many = int(match.group(1)) + what = match.group(2) + if what == 'd': + d = datetime.timedelta(days=how_many) + elif what == 'w': + d = datetime.timedelta(weeks=how_many) + elif what == 'm': + d = datetime.timedelta(weeks=4 * how_many) + return datetime.datetime.utcnow() + d # prepare the input for the checks below by # casting strings ('1327098335') to ints @@ -60,37 +64,43 @@ For safety this can also handle inputs that are either timestamps, or datetimes input = int(input) except ValueError: try: - new_input=handle_shorthands(input) - if new_input is not None: input=new_input + new_input = handle_shorthands(input) + if new_input is not None: + input = new_input except: import traceback traceback.print_exc() - #################### here we go - if isinstance (input, datetime.datetime): + # here we go + if isinstance(input, datetime.datetime): #logger.info ("argument to utcparse already a datetime - doing nothing") return input - elif isinstance (input, StringType): + elif isinstance(input, StringType): t = dateutil.parser.parse(input) if t.utcoffset() is not None: t = t.utcoffset() + t.replace(tzinfo=None) return t - elif isinstance (input, (int,float,long)): + elif isinstance(input, (int, float, long)): return datetime.datetime.fromtimestamp(input) else: - logger.error("Unexpected type in utcparse [%s]"%type(input)) + logger.error("Unexpected type in utcparse [%s]" % type(input)) + def datetime_to_string(dt): return datetime.datetime.strftime(dt, SFATIME_FORMAT) + def datetime_to_utc(dt): return time.gmtime(datetime_to_epoch(dt)) -# see https://docs.python.org/2/library/time.html +# see https://docs.python.org/2/library/time.html # all timestamps are in UTC so time.mktime() would be *wrong* + + def datetime_to_epoch(dt): return int(calendar.timegm(dt.timetuple())) + def add_datetime(input, days=0, hours=0, minutes=0, seconds=0): """ Adjust the input date by the specified delta (in seconds). @@ -100,9 +110,10 @@ def add_datetime(input, days=0, hours=0, minutes=0, seconds=0): if __name__ == '__main__': # checking consistency - print(20*'X') - print(("Should be close to zero: %s"%(datetime_to_epoch(datetime.datetime.utcnow())-time.time()))) - print(20*'X') + print(20 * 'X') + print(("Should be close to zero: %s" % + (datetime_to_epoch(datetime.datetime.utcnow()) - time.time()))) + print(20 * 'X') for input in [ '+2d', '+3w', @@ -114,4 +125,5 @@ if __name__ == '__main__': '2014-05-28T15:18', '2014-05-28T15:18:30', ]: - print("input=%20s -> parsed %s"%(input,datetime_to_string(utcparse(input)))) + print("input=%20s -> parsed %s" % + (input, datetime_to_string(utcparse(input)))) diff --git a/sfa/util/storage.py b/sfa/util/storage.py index 9033434f..9b682064 100644 --- a/sfa/util/storage.py +++ b/sfa/util/storage.py @@ -1,6 +1,7 @@ import os from sfa.util.xml import XML + class SimpleStorage(dict): """ Handles storing and loading python dictionaries. The storage file created @@ -8,49 +9,51 @@ class SimpleStorage(dict): """ db_filename = None type = 'dict' - - def __init__(self, db_filename, db = None): - if db is None: db={} + + def __init__(self, db_filename, db=None): + if db is None: + db = {} dict.__init__(self, db) self.db_filename = db_filename - + def load(self): if os.path.exists(self.db_filename) and os.path.isfile(self.db_filename): db_file = open(self.db_filename, 'r') dict.__init__(self, eval(db_file.read())) elif os.path.exists(self.db_filename) and not os.path.isfile(self.db_filename): - raise IOError('%s exists but is not a file. please remove it and try again' \ - % self.db_filename) + raise IOError('%s exists but is not a file. please remove it and try again' + % self.db_filename) else: self.write() self.load() - + def write(self): - db_file = open(self.db_filename, 'w') + db_file = open(self.db_filename, 'w') db_file.write(str(self)) db_file.close() - + def sync(self): self.write() + class XmlStorage(SimpleStorage): """ Handles storing and loading python dictionaries. The storage file created is a xml representation of the python dictionary. - """ + """ db_filename = None type = 'xml' def load(self): """ Parse an xml file and store it as a dict - """ + """ if os.path.exists(self.db_filename) and os.path.isfile(self.db_filename): xml = XML(self.db_filename) dict.__init__(self, xml.todict()) elif os.path.exists(self.db_filename) and not os.path.isfile(self.db_filename): - raise IOError('%s exists but is not a file. please remove it and try again' \ - % self.db_filename) + raise IOError('%s exists but is not a file. please remove it and try again' + % self.db_filename) else: self.write() self.load() @@ -64,5 +67,3 @@ class XmlStorage(SimpleStorage): def sync(self): self.write() - - diff --git a/sfa/util/xml.py b/sfa/util/xml.py index 3a38ecc3..82cee736 100755 --- a/sfa/util/xml.py +++ b/sfa/util/xml.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python from lxml import etree from sfa.util.faults import InvalidXML from sfa.rspecs.elements.element import Element @@ -7,22 +7,25 @@ from sfa.util.py23 import StringType from sfa.util.py23 import StringIO # helper functions to help build xpaths + + class XpathFilter: - @staticmethod + @staticmethod def filter_value(key, value): - xpath = "" + xpath = "" if isinstance(value, str): if '*' in value: value = value.replace('*', '') xpath = 'contains(%s, "%s")' % (key, value) else: - xpath = '%s="%s"' % (key, value) + xpath = '%s="%s"' % (key, value) return xpath @staticmethod def xpath(filter=None): - if filter is None: filter={} + if filter is None: + filter = {} xpath = "" if filter: filter_list = [] @@ -30,12 +33,13 @@ class XpathFilter: if key == 'text': key = 'text()' else: - key = '@'+key + key = '@' + key if isinstance(value, str): filter_list.append(XpathFilter.filter_value(key, value)) elif isinstance(value, list): - stmt = ' or '.join([XpathFilter.filter_value(key, str(val)) for val in value]) - filter_list.append(stmt) + stmt = ' or '.join( + [XpathFilter.filter_value(key, str(val)) for val in value]) + filter_list.append(stmt) if filter_list: xpath = ' and '.join(filter_list) xpath = '[' + xpath + ']' @@ -48,22 +52,25 @@ class XpathFilter: # a default namespace defined (xmlns="http://default.com/") and specific prefixes defined # (xmlns:foo="http://foo.com") # according to the documentation instead of writing -# element.xpath ( "//node/foo:subnode" ) +# element.xpath ( "//node/foo:subnode" ) # we'd then need to write xpaths like -# element.xpath ( "//{http://default.com/}node/{http://foo.com}subnode" ) +# element.xpath ( "//{http://default.com/}node/{http://foo.com}subnode" ) # which is a real pain.. # So just so we can keep some reasonable programming style we need to manage the -# namespace map that goes with the _Element (its internal .nsmap being unmutable) +# namespace map that goes with the _Element (its internal .nsmap being +# unmutable) + class XmlElement: + def __init__(self, element, namespaces): self.element = element self.namespaces = namespaces - + # redefine as few methods as possible def xpath(self, xpath, namespaces=None): if not namespaces: - namespaces = self.namespaces + namespaces = self.namespaces elems = self.element.xpath(xpath, namespaces=namespaces) return [XmlElement(elem, namespaces) for elem in elems] @@ -85,7 +92,8 @@ class XmlElement: Returns an instance (dict) of this xml element. The instance holds a reference to this xml element. """ - if fields is None: fields=[] + if fields is None: + fields = [] if not instance_class: instance_class = Element if not fields and hasattr(instance_class, 'fields'): @@ -97,32 +105,33 @@ class XmlElement: instance = instance_class({}, self) for field in fields: if field in self.attrib: - instance[field] = self.attrib[field] - return instance + instance[field] = self.attrib[field] + return instance def add_instance(self, name, instance, fields=None): """ Adds the specifed instance(s) as a child element of this xml element. """ - if fields is None: fields=[] + if fields is None: + fields = [] if not fields and hasattr(instance, 'keys'): fields = instance.keys() elem = self.add_element(name) for field in fields: if field in instance and instance[field]: elem.set(field, unicode(instance[field])) - return elem + return elem def remove_elements(self, name): """ Removes all occurences of an element from the tree. Start at specified root_node if specified, otherwise start at tree's root. """ - + if not element_name.startswith('//'): element_name = '//' + element_name - elements = self.element.xpath('%s ' % name, namespaces=self.namespaces) + elements = self.element.xpath('%s ' % name, namespaces=self.namespaces) for element in elements: parent = element.getparent() parent.remove(element) @@ -139,26 +148,27 @@ class XmlElement: def set_text(self, text): self.element.text = text - + # Element does not have unset ?!? def unset(self, key): del self.element.attrib[key] - + def toxml(self): - return etree.tostring(self.element, encoding='UTF-8', pretty_print=True) + return etree.tostring(self.element, encoding='UTF-8', pretty_print=True) def __str__(self): return self.toxml() - ### other method calls or attribute access like .text or .tag or .get + # other method calls or attribute access like .text or .tag or .get # are redirected on self.element - def __getattr__ (self, name): + def __getattr__(self, name): if not hasattr(self.element, name): raise AttributeError(name) return getattr(self.element, name) + class XML: - + def __init__(self, xml=None, namespaces=None): self.root = None self.namespaces = namespaces @@ -188,29 +198,30 @@ class XML: root = tree.getroot() self.namespaces = dict(root.nsmap) # set namespaces map - if 'default' not in self.namespaces and None in self.namespaces: - # If the 'None' exist, then it's pointing to the default namespace. This makes - # it hard for us to write xpath queries for the default naemspace because lxml - # wont understand a None prefix. We will just associate the default namespeace - # with a key named 'default'. + if 'default' not in self.namespaces and None in self.namespaces: + # If the 'None' exist, then it's pointing to the default namespace. This makes + # it hard for us to write xpath queries for the default naemspace because lxml + # wont understand a None prefix. We will just associate the default namespeace + # with a key named 'default'. self.namespaces['default'] = self.namespaces.pop(None) - + else: - self.namespaces['default'] = 'default' + self.namespaces['default'] = 'default' self.root = XmlElement(root, self.namespaces) # set schema for key in self.root.attrib.keys(): if key.endswith('schemaLocation'): # schemaLocation should be at the end of the list. - # Use list comprehension to filter out empty strings - schema_parts = [x for x in self.root.attrib[key].split(' ') if x] - self.schema = schema_parts[1] - namespace, schema = schema_parts[0], schema_parts[1] + # Use list comprehension to filter out empty strings + schema_parts = [ + x for x in self.root.attrib[key].split(' ') if x] + self.schema = schema_parts[1] + namespace, schema = schema_parts[0], schema_parts[1] break - def parse_dict(self, d, root_tag_name='xml', element = None): - if element is None: + def parse_dict(self, d, root_tag_name='xml', element=None): + if element is None: if self.root is None: self.parse_xml('<%s/>' % root_tag_name) element = self.root.element @@ -228,7 +239,8 @@ class XML: child_element = etree.SubElement(element, key) self.parse_dict(val, key, child_element) elif isinstance(val, StringType): - child_element = etree.SubElement(element, key).text = val + child_element = etree.SubElement( + element, key).text = val elif isinstance(value, int): d[key] = unicode(d[key]) @@ -237,7 +249,7 @@ class XML: # element.attrib.update will explode if DateTimes are in the # dcitionary. - d=d.copy() + d = d.copy() # looks like iteritems won't stand side-effects for k in d.keys(): if not isinstance(d[k], StringType): @@ -268,7 +280,7 @@ class XML: def remove_attribute(self, name, element=None): if not element: element = self.root - element.remove_attribute(name) + element.remove_attribute(name) def add_element(self, *args, **kwds): """ @@ -278,7 +290,7 @@ class XML: """ return self.root.add_element(*args, **kwds) - def remove_elements(self, name, element = None): + def remove_elements(self, name, element=None): """ Removes all occurences of an element from the tree. Start at specified root_node if specified, otherwise start at tree's root. @@ -308,9 +320,11 @@ class XML: for child_elem in list(elem): key = str(child_elem.tag) if key not in attrs: - attrs[key] = [self.get_element_attributes(child_elem, depth-1)] + attrs[key] = [self.get_element_attributes( + child_elem, depth - 1)] else: - attrs[key].append(self.get_element_attributes(child_elem, depth-1)) + attrs[key].append( + self.get_element_attributes(child_elem, depth - 1)) else: attrs['child_nodes'] = list(elem) return attrs @@ -319,7 +333,7 @@ class XML: return self.root.append(elem) def iterchildren(self): - return self.root.iterchildren() + return self.root.iterchildren() def merge(self, in_xml): pass @@ -328,8 +342,8 @@ class XML: return self.toxml() def toxml(self): - return etree.tostring(self.root.element, encoding='UTF-8', pretty_print=True) - + return etree.tostring(self.root.element, encoding='UTF-8', pretty_print=True) + # XXX smbaker, for record.load_from_string def todict(self, elem=None): if elem is None: @@ -342,18 +356,17 @@ class XML: d[child.tag] = [] d[child.tag].append(self.todict(child)) - if len(d)==1 and ("text" in d): + if len(d) == 1 and ("text" in d): d = d["text"] return d - + def save(self, filename): f = open(filename, 'w') f.write(self.toxml()) f.close() -# no RSpec in scope -#if __name__ == '__main__': +# no RSpec in scope +# if __name__ == '__main__': # rspec = RSpec('/tmp/resources.rspec') # print rspec - diff --git a/sfa/util/xrn.py b/sfa/util/xrn.py index bcf541ad..081c71c2 100644 --- a/sfa/util/xrn.py +++ b/sfa/util/xrn.py @@ -11,13 +11,13 @@ # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Work. # -# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS +# THE WORK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE WORK OR THE USE OR OTHER DEALINGS # IN THE WORK. #---------------------------------------------------------------------- @@ -25,22 +25,36 @@ import re from sfa.util.faults import SfaAPIError -# for convenience and smoother translation - we should get rid of these functions eventually +# for convenience and smoother translation - we should get rid of these +# functions eventually + + def get_leaf(hrn): return Xrn(hrn).get_leaf() + + def get_authority(hrn): return Xrn(hrn).get_authority_hrn() -def urn_to_hrn(urn): xrn=Xrn(urn); return (xrn.hrn, xrn.type) -def hrn_to_urn(hrn,type): return Xrn(hrn, type=type).urn -def hrn_authfor_hrn(parenthrn, hrn): return Xrn.hrn_is_auth_for_hrn(parenthrn, hrn) + + +def urn_to_hrn(urn): xrn = Xrn(urn) +return (xrn.hrn, xrn.type) + + +def hrn_to_urn(hrn, type): return Xrn(hrn, type=type).urn + + +def hrn_authfor_hrn( + parenthrn, hrn): return Xrn.hrn_is_auth_for_hrn(parenthrn, hrn) + class Xrn: - ########## basic tools on HRNs + # basic tools on HRNs # split a HRN-like string into pieces # this is like split('.') except for escaped (backslashed) dots # e.g. hrn_split ('a\.b.c.d') -> [ 'a\.b','c','d'] @staticmethod def hrn_split(hrn): - return [ x.replace('--sep--','\\.') for x in hrn.replace('\\.','--sep--').split('.') ] + return [x.replace('--sep--', '\\.') for x in hrn.replace('\\.', '--sep--').split('.')] # e.g. hrn_leaf ('a\.b.c.d') -> 'd' @staticmethod @@ -49,18 +63,18 @@ class Xrn: # e.g. hrn_auth_list ('a\.b.c.d') -> ['a\.b', 'c'] @staticmethod def hrn_auth_list(hrn): return Xrn.hrn_split(hrn)[0:-1] - + # e.g. hrn_auth ('a\.b.c.d') -> 'a\.b.c' @staticmethod def hrn_auth(hrn): return '.'.join(Xrn.hrn_auth_list(hrn)) - + # e.g. escape ('a.b') -> 'a\.b' @staticmethod def escape(token): return re.sub(r'([^\\])\.', r'\1\.', token) # e.g. unescape ('a\.b') -> 'a.b' @staticmethod - def unescape(token): return token.replace('\\.','.') + def unescape(token): return token.replace('\\.', '.') # Return the HRN authority chain from top to bottom. # e.g. hrn_auth_chain('a\.b.c.d') -> ['a\.b', 'a\.b.c'] @@ -69,9 +83,9 @@ class Xrn: parts = Xrn.hrn_auth_list(hrn) chain = [] for i in range(len(parts)): - chain.append('.'.join(parts[:i+1])) + chain.append('.'.join(parts[:i + 1])) # Include the HRN itself? - #chain.append(hrn) + # chain.append(hrn) return chain # Is the given HRN a true authority over the namespace of the other @@ -89,29 +103,36 @@ class Xrn: return True return False - ########## basic tools on URNs + # basic tools on URNs URN_PREFIX = "urn:publicid:IDN" URN_PREFIX_lower = "urn:publicid:idn" @staticmethod - def is_urn (text): + def is_urn(text): return text.lower().startswith(Xrn.URN_PREFIX_lower) @staticmethod - def urn_full (urn): - if Xrn.is_urn(urn): return urn - else: return Xrn.URN_PREFIX+urn + def urn_full(urn): + if Xrn.is_urn(urn): + return urn + else: + return Xrn.URN_PREFIX + urn + @staticmethod - def urn_meaningful (urn): - if Xrn.is_urn(urn): return urn[len(Xrn.URN_PREFIX):] - else: return urn + def urn_meaningful(urn): + if Xrn.is_urn(urn): + return urn[len(Xrn.URN_PREFIX):] + else: + return urn + @staticmethod - def urn_split (urn): + def urn_split(urn): return Xrn.urn_meaningful(urn).split('+') @staticmethod def filter_type(urns=None, type=None): - if urns is None: urns=[] + if urns is None: + urns = [] urn_list = [] if not type: return urns @@ -119,7 +140,7 @@ class Xrn: for urn in urns: xrn = Xrn(xrn=urn) if (xrn.type == type): - # Xrn is probably a urn so we can just compare types + # Xrn is probably a urn so we can just compare types urn_list.append(urn) return urn_list #################### @@ -129,45 +150,53 @@ class Xrn: # self.type # self.path # provide either urn, or (hrn + type) - def __init__ (self, xrn="", type=None, id=None): - if not xrn: xrn = "" + + def __init__(self, xrn="", type=None, id=None): + if not xrn: + xrn = "" # user has specified xrn : guess if urn or hrn self.id = id if Xrn.is_urn(xrn): - self.hrn=None - self.urn=xrn + self.hrn = None + self.urn = xrn if id: self.urn = "%s:%s" % (self.urn, str(id)) self.urn_to_hrn() else: - self.urn=None - self.hrn=xrn - self.type=type + self.urn = None + self.hrn = xrn + self.type = type self.hrn_to_urn() self._normalize() # happens all the time .. # if not type: # debug_logger.debug("type-less Xrn's are not safe") - def __repr__ (self): - result=" 1: self.id = ":".join(parts[1:]) - name = parts[0] - hrn += '.%s' % Xrn.escape(name) + name = parts[0] + hrn += '.%s' % Xrn.escape(name) + + self.hrn = str(hrn) + self.type = str(type) - self.hrn=str(hrn) - self.type=str(type) - def hrn_to_urn(self): """ compute urn from (hrn, type) @@ -249,19 +278,20 @@ class Xrn: # if not self.hrn or self.hrn.startswith(Xrn.URN_PREFIX): if Xrn.is_urn(self.hrn): - raise SfaAPIError("Xrn.hrn_to_urn, hrn=%s"%self.hrn) + raise SfaAPIError("Xrn.hrn_to_urn, hrn=%s" % self.hrn) if self.type and self.type.startswith('authority'): self.authority = Xrn.hrn_auth_list(self.hrn) leaf = self.get_leaf() - #if not self.authority: + # if not self.authority: # self.authority = [self.hrn] type_parts = self.type.split("+") self.type = type_parts[0] name = 'sa' if len(type_parts) > 1: name = type_parts[1] - auth_parts = [part for part in [self.get_authority_urn(), leaf] if part] + auth_parts = [part for part in [ + self.get_authority_urn(), leaf] if part] authority_string = ":".join(auth_parts) else: self.authority = Xrn.hrn_auth_list(self.hrn) @@ -269,22 +299,22 @@ class Xrn: authority_string = self.get_authority_urn() if self.type == None: - urn = "+".join(['',authority_string,Xrn.unescape(name)]) + urn = "+".join(['', authority_string, Xrn.unescape(name)]) else: - urn = "+".join(['',authority_string,self.type,Xrn.unescape(name)]) + urn = "+".join(['', authority_string, + self.type, Xrn.unescape(name)]) if hasattr(self, 'id') and self.id: - urn = "%s:%s" % (urn, self.id) + urn = "%s:%s" % (urn, self.id) self.urn = Xrn.URN_PREFIX + urn def dump_string(self): - result="-------------------- XRN\n" - result += "URN=%s\n"%self.urn - result += "HRN=%s\n"%self.hrn - result += "TYPE=%s\n"%self.type - result += "LEAF=%s\n"%self.get_leaf() - result += "AUTH(hrn format)=%s\n"%self.get_authority_hrn() - result += "AUTH(urn format)=%s\n"%self.get_authority_urn() + result = "-------------------- XRN\n" + result += "URN=%s\n" % self.urn + result += "HRN=%s\n" % self.hrn + result += "TYPE=%s\n" % self.type + result += "LEAF=%s\n" % self.get_leaf() + result += "AUTH(hrn format)=%s\n" % self.get_authority_hrn() + result += "AUTH(urn format)=%s\n" % self.get_authority_urn() return result - diff --git a/tools/depgraph2dot.py b/tools/depgraph2dot.py index ab07a313..97730aa0 100755 --- a/tools/depgraph2dot.py +++ b/tools/depgraph2dot.py @@ -21,33 +21,38 @@ # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -import sys, getopt, colorsys, imp, hashlib +import sys +import getopt +import colorsys +import imp +import hashlib + class pydepgraphdot: - def main(self,argv): - opts,args = getopt.getopt(argv,'',['mono']) + def main(self, argv): + opts, args = getopt.getopt(argv, '', ['mono']) self.colored = 1 - for o,v in opts: - if o=='--mono': + for o, v in opts: + if o == '--mono': self.colored = 0 self.render() - def fix(self,s): + def fix(self, s): # Convert a module name to a syntactically correct node name - return s.replace('.','_') - + return s.replace('.', '_') + def render(self): - p,t = self.get_data() + p, t = self.get_data() # normalise our input data - for k,d in p.items(): + for k, d in p.items(): for v in d.keys(): if v not in p: p[v] = {} - - f = self.get_output_file() - + + f = self.get_output_file() + f.write('digraph G {\n') #f.write('concentrate = true;\n') #f.write('ordering = out;\n') @@ -55,90 +60,91 @@ class pydepgraphdot: f.write('node [style=filled,fontname=Helvetica,fontsize=10];\n') allkd = p.items() allkd.sort() - for k,d in allkd: + for k, d in allkd: tk = t.get(k) - if self.use(k,tk): + if self.use(k, tk): allv = d.keys() allv.sort() for v in allv: tv = t.get(v) - if self.use(v,tv) and not self.toocommon(v,tv): - f.write('%s -> %s' % ( self.fix(k),self.fix(v) ) ) - self.write_attributes(f,self.edge_attributes(k,v)) + if self.use(v, tv) and not self.toocommon(v, tv): + f.write('%s -> %s' % (self.fix(k), self.fix(v))) + self.write_attributes(f, self.edge_attributes(k, v)) f.write(';\n') f.write(self.fix(k)) - self.write_attributes(f,self.node_attributes(k,tk)) + self.write_attributes(f, self.node_attributes(k, tk)) f.write(';\n') f.write('}\n') - def write_attributes(self,f,a): + def write_attributes(self, f, a): if a: f.write(' [') f.write(','.join(a)) f.write(']') - def node_attributes(self,k,type): + def node_attributes(self, k, type): a = [] a.append('label="%s"' % self.label(k)) if self.colored: - a.append('fillcolor="%s"' % self.color(k,type)) + a.append('fillcolor="%s"' % self.color(k, type)) else: a.append('fillcolor=white') - if self.toocommon(k,type): + if self.toocommon(k, type): a.append('peripheries=2') return a - - def edge_attributes(self,k,v): + + def edge_attributes(self, k, v): a = [] - weight = self.weight(k,v) - if weight!=1: + weight = self.weight(k, v) + if weight != 1: a.append('weight=%d' % weight) - length = self.alien(k,v) + length = self.alien(k, v) if length: a.append('minlen=%d' % length) return a - + def get_data(self): t = eval(sys.stdin.read()) - return t['depgraph'],t['types'] - + return t['depgraph'], t['types'] + def get_output_file(self): return sys.stdout - def use(self,s,type): + def use(self, s, type): # Return true if this module is interesting and should be drawn. Return false - # if it should be completely omitted. This is a default policy - please override. - if s in ('os','sys','qt','time','__future__','types','re','string'): + # if it should be completely omitted. This is a default policy - please + # override. + if s in ('os', 'sys', 'qt', 'time', '__future__', 'types', 're', 'string'): # nearly all modules use all of these... more or less. They add nothing to # our diagram. return 0 if s.startswith('encodings.'): return 0 - if s=='__main__': + if s == '__main__': return 1 - if self.toocommon(s,type): + if self.toocommon(s, type): # A module where we dont want to draw references _to_. Dot doesnt handle these # well, so it is probably best to not draw them at all. return 0 return 1 - def toocommon(self,s,type): + def toocommon(self, s, type): # Return true if references to this module are uninteresting. Such references # do not get drawn. This is a default policy - please override. # - if s=='__main__': + if s == '__main__': # references *to* __main__ are never interesting. omitting them means # that main floats to the top of the page return 1 - if type==imp.PKG_DIRECTORY: + if type == imp.PKG_DIRECTORY: # dont draw references to packages. return 1 return 0 - - def weight(self,a,b): + + def weight(self, a, b): # Return the weight of the dependency from a to b. Higher weights # usually have shorter straighter edges. Return 1 if it has normal weight. - # A value of 4 is usually good for ensuring that a related pair of modules + # A value of 4 is usually good for ensuring that a related pair of modules # are drawn next to each other. This is a default policy - please override. # if b.split('.')[-1].startswith('_'): @@ -147,51 +153,48 @@ class pydepgraphdot: # together return 4 return 1 - - def alien(self,a,b): + + def alien(self, a, b): # Return non-zero if references to this module are strange, and should be drawn # extra-long. the value defines the length, in rank. This is also good for putting some # vertical space between seperate subsystems. This is a default policy - please override. # return 0 - def label(self,s): + def label(self, s): # Convert a module name to a formatted node label. This is a default policy - please override. # return '\\.\\n'.join(s.split('.')) - def color(self,s,type): + def color(self, s, type): # Return the node color for this module name. This is a default policy - please override. # # Calculate a color systematically based on the hash of the module name. Modules in the # same package have the same color. Unpackaged modules are grey - t = self.normalise_module_name_for_hash_coloring(s,type) + t = self.normalise_module_name_for_hash_coloring(s, type) return self.color_from_name(t) - - def normalise_module_name_for_hash_coloring(self,s,type): - if type==imp.PKG_DIRECTORY: + + def normalise_module_name_for_hash_coloring(self, s, type): + if type == imp.PKG_DIRECTORY: return s else: i = s.rfind('.') - if i<0: + if i < 0: return '' else: return s[:i] - - def color_from_name(self,name): + + def color_from_name(self, name): n = hashlib.md5(name).digest() - hf = float(ord(n[0])+ord(n[1])*0xff)/0xffff - sf = float(ord(n[2]))/0xff - vf = float(ord(n[3]))/0xff - r,g,b = colorsys.hsv_to_rgb(hf, 0.3+0.6*sf, 0.8+0.2*vf) - return '#%02x%02x%02x' % (r*256,g*256,b*256) + hf = float(ord(n[0]) + ord(n[1]) * 0xff) / 0xffff + sf = float(ord(n[2])) / 0xff + vf = float(ord(n[3])) / 0xff + r, g, b = colorsys.hsv_to_rgb(hf, 0.3 + 0.6 * sf, 0.8 + 0.2 * vf) + return '#%02x%02x%02x' % (r * 256, g * 256, b * 256) def main(): pydepgraphdot().main(sys.argv[1:]) -if __name__=='__main__': +if __name__ == '__main__': main() - - - diff --git a/tools/py2depgraph.py b/tools/py2depgraph.py index ef3b6f84..5c03f7f2 100755 --- a/tools/py2depgraph.py +++ b/tools/py2depgraph.py @@ -20,52 +20,59 @@ # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -import sys, pprint +import sys +import pprint import modulefinder -focus = [ 'sfa' , 'OpenSSL', 'M2Crypto', 'xmlrpclib', 'threading' ] +focus = ['sfa', 'OpenSSL', 'M2Crypto', 'xmlrpclib', 'threading'] + class mymf(modulefinder.ModuleFinder): - def __init__(self,*args,**kwargs): + + def __init__(self, *args, **kwargs): self._depgraph = {} self._types = {} self._last_caller = None - modulefinder.ModuleFinder.__init__(self,*args,**kwargs) - + modulefinder.ModuleFinder.__init__(self, *args, **kwargs) + def import_hook(self, name, caller=None, fromlist=None, level=None): old_last_caller = self._last_caller try: self._last_caller = caller - return modulefinder.ModuleFinder.import_hook(self,name,caller,fromlist) + return modulefinder.ModuleFinder.import_hook(self, name, caller, fromlist) finally: self._last_caller = old_last_caller - - def import_module(self,partnam,fqname,parent): - keep=False + + def import_module(self, partnam, fqname, parent): + keep = False for start in focus: - if fqname.startswith(start): keep=True + if fqname.startswith(start): + keep = True if not keep: - print >> sys.stderr, "Trimmed fqname",fqname + print >> sys.stderr, "Trimmed fqname", fqname return - r = modulefinder.ModuleFinder.import_module(self,partnam,fqname,parent) + r = modulefinder.ModuleFinder.import_module( + self, partnam, fqname, parent) if r is not None: - self._depgraph.setdefault(self._last_caller.__name__,{})[r.__name__] = 1 + self._depgraph.setdefault(self._last_caller.__name__, {})[ + r.__name__] = 1 return r - + def load_module(self, fqname, fp, pathname, (suffix, mode, type)): - r = modulefinder.ModuleFinder.load_module(self, fqname, fp, pathname, (suffix, mode, type)) + r = modulefinder.ModuleFinder.load_module( + self, fqname, fp, pathname, (suffix, mode, type)) if r is not None: self._types[r.__name__] = type return r - - -def main(argv): + + +def main(argv): path = sys.path[:] debug = 0 exclude = [] - mf = mymf(path,debug,exclude) + mf = mymf(path, debug, exclude) mf.run_script(argv[0]) - pprint.pprint({'depgraph':mf._depgraph,'types':mf._types}) - -if __name__=='__main__': + pprint.pprint({'depgraph': mf._depgraph, 'types': mf._types}) + +if __name__ == '__main__': main(sys.argv[1:]) diff --git a/tools/reset_gids.py b/tools/reset_gids.py index e30ed329..3922034e 100755 --- a/tools/reset_gids.py +++ b/tools/reset_gids.py @@ -8,8 +8,9 @@ from sfa.trust.hierarchy import Hierarchy from sfa.util.xrn import Xrn from sfa.trust.certificate import Certificate, Keypair, convert_public_key + def fix_users(): - s=global_dbsession + s = global_dbsession hierarchy = Hierarchy() users = s.query(RegRecord).filter_by(type="user") for record in users: @@ -17,15 +18,16 @@ def fix_users(): if not record.gid: uuid = create_uuid() pkey = Keypair(create=True) - pub_key=getattr(record,'reg_keys',None) + pub_key = getattr(record, 'reg_keys', None) if len(pub_key) > 0: # use only first key in record - if pub_key and isinstance(pub_key, list): pub_key = pub_key[0] + if pub_key and isinstance(pub_key, list): + pub_key = pub_key[0] pub_key = pub_key.key pkey = convert_public_key(pub_key) - urn = Xrn (xrn=record.hrn, type='user').get_urn() - email=getattr(record,'email',None) - gid_object = hierarchy.create_gid(urn, uuid, pkey, email = email) + urn = Xrn(xrn=record.hrn, type='user').get_urn() + email = getattr(record, 'email', None) + gid_object = hierarchy.create_gid(urn, uuid, pkey, email=email) gid = gid_object.save_to_string(save_parents=True) record.gid = gid s.commit()