added InvalidXML and InvalidXMLElement faults
[sfa.git] / sfa / client / sfi.py
index 750d873..83a66f9 100755 (executable)
@@ -6,26 +6,28 @@ import sys
 sys.path.append('.')
 import os, os.path
 import tempfile
 sys.path.append('.')
 import os, os.path
 import tempfile
-import traceback
 import socket
 import socket
-import random
 import datetime
 import datetime
-import zlib
+import codecs
+import pickle
 from lxml import etree
 from StringIO import StringIO
 from lxml import etree
 from StringIO import StringIO
-from types import StringTypes, ListType
 from optparse import OptionParser
 from optparse import OptionParser
-from sfa.util.sfalogging import info_logger
+from sfa.client.client_helper import pg_users_arg, sfa_users_arg
+from sfa.util.sfalogging import sfi_logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.gid import GID
 from sfa.trust.credential import Credential
 from sfa.util.sfaticket import SfaTicket
 from sfa.util.record import SfaRecord, UserRecord, SliceRecord, NodeRecord, AuthorityRecord
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.gid import GID
 from sfa.trust.credential import Credential
 from sfa.util.sfaticket import SfaTicket
 from sfa.util.record import SfaRecord, UserRecord, SliceRecord, NodeRecord, AuthorityRecord
-from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn
+from sfa.rspecs.rspec import RSpec
+from sfa.rspecs.rspec_converter import RSpecConverter
+from sfa.util.xrn import get_leaf, get_authority, hrn_to_urn
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.util.config import Config
 from sfa.util.version import version_core
 from sfa.util.cache import Cache
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.util.config import Config
 from sfa.util.version import version_core
 from sfa.util.cache import Cache
+from sfa.rspecs.version_manager import VersionManager
 
 AGGREGATE_PORT=12346
 CM_PORT=12346
 
 AGGREGATE_PORT=12346
 CM_PORT=12346
@@ -76,23 +78,51 @@ def filter_records(type, records):
 
 
 # save methods
 
 
 # save methods
+def save_variable_to_file(var, filename, format="text"):
+    f = open(filename, "w")
+    if format == "text":
+        f.write(str(var))
+    elif format == "pickled":
+        f.write(pickle.dumps(var))
+    else:
+        # this should never happen
+        print "unknown output format", format
+
+
 def save_rspec_to_file(rspec, filename):
     if not filename.endswith(".rspec"):
         filename = filename + ".rspec"
 def save_rspec_to_file(rspec, filename):
     if not filename.endswith(".rspec"):
         filename = filename + ".rspec"
-
     f = open(filename, 'w')
     f.write(rspec)
     f.close()
     return
 
     f = open(filename, 'w')
     f.write(rspec)
     f.close()
     return
 
-def save_records_to_file(filename, recordList):
-    index = 0
-    for record in recordList:
-        if index > 0:
-            save_record_to_file(filename + "." + str(index), record)
-        else:
-            save_record_to_file(filename, record)
-        index = index + 1
+def save_records_to_file(filename, recordList, format="xml"):
+    if format == "xml":
+        index = 0
+        for record in recordList:
+            if index > 0:
+                save_record_to_file(filename + "." + str(index), record)
+            else:
+                save_record_to_file(filename, record)
+            index = index + 1
+    elif format == "xmllist":
+        f = open(filename, "w")
+        f.write("<recordlist>\n")
+        for record in recordList:
+            record = SfaRecord(dict=record)
+            f.write('<record hrn="' + record.get_name() + '" type="' + record.get_type() + '" />\n')
+        f.write("</recordlist>\n")
+        f.close()
+    elif format == "hrnlist":
+        f = open(filename, "w")
+        for record in recordList:
+            record = SfaRecord(dict=record)
+            f.write(record.get_name() + "\n")
+        f.close()
+    else:
+        # this should never happen
+        print "unknown output format", format
 
 def save_record_to_file(filename, record):
     if record['type'] in ['user']:
 
 def save_record_to_file(filename, record):
     if record['type'] in ['user']:
@@ -106,13 +136,17 @@ def save_record_to_file(filename, record):
     else:
         record = SfaRecord(dict=record)
     str = record.save_to_string()
     else:
         record = SfaRecord(dict=record)
     str = record.save_to_string()
-    file(filename, "w").write(str)
+    f=codecs.open(filename, encoding='utf-8',mode="w")
+    f.write(str)
+    f.close()
     return
 
 
 # load methods
 def load_record_from_file(filename):
     return
 
 
 # load methods
 def load_record_from_file(filename):
-    str = file(filename, "r").read()
+    f=codecs.open(filename, encoding="utf-8", mode="r")
+    str = f.read()
+    f.close()
     record = SfaRecord(string=str)
     return record
 
     record = SfaRecord(string=str)
     return record
 
@@ -134,6 +168,8 @@ class Sfi:
         for opt in Sfi.required_options:
             if not hasattr(options,opt): setattr(options,opt,None)
         if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
         for opt in Sfi.required_options:
             if not hasattr(options,opt): setattr(options,opt,None)
         if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
+        # xxx oops, this is dangerous, sounds like ww sometimes have discrepency
+        # would be safer to remove self.sfi_dir altogether
         self.sfi_dir = options.sfi_dir
         self.options = options
         self.slicemgr = None
         self.sfi_dir = options.sfi_dir
         self.options = options
         self.slicemgr = None
@@ -141,7 +177,8 @@ class Sfi:
         self.user = None
         self.authority = None
         self.hashrequest = False
         self.user = None
         self.authority = None
         self.hashrequest = False
-        self.logger = info_logger
+        self.logger = sfi_logger
+        self.logger.enable_console()
    
     def create_cmd_parser(self, command, additional_cmdargs=None):
         cmdargs = {"list": "authority",
    
     def create_cmd_parser(self, command, additional_cmdargs=None):
         cmdargs = {"list": "authority",
@@ -151,6 +188,7 @@ class Sfi:
                   "update": "record",
                   "aggregates": "[name]",
                   "registries": "[name]",
                   "update": "record",
                   "aggregates": "[name]",
                   "registries": "[name]",
+                  "create_gid": "[name]",
                   "get_gid": [],  
                   "get_trusted_certs": "cred",
                   "slices": "",
                   "get_gid": [],  
                   "get_trusted_certs": "cred",
                   "slices": "",
@@ -214,15 +252,27 @@ class Sfi:
                                 help="optional component information", default=None)
 
 
                                 help="optional component information", default=None)
 
 
-        if command in ("resources", "show", "list"):
+        # 'create' does return the new rspec, makes sense to save that too
+        if command in ("resources", "show", "list", "create_gid", 'create'):
            parser.add_option("-o", "--output", dest="file",
                             help="output XML to file", metavar="FILE", default=None)
            parser.add_option("-o", "--output", dest="file",
                             help="output XML to file", metavar="FILE", default=None)
-        
+
         if command in ("show", "list"):
            parser.add_option("-f", "--format", dest="format", type="choice",
                              help="display format ([text]|xml)", default="text",
                              choices=("text", "xml"))
 
         if command in ("show", "list"):
            parser.add_option("-f", "--format", dest="format", type="choice",
                              help="display format ([text]|xml)", default="text",
                              choices=("text", "xml"))
 
+           parser.add_option("-F", "--fileformat", dest="fileformat", type="choice",
+                             help="output file format ([xml]|xmllist|hrnlist)", default="xml",
+                             choices=("xml", "xmllist", "hrnlist"))
+
+        if command in ("status", "version"):
+           parser.add_option("-o", "--output", dest="file",
+                            help="output dictionary to file", metavar="FILE", default=None)
+           parser.add_option("-F", "--fileformat", dest="fileformat", type="choice",
+                             help="output file format ([text]|pickled)", default="text",
+                             choices=("text","pickled"))
+
         if command in ("delegate"):
            parser.add_option("-u", "--user",
                             action="store_true", dest="delegate_user", default=False,
         if command in ("delegate"):
            parser.add_option("-u", "--user",
                             action="store_true", dest="delegate_user", default=False,
@@ -272,13 +322,15 @@ class Sfi:
         parser.add_option("-k", "--hashrequest",
                          action="store_true", dest="hashrequest", default=False,
                          help="Create a hash of the request that will be authenticated on the server")
         parser.add_option("-k", "--hashrequest",
                          action="store_true", dest="hashrequest", default=False,
                          help="Create a hash of the request that will be authenticated on the server")
+        parser.add_option("-t", "--timeout", dest="timeout", default=None,
+                         help="Amout of time tom wait before timing out the request")
         parser.disable_interspersed_args()
 
         return parser
         
 
     def read_config(self):
         parser.disable_interspersed_args()
 
         return parser
         
 
     def read_config(self):
-       config_file = self.options.sfi_dir + os.sep + "sfi_config"
+       config_file = os.path.join(self.options.sfi_dir,"sfi_config")
        try:
           config = Config (config_file)
        except:
        try:
           config = Config (config_file)
        except:
@@ -346,16 +398,16 @@ class Sfi:
        self.cert_file = cert_file
        self.cert = GID(filename=cert_file)
        self.logger.info("Contacting Registry at: %s"%self.reg_url)
        self.cert_file = cert_file
        self.cert = GID(filename=cert_file)
        self.logger.info("Contacting Registry at: %s"%self.reg_url)
-       self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)  
+       self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)  
        self.logger.info("Contacting Slice Manager at: %s"%self.sm_url)
        self.logger.info("Contacting Slice Manager at: %s"%self.sm_url)
-       self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, self.options)
+       self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
        return
 
     def get_cached_server_version(self, server):
         # check local cache first
         cache = None
         version = None 
        return
 
     def get_cached_server_version(self, server):
         # check local cache first
         cache = None
         version = None 
-        cache_file = self.sfi_dir + os.path.sep + 'sfi_cache.dat'
+        cache_file = os.path.join(self.options.sfi_dir,'sfi_cache.dat')
         cache_key = server.url + "-version"
         try:
             cache = Cache(cache_file)
         cache_key = server.url + "-version"
         try:
             cache = Cache(cache_file)
@@ -370,6 +422,8 @@ class Sfi:
             version = server.GetVersion()
             # cache version for 24 hours
             cache.add(cache_key, version, ttl= 60*60*24)
             version = server.GetVersion()
             # cache version for 24 hours
             cache.add(cache_key, version, ttl= 60*60*24)
+            self.logger.info("Updating cache file %s" % cache_file)
+            cache.save_to_file(cache_file)
 
 
         return version   
 
 
         return version   
@@ -380,7 +434,7 @@ class Sfi:
         Returns true if server support the optional call_id arg, false otherwise. 
         """
         server_version = self.get_cached_server_version(server)
         Returns true if server support the optional call_id arg, false otherwise. 
         """
         server_version = self.get_cached_server_version(server)
-        if 'sfa' in server_version:
+        if 'sfa' in server_version and 'code_tag' in server_version:
             code_tag = server_version['code_tag']
             code_tag_parts = code_tag.split("-")
             
             code_tag = server_version['code_tag']
             code_tag_parts = code_tag.split("-")
             
@@ -439,7 +493,7 @@ class Sfi:
             self.logger.info("Getting Registry issued cert")
             self.read_config()
             # *hack.  need to set registyr before _get_gid() is called 
             self.logger.info("Getting Registry issued cert")
             self.read_config()
             # *hack.  need to set registyr before _get_gid() is called 
-            self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)
+            self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
             gid = self._get_gid(type='user')
             self.registry = None 
             self.logger.info("Writing certificate to %s"%cert_file)
             gid = self._get_gid(type='user')
             self.registry = None 
             self.logger.info("Writing certificate to %s"%cert_file)
@@ -479,7 +533,6 @@ class Sfi:
             hrn = self.user
  
         gidfile = os.path.join(self.options.sfi_dir, hrn + ".gid")
             hrn = self.user
  
         gidfile = os.path.join(self.options.sfi_dir, hrn + ".gid")
-        print gidfile
         gid = self.get_cached_gid(gidfile)
         if not gid:
             user_cred = self.get_user_cred()
         gid = self.get_cached_gid(gidfile)
         if not gid:
             user_cred = self.get_user_cred()
@@ -613,7 +666,7 @@ class Sfi:
         host_parts = host.split('/')
         host_parts[0] = host_parts[0] + ":" + str(port)
         url =  "http://%s" %  "/".join(host_parts)    
         host_parts = host.split('/')
         host_parts[0] = host_parts[0] + ":" + str(port)
         url =  "http://%s" %  "/".join(host_parts)    
-        return xmlrpcprotocol.get_server(url, keyfile, certfile, self.options)
+        return xmlrpcprotocol.get_server(url, keyfile, certfile, timeout=self.options.timeout, verbose=self.options.debug)
 
     # xxx opts could be retrieved in self.options
     def get_server_from_opts(self, opts):
 
     # xxx opts could be retrieved in self.options
     def get_server_from_opts(self, opts):
@@ -638,7 +691,22 @@ class Sfi:
   
     def dispatch(self, command, cmd_opts, cmd_args):
         return getattr(self, command)(cmd_opts, cmd_args)
   
     def dispatch(self, command, cmd_opts, cmd_args):
         return getattr(self, command)(cmd_opts, cmd_args)
+
+    def create_gid(self, opts, args):
+        if len(args) < 1:
+            self.print_help()
+            sys.exit(1)
+        target_hrn = args[0]
+        user_cred = self.get_user_cred().save_to_string(save_parents=True)
+        gid = self.registry.CreateGid(user_cred, target_hrn, self.cert.save_to_string())
+        if opts.file:
+            filename = opts.file
+        else:
+            filename = os.sep.join([self.sfi_dir, '%s.gid' % target_hrn])
+        self.logger.info("writing %s gid to %s" % (target_hrn, filename))
+        GID(string=gid).save_to_file(filename)
+         
+     
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
@@ -650,17 +718,14 @@ class Sfi:
             list = self.registry.List(hrn, user_cred)
         except IndexError:
             raise Exception, "Not enough parameters for the 'list' command"
             list = self.registry.List(hrn, user_cred)
         except IndexError:
             raise Exception, "Not enough parameters for the 'list' command"
-          
-        # filter on person, slice, site, node, etc.  
+
+        # filter on person, slice, site, node, etc.
         # THis really should be in the self.filter_records funct def comment...
         list = filter_records(opts.type, list)
         for record in list:
         # THis really should be in the self.filter_records funct def comment...
         list = filter_records(opts.type, list)
         for record in list:
-            print "%s (%s)" % (record['hrn'], record['type'])     
+            print "%s (%s)" % (record['hrn'], record['type'])
         if opts.file:
         if opts.file:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_records_to_file(file, list)
+            save_records_to_file(opts.file, list, opts.fileformat)
         return
     
     # show named registry record
         return
     
     # show named registry record
@@ -689,12 +754,8 @@ class Sfi:
                 record.dump()  
             else:
                 print record.save_to_string() 
                 record.dump()  
             else:
                 print record.save_to_string() 
         if opts.file:
         if opts.file:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_records_to_file(file, records)
+            save_records_to_file(opts.file, records, opts.fileformat)
         return
     
     def delegate(self, opts, args):
         return
     
     def delegate(self, opts, args):
@@ -854,6 +915,8 @@ class Sfi:
             version=server.GetVersion()
         for (k,v) in version.iteritems():
             print "%-20s: %s"%(k,v)
             version=server.GetVersion()
         for (k,v) in version.iteritems():
             print "%-20s: %s"%(k,v)
+        if opts.file:
+            save_variable_to_file(version, opts.file, opts.fileformat)
 
     # list instantiated slices
     def slices(self, opts, args):
 
     # list instantiated slices
     def slices(self, opts, args):
@@ -891,7 +954,15 @@ class Sfi:
             delegated_cred = self.delegate_cred(cred, get_authority(self.authority))
             creds.append(delegated_cred)
         if opts.rspec_version:
             delegated_cred = self.delegate_cred(cred, get_authority(self.authority))
             creds.append(delegated_cred)
         if opts.rspec_version:
-            call_options['rspec_version'] = opts.rspec_version 
+            version_manager = VersionManager()
+            server_version = self.get_cached_server_version(server)
+            if 'sfa' in server_version:
+                # just request the version the client wants 
+                call_options['rspec_version'] = version_manager.get_version(opts.rspec_version).to_dict()
+            else:
+                # this must be a protogeni aggregate. We should request a v2 ad rspec
+                # regardless of what the client user requested 
+                call_options['rspec_version'] = version_manager.get_version('ProtoGENI 2').to_dict()     
         #panos add info options
         if opts.info:
             call_options['info'] = opts.info 
         #panos add info options
         if opts.info:
             call_options['info'] = opts.info 
@@ -900,59 +971,58 @@ class Sfi:
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
         result = server.ListResources(*call_args)
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
         result = server.ListResources(*call_args)
-        format = opts.format
         if opts.file is None:
         if opts.file is None:
-            display_rspec(result, format)
+            display_rspec(result, opts.format)
         else:
         else:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_rspec_to_file(result, file)
+            save_rspec_to_file(result, opts.file)
         return
         return
-    
+
     # created named slice with given rspec
     def create(self, opts, args):
     # created named slice with given rspec
     def create(self, opts, args):
+        server = self.get_server_from_opts(opts)
+        server_version = self.get_cached_server_version(server)
         slice_hrn = args[0]
         slice_hrn = args[0]
-        slice_urn = hrn_to_urn(slice_hrn, 'slice') 
+        slice_urn = hrn_to_urn(slice_hrn, 'slice')
         user_cred = self.get_user_cred()
         slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True)
         user_cred = self.get_user_cred()
         slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True)
-        creds = [slice_cred]
-        if opts.delegate:
-            delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority))
-            creds.append(delegated_cred)
+        # delegate the cred to the callers root authority
+        delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority)+'.slicemanager')
+        #delegated_cred = self.delegate_cred(slice_cred, get_authority(slice_hrn))
+        #creds.append(delegated_cred)
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
 
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
 
+        # need to pass along user keys to the aggregate.
         # users = [
         #  { urn: urn:publicid:IDN+emulab.net+user+alice
         # users = [
         #  { urn: urn:publicid:IDN+emulab.net+user+alice
-        #    keys: [<ssh key A>, <ssh key B>] 
+        #    keys: [<ssh key A>, <ssh key B>]
         #  }]
         users = []
         #  }]
         users = []
-        server = self.get_server_from_opts(opts)
-        version = server.GetVersion()
-        if 'sfa' not in version:
-            # need to pass along user keys if this request is going to a ProtoGENI aggregate 
-            # ProtoGeni Aggregates will only install the keys of the user that is issuing the
-            # request. So we will only pass in one user that contains the keys for all
-            # users of the slice 
-            user = {'urn': user_cred.get_gid_caller().get_urn(),
-                    'keys': []}
-            slice_record = self.registry.Resolve(slice_urn, creds)
-            if slice_record and 'researchers' in slice_record:
-                user_hrns = slice_record['researchers']
-                user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns] 
-                user_records = self.registry.Resolve(user_urns, creds)
-                for user_record in user_records:
-                    if 'keys' in user_record:
-                        user['keys'].extend(user_record['keys'])
-            users.append(user)
-
+        slice_records = self.registry.Resolve(slice_urn, [user_cred.save_to_string(save_parents=True)])
+        if slice_records and 'researcher' in slice_records[0] and slice_records[0]['researcher']!=[]:
+            slice_record = slice_records[0]
+            user_hrns = slice_record['researcher']
+            user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns]
+            user_records = self.registry.Resolve(user_urns, [user_cred.save_to_string(save_parents=True)])
+
+            if 'sfa' not in server_version:
+                users = pg_users_arg(user_records)
+                rspec = RSpec(rspec)
+                rspec.filter({'component_manager_id': server_version['urn']})
+                rspec = RSpecConverter.to_pg_rspec(rspec.toxml(), content_type='request')
+                creds = [slice_cred]
+            else:
+                users = sfa_users_arg(user_records, slice_record)
+                creds = [slice_cred, delegated_cred]
         call_args = [slice_urn, creds, rspec, users]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
         call_args = [slice_urn, creds, rspec, users]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
-             
-        result =  server.CreateSliver(*call_args)
-        print result
+           
+        result = server.CreateSliver(*call_args)
+        if opts.file is None:
+            print result
+        else:
+            save_rspec_to_file (result, opts.file)
         return result
 
     # get a ticket for the specified slice
         return result
 
     # get a ticket for the specified slice
@@ -1089,7 +1159,10 @@ class Sfi:
         call_args = [slice_urn, creds]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
         call_args = [slice_urn, creds]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
-        print server.SliverStatus(*call_args)
+        result = server.SliverStatus(*call_args)
+        print result
+        if opts.file:
+            save_variable_to_file(result, opts.file, opts.fileformat)
 
 
     def shutdown(self, opts, args):
 
 
     def shutdown(self, opts, args):
@@ -1139,6 +1212,7 @@ class Sfi:
             self.dispatch(command, cmd_opts, cmd_args)
         except KeyError:
             self.logger.critical ("Unknown command %s"%command)
             self.dispatch(command, cmd_opts, cmd_args)
         except KeyError:
             self.logger.critical ("Unknown command %s"%command)
+            raise
             sys.exit(1)
     
         return
             sys.exit(1)
     
         return