Merge branch 'master' into eucalyptus-devel
authorMarco Yuen <marcoy@gmail.com>
Wed, 6 Jul 2011 18:20:59 +0000 (14:20 -0400)
committerMarco Yuen <marcoy@gmail.com>
Wed, 6 Jul 2011 18:20:59 +0000 (14:20 -0400)
33 files changed:
TODO
setup.py
sfa.spec
sfa/client/Makefile
sfa/client/sfi.py
sfa/init.d/sfa
sfa/managers/aggregate_manager_eucalyptus.py
sfa/managers/aggregate_manager_pl.py
sfa/managers/eucalyptus/eucalyptus.rnc
sfa/managers/eucalyptus/eucalyptus.rng
sfa/managers/slice_manager_pl.py
sfa/plc/sfaImport.py
sfa/rspecs/elements/PGv1Network.py [new file with mode: 0644]
sfa/rspecs/elements/SFAv1Network.py [new file with mode: 0755]
sfa/rspecs/elements/SFAv1Node.py [new file with mode: 0755]
sfa/rspecs/elements/SFAv1Sliver.py [new file with mode: 0755]
sfa/rspecs/elements/__init__.py [new file with mode: 0644]
sfa/rspecs/elements/element.py [new file with mode: 0644]
sfa/rspecs/elements/link.py [new file with mode: 0644]
sfa/rspecs/elements/network.py [new file with mode: 0644]
sfa/rspecs/elements/node.py [new file with mode: 0644]
sfa/rspecs/elements/sliver.py [new file with mode: 0644]
sfa/rspecs/pg_rspec.py
sfa/rspecs/sfa_rspec.py
sfa/server/interface.py
sfa/server/registry.py
sfa/server/sfa-server.py
sfa/trust/credential.py
sfa/trust/rights.py
sfa/util/api.py
sfa/util/plxrn.py
sfa/util/sfalogging.py
sfa/util/sfatime.py

diff --git a/TODO b/TODO
index ab99b42..6af99bd 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,5 +1,7 @@
 RSpecs
 - CreateSlivers should update SliverTags/attributes 
+- ProtoGENI rspec integration testing
+- initscripts in the rspec
 
 Registry
 - Verify that sub authority certificates still work
index 2800555..e91c5dc 100755 (executable)
--- a/setup.py
+++ b/setup.py
@@ -44,6 +44,7 @@ package_dirs = [
     'sfa/util', 
     'sfa/managers',
     'sfa/rspecs',
+    'sfa/rspecs/elements',
     'sfatables',
     'sfatables/commands',
     'sfatables/processors',
index 92d56a3..521bcf2 100644 (file)
--- a/sfa.spec
+++ b/sfa.spec
@@ -193,8 +193,7 @@ if [ "$1" = 0 ] ; then
 fi
 
 %postun cm
-[ "$1" -ge "1" ] && service sfa-cm restart
-
+[ "$1" -ge "1" ] && service sfa-cm restart || :
 
 %changelog
 * Tue Jun 21 2011 Thierry Parmentelat <thierry.parmentelat@sophia.inria.fr> - sfa-1.0-26
index 8f334b5..820106e 100644 (file)
@@ -38,7 +38,8 @@ BUNDLES-LR += http://www.planet-lab.jp:12347/@auto-plj-sa
 BUNDLES-LR += http://www.emanicslab.org:12345/@auto-elc-reg 
 BUNDLES-LR += http://www.emanicslab.org:12347/@auto-elc-sa
 
-EXTENSIONS := png svg
+#EXTENSIONS := png svg
+EXTENSIONS := png
 
 ####################
 ALL += $(foreach bundle,$(BUNDLES),$(word 2,$(subst @, ,$(bundle))))
@@ -49,7 +50,7 @@ all: $(ALL)
 ####################
 define bundle_scan_target
 $(word 2,$(subst @, ,$(1))):
-       ./sfascan.py $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1))).$(extension)) $(word 1,$(subst @, ,$(1))) >& $(word 2,$(subst @, ,$(1))).out
+       ./sfascan.py $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1))).$(extension)) $(word 1,$(subst @, ,$(1))) >& .$(word 2,$(subst @, ,$(1))).out
 .PHONY: $(word 2,$(subst @, ,$(1)))
 endef
 
@@ -59,7 +60,7 @@ $(foreach bundle,$(BUNDLES),$(eval $(call bundle_scan_target,$(bundle))))
 #################### same but left-to-right
 define bundle_scan_target_lr
 $(word 2,$(subst @, ,$(1)))-lr:
-       ./sfascan.py -l $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1)))-lr.$(extension)) $(word 1,$(subst @, ,$(1))) >& $(word 2,$(subst @, ,$(1)))-lr.out
+       ./sfascan.py -l $(foreach extension,$(EXTENSIONS),-o $(word 2,$(subst @, ,$(1)))-lr.$(extension)) $(word 1,$(subst @, ,$(1))) >& .$(word 2,$(subst @, ,$(1)))-lr.out
 .PHONY: $(word 2,$(subst @, ,$(1)))-lr
 endef
 
index 750d873..094c6be 100755 (executable)
@@ -15,7 +15,7 @@ from lxml import etree
 from StringIO import StringIO
 from types import StringTypes, ListType
 from optparse import OptionParser
-from sfa.util.sfalogging import info_logger
+from sfa.util.sfalogging import sfi_logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.gid import GID
 from sfa.trust.credential import Credential
@@ -26,6 +26,8 @@ import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.util.config import Config
 from sfa.util.version import version_core
 from sfa.util.cache import Cache
+from sfa.rspecs.rspec_version import RSpecVersion
+from sfa.rspecs.pg_rspec import pg_rspec_request_version
 
 AGGREGATE_PORT=12346
 CM_PORT=12346
@@ -141,7 +143,8 @@ class Sfi:
         self.user = None
         self.authority = None
         self.hashrequest = False
-        self.logger = info_logger
+        self.logger = sfi_logger
+        self.logger.enable_console()
    
     def create_cmd_parser(self, command, additional_cmdargs=None):
         cmdargs = {"list": "authority",
@@ -370,6 +373,8 @@ class Sfi:
             version = server.GetVersion()
             # cache version for 24 hours
             cache.add(cache_key, version, ttl= 60*60*24)
+            self.logger.info("Updating cache file %s" % cache_file)
+            cache.save_to_file(cache_file)
 
 
         return version   
@@ -891,7 +896,14 @@ class Sfi:
             delegated_cred = self.delegate_cred(cred, get_authority(self.authority))
             creds.append(delegated_cred)
         if opts.rspec_version:
-            call_options['rspec_version'] = opts.rspec_version 
+            server_version = self.get_cached_server_version(server)
+            if 'sfa' in server_version:
+                # just request the version the client wants 
+                call_options['rspec_version'] = dict(RSpecVersion(opts.rspec_version)) 
+            else:
+                # this must be a protogeni aggregate. We should request a v2 ad rspec
+                # regardless of what the client user requested 
+                call_options['rspec_version'] = dict(pg_rspec_request_version)     
         #panos add info options
         if opts.info:
             call_options['info'] = opts.info 
@@ -929,7 +941,7 @@ class Sfi:
         #  }]
         users = []
         server = self.get_server_from_opts(opts)
-        version = server.GetVersion()
+        version = self.get_cached_server_version(server)
         if 'sfa' not in version:
             # need to pass along user keys if this request is going to a ProtoGENI aggregate 
             # ProtoGeni Aggregates will only install the keys of the user that is issuing the
index b039c24..e2fdb10 100755 (executable)
@@ -61,6 +61,9 @@ start() {
     
     reload
 
+    # install peer certs
+    action $"SFA installing peer certs" daemon /usr/bin/sfa-server.py -t -d $OPTIONS 
+
     if [ "$SFA_REGISTRY_ENABLED" -eq 1 ]; then
         action $"SFA Registry" daemon /usr/bin/sfa-server.py -r -d $OPTIONS
     fi
index 26e5742..e9969f9 100644 (file)
@@ -374,7 +374,7 @@ class EucaRSpecBuilder(object):
         xml = self.eucaRSpec
         cloud = self.cloudInfo
         with xml.RSpec(type='eucalyptus'):
-            with xml.cloud(id=cloud['name']):
+            with xml.network(id=cloud['name']):
                 with xml.ipv4:
                     xml << cloud['ip']
                 #self.__keyPairsXML(cloud['keypairs'])
index a4e5651..0dd236c 100644 (file)
@@ -79,6 +79,7 @@ def __get_registry_objects(slice_xrn, creds, users):
 
         slice = {}
         
+        # get_expiration always returns a normalized datetime - no need to utcparse
         extime = Credential(string=creds[0]).get_expiration()
         # If the expiration time is > 60 days from now, set the expiration time to 60 days from now
         if extime > datetime.datetime.utcnow() + datetime.timedelta(days=60):
@@ -118,7 +119,7 @@ def SliverStatus(api, slice_xrn, creds, call_id):
     
     slices = api.plshell.GetSlices(api.plauth, [slicename], ['node_ids','person_ids','name','expires'])
     if len(slices) == 0:        
-        raise Exception("Slice %s not found (used %s as slicename internally)" % slice_xrn, slicename)
+        raise Exception("Slice %s not found (used %s as slicename internally)" % (slice_xrn, slicename))
     slice = slices[0]
     
     # report about the local nodes only
@@ -196,17 +197,46 @@ def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):
     added_nodes = list(set(requested_slivers).difference(current_slivers))
 
     # get sliver attributes
-    slice_attributes = rspec.get_slice_attributes()
-
+    requested_slice_attributes = rspec.get_slice_attributes()
+    removed_slice_attributes = []
+    existing_slice_attributes = []    
+    for slice_tag in api.plshell.GetSliceTags(api.plauth, {'slice_id': slice['slice_id']}):
+        attribute_found=False
+        for requested_attribute in requested_slice_attributes:
+            if requested_attribute['name'] == slice_tag['tagname'] and \
+               requested_attribute['value'] == slice_tag['value']:
+                attribute_found=True
+
+        if not attribute_found: 
+            removed_slice_attributes.append(slice_tag)
+        else:
+            existing_slice_attributes.append(slice_tag)  
+         
+    #api.logger.debug("requested slice attributes: %s" % str(requested_slice_attributes))
+    #api.logger.debug("removed slice attributes: %s" % str(removed_slice_attributes))
+    #api.logger.debug("existing slice attributes: %s" % str(existing_slice_attributes))
     try:
         if peer:
             api.plshell.UnBindObjectFromPeer(api.plauth, 'slice', slice['slice_id'], peer)
 
         api.plshell.AddSliceToNodes(api.plauth, slice['name'], added_nodes) 
         api.plshell.DeleteSliceFromNodes(api.plauth, slice['name'], deleted_nodes)
-        for attribute in sliver_atrributes:
-            name, value, node_id = attribute['tagname'], attribute['value'], attribute.get('node_id', None)
-            api.plshell.AddSliceTag(api.plauth, slice['name'], name, value, node_id)
+        # remove stale attributes
+        for attribute in removed_slice_attributes:
+            try:
+                api.plshell.DeleteSliceTag(api.plauth, attribute['slice_tag_id'])
+            except Exception, e:
+                api.logger.warn('Failed to remove sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
+                                % (name, value,  node_id, str(e)))
+
+        # add requested_attributes
+        for attribute in requested_slice_attributes:
+            try:
+                name, value, node_id = attribute['name'], attribute['value'], attribute.get('node_id', None)
+                api.plshell.AddSliceTag(api.plauth, slice['name'], name, value, node_id)
+            except Exception, e:
+                api.logger.warn('Failed to add sliver attribute. name: %s, value: %s, node_id: %s\nCause:%s'\
+                                % (name, value,  node_id, str(e)))
 
     finally:
         if peer:
index ba9758c..bc61ce3 100644 (file)
@@ -1,9 +1,9 @@
 start = RSpec
 RSpec = element RSpec {
     attribute type { xsd:NMTOKEN },
-    cloud
+    network
 }
-cloud = element cloud {
+network = element network {
     attribute id { xsd:NMTOKEN },
     user_info?,
     ipv4,
index 51d23c6..d7a85b4 100644 (file)
@@ -8,11 +8,11 @@
       <attribute name="type">
         <data type="NMTOKEN"/>
       </attribute>
-      <ref name="cloud"/>
+      <ref name="network"/>
     </element>
   </define>
-  <define name="cloud">
-    <element name="cloud">
+  <define name="network">
+    <element name="network">
       <attribute name="id">
         <data type="NMTOKEN"/>
       </attribute>
index 5c64df0..2cbc823 100644 (file)
@@ -37,12 +37,7 @@ def _call_id_supported(api, server):
     """
     Returns true if server support the optional call_id arg, false otherwise.
     """
-    cache_key = server.url + "-version"
-    server_version = api.cache.get(cache_key)
-    if not server_version:
-        server_version = server.GetVersion()
-        # cache version for 24 hours
-        api.cache.add(cache_key, server_version, ttl= 60*60*24)
+    server_version = api.get_cached_server_version(server)
 
     if 'sfa' in server_version:
         code_tag = server_version['code_tag']
@@ -93,7 +88,10 @@ def ListResources(api, creds, options, call_id):
         args = [credential, my_opts]
         if _call_id_supported(api, server):
             args.append(call_id)
-        return server.ListResources(*args)
+        try:
+            return server.ListResources(*args)
+        except Exception, e:
+            api.logger.warn("ListResources failed at %s: %s" %(server.url, str(e)))
 
     if Callids().already_handled(call_id): return ""
 
@@ -102,6 +100,8 @@ def ListResources(api, creds, options, call_id):
     (hrn, type) = urn_to_hrn(xrn)
     my_opts = copy(options)
     my_opts['geni_compressed'] = False
+    if 'rspec_version' in my_opts:
+        del my_opts['rspec_version']
 
     # get the rspec's return format from options
     rspec_version = RSpecVersion(options.get('rspec_version'))
@@ -121,6 +121,7 @@ def ListResources(api, creds, options, call_id):
     credential = api.getDelegatedCredential(creds)
     if not credential:
         credential = api.getCredential()
+    credentials = [credential]
     threads = ThreadManager()
     for aggregate in api.aggregates:
         # prevent infinite loop. Dont send request back to caller
@@ -129,8 +130,8 @@ def ListResources(api, creds, options, call_id):
             continue
         # get the rspec from the aggregate
         server = api.aggregates[aggregate]
-        #threads.run(server.ListResources, credential, my_opts, call_id)
-        threads.run(_ListResources, server, credential, my_opts, call_id)
+        #threads.run(server.ListResources, credentials, my_opts, call_id)
+        threads.run(_ListResources, server, credentials, my_opts, call_id)
 
     results = threads.get_results()
     rspec_version = RSpecVersion(my_opts.get('rspec_version'))
@@ -154,17 +155,23 @@ def ListResources(api, creds, options, call_id):
 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
 
     def _CreateSliver(server, xrn, credential, rspec, users, call_id):
+        try:
             # Need to call GetVersion at an aggregate to determine the supported 
             # rspec type/format beofre calling CreateSliver at an Aggregate. 
-            server_version = _get_server_version(api, server)    
+            server_version = api.get_cached_server_version(server)    
             if 'sfa' not in aggregate_version and 'geni_api' in aggregate_version:
                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
-                # if aggregate supports sfa rspecs. othewise convert to pg rspec
+                # if aggregate supports sfa rspecs. otherwise convert to pg rspec
                 rspec = RSpecConverter.to_pg_rspec(rspec)
             args = [xrn, credential, rspec, users]
             if _call_id_supported(api, server):
                 args.append(call_id)
-            return server.CreateSliver(*args)
+            try:
+                return server.CreateSliver(*args)
+            except Exception, e:
+                api.logger.warn("CreateSliver failed at %s: %s" %(server.url, str(e)))
+        except: 
+            logger.log_exc('Something wrong in _CreateSliver')
 
     if Callids().already_handled(call_id): return ""
     # Validate the RSpec against PlanetLab's schema --disabled for now
@@ -350,23 +357,6 @@ def ListSlices(api, creds, call_id):
     return slices
 
 
-    if rspec_version['type'] == pg_rspec_ad_version['type']:
-        rspec = PGRSpec()
-    else:
-        rspec = SfaRSpec()
-    for result in results:
-        try:
-            rspec.merge(result)
-        except:
-            api.logger.info("SM.ListResources: Failed to merge aggregate rspec")
-
-    # cache the result
-    if caching and api.cache and not xrn:
-        api.cache.add(version_string, rspec.toxml())
-
-    return rspec.toxml()
-
-
 def get_ticket(api, xrn, creds, rspec, users):
     slice_hrn, type = urn_to_hrn(xrn)
     # get the netspecs contained within the clients rspec
index ee34833..e9d0940 100644 (file)
@@ -135,12 +135,16 @@ class sfaImport:
             # to planetlab
             keys = self.shell.GetKeys(self.plc_auth, key_ids)
             key = keys[0]['key']
-            pkey = convert_public_key(key)
+            pkey = None
+            try:
+                pkey = convert_public_key(key)
+            except:
+                self.logger.warn('unable to convert public key for %s' % hrn) 
             if not pkey:
                 pkey = Keypair(create=True)
         else:
             # the user has no keys
-            self.logger.warning("Import: person %s does not have a PL public key"%hrn)
+            self.logger.warn("Import: person %s does not have a PL public key"%hrn)
             # if a key is unavailable, then we still need to put something in the
             # user's GID. So make one up.
             pkey = Keypair(create=True)
diff --git a/sfa/rspecs/elements/PGv1Network.py b/sfa/rspecs/elements/PGv1Network.py
new file mode 100644 (file)
index 0000000..c4f3c87
--- /dev/null
@@ -0,0 +1,7 @@
+from sfa.rspecs.elements.networks import Network
+
+class PGv2Network(Network):
+
+    def get_networks_names(self):
+        networks = self.xml.xpath('//rspecv2:node[@component_manager_id]/@component_manager_id', namespaces=self.namespaces)
+        return list(set(networks))
diff --git a/sfa/rspecs/elements/SFAv1Network.py b/sfa/rspecs/elements/SFAv1Network.py
new file mode 100755 (executable)
index 0000000..4bfa26d
--- /dev/null
@@ -0,0 +1,32 @@
+#!/usr/bin/python
+from sfa.rspecs.elements.network import Network
+
+class SFAv1Network(Network):
+
+    def get_network_elements(self):
+        return self.root_node.xpath('//network')        
+
+    def get_networks(self):
+        network_elems = self.get_network_elements()
+        networks = [self.get_attributes(network_elem) \
+                    for network_elem in network_elems]
+        return networks
+
+    def add_networks(self, networks):
+        if not isinstance(networks, list):
+            networks = [networks]
+        return self.add_element('network', {'id': network}, self.root_node)
+
+
+
+if __name__ == '__main__':
+    import sys
+    from lxml import etree
+    args = sys.argv[1:]
+    filename = args[0]
+
+    root_node = etree.parse(filename)
+    network = SFAv1Network(root_node)
+    print network.get_networks()
+
+    
diff --git a/sfa/rspecs/elements/SFAv1Node.py b/sfa/rspecs/elements/SFAv1Node.py
new file mode 100755 (executable)
index 0000000..4daad4f
--- /dev/null
@@ -0,0 +1,90 @@
+#!/usr/bin/python
+
+from sfa.rspecs.elements.node import Node
+
+class SFAv1Node(Node):
+
+    def get_node_elements(self, network=None, hostnames=None):
+        if network:
+            query = '//network[@name="%s"]//node' % network
+        else:
+            query = '//node'
+
+        if isinstance(hostnames, str):
+            query = query + '/hostname[text() = "%s"]' % hostnames
+        elif isinstance(hostnames, list):
+            query = query + '/hostname[contains( "%s" , text())]' \
+                    %(" ".join(hostnames))
+            
+        return self.xpath(query)
+
+    def get_nodes(self, network=None, hostnames=None):
+        node_elems = self.get_node_elements(network, hostnames)
+        nodes = [self.get_attributes(node_elem, recursive=True) \
+                 for node_elem in node_elems]
+        return nodes
+
+    def add_nodes(self, nodes, network=None, no_dupes=False):
+        if not isinstance(nodes, list):
+            nodes = [nodes]
+        for node in nodes:
+            if no_dupes and \
+              self.get_node_element(node['hostname']):
+                # node already exists
+                continue
+
+            network_tag = self.root_node
+            if 'network' in node:
+                network = node['network']
+                network_tags = self.root_node.xpath('//network[@name="%s"]' % network)
+                if not network_tags:
+                    #network_tag = etree.SubElement(self.root_node, 'network', name=network)
+                    network_tag = self.add_element('network', {'name': network}, self.root_node)
+                else:
+                    network_tag = network_tags[0]
+
+            #node_tag = etree.SubElement(network_tag, 'node')
+            node_tag = self.add_element('node', parent=network_tag)
+            if 'network' in node:
+                node_tag.set('component_manager_id', network)
+            if 'urn' in node:
+                node_tag.set('component_id', node['urn'])
+            if 'site_urn' in node:
+                node_tag.set('site_id', node['site_urn'])
+            if 'node_id' in node:
+                node_tag.set('node_id', 'n'+str(node['node_id']))
+            if 'hostname' in node:
+                #hostname_tag = etree.SubElement(node_tag, 'hostname').text = node['hostname']
+                hostname_tag = self.add_element('hostname', parent=node_tag)
+                hostname_tag.text = node['hostname']
+            if 'interfaces' in node:
+                for interface in node['interfaces']:
+                    if 'bwlimit' in interface and interface['bwlimit']:
+                        #bwlimit = etree.SubElement(node_tag, 'bw_limit', units='kbps').text = str(interface['bwlimit']/1000)
+                        bwlimit_tag = self.add_element('bw_limit', {'units': 'kbps'}, parent=node_tag)
+                        bwlimit_tag.text = str(interface['bwlimit']/1000)
+            if 'tags' in node:
+                for tag in node['tags']:
+                   # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category
+                   if tag['tagname'] in ['fcdistro', 'arch'] or 'sfa' in tag['category'].split('/'):
+                        #tag_element = etree.SubElement(node_tag, tag['tagname'], value=tag['value'])
+                        tag_element = self.add_element(tag['tagname'], parent=node_tag)
+                        tag_element.text = tag['value']
+
+            if 'site' in node:
+                longitude = str(node['site']['longitude'])
+                latitude = str(node['site']['latitude'])
+                #location = etree.SubElement(node_tag, 'location', country='unknown', \
+                #                            longitude=longitude, latitude=latitude) 
+                location_attrs = {'country': 'unknown', 'longitude': longitude, 'latitude': latitude} 
+                self.add_element('location', location_attrs, node_tag) 
+
+if __name__ == '__main__':
+    import sys
+    from lxml import etree
+    args = sys.argv[1:]
+    filename = args[0]
+
+    root_node = etree.parse(filename)
+    network = SFAv1Node(root_node)
+    print network.get_nodes()
diff --git a/sfa/rspecs/elements/SFAv1Sliver.py b/sfa/rspecs/elements/SFAv1Sliver.py
new file mode 100755 (executable)
index 0000000..eea5032
--- /dev/null
@@ -0,0 +1,97 @@
+#!/usr/bin/python
+
+from sfa.rspecs.elements.sliver import Sliver
+from sfa.rspecs.elements.SFAv1Node import SFVv1Node
+
+class SFAv1Sliver(Sliver):
+
+    def get_sliver_elements(self, network=None):
+        if network:
+            slivers = self.root_node.xpath('//network[@name="%s"]//node/sliver' % network)
+        else:
+            slivers = self.root_node.xpath('//node/sliver')
+        return slivers
+
+    def get_slivers(self, network=None):
+        sliver_elems = self.get_sliver_elements(network)
+        slivers = [self.get_attributes(sliver_elem, recursive=True) \
+                 for sliver_elem in sliver_elems]
+        return slivers
+
+    def add_slivers(self, slivers, network=None):
+        if not isinstance(slivers, list):
+            slivers = [slivers]
+        nodes = SfaV1Node(self.root_node) 
+        for sliver in slivers:
+            if isinstance(sliver, basestring):
+                sliver = {'hostname': sliver}
+            if 'hostname' in sliver:
+                node_elem = nodes.get_node_elements(hostnames=sliver['hostname'])
+                if node_elem:
+                    node_elem[0]
+                sliver_elem = self.add_element('sliver', parent=node_elem)
+                if 'tags' in sliver:
+                    for tag in sliver['tags']:
+                        self.add_element(tag['tagname'], parent=sliver_elem, text=tag['value'])
+
+    def remove_slivers(self, slivers, network=node):
+        nodes = SfaV1Node(self.root_node) 
+        for sliver in slivers:
+            if isinstance(sliver, str):
+                hostname = sliver
+            else:
+                hostname = sliver['hostname']
+            node_elem = nodes.get_node_elements(network=network, hostnames=hostname)
+            sliver_elem = node_elem.find('sliver')
+            if sliver_elem != None:
+                node_elem.remove(sliver_elem)    
+                         
+                
+    def get_sliver_defaults(self, network=None):
+        if network:
+            defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)     
+        else:
+            defaults = self.xml.xpath("//network/sliver_defaults" % network)
+        return self.attributes_list(defaults)
+
+    def add_default_sliver_attribute(self, name, value, network=None):
+        if network:
+            defaults = self.xpath("//network[@name='%s']/sliver_defaults" % network)
+        else:
+            defaults = self.xpath("//sliver_defaults" % network)
+        if not defaults:
+            network_tag = self.xpath("//network[@name='%s']" % network)
+            if isinstance(network_tag, list):
+                network_tag = network_tag[0]
+            defaults = self.add_element('sliver_defaults', attrs={}, parent=network_tag)
+        elif isinstance(defaults, list):
+            defaults = defaults[0]
+        self.add_attribute(defaults, name, value)
+
+    def add_sliver_attribute(self, hostname, name, value, network=None):
+        node = self.get_node_elements(network, hostname)
+        sliver = node.find("sliver")
+        self.add_attribute(sliver, name, value)
+    
+    def remove_default_sliver_attribute(self, name, value, network=None):
+        if network:
+            defaults = self.xpath("//network[@name='%s']/sliver_defaults" % network)
+        else:
+            defaults = self.xpath("//sliver_defaults" % network)
+        self.remove_attribute(defaults, name, value)
+    
+    def remove_sliver_attribute(self, hostname, name, value, network=None):
+        node = self.get_node_elements(network, hostname)
+        sliver = node.find("sliver")
+        self.remove_attribute(sliver, name, value)    
+            
+    
+if __name__ == '__main__':
+    import sys
+    from lxml import etree
+    args = sys.argv[1:]
+    filename = args[0]
+
+    root_node = etree.parse(filename)
+    network = SFAv1Node(root_node)
+    print network.get_nodes()
diff --git a/sfa/rspecs/elements/__init__.py b/sfa/rspecs/elements/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sfa/rspecs/elements/element.py b/sfa/rspecs/elements/element.py
new file mode 100644 (file)
index 0000000..8217c11
--- /dev/null
@@ -0,0 +1,85 @@
+from lxml import etree
+
+class Element:
+    def __init__(self, root_node, namespaces = None):
+        self.root_node = root_node
+        self.namespaces = namespaces
+
+    def xpath(self, xpath):
+        return this.root_node.xpath(xpath, namespaces=self.namespaces) 
+
+    def add_element(self, name, attrs={}, parent=None, text=""):
+        """
+        Generic wrapper around etree.SubElement(). Adds an element to
+        specified parent node. Adds element to root node is parent is
+        not specified.
+        """
+        if parent == None:
+            parent = self.root_node
+        element = etree.SubElement(parent, name)
+        if text:
+            element.text = text
+        if isinstance(attrs, dict):
+            for attr in attrs:
+                element.set(attr, attrs[attr])
+        return element
+
+    def remove_element(self, element_name, root_node = None):
+        """
+        Removes all occurences of an element from the tree. Start at
+        specified root_node if specified, otherwise start at tree's root.
+        """
+        if not root_node:
+            root_node = self.root_node
+
+        if not element_name.startswith('//'):
+            element_name = '//' + element_name
+
+        elements = root_node.xpath('%s ' % element_name, namespaces=self.namespaces)
+        for element in elements:
+            parent = element.getparent()
+            parent.remove(element)
+
+    
+    def add_attribute(self, elem, name, value):
+        """
+        Add attribute to specified etree element
+        """
+        opt = etree.SubElement(elem, name)
+        opt.text = value
+
+    def remove_attribute(self, elem, name, value):
+        """
+        Removes an attribute from an element
+        """
+        if not elem == None:
+            opts = elem.iterfind(name)
+            if opts is not None:
+                for opt in opts:
+                    if opt.text == value:
+                        elem.remove(opt)
+
+    def get_attributes(self, elem=None, depth=None):
+        if elem == None:
+            elem = self.root_node
+        attrs = dict(elem.attrib)
+        attrs['text'] = str(elem.text).strip()
+        if depth is None or isinstance(depth, int) and depth > 0: 
+            for child_elem in list(elem):
+                key = str(child_elem.tag)
+                if key not in attrs:
+                    attrs[key] = [self.get_attributes(child_elem, recursive)]
+                else:
+                    attrs[key].append(self.get_attributes(child_elem, recursive))
+        return attrs
+    
+    def attributes_list(self, elem):
+        # convert a list of attribute tags into list of tuples
+        # (tagnme, text_value)
+        opts = []
+        if not elem == None:
+            for e in elem:
+                opts.append((e.tag, e.text))
+        return opts
+
+    
diff --git a/sfa/rspecs/elements/link.py b/sfa/rspecs/elements/link.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/sfa/rspecs/elements/network.py b/sfa/rspecs/elements/network.py
new file mode 100644 (file)
index 0000000..6a358a4
--- /dev/null
@@ -0,0 +1,11 @@
+from sfa.rspecs.elements.element import Element
+from sfa.util.sfalogging import logger
+
+class Network(Element):
+
+    def get_networks(*args, **kwds):
+        logger.info("sfa.rspecs.networks: get_networks not implemented")
+
+    def add_networks(*args, **kwds):
+        logger.info("sfa.rspecs.networks: add_network not implemented")
+        
diff --git a/sfa/rspecs/elements/node.py b/sfa/rspecs/elements/node.py
new file mode 100644 (file)
index 0000000..db6e119
--- /dev/null
@@ -0,0 +1,13 @@
+from sfa.rspecs.elements.element import Element
+from sfa.util.faults import SfaNotImplemented 
+from sfa.util.sfalogging import logger
+class Node(Element):
+
+    def get_nodes(*args):
+        logger.info("sfa.rspecs.nodes: get_nodes not implemented") 
+    
+    def add_nodes(*args):
+        logger.info("sfa.rspecs.nodes: add_nodes not implemented") 
+                
+      
diff --git a/sfa/rspecs/elements/sliver.py b/sfa/rspecs/elements/sliver.py
new file mode 100644 (file)
index 0000000..67105dc
--- /dev/null
@@ -0,0 +1,29 @@
+from sfa.rspecs.elements.element import Element
+from sfa.util.sfalogging import logger
+
+class Slivers(Element):
+
+    def get_slivers(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: get_slivers not implemented")
+
+    def add_slivers(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: add_slivers not implemented")
+
+    def remove_slivers(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: remove_slivers not implemented")
+
+    def get_sliver_defaults(*args, **kwds):    
+        logger.debug("sfa.rspecs.slivers: get_sliver_defaults not implemented")
+    
+    def add_default_sliver_attribute(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: add_default_sliver_attributes not implemented")
+
+    def add_sliver_attribute(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: add_sliver_attribute not implemented")
+
+    def remove_default_sliver_attribute(*args, **kwds):
+        logger.debug("sfa.rspecs.slivers: remove_default_sliver_attributes not implemented")
+
+    def remove_sliver_attribute(*args, **kwds):
+        logger.debuv("sfa.rspecs.slivers: remove_sliver_attribute not implemented")
+        
index 801144a..6ee30db 100755 (executable)
@@ -3,7 +3,7 @@ from lxml import etree
 from StringIO import StringIO
 from sfa.rspecs.rspec import RSpec 
 from sfa.util.xrn import *
-from sfa.util.plxrn import hostname_to_urn
+from sfa.util.plxrn import hostname_to_urn, xrn_to_hostname
 from sfa.util.config import Config 
 from sfa.rspecs.rspec_version import RSpecVersion 
 
@@ -79,23 +79,37 @@ class PGRSpec(RSpec):
         return nodes
 
     def get_nodes(self, network=None):
-        xpath = '//rspecv2:node[@component_name]/@component_name | //node[@component_name]/@component_name'
-        return self.xml.xpath(xpath, namespaces=self.namespaces) 
+        xpath = '//rspecv2:node[@component_name]/@component_id | //node[@component_name]/@component_id'
+        nodes = self.xml.xpath(xpath, namespaces=self.namespaces)
+        nodes = [xrn_to_hostname(node) for node in nodes]
+        return nodes 
 
     def get_nodes_with_slivers(self, network=None):
         if network:
-            return self.xml.xpath('//rspecv2:node[@component_manager_id="%s"][sliver_type]/@component_name' % network, namespaces=self.namespaces)
+            nodes = self.xml.xpath('//rspecv2:node[@component_manager_id="%s"][sliver_type]/@component_id' % network, namespaces=self.namespaces)
         else:
-            return self.xml.xpath('//rspecv2:node[rspecv2:sliver_type]/@component_name', namespaces=self.namespaces)
+            nodes = self.xml.xpath('//rspecv2:node[rspecv2:sliver_type]/@component_id', namespaces=self.namespaces)
+        nodes = [xrn_to_hostname(node) for node in nodes]
+        return nodes
 
     def get_nodes_without_slivers(self, network=None):
-        pass
+        return []
    
     def get_slice_attributes(self, network=None):
-        pass
+        return []
+
+    def attributes_list(self, elem):
+        opts = []
+        if elem is not None:
+            for e in elem:
+                opts.append((e.tag, e.text))
+        return opts
 
     def get_default_sliver_attributes(self, network=None):
-        pass 
+        return []
+
+    def add_default_sliver_attribute(self, name, value, network=None):
+        pass
 
     def add_nodes(self, nodes, check_for_dupes=False):
         if not isinstance(nodes, list):
@@ -118,30 +132,36 @@ class PGRSpec(RSpec):
             node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='plab-pc')
             node_type_tag = etree.SubElement(node_tag, 'hardware_type', name='pc')
             available_tag = etree.SubElement(node_tag, 'available', now='true')
-            location_tag = etree.SubElement(node_tag, 'location', country="us")
+            sliver_type_tag = etree.SubElement(node_tag, 'sliver_type', name='plab-vnode')
+            # protogeni uses the <sliver_type> tag to identify the types of
+            # vms available at the node. 
+            # only add location tag if longitude and latitude are not null
             if 'site' in node:
-                if 'longitude' in node['site']:
-                    location_tag.set('longitude', str(node['site']['longitude']))
-                if 'latitude' in node['site']:
-                    location_tag.set('latitude', str(node['site']['latitude']))
-            #if 'interfaces' in node:
-            
+                longitude = node['site'].get('longitude', None)
+                latitude = node['site'].get('latitude', None)
+                if longitude and latitude:
+                    location_tag = etree.SubElement(node_tag, 'location', country="us", \
+                                                    longitude=str(longitude), latitude=str(latitude))
+
 
     def add_slivers(self, slivers, sliver_urn=None, no_dupes=False): 
+
+        # all nodes hould already be present in the rspec. Remove all 
+        # nodes that done have slivers
         slivers = self._process_slivers(slivers)
-        nodes_with_slivers = self.get_nodes_with_slivers()
-        for sliver in slivers:
-            hostname = sliver['hostname']
-            if hostname in nodes_with_slivers:
-                continue
-            nodes = self.xml.xpath('//rspecv2:node[@component_name="%s"] | //node[@component_name="%s"]' % (hostname, hostname), namespaces=self.namespaces)
-            if nodes:
-                node = nodes[0]
+        sliver_hosts = [sliver['hostname'] for sliver in slivers]
+        nodes = self.get_node_elements()
+        for node in nodes:
+            urn = node.get('component_id')
+            hostname = xrn_to_hostname(urn)
+            if hostname not in sliver_hosts:
+                parent = node.getparent()
+                parent.remove(node)
+            else:
                 node.set('client_id', hostname)
                 if sliver_urn:
                     node.set('sliver_id', sliver_urn)
-                etree.SubElement(node, 'sliver_type', name='plab-vnode')
-
+     
     def add_default_sliver_attribute(self, name, value, network=None):
         pass
 
@@ -166,19 +186,6 @@ class PGRSpec(RSpec):
     def cleanup(self):
         # remove unncecessary elements, attributes
         if self.type in ['request', 'manifest']:
-            # remove nodes without slivers
-            nodes = self.get_node_elements()
-            for node in nodes:
-                delete = True
-                hostname = node.get('component_name')
-                parent = node.getparent()
-                children = node.getchildren()
-                for child in children:
-                    if child.tag.endswith('sliver_type'):
-                        delete = False
-                if delete:
-                    parent.remove(node)
-
             # remove 'available' element from remaining node elements
             self.remove_element('//rspecv2:available | //available')
 
index 3f85618..5eb34ed 100755 (executable)
@@ -53,14 +53,18 @@ class SfaRSpec(RSpec):
             nodes = self.xml.xpath('//node/hostname/text()')
         else:
             nodes = self.xml.xpath('//network[@name="%s"]//node/hostname/text()' % network)
+
+        nodes = [node.strip() for node in nodes]
         return nodes
 
     def get_nodes_with_slivers(self, network = None):
         if network:
-            return self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)   
+            nodes =  self.xml.xpath('//network[@name="%s"]//node[sliver]/hostname/text()' % network)   
         else:
-            return self.xml.xpath('//node[sliver]/hostname/text()')
+            nodes = self.xml.xpath('//node[sliver]/hostname/text()')
 
+        nodes = [node.strip() for node in nodes]
+        return nodes     
     def get_nodes_without_slivers(self, network=None): 
         xpath_nodes_without_slivers = '//node[not(sliver)]/hostname/text()'
         xpath_nodes_without_slivers_in_network = '//network[@name="%s"]//node[not(sliver)]/hostname/text()' 
@@ -83,7 +87,9 @@ class SfaRSpec(RSpec):
         if network:
             defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)        
         else:
-            defaults = self.xml.xpath("//network/sliver_defaults" % network)
+            defaults = self.xml.xpath("//sliver_defaults")
+        if isinstance(defaults, list) and defaults:
+            defaults = defaults[0]
         return self.attributes_list(defaults)
 
     def get_sliver_attributes(self, hostname, network=None):
@@ -92,8 +98,17 @@ class SfaRSpec(RSpec):
         return self.attributes_list(sliver)
 
     def get_slice_attributes(self, network=None):
-        # TODO: FINISH
-        return []
+        slice_attributes = []
+        nodes_with_slivers = self.get_nodes_with_slivers(network)
+        for default_attribute in self.get_default_sliver_attributes(network):
+            attribute = {'name': str(default_attribute[0]), 'value': str(default_attribute[1]), 'node_id': None}
+            slice_attributes.append(attribute)
+        for node in nodes_with_slivers:
+            sliver_attributes = self.get_sliver_attributes(node, network)
+            for sliver_attribute in sliver_attributes:
+                attribute = {'name': str(sliver_attribute[0]), 'value': str(sliver_attribute[1]), 'node_id': node}
+                slice_attributes.append(attribute)    
+        return slice_attributes
 
     def get_site_nodes(self, siteid, network=None):
         if network:
@@ -212,7 +227,7 @@ class SfaRSpec(RSpec):
                 for tag in node['tags']:
                    # expose this hard wired list of tags, plus the ones that are marked 'sfa' in their category 
                    if tag['tagname'] in ['fcdistro', 'arch'] or 'sfa' in tag['category'].split('/'):
-                        tag_element = etree.SubElement(node_tag, tag['tagname'], value=tag['value'])
+                        tag_element = etree.SubElement(node_tag, tag['tagname']).text=tag['value']
 
             if 'site' in node:
                 longitude = str(node['site']['longitude'])
@@ -236,7 +251,7 @@ class SfaRSpec(RSpec):
             sliver_elem = etree.SubElement(node_elem, 'sliver')
             if 'tags' in sliver:
                 for tag in sliver['tags']:
-                    etree.SubElement(sliver_elem, tag['tagname'], value=tag['value'])
+                    etree.SubElement(sliver_elem, tag['tagname']).text = value=tag['value']
 
     def remove_slivers(self, slivers, network=None, no_dupes=False):
         slivers = self._process_slivers(slivers)
@@ -251,9 +266,13 @@ class SfaRSpec(RSpec):
             defaults = self.xml.xpath("//network[@name='%s']/sliver_defaults" % network)
         else:
             defaults = self.xml.xpath("//sliver_defaults" % network)
-        if defaults is None:
+        if not defaults :
             network_tag = self.xml.xpath("//network[@name='%s']" % network)
+            if isinstance(network_tag, list):
+                network_tag = network_tag[0]
             defaults = self.add_element('sliver_defaults', attrs={}, parent=network_tag)
+        elif isinstance(defaults, list):
+            defaults = defaults[0]
         self.add_attribute(defaults, name, value)
 
     def add_sliver_attribute(self, hostname, name, value, network=None):
index 12a0e4f..f37331a 100644 (file)
@@ -80,107 +80,6 @@ class Interfaces(dict):
                 self.interfaces[interface['hrn']] = interface
 
 
-    def sync_interfaces(self):
-        """
-        Install missing trusted gids and db records for our federated
-        interfaces
-        """     
-        # Attempt to get any missing peer gids
-        # There should be a gid file in /etc/sfa/trusted_roots for every
-        # peer registry found in in the registries.xml config file. If there
-        # are any missing gids, request a new one from the peer registry.
-        gids_current = self.api.auth.trusted_cert_list
-        hrns_current = [gid.get_hrn() for gid in gids_current] 
-        hrns_expected = self.interfaces.keys() 
-        new_hrns = set(hrns_expected).difference(hrns_current)
-        gids = self.get_peer_gids(new_hrns) + gids_current
-        # make sure there is a record for every gid
-        self.update_db_records(self.type, gids)
-        
-    def get_peer_gids(self, new_hrns):
-        """
-        Install trusted gids from the specified interfaces.  
-        """
-        peer_gids = []
-        if not new_hrns:
-            return peer_gids
-        trusted_certs_dir = self.api.config.get_trustedroots_dir()
-        for new_hrn in new_hrns:
-            if not new_hrn:
-                continue
-            # the gid for this interface should already be installed  
-            if new_hrn == self.api.config.SFA_INTERFACE_HRN:
-                continue
-            try:
-                # get gid from the registry
-                interface_info =  self.interfaces[new_hrn]
-                interface = self[new_hrn]
-                trusted_gids = interface.get_trusted_certs()
-                if trusted_gids:
-                    # the gid we want shoudl be the first one in the list, 
-                    # but lets make sure
-                    for trusted_gid in trusted_gids:
-                        # default message
-                        message = "interface: %s\t" % (self.api.interface)
-                        message += "unable to install trusted gid for %s" % \
-                                   (new_hrn) 
-                        gid = GID(string=trusted_gids[0])
-                        peer_gids.append(gid) 
-                        if gid.get_hrn() == new_hrn:
-                            gid_filename = os.path.join(trusted_certs_dir, '%s.gid' % new_hrn)
-                            gid.save_to_file(gid_filename, save_parents=True)
-                            message = "interface: %s\tinstalled trusted gid for %s" % \
-                                (self.api.interface, new_hrn)
-                        # log the message
-                        self.api.logger.info(message)
-            except:
-                message = "interface: %s\tunable to install trusted gid for %s" % \
-                            (self.api.interface, new_hrn) 
-                self.api.logger.log_exc(message)
-        
-        # reload the trusted certs list
-        self.api.auth.load_trusted_certs()
-        return peer_gids
-
-    def update_db_records(self, type, gids):
-        """
-        Make sure there is a record in the local db for allowed registries
-        defined in the config file (registries.xml). Removes old records from
-        the db.         
-        """
-        # import SfaTable here so this module can be loaded by ComponentAPI 
-        from sfa.util.table import SfaTable
-        if not gids: 
-            return
-        
-        # hrns that should have a record
-        hrns_expected = [gid.get_hrn() for gid in gids]
-
-        # get hrns that actually exist in the db
-        table = SfaTable()
-        records = table.find({'type': type, 'pointer': -1})
-        hrns_found = [record['hrn'] for record in records]
-      
-        # remove old records
-        for record in records:
-            if record['hrn'] not in hrns_expected and \
-                record['hrn'] != self.api.config.SFA_INTERFACE_HRN:
-                table.remove(record)
-
-        # add new records
-        for gid in gids:
-            hrn = gid.get_hrn()
-            if hrn not in hrns_found:
-                record = {
-                    'hrn': hrn,
-                    'type': type,
-                    'pointer': -1, 
-                    'authority': get_authority(hrn),
-                    'gid': gid.save_to_string(save_parents=True),
-                }
-                record = SfaRecord(dict=record)
-                table.insert(record)
-                        
     def get_connections(self):
         """
         read connection details for the trusted peer registries from file return 
index b7bfdd8..da25b2a 100644 (file)
@@ -9,8 +9,6 @@ from sfa.util.server import SfaServer
 from sfa.util.faults import *
 from sfa.util.xrn import hrn_to_urn
 from sfa.server.interface import Interfaces
-import sfa.util.xmlrpcprotocol as xmlrpcprotocol
-import sfa.util.soapprotocol as soapprotocol
  
 
 ##
index c981158..28a5a61 100755 (executable)
@@ -35,6 +35,7 @@ component_port=12346
 import os, os.path
 import traceback
 import sys
+import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from optparse import OptionParser
 
 from sfa.util.sfalogging import logger
@@ -46,7 +47,8 @@ from sfa.util.config import Config
 from sfa.plc.api import SfaAPI
 from sfa.server.registry import Registries
 from sfa.server.aggregate import Aggregates
-
+from sfa.util.xrn import get_authority, hrn_to_urn
+from sfa.util.sfalogging import logger
 
 # after http://www.erlenstar.demon.co.uk/unix/faq_2.html
 def daemon():
@@ -162,17 +164,101 @@ def init_server(options, config):
         manager_module = manager_base + ".component_manager_%s" % mgr_type
         init_manager(manager_module, manager_base)    
 
-def sync_interfaces(server_key_file, server_cert_file):
+def install_peer_certs(server_key_file, server_cert_file):
     """
     Attempt to install missing trusted gids and db records for 
     our federated interfaces
     """
+    # Attempt to get any missing peer gids
+    # There should be a gid file in /etc/sfa/trusted_roots for every
+    # peer registry found in in the registries.xml config file. If there
+    # are any missing gids, request a new one from the peer registry.
     api = SfaAPI(key_file = server_key_file, cert_file = server_cert_file)
     registries = Registries(api)
     aggregates = Aggregates(api)
-    registries.sync_interfaces()
-    aggregates.sync_interfaces()
+    interfaces = dict(registries.interfaces.items() + aggregates.interfaces.items())
+    gids_current = api.auth.trusted_cert_list
+    hrns_current = [gid.get_hrn() for gid in gids_current]
+    hrns_expected = interfaces.keys()
+    new_hrns = set(hrns_expected).difference(hrns_current)
+    #gids = self.get_peer_gids(new_hrns) + gids_current
+    peer_gids = []
+    if not new_hrns:
+        return 
+
+    trusted_certs_dir = api.config.get_trustedroots_dir()
+    for new_hrn in new_hrns:
+        if not new_hrn: continue
+        # the gid for this interface should already be installed
+        if new_hrn == api.config.SFA_INTERFACE_HRN: continue
+        try:
+            # get gid from the registry
+            url = interfaces[new_hrn]['url']
+            interface = xmlrpcprotocol.get_server(url, server_key_file, server_cert_file)
+            # skip non sfa aggregates
+            server_version = api.get_cached_server_version(interface)
+            if 'sfa' not in server_version:
+                logger.info("get_trusted_certs: skipping non sfa aggregate: %s" % new_hrn)
+                continue
+      
+            trusted_gids = interface.get_trusted_certs()
+            if trusted_gids:
+                # the gid we want should be the first one in the list,
+                # but lets make sure
+                for trusted_gid in trusted_gids:
+                    # default message
+                    message = "interface: %s\t" % (api.interface)
+                    message += "unable to install trusted gid for %s" % \
+                               (new_hrn)
+                    gid = GID(string=trusted_gids[0])
+                    peer_gids.append(gid)
+                    if gid.get_hrn() == new_hrn:
+                        gid_filename = os.path.join(trusted_certs_dir, '%s.gid' % new_hrn)
+                        gid.save_to_file(gid_filename, save_parents=True)
+                        message = "installed trusted cert for %s" % new_hrn
+                    # log the message
+                    api.logger.info(message)
+        except:
+            message = "interface: %s\tunable to install trusted gid for %s" % \
+                        (api.interface, new_hrn)
+            api.logger.log_exc(message)
+    # doesnt matter witch one
+    update_cert_records(peer_gids)
+
+def update_cert_records(gids):
+    """
+    Make sure there is a record in the registry for the specified gids. 
+    Removes old records from the db.
+    """
+    # import SfaTable here so this module can be loaded by ComponentAPI
+    from sfa.util.table import SfaTable
+    if not gids:
+        return
+    table = SfaTable()
+    # get records that actually exist in the db
+    gid_urns = [gid.get_urn() for gid in gids]
+    hrns_expected = [gid.get_hrn() for gid in gids]
+    records_found = table.find({'hrn': hrns_expected, 'pointer': -1}) 
+
+    # remove old records
+    for record in records_found:
+        if record['hrn'] not in hrns_expected and \
+            record['hrn'] != self.api.config.SFA_INTERFACE_HRN:
+            table.remove(record)
 
+    # TODO: store urn in the db so we do this in 1 query 
+    for gid in gids:
+        hrn, type = gid.get_hrn(), gid.get_type()
+        record = table.find({'hrn': hrn, 'type': type, 'pointer': -1})
+        if not record:
+            record = {
+                'hrn': hrn, 'type': type, 'pointer': -1,
+                'authority': get_authority(hrn),
+                'gid': gid.save_to_string(save_parents=True),
+            }
+            record = SfaRecord(dict=record)
+            table.insert(record)
+        
 def main():
     # Generate command line parser
     parser = OptionParser(usage="sfa-server [options]")
@@ -184,12 +270,14 @@ def main():
          help="run aggregate manager", default=False)
     parser.add_option("-c", "--component", dest="cm", action="store_true",
          help="run component server", default=False)
+    parser.add_option("-t", "--trusted-certs", dest="trusted_certs", action="store_true",
+         help="refresh trusted certs", default=False)
     parser.add_option("-v", "--verbose", action="count", dest="verbose", default=0,
          help="verbose mode - cumulative")
     parser.add_option("-d", "--daemon", dest="daemon", action="store_true",
          help="Run as daemon.", default=False)
     (options, args) = parser.parse_args()
-
+    
     config = Config()
     if config.SFA_API_DEBUG: pass
     hierarchy = Hierarchy()
@@ -198,9 +286,12 @@ def main():
 
     init_server_key(server_key_file, server_cert_file, config, hierarchy)
     init_server(options, config)
-    sync_interfaces(server_key_file, server_cert_file)   
  
     if (options.daemon):  daemon()
+    
+    if options.trusted_certs:
+        install_peer_certs(server_key_file, server_cert_file)   
+    
     # start registry server
     if (options.registry):
         from sfa.server.registry import Registry
index 3e1fbcc..bd7e7f1 100644 (file)
 ##
 
 import os
+from types import StringTypes
 import datetime
+from StringIO import StringIO
 from tempfile import mkstemp
-import dateutil.parser
-from StringIO import StringIO 
 from xml.dom.minidom import Document, parseString
 from lxml import etree
 
 from sfa.util.faults import *
 from sfa.util.sfalogging import logger
+from sfa.util.sfatime import utcparse
 from sfa.trust.certificate import Keypair
 from sfa.trust.credential_legacy import CredentialLegacy
 from sfa.trust.rights import Right, Rights
@@ -49,37 +50,66 @@ DEFAULT_CREDENTIAL_LIFETIME = 86400 * 14
 # TODO:
 # . make privs match between PG and PL
 # . Need to add support for other types of credentials, e.g. tickets
-
+# . add namespaces to signed-credential element?
 
 signature_template = \
 '''
 <Signature xml:id="Sig_%s" xmlns="http://www.w3.org/2000/09/xmldsig#">
-    <SignedInfo>
-      <CanonicalizationMethod Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>
-      <SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>
-      <Reference URI="#%s">
+  <SignedInfo>
+    <CanonicalizationMethod Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>
+    <SignatureMethod Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>
+    <Reference URI="#%s">
       <Transforms>
         <Transform Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />
       </Transforms>
       <DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/>
       <DigestValue></DigestValue>
-      </Reference>
-    </SignedInfo>
-    <SignatureValue />
-      <KeyInfo>
-        <X509Data>
-          <X509SubjectName/>
-          <X509IssuerSerial/>
-          <X509Certificate/>
-        </X509Data>
-      <KeyValue />
-      </KeyInfo>
-    </Signature>
+    </Reference>
+  </SignedInfo>
+  <SignatureValue />
+  <KeyInfo>
+    <X509Data>
+      <X509SubjectName/>
+      <X509IssuerSerial/>
+      <X509Certificate/>
+    </X509Data>
+    <KeyValue />
+  </KeyInfo>
+</Signature>
 '''
 
+# PG formats the template (whitespace) slightly differently.
+# Note that they don't include the xmlns in the template, but add it later.
+# Otherwise the two are equivalent.
+#signature_template_as_in_pg = \
+#'''
+#<Signature xml:id="Sig_%s" >
+# <SignedInfo>
+#  <CanonicalizationMethod      Algorithm="http://www.w3.org/TR/2001/REC-xml-c14n-20010315"/>
+#  <SignatureMethod      Algorithm="http://www.w3.org/2000/09/xmldsig#rsa-sha1"/>
+#  <Reference URI="#%s">
+#    <Transforms>
+#      <Transform         Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature" />
+#    </Transforms>
+#    <DigestMethod        Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/>
+#    <DigestValue></DigestValue>
+#    </Reference>
+# </SignedInfo>
+# <SignatureValue />
+# <KeyInfo>
+#  <X509Data >
+#   <X509SubjectName/>
+#   <X509IssuerSerial/>
+#   <X509Certificate/>
+#  </X509Data>
+#  <KeyValue />
+# </KeyInfo>
+#</Signature>
+#'''
+
 ##
 # Convert a string into a bool
-
+# used to convert an xsd:boolean to a Python boolean
 def str2bool(str):
     if str.lower() in ['true','1']:
         return True
@@ -214,7 +244,6 @@ class Credential(object):
                 str = string
             elif filename:
                 str = file(filename).read()
-                self.filename=filename
                 
             if str.strip().startswith("-----"):
                 self.legacy = CredentialLegacy(False,string=str)
@@ -316,21 +345,25 @@ class Credential(object):
 
             
     ##
-    # Expiration: an absolute UTC time of expiration (as either an int or datetime)
+    # Expiration: an absolute UTC time of expiration (as either an int or string or datetime)
     # 
     def set_expiration(self, expiration):
-        if isinstance(expiration, int):
+        if isinstance(expiration, (int,float)):
             self.expiration = datetime.datetime.fromtimestamp(expiration)
-        else:
+        elif isinstance (expiration, datetime.datetime):
             self.expiration = expiration
-            
+        elif isinstance (expiration, StringTypes):
+            self.expiration = utcparse (expiration)
+        else:
+            logger.error ("unexpected input type in Credential.set_expiration")
 
     ##
-    # get the lifetime of the credential (in datetime format)
-
+    # get the lifetime of the credential (always in datetime format)
+    #
     def get_expiration(self):
         if not self.expiration:
             self.decode()
+        # at this point self.expiration is normalized as a datetime - DON'T call utcparse again
         return self.expiration
 
     ##
@@ -385,6 +418,15 @@ class Credential(object):
         # Create the XML document
         doc = Document()
         signed_cred = doc.createElement("signed-credential")
+
+# PG adds these. It would be nice to be consistent.
+# But it's kind of odd for PL to use PG schemas that talk
+# about tickets, and the PG CM policies.
+# Note the careful addition of attributes from the parent below...
+#        signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")
+#        signed_cred.setAttribute("xsinoNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd")
+#        signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd")
+
         doc.appendChild(signed_cred)  
         
         # Fill in the <credential> bit        
@@ -416,11 +458,34 @@ class Credential(object):
         # Add the parent credential if it exists
         if self.parent:
             sdoc = parseString(self.parent.get_xml())
+            # If the root node is a signed-credential (it should be), then
+            # get all its attributes and attach those to our signed_cred
+            # node.
+            # Specifically, PG adds attributes for namespaces (which is reasonable),
+            # and we need to include those again here or else their signature
+            # no longer matches on the credential.
+            # We expect three of these, but here we copy them all:
+#        signed_cred.setAttribute("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")
+#        signed_cred.setAttribute("xsinoNamespaceSchemaLocation", "http://www.protogeni.net/resources/credential/credential.xsd")
+#        signed_cred.setAttribute("xsi:schemaLocation", "http://www.protogeni.net/resources/credential/ext/policy/1 http://www.protogeni.net/resources/credential/ext/policy/1/policy.xsd")
+            parentRoot = sdoc.documentElement
+            if parentRoot.tagName == "signed-credential" and parentRoot.hasAttributes():
+                for attrIx in range(0, parentRoot.attributes.length):
+                    attr = parentRoot.attributes.item(attrIx)
+                    # returns the old attribute of same name that was
+                    # on the credential
+                    # Below throws InUse exception if we forgot to clone the attribute first
+                    oldAttr = signed_cred.setAttributeNode(attr.cloneNode(True))
+                    if oldAttr and oldAttr.value != attr.value:
+                        msg = "Delegating cred from owner %s to %s over %s replaced attribute %s value %s with %s" % (self.parent.gidCaller.get_urn(), self.gidCaller.get_urn(), self.gidObject.get_urn(), oldAttr.name, oldAttr.value, attr.value)
+                        logger.error(msg)
+                        raise CredentialNotVerifiable("Can't encode new valid delegated credential: %s" % msg)
+
             p_cred = doc.importNode(sdoc.getElementsByTagName("credential")[0], True)
             p = doc.createElement("parent")
             p.appendChild(p_cred)
             cred.appendChild(p)
-
+        # done handling parent credential
 
         # Create the <signatures> tag
         signatures = doc.createElement("signatures")
@@ -452,7 +517,6 @@ class Credential(object):
             f = open(filename, "w")
         f.write(self.xml)
         f.close()
-        self.filename=filename
 
     def save_to_string(self, save_parents=True):
         if not self.xml:
@@ -583,7 +647,7 @@ class Credential(object):
         
 
         self.set_refid(cred.getAttribute("xml:id"))
-        self.set_expiration(dateutil.parser.parse(getTextNode(cred, "expires")))
+        self.set_expiration(utcparse(getTextNode(cred, "expires")))
         self.gidCaller = GID(string=getTextNode(cred, "owner_gid"))
         self.gidObject = GID(string=getTextNode(cred, "target_gid"))   
 
@@ -595,10 +659,12 @@ class Credential(object):
             kind = getTextNode(priv, "name")
             deleg = str2bool(getTextNode(priv, "can_delegate"))
             if kind == '*':
-                # Convert * into the default privileges for the credential's type                
+                # Convert * into the default privileges for the credential's type
+                # Each inherits the delegatability from the * above
                 _ , type = urn_to_hrn(self.gidObject.get_urn())
                 rl = rlist.determine_rights(type, self.gidObject.get_urn())
                 for r in rl.rights:
+                    r.delegate = deleg
                     rlist.add(r)
             else:
                 rlist.add(Right(kind.strip(), deleg))
@@ -626,6 +692,10 @@ class Credential(object):
     # Verify
     #   trusted_certs: A list of trusted GID filenames (not GID objects!) 
     #                  Chaining is not supported within the GIDs by xmlsec1.
+    #
+    #   trusted_certs_required: Should usually be true. Set False means an
+    #                 empty list of trusted_certs would still let this method pass.
+    #                 It just skips xmlsec1 verification et al. Only used by some utils
     #    
     # Verify that:
     # . All of the signatures are valid and that the issuers trace back
@@ -648,11 +718,10 @@ class Credential(object):
     #   must be done elsewhere
     #
     # @param trusted_certs: The certificates of trusted CA certificates
-    # @param schema: The RelaxNG schema to validate the credential against 
-    def verify(self, trusted_certs, schema=None):
+    def verify(self, trusted_certs=None, schema=None, trusted_certs_required=True):
         if not self.xml:
-            self.decode()        
-        
+            self.decode()
+
         # validate against RelaxNG schema
         if not self.legacy:
             if schema and os.path.exists(schema):
@@ -662,21 +731,26 @@ class Credential(object):
                 if not xmlschema.validate(tree):
                     error = xmlschema.error_log.last_error
                     message = "%s (line %s)" % (error.message, error.line)
-                    raise CredentialNotVerifiable(message) 
-            
+                    raise CredentialNotVerifiable(message)        
 
-#       trusted_cert_objects = [GID(filename=f) for f in trusted_certs]
+        if trusted_certs_required and trusted_certs is None:
+            trusted_certs = []
+
+#        trusted_cert_objects = [GID(filename=f) for f in trusted_certs]
         trusted_cert_objects = []
         ok_trusted_certs = []
-        for f in trusted_certs:
-            try:
-                # Failures here include unreadable files
-                # or non PEM files
-                trusted_cert_objects.append(GID(filename=f))
-                ok_trusted_certs.append(f)
-            except Exception, exc:
-                logger.error("Failed to load trusted cert from %s: %r"%( f, exc))
-        trusted_certs = ok_trusted_certs
+        # If caller explicitly passed in None that means skip cert chain validation.
+        # Strange and not typical
+        if trusted_certs is not None:
+            for f in trusted_certs:
+                try:
+                    # Failures here include unreadable files
+                    # or non PEM files
+                    trusted_cert_objects.append(GID(filename=f))
+                    ok_trusted_certs.append(f)
+                except Exception, exc:
+                    logger.error("Failed to load trusted cert from %s: %r", f, exc)
+            trusted_certs = ok_trusted_certs
 
         # Use legacy verification if this is a legacy credential
         if self.legacy:
@@ -686,7 +760,6 @@ class Credential(object):
             if self.legacy.object_gid:
                 self.legacy.object_gid.verify_chain(trusted_cert_objects)
             return True
-
         
         # make sure it is not expired
         if self.get_expiration() < datetime.datetime.utcnow():
@@ -694,12 +767,16 @@ class Credential(object):
 
         # Verify the signatures
         filename = self.save_to_random_tmp_file()
-        cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs])
+        if trusted_certs is not None:
+            cert_args = " ".join(['--trusted-pem %s' % x for x in trusted_certs])
 
-        # Verify the gids of this cred and of its parents
-        for cur_cred in self.get_credential_list():
-            cur_cred.get_gid_object().verify_chain(trusted_cert_objects)
-            cur_cred.get_gid_caller().verify_chain(trusted_cert_objects) 
+        # If caller explicitly passed in None that means skip cert chain validation.
+        # Strange and not typical
+        if trusted_certs is not None:
+            # Verify the gids of this cred and of its parents
+            for cur_cred in self.get_credential_list():
+                cur_cred.get_gid_object().verify_chain(trusted_cert_objects)
+                cur_cred.get_gid_caller().verify_chain(trusted_cert_objects)
 
         refs = []
         refs.append("Sig_%s" % self.get_refid())
@@ -709,10 +786,24 @@ class Credential(object):
             refs.append("Sig_%s" % ref)
 
         for ref in refs:
+            # If caller explicitly passed in None that means skip xmlsec1 validation.
+            # Strange and not typical
+            if trusted_certs is None:
+                break
+
+#            print "Doing %s --verify --node-id '%s' %s %s 2>&1" % \
+#                (self.xmlsec_path, ref, cert_args, filename)
             verified = os.popen('%s --verify --node-id "%s" %s %s 2>&1' \
                             % (self.xmlsec_path, ref, cert_args, filename)).read()
             if not verified.strip().startswith("OK"):
-                raise CredentialNotVerifiable("xmlsec1 error verifying cert: " + verified)
+                # xmlsec errors have a msg= which is the interesting bit.
+                mstart = verified.find("msg=")
+                msg = ""
+                if mstart > -1 and len(verified) > 4:
+                    mstart = mstart + 4
+                    mend = verified.find('\\', mstart)
+                    msg = verified[mstart:mend]
+                raise CredentialNotVerifiable("xmlsec1 error verifying cred using Signature ID %s: %s %s" % (ref, msg, verified.strip()))
         os.remove(filename)
 
         # Verify the parents (delegation)
@@ -759,7 +850,7 @@ class Credential(object):
         # Maybe should be (hrn, type) = urn_to_hrn(root_cred_signer.get_urn())
         root_cred_signer_type = root_cred_signer.get_type()
         if (root_cred_signer_type == 'authority'):
-            #logger.debug('Cred signer is an authority')
+            #sfa_logger.debug('Cred signer is an authority')
             # signer is an authority, see if target is in authority's domain
             hrn = root_cred_signer.get_hrn()
             if root_target_gid.get_hrn().startswith(hrn):
@@ -787,8 +878,8 @@ class Credential(object):
         # make sure the rights given to the child are a subset of the
         # parents rights (and check delegate bits)
         if not parent_cred.get_privileges().is_superset(self.get_privileges()):
-            raise ChildRightsNotSubsetOfParent(
-                self.parent.get_privileges().save_to_string() + " " +
+            raise ChildRightsNotSubsetOfParent(("Parent cred ref %s rights " % self.parent.get_refid()) + 
+                self.parent.get_privileges().save_to_string() + (" not superset of delegated cred ref %s rights " % self.get_refid()) +
                 self.get_privileges().save_to_string())
 
         # make sure my target gid is the same as the parent's
@@ -838,19 +929,23 @@ class Credential(object):
         dcred.encode()
         dcred.sign()
 
-        return dcred 
+        return dcred
 
     # only informative
     def get_filename(self):
         return getattr(self,'filename',None)
-
+    ##
+    # Dump the contents of a credential to stdout in human-readable format
+    #
     # @param dump_parents If true, also dump the parent certificates
     def dump (self, *args, **kwargs):
         print self.dump_string(*args, **kwargs)
 
+
     def dump_string(self, dump_parents=False):
         result=""
-        result += "CREDENTIAL %s\n" % self.get_subject() 
+        result += "CREDENTIAL %s\n" % self.get_subject()
         filename=self.get_filename()
         if filename: result += "Filename %s\n"%filename
         result += "      privs: %s\n" % self.get_privileges().save_to_string()
@@ -873,4 +968,3 @@ class Credential(object):
             result += self.parent.dump(True)
 
         return result
-
index ff1ac2d..14749cb 100644 (file)
@@ -220,6 +220,7 @@ class Rights:
             for my_right in self.rights:
                 if my_right.is_superset(child_right):
                     allowed = True
+                    break
             if not allowed:
                 return False
         return True
index 67d155e..a733308 100644 (file)
@@ -7,11 +7,13 @@ import os
 import traceback
 import string
 import xmlrpclib
+import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 
 from sfa.util.sfalogging import logger
 from sfa.trust.auth import Auth
 from sfa.util.config import *
 from sfa.util.faults import *
+from sfa.util.cache import Cache
 from sfa.trust.credential import *
 from sfa.trust.certificate import *
 
@@ -113,12 +115,11 @@ class ManagerWrapper:
         
 class BaseAPI:
 
-    cache = None
     protocol = None
   
     def __init__(self, config = "/etc/sfa/sfa_config.py", encoding = "utf-8", 
                  methods='sfa.methods', peer_cert = None, interface = None, 
-                 key_file = None, cert_file = None, cache = cache):
+                 key_file = None, cert_file = None, cache = None):
 
         self.encoding = encoding
         
@@ -129,7 +130,6 @@ class BaseAPI:
         # Better just be documenting the API
         if config is None:
             return
-        
         # Load configuration
         self.config = Config(config)
         self.auth = Auth(peer_cert)
@@ -140,6 +140,8 @@ class BaseAPI:
         self.cert_file = cert_file
         self.cert = Certificate(filename=self.cert_file)
         self.cache = cache
+        if self.cache is None:
+            self.cache = Cache()
         self.credential = None
         self.source = None 
         self.time_format = "%Y-%m-%d %H:%M:%S"
@@ -266,3 +268,14 @@ class BaseAPI:
                 raise result 
             
         return response
+
+    def get_cached_server_version(self, server):
+        cache_key = server.url + "-version"
+        server_version = None
+        if self.cache:
+            server_version = self.cache.get(cache_key)
+        if not server_version:
+            server_version = server.GetVersion()
+            # cache version for 24 hours
+            self.cache.add(cache_key, server_version, ttl= 60*60*24)
+        return server_version
index 5580c44..dacdd51 100644 (file)
@@ -17,7 +17,8 @@ def hrn_to_pl_login_base (hrn):
     return PlXrn(xrn=hrn,type='slice').pl_login_base()
 def hrn_to_pl_authname (hrn):
     return PlXrn(xrn=hrn,type='any').pl_authname()
-
+def xrn_to_hostname(hrn):
+    return Xrn.unescape(PlXrn(xrn=hrn, type='node').get_leaf())
 
 class PlXrn (Xrn):
 
index f812517..3991d33 100755 (executable)
@@ -32,7 +32,16 @@ class _SfaLogger:
         handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
         self.logger=logging.getLogger(loggername)
         self.logger.setLevel(level)
-        self.logger.addHandler(handler)
+        # check if logger already has the handler we're about to add
+        handler_exists = False
+        for l_handler in self.logger.handlers:
+            if l_handler.baseFilename == handler.baseFilename and \
+               l_handler.level == handler.level:
+                handler_exists = True 
+
+        if not handler_exists:
+            self.logger.addHandler(handler)
+
         self.loggername=loggername
 
     def setLevel(self,level):
@@ -84,6 +93,12 @@ class _SfaLogger:
         self.debug("%s BEG STACK"%message+"\n"+to_log)
         self.debug("%s END STACK"%message)
 
+    def enable_console(self, stream=sys.stdout):
+        formatter = logging.Formatter("%(message)s")
+        handler = logging.StreamHandler(stream)
+        handler.setFormatter(formatter)
+        self.logger.addHandler(handler)
+
 
 info_logger = _SfaLogger(loggername='info', level=logging.INFO)
 debug_logger = _SfaLogger(loggername='debug', level=logging.DEBUG)
@@ -91,7 +106,7 @@ warn_logger = _SfaLogger(loggername='warning', level=logging.WARNING)
 error_logger = _SfaLogger(loggername='error', level=logging.ERROR)
 critical_logger = _SfaLogger(loggername='critical', level=logging.CRITICAL)
 logger = info_logger
-
+sfi_logger = _SfaLogger(logfile=os.path.expanduser("~/.sfi/")+'sfi.log',loggername='sfilog', level=logging.DEBUG)
 ########################################
 import time
 
@@ -120,21 +135,25 @@ def profile(logger):
 if __name__ == '__main__': 
     print 'testing sfalogging into logger.log'
     logger=_SfaLogger('logger.log')
+    logger2=_SfaLogger('logger.log', level=logging.DEBUG)
+    logger3=_SfaLogger('logger.log', level=logging.ERROR)
+    print logger.logger.handlers
+   
     logger.critical("logger.critical")
     logger.error("logger.error")
-    logger.warning("logger.warning")
+    logger.warn("logger.warning")
     logger.info("logger.info")
     logger.debug("logger.debug")
     logger.setLevel(logging.DEBUG)
     logger.debug("logger.debug again")
     
 
-    @profile(my_logger)
+    @profile(logger)
     def sleep(seconds = 1):
         time.sleep(seconds)
 
-    my_logger.info('console.info')
+    logger.info('console.info')
     sleep(0.5)
-    my_logger.setLevel(logging.DEBUG)
+    logger.setLevel(logging.DEBUG)
     sleep(0.25)
 
index 901b4e0..11cc566 100644 (file)
@@ -1,10 +1,26 @@
+from types import StringTypes
 import dateutil.parser
+import datetime
 
-def utcparse(str):
+from sfa.util.sfalogging import logger
+
+def utcparse(input):
     """ Translate a string into a time using dateutil.parser.parse but make sure it's in UTC time and strip
-    the timezone, so that it's compatible with normal datetime.datetime objects"""
+the timezone, so that it's compatible with normal datetime.datetime objects.
+
+For safety this can also handle inputs that are either timestamps, or datetimes
+"""
     
-    t = dateutil.parser.parse(str)
-    if not t.utcoffset() is None:
-        t = t.utcoffset() + t.replace(tzinfo=None)
-    return t
+    if isinstance (input, datetime.datetime):
+        logger.warn ("argument to utcparse already a datetime - doing nothing")
+        return input
+    elif isinstance (input, StringTypes):
+        t = dateutil.parser.parse(input)
+        if t.utcoffset() is not None:
+            t = t.utcoffset() + t.replace(tzinfo=None)
+        return t
+    elif isinstance (input, (int,float)):
+        return datetime.datetime.fromtimestamp(input)
+    else:
+        logger.error("Unexpected type in utcparse [%s]"%type(input))
+