cleaned up saving-to-file functions in sfi - now always displays a message when writi...
[sfa.git] / sfa / client / sfi.py
index 2a3e21f..dce1ac2 100644 (file)
@@ -112,72 +112,69 @@ def show_credentials (cred_s):
     for cred in cred_s:
         print "Using Credential {}".format(credential_printable(cred))
 
-# save methods
-def save_raw_to_file(var, filename, format="text", banner=None):
-    if filename == "-":
-        # if filename is "-", send it to stdout
-        f = sys.stdout
+########## save methods
+
+### raw
+def save_raw_to_file(var, filename, format='text', banner=None):
+    if filename == '-':
+        _save_raw_to_file(var, sys.stdout, format, banner)
     else:
-        f = open(filename, "w")
-    if banner:
-        f.write(banner+"\n")
+        with open(filename, w) as fileobj:
+            _save_raw_to_file(var, fileobj, format, banner)
+        print "(Over)wrote {}".format(filename)
+
+def _save_raw_to_file(var, f, format, banner):
     if format == "text":
-        f.write(str(var))
+        if banner: f.write(banner+"\n")
+        f.write("{}".format(var))
+        if banner: f.write('\n'+banner+"\n")
     elif format == "pickled":
         f.write(pickle.dumps(var))
     elif format == "json":
-        if hasattr(json, "dumps"):
-            f.write(json.dumps(var))   # python 2.6
-        else:
-            f.write(json.write(var))   # python 2.5
+        f.write(json.dumps(var))   # python 2.6
     else:
         # this should never happen
         print "unknown output format", format
-    if banner:
-        f.write('\n'+banner+"\n")
 
+### 
 def save_rspec_to_file(rspec, filename):
     if not filename.endswith(".rspec"):
         filename = filename + ".rspec"
     with open(filename, 'w') as f:
         f.write("{}".format(rspec))
-    return
+    print "(Over)wrote {}".format(filename)
+
+def save_record_to_file(filename, record_dict):
+    record = Record(dict=record_dict)
+    xml = record.save_as_xml()
+    with codecs.open(filename, encoding='utf-8',mode="w") as f:
+        f.write(xml)
+    print "(Over)wrote {}".format(filename)
 
 def save_records_to_file(filename, record_dicts, format="xml"):
     if format == "xml":
-        index = 0
-        for record_dict in record_dicts:
-            if index > 0:
-                save_record_to_file(filename + "." + str(index), record_dict)
-            else:
-                save_record_to_file(filename, record_dict)
-            index = index + 1
+        for index, record_dict in enumerate(record_dicts):
+            save_record_to_file(filename + "." + str(index), record_dict)
     elif format == "xmllist":
-        f = open(filename, "w")
-        f.write("<recordlist>\n")
-        for record_dict in record_dicts:
-            record_obj=Record(dict=record_dict)
-            f.write('<record hrn="' + record_obj.hrn + '" type="' + record_obj.type + '" />\n')
-        f.write("</recordlist>\n")
-        f.close()
+        with open(filename, "w") as f:
+            f.write("<recordlist>\n")
+            for record_dict in record_dicts:
+                record_obj = Record(dict=record_dict)
+                f.write('<record hrn="' + record_obj.hrn + '" type="' + record_obj.type + '" />\n')
+            f.write("</recordlist>\n")
+            print "(Over)wrote {}".format(filename)
+
     elif format == "hrnlist":
-        f = open(filename, "w")
-        for record_dict in record_dicts:
-            record_obj=Record(dict=record_dict)
-            f.write(record_obj.hrn + "\n")
-        f.close()
+        with open(filename, "w") as f:
+            for record_dict in record_dicts:
+                record_obj = Record(dict=record_dict)
+                f.write(record_obj.hrn + "\n")
+            print "(Over)wrote {}".format(filename)
+
     else:
         # this should never happen
         print "unknown output format", format
 
-def save_record_to_file(filename, record_dict):
-    record = Record(dict=record_dict)
-    xml = record.save_as_xml()
-    f=codecs.open(filename, encoding='utf-8',mode="w")
-    f.write(xml)
-    f.close()
-    return
-
 # minimally check a key argument
 def check_ssh_key (key):
     good_ssh_key = r'^.*(?:ssh-dss|ssh-rsa)[ ]+[A-Za-z0-9+/=]+(?: .*)?$'
@@ -193,13 +190,16 @@ def normalize_type (type):
         return 'slice'
     elif type.startswith('no'):
         return 'node'
+    elif type.startswith('ag'):
+        return 'aggregate'
+    elif type.startswith('al'):
+        return 'all'
     else:
+        print 'unknown type {} - should start with one of au|us|sl|no|ag|al'.format(type)
         return None
 
 def load_record_from_opts(options):
     record_dict = {}
-    if hasattr(options, 'type'):
-        options.type = normalize_type(options.type)
     if hasattr(options, 'xrn') and options.xrn:
         if hasattr(options, 'type') and options.type:
             xrn = Xrn(options.xrn, options.type)
@@ -439,7 +439,7 @@ class Sfi:
 
         if canonical in ("register", "update"):
             parser.add_option('-x', '--xrn', dest='xrn', metavar='<xrn>', help='object hrn/urn (mandatory)')
-            parser.add_option('-t', '--type', dest='type', metavar='<type>', help='object type', default=None)
+            parser.add_option('-t', '--type', dest='type', metavar='<type>', help='object type (2 first chars is enough)', default=None)
             parser.add_option('-e', '--email', dest='email', default="",  help="email (mandatory for users)") 
             parser.add_option('-n', '--name', dest='name', default="",  help="name (optional for authorities)") 
             parser.add_option('-k', '--key', dest='key', metavar='<key>', help='public key string or file', 
@@ -472,10 +472,9 @@ class Sfi:
                               help="renew as long as possible")
         # registy filter option
         if canonical in ("list", "show", "remove"):
-            parser.add_option("-t", "--type", dest="type", type="choice",
-                            help="type filter ([all]|user|slice|authority|node|aggregate)",
-                            choices=("all", "user", "slice", "authority", "node", "aggregate"),
-                            default="all")
+            parser.add_option("-t", "--type", dest="type", metavar="<type>",
+                              default="all",
+                              help="type filter - 2 first chars is enough ([all]|user|slice|authority|node|aggregate)")
         if canonical in ("show"):
             parser.add_option("-k","--key",dest="keys",action="append",default=[],
                               help="specify specific keys to be displayed from record")
@@ -496,9 +495,9 @@ class Sfi:
             #panos: a new option to define the type of information about resources a user is interested in
             parser.add_option("-i", "--info", dest="info",
                                 help="optional component information", default=None)
-            # a new option to retreive or not reservation-oriented RSpecs (leases)
+            # a new option to retrieve or not reservation-oriented RSpecs (leases)
             parser.add_option("-l", "--list_leases", dest="list_leases", type="choice",
-                                help="Retreive or not reservation-oriented RSpecs ([resources]|leases|all )",
+                                help="Retrieve or not reservation-oriented RSpecs ([resources]|leases|all)",
                                 choices=("all", "resources", "leases"), default="resources")
 
 
@@ -593,6 +592,12 @@ use this if you mean an authority instead""")
             sys.exit(1)
         self.command_options = command_options
 
+        # allow incoming types on 2 characters only
+        if hasattr(command_options, 'type'):
+            command_options.type = normalize_type(command_options.type)
+            if not command_options.type:
+                sys.exit(1)
+        
         self.read_config () 
         self.bootstrap ()
         self.logger.debug("Command={}".format(self.command))
@@ -1280,7 +1285,12 @@ use this if you mean an authority instead""")
         """
         server = self.sliceapi()
         server_version = self.get_cached_server_version(server)
+        if len(args) != 2:
+            self.print_help()
+            sys.exit(1)
         slice_hrn = args[0]
+        rspec_file = self.get_rspec_file(args[1])
+
         slice_urn = Xrn(slice_hrn, type='slice').get_urn()
 
         # credentials
@@ -1300,8 +1310,6 @@ use this if you mean an authority instead""")
             show_credentials(creds)
 
         # rspec
-        rspec_file = self.get_rspec_file(args[1])
-        rspec = open(rspec_file).read()
         api_options = {}
         api_options ['call_id'] = unique_call_id()
         # users
@@ -1320,7 +1328,9 @@ use this if you mean an authority instead""")
         api_options['sfa_users'] = sfa_users
         api_options['geni_users'] = geni_users
 
-        allocate = server.Allocate(slice_urn, creds, rspec, api_options)
+        with open(rspec_file) as rspec:
+            rspec_xml = rspec.read()
+            allocate = server.Allocate(slice_urn, creds, rspec_xml, api_options)
         value = ReturnValue.get_value(allocate)
         if self.options.raw:
             save_raw_to_file(allocate, self.options.raw, self.options.rawformat, self.options.rawbanner)