merged in from trunk -r 17776:17849
authorJosh Karlin <jkarlin@bbn.com>
Thu, 29 Apr 2010 13:58:07 +0000 (13:58 +0000)
committerJosh Karlin <jkarlin@bbn.com>
Thu, 29 Apr 2010 13:58:07 +0000 (13:58 +0000)
sfa/managers/aggregate_manager_pl.py
sfa/managers/slice_manager_pl.py
sfa/methods/get_resources.py
sfa/plc/api.py
sfa/plc/nodes.py [deleted file]
sfa/plc/slices.py
sfa/trust/certificate.py
sfa/util/api.py
sfa/util/cache.py [new file with mode: 0644]
sfa/util/server.py
sfa/util/table.py

index 8b95d25..c49d841 100644 (file)
@@ -46,12 +46,12 @@ def __get_hostnames(nodes):
     return hostnames
     
 def create_slice(api, xrn, xml):
-    hrn, type = urn_to_hrn(xrn)
-    peer = None
-
     """
     Verify HRN and initialize the slice record in PLC if necessary.
     """
+
+    hrn, type = urn_to_hrn(xrn)
+    peer = None
     slices = Slices(api)
     peer = slices.get_peer(hrn)
     sfa_peer = slices.get_sfa_peer(hrn)
@@ -66,6 +66,7 @@ def create_slice(api, xrn, xml):
 
     slice = network.get_slice(api, hrn)
     current = __get_hostnames(slice.get_nodes())
+    
     network.addRSpec(xml, api.config.SFA_AGGREGATE_RSPEC_SCHEMA)
     request = __get_hostnames(network.nodesWithSlivers())
     
@@ -168,22 +169,43 @@ def reset_slice(api, xrn):
     return 1
 
 def get_slices(api):
-    # XX just import the legacy module and excute that until
-    # we transition the code to this module
-    from sfa.plc.slices import Slices
-    slices = Slices(api)
-    slices.refresh()
-    return [hrn_to_urn(slice_hrn, 'slice') for slice_hrn in slices['hrn']]
-     
+    # look in cache first
+    if api.cache:
+        slices = api.cache.get('slices')
+        if slices:
+            return slices
+
+    # get data from db 
+    slices = api.plshell.GetSlices(api.plauth, {'peer_id': None}, ['name'])
+    slice_hrns = [slicename_to_hrn(api.hrn, slice['name']) for slice in slices]
+    slice_urns = [hrn_to_urn(slice_hrn, 'slice') for slice_hrn in slice_hrns]
+
+    # cache the result
+    if api.cache:
+        api.cache.add('slices', slice_urns) 
+
+    return slice_urns
+    
 def get_rspec(api, xrn=None, origin_hrn=None):
+    # look in cache first
+    if api.cache and not xrn:
+        rspec = api.cache.get('nodes')
+        if rspec:
+            return rspec 
+
     hrn, type = urn_to_hrn(xrn)
     network = Network(api)
     if (hrn):
         if network.get_slice(api, hrn):
             network.addSlice()
 
-    return network.toxml()
+    rspec = network.toxml()
+
+    # cache the result
+    if api.cache and not xrn:
+        api.cache.add('nodes', rspec)
+
+    return rspec
 
 """
 Returns the request context required by sfatables. At some point, this
index 3a15b68..ad1068f 100644 (file)
@@ -17,7 +17,6 @@ from sfa.util.faults import *
 from sfa.util.record import SfaRecord
 from sfa.util.policy import Policy
 from sfa.util.prefixTree import prefixTree
-from sfa.util.rspec import *
 from sfa.util.sfaticket import *
 from sfa.util.debug import log
 from sfa.util.sfalogging import logger
@@ -187,17 +186,38 @@ def reset_slice(api, xrn):
     return 1
 
 def get_slices(api):
-    # XX just import the legacy module and excute that until
-    # we transition the code to this module
-    from sfa.plc.slices import Slices
-    slices = Slices(api)
-    slices.refresh()
-    return [hrn_to_urn(slice_hrn, 'slice') for slice_hrn in slices['hrn']]
-     
+    # look in cache first
+    if api.cache:
+        slices = api.cache.get('slices')
+        if slices:
+            return slices    
+
+    # fetch from aggregates
+    slices = []
+    credential = api.getCredential()
+    for aggregate in api.aggregates:
+        try:
+            tmp_slices = api.aggregates[aggregate].get_slices(credential)
+            slices.extend(tmp_slices)
+        except:
+            print >> log, "%s" % (traceback.format_exc())
+            print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
+
+    # cache the result
+    if api.cache:
+        api.cache.add('slices', slices)
+
+    return slices
 def get_rspec(api, xrn=None, origin_hrn=None):
+    # look in cache first 
+    if api.cache and not xrn:
+        rspec =  api.cache.get('nodes')
+        if rspec:
+            return rspec
+
     hrn, type = urn_to_hrn(xrn)
     rspec = None
-
     aggs = api.aggregates
     cred = api.getCredential()                                                 
 
@@ -231,7 +251,11 @@ def get_rspec(api, xrn=None, origin_hrn=None):
                     for request in root.iterfind("./request"):
                         rspec.append(deepcopy(request))
 
-    return etree.tostring(rspec, xml_declaration=True, pretty_print=True)
+    rspec =  etree.tostring(rspec, xml_declaration=True, pretty_print=True)
+    if api.cache and not xrn:
+        api.cache.add('nodes', rspec)
+    return rspec
 
 """
 Returns the request context required by sfatables. At some point, this
index 4a08c26..1fefac2 100644 (file)
@@ -7,7 +7,6 @@ from sfa.util.method import Method
 from sfa.util.parameter import Parameter, Mixed
 from sfa.trust.auth import Auth
 from sfa.util.config import Config
-from sfa.plc.nodes import Nodes
 # RSpecManager_pl is not used. This line is a check that ensures that everything is in place for the import to work.
 import sfa.rspecs.aggregates.rspec_manager_pl
 from sfa.trust.credential import Credential
index 44c38a9..3f5a034 100644 (file)
@@ -36,11 +36,12 @@ class SfaAPI(BaseAPI):
     import sfa.methods
     methods = sfa.methods.all
     
-    def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods', \
-                 peer_cert = None, interface = None, key_file = None, cert_file = None):
+    def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", 
+                 methods='sfa.methods', peer_cert = None, interface = None, 
+                key_file = None, cert_file = None, cache = None):
         BaseAPI.__init__(self, config=config, encoding=encoding, methods=methods, \
                          peer_cert=peer_cert, interface=interface, key_file=key_file, \
-                         cert_file=cert_file)
+                         cert_file=cert_file, cache=cache)
  
         self.encoding = encoding
 
diff --git a/sfa/plc/nodes.py b/sfa/plc/nodes.py
deleted file mode 100644 (file)
index 190ee63..0000000
+++ /dev/null
@@ -1,262 +0,0 @@
-### $Id$
-### $URL$
-
-import os
-import time
-import datetime
-import sys
-import traceback
-
-from sfa.util.namespace import *
-from sfa.util.rspec import *
-from sfa.util.specdict import * 
-from sfa.util.faults import *
-from sfa.util.storage import *
-from sfa.util.debug import log
-from sfa.util.rspec import *
-from sfa.util.specdict import * 
-from sfa.util.policy import Policy
-
-class Nodes(SimpleStorage):
-
-    def __init__(self, api, ttl = 1, origin_hrn=None):
-        self.api = api
-        self.ttl = ttl
-        self.threshold = None
-        path = self.api.config.SFA_DATA_DIR
-        filename = ".".join([self.api.interface, self.api.hrn, "nodes"])
-        filepath = path + os.sep + filename
-        self.nodes_file = filepath
-        SimpleStorage.__init__(self, self.nodes_file)
-        self.policy = Policy(api)
-        self.load()
-        self.origin_hrn = origin_hrn
-
-
-    def refresh(self):
-        """
-        Update the cached list of nodes
-        """
-
-        # Reload components list
-        now = datetime.datetime.now()
-        if not self.has_key('threshold') or not self.has_key('timestamp') or \
-           now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))): 
-            if self.api.interface in ['aggregate']:
-                self.refresh_nodes_aggregate()
-            elif self.api.interface in ['slicemgr']:
-                self.refresh_nodes_smgr()
-
-    def refresh_nodes_aggregate(self):
-        rspec = RSpec()
-        rspec.parseString(self.get_rspec())
-        
-        # filter nodes according to policy
-        blist = self.policy['node_blacklist']
-        wlist = self.policy['node_whitelist']
-        rspec.filter('NodeSpec', 'name', blacklist=blist, whitelist=wlist)
-
-        # extract ifspecs from rspec to get ips'
-        ips = []
-        ifspecs = rspec.getDictsByTagName('IfSpec')
-        for ifspec in ifspecs:
-            if ifspec.has_key('addr') and ifspec['addr']:
-                ips.append(ifspec['addr'])
-
-        # extract nodespecs from rspec to get dns names
-        hostnames = []
-        nodespecs = rspec.getDictsByTagName('NodeSpec')
-        for nodespec in nodespecs:
-            if nodespec.has_key('name') and nodespec['name']:
-                hostnames.append(nodespec['name'])
-
-        # update timestamp and threshold
-        timestamp = datetime.datetime.now()
-        hr_timestamp = timestamp.strftime(self.api.time_format)
-        delta = datetime.timedelta(hours=self.ttl)
-        threshold = timestamp + delta
-        hr_threshold = threshold.strftime(self.api.time_format)
-
-        node_details = {}
-        node_details['rspec'] = rspec.toxml()
-        node_details['ip'] = ips
-        node_details['dns'] = hostnames
-        node_details['timestamp'] = hr_timestamp
-        node_details['threshold'] = hr_threshold
-        # save state 
-        self.update(node_details)
-        self.write()       
-    def get_rspec_smgr(self, xrn = None):
-        hrn, type = urn_to_hrn(xrn)
-        # convert and threshold to ints
-        if self.has_key('timestamp') and self['timestamp']:
-            hr_timestamp = self['timestamp']
-            timestamp = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_timestamp, self.api.time_format)))
-            hr_threshold = self['threshold']
-            threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(hr_threshold, self.api.time_format)))
-        else:
-            timestamp = datetime.datetime.now()
-            hr_timestamp = timestamp.strftime(self.api.time_format)
-            delta = datetime.timedelta(hours=self.ttl)
-            threshold = timestamp + delta
-            hr_threshold = threshold.strftime(self.api.time_format)
-
-        start_time = int(timestamp.strftime("%s"))
-        end_time = int(threshold.strftime("%s"))
-        duration = end_time - start_time
-
-        aggregates = self.api.aggregates
-        rspecs = {}
-        networks = []
-        rspec = RSpec()
-        credential = self.api.getCredential()
-        origin_hrn = self.origin_hrn
-        for aggregate in aggregates:
-          if aggregate not in [self.api.auth.client_cred.get_gid_caller().get_hrn()]:
-            try:
-                # get the rspec from the aggregate
-                agg_rspec = aggregates[aggregate].get_resources(credential, xrn, origin_hrn)
-                # extract the netspec from each aggregates rspec
-                rspec.parseString(agg_rspec)
-                networks.extend([{'NetSpec': rspec.getDictsByTagName('NetSpec')}])
-            except:
-                # XX print out to some error log
-                print >> log, "Error getting resources at aggregate %s" % aggregate
-                traceback.print_exc(log)
-                print >> log, "%s" % (traceback.format_exc())
-        # create the rspec dict
-        resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
-        resourceDict = {'RSpec': resources}
-        # convert rspec dict to xml
-        rspec.parseDict(resourceDict)
-        return rspec.toxml()
-
-    def refresh_nodes_smgr(self):
-
-        rspec = RSpec(xml=self.get_rspec_smgr())        
-        # filter according to policy
-        blist = self.policy['node_blacklist']
-        wlist = self.policy['node_whitelist']    
-        rspec.filter('NodeSpec', 'name', blacklist=blist, whitelist=wlist)
-
-        # update timestamp and threshold
-        timestamp = datetime.datetime.now()
-        hr_timestamp = timestamp.strftime(self.api.time_format)
-        delta = datetime.timedelta(hours=self.ttl)
-        threshold = timestamp + delta
-        hr_threshold = threshold.strftime(self.api.time_format)
-
-        nodedict = {'rspec': rspec.toxml(),
-                    'timestamp': hr_timestamp,
-                    'threshold':  hr_threshold}
-
-        self.update(nodedict)
-        self.write()
-
-    def get_rspec(self, xrn = None):
-
-        if self.api.interface in ['slicemgr']:
-            return self.get_rspec_smgr(xrn)
-        elif self.api.interface in ['aggregate']:
-            return self.get_rspec_aggregate(xrn)     
-
-    def get_rspec_aggregate(self, xrn = None):
-        """
-        Get resource information from PLC
-        """
-        hrn, type = urn_to_hrn(xrn)
-        slicename = None
-        # Get the required nodes
-        if not hrn:
-            nodes = self.api.plshell.GetNodes(self.api.plauth, {'peer_id': None})
-            try:  linkspecs = self.api.plshell.GetLinkSpecs() # if call is supported
-            except:  linkspecs = []
-        else:
-            slicename = hrn_to_pl_slicename(hrn)
-            slices = self.api.plshell.GetSlices(self.api.plauth, [slicename])
-            if not slices:
-                nodes = []
-            else:
-                slice = slices[0]
-                node_ids = slice['node_ids']
-                nodes = self.api.plshell.GetNodes(self.api.plauth, {'peer_id': None, 'node_id': node_ids})
-
-        # Filter out whitelisted nodes
-        public_nodes = lambda n: n.has_key('slice_ids_whitelist') and not n['slice_ids_whitelist']
-            
-        # ...only if they are not already assigned to this slice.
-        if (not slicename):        
-            nodes = filter(public_nodes, nodes)
-
-        # Get all network interfaces
-        interface_ids = []
-        for node in nodes:
-            # The field name has changed in plcapi 4.3
-            if self.api.plshell_version in ['4.2']:
-                interface_ids.extend(node['nodenetwork_ids'])
-            elif self.api.plshell_version in ['4.3']:
-                interface_ids.extend(node['interface_ids'])
-            else:
-                raise SfaAPIError, "Unsupported plcapi version ", \
-                                 self.api.plshell_version
-
-        if self.api.plshell_version in ['4.2']:
-            interfaces = self.api.plshell.GetNodeNetworks(self.api.plauth, interface_ids)
-        elif self.api.plshell_version in ['4.3']:
-            interfaces = self.api.plshell.GetInterfaces(self.api.plauth, interface_ids)
-        else:
-            raise SfaAPIError, "Unsupported plcapi version ", \
-                                self.api.plshell_version 
-        interface_dict = {}
-        for interface in interfaces:
-            if self.api.plshell_version in ['4.2']:
-                interface_dict[interface['nodenetwork_id']] = interface
-            elif self.api.plshell_version in ['4.3']:
-                interface_dict[interface['interface_id']] = interface
-            else:
-                raise SfaAPIError, "Unsupported plcapi version", \
-                                    self.api.plshell_version 
-
-        # join nodes with thier interfaces
-        for node in nodes:
-            node['interfaces'] = []
-            if self.api.plshell_version in ['4.2']:
-                for nodenetwork_id in node['nodenetwork_ids']:
-                    node['interfaces'].append(interface_dict[nodenetwork_id])
-            elif self.api.plshell_version in ['4.3']:
-                for interface_id in node['interface_ids']:
-                    node['interfaces'].append(interface_dict[interface_id])
-            else:
-                raise SfaAPIError, "Unsupported plcapi version", \
-                                    self.api.plshell_version
-
-        # convert and threshold to ints
-        if self.has_key('timestamp') and self['timestamp']:
-            timestamp = datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['timestamp'], self.api.time_format)))
-            threshold = datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format)))
-        else:
-            timestamp = datetime.datetime.now()
-            delta = datetime.timedelta(hours=self.ttl)
-            threshold = timestamp + delta
-
-        start_time = int(timestamp.strftime("%s"))
-        end_time = int(threshold.strftime("%s"))
-        duration = end_time - start_time
-
-        # create the plc dict
-        networks = [{'nodes': nodes,
-                     'name': self.api.hrn,
-                     'start_time': start_time,
-                     'duration': duration}]
-        if not hrn:
-            networks[0]['links'] = linkspecs
-        resources = {'networks': networks, 'start_time': start_time, 'duration': duration}
-
-        # convert the plc dict to an rspec dict
-        resourceDict = RSpecDict(resources)
-        # convert the rspec dict to xml
-        rspec = RSpec()
-        rspec.parseDict(resourceDict)
-        return rspec.toxml()
-        
index 9730e97..3c53504 100644 (file)
@@ -11,7 +11,6 @@ from sfa.util.namespace import *
 from sfa.util.rspec import *
 from sfa.util.specdict import *
 from sfa.util.faults import *
-from sfa.util.storage import *
 from sfa.util.record import SfaRecord
 from sfa.util.policy import Policy
 from sfa.util.prefixTree import prefixTree
@@ -19,21 +18,14 @@ from sfa.util.debug import log
 
 MAXINT =  2L**31-1
 
-class Slices(SimpleStorage):
+class Slices:
 
     rspec_to_slice_tag = {'max_rate':'net_max_rate'}
 
     def __init__(self, api, ttl = .5, origin_hrn=None):
         self.api = api
-        self.ttl = ttl
-        self.threshold = None
-        path = self.api.config.SFA_DATA_DIR
-        filename = ".".join([self.api.interface, self.api.hrn, "slices"])
-        filepath = path + os.sep + filename
-        self.slices_file = filepath
-        SimpleStorage.__init__(self, self.slices_file)
+        #filepath = path + os.sep + filename
         self.policy = Policy(self.api)    
-        self.load()
         self.origin_hrn = origin_hrn
 
     def get_slivers(self, xrn, node=None):
@@ -176,66 +168,6 @@ class Slices(SimpleStorage):
 
         return sfa_peer 
 
-    def refresh(self):
-        """
-        Update the cached list of slices
-        """
-        # Reload components list
-        now = datetime.datetime.now()
-        if not self.has_key('threshold') or not self.has_key('timestamp') or \
-           now > datetime.datetime.fromtimestamp(time.mktime(time.strptime(self['threshold'], self.api.time_format))):
-            if self.api.interface in ['aggregate']:
-                self.refresh_slices_aggregate()
-            elif self.api.interface in ['slicemgr']:
-                self.refresh_slices_smgr()
-
-    def refresh_slices_aggregate(self):
-        slices = self.api.plshell.GetSlices(self.api.plauth, {'peer_id': None}, ['name'])
-        slice_hrns = [slicename_to_hrn(self.api.hrn, slice['name']) for slice in slices]
-
-         # update timestamp and threshold
-        timestamp = datetime.datetime.now()
-        hr_timestamp = timestamp.strftime(self.api.time_format)
-        delta = datetime.timedelta(hours=self.ttl)
-        threshold = timestamp + delta
-        hr_threshold = threshold.strftime(self.api.time_format)
-        
-        slice_details = {'hrn': slice_hrns,
-                         'timestamp': hr_timestamp,
-                         'threshold': hr_threshold
-                        }
-        self.update(slice_details)
-        self.write()     
-        
-
-    def refresh_slices_smgr(self):
-        slice_hrns = []
-        credential = self.api.getCredential()
-        for aggregate in self.api.aggregates:
-            success = False
-            try:
-                slices = self.api.aggregates[aggregate].get_slices(credential)
-                slice_hrns.extend(slices)
-                success = True
-            except:
-                print >> log, "%s" % (traceback.format_exc())
-                print >> log, "Error calling slices at aggregate %(aggregate)s" % locals()
-
-        # update timestamp and threshold
-        timestamp = datetime.datetime.now()
-        hr_timestamp = timestamp.strftime(self.api.time_format)
-        delta = datetime.timedelta(hours=self.ttl)
-        threshold = timestamp + delta
-        hr_threshold = threshold.strftime(self.api.time_format)
-
-        slice_details = {'hrn': slice_hrns,
-                         'timestamp': hr_timestamp,
-                         'threshold': hr_threshold
-                        }
-        self.update(slice_details)
-        self.write()
-
-
     def verify_site(self, registry, credential, slice_hrn, peer, sfa_peer):
         authority = get_authority(slice_hrn)
         authority_urn = hrn_to_urn(authority, 'authority')
index f7dec97..ce8597e 100644 (file)
@@ -551,6 +551,10 @@ class Certificate:
             #print "TRUSTED CERT", trusted_cert.dump()
             #print "Client is signed by Trusted?", self.is_signed_by_cert(trusted_cert)
             if self.is_signed_by_cert(trusted_cert):
+                # make sure sure the trusted cert's hrn is a prefix of the
+                # signed cert's hrn
+                if not self.get_subject().startswith(trusted_cert.get_subject()):
+                    raise GidParentHrn(trusted_cert.get_subject()) 
                 #print self.get_subject(), "is signed by a root"
                 return
 
index 037ff0d..a7a07b6 100644 (file)
@@ -97,14 +97,15 @@ def import_deep(name):
 
 class BaseAPI:
 
-    def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", methods='sfa.methods',
-
-                 peer_cert = None, interface = None, key_file = None, cert_file = None):
+    cache = None
+    def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", 
+                 methods='sfa.methods', peer_cert = None, interface = None, 
+                 key_file = None, cert_file = None, cache = cache):
 
         self.encoding = encoding
         
         # flat list of method names
-         
         self.methods_module = methods_module = __import__(methods, fromlist=[methods])
         self.methods = methods_module.all
 
@@ -121,6 +122,7 @@ class BaseAPI:
         self.key = Keypair(filename=self.key_file)
         self.cert_file = cert_file
         self.cert = Certificate(filename=self.cert_file)
+        self.cache = cache
         self.credential = None
         self.source = None 
         self.time_format = "%Y-%m-%d %H:%M:%S"
diff --git a/sfa/util/cache.py b/sfa/util/cache.py
new file mode 100644 (file)
index 0000000..45961fe
--- /dev/null
@@ -0,0 +1,62 @@
+#
+# This module implements general purpose caching system
+#
+from __future__ import with_statement
+import time
+import threading
+from datetime import datetime
+
+# maximum lifetime of cached data (in seconds) 
+MAX_CACHE_TTL = 60 * 60
+
+class CacheData:
+
+    data = None
+    created = None
+    expires = None
+    lock = None
+
+    def __init__(self, data, ttl = MAX_CACHE_TTL):
+        self.lock = threading.RLock()
+        self.data = data
+        self.renew(ttl)
+
+    def is_expired(self):
+        return time.time() > self.expires
+
+    def get_created_date(self):
+        return str(datetime.fromtimestamp(self.created))
+
+    def get_expires_date(self):
+        return str(datetime.fromtimestamp(self.expires))
+
+    def renew(self, ttl = MAX_CACHE_TTL):
+        self.created = time.time()
+        self.expires = self.created + ttl   
+       
+    def set_data(self, data, renew=True, ttl = MAX_CACHE_TTL):
+        with self.lock: 
+            self.data = data
+            if renew:
+                self.renew(ttl)
+    
+    def get_data(self):
+        return self.data
+
+class Cache:
+
+    cache  = {}
+    lock = threading.RLock()
+   
+    def add(self, key, value, ttl = MAX_CACHE_TTL):
+        with self.lock:
+            if self.cache.has_key(key):
+                self.cache[key].set_data(value, ttl=ttl)
+            else:
+                self.cache[key] = CacheData(value, ttl)
+           
+    def get(self, key):
+        data = self.cache.get(key)
+        if not data or data.is_expired():
+            return None 
+        return data.get_data()
index 3dedcf5..e6d3f3b 100644 (file)
@@ -22,7 +22,8 @@ from Queue import Queue
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import *
 from sfa.util.faults import *
-from sfa.plc.api import SfaAPI 
+from sfa.plc.api import SfaAPI
+from sfa.util.cache import Cache 
 from sfa.util.debug import log
 
 ##
@@ -90,7 +91,8 @@ class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
     def do_POST(self):
         """Handles the HTTPS POST request.
 
-        It was copied out from SimpleXMLRPCServer.py and modified to shutdown the socket cleanly.
+        It was copied out from SimpleXMLRPCServer.py and modified to shutdown 
+        the socket cleanly.
         """
         try:
             peer_cert = Certificate()
@@ -98,7 +100,8 @@ class SecureXMLRpcRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
             self.api = SfaAPI(peer_cert = peer_cert, 
                               interface = self.server.interface, 
                               key_file = self.server.key_file, 
-                              cert_file = self.server.cert_file)
+                              cert_file = self.server.cert_file,
+                              cache = self.cache)
             # get arguments
             request = self.rfile.read(int(self.headers["content-length"]))
             remote_addr = (remote_ip, remote_port) = self.connection.getpeername()
@@ -137,7 +140,8 @@ class SecureXMLRPCServer(BaseHTTPServer.HTTPServer,SimpleXMLRPCServer.SimpleXMLR
         self.key_file = key_file
         self.cert_file = cert_file
         self.method_map = {}
-
+        # add cache to the request handler
+        HandlerClass.cache = Cache()
         #for compatibility with python 2.4 (centos53)
         if sys.version_info < (2, 5):
             SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self)
index f5af7b9..892b72c 100644 (file)
@@ -98,7 +98,7 @@ class SfaTable(list):
             sql = " DELETE FROM %s WHERE authority = %s" % \
                     (self.tablename, record['hrn'])
             self.db.do(sql)
-            self.db.commit() 
+        self.db.commit() 
 
     def insert(self, record):
         db_fields = self.db_fields(record)