From 5a9d22d382a20313c622fdc826c3eebb6d3cf586 Mon Sep 17 00:00:00 2001 From: Josh Karlin Date: Fri, 30 Jul 2010 18:43:21 +0000 Subject: [PATCH] merged trunk -r 18510:18539 --- TODO | 23 +-- keyconvert/keyconvert.py | 18 +- sfa/client/sfi.py | 126 ++++++++----- sfa/managers/aggregate_manager_pl.py | 18 +- sfa/managers/slice_manager_pl.py | 273 ++++++++++++--------------- sfa/plc/sfa-import-plc.py | 2 +- sfa/server/interface.py | 19 +- sfa/trust/credential.py | 8 +- sfa/util/api.py | 9 +- sfa/util/rspec.py | 32 ++++ sfa/util/sfaticket.py | 4 +- sfa/util/threadmanager.py | 71 +++++++ 12 files changed, 355 insertions(+), 248 deletions(-) create mode 100755 sfa/util/threadmanager.py diff --git a/TODO b/TODO index 9c30bf51..d863a591 100644 --- a/TODO +++ b/TODO @@ -1,29 +1,16 @@ -- Tutorial - * make a tutorial for sfa - -- Tag +- Build/Tags * test rpm build/install -- Geni Aggregate - * are we going to deploy a geni aggregate - * test - -- Trunk -* use PLC shell instead of xmlrpc when communicating with local plc aggregate - -- Client - * update getNodes to use lxml.etree for parsing the rspec - - Stop invalid users * a recently disabled/deleted user may still have a valid cred. Keep a list of valid/invalid users on the aggregate and check callers against this list - Component manager * GetGids - make this work for peer slices - * GetTicket - must verify_{site,slice,person,keys} on remote aggregate * Redeem ticket - RedeemTicket/AdminTicket not working. Why? * install the slice and node gid when the slice is created (create NM plugin to execute sfa_component_setup.py ?) - Registry +* fix legacy credential support * move db tables into db with less overhead (tokyocabinet?) - GUI/Auth Service @@ -32,13 +19,7 @@ * service manages users key/cert,creds * gui requires user's cred (depends on Auth Service above) -- SM call routing -* sfi -a option should send request to sm with an extra argument to - specify which am to contact instead of connecting directly to the am - (am may not trust client directly) - - Protogeni -* merger josh's branch with trunk * agree on standard set of functon calls * agree on standard set of privs * on permission error, return priv needed to make call diff --git a/keyconvert/keyconvert.py b/keyconvert/keyconvert.py index af12b1f4..de904ee6 100755 --- a/keyconvert/keyconvert.py +++ b/keyconvert/keyconvert.py @@ -4,7 +4,21 @@ import sys import base64 import struct import binascii -from M2Crypto import RSA, DSA +from M2Crypto import RSA, DSA, m2 + + +###### Workaround for bug in m2crypto-0.18 (on Fedora 8) +class RSA_pub_fix(RSA.RSA_pub): + def save_key_bio(self, bio, *args, **kw): + return self.save_pub_key_bio(bio) + +def rsa_new_pub_key((e, n)): + rsa = m2.rsa_new() + m2.rsa_set_e(rsa, e) + m2.rsa_set_n(rsa, n) + return RSA_pub_fix(rsa, 1) +###### +#rsa_new_pub_key = RSA.new_pub_key def decode_key(fname): @@ -78,7 +92,7 @@ def convert(fin, fout): if key_type == "ssh-rsa": e, n = ret[1:] - rsa = RSA.new_pub_key((e, n)) + rsa = rsa_new_pub_key((e, n)) rsa.save_pem(fout) elif key_type == "ssh-dss": diff --git a/sfa/client/sfi.py b/sfa/client/sfi.py index 8592ec1a..53c2227b 100755 --- a/sfa/client/sfi.py +++ b/sfa/client/sfi.py @@ -9,6 +9,7 @@ import tempfile import traceback import socket import random +import datetime from lxml import etree from StringIO import StringIO from types import StringTypes, ListType @@ -23,6 +24,8 @@ import sfa.util.xmlrpcprotocol as xmlrpcprotocol from sfa.util.config import Config import zlib +AGGREGATE_PORT=12346 +CM_PORT=12346 # utility methods here # display methods @@ -171,12 +174,12 @@ class Sfi: parser.add_option("-f", "--format", dest="format", type="choice", help="display format ([xml]|dns|ip)", default="xml", choices=("xml", "dns", "ip")) + + if command in ("resources", "slices", "create", "delete", "start", "stop", "get_ticket"): parser.add_option("-a", "--aggregate", dest="aggregate", - default=None, help="aggregate hrn") - - if command in ("create", "get_ticket"): - parser.add_option("-a", "--aggregate", dest="aggregate", default=None, - help="aggregate hrn") + default=None, help="aggregate host") + parser.add_option("-p", "--port", dest="port", + default=AGGREGATE_PORT, help="aggregate port") if command in ("start", "stop", "reset", "delete", "slices"): parser.add_option("-c", "--component", dest="component", default=None, @@ -203,7 +206,7 @@ class Sfi: help="delegate user credential") parser.add_option("-s", "--slice", dest="delegate_slice", help="delegate slice credential", metavar="HRN", default=None) - + return parser @@ -378,12 +381,25 @@ class Sfi: print "Writing user gid to", file gid.save_to_file(file, save_parents=True) return gid + + def get_cached_credential(self, file): + """ + Return a cached credential only if it hasn't expired. + """ + if (os.path.isfile(file)): + credential = Credential(filename=file) + # make sure it isnt expired + if not credential.get_lifetime or \ + datetime.datetime.today() < credential.get_lifefime(): + return credential + return None def get_user_cred(self): #file = os.path.join(self.options.sfi_dir, get_leaf(self.user) + ".cred") file = os.path.join(self.options.sfi_dir, self.user.replace(self.authority + '.', '') + ".cred") - if (os.path.isfile(file)): - user_cred = Credential(filename=file) + + user_cred = self.get_cached_credential(file) + if user_cred: return user_cred else: # bootstrap user credential @@ -410,8 +426,8 @@ class Sfi: sys.exit(-1) file = os.path.join(self.options.sfi_dir, get_leaf("authority") + ".cred") - if (os.path.isfile(file)): - auth_cred = Credential(filename=file) + auth_cred = self.get_cached_credential(file) + if auth_cred: return auth_cred else: # bootstrap authority credential from user credential @@ -429,8 +445,8 @@ class Sfi: def get_slice_cred(self, name): file = os.path.join(self.options.sfi_dir, "slice_" + get_leaf(name) + ".cred") - if (os.path.isfile(file)): - slice_cred = Credential(filename=file) + slice_cred = self.get_cached_credential(file) + if slice_cred: return slice_cred else: # bootstrap slice credential from user credential @@ -538,15 +554,22 @@ class Sfi: if not records: print "No such component:", opts.component record = records[0] - cm_port = "12346" - url = "https://%s:%s" % (record['hostname'], cm_port) - return xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug) - - # + + return self.get_server(record['hostname'], CM_PORT, self.key_file, \ + self.cert_file, self.options.debug) + + def get_server(self, host, port, keyfile, certfile, debug): + """ + Return an instnace of an xmlrpc server connection + """ + url = "http://%s:%s" % (host, port) + return xmlrpcprotocol.get_server(url, keyfile, certfile, debug) + + #========================================================================== # Following functions implement the commands # # Registry-related commands - # + #========================================================================== def dispatch(self, command, cmd_opts, cmd_args): getattr(self, command)(cmd_opts, cmd_args) @@ -761,16 +784,21 @@ class Sfi: return - # + # ================================================================== # Slice-related commands - # + # ================================================================== - # list available nodes -- use 'resources' w/ no argument instead # list instantiated slices def slices(self, opts, args): + """ + list instantiated slices + """ user_cred = self.get_user_cred().save_to_string(save_parents=True) server = self.slicemgr + if opts.aggregate: + server = self.get_server(opts.aggregate, opts.port, self.key_file, \ + self.cert_file, self.options.debug) # direct connection to the nodes component manager interface if opts.component: server = self.get_component_server_from_hrn(opts.component) @@ -783,13 +811,8 @@ class Sfi: user_cred = self.get_user_cred().save_to_string(save_parents=True) server = self.slicemgr if opts.aggregate: - agg_hrn = opts.aggregate - aggregates = self.registry.get_aggregates(user_cred, agg_hrn) - if not aggregates: - raise Exception, "No such aggregate %s" % agg_hrn - aggregate = aggregates[0] - url = "http://%s:%s" % (aggregate['addr'], aggregate['port']) - server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug) + server = self.get_server(opts.aggregate, opts.port, self.key_file, \ + self.cert_file, self.options.debug) if args: cred = self.get_slice_cred(args[0]).save_to_string(save_parents=True) hrn = args[0] @@ -816,13 +839,11 @@ class Sfi: rspec_file = self.get_rspec_file(args[1]) rspec = open(rspec_file).read() server = self.slicemgr + if opts.aggregate: - aggregates = self.registry.get_aggregates(user_cred, opts.aggregate) - if not aggregates: - raise Exception, "No such aggregate %s" % opts.aggregate - aggregate = aggregates[0] - url = "http://%s:%s" % (aggregate['addr'], aggregate['port']) - server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug) + server = self.get_server(opts.aggregate, opts.port, self.key_file, \ + self.cert_file, self.options.debug) + return server.create_slice(slice_cred, slice_hrn, rspec) # get a ticket for the specified slice @@ -834,12 +855,8 @@ class Sfi: rspec = open(rspec_file).read() server = self.slicemgr if opts.aggregate: - aggregates = self.registry.get_aggregates(user_cred, opts.aggregate) - if not aggregates: - raise Exception, "No such aggregate %s" % opts.aggregate - aggregate = aggregates[0] - url = "http://%s:%s" % (aggregate['addr'], aggregate['port']) - server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug) + server = self.get_server(opts.aggregate, opts.port, self.key_file, \ + self.cert_file, self.options.debug) ticket_string = server.get_ticket(slice_cred, slice_hrn, rspec) file = os.path.join(self.options.sfi_dir, get_leaf(slice_hrn) + ".ticket") print "writing ticket to ", file @@ -853,7 +870,7 @@ class Sfi: # use this to get the right slice credential ticket = SfaTicket(filename=ticket_file) ticket.decode() - slice_hrn = ticket.gidObject.get_hrn() + slice_hrn = ticket.gidObject.get_hrn() #slice_hrn = ticket.attributes['slivers'][0]['hrn'] user_cred = self.get_user_cred() slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True) @@ -868,28 +885,28 @@ class Sfi: connections = {} for hostname in hostnames: try: - cm_port = "12346" - url = "https://%(hostname)s:%(cm_port)s" % locals() - print "Calling redeem_ticket at %(url)s " % locals(), - cm = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug) - cm.redeem_ticket(slice_cred, ticket.save_to_string(save_parents=True)) + print "Calling redeem_ticket at %(hostname)s " % locals(), + server = self.get_server(hostname, CM_PORT, self.key_file, \ + self.cert_file, self.options.debug) + server.redeem_ticket(slice_cred, ticket.save_to_string(save_parents=True)) print "Success" except socket.gaierror: print "Failed:", print "Componet Manager not accepting requests" except Exception, e: print "Failed:", e.message - return # delete named slice def delete(self, opts, args): slice_hrn = args[0] server = self.slicemgr + if opts.aggregate: + server = self.get_server(opts.aggregate, opts.port, self.key_file, \ + self.cert_file, self.options.debug) # direct connection to the nodes component manager interface if opts.component: server = self.get_component_server_from_hrn(opts.component) - slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True) return server.delete_slice(slice_cred, slice_hrn) @@ -897,10 +914,12 @@ class Sfi: def start(self, opts, args): slice_hrn = args[0] server = self.slicemgr - # direct connection to the nodes component manager interface + # direct connection to an aggregagte + if opts.aggregate: + server = self.get_server(opts.aggregate, opts.port, self.key_file, \ + self.cert_file, self.options.debug) if opts.component: server = self.get_component_server_from_hrn(opts.component) - slice_cred = self.get_slice_cred(args[0]).save_to_string(save_parents=True) return server.start_slice(slice_cred, slice_hrn) @@ -908,10 +927,13 @@ class Sfi: def stop(self, opts, args): slice_hrn = args[0] server = self.slicemgr + # direct connection to an aggregate + if opts.aggregate: + server = self.get_server(opts.aggregate, opts.port, self.key_file, \ + self.cert_file, self.options.debug) # direct connection to the nodes component manager interface if opts.component: server = self.get_component_server_from_hrn(opts.component) - slice_cred = self.get_slice_cred(args[0]).save_to_string(save_parents=True) return server.stop_slice(slice_cred, slice_hrn) @@ -926,7 +948,9 @@ class Sfi: return server.reset_slice(slice_cred, slice_hrn) + # ===================================================================== # GENI AM related calls + # ===================================================================== def GetVersion(self, opts, args): server = self.geni_am diff --git a/sfa/managers/aggregate_manager_pl.py b/sfa/managers/aggregate_manager_pl.py index 76c576bd..e6b2c3de 100644 --- a/sfa/managers/aggregate_manager_pl.py +++ b/sfa/managers/aggregate_manager_pl.py @@ -96,13 +96,25 @@ def create_slice(api, xrn, xml, reg_objects=None): return True -def get_ticket(api, xrn, rspec, origin_hrn=None): +def get_ticket(api, xrn, rspec, origin_hrn=None, reg_objects=None): + slice_hrn, type = urn_to_hrn(xrn) - # the the slice record + slices = Slices(api) + peer = slices.get_peer(slice_hrn) + sfa_peer = slices.get_sfa_peer(slice_hrn) + + # get the slice record registry = api.registries[api.hrn] credential = api.getCredential() records = registry.resolve(credential, xrn) - + + # similar to create_slice, we must verify that the required records exist + # at this aggregate before we can issue a ticket + site_id, remote_site_id = slices.verify_site(registry, credential, slice_hrn, + peer, sfa_peer, reg_objects) + slice = slices.verify_slice(registry, credential, slice_hrn, site_id, + remote_site_id, peer, sfa_peer, reg_objects) + # make sure we get a local slice record record = None for tmp_record in records: diff --git a/sfa/managers/slice_manager_pl.py b/sfa/managers/slice_manager_pl.py index 1606eb37..72227d88 100644 --- a/sfa/managers/slice_manager_pl.py +++ b/sfa/managers/slice_manager_pl.py @@ -9,7 +9,7 @@ from copy import deepcopy from lxml import etree from StringIO import StringIO from types import StringTypes - +from sfa.util.rspec import merge_rspecs from sfa.util.namespace import * from sfa.util.rspec import * from sfa.util.specdict import * @@ -18,21 +18,18 @@ from sfa.util.record import SfaRecord from sfa.util.policy import Policy from sfa.util.prefixTree import prefixTree from sfa.util.sfaticket import * +from sfa.util.threadmanager import ThreadManager +import sfa.util.xmlrpcprotocol as xmlrpcprotocol from sfa.util.debug import log import sfa.plc.peers as peers def delete_slice(api, xrn, origin_hrn=None): credential = api.getCredential() - aggregates = api.aggregates - for aggregate in aggregates: - success = False - # request hash is optional so lets try the call without it - try: - aggregates[aggregate].delete_slice(credential, xrn, origin_hrn) - success = True - except: - print >> log, "%s" % (traceback.format_exc()) - print >> log, "Error calling delete slice at aggregate %s" % aggregate + threads = ThreadManager() + for aggregate in api.aggregates: + server = api.aggregates[aggregate] + threads.run(server.delete_slice, credential, xrn, origin_hrn) + threads.get_results() return 1 def create_slice(api, xrn, rspec, origin_hrn=None): @@ -57,126 +54,97 @@ def create_slice(api, xrn, rspec, origin_hrn=None): message = "%s (line %s)" % (error.message, error.line) raise InvalidRSpec(message) - aggs = api.aggregates - cred = api.getCredential() - for agg in aggs: - if agg not in [api.auth.client_cred.get_gid_caller().get_hrn()]: - try: - # Just send entire RSpec to each aggregate - aggs[agg].create_slice(cred, xrn, rspec, origin_hrn) - except: - print >> log, "Error creating slice %s at %s" % (hrn, agg) - traceback.print_exc() - - return True + cred = api.getCredential() + threads = ThreadManager() + for aggregate in api.aggregates: + if aggregate not in [api.auth.client_cred.get_gid_caller().get_hrn()]: + server = api.aggregates[aggregate] + # Just send entire RSpec to each aggregate + threads.run(server.create_slice, cred, xrn, rspec, origin_hrn) + threads.get_results() + return 1 def get_ticket(api, xrn, rspec, origin_hrn=None): slice_hrn, type = urn_to_hrn(xrn) # get the netspecs contained within the clients rspec - client_rspec = RSpec(xml=rspec) - netspecs = client_rspec.getDictsByTagName('NetSpec') + aggregate_rspecs = {} + tree= etree.parse(StringIO(rspec)) + elements = tree.findall('./network') + for element in elements: + aggregate_hrn = element.values()[0] + aggregate_rspecs[aggregate_hrn] = rspec + + # get a ticket from each aggregate + credential = api.getCredential() + threads = ThreadManager() + for aggregate, aggregate_rspec in aggregate_rspecs.items(): + server = None + if aggregate in api.aggregates: + server = api.aggregates[aggregate] + else: + net_urn = hrn_to_urn(aggregate, 'authority') + # we may have a peer that knows about this aggregate + for agg in api.aggregates: + agg_info = api.aggregates[agg].get_aggregates(credential, net_urn) + if agg_info: + # send the request to this address + url = 'http://%s:%s' % (agg_info['addr'], agg_info['port']) + server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file) + break + if server is None: + continue + threads.run(server.get_ticket, credential, xrn, aggregate_rspec, origin_hrn) + results = threads.get_results() - # create an rspec for each individual rspec - rspecs = {} - temp_rspec = RSpec() - for netspec in netspecs: - net_hrn = netspec['name'] - resources = {'start_time': 0, 'end_time': 0 , - 'network': {'NetSpec' : netspec}} - resourceDict = {'RSpec': resources} - temp_rspec.parseDict(resourceDict) - rspecs[net_hrn] = temp_rspec.toxml() + # gather information from each ticket + rspecs = [] + initscripts = [] + slivers = [] + object_gid = None + for result in results: + agg_ticket = SfaTicket(string=result) + attrs = agg_ticket.get_attributes() + if not object_gid: + object_gid = agg_ticket.get_gid_object() + print object_gid + rspecs.append(agg_ticket.get_rspec()) + initscripts.extend(attrs.get('initscripts', [])) + slivers.extend(attrs.get('slivers', [])) - # send the rspec to the appropiate aggregate/sm - aggregates = api.aggregates - credential = api.getCredential() - tickets = {} - for net_hrn in rspecs: - net_urn = urn_to_hrn(net_hrn) - try: - # if we are directly connected to the aggregate then we can just - # send them the request. if not, then we may be connected to an sm - # thats connected to the aggregate - if net_hrn in aggregates: - ticket = aggregates[net_hrn].get_ticket(credential, xrn, \ - rspecs[net_hrn], origin_hrn) - tickets[net_hrn] = ticket - else: - # lets forward this rspec to a sm that knows about the network - for agg in aggregates: - network_found = aggregates[agg].get_aggregates(credential, net_urn) - if network_found: - ticket = aggregates[aggregate].get_ticket(credential, \ - slice_hrn, rspecs[net_hrn], origin_hrn) - tickets[aggregate] = ticket - except: - print >> log, "Error getting ticket for %(slice_hrn)s at aggregate %(net_hrn)s" % \ - locals() - - # create a new ticket - new_ticket = SfaTicket(subject = slice_hrn) - new_ticket.set_gid_caller(api.auth.client_gid) - new_ticket.set_issuer(key=api.key, subject=api.hrn) - - tmp_rspec = RSpec() - networks = [] - valid_data = { - 'timestamp': int(time.time()), - 'initscripts': [], - 'slivers': [] - } - # merge data from aggregate ticket into new ticket - for agg_ticket in tickets.values(): - # get data from this ticket - agg_ticket = SfaTicket(string=agg_ticket) - attributes = agg_ticket.get_attributes() - if attributes.get('initscripts', []) != None: - valid_data['initscripts'].extend(attributes.get('initscripts', [])) - if attributes.get('slivers', []) != None: - valid_data['slivers'].extend(attributes.get('slivers', [])) - - # set the object gid - object_gid = agg_ticket.get_gid_object() - new_ticket.set_gid_object(object_gid) - new_ticket.set_pubkey(object_gid.get_pubkey()) + # merge info + attributes = {'initscripts': initscripts, + 'slivers': slivers} + merged_rspec = merge_rspecs(rspecs) - # build the rspec - tmp_rspec.parseString(agg_ticket.get_rspec()) - networks.extend([{'NetSpec': tmp_rspec.getDictsByTagName('NetSpec')}]) - + # create a new ticket + ticket = SfaTicket(subject = slice_hrn) + ticket.set_gid_caller(api.auth.client_gid) + ticket.set_issuer(key=api.key, subject=api.hrn) + ticket.set_gid_object(object_gid) + ticket.set_pubkey(object_gid.get_pubkey()) #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn)) - new_ticket.set_attributes(valid_data) - resources = {'networks': networks, 'start_time': 0, 'duration': 0} - resourceDict = {'RSpec': resources} - tmp_rspec.parseDict(resourceDict) - new_ticket.set_rspec(tmp_rspec.toxml()) - new_ticket.encode() - new_ticket.sign() - return new_ticket.save_to_string(save_parents=True) + ticket.set_attributes(attributes) + ticket.set_rspec(merged_rspec) + ticket.encode() + ticket.sign() + return ticket.save_to_string(save_parents=True) def start_slice(api, xrn): - hrn, type = urn_to_hrn(xrn) - slicename = hrn_to_pl_slicename(hrn) - slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id']) - if not slices: - raise RecordNotFound(hrn) - slice_id = slices[0] - attributes = api.plshell.GetSliceTags(api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id']) - attribute_id = attreibutes[0]['slice_attribute_id'] - api.plshell.UpdateSliceTag(api.plauth, attribute_id, "1" ) - + credential = api.getCredential() + threads = ThreadManager() + for aggregate in api.aggregates: + server = api.aggregates[aggregate] + threads.run(server.stop_slice, credential, xrn) + threads.get_results() return 1 def stop_slice(api, xrn): - hrn, type = urn_to_hrn(xrn) - slicename = hrn_to_pl_slicename(hrn) - slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id']) - if not slices: - raise RecordNotFound(hrn) - slice_id = slices[0]['slice_id'] - attributes = api.plshell.GetSliceTags(api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id']) - attribute_id = attributes[0]['slice_attribute_id'] - api.plshell.UpdateSliceTag(api.plauth, attribute_id, "0") + credential = api.getCredential() + threads = ThreadManager() + for aggregate in api.aggregates: + server = api.aggregates[aggregate] + threads.run(server.stop_slice, credential, xrn) + threads.get_results() return 1 def reset_slice(api, xrn): @@ -193,14 +161,17 @@ def get_slices(api): # fetch from aggregates slices = [] credential = api.getCredential() + threads = ThreadManager() for aggregate in api.aggregates: - try: - tmp_slices = api.aggregates[aggregate].get_slices(credential) - slices.extend(tmp_slices) - except: - print >> log, "%s" % (traceback.format_exc()) - print >> log, "Error calling slices at aggregate %(aggregate)s" % locals() + server = api.aggregates[aggregate] + threads.run(server.get_slices, credential) + # combime results + results = threads.get_results() + slices = [] + for result in results: + slices.extend(result) + # cache the result if api.cache: api.cache.add('slices', slices) @@ -216,37 +187,35 @@ def get_rspec(api, xrn=None, origin_hrn=None): hrn, type = urn_to_hrn(xrn) rspec = None - aggs = api.aggregates - cred = api.getCredential() - for agg in aggs: - if agg not in [api.auth.client_cred.get_gid_caller().get_hrn()]: - try: - # get the rspec from the aggregate - agg_rspec = aggs[agg].get_resources(cred, xrn, origin_hrn) - except: - # XX print out to some error log - print >> log, "Error getting resources at aggregate %s" % agg - traceback.print_exc(log) - print >> log, "%s" % (traceback.format_exc()) - continue - - try: - tree = etree.parse(StringIO(agg_rspec)) - except etree.XMLSyntaxError: - message = agg + ": " + str(sys.exc_info()[1]) - raise InvalidRSpec(message) + cred = api.getCredential() + threads = ThreadManager() + for aggregate in api.aggregates: + if aggregate not in [api.auth.client_cred.get_gid_caller().get_hrn()]: + # get the rspec from the aggregate + server = api.aggregates[aggregate] + threads.run(server.get_resources, cred, xrn, origin_hrn) + + results = threads.get_results() + # combine the rspecs into a single rspec + for agg_rspec in results: + try: + tree = etree.parse(StringIO(agg_rspec)) + except etree.XMLSyntaxError: + message = str(agg_rspec) + ": " + str(sys.exc_info()[1]) + raise InvalidRSpec(message) - root = tree.getroot() - if root.get("type") in ["SFA"]: - if rspec == None: - rspec = root - else: - for network in root.iterfind("./network"): - rspec.append(deepcopy(network)) - for request in root.iterfind("./request"): - rspec.append(deepcopy(request)) + root = tree.getroot() + if root.get("type") in ["SFA"]: + if rspec == None: + rspec = root + else: + for network in root.iterfind("./network"): + rspec.append(deepcopy(network)) + for request in root.iterfind("./request"): + rspec.append(deepcopy(request)) rspec = etree.tostring(rspec, xml_declaration=True, pretty_print=True) + # cache the result if api.cache and not xrn: api.cache.add('nodes', rspec) diff --git a/sfa/plc/sfa-import-plc.py b/sfa/plc/sfa-import-plc.py index 39cb28c0..786d57ac 100755 --- a/sfa/plc/sfa-import-plc.py +++ b/sfa/plc/sfa-import-plc.py @@ -118,7 +118,7 @@ def main(): sites_dict[site['login_base']] = site # Get all plc users - persons = shell.GetPersons(plc_auth, {'peer_id': None}, ['person_id', 'email', 'key_ids', 'site_ids']) + persons = shell.GetPersons(plc_auth, {'peer_id': None, 'enabled': True}, ['person_id', 'email', 'key_ids', 'site_ids']) persons_dict = {} for person in persons: persons_dict[person['person_id']] = person diff --git a/sfa/server/interface.py b/sfa/server/interface.py index 65b8d835..5ccb7ee3 100644 --- a/sfa/server/interface.py +++ b/sfa/server/interface.py @@ -87,8 +87,8 @@ class Interfaces(dict): hrns_current = [gid.get_hrn() for gid in gids_current] hrns_expected = self.interfaces.keys() new_hrns = set(hrns_expected).difference(hrns_current) - gids = self.get_peer_gids(new_hrns) - # update the local db records for these registries + gids = self.get_peer_gids(new_hrns) + gids_current + # make sure there is a record for every gid self.update_db_records(self.type, gids) def get_peer_gids(self, new_hrns): @@ -145,20 +145,19 @@ class Interfaces(dict): """ if not gids: return - # get hrns we expect to find - # ignore records for local interfaces - ignore_interfaces = [self.api.config.SFA_INTERFACE_HRN] - hrns_expected = [gid.get_hrn() for gid in gids \ - if gid.get_hrn() not in ignore_interfaces] + + # hrns that should have a record + hrns_expected = [gid.get_hrn() for gid in gids] # get hrns that actually exist in the db table = SfaTable() - records = table.find({'type': type}) + records = table.find({'type': type, 'pointer': -1}) hrns_found = [record['hrn'] for record in records] - + # remove old records for record in records: - if record['hrn'] not in hrns_expected: + if record['hrn'] not in hrns_expected and \ + record['hrn'] != self.api.config.SFA_INTERFACE_HRN: table.remove(record) # add new records diff --git a/sfa/trust/credential.py b/sfa/trust/credential.py index 453401f6..cfab006f 100644 --- a/sfa/trust/credential.py +++ b/sfa/trust/credential.py @@ -217,6 +217,10 @@ class Credential(object): self.xmlsec_path = path + '/' + 'xmlsec1' break + def get_subject(self): + if not self.gidObject: + self.decode() + return self.gidObject.get_subject() def get_signature(self): if not self.signature: @@ -781,9 +785,7 @@ class Credential(object): # @param dump_parents If true, also dump the parent certificates def dump(self, dump_parents=False): -# FIXME: get_subject doesnt exist -# print "CREDENTIAL", self.get_subject() - print "CREDENTIAL" + print "CREDENTIAL", self.get_subject() print " privs:", self.get_privileges().save_to_string() diff --git a/sfa/util/api.py b/sfa/util/api.py index 5c5813fb..9424b49d 100644 --- a/sfa/util/api.py +++ b/sfa/util/api.py @@ -192,9 +192,12 @@ class BaseAPI: try: result = self.call(source, method, *args) + except SfaFault, fault: + result = fault except Exception, fault: - traceback.print_exc(file = log) - result = fault + #traceback.print_exc(file = log) + result = SfaAPIError(fault) + # Return result response = self.prepare_response(result, method) @@ -206,7 +209,7 @@ class BaseAPI: """ if self.protocol == 'xmlrpclib': - if not isinstance(result, Exception): + if not isinstance(result, SfaFault): result = (result,) response = xmlrpclib.dumps(result, methodresponse = True, encoding = self.encoding, allow_none = 1) elif self.protocol == 'soap': diff --git a/sfa/util/rspec.py b/sfa/util/rspec.py index f8910342..c56d3820 100644 --- a/sfa/util/rspec.py +++ b/sfa/util/rspec.py @@ -7,6 +7,38 @@ import os import httplib from xml.dom import minidom from types import StringTypes, ListType +from lxml import etree +from StringIO import StringIO + +def merge_rspecs(rspecs): + """ + Merge merge a set of RSpecs into 1 RSpec, and return the result. + rspecs must be a valid RSpec string or list of rspec strings. + """ + if not rspecs or not isinstance(rspecs, list): + return rspecs + + rspec = None + for tmp_rspec in rspecs: + try: + tree = etree.parse(StringIO(tmp_rspec)) + except etree.XMLSyntaxError: + # consider failing silently here + message = str(agg_rspec) + ": " + str(sys.exc_info()[1]) + raise InvalidRSpec(message) + + root = tree.getroot() + if root.get("type") in ["SFA"]: + if rspec == None: + rspec = root + else: + for network in root.iterfind("./network"): + rspec.append(deepcopy(network)) + for request in root.iterfind("./request"): + rspec.append(deepcopy(request)) + return etree.tostring(rspec, xml_declaration=True, pretty_print=True) + + class RSpec: diff --git a/sfa/util/sfaticket.py b/sfa/util/sfaticket.py index 15c486e6..e4486d1e 100644 --- a/sfa/util/sfaticket.py +++ b/sfa/util/sfaticket.py @@ -79,13 +79,13 @@ class SfaTicket(Certificate): dict["gidCaller"] = self.gidCaller.save_to_string(save_parents=True) if self.gidObject: dict["gidObject"] = self.gidObject.save_to_string(save_parents=True) - str = xmlrpclib.dumps((dict,), allow_none=True) + str = "URI:" + xmlrpclib.dumps((dict,), allow_none=True) self.set_data(str) def decode(self): data = self.get_data() if data: - dict = xmlrpclib.loads(self.get_data())[0][0] + dict = xmlrpclib.loads(self.get_data()[4:])[0][0] else: dict = {} diff --git a/sfa/util/threadmanager.py b/sfa/util/threadmanager.py new file mode 100755 index 00000000..3d5dd03e --- /dev/null +++ b/sfa/util/threadmanager.py @@ -0,0 +1,71 @@ +import threading +import time +from Queue import Queue + +def ThreadedMethod(callable, queue): + """ + A function decorator that returns a running thread. The thread + runs the specified callable and stores the result in the specified + results queue + """ + def wrapper(args, kwds): + class ThreadInstance(threading.Thread): + def run(self): + try: + queue.put(callable(*args, **kwds)) + except: + # ignore errors + pass + thread = ThreadInstance() + thread.start() + return thread + return wrapper + + + +class ThreadManager: + """ + ThreadManager executes a callable in a thread and stores the result + in a thread safe queue. + """ + queue = Queue() + threads = [] + + def run (self, method, *args, **kwds): + """ + Execute a callable in a separate thread. + """ + method = ThreadedMethod(method, self.queue) + thread = method(args, kwds) + self.threads.append(thread) + + start = run + + def get_results(self): + """ + Return a list of all the results so far. Blocks until + all threads are finished. + """ + for thread in self.threads: + thread.join() + results = [] + while not self.queue.empty(): + results.append(self.queue.get()) + return results + +if __name__ == '__main__': + + def f(name, n, sleep=1): + nums = [] + for i in range(n, n+5): + print "%s: %s" % (name, i) + nums.append(i) + time.sleep(sleep) + return nums + + threads = ThreadManager() + threads.run(f, "Thread1", 10, 2) + threads.run(f, "Thread2", -10, 1) + + results = threads.get_results() + print "Results:", results -- 2.43.0