trimmed useless imports, unstarred all imports
[sfa.git] / sfa / server / sfa_component_setup.py
index 227f992..6d74468 100755 (executable)
@@ -3,14 +3,30 @@ import sys
 import os
 import tempfile
 from optparse import OptionParser
 import os
 import tempfile
 from optparse import OptionParser
+
+from sfa.util.faults import ConnectionKeyGIDMismatch
 from sfa.util.config import Config
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
 from sfa.util.config import Config
 import sfa.util.xmlrpcprotocol as xmlrpcprotocol
-from sfa.util.namespace import *
+from sfa.util.plxrn import hrn_to_pl_slicename, slicename_to_hrn
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
 from sfa.trust.gid import GID
 from sfa.trust.hierarchy import Hierarchy
 
 from sfa.trust.certificate import Keypair, Certificate
 from sfa.trust.credential import Credential
 from sfa.trust.gid import GID
 from sfa.trust.hierarchy import Hierarchy
 
+KEYDIR = "/var/lib/sfa/"
+CONFDIR = "/etc/sfa/"
+
+def handle_gid_mismatch_exception(f):
+    def wrapper(*args, **kwds):
+        try: return f(*args, **kwds)
+        except ConnectionKeyGIDMismatch:
+            # clean regen server keypair and try again
+            print "cleaning keys and trying again"
+            clean_key_cred()
+            return f(args, kwds)
+
+    return wrapper
+
 def get_server(url=None, port=None, keyfile=None, certfile=None,verbose=False):
     """
     returns an xmlrpc connection to the service a the specified 
 def get_server(url=None, port=None, keyfile=None, certfile=None,verbose=False):
     """
     returns an xmlrpc connection to the service a the specified 
@@ -45,6 +61,26 @@ def create_default_dirs():
     for dir in all_dirs:
         if not os.path.exists(dir):
             os.makedirs(dir)
     for dir in all_dirs:
         if not os.path.exists(dir):
             os.makedirs(dir)
+
+def has_node_key():
+    key_file = KEYDIR + os.sep + 'server.key'
+    return os.path.exists(key_file) 
+
+def clean_key_cred():
+    """
+    remove the existing keypair and cred  and generate new ones
+    """
+    files = ["server.key", "server.cert", "node.cred"]
+    for f in files:
+        filepath = KEYDIR + os.sep + f
+        if os.path.isfile(filepath):
+            os.unlink(f)
+   
+    # install the new key pair
+    # get_credential will take care of generating the new keypair
+    # and credential 
+    get_credential()
+    
              
 def get_node_key(registry=None, verbose=False):
     # this call requires no authentication, 
              
 def get_node_key(registry=None, verbose=False):
     # this call requires no authentication, 
@@ -74,7 +110,8 @@ def create_server_keypair(keyfile=None, certfile=None, hrn="component", verbose=
     cert.set_pubkey(key)
     cert.sign()
     cert.save_to_file(certfile, save_parents=True)       
     cert.set_pubkey(key)
     cert.sign()
     cert.save_to_file(certfile, save_parents=True)       
-        
+
+@handle_gid_mismatch_exception
 def get_credential(registry=None, force=False, verbose=False):
     config = Config()
     hierarchy = Hierarchy()
 def get_credential(registry=None, force=False, verbose=False):
     config = Config()
     hierarchy = Hierarchy()
@@ -110,11 +147,12 @@ def get_credential(registry=None, force=False, verbose=False):
         registry = get_server(url=registry, keyfile=keyfile, certfile=certfile)
         cert = Certificate(filename=certfile)
         cert_str = cert.save_to_string(save_parents=True)
         registry = get_server(url=registry, keyfile=keyfile, certfile=certfile)
         cert = Certificate(filename=certfile)
         cert_str = cert.save_to_string(save_parents=True)
-        cred = registry.get_self_credential(cert_str, 'node', hrn)    
+        cred = registry.GetSelfCredential(cert_str, 'node', hrn)
         Credential(string=cred).save_to_file(credfile, save_parents=True)
     
     return cred
 
         Credential(string=cred).save_to_file(credfile, save_parents=True)
     
     return cred
 
+@handle_gid_mismatch_exception
 def get_trusted_certs(registry=None, verbose=False):
     """
     refresh our list of trusted certs.
 def get_trusted_certs(registry=None, verbose=False):
     """
     refresh our list of trusted certs.
@@ -157,6 +195,7 @@ def get_trusted_certs(registry=None, verbose=False):
                 print "Removing old gid ", gid_name
             os.unlink(trusted_certs_dir + os.sep + gid_name)                     
 
                 print "Removing old gid ", gid_name
             os.unlink(trusted_certs_dir + os.sep + gid_name)                     
 
+@handle_gid_mismatch_exception
 def get_gids(registry=None, verbose=False):
     """
     Get the gid for all instantiated slices on this node and store it
 def get_gids(registry=None, verbose=False):
     """
     Get the gid for all instantiated slices on this node and store it
@@ -186,15 +225,28 @@ def get_gids(registry=None, verbose=False):
     api = ComponentAPI()
     xids_tuple = api.nodemanager.GetXIDs()
     slices = eval(xids_tuple[1])
     api = ComponentAPI()
     xids_tuple = api.nodemanager.GetXIDs()
     slices = eval(xids_tuple[1])
-    slicenames = slices.keys()   
-    hrns = [slicename_to_hrn(interface_hrn, slicename) for slicename in slicenames]
-        
+    slicenames = slices.keys()
 
 
+    # generate a list of slices that dont have gids installed
+    slices_without_gids = []
+    for slicename in slicenames:
+        if not os.path.isfile("/vservers/%s/etc/slice.gid" % slicename) \
+        or not os.path.isfile("/vservers/%s/etc/node.gid" % slicename):
+            slices_without_gids.append(slicename) 
+    
+    # convert slicenames to hrns
+    hrns = [slicename_to_hrn(interface_hrn, slicename) \
+            for slicename in slices_without_gids]
+    
+    # exit if there are no gids to install
+    if not hrns:
+        return
+        
     if verbose:
         print "Getting gids for slices on this node from registry"  
     # get the gids
     # and save them in the right palce
     if verbose:
         print "Getting gids for slices on this node from registry"  
     # get the gids
     # and save them in the right palce
-    records = registry.get_gids(cred, hrns)
+    records = registry.GetGids(hrns, cred)
     for record in records:
         # if this isnt a slice record skip it
         if not record['type'] == 'slice':
     for record in records:
         # if this isnt a slice record skip it
         if not record['type'] == 'slice':