use registry issued cert instead of self signed cert
[sfa.git] / sfa / client / sfi.py
index 1c49dc9..6114c55 100755 (executable)
@@ -23,9 +23,9 @@ from sfa.trust.credential import Credential
 from sfa.util.sfaticket import SfaTicket
 from sfa.util.record import SfaRecord, UserRecord, SliceRecord, NodeRecord, AuthorityRecord
 from sfa.util.xrn import Xrn, get_leaf, get_authority, hrn_to_urn
-from sfa.util.xmlrpcprotocol import ServerException
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.util.config import Config
+from sfa.util.version import version_core
 
 AGGREGATE_PORT=12346
 CM_PORT=12346
@@ -117,15 +117,28 @@ def load_record_from_file(filename):
     return record
 
 
+import uuid
+def unique_call_id(): return uuid.uuid4().urn
 
 class Sfi:
+    
+    required_options=['verbose',  'debug',  'registry',  'sm',  'auth',  'user']
+
+    # dummy to meet Sfi's expectations for its 'options' field
+    # i.e. s/t we can do setattr on
+    class DummyOptions:
+        pass
 
-    def __init__ (self):
+    def __init__ (self,options=None):
+        if options is None: options=Sfi.DummyOptions()
+        for opt in Sfi.required_options:
+            if not hasattr(options,opt): setattr(options,opt,None)
+        if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
+        self.options = options
         self.slicemgr = None
         self.registry = None
         self.user = None
         self.authority = None
-        self.options = None
         self.hashrequest = False
         sfa_logger_goes_to_console()
         self.logger=sfa_logger()
@@ -191,6 +204,8 @@ class Sfi:
                             default="all")
         # display formats
         if command in ("resources"):
+            parser.add_option("-r", "--rspec-version", dest="rspec_version", default="sfa 1",
+                              help="schema type and version of resulting RSpec")
             parser.add_option("-f", "--format", dest="format", type="choice",
                              help="display format ([xml]|dns|ip)", default="xml",
                              choices=("xml", "dns", "ip"))
@@ -212,9 +227,16 @@ class Sfi:
                             help="delegate slice credential", metavar="HRN", default=None)
         
         if command in ("version"):
+            parser.add_option("-a", "--aggregate", dest="aggregate",
+                             default=None, help="aggregate host")
+            parser.add_option("-p", "--port", dest="port",
+                             default=AGGREGATE_PORT, help="aggregate port")
             parser.add_option("-R","--registry-version",
-                              action="store_true", dest="probe_registry", default=False,
+                              action="store_true", dest="version_registry", default=False,
                               help="probe registry version instead of slicemgr")
+            parser.add_option("-l","--local",
+                              action="store_true", dest="version_local", default=False,
+                              help="display version of the local client")
 
         return parser
 
@@ -251,10 +273,7 @@ class Sfi:
         return parser
         
  
-    #
-    # Establish Connection to SliceMgr and Registry Servers
-    #
-    def set_servers(self):
+    def read_config(self):
        config_file = self.options.sfi_dir + os.sep + "sfi_config"
        try:
           config = Config (config_file)
@@ -270,18 +289,18 @@ class Sfi:
        errors = 0
        # Set SliceMgr URL
        if (self.options.sm is not None):
-          sm_url = self.options.sm
+          self.sm_url = self.options.sm
        elif hasattr(config, "SFI_SM"):
-          sm_url = config.SFI_SM
+          self.sm_url = config.SFI_SM
        else:
           self.logger.error("You need to set e.g. SFI_SM='http://your.slicemanager.url:12347/' in %s" % config_file)
           errors += 1 
     
        # Set Registry URL
        if (self.options.registry is not None):
-          reg_url = self.options.registry
+          self.reg_url = self.options.registry
        elif hasattr(config, "SFI_REGISTRY"):
-          reg_url = config.SFI_REGISTRY
+          self.reg_url = config.SFI_REGISTRY
        else:
           self.logger.errors("You need to set e.g. SFI_REGISTRY='http://your.registry.url:12345/' in %s" % config_file)
           errors += 1 
@@ -307,20 +326,26 @@ class Sfi:
     
        if errors:
           sys.exit(1)
-    
-    
+
+
+    #
+    # Establish Connection to SliceMgr and Registry Servers
+    #
+    def set_servers(self):
+
+       self.read_config() 
        # Get key and certificate
        key_file = self.get_key_file()
        cert_file = self.get_cert_file(key_file)
        self.key = Keypair(filename=key_file) 
        self.key_file = key_file
        self.cert_file = cert_file
-       self.cert = Certificate(filename=cert_file) 
+       self.cert = GID(filename=cert_file) 
        # Establish connection to server(s)
-       self.logger.info("Contacting Registry at: %s"%reg_url)
-       self.registry = xmlrpcprotocol.get_server(reg_url, key_file, cert_file, self.options)  
-       self.logger.info("Contacting Slice Manager at: %s"%sm_url)
-       self.slicemgr = xmlrpcprotocol.get_server(sm_url, key_file, cert_file, self.options)
+       self.logger.info("Contacting Registry at: %s"%self.reg_url)
+       self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)  
+       self.logger.info("Contacting Slice Manager at: %s"%self.sm_url)
+       self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, self.options)
 
        return
     
@@ -349,27 +374,36 @@ class Sfi:
     
     def get_cert_file(self, key_file):
     
-        file = os.path.join(self.options.sfi_dir, self.user.replace(self.authority + '.', '') + ".cert")
-        if (os.path.isfile(file)):
-            # use existing cert if it exists                     
-            return file
-        else:
-            try:
-                # attempt to use gid as the cert.  
-                gid = self._get_gid()
-                self.logger.info("Writing certificate to %s"%file)
-                gid.save_to_file(file) 
-            except:
-                # generate self signed certificate
-                k = Keypair(filename=key_file)
-                cert = Certificate(subject=self.user)
-                cert.set_pubkey(k)
-                cert.set_issuer(k, self.user)
-                cert.sign()
-                self.logger.info("Writing self-signed certificate to %s"%file)
-                cert.save_to_file(file)
-            
-            return file
+        cert_file = os.path.join(self.options.sfi_dir, self.user.replace(self.authority + '.', '') + ".cert")
+        if (os.path.isfile(cert_file)):
+            # we'd perfer to use Registry issued certs instead of self signed certs. 
+            # if this is a Registry cert (GID) then we are done 
+            gid = GID(filename=cert_file)
+            if gid.get_urn():
+                return cert_file
+
+        # generate self signed certificate
+        k = Keypair(filename=key_file)
+        cert = Certificate(subject=self.user)
+        cert.set_pubkey(k)
+        cert.set_issuer(k, self.user)
+        cert.sign()
+        self.logger.info("Writing self-signed certificate to %s"%cert_file)
+        cert.save_to_file(cert_file)
+        # try to get registry issued cert
+        try:
+            self.logger.info("Getting Registry issued cert")
+            self.read_config()
+            # *hack.  need to set registyr before _get_gid() is called 
+            self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)
+            gid = self._get_gid(type='user')
+            self.registry = None 
+            self.logger.info("Writing certificate to %s"%cert_file)
+            gid.save_to_file(cert_file)
+        except:
+            self.logger.info("Failed to download Registry issued cert")
+        return cert_file
 
     def get_cached_gid(self, file):
         """
@@ -392,7 +426,7 @@ class Sfi:
         self.logger.debug("Sfi.get_gid-> %s",gid.save_to_string(save_parents=True))
         return gid
 
-    def _get_gid(self, hrn=None):
+    def _get_gid(self, hrn=None, type=None):
         """
         git_gid helper. Retrive the gid from the registry and save it to file.
         """
@@ -405,7 +439,12 @@ class Sfi:
         if not gid:
             user_cred = self.get_user_cred()
             records = self.registry.Resolve(hrn, user_cred.save_to_string(save_parents=True))
-            if not records:
+            record = None
+            if type:
+                for rec in records:
+                   if type == record['type']:
+                        record = rec 
+            if not record:
                 raise RecordNotFound(args[0])
             gid = GID(string=records[0]['gid'])
             self.logger.info("Writing gid to %s"%gidfile)
@@ -475,7 +514,7 @@ class Sfi:
        if (os.path.isfile(file)):
           return file
        else:
-          self.logger.critical("No such rspec file"%rspec)
+          self.logger.critical("No such rspec file %s"%rspec)
           sys.exit(1)
     
     def get_record_file(self, record):
@@ -554,7 +593,7 @@ class Sfi:
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
-            self.parser.print_help()
+            self.print_help()
             sys.exit(1)
         hrn = args[0]
         user_cred = self.get_user_cred().save_to_string(save_parents=True)
@@ -578,7 +617,7 @@ class Sfi:
     # show named registry record
     def show(self, opts, args):
         if len(args)!= 1:
-            self.parser.print_help()
+            self.print_help()
             sys.exit(1)
         hrn = args[0]
         user_cred = self.get_user_cred().save_to_string(save_parents=True)
@@ -662,7 +701,7 @@ class Sfi:
     def remove(self, opts, args):
         auth_cred = self.get_auth_cred().save_to_string(save_parents=True)
         if len(args)!=1:
-            self.parser.print_help()
+            self.print_help()
             sys.exit(1)
         hrn = args[0]
         type = opts.type 
@@ -674,7 +713,7 @@ class Sfi:
     def add(self, opts, args):
         auth_cred = self.get_auth_cred().save_to_string(save_parents=True)
         if len(args)!=1:
-            self.parser.print_help()
+            self.print_help()
             sys.exit(1)
         record_filepath = args[0]
         rec_file = self.get_record_file(record_filepath)
@@ -685,7 +724,7 @@ class Sfi:
     def update(self, opts, args):
         user_cred = self.get_user_cred()
         if len(args)!=1:
-            self.parser.print_help()
+            self.print_help()
             sys.exit(1)
         rec_file = self.get_record_file(args[0])
         record = load_record_from_file(rec_file)
@@ -697,7 +736,7 @@ class Sfi:
         elif record['type'] in ["slice"]:
             try:
                 cred = self.get_slice_cred(record.get_name()).save_to_string(save_parents=True)
-            except ServerException, e:
+            except xmlrpcprotocol.ServerException, e:
                # XXX smbaker -- once we have better error return codes, update this
                # to do something better than a string compare
                if "Permission error" in e.args[0]:
@@ -755,11 +794,15 @@ class Sfi:
     
 
     def version(self, opts, args):
-        if opts.probe_registry:
-            server=self.registry
+        if opts.version_local:
+            version=version_core()
         else:
-            server = self.get_server_from_opts(opts)
-        for (k,v) in server.GetVersion().items():
+            if opts.version_registry:
+                server=self.registry
+            else:
+                server = self.get_server_from_opts(opts)
+            version=server.GetVersion()
+        for (k,v) in version.iteritems():
             print "%-20s: %s"%(k,v)
 
     # list instantiated slices
@@ -773,6 +816,7 @@ class Sfi:
             delegated_cred = self.delegate_cred(user_cred, get_authority(self.authority))
             creds.append(delegated_cred)  
         server = self.get_server_from_opts(opts)
+        #results = server.ListSlices(creds, unique_call_id())
         results = server.ListSlices(creds)
         display_list(results)
         return
@@ -795,11 +839,14 @@ class Sfi:
         creds = [cred]
         if opts.delegate:
             delegated_cred = self.delegate_cred(cred, get_authority(self.authority))
-            creds.append(delegated_cred) 
-        result = server.ListResources(creds, call_options)
+            creds.append(delegated_cred)
+        if opts.rspec_version:
+            call_options['rspec_version'] = opts.rspec_version 
+        result = server.ListResources(creds, call_options,unique_call_id())
         format = opts.format
-        display_rspec(result, format)
-        if (opts.file is not None):
+        if opts.file is None:
+            display_rspec(result, format)
+        else:
             file = opts.file
             if not file.startswith(os.sep):
                 file = os.path.join(self.options.sfi_dir, file)
@@ -818,8 +865,17 @@ class Sfi:
             creds.append(delegated_cred)
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
+
+        # TODO: need to determine if this request is going to a ProtoGENI aggregate. If so
+        # we need to obtain the keys for all users in the slice  
+        # e.g. 
+        # users = [
+        #  { urn: urn:publicid:IDN+emulab.net+user+alice
+        #    keys: [<ssh key A>, <ssh key B>] 
+        #  }]
+        users = []
         server = self.get_server_from_opts(opts)
-        result =  server.CreateSliver(slice_urn, creds, rspec, [])
+        result =  server.CreateSliver(slice_urn, creds, rspec, users, unique_call_id())
         print result
         return result
 
@@ -886,7 +942,7 @@ class Sfi:
             delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority))
             creds.append(delegated_cred)
         server = self.get_server_from_opts(opts)
-        return server.DeleteSliver(slice_urn, creds)
+        return server.DeleteSliver(slice_urn, creds, unique_call_id())
     
     # start named slice
     def start(self, opts, args):
@@ -934,7 +990,7 @@ class Sfi:
             delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority))
             creds.append(delegated_cred)
         time = args[1]
-        return server.RenewSliver(slice_urn, creds, time)
+        return server.RenewSliver(slice_urn, creds, time, unique_call_id())
 
 
     def status(self, opts, args):
@@ -946,7 +1002,7 @@ class Sfi:
             delegated_cred = self.delegate_cred(slice_cred, get_authority(self.authority))
             creds.append(delegated_cred)
         server = self.get_server_from_opts(opts)
-        print server.SliverStatus(slice_urn, creds)
+        print server.SliverStatus(slice_urn, creds, unique_call_id())
 
 
     def shutdown(self, opts, args):
@@ -960,13 +1016,16 @@ class Sfi:
         server = self.get_server_from_opts(opts)
         return server.Shutdown(slice_urn, creds)         
     
+    def print_help (self):
+        self.sfi_parser.print_help()
+        self.cmd_parser.print_help()
 
     #
     # Main: parse arguments and dispatch to command
     #
     def main(self):
-        parser = self.create_parser()
-        (options, args) = parser.parse_args()
+        self.sfi_parser = self.create_parser()
+        (options, args) = self.sfi_parser.parse_args()
         self.options = options
 
         self.logger.setLevelFromOptVerbose(self.options.verbose)
@@ -978,8 +1037,8 @@ class Sfi:
             return -1
     
         command = args[0]
-        self.parser = self.create_cmd_parser(command)
-        (cmd_opts, cmd_args) = self.parser.parse_args(args[1:])
+        self.cmd_parser = self.create_cmd_parser(command)
+        (cmd_opts, cmd_args) = self.cmd_parser.parse_args(args[1:])
 
         self.set_servers()