convert GetNodeNetworks to GetInterfaces and nodenetwork_ids to interface_ids
[monitor.git] / getsshkeys.py
index d0084db..7b74759 100755 (executable)
@@ -5,17 +5,30 @@ import sys
 import string
 import time
 import xml, xmlrpclib
+try:
+       from monitor import config
+       auth = {'Username'   : config.API_AUTH_USER,
+               'AuthMethod' : "password",
+                       'AuthString' : config.API_AUTH_PASSWORD}
+except:
+       import traceback
+       print traceback.print_exc()
+       auth = {'AuthMethod' : "anonymous"}
 
 args = {}
 args['known_hosts'] =  os.environ['HOME'] + os.sep + ".ssh" + os.sep + "known_hosts"
-args['XMLRPC_SERVER'] = 'https://www.planet-lab.org/PLCAPI/'
+try:
+       from monitor import config
+       args['XMLRPC_SERVER'] = config.API_SERVER
+except:
+       args['XMLRPC_SERVER'] = 'https://boot.planet-lab.org/PLCAPI/'
+       print "Using default API server %s" %  args['XMLRPC_SERVER']
 
 class SSHKnownHosts:
        def __init__(self, args = args):
                self.args = args
                self.read_knownhosts()
-               self.auth = {}
-               self.auth['AuthMethod'] = "anonymous"
+               self.auth = auth
                self.api = xmlrpclib.Server(args['XMLRPC_SERVER'], verbose=False, allow_none=True)
                self.nodenetworks = {}
 
@@ -86,12 +99,14 @@ class SSHKnownHosts:
 
        def delete(self, host):
                node = self.getNodes(host) 
-               (host, ip, _, _) = self._record_from_node(node[0])
-               index = "%s,%s" % (host,ip)
-               if index in self.pl_keys:
-                       del self.pl_keys[index]
-               if index in self.other_keys:
-                       del self.other_keys[index]
+               if len(node) > 0:
+                       (host, ip, _, _) = self._record_from_node(node[0])
+                       index = "%s,%s" % (host,ip)
+                       if index in self.pl_keys:
+                               del self.pl_keys[index]
+                       if index in self.other_keys:
+                               del self.other_keys[index]
+               return node
 
        def updateDirect(self, host):
                cmd = os.popen("/usr/bin/ssh-keyscan -t rsa %s 2>/dev/null" % host)
@@ -105,28 +120,30 @@ class SSHKnownHosts:
                self.other_keys.update(rec)
 
        def update(self, host):
-               node = self.getNodes(host) 
-               ret = self._record_from_node(node[0])
-               (host, ip, key, comment)  = ret
-               if ip == None:
-                       self.updateDirect(host)
-               else:
-                       rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
-                       self.pl_keys.update(rec)
+               node = self.delete(host)
+               #node = self.getNodes(host) 
+               if node is not []:
+                       ret = self._record_from_node(node[0])
+                       (host, ip, key, comment)  = ret
+                       if ip == None:
+                               self.updateDirect(host)
+                       else:
+                               rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
+                               self.pl_keys.update(rec)
 
        def getNodes(self, host=None):
                if type(host) == type(""): host = [host]
 
                # get the node(s) info
-               nodes = self.api.GetNodes(self.auth,host,["hostname","ssh_rsa_key","nodenetwork_ids"])
+               nodes = self.api.GetNodes(self.auth,host,["hostname","ssh_rsa_key","interface_ids"])
 
                # for each node's node network, update the self.nodenetworks cache
                nodenetworks = []
                for node in nodes:
-                       for net in node["nodenetwork_ids"]:
+                       for net in node["interface_ids"]:
                                nodenetworks.append(net)
 
-               plcnodenetworks = self.api.GetNodeNetworks(self.auth,nodenetworks,["nodenetwork_id","ip"])
+               plcnodenetworks = self.api.GetInterfaces(self.auth,nodenetworks,["nodenetwork_id","ip"])
                for n in plcnodenetworks:
                        self.nodenetworks[n["nodenetwork_id"]]=n
                return nodes
@@ -135,10 +152,10 @@ class SSHKnownHosts:
                host = node['hostname']
                key = node['ssh_rsa_key']
 
-               nodenetworks = node['nodenetwork_ids']
+               nodenetworks = node['interface_ids']
                if len(nodenetworks)==0: return (host, None, None, None)
 
-               # the [0] subscript to node['nodenetwork_ids'] means
+               # the [0] subscript to node['interface_ids'] means
                # that this function wont work with multihomed nodes
                l_nw = self.nodenetworks.get(nodenetworks[0],None)
                if l_nw is None: return (host, None, None, None)
@@ -150,7 +167,7 @@ class SSHKnownHosts:
 
                key = key.strip()
                # TODO: check for '==' at end of key.
-               if key[-1] != '=':
+               if len(key) > 0 and key[-1] != '=':
                        print "Host with corrupt key! for %s %s" % (node['boot_state'], node['hostname'])
 
                s_date = time.strftime("%Y/%m/%d_%H:%M:%S",time.gmtime(time.time()))
@@ -163,7 +180,7 @@ def main(hosts):
        k = SSHKnownHosts()
        if len (hosts) > 0:
                for host in hosts:
-                       k.update(host)
+                       k.updateDirect(host)
        else:
                k.updateAll()
        k.write()