don't resolve user records if there aren't any users
[sfa.git] / sfa / client / sfi.py
old mode 100755 (executable)
new mode 100644 (file)
index 094c6be..726138f
@@ -11,10 +11,12 @@ import socket
 import random
 import datetime
 import zlib
 import random
 import datetime
 import zlib
+import codecs
 from lxml import etree
 from StringIO import StringIO
 from types import StringTypes, ListType
 from optparse import OptionParser
 from lxml import etree
 from StringIO import StringIO
 from types import StringTypes, ListType
 from optparse import OptionParser
+
 from sfa.util.sfalogging import sfi_logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.gid import GID
 from sfa.util.sfalogging import sfi_logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.gid import GID
@@ -81,7 +83,6 @@ def filter_records(type, records):
 def save_rspec_to_file(rspec, filename):
     if not filename.endswith(".rspec"):
         filename = filename + ".rspec"
 def save_rspec_to_file(rspec, filename):
     if not filename.endswith(".rspec"):
         filename = filename + ".rspec"
-
     f = open(filename, 'w')
     f.write(rspec)
     f.close()
     f = open(filename, 'w')
     f.write(rspec)
     f.close()
@@ -108,13 +109,17 @@ def save_record_to_file(filename, record):
     else:
         record = SfaRecord(dict=record)
     str = record.save_to_string()
     else:
         record = SfaRecord(dict=record)
     str = record.save_to_string()
-    file(filename, "w").write(str)
+    f=codecs.open(filename, encoding='utf-8',mode="w")
+    f.write(str)
+    f.close()
     return
 
 
 # load methods
 def load_record_from_file(filename):
     return
 
 
 # load methods
 def load_record_from_file(filename):
-    str = file(filename, "r").read()
+    f=codecs.open(filename, encoding="utf-8", mode="r")
+    str = f.read()
+    f.close()
     record = SfaRecord(string=str)
     return record
 
     record = SfaRecord(string=str)
     return record
 
@@ -136,6 +141,8 @@ class Sfi:
         for opt in Sfi.required_options:
             if not hasattr(options,opt): setattr(options,opt,None)
         if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
         for opt in Sfi.required_options:
             if not hasattr(options,opt): setattr(options,opt,None)
         if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
+        # xxx oops, this is dangerous, sounds like ww sometimes have discrepency
+        # would be safer to remove self.sfi_dir altogether
         self.sfi_dir = options.sfi_dir
         self.options = options
         self.slicemgr = None
         self.sfi_dir = options.sfi_dir
         self.options = options
         self.slicemgr = None
@@ -154,6 +161,7 @@ class Sfi:
                   "update": "record",
                   "aggregates": "[name]",
                   "registries": "[name]",
                   "update": "record",
                   "aggregates": "[name]",
                   "registries": "[name]",
+                  "create_gid": "[name]",
                   "get_gid": [],  
                   "get_trusted_certs": "cred",
                   "slices": "",
                   "get_gid": [],  
                   "get_trusted_certs": "cred",
                   "slices": "",
@@ -217,7 +225,8 @@ class Sfi:
                                 help="optional component information", default=None)
 
 
                                 help="optional component information", default=None)
 
 
-        if command in ("resources", "show", "list"):
+        # 'create' does return the new rspec, makes sense to save that too
+        if command in ("resources", "show", "list", "create_gid", 'create'):
            parser.add_option("-o", "--output", dest="file",
                             help="output XML to file", metavar="FILE", default=None)
         
            parser.add_option("-o", "--output", dest="file",
                             help="output XML to file", metavar="FILE", default=None)
         
@@ -275,13 +284,15 @@ class Sfi:
         parser.add_option("-k", "--hashrequest",
                          action="store_true", dest="hashrequest", default=False,
                          help="Create a hash of the request that will be authenticated on the server")
         parser.add_option("-k", "--hashrequest",
                          action="store_true", dest="hashrequest", default=False,
                          help="Create a hash of the request that will be authenticated on the server")
+        parser.add_option("-t", "--timeout", dest="timeout", default=None,
+                         help="Amout of time tom wait before timing out the request")
         parser.disable_interspersed_args()
 
         return parser
         
 
     def read_config(self):
         parser.disable_interspersed_args()
 
         return parser
         
 
     def read_config(self):
-       config_file = self.options.sfi_dir + os.sep + "sfi_config"
+       config_file = os.path.join(self.options.sfi_dir,"sfi_config")
        try:
           config = Config (config_file)
        except:
        try:
           config = Config (config_file)
        except:
@@ -349,16 +360,16 @@ class Sfi:
        self.cert_file = cert_file
        self.cert = GID(filename=cert_file)
        self.logger.info("Contacting Registry at: %s"%self.reg_url)
        self.cert_file = cert_file
        self.cert = GID(filename=cert_file)
        self.logger.info("Contacting Registry at: %s"%self.reg_url)
-       self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)  
+       self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)  
        self.logger.info("Contacting Slice Manager at: %s"%self.sm_url)
        self.logger.info("Contacting Slice Manager at: %s"%self.sm_url)
-       self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, self.options)
+       self.slicemgr = xmlrpcprotocol.get_server(self.sm_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
        return
 
     def get_cached_server_version(self, server):
         # check local cache first
         cache = None
         version = None 
        return
 
     def get_cached_server_version(self, server):
         # check local cache first
         cache = None
         version = None 
-        cache_file = self.sfi_dir + os.path.sep + 'sfi_cache.dat'
+        cache_file = os.path.join(self.options.sfi_dir,'sfi_cache.dat')
         cache_key = server.url + "-version"
         try:
             cache = Cache(cache_file)
         cache_key = server.url + "-version"
         try:
             cache = Cache(cache_file)
@@ -444,7 +455,7 @@ class Sfi:
             self.logger.info("Getting Registry issued cert")
             self.read_config()
             # *hack.  need to set registyr before _get_gid() is called 
             self.logger.info("Getting Registry issued cert")
             self.read_config()
             # *hack.  need to set registyr before _get_gid() is called 
-            self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, self.options)
+            self.registry = xmlrpcprotocol.get_server(self.reg_url, key_file, cert_file, timeout=self.options.timeout, verbose=self.options.debug)
             gid = self._get_gid(type='user')
             self.registry = None 
             self.logger.info("Writing certificate to %s"%cert_file)
             gid = self._get_gid(type='user')
             self.registry = None 
             self.logger.info("Writing certificate to %s"%cert_file)
@@ -618,7 +629,7 @@ class Sfi:
         host_parts = host.split('/')
         host_parts[0] = host_parts[0] + ":" + str(port)
         url =  "http://%s" %  "/".join(host_parts)    
         host_parts = host.split('/')
         host_parts[0] = host_parts[0] + ":" + str(port)
         url =  "http://%s" %  "/".join(host_parts)    
-        return xmlrpcprotocol.get_server(url, keyfile, certfile, self.options)
+        return xmlrpcprotocol.get_server(url, keyfile, certfile, timeout=self.options.timeout, verbose=self.options.debug)
 
     # xxx opts could be retrieved in self.options
     def get_server_from_opts(self, opts):
 
     # xxx opts could be retrieved in self.options
     def get_server_from_opts(self, opts):
@@ -643,7 +654,22 @@ class Sfi:
   
     def dispatch(self, command, cmd_opts, cmd_args):
         return getattr(self, command)(cmd_opts, cmd_args)
   
     def dispatch(self, command, cmd_opts, cmd_args):
         return getattr(self, command)(cmd_opts, cmd_args)
+
+    def create_gid(self, opts, args):
+        if len(args) < 1:
+            self.print_help()
+            sys.exit(1)
+        target_hrn = args[0]
+        user_cred = self.get_user_cred().save_to_string(save_parents=True)
+        gid = self.registry.CreateGid(user_cred, target_hrn, self.cert.save_to_string())
+        if opts.file:
+            filename = opts.file
+        else:
+            filename = os.sep.join([self.sfi_dir, '%s.gid' % target_hrn])
+        self.logger.info("writing %s gid to %s" % (target_hrn, filename))
+        GID(string=gid).save_to_file(filename)
+         
+     
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
@@ -662,10 +688,7 @@ class Sfi:
         for record in list:
             print "%s (%s)" % (record['hrn'], record['type'])     
         if opts.file:
         for record in list:
             print "%s (%s)" % (record['hrn'], record['type'])     
         if opts.file:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_records_to_file(file, list)
+            save_records_to_file(opts.file, list)
         return
     
     # show named registry record
         return
     
     # show named registry record
@@ -676,6 +699,7 @@ class Sfi:
         hrn = args[0]
         user_cred = self.get_user_cred().save_to_string(save_parents=True)
         records = self.registry.Resolve(hrn, user_cred)
         hrn = args[0]
         user_cred = self.get_user_cred().save_to_string(save_parents=True)
         records = self.registry.Resolve(hrn, user_cred)
+        print records
         records = filter_records(opts.type, records)
         if not records:
             print "No record of type", opts.type
         records = filter_records(opts.type, records)
         if not records:
             print "No record of type", opts.type
@@ -694,12 +718,8 @@ class Sfi:
                 record.dump()  
             else:
                 print record.save_to_string() 
                 record.dump()  
             else:
                 print record.save_to_string() 
         if opts.file:
         if opts.file:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_records_to_file(file, records)
+            save_records_to_file(opts.file, records)
         return
     
     def delegate(self, opts, args):
         return
     
     def delegate(self, opts, args):
@@ -912,18 +932,16 @@ class Sfi:
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
         result = server.ListResources(*call_args)
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
         result = server.ListResources(*call_args)
-        format = opts.format
         if opts.file is None:
         if opts.file is None:
-            display_rspec(result, format)
+            display_rspec(result, opts.format)
         else:
         else:
-            file = opts.file
-            if not file.startswith(os.sep):
-                file = os.path.join(self.options.sfi_dir, file)
-            save_rspec_to_file(result, file)
+            save_rspec_to_file(result, opts.file)
         return
     
     # created named slice with given rspec
     def create(self, opts, args):
         return
     
     # created named slice with given rspec
     def create(self, opts, args):
+        server = self.get_server_from_opts(opts)
+        server_version = self.get_cached_server_version(server)
         slice_hrn = args[0]
         slice_urn = hrn_to_urn(slice_hrn, 'slice') 
         user_cred = self.get_user_cred()
         slice_hrn = args[0]
         slice_urn = hrn_to_urn(slice_hrn, 'slice') 
         user_cred = self.get_user_cred()
@@ -935,36 +953,52 @@ class Sfi:
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
 
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
 
+        # need to pass along user keys to the aggregate.  
         # users = [
         #  { urn: urn:publicid:IDN+emulab.net+user+alice
         #    keys: [<ssh key A>, <ssh key B>] 
         #  }]
         users = []
         # users = [
         #  { urn: urn:publicid:IDN+emulab.net+user+alice
         #    keys: [<ssh key A>, <ssh key B>] 
         #  }]
         users = []
-        server = self.get_server_from_opts(opts)
-        version = self.get_cached_server_version(server)
-        if 'sfa' not in version:
-            # need to pass along user keys if this request is going to a ProtoGENI aggregate 
+        all_keys = []
+        all_key_ids = []
+        slice_records = self.registry.Resolve(slice_urn, [user_cred.save_to_string(save_parents=True)])
+        if slice_records and 'researcher' in slice_records[0] and slice_records[0]['researcher']!=[]:
+            slice_record = slice_records[0]
+            user_hrns = slice_record['researcher']
+            user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns]
+            user_records = self.registry.Resolve(user_urns, [user_cred.save_to_string(save_parents=True)])
+            for user_record in user_records:
+                if user_record['type'] != 'user':
+                    continue
+                #user = {'urn': user_cred.get_gid_caller().get_urn(),'keys': []}
+                user = {'urn': user_cred.get_gid_caller().get_urn(), #
+                        'keys': user_record['keys'],
+                        'email': user_record['email'], #  needed for MyPLC
+                        'person_id': user_record['person_id'], # needed for MyPLC
+                        'first_name': user_record['first_name'], # needed for MyPLC
+                        'last_name': user_record['last_name'], # needed for MyPLC
+                        'slice_record': slice_record, # needed for legacy refresh peer
+                        'key_ids': user_record['key_ids'] # needed for legacy refresh peer
+                } 
+                users.append(user)
+                all_keys.extend(user_record['keys'])
+                all_key_ids.extend(user_record['key_ids'])
             # ProtoGeni Aggregates will only install the keys of the user that is issuing the
             # ProtoGeni Aggregates will only install the keys of the user that is issuing the
-            # request. So we will only pass in one user that contains the keys for all
-            # users of the slice 
-            user = {'urn': user_cred.get_gid_caller().get_urn(),
-                    'keys': []}
-            slice_record = self.registry.Resolve(slice_urn, creds)
-            if slice_record and 'researchers' in slice_record:
-                user_hrns = slice_record['researchers']
-                user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns] 
-                user_records = self.registry.Resolve(user_urns, creds)
-                for user_record in user_records:
-                    if 'keys' in user_record:
-                        user['keys'].extend(user_record['keys'])
-            users.append(user)
+            # request. So we will add all to the current caller's list of keys
+            if 'sfa' not in server_version:
+                for user in users:
+                    if user['urn'] == user_cred.get_gid_caller().get_urn():
+                        user['keys'] = all_keys  
 
         call_args = [slice_urn, creds, rspec, users]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
              
 
         call_args = [slice_urn, creds, rspec, users]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
              
-        result =  server.CreateSliver(*call_args)
-        print result
+        result = server.CreateSliver(*call_args)
+        if opts.file is None:
+            print result
+        else:
+            save_rspec_to_file (result, opts.file)
         return result
 
     # get a ticket for the specified slice
         return result
 
     # get a ticket for the specified slice
@@ -1151,6 +1185,7 @@ class Sfi:
             self.dispatch(command, cmd_opts, cmd_args)
         except KeyError:
             self.logger.critical ("Unknown command %s"%command)
             self.dispatch(command, cmd_opts, cmd_args)
         except KeyError:
             self.logger.critical ("Unknown command %s"%command)
+            raise
             sys.exit(1)
     
         return
             sys.exit(1)
     
         return