Merge branch 'master' of ssh://git.planet-lab.org/git/sfa
[sfa.git] / sfa / client / sfi.py
index 6e0c82c..293e5bd 100755 (executable)
@@ -11,10 +11,12 @@ import socket
 import random
 import datetime
 import zlib
 import random
 import datetime
 import zlib
+import codecs
 from lxml import etree
 from StringIO import StringIO
 from types import StringTypes, ListType
 from optparse import OptionParser
 from lxml import etree
 from StringIO import StringIO
 from types import StringTypes, ListType
 from optparse import OptionParser
+
 from sfa.util.sfalogging import sfi_logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.gid import GID
 from sfa.util.sfalogging import sfi_logger
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.gid import GID
@@ -108,13 +110,17 @@ def save_record_to_file(filename, record):
     else:
         record = SfaRecord(dict=record)
     str = record.save_to_string()
     else:
         record = SfaRecord(dict=record)
     str = record.save_to_string()
-    file(filename, "w").write(str)
+    f=codecs.open(filename, encoding='utf-8',mode="w")
+    f.write(str)
+    f.close()
     return
 
 
 # load methods
 def load_record_from_file(filename):
     return
 
 
 # load methods
 def load_record_from_file(filename):
-    str = file(filename, "r").read()
+    f=codecs.open(filename, encoding="utf-8", mode="r")
+    str = f.read()
+    f.close()
     record = SfaRecord(string=str)
     return record
 
     record = SfaRecord(string=str)
     return record
 
@@ -136,6 +142,8 @@ class Sfi:
         for opt in Sfi.required_options:
             if not hasattr(options,opt): setattr(options,opt,None)
         if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
         for opt in Sfi.required_options:
             if not hasattr(options,opt): setattr(options,opt,None)
         if not hasattr(options,'sfi_dir'): options.sfi_dir=os.path.expanduser("~/.sfi/")
+        # xxx oops, this is dangerous, sounds like ww sometimes have discrepency
+        # would be safer to remove self.sfi_dir altogether
         self.sfi_dir = options.sfi_dir
         self.options = options
         self.slicemgr = None
         self.sfi_dir = options.sfi_dir
         self.options = options
         self.slicemgr = None
@@ -154,6 +162,7 @@ class Sfi:
                   "update": "record",
                   "aggregates": "[name]",
                   "registries": "[name]",
                   "update": "record",
                   "aggregates": "[name]",
                   "registries": "[name]",
+                  "create_gid": "[name]",
                   "get_gid": [],  
                   "get_trusted_certs": "cred",
                   "slices": "",
                   "get_gid": [],  
                   "get_trusted_certs": "cred",
                   "slices": "",
@@ -217,7 +226,7 @@ class Sfi:
                                 help="optional component information", default=None)
 
 
                                 help="optional component information", default=None)
 
 
-        if command in ("resources", "show", "list"):
+        if command in ("resources", "show", "list", "create_gid"):
            parser.add_option("-o", "--output", dest="file",
                             help="output XML to file", metavar="FILE", default=None)
         
            parser.add_option("-o", "--output", dest="file",
                             help="output XML to file", metavar="FILE", default=None)
         
@@ -275,7 +284,7 @@ class Sfi:
         parser.add_option("-k", "--hashrequest",
                          action="store_true", dest="hashrequest", default=False,
                          help="Create a hash of the request that will be authenticated on the server")
         parser.add_option("-k", "--hashrequest",
                          action="store_true", dest="hashrequest", default=False,
                          help="Create a hash of the request that will be authenticated on the server")
-        parser.add_option("-t", "--timeout", dest="timeout", default=30,
+        parser.add_option("-t", "--timeout", dest="timeout", default=None,
                          help="Amout of time tom wait before timing out the request")
         parser.disable_interspersed_args()
 
                          help="Amout of time tom wait before timing out the request")
         parser.disable_interspersed_args()
 
@@ -283,7 +292,7 @@ class Sfi:
         
 
     def read_config(self):
         
 
     def read_config(self):
-       config_file = self.options.sfi_dir + os.sep + "sfi_config"
+       config_file = os.path.join(self.options.sfi_dir,"sfi_config")
        try:
           config = Config (config_file)
        except:
        try:
           config = Config (config_file)
        except:
@@ -360,7 +369,7 @@ class Sfi:
         # check local cache first
         cache = None
         version = None 
         # check local cache first
         cache = None
         version = None 
-        cache_file = self.sfi_dir + os.path.sep + 'sfi_cache.dat'
+        cache_file = os.path.join(self.options.sfi_dir,'sfi_cache.dat')
         cache_key = server.url + "-version"
         try:
             cache = Cache(cache_file)
         cache_key = server.url + "-version"
         try:
             cache = Cache(cache_file)
@@ -645,7 +654,22 @@ class Sfi:
   
     def dispatch(self, command, cmd_opts, cmd_args):
         return getattr(self, command)(cmd_opts, cmd_args)
   
     def dispatch(self, command, cmd_opts, cmd_args):
         return getattr(self, command)(cmd_opts, cmd_args)
+
+    def create_gid(self, opts, args):
+        if len(args) < 1:
+            self.print_help()
+            sys.exit(1)
+        target_hrn = args[0]
+        user_cred = self.get_user_cred().save_to_string(save_parents=True)
+        gid = self.registry.CreateGid(user_cred, target_hrn, self.cert.save_to_string())
+        if opts.file:
+            filename = opts.file
+        else:
+            filename = os.sep.join([self.sfi_dir, '%s.gid' % target_hrn])
+        self.logger.info("writing %s gid to %s" % (target_hrn, filename))
+        GID(string=gid).save_to_file(filename)
+         
+     
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
     # list entires in named authority registry
     def list(self, opts, args):
         if len(args)!= 1:
@@ -678,6 +702,7 @@ class Sfi:
         hrn = args[0]
         user_cred = self.get_user_cred().save_to_string(save_parents=True)
         records = self.registry.Resolve(hrn, user_cred)
         hrn = args[0]
         user_cred = self.get_user_cred().save_to_string(save_parents=True)
         records = self.registry.Resolve(hrn, user_cred)
+        print records
         records = filter_records(opts.type, records)
         if not records:
             print "No record of type", opts.type
         records = filter_records(opts.type, records)
         if not records:
             print "No record of type", opts.type
@@ -696,7 +721,6 @@ class Sfi:
                 record.dump()  
             else:
                 print record.save_to_string() 
                 record.dump()  
             else:
                 print record.save_to_string() 
         if opts.file:
             file = opts.file
             if not file.startswith(os.sep):
         if opts.file:
             file = opts.file
             if not file.startswith(os.sep):
@@ -926,6 +950,8 @@ class Sfi:
     
     # created named slice with given rspec
     def create(self, opts, args):
     
     # created named slice with given rspec
     def create(self, opts, args):
+        server = self.get_server_from_opts(opts)
+        server_version = self.get_cached_server_version(server)
         slice_hrn = args[0]
         slice_urn = hrn_to_urn(slice_hrn, 'slice') 
         user_cred = self.get_user_cred()
         slice_hrn = args[0]
         slice_urn = hrn_to_urn(slice_hrn, 'slice') 
         user_cred = self.get_user_cred()
@@ -937,35 +963,48 @@ class Sfi:
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
 
         rspec_file = self.get_rspec_file(args[1])
         rspec = open(rspec_file).read()
 
+        # need to pass along user keys to the aggregate.  
         # users = [
         #  { urn: urn:publicid:IDN+emulab.net+user+alice
         #    keys: [<ssh key A>, <ssh key B>] 
         #  }]
         users = []
         # users = [
         #  { urn: urn:publicid:IDN+emulab.net+user+alice
         #    keys: [<ssh key A>, <ssh key B>] 
         #  }]
         users = []
-        server = self.get_server_from_opts(opts)
-        version = self.get_cached_server_version(server)
-        if 'sfa' not in version:
-            # need to pass along user keys if this request is going to a ProtoGENI aggregate 
+        all_keys = []
+        all_key_ids = []
+        slice_records = self.registry.Resolve(slice_urn, [user_cred.save_to_string(save_parents=True)])
+        if slice_records and 'researcher' in slice_records[0]:
+            slice_record = slice_records[0]
+            user_hrns = slice_record['researcher']
+            user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns]
+            user_records = self.registry.Resolve(user_urns, [user_cred.save_to_string(save_parents=True)])
+            for user_record in user_records:
+                if user_record['type'] != 'user':
+                    continue
+                #user = {'urn': user_cred.get_gid_caller().get_urn(),'keys': []}
+                user = {'urn': user_cred.get_gid_caller().get_urn(), #
+                        'keys': user_record['keys'],
+                        'email': user_record['email'], #  needed for MyPLC
+                        'person_id': user_record['person_id'], # needed for MyPLC
+                        'first_name': user_record['first_name'], # needed for MyPLC
+                        'last_name': user_record['last_name'], # needed for MyPLC
+                        'slice_record': slice_record, # needed for legacy refresh peer
+                        'key_ids': user_record['key_ids'] # needed for legacy refresh peer
+                } 
+                users.append(user)
+                all_keys.extend(user_record['keys'])
+                all_key_ids.extend(user_record['key_ids'])
             # ProtoGeni Aggregates will only install the keys of the user that is issuing the
             # ProtoGeni Aggregates will only install the keys of the user that is issuing the
-            # request. So we will only pass in one user that contains the keys for all
-            # users of the slice 
-            user = {'urn': user_cred.get_gid_caller().get_urn(),
-                    'keys': []}
-            slice_record = self.registry.Resolve(slice_urn, creds)
-            if slice_record and 'researchers' in slice_record:
-                user_hrns = slice_record['researchers']
-                user_urns = [hrn_to_urn(hrn, 'user') for hrn in user_hrns] 
-                user_records = self.registry.Resolve(user_urns, creds)
-                for user_record in user_records:
-                    if 'keys' in user_record:
-                        user['keys'].extend(user_record['keys'])
-            users.append(user)
+            # request. So we will add all to the current caller's list of keys
+            if 'sfa' not in server_version:
+                for user in users:
+                    if user['urn'] == user_cred.get_gid_caller().get_urn():
+                        user['keys'] = all_keys  
 
         call_args = [slice_urn, creds, rspec, users]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
              
 
         call_args = [slice_urn, creds, rspec, users]
         if self.server_supports_call_id_arg(server):
             call_args.append(unique_call_id())
              
-        result =  server.CreateSliver(*call_args)
+        result = server.CreateSliver(*call_args)
         print result
         return result
 
         print result
         return result
 
@@ -1153,6 +1192,7 @@ class Sfi:
             self.dispatch(command, cmd_opts, cmd_args)
         except KeyError:
             self.logger.critical ("Unknown command %s"%command)
             self.dispatch(command, cmd_opts, cmd_args)
         except KeyError:
             self.logger.critical ("Unknown command %s"%command)
+            raise
             sys.exit(1)
     
         return
             sys.exit(1)
     
         return