merged trunk -r 18510:18539
authorJosh Karlin <jkarlin@bbn.com>
Fri, 30 Jul 2010 18:43:21 +0000 (18:43 +0000)
committerJosh Karlin <jkarlin@bbn.com>
Fri, 30 Jul 2010 18:43:21 +0000 (18:43 +0000)
12 files changed:
TODO
keyconvert/keyconvert.py
sfa/client/sfi.py
sfa/managers/aggregate_manager_pl.py
sfa/managers/slice_manager_pl.py
sfa/plc/sfa-import-plc.py
sfa/server/interface.py
sfa/trust/credential.py
sfa/util/api.py
sfa/util/rspec.py
sfa/util/sfaticket.py
sfa/util/threadmanager.py [new file with mode: 0755]

diff --git a/TODO b/TODO
index 9c30bf5..d863a59 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,29 +1,16 @@
-- Tutorial
- * make a tutorial for sfa
-- Tag
+- Build/Tags
 * test rpm build/install
 
-- Geni Aggregate
- * are we going to deploy a geni aggregate
- * test
-
-- Trunk
-* use PLC shell instead of xmlrpc when communicating with local plc aggregate
-
-- Client
-  * update getNodes to use lxml.etree for parsing the rspec
-
 - Stop invalid users
 * a recently disabled/deleted user may still have a valid cred. Keep a list of valid/invalid users on the aggregate and check callers against this list
 
 - Component manager
   * GetGids - make this work for peer slices
-  * GetTicket - must verify_{site,slice,person,keys} on remote aggregate 
   * Redeem ticket - RedeemTicket/AdminTicket not working. Why?
   * install the slice and node gid when the slice is created (create NM plugin to execute sfa_component_setup.py ?) 
 
 - Registry
+* fix legacy credential support
 * move db tables into db with less overhead (tokyocabinet?)
 
 - GUI/Auth Service
   * service manages users key/cert,creds
   * gui requires user's cred (depends on Auth Service above)
       
--  SM call routing
-* sfi -a option should send request to sm with an extra argument to 
-  specify which am to contact instead of connecting directly to the am 
-  (am may not trust client directly)
-
 - Protogeni
-* merger josh's branch with trunk
 * agree on standard set of functon calls
 * agree on standard set of privs
 * on permission error, return priv needed to make call
index af12b1f..de904ee 100755 (executable)
@@ -4,7 +4,21 @@ import sys
 import base64
 import struct
 import binascii
-from M2Crypto import RSA, DSA
+from M2Crypto import RSA, DSA, m2
+
+
+###### Workaround for bug in m2crypto-0.18 (on Fedora 8)
+class RSA_pub_fix(RSA.RSA_pub):
+    def save_key_bio(self, bio, *args, **kw):
+        return self.save_pub_key_bio(bio)
+
+def rsa_new_pub_key((e, n)):
+    rsa = m2.rsa_new()
+    m2.rsa_set_e(rsa, e)
+    m2.rsa_set_n(rsa, n)
+    return RSA_pub_fix(rsa, 1)
+######
+#rsa_new_pub_key = RSA.new_pub_key
 
 
 def decode_key(fname):
@@ -78,7 +92,7 @@ def convert(fin, fout):
 
     if key_type == "ssh-rsa":
         e, n = ret[1:]
-        rsa = RSA.new_pub_key((e, n))
+        rsa = rsa_new_pub_key((e, n))
         rsa.save_pem(fout)
 
     elif key_type == "ssh-dss":
index 8592ec1..53c2227 100755 (executable)
@@ -9,6 +9,7 @@ import tempfile
 import traceback
 import socket
 import random
+import datetime
 from lxml import etree
 from StringIO import StringIO
 from types import StringTypes, ListType
@@ -23,6 +24,8 @@ import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.util.config import Config
 import zlib
 
+AGGREGATE_PORT=12346
+CM_PORT=12346
 
 # utility methods here
 # display methods
@@ -171,12 +174,12 @@ class Sfi:
             parser.add_option("-f", "--format", dest="format", type="choice",
                              help="display format ([xml]|dns|ip)", default="xml",
                              choices=("xml", "dns", "ip"))
+                                
+        if command in ("resources", "slices", "create", "delete", "start", "stop", "get_ticket"):
             parser.add_option("-a", "--aggregate", dest="aggregate",
-                             default=None, help="aggregate hrn")
-
-        if command in ("create", "get_ticket"):
-            parser.add_option("-a", "--aggregate", dest="aggregate", default=None,
-                             help="aggregate hrn")
+                             default=None, help="aggregate host")
+            parser.add_option("-p", "--port", dest="port",
+                             default=AGGREGATE_PORT, help="aggregate port")
 
         if command in ("start", "stop", "reset", "delete", "slices"):
             parser.add_option("-c", "--component", dest="component", default=None,
@@ -203,7 +206,7 @@ class Sfi:
                             help="delegate user credential")
            parser.add_option("-s", "--slice", dest="delegate_slice",
                             help="delegate slice credential", metavar="HRN", default=None)
-
+        
         return parser
 
         
@@ -378,12 +381,25 @@ class Sfi:
                 print "Writing user gid to", file
             gid.save_to_file(file, save_parents=True)
             return gid       
+
+    def get_cached_credential(self, file):
+        """
+        Return a cached credential only if it hasn't expired.
+        """
+        if (os.path.isfile(file)):
+            credential = Credential(filename=file)
+            # make sure it isnt expired 
+            if not credential.get_lifetime or \
+               datetime.datetime.today() < credential.get_lifefime():
+                return credential
+        return None 
  
     def get_user_cred(self):
         #file = os.path.join(self.options.sfi_dir, get_leaf(self.user) + ".cred")
         file = os.path.join(self.options.sfi_dir, self.user.replace(self.authority + '.', '') + ".cred")
-        if (os.path.isfile(file)):
-            user_cred = Credential(filename=file)
+
+        user_cred = self.get_cached_credential(file)
+        if user_cred:
             return user_cred
         else:
             # bootstrap user credential
@@ -410,8 +426,8 @@ class Sfi:
             sys.exit(-1)
     
         file = os.path.join(self.options.sfi_dir, get_leaf("authority") + ".cred")
-        if (os.path.isfile(file)):
-            auth_cred = Credential(filename=file)
+        auth_cred = self.get_cached_credential(file)
+        if auth_cred:
             return auth_cred
         else:
             # bootstrap authority credential from user credential
@@ -429,8 +445,8 @@ class Sfi:
     
     def get_slice_cred(self, name):
         file = os.path.join(self.options.sfi_dir, "slice_" + get_leaf(name) + ".cred")
-        if (os.path.isfile(file)):
-            slice_cred = Credential(filename=file)
+        slice_cred = self.get_cached_credential(file)
+        if slice_cred:
             return slice_cred
         else:
             # bootstrap slice credential from user credential
@@ -538,15 +554,22 @@ class Sfi:
         if not records:
             print "No such component:", opts.component
         record = records[0]
-        cm_port = "12346"
-        url = "https://%s:%s" % (record['hostname'], cm_port)
-        return xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug)
-    
-    #
+  
+        return self.get_server(record['hostname'], CM_PORT, self.key_file, \
+                               self.cert_file, self.options.debug)
+    def get_server(self, host, port, keyfile, certfile, debug):
+        """
+        Return an instnace of an xmlrpc server connection    
+        """
+        url = "http://%s:%s" % (host, port)
+        return xmlrpcprotocol.get_server(url, keyfile, certfile, debug)
+    #==========================================================================
     # Following functions implement the commands
     #
     # Registry-related commands
-    #
+    #==========================================================================
   
     def dispatch(self, command, cmd_opts, cmd_args):
         getattr(self, command)(cmd_opts, cmd_args)
@@ -761,16 +784,21 @@ class Sfi:
         return
 
  
-    #
+    # ==================================================================
     # Slice-related commands
-    #
+    # ==================================================================
     
-    # list available nodes -- use 'resources' w/ no argument instead
 
     # list instantiated slices
     def slices(self, opts, args):
+        """
+        list instantiated slices
+        """
         user_cred = self.get_user_cred().save_to_string(save_parents=True)
         server = self.slicemgr
+        if opts.aggregate:
+            server = self.get_server(opts.aggregate, opts.port, self.key_file, \
+                                     self.cert_file, self.options.debug)
         # direct connection to the nodes component manager interface
         if opts.component:
             server = self.get_component_server_from_hrn(opts.component)
@@ -783,13 +811,8 @@ class Sfi:
         user_cred = self.get_user_cred().save_to_string(save_parents=True)
         server = self.slicemgr
         if opts.aggregate:
-            agg_hrn = opts.aggregate
-            aggregates = self.registry.get_aggregates(user_cred, agg_hrn)
-            if not aggregates:
-                raise Exception, "No such aggregate %s" % agg_hrn
-            aggregate = aggregates[0]
-            url = "http://%s:%s" % (aggregate['addr'], aggregate['port'])     
-            server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug)
+            server = self.get_server(opts.aggregate, opts.port, self.key_file, \
+                                     self.cert_file, self.options.debug)
         if args:
             cred = self.get_slice_cred(args[0]).save_to_string(save_parents=True)
             hrn = args[0]
@@ -816,13 +839,11 @@ class Sfi:
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
         server = self.slicemgr
+
         if opts.aggregate:
-            aggregates = self.registry.get_aggregates(user_cred, opts.aggregate)
-            if not aggregates:
-                raise Exception, "No such aggregate %s" % opts.aggregate
-            aggregate = aggregates[0]
-            url = "http://%s:%s" % (aggregate['addr'], aggregate['port'])
-            server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug)
+            server = self.get_server(opts.aggregate, opts.port, self.key_file, \
+                                     self.cert_file, self.options.debug)
+
         return server.create_slice(slice_cred, slice_hrn, rspec)
 
     # get a ticket for the specified slice
@@ -834,12 +855,8 @@ class Sfi:
         rspec = open(rspec_file).read()
         server = self.slicemgr
         if opts.aggregate:
-            aggregates = self.registry.get_aggregates(user_cred, opts.aggregate)
-            if not aggregates:
-                raise Exception, "No such aggregate %s" % opts.aggregate
-            aggregate = aggregates[0]
-            url = "http://%s:%s" % (aggregate['addr'], aggregate['port'])
-            server = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug)
+            server = self.get_server(opts.aggregate, opts.port, self.key_file, \
+                                     self.cert_file, self.options.debug)
         ticket_string = server.get_ticket(slice_cred, slice_hrn, rspec)
         file = os.path.join(self.options.sfi_dir, get_leaf(slice_hrn) + ".ticket")
         print "writing ticket to ", file        
@@ -853,7 +870,7 @@ class Sfi:
         # use this to get the right slice credential 
         ticket = SfaTicket(filename=ticket_file)
         ticket.decode()
-       slice_hrn = ticket.gidObject.get_hrn()
+        slice_hrn = ticket.gidObject.get_hrn()
         #slice_hrn = ticket.attributes['slivers'][0]['hrn']
         user_cred = self.get_user_cred()
         slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True)
@@ -868,28 +885,28 @@ class Sfi:
         connections = {}
         for hostname in hostnames:
             try:
-                cm_port = "12346" 
-                url = "https://%(hostname)s:%(cm_port)s" % locals() 
-                print "Calling redeem_ticket at %(url)s " % locals(),
-                cm = xmlrpcprotocol.get_server(url, self.key_file, self.cert_file, self.options.debug)
-                cm.redeem_ticket(slice_cred, ticket.save_to_string(save_parents=True))
+                print "Calling redeem_ticket at %(hostname)s " % locals(),
+                server = self.get_server(hostname, CM_PORT, self.key_file, \
+                                         self.cert_file, self.options.debug)
+                server.redeem_ticket(slice_cred, ticket.save_to_string(save_parents=True))
                 print "Success"
             except socket.gaierror:
                 print "Failed:",
                 print "Componet Manager not accepting requests" 
             except Exception, e:
                 print "Failed:", e.message
-             
         return
  
     # delete named slice
     def delete(self, opts, args):
         slice_hrn = args[0]
         server = self.slicemgr
+        if opts.aggregate:
+            server = self.get_server(opts.aggregate, opts.port, self.key_file, \
+                                     self.cert_file, self.options.debug)
         # direct connection to the nodes component manager interface
         if opts.component:
             server = self.get_component_server_from_hrn(opts.component)
         slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True)
         return server.delete_slice(slice_cred, slice_hrn)
     
@@ -897,10 +914,12 @@ class Sfi:
     def start(self, opts, args):
         slice_hrn = args[0]
         server = self.slicemgr
-        # direct connection to the nodes component manager interface
+        # direct connection to an aggregagte
+        if opts.aggregate:
+            server = self.get_server(opts.aggregate, opts.port, self.key_file, \
+                                     self.cert_file, self.options.debug)
         if opts.component:
             server = self.get_component_server_from_hrn(opts.component)
         slice_cred = self.get_slice_cred(args[0]).save_to_string(save_parents=True)
         return server.start_slice(slice_cred, slice_hrn)
     
@@ -908,10 +927,13 @@ class Sfi:
     def stop(self, opts, args):
         slice_hrn = args[0]
         server = self.slicemgr
+        # direct connection to an aggregate
+        if opts.aggregate:
+            server = self.get_server(opts.aggregate, opts.port, self.key_file, \
+                                     self.cert_file, self.options.debug)
         # direct connection to the nodes component manager interface
         if opts.component:
             server = self.get_component_server_from_hrn(opts.component)
-
         slice_cred = self.get_slice_cred(args[0]).save_to_string(save_parents=True)
         return server.stop_slice(slice_cred, slice_hrn)
     
@@ -926,7 +948,9 @@ class Sfi:
         return server.reset_slice(slice_cred, slice_hrn)
 
 
+    # =====================================================================
     # GENI AM related calls
+    # =====================================================================
 
     def GetVersion(self, opts, args):
         server = self.geni_am
index 76c576b..e6b2c3d 100644 (file)
@@ -96,13 +96,25 @@ def create_slice(api, xrn, xml, reg_objects=None):
     return True
 
 
-def get_ticket(api, xrn, rspec, origin_hrn=None):
+def get_ticket(api, xrn, rspec, origin_hrn=None, reg_objects=None):
+
     slice_hrn, type = urn_to_hrn(xrn)
-    # the the slice record
+    slices = Slices(api)
+    peer = slices.get_peer(slice_hrn)
+    sfa_peer = slices.get_sfa_peer(slice_hrn)
+    
+    # get the slice record
     registry = api.registries[api.hrn]
     credential = api.getCredential()
     records = registry.resolve(credential, xrn)
-    
+
+    # similar to create_slice, we must verify that the required records exist
+    # at this aggregate before we can issue a ticket   
+    site_id, remote_site_id = slices.verify_site(registry, credential, slice_hrn,
+                                                 peer, sfa_peer, reg_objects)
+    slice = slices.verify_slice(registry, credential, slice_hrn, site_id,
+                                remote_site_id, peer, sfa_peer, reg_objects)
+
     # make sure we get a local slice record
     record = None  
     for tmp_record in records:
index 1606eb3..72227d8 100644 (file)
@@ -9,7 +9,7 @@ from copy import deepcopy
 from lxml import etree
 from StringIO import StringIO
 from types import StringTypes
-
+from sfa.util.rspec import merge_rspecs
 from sfa.util.namespace import *
 from sfa.util.rspec import *
 from sfa.util.specdict import *
@@ -18,21 +18,18 @@ from sfa.util.record import SfaRecord
 from sfa.util.policy import Policy
 from sfa.util.prefixTree import prefixTree
 from sfa.util.sfaticket import *
+from sfa.util.threadmanager import ThreadManager
+import sfa.util.xmlrpcprotocol as xmlrpcprotocol     
 from sfa.util.debug import log
 import sfa.plc.peers as peers
 
 def delete_slice(api, xrn, origin_hrn=None):
     credential = api.getCredential()
-    aggregates = api.aggregates
-    for aggregate in aggregates:
-        success = False
-        # request hash is optional so lets try the call without it
-        try:
-            aggregates[aggregate].delete_slice(credential, xrn, origin_hrn)
-            success = True
-        except:
-            print >> log, "%s" % (traceback.format_exc())
-            print >> log, "Error calling delete slice at aggregate %s" % aggregate
+    threads = ThreadManager()
+    for aggregate in api.aggregates:
+        server = api.aggregates[aggregate] 
+        threads.run(server.delete_slice, credential, xrn, origin_hrn)
+    threads.get_results()
     return 1
 
 def create_slice(api, xrn, rspec, origin_hrn=None):
@@ -57,126 +54,97 @@ def create_slice(api, xrn, rspec, origin_hrn=None):
             message = "%s (line %s)" % (error.message, error.line)
             raise InvalidRSpec(message)
 
-    aggs = api.aggregates
-    cred = api.getCredential()                                                 
-    for agg in aggs:
-        if agg not in [api.auth.client_cred.get_gid_caller().get_hrn()]:      
-            try:
-                # Just send entire RSpec to each aggregate
-                aggs[agg].create_slice(cred, xrn, rspec, origin_hrn)
-            except:
-                print >> log, "Error creating slice %s at %s" % (hrn, agg)
-                traceback.print_exc()
-
-    return True
+    cred = api.getCredential()
+    threads = ThreadManager()
+    for aggregate in api.aggregates:
+        if aggregate not in [api.auth.client_cred.get_gid_caller().get_hrn()]:
+            server = api.aggregates[aggregate]
+            # Just send entire RSpec to each aggregate
+            threads.run(server.create_slice, cred, xrn, rspec, origin_hrn)
+    threads.get_results() 
+    return 1
 
 def get_ticket(api, xrn, rspec, origin_hrn=None):
     slice_hrn, type = urn_to_hrn(xrn)
     # get the netspecs contained within the clients rspec
-    client_rspec = RSpec(xml=rspec)
-    netspecs = client_rspec.getDictsByTagName('NetSpec')
+    aggregate_rspecs = {}
+    tree= etree.parse(StringIO(rspec))
+    elements = tree.findall('./network')
+    for element in elements:
+        aggregate_hrn = element.values()[0]
+        aggregate_rspecs[aggregate_hrn] = rspec 
+
+    # get a ticket from each aggregate 
+    credential = api.getCredential()
+    threads = ThreadManager()
+    for aggregate, aggregate_rspec in aggregate_rspecs.items():
+        server = None
+        if aggregate in api.aggregates:
+            server = api.aggregates[aggregate]
+        else:
+            net_urn = hrn_to_urn(aggregate, 'authority')     
+            # we may have a peer that knows about this aggregate
+            for agg in api.aggregates:
+                agg_info = api.aggregates[agg].get_aggregates(credential, net_urn)
+                if agg_info:
+                    # send the request to this address 
+                    url = 'http://%s:%s' % (agg_info['addr'], agg_info['port'])
+                    server = xmlrpcprotocol.get_server(url, api.key_file, api.cert_file)
+                    break   
+        if server is None:
+            continue 
+        threads.run(server.get_ticket, credential, xrn, aggregate_rspec, origin_hrn)
+    results = threads.get_results()
     
-    # create an rspec for each individual rspec 
-    rspecs = {}
-    temp_rspec = RSpec()
-    for netspec in netspecs:
-        net_hrn = netspec['name']
-        resources = {'start_time': 0, 'end_time': 0 , 
-                     'network': {'NetSpec' : netspec}}
-        resourceDict = {'RSpec': resources}
-        temp_rspec.parseDict(resourceDict)
-        rspecs[net_hrn] = temp_rspec.toxml() 
+    # gather information from each ticket 
+    rspecs = []
+    initscripts = []
+    slivers = [] 
+    object_gid = None  
+    for result in results:
+        agg_ticket = SfaTicket(string=result)
+        attrs = agg_ticket.get_attributes()
+        if not object_gid:
+            object_gid = agg_ticket.get_gid_object()
+        print object_gid
+        rspecs.append(agg_ticket.get_rspec())
+        initscripts.extend(attrs.get('initscripts', [])) 
+        slivers.extend(attrs.get('slivers', [])) 
     
-    # send the rspec to the appropiate aggregate/sm
-    aggregates = api.aggregates
-    credential = api.getCredential()
-    tickets = {}
-    for net_hrn in rspecs:
-        net_urn = urn_to_hrn(net_hrn)     
-        try:
-            # if we are directly connected to the aggregate then we can just
-            # send them the request. if not, then we may be connected to an sm
-            # thats connected to the aggregate
-            if net_hrn in aggregates:
-                ticket = aggregates[net_hrn].get_ticket(credential, xrn, \
-                            rspecs[net_hrn], origin_hrn)
-                tickets[net_hrn] = ticket
-            else:
-                # lets forward this rspec to a sm that knows about the network
-                for agg in aggregates:
-                    network_found = aggregates[agg].get_aggregates(credential, net_urn)
-                    if network_found:
-                        ticket = aggregates[aggregate].get_ticket(credential, \
-                                        slice_hrn, rspecs[net_hrn], origin_hrn)
-                        tickets[aggregate] = ticket
-        except:
-            print >> log, "Error getting ticket for %(slice_hrn)s at aggregate %(net_hrn)s" % \
-                           locals()
-            
-    # create a new ticket
-    new_ticket = SfaTicket(subject = slice_hrn)
-    new_ticket.set_gid_caller(api.auth.client_gid)
-    new_ticket.set_issuer(key=api.key, subject=api.hrn)
-   
-    tmp_rspec = RSpec()
-    networks = []
-    valid_data = {
-        'timestamp': int(time.time()),
-        'initscripts': [],
-        'slivers': [] 
-    } 
-    # merge data from aggregate ticket into new ticket 
-    for agg_ticket in tickets.values():
-        # get data from this ticket
-        agg_ticket = SfaTicket(string=agg_ticket)
-        attributes = agg_ticket.get_attributes()
-       if attributes.get('initscripts', []) != None:
-            valid_data['initscripts'].extend(attributes.get('initscripts', []))
-       if attributes.get('slivers', []) != None:
-            valid_data['slivers'].extend(attributes.get('slivers', []))
-        # set the object gid
-        object_gid = agg_ticket.get_gid_object()
-        new_ticket.set_gid_object(object_gid)
-        new_ticket.set_pubkey(object_gid.get_pubkey())
+    # merge info
+    attributes = {'initscripts': initscripts,
+                 'slivers': slivers}
+    merged_rspec = merge_rspecs(rspecs) 
 
-        # build the rspec
-        tmp_rspec.parseString(agg_ticket.get_rspec())
-        networks.extend([{'NetSpec': tmp_rspec.getDictsByTagName('NetSpec')}])
-    
+    # create a new ticket
+    ticket = SfaTicket(subject = slice_hrn)
+    ticket.set_gid_caller(api.auth.client_gid)
+    ticket.set_issuer(key=api.key, subject=api.hrn)
+    ticket.set_gid_object(object_gid)
+    ticket.set_pubkey(object_gid.get_pubkey())
     #new_ticket.set_parent(api.auth.hierarchy.get_auth_ticket(auth_hrn))
-    new_ticket.set_attributes(valid_data)
-    resources = {'networks': networks, 'start_time': 0, 'duration': 0}
-    resourceDict = {'RSpec': resources}
-    tmp_rspec.parseDict(resourceDict)
-    new_ticket.set_rspec(tmp_rspec.toxml())
-    new_ticket.encode()
-    new_ticket.sign()          
-    return new_ticket.save_to_string(save_parents=True)
+    ticket.set_attributes(attributes)
+    ticket.set_rspec(merged_rspec)
+    ticket.encode()
+    ticket.sign()          
+    return ticket.save_to_string(save_parents=True)
 
 def start_slice(api, xrn):
-    hrn, type = urn_to_hrn(xrn)
-    slicename = hrn_to_pl_slicename(hrn)
-    slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id'])
-    if not slices:
-        raise RecordNotFound(hrn)
-    slice_id = slices[0]
-    attributes = api.plshell.GetSliceTags(api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
-    attribute_id = attreibutes[0]['slice_attribute_id']
-    api.plshell.UpdateSliceTag(api.plauth, attribute_id, "1" )
-
+    credential = api.getCredential()
+    threads = ThreadManager()
+    for aggregate in api.aggregates:
+        server = api.aggregates[aggregate]
+        threads.run(server.stop_slice, credential, xrn)
+    threads.get_results()    
     return 1
  
 def stop_slice(api, xrn):
-    hrn, type = urn_to_hrn(xrn)
-    slicename = hrn_to_pl_slicename(hrn)
-    slices = api.plshell.GetSlices(api.plauth, {'name': slicename}, ['slice_id'])
-    if not slices:
-        raise RecordNotFound(hrn)
-    slice_id = slices[0]['slice_id']
-    attributes = api.plshell.GetSliceTags(api.plauth, {'slice_id': slice_id, 'name': 'enabled'}, ['slice_attribute_id'])
-    attribute_id = attributes[0]['slice_attribute_id']
-    api.plshell.UpdateSliceTag(api.plauth, attribute_id, "0")
+    credential = api.getCredential()
+    threads = ThreadManager()
+    for aggregate in api.aggregates:
+        server = api.aggregates[aggregate]
+        threads.run(server.stop_slice, credential, xrn)
+    threads.get_results()    
     return 1
 
 def reset_slice(api, xrn):
@@ -193,14 +161,17 @@ def get_slices(api):
     # fetch from aggregates
     slices = []
     credential = api.getCredential()
+    threads = ThreadManager()
     for aggregate in api.aggregates:
-        try:
-            tmp_slices = api.aggregates[aggregate].get_slices(credential)
-            slices.extend(tmp_slices)
-        except:
-            print >> log, "%s" % (traceback.format_exc())
-            print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
+        server = api.aggregates[aggregate]
+        threads.run(server.get_slices, credential)
 
+    # combime results
+    results = threads.get_results()
+    slices = []
+    for result in results:
+        slices.extend(result)
+    
     # cache the result
     if api.cache:
         api.cache.add('slices', slices)
@@ -216,37 +187,35 @@ def get_rspec(api, xrn=None, origin_hrn=None):
 
     hrn, type = urn_to_hrn(xrn)
     rspec = None
-    aggs = api.aggregates
-    cred = api.getCredential()                                                 
-    for agg in aggs:
-        if agg not in [api.auth.client_cred.get_gid_caller().get_hrn()]:      
-            try:
-                # get the rspec from the aggregate
-                agg_rspec = aggs[agg].get_resources(cred, xrn, origin_hrn)
-            except:
-                # XX print out to some error log
-                print >> log, "Error getting resources at aggregate %s" % agg
-                traceback.print_exc(log)
-                print >> log, "%s" % (traceback.format_exc())
-                continue
-                
-            try:
-                tree = etree.parse(StringIO(agg_rspec))
-            except etree.XMLSyntaxError:
-                message = agg + ": " + str(sys.exc_info()[1])
-                raise InvalidRSpec(message)
+    cred = api.getCredential()
+    threads = ThreadManager()
+    for aggregate in api.aggregates:
+        if aggregate not in [api.auth.client_cred.get_gid_caller().get_hrn()]:      
+            # get the rspec from the aggregate
+            server = api.aggregates[aggregate]
+            threads.run(server.get_resources, cred, xrn, origin_hrn)
+
+    results = threads.get_results()
+    # combine the rspecs into a single rspec 
+    for agg_rspec in results:
+        try:
+            tree = etree.parse(StringIO(agg_rspec))
+        except etree.XMLSyntaxError:
+            message = str(agg_rspec) + ": " + str(sys.exc_info()[1])
+            raise InvalidRSpec(message)
 
-            root = tree.getroot()
-            if root.get("type") in ["SFA"]:
-                if rspec == None:
-                    rspec = root
-                else:
-                    for network in root.iterfind("./network"):
-                        rspec.append(deepcopy(network))
-                    for request in root.iterfind("./request"):
-                        rspec.append(deepcopy(request))
+        root = tree.getroot()
+        if root.get("type") in ["SFA"]:
+            if rspec == None:
+                rspec = root
+            else:
+                for network in root.iterfind("./network"):
+                    rspec.append(deepcopy(network))
+                for request in root.iterfind("./request"):
+                    rspec.append(deepcopy(request))
 
     rspec =  etree.tostring(rspec, xml_declaration=True, pretty_print=True)
+    # cache the result
     if api.cache and not xrn:
         api.cache.add('nodes', rspec)
  
index 39cb28c..786d57a 100755 (executable)
@@ -118,7 +118,7 @@ def main():
         sites_dict[site['login_base']] = site 
     
     # Get all plc users
-    persons = shell.GetPersons(plc_auth, {'peer_id': None}, ['person_id', 'email', 'key_ids', 'site_ids'])
+    persons = shell.GetPersons(plc_auth, {'peer_id': None, 'enabled': True}, ['person_id', 'email', 'key_ids', 'site_ids'])
     persons_dict = {}
     for person in persons:
         persons_dict[person['person_id']] = person
index 65b8d83..5ccb7ee 100644 (file)
@@ -87,8 +87,8 @@ class Interfaces(dict):
         hrns_current = [gid.get_hrn() for gid in gids_current] 
         hrns_expected = self.interfaces.keys() 
         new_hrns = set(hrns_expected).difference(hrns_current)
-        gids = self.get_peer_gids(new_hrns)
-        # update the local db records for these registries
+        gids = self.get_peer_gids(new_hrns) + gids_current
+        # make sure there is a record for every gid
         self.update_db_records(self.type, gids)
         
     def get_peer_gids(self, new_hrns):
@@ -145,20 +145,19 @@ class Interfaces(dict):
         """
         if not gids: 
             return
-        # get hrns we expect to find
-        # ignore records for local interfaces
-        ignore_interfaces = [self.api.config.SFA_INTERFACE_HRN]
-        hrns_expected = [gid.get_hrn() for gid in gids \
-                         if gid.get_hrn() not in ignore_interfaces]
+        
+        # hrns that should have a record
+        hrns_expected = [gid.get_hrn() for gid in gids]
 
         # get hrns that actually exist in the db
         table = SfaTable()
-        records = table.find({'type': type})
+        records = table.find({'type': type, 'pointer': -1})
         hrns_found = [record['hrn'] for record in records]
-       
+      
         # remove old records
         for record in records:
-            if record['hrn'] not in hrns_expected:
+            if record['hrn'] not in hrns_expected and \
+                record['hrn'] != self.api.config.SFA_INTERFACE_HRN:
                 table.remove(record)
 
         # add new records
index 453401f..cfab006 100644 (file)
@@ -217,6 +217,10 @@ class Credential(object):
                 self.xmlsec_path = path + '/' + 'xmlsec1'
                 break
 
+    def get_subject(self):
+        if not self.gidObject:
+            self.decode()
+        return self.gidObject.get_subject()   
 
     def get_signature(self):
         if not self.signature:
@@ -781,9 +785,7 @@ class Credential(object):
     # @param dump_parents If true, also dump the parent certificates
 
     def dump(self, dump_parents=False):
-# FIXME: get_subject doesnt exist
-#        print "CREDENTIAL", self.get_subject()
-        print "CREDENTIAL"
+        print "CREDENTIAL", self.get_subject()
 
         print "      privs:", self.get_privileges().save_to_string()
 
index 5c5813f..9424b49 100644 (file)
@@ -192,9 +192,12 @@ class BaseAPI:
 
         try:
             result = self.call(source, method, *args)
+        except SfaFault, fault:
+            result = fault 
         except Exception, fault:
-            traceback.print_exc(file = log)
-            result = fault
+            #traceback.print_exc(file = log)
+            result = SfaAPIError(fault)
+
 
         # Return result
         response = self.prepare_response(result, method)
@@ -206,7 +209,7 @@ class BaseAPI:
         """   
  
         if self.protocol == 'xmlrpclib':
-            if not isinstance(result, Exception):
+            if not isinstance(result, SfaFault):
                 result = (result,)
             response = xmlrpclib.dumps(result, methodresponse = True, encoding = self.encoding, allow_none = 1)
         elif self.protocol == 'soap':
index f891034..c56d382 100644 (file)
@@ -7,6 +7,38 @@ import os
 import httplib
 from xml.dom import minidom
 from types import StringTypes, ListType
+from lxml import etree
+from StringIO import StringIO
+
+def merge_rspecs(rspecs):
+    """
+    Merge merge a set of RSpecs into 1 RSpec, and return the result.
+    rspecs must be a valid RSpec string or list of rspec strings. 
+    """
+    if not rspecs or not isinstance(rspecs, list):
+        return rspecs
+    
+    rspec = None
+    for tmp_rspec in rspecs:
+        try:
+            tree = etree.parse(StringIO(tmp_rspec))
+        except etree.XMLSyntaxError:
+            # consider failing silently here
+            message = str(agg_rspec) + ": " + str(sys.exc_info()[1])
+            raise InvalidRSpec(message)
+
+        root = tree.getroot()
+        if root.get("type") in ["SFA"]:
+            if rspec == None:
+                rspec = root
+            else:
+                for network in root.iterfind("./network"):
+                    rspec.append(deepcopy(network))
+                for request in root.iterfind("./request"):
+                    rspec.append(deepcopy(request))    
+    return etree.tostring(rspec, xml_declaration=True, pretty_print=True)
+        
+
 
 class RSpec:
 
index 15c486e..e4486d1 100644 (file)
@@ -79,13 +79,13 @@ class SfaTicket(Certificate):
             dict["gidCaller"] = self.gidCaller.save_to_string(save_parents=True)
         if self.gidObject:
             dict["gidObject"] = self.gidObject.save_to_string(save_parents=True)
-        str = xmlrpclib.dumps((dict,), allow_none=True)
+        str = "URI:" + xmlrpclib.dumps((dict,), allow_none=True)
         self.set_data(str)
 
     def decode(self):
         data = self.get_data()
         if data:
-            dict = xmlrpclib.loads(self.get_data())[0][0]
+            dict = xmlrpclib.loads(self.get_data()[4:])[0][0]
         else:
             dict = {}
 
diff --git a/sfa/util/threadmanager.py b/sfa/util/threadmanager.py
new file mode 100755 (executable)
index 0000000..3d5dd03
--- /dev/null
@@ -0,0 +1,71 @@
+import threading
+import time
+from Queue import Queue
+
+def ThreadedMethod(callable, queue):
+    """
+    A function decorator that returns a running thread. The thread
+    runs the specified callable and stores the result in the specified
+    results queue
+    """
+    def wrapper(args, kwds):
+        class ThreadInstance(threading.Thread): 
+            def run(self):
+                try:
+                    queue.put(callable(*args, **kwds))
+                except:
+                    # ignore errors
+                    pass
+        thread = ThreadInstance()
+        thread.start()
+        return thread
+    return wrapper
+
+
+class ThreadManager:
+    """
+    ThreadManager executes a callable in a thread and stores the result
+    in a thread safe queue. 
+    """
+    queue = Queue()
+    threads = []
+
+    def run (self, method, *args, **kwds):
+        """
+        Execute a callable in a separate thread.    
+        """
+        method = ThreadedMethod(method, self.queue)
+        thread = method(args, kwds)
+        self.threads.append(thread)
+
+    start = run
+
+    def get_results(self):
+        """
+        Return a list of all the results so far. Blocks until 
+        all threads are finished. 
+        """
+        for thread in self.threads:
+            thread.join()
+        results = []
+        while not self.queue.empty():
+            results.append(self.queue.get())  
+        return results
+           
+if __name__ == '__main__':
+
+    def f(name, n, sleep=1):
+        nums = []
+        for i in range(n, n+5):
+            print "%s: %s" % (name, i)
+            nums.append(i)
+            time.sleep(sleep)
+        return nums
+
+    threads = ThreadManager()
+    threads.run(f, "Thread1", 10, 2)
+    threads.run(f, "Thread2", -10, 1)
+
+    results = threads.get_results()
+    print "Results:", results