Merge branch 'master' of ssh://bakers@git.planet-lab.org/git/sfa
authorsmbaker <smbaker@fc8clean.lan>
Thu, 7 Jul 2011 18:04:25 +0000 (11:04 -0700)
committersmbaker <smbaker@fc8clean.lan>
Thu, 7 Jul 2011 18:04:25 +0000 (11:04 -0700)
14 files changed:
sfa/managers/aggregate_manager_eucalyptus.py
sfa/managers/aggregate_manager_pl.py
sfa/managers/eucalyptus/eucalyptus.rnc
sfa/managers/eucalyptus/eucalyptus.rng
sfa/managers/slice_manager_pl.py
sfa/plc/aggregate.py
sfa/plc/sfa-import-plc.py
sfa/plc/sfaImport.py
sfa/rspecs/pg_rspec.py
sfa/trust/credential.py
sfa/trust/gid.py
sfa/util/sfalogging.py
sfa/util/sfatime.py
sfa/util/threadmanager.py

index 1b6d143..c5d2ba3 100644 (file)
@@ -12,13 +12,16 @@ from lxml import etree as ET
 from sqlobject import *
 
 from sfa.util.faults import *
-from sfa.util.xrn import urn_to_hrn
+from sfa.util.xrn import urn_to_hrn, Xrn
 from sfa.util.rspec import RSpec
 from sfa.server.registry import Registries
 from sfa.trust.credential import Credential
 from sfa.plc.api import SfaAPI
 from sfa.util.plxrn import hrn_to_pl_slicename, slicename_to_hrn
 from sfa.util.callids import Callids
+from sfa.util.sfalogging import sfa_logger
+from sfa.rspecs.sfa_rspec import sfa_rspec_version
+from sfa.util.version import version_core
 
 ##
 # The data structure used to represent a cloud.
@@ -404,7 +407,8 @@ def ListResources(api, creds, options, call_id):
     # get hrn of the original caller
     origin_hrn = options.get('origin_hrn', None)
     if not origin_hrn:
-        origin_hrn = Credential(string=creds[0]).get_gid_caller().get_hrn()
+        origin_hrn = Credential(string=creds).get_gid_caller().get_hrn()
+        # origin_hrn = Credential(string=creds[0]).get_gid_caller().get_hrn()
 
     conn = getEucaConnection()
 
@@ -505,6 +509,11 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
     schemaXML = ET.parse(EUCALYPTUS_RSPEC_SCHEMA)
     rspecValidator = ET.RelaxNG(schemaXML)
     rspecXML = ET.XML(xml)
+    for network in rspecXML.iterfind("./network"):
+        if network.get('id') != cloud['name']:
+            # Throw away everything except my own RSpec
+            # sfa_logger().error("CreateSliver: deleting %s from rspec"%network.get('id'))
+            network.getparent().remove(network)
     if not rspecValidator(rspecXML):
         error = rspecValidator.error_log.last_error
         message = '%s (line %s)' % (error.message, error.line) 
@@ -521,7 +530,7 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
     pendingRmInst = []
     for sliceInst in s.instances:
         pendingRmInst.append(sliceInst.instance_id)
-    existingInstGroup = rspecXML.findall('.//euca_instances')
+    existingInstGroup = rspecXML.findall(".//euca_instances")
     for instGroup in existingInstGroup:
         for existingInst in instGroup:
             if existingInst.get('id') in pendingRmInst:
@@ -533,7 +542,7 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
     conn.terminate_instances(pendingRmInst)
 
     # Process new instance requests
-    requests = rspecXML.findall('.//request')
+    requests = rspecXML.findall(".//request")
     if requests:
         # Get all the public keys associate with slice.
         pubKeys = getKeysForSlice(s.slice_hrn)
@@ -567,6 +576,19 @@ def CreateSliver(api, xrn, creds, xml, users, call_id):
     # with enough data for the client to understand what's happened
     return xml
 
+def GetVersion(api):
+    xrn=Xrn(api.hrn)
+    request_rspec_versions = [dict(sfa_rspec_version)]
+    ad_rspec_versions = [dict(sfa_rspec_version)]
+    version_more = {'interface':'aggregate',
+                    'testbed':'myplc',
+                    'hrn':xrn.get_hrn(),
+                    'request_rspec_versions': request_rspec_versions,
+                    'ad_rspec_versions': ad_rspec_versions,
+                    'default_ad_rspec': dict(sfa_rspec_version)
+                    }
+    return version_core(version_more)
+
 def main():
     init_server()
 
index 5fcd6ef..428b7a0 100644 (file)
@@ -79,7 +79,8 @@ def __get_registry_objects(slice_xrn, creds, users):
 
         slice = {}
         
-        extime = utcparse(Credential(string=creds[0]).get_expiration())
+        # get_expiration always returns a normalized datetime - no need to utcparse
+        extime = Credential(string=creds[0]).get_expiration()
         # If the expiration time is > 60 days from now, set the expiration time to 60 days from now
         if extime > datetime.datetime.utcnow() + datetime.timedelta(days=60):
             extime = datetime.datetime.utcnow() + datetime.timedelta(days=60)
@@ -211,9 +212,6 @@ def CreateSliver(api, slice_xrn, creds, rspec_string, users, call_id):
         else:
             existing_slice_attributes.append(slice_tag)  
          
-    #api.logger.debug("requested slice attributes: %s" % str(requested_slice_attributes))
-    #api.logger.debug("removed slice attributes: %s" % str(removed_slice_attributes))
-    #api.logger.debug("existing slice attributes: %s" % str(existing_slice_attributes))
     try:
         if peer:
             api.plshell.UnBindObjectFromPeer(api.plauth, 'slice', slice['slice_id'], peer)
index bc61ce3..670a01e 100644 (file)
@@ -1,5 +1,7 @@
 start = RSpec
 RSpec = element RSpec {
+    attribute expires { xsd:NMTOKEN },
+    attribute generated { xsd:NMTOKEN },
     attribute type { xsd:NMTOKEN },
     network
 }
index d7a85b4..d8be05e 100644 (file)
@@ -5,10 +5,18 @@
   </start>
   <define name="RSpec">
     <element name="RSpec">
+      <attribute name="expires">
+        <data type="NMTOKEN"/>
+      </attribute>
+      <attribute name="generated">
+        <data type="NMTOKEN"/>
+      </attribute>
       <attribute name="type">
         <data type="NMTOKEN"/>
       </attribute>
-      <ref name="network"/>
+      <oneOrMore>
+        <ref name="network"/>
+      </oneOrMore>
     </element>
   </define>
   <define name="network">
index 4ac41a4..995b837 100644 (file)
@@ -44,7 +44,7 @@ def _call_id_supported(api, server):
         code_tag_parts = code_tag.split("-")
 
         version_parts = code_tag_parts[0].split(".")
-        major, minor = version_parts[0], version_parts[1]
+        major, minor = version_parts[0:2]
         rev = code_tag_parts[1]
         if int(major) > 1:
             if int(minor) > 0 or int(rev) > 20:
@@ -128,6 +128,7 @@ def ListResources(api, creds, options, call_id):
         # unless the caller is the aggregate's SM
         if caller_hrn == aggregate and aggregate != api.hrn:
             continue
+
         # get the rspec from the aggregate
         server = api.aggregates[aggregate]
         #threads.run(server.ListResources, credentials, my_opts, call_id)
@@ -155,20 +156,20 @@ def ListResources(api, creds, options, call_id):
 def CreateSliver(api, xrn, creds, rspec_str, users, call_id):
 
     def _CreateSliver(server, xrn, credential, rspec, users, call_id):
+        try:
             # Need to call GetVersion at an aggregate to determine the supported 
             # rspec type/format beofre calling CreateSliver at an Aggregate. 
-            server_version = _get_server_version(api, server)    
-            if 'sfa' not in aggregate_version and 'geni_api' in aggregate_version:
+            server_version = api.get_cached_server_version(server)    
+            if 'sfa' not in server_version and 'geni_api' in server_version:
                 # sfa aggregtes support both sfa and pg rspecs, no need to convert
-                # if aggregate supports sfa rspecs. othewise convert to pg rspec
+                # if aggregate supports sfa rspecs. otherwise convert to pg rspec
                 rspec = RSpecConverter.to_pg_rspec(rspec)
             args = [xrn, credential, rspec, users]
             if _call_id_supported(api, server):
                 args.append(call_id)
-            try:
-                return server.CreateSliver(*args)
-            except Exception, e:
-                api.logger.warn("CreateSliver failed at %s: %s" %(server.url, str(e)))
+            return server.CreateSliver(*args)
+        except: 
+            logger.log_exc('Something wrong in _CreateSliver with URL %s'%server.url)
 
     if Callids().already_handled(call_id): return ""
     # Validate the RSpec against PlanetLab's schema --disabled for now
index 76a93ed..1cd1933 100644 (file)
@@ -99,19 +99,22 @@ class Aggregate:
                 slivers = []
                 tags = self.api.plshell.GetSliceTags(self.api.plauth, slice['slice_tag_ids'])
                 for node_id in slice['node_ids']:
-                    sliver = {}
-                    sliver['hostname'] = self.nodes[node_id]['hostname']
-                    sliver['tags'] = []
-                    slivers.append(sliver)
-                    for tag in tags:
-                        # if tag isn't bound to a node then it applies to all slivers
-                        # and belongs in the <sliver_defaults> tag
-                        if not tag['node_id']:
-                            rspec.add_default_sliver_attribute(tag['tagname'], tag['value'], self.api.hrn)
-                        else:
-                            tag_host = self.nodes[tag['node_id']]['hostname']
-                            if tag_host == sliver['hostname']:
-                                sliver['tags'].append(tag)
+                    try:
+                        sliver = {}
+                        sliver['hostname'] = self.nodes[node_id]['hostname']
+                        sliver['tags'] = []
+                        slivers.append(sliver)
+                        for tag in tags:
+                            # if tag isn't bound to a node then it applies to all slivers
+                            # and belongs in the <sliver_defaults> tag
+                            if not tag['node_id']:
+                                rspec.add_default_sliver_attribute(tag['tagname'], tag['value'], self.api.hrn)
+                            else:
+                                tag_host = self.nodes[tag['node_id']]['hostname']
+                                if tag_host == sliver['hostname']:
+                                    sliver['tags'].append(tag)
+                    except:
+                        self.api.logger.log_exc('unable to add sliver %s to node %s' % (slice['name'], node_id)) 
                 rspec.add_slivers(slivers, sliver_urn=slice_xrn)
 
         return rspec.toxml(cleanup=True)          
index 45386ea..3dd12d1 100755 (executable)
@@ -28,7 +28,6 @@ from sfa.trust.trustedroot import *
 from sfa.trust.hierarchy import *
 from sfa.util.xrn import Xrn
 from sfa.plc.api import *
-from sfa.util.sfalogging import logger
 from sfa.trust.gid import create_uuid
 from sfa.plc.sfaImport import sfaImport
 
@@ -147,7 +146,7 @@ def main():
     # start importing 
     for site in sites:
         site_hrn = interface_hrn + "." + site['login_base']
-        logger.info("Importing site: %s" % site_hrn)
+        sfaImporter.logger.info("Importing site: %s" % site_hrn)
 
         # import if hrn is not in list of existing hrns or if the hrn exists
         # but its not a site record
index ee34833..e9d0940 100644 (file)
@@ -135,12 +135,16 @@ class sfaImport:
             # to planetlab
             keys = self.shell.GetKeys(self.plc_auth, key_ids)
             key = keys[0]['key']
-            pkey = convert_public_key(key)
+            pkey = None
+            try:
+                pkey = convert_public_key(key)
+            except:
+                self.logger.warn('unable to convert public key for %s' % hrn) 
             if not pkey:
                 pkey = Keypair(create=True)
         else:
             # the user has no keys
-            self.logger.warning("Import: person %s does not have a PL public key"%hrn)
+            self.logger.warn("Import: person %s does not have a PL public key"%hrn)
             # if a key is unavailable, then we still need to put something in the
             # user's GID. So make one up.
             pkey = Keypair(create=True)
index 6ee30db..e86e13c 100755 (executable)
@@ -74,7 +74,14 @@ class PGRSpec(RSpec):
         networks = self.xml.xpath('//rspecv2:node[@component_manager_uuid]/@component_manager_uuid', namespaces=self.namespaces)
         return set(networks)
 
-    def get_node_elements(self):
+    def get_node_element(self, hostname, network=None):
+        nodes = self.xml.xpath('//rspecv2:node[@component_id[contains(., "%s")]]' % hostname, namespaces=self.namespaces)
+        if isinstance(nodes,list) and nodes:
+            return nodes[0]
+        else:
+            return None
+
+    def get_node_elements(self, network=None):
         nodes = self.xml.xpath('//rspecv2:node | //node', namespaces=self.namespaces)
         return nodes
 
@@ -94,15 +101,41 @@ class PGRSpec(RSpec):
 
     def get_nodes_without_slivers(self, network=None):
         return []
+
+    def get_sliver_attributes(self, hostname, network=None):
+        node = self.get_node_element(hostname, network)
+        sliver = node.xpath('//rspecv2:sliver_type', namespaces=self.namespaces)
+        if sliver is not None and isinstance(sliver, list):
+            sliver = sliver[0]
+        return self.attributes_list(sliver)
    
     def get_slice_attributes(self, network=None):
-        return []
+        slice_attributes = []
+        nodes_with_slivers = self.get_nodes_with_slivers(network)
+        from sfa.util.sfalogging import logger 
+        # TODO: default sliver attributes in the PG rspec?
+        default_ns_prefix = self.namespaces['rspecv2']
+        for node in nodes_with_slivers:
+            sliver_attributes = self.get_sliver_attributes(node, network)
+            for sliver_attribute in sliver_attributes:
+                name=str(sliver_attribute[0]) 
+                value=str(sliver_attribute[1])
+                # we currently only suppor the <initscript> and <flack> attributes 
+                if  'info' in name:
+                    value = ",".join(["%s=%s" %(a,b) for (a,b) in sliver_attribute[2].items()])
+                    attribute = {'name': 'flack_info', 'value': value, 'node_id': node}
+                    slice_attributes.append(attribute) 
+                elif 'initscript' in name: 
+                    attribute = {'name': 'initscript', 'value': value, 'node_id': node}
+                    slice_attributes.append(attribute) 
+
+        return slice_attributes
 
     def attributes_list(self, elem):
         opts = []
         if elem is not None:
             for e in elem:
-                opts.append((e.tag, e.text))
+                opts.append((e.tag, e.text, e.attrib))
         return opts
 
     def get_default_sliver_attributes(self, network=None):
index fad4668..ceed5af 100644 (file)
 # Credentials are signed XML files that assign a subject gid privileges to an object gid
 ##
 
-### $Id$
-### $URL$
-
 import os
+from types import StringTypes
 import datetime
-from sfa.util.sfatime import utcparse
+from StringIO import StringIO
 from tempfile import mkstemp
 from xml.dom.minidom import Document, parseString
 from lxml import etree
-from dateutil.parser import parse
-from StringIO import StringIO
+
 from sfa.util.faults import *
 from sfa.util.sfalogging import logger
+from sfa.util.sfatime import utcparse
 from sfa.trust.certificate import Keypair
 from sfa.trust.credential_legacy import CredentialLegacy
 from sfa.trust.rights import Right, Rights
@@ -347,22 +345,26 @@ class Credential(object):
 
             
     ##
-    # Expiration: an absolute UTC time of expiration (as either an int or datetime)
+    # Expiration: an absolute UTC time of expiration (as either an int or string or datetime)
     # 
     def set_expiration(self, expiration):
-        if isinstance(expiration, int):
+        if isinstance(expiration, (int,float)):
             self.expiration = datetime.datetime.fromtimestamp(expiration)
-        else:
+        elif isinstance (expiration, datetime.datetime):
             self.expiration = expiration
-            
+        elif isinstance (expiration, StringTypes):
+            self.expiration = utcparse (expiration)
+        else:
+            logger.error ("unexpected input type in Credential.set_expiration")
 
     ##
-    # get the lifetime of the credential (in datetime format)
-
+    # get the lifetime of the credential (always in datetime format)
+    #
     def get_expiration(self):
         if not self.expiration:
             self.decode()
-        return utcparse(self.expiration)
+        # at this point self.expiration is normalized as a datetime - DON'T call utcparse again
+        return self.expiration
 
     ##
     # For legacy sake
@@ -645,7 +647,7 @@ class Credential(object):
         
 
         self.set_refid(cred.getAttribute("xml:id"))
-        self.set_expiration(parse(getTextNode(cred, "expires")))
+        self.set_expiration(utcparse(getTextNode(cred, "expires")))
         self.gidCaller = GID(string=getTextNode(cred, "owner_gid"))
         self.gidObject = GID(string=getTextNode(cred, "target_gid"))   
 
@@ -963,6 +965,6 @@ class Credential(object):
 
         if self.parent and dump_parents:
             result += "\nPARENT"
-            result += self.parent.dump(True)
+            result += self.parent.dump_string(True)
 
         return result
index 7650d11..b881a1f 100644 (file)
@@ -180,7 +180,7 @@ class GID(Certificate):
         print self.dump_string(*args,**kwargs)
 
     def dump_string(self, indent=0, dump_parents=False):
-        result="GID\n"
+        result=" "*(indent-2) + "GID\n"
         result += " "*indent + "hrn:" + str(self.get_hrn()) +"\n"
         result += " "*indent + "urn:" + str(self.get_urn()) +"\n"
         result += " "*indent + "uuid:" + str(self.get_uuid()) + "\n"
index 93e3073..3ec0350 100644 (file)
@@ -28,6 +28,14 @@ class _SfaLogger:
             # This is usually a permissions error becaue the file is
             # owned by root, but httpd is trying to access it.
             tmplogfile=os.getenv("TMPDIR", "/tmp") + os.path.sep + os.path.basename(logfile)
+            # In strange uses, 2 users on same machine might use same code,
+            # meaning they would clobber each others files
+            # We could (a) rename the tmplogfile, or (b)
+            # just log to the console in that case.
+            # Here we default to the console.
+            if os.path.exists(tmplogfile) and not os.access(tmplogfile,os.W_OK):
+                loggername = loggername + "-console"
+                handler = logging.StreamHandler()
             handler=logging.handlers.RotatingFileHandler(tmplogfile,maxBytes=1000000, backupCount=5) 
         handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
         self.logger=logging.getLogger(loggername)
index 901b4e0..11cc566 100644 (file)
@@ -1,10 +1,26 @@
+from types import StringTypes
 import dateutil.parser
+import datetime
 
-def utcparse(str):
+from sfa.util.sfalogging import logger
+
+def utcparse(input):
     """ Translate a string into a time using dateutil.parser.parse but make sure it's in UTC time and strip
-    the timezone, so that it's compatible with normal datetime.datetime objects"""
+the timezone, so that it's compatible with normal datetime.datetime objects.
+
+For safety this can also handle inputs that are either timestamps, or datetimes
+"""
     
-    t = dateutil.parser.parse(str)
-    if not t.utcoffset() is None:
-        t = t.utcoffset() + t.replace(tzinfo=None)
-    return t
+    if isinstance (input, datetime.datetime):
+        logger.warn ("argument to utcparse already a datetime - doing nothing")
+        return input
+    elif isinstance (input, StringTypes):
+        t = dateutil.parser.parse(input)
+        if t.utcoffset() is not None:
+            t = t.utcoffset() + t.replace(tzinfo=None)
+        return t
+    elif isinstance (input, (int,float)):
+        return datetime.datetime.fromtimestamp(input)
+    else:
+        logger.error("Unexpected type in utcparse [%s]"%type(input))
+
index 331f847..4ce578f 100755 (executable)
@@ -2,6 +2,7 @@ import threading
 import traceback
 import time
 from Queue import Queue
+from sfa.util.sfa.logging import logger
 
 def ThreadedMethod(callable, results, errors):
     """
@@ -15,6 +16,7 @@ def ThreadedMethod(callable, results, errors):
                 try:
                     results.put(callable(*args, **kwds))
                 except Exception, e:
+                    logger.log_exc('ThreadManager: Error in thread: ')
                     errors.put(traceback.format_exc())
                     
         thread = ThreadInstance()