moved utility methods out of Sfi class. added get_ticket method
authorTony Mack <tmack@cs.princeton.edu>
Fri, 20 Nov 2009 01:02:44 +0000 (01:02 +0000)
committerTony Mack <tmack@cs.princeton.edu>
Fri, 20 Nov 2009 01:02:44 +0000 (01:02 +0000)
sfa/client/sfi.py

index abc5bc2..06b8f05 100755 (executable)
@@ -12,12 +12,114 @@ from optparse import OptionParser
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
 from sfa.util.geniclient import GeniClient
+from sfa.util.sfaticket import SfaTicket
 from sfa.util.record import *
+from sfa.util.misc import *
 from sfa.util.rspec import RSpec
 from sfa.util.xmlrpcprotocol import ServerException
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.util.config import Config
 
+# utility methods here
+# display methods
+def display_rspec(rspec, format = 'rspec'):
+    if format in ['dns']:
+        spec = RSpec()
+        spec.parseString(rspec)
+        hostnames = []
+        nodespecs = spec.getDictsByTagName('NodeSpec')
+        for nodespec in nodespecs:
+            if nodespec.has_key('name') and nodespec['name']:
+                if isinstance(nodespec['name'], ListType):
+                    hostnames.extend(nodespec['name'])
+                elif isinstance(nodespec['name'], StringTypes):
+                     hostnames.append(nodespec['name'])
+        result = hostnames
+    elif format in ['ip']:
+        spec = RSpec()
+        spec.parseString(rspec)
+        ips = []
+        ifspecs = spec.getDictsByTagName('IfSpec')
+        for ifspec in ifspecs:
+            if ifspec.has_key('addr') and ifspec['addr']:
+                ips.append(ifspec['addr'])
+        result = ips
+    else:
+        result = rspec
+
+    print result
+    return
+
+def display_list(results):
+    for result in results:
+        print result
+
+
+def display_records(recordList, dump = False):
+    ''' Print all fields in the record'''
+    for record in recordList:
+        display_record(record, dump)
+
+def display_record(record, dump = False):
+    if dump:
+        record.dump()
+    else:
+        info = record.getdict()
+        print "%s (%s)" % (info['hrn'], info['type'])
+    return
+
+
+def filter_records(type, records):
+    filtered_records = []
+    for record in records:
+        if (record['type'] == type) or (type == "all"):
+            filtered_records.append(record)
+    return filtered_records
+
+
+# save methods
+def save_rspec_to_file(rspec, filename):
+    if not filename.endswith(".rspec"):
+        filename = filename + ".rspec"
+
+    f = open(filename, 'w')
+    f.write(rspec)
+    f.close()
+    return
+
+def save_records_to_file(filename, recordList):
+    index = 0
+    for record in recordList:
+        if index>0:
+            save_record_to_file(filename + "." + str(index), record)
+        else:
+            save_record_to_file(filename, record)
+        index = index + 1
+
+def save_record_to_file(filename, record):
+    if record['type'] in ['user']:
+        record = UserRecord(dict = record)
+    elif record['type'] in ['slice']:
+        record = SliceRecord(dict = record)
+    elif record['type'] in ['node']:
+        record = NodeRecord(dict = record)
+    elif record['type'] in ['authority', 'ma', 'sa']:
+        record = AuthorityRecord(dict = record)
+    else:
+        record = GeniRecord(dict = record)
+    str = record.save_to_string()
+    file(filename, "w").write(str)
+    return
+
+
+# load methods
+def load_record_from_file(filename):
+    str = file(filename, "r").read()
+    record = GeniRecord(string=str)
+    return record
+
+
+
 class Sfi:
     
     slicemgr = None
@@ -39,6 +141,7 @@ class Sfi:
                   "slices": "",
                   "resources": "[name]",
                   "create": "name rspec",
+                  "get_ticket": "name rspec",
                   "delete": "name",
                   "reset": "name",
                   "start": "name",
@@ -67,7 +170,7 @@ class Sfi:
             parser.add_option("-a", "--aggregate", dest="aggregate",
                              default=None, help="aggregate hrn")
 
-        if command in ("create"):
+        if command in ("create", "get_ticket"):
             parser.add_option("-a", "--aggregate", dest="aggregate",default=None,
                              help="aggregate hrn")
 
@@ -213,12 +316,9 @@ class Sfi:
     #   - bootstrap slice credential from user credential
     #
     
-    def get_leaf(self,name):
-       parts = name.split(".")
-       return parts[-1]
     
     def get_key_file(self):
-       file = os.path.join(self.options.sfi_dir, self.get_leaf(self.user) + ".pkey")
+       file = os.path.join(self.options.sfi_dir, get_leaf(self.user) + ".pkey")
        if (os.path.isfile(file)):
           return file
        else:
@@ -228,7 +328,7 @@ class Sfi:
     
     def get_cert_file(self,key_file):
     
-       file = os.path.join(self.options.sfi_dir, self.get_leaf(self.user) + ".cert")
+       file = os.path.join(self.options.sfi_dir, get_leaf(self.user) + ".cert")
        if (os.path.isfile(file)):
           return file
        else:
@@ -243,7 +343,7 @@ class Sfi:
           return file
    
     def get_gid(self):
-        file = os.path.join(self.options.sfi_dir, self.get_leaf(self.user) + ".gid")
+        file = os.path.join(self.options.sfi_dir, get_leaf(self.user) + ".gid")
         if (os.path.isfile(file)):
             gid = GID(filename=file)
             return gid
@@ -260,7 +360,7 @@ class Sfi:
             return gid       
  
     def get_user_cred(self):
-        file = os.path.join(self.options.sfi_dir, self.get_leaf(self.user) + ".cred")
+        file = os.path.join(self.options.sfi_dir, get_leaf(self.user) + ".cred")
         if (os.path.isfile(file)):
             user_cred = Credential(filename=file)
             return user_cred
@@ -280,13 +380,13 @@ class Sfi:
             else:
                print "Failed to get user credential"
                sys.exit(-1)
-    
+  
     def get_auth_cred(self):
         if not self.authority:
             print "no authority specified. Use -a or set SF_AUTH"
             sys.exit(-1)
     
-        file = os.path.join(self.options.sfi_dir, self.get_leaf("authority") +".cred")
+        file = os.path.join(self.options.sfi_dir, get_leaf("authority") +".cred")
         if (os.path.isfile(file)):
             auth_cred = Credential(filename=file)
             return auth_cred
@@ -308,7 +408,7 @@ class Sfi:
                 sys.exit(-1)
     
     def get_slice_cred(self,name):
-        file = os.path.join(self.options.sfi_dir, "slice_" + self.get_leaf(name) + ".cred")
+        file = os.path.join(self.options.sfi_dir, "slice_" + get_leaf(name) + ".cred")
         if (os.path.isfile(file)):
             slice_cred = Credential(filename=file)
             return slice_cred
@@ -340,7 +440,7 @@ class Sfi:
            
     
         records = self.registry.resolve(cred, hrn)
-        records = self.filter_records(type, records)
+        records = filter_records(type, records)
         
         if not records:
             raise Exception, "Error: Didn't find a %(type)s record for %(hrn)s" % locals()
@@ -431,11 +531,14 @@ class Sfi:
           
         # filter on person, slice, site, node, etc.  
         # THis really should be in the self.filter_records funct def comment...
-        list = self.filter_records(opts.type, list)
+        list = filter_records(opts.type, list)
         for record in list:
             print "%s (%s)" % (record['hrn'], record['type'])     
         if opts.file:
-            self.save_records_to_file(opts.file, list)
+            file = opts.file
+            if not file.startswith(os.sep):
+                file = os.path.join(self.options.sfi_dir, get_leaf(self.user) + ".gid")
+            save_records_to_file(file, list)
         return
     
     # show named registry record
@@ -446,7 +549,7 @@ class Sfi:
         if self.hashrequest:
             request_hash = self.key.compute_hash([user_cred, hrn])    
         records = self.registry.resolve(user_cred, hrn, request_hash)
-        records = self.filter_records(opts.type, records)
+        records = filter_records(opts.type, records)
         if not records:
             print "No record of type", opts.type
         for record in records:
@@ -466,7 +569,10 @@ class Sfi:
                 print record.save_to_string() 
        
         if opts.file:
-            self.save_records_to_file(opts.file, records)
+            file = opts.file
+            if not file.startswith(os.sep):
+                file = os.path.join(self.options.sfi_dir, file)
+            save_records_to_file(file, records)
         return
     
     def delegate(self,opts, args):
@@ -488,7 +594,7 @@ class Sfi:
            return
     
        records = self.registry.resolve(user_cred, args[0])
-       records = self.filter_records("user", records)
+       records = filter_records("user", records)
     
        if not records:
            print "Error: Didn't find a user record for", args[0]
@@ -514,11 +620,11 @@ class Sfi:
        dcred.sign()
     
        if opts.delegate_user:
-           dest_fn = os.path.join(self.options.sfi_dir, self.get_leaf(delegee_hrn) + "_" 
-                                  + self.get_leaf(object_hrn) + ".cred")
+           dest_fn = os.path.join(self.options.sfi_dir, get_leaf(delegee_hrn) + "_" 
+                                  + get_leaf(object_hrn) + ".cred")
        elif opts.delegate_slice:
-           dest_fn = os.path_join(self.options.sfi_dir, self.get_leaf(delegee_hrn) + "_slice_" 
-                                  + self.get_leaf(object_hrn) + ".cred")
+           dest_fn = os.path_join(self.options.sfi_dir, get_leaf(delegee_hrn) + "_slice_" 
+                                  + get_leaf(object_hrn) + ".cred")
     
        dcred.save_to_file(dest_fn, save_parents = True)
     
@@ -543,7 +649,7 @@ class Sfi:
         auth_cred = self.get_auth_cred().save_to_string(save_parents=True)
         record_filepath = args[0]
         rec_file = self.get_record_file(record_filepath)
-        record = self.load_record_from_file(rec_file).as_dict()
+        record = load_record_from_file(rec_file).as_dict()
         request_hash=None
         if self.hashrequest:
             arg_list = [auth_cred]
@@ -554,7 +660,7 @@ class Sfi:
     def update(self,opts, args):
         user_cred = self.get_user_cred()
         rec_file = self.get_record_file(args[0])
-        record = self.load_record_from_file(rec_file)
+        record = load_record_from_file(rec_file)
         if record['type'] == "user":
             if record.get_name() == user_cred.get_gid_object().get_hrn():
                 cred = user_cred.save_to_string(save_parents=True)
@@ -594,7 +700,7 @@ class Sfi:
             arg_list = [user_cred, hrn]  
             request_hash = self.key.compute_hash(arg_list)
         result = self.registry.get_aggregates(user_cred, hrn, request_hash)
-        self.display_list(result)
+        display_list(result)
         return 
 
     def registries(self, opts, args):
@@ -607,7 +713,7 @@ class Sfi:
             arg_list = [user_cred, hrn]  
             request_hash = self.key.compute_hash(arg_list)
         result = self.registry.get_registries(user_cred, hrn, request_hash)
-        self.display_list(result)
+        display_list(result)
         return
  
     #
@@ -624,7 +730,7 @@ class Sfi:
             arg_list = [user_cred]
             request_hash = self.key.compute_hash(arg_list)
         results = self.slicemgr.get_slices(user_cred, request_hash)
-        self.display_list(results)
+        display_list(results)
         return
     
     # show rspec for named slice
@@ -655,9 +761,12 @@ class Sfi:
         result = server.get_resources(cred, hrn, request_hash)
         format = opts.format
        
-        self.display_rspec(result, format)
+        display_rspec(result, format)
         if (opts.file is not None):
-            self.save_rspec_to_file(result, opts.file)
+            file = opts.file
+            if not file.startswith(os.sep):
+                file = os.path.join(self.options.sfi_dir, file)
+            save_rspec_to_file(result, file)
         return
     
     # created named slice with given rspec
@@ -680,7 +789,33 @@ class Sfi:
             arg_list = [slice_cred, slice_hrn, rspec]
             request_hash = self.key.compute_hash(arg_list) 
         return server.create_slice(slice_cred, slice_hrn, rspec, request_hash)
-    
+
+    # get a ticket for the specified slice
+    def get_ticket(self, opts, args):
+        slice_hrn, rspec_path = args[0], args[1]
+        user_cred = self.get_user_cred()
+        slice_cred = self.get_slice_cred(slice_hrn).save_to_string(save_parents=True)
+        rspec_file = self.get_rspec_file(rspec_path) 
+        rspec=open(rspec_file).read()
+        server = self.slicemgr
+        if opts.aggregate:
+            aggregates = self.registry.get_aggregates(user_cred, opts.aggregate)
+            if not aggregates:
+                raise Exception, "No such aggregate %s" % opts.aggregate
+            aggregate = aggregates[0]
+            url = "http://%s:%s" % (aggregate['addr'], aggregate['port'])
+            server = GeniClient(url, self.key_file, self.cert_file, self.options.protocol)
+        request_hash=None
+        if self.hashrequest:
+            arg_list = [slice_cred, slice_hrn, rspec]
+            request_hash = self.key.compute_hash(arg_list)
+        ticket_string = server.get_ticket(slice_cred, slice_hrn, rspec, request_hash)
+        file = os.path.join(self.options.sfi_dir, get_leaf(slice_hrn) + ".ticket")
+        print "writing ticket to ", file        
+        ticket = SfaTicket(string=ticket_string)
+        ticket.save_to_file(filename=file, save_parents=True)
+        print ticket_string  
     # delete named slice
     def delete(self,opts, args):
         slice_hrn = args[0]
@@ -721,107 +856,6 @@ class Sfi:
             request_hash = self.key.compute_hash(arg_list)
         return self.slicemgr.reset_slice(slice_cred, slice_hrn, request_hash)
     
-    #
-    #
-    # Display, Save, and Filter RSpecs and Records
-    #   - to be replace by EMF-generated routines
-    #
-    #
-    
-    def display_rspec(self,rspec, format = 'rspec'):
-        if format in ['dns']:
-            spec = RSpec()
-            spec.parseString(rspec)
-            hostnames = []
-            nodespecs = spec.getDictsByTagName('NodeSpec')
-            for nodespec in nodespecs:
-                if nodespec.has_key('name') and nodespec['name']:
-                    if isinstance(nodespec['name'], ListType):
-                        hostnames.extend(nodespec['name'])
-                    elif isinstance(nodespec['name'], StringTypes):
-                        hostnames.append(nodespec['name'])
-            result = hostnames
-        elif format in ['ip']:
-            spec = RSpec()
-            spec.parseString(rspec)
-            ips = []
-            ifspecs = spec.getDictsByTagName('IfSpec')
-            for ifspec in ifspecs:
-                if ifspec.has_key('addr') and ifspec['addr']:
-                    ips.append(ifspec['addr'])
-            result = ips 
-        else:     
-            result = rspec
-    
-        print result
-        return
-    
-    def display_list(self,results):
-        for result in results:
-            print result
-    
-    def save_rspec_to_file(self,rspec, filename):
-       if not filename.startswith(os.sep):
-           filename = self.options.sfi_dir + filename
-       if not filename.endswith(".rspec"):
-           filename = filename + ".rspec"
-    
-       f = open(filename, 'w')
-       f.write(rspec)
-       f.close()
-       return
-    
-    def display_records(self,recordList, dump = False):
-       ''' Print all fields in the record'''
-       for record in recordList:
-          self.display_record(record, dump)
-    
-    def display_record(self,record, dump = False):
-       if dump:
-           record.dump()
-       else:
-           info = record.getdict()
-           print "%s (%s)" % (info['hrn'], info['type'])
-       return
-    
-    def filter_records(self,type, records):
-       filtered_records = []
-       for record in records:
-           if (record['type'] == type) or (type == "all"):
-               filtered_records.append(record)
-       return filtered_records
-    
-    def save_records_to_file(self,filename, recordList):
-       index = 0
-       for record in recordList:
-           if index>0:
-               self.save_record_to_file(filename + "." + str(index), record)
-           else:
-               self.save_record_to_file(filename, record)
-           index = index + 1
-    
-    def save_record_to_file(self,filename, record):
-       if record['type'] in ['user']:
-           record = UserRecord(dict = record)
-       elif record['type'] in ['slice']:
-           record = SliceRecord(dict = record)
-       elif record['type'] in ['node']:
-           record = NodeRecord(dict = record)
-       elif record['type'] in ['authority', 'ma', 'sa']:
-          record = AuthorityRecord(dict = record)
-       else:
-           record = GeniRecord(dict = record) 
-       if not filename.startswith(os.sep):
-           filename = self.options.sfi_dir + filename
-       str = record.save_to_string()
-       file(filename, "w").write(str)
-       return
-    
-    def load_record_from_file(self,filename):
-       str = file(filename, "r").read()
-       record = GeniRecord(string=str)
-       return record
-    
     #
     # Main: parse arguments and dispatch to command
     #