A few changes to improve upon the script:
authorMarc Fiuczynski <mef@cs.princeton.edu>
Tue, 6 May 2008 02:55:18 +0000 (02:55 +0000)
committerMarc Fiuczynski <mef@cs.princeton.edu>
Tue, 6 May 2008 02:55:18 +0000 (02:55 +0000)
- try to make it stand alone python script
  - uses xmlrpc directly; no longer needs to import plc module

- fetches nodenetworks for all hosts and caches it locally
  to avoid having to invoke the API n times (where n is the
  # of nodes at the PLC).

Still needs:

- a proper help/usage message printed

- a way to export full functionality (e.g., delete)

- a way to specify XMLRPC_SERVER as a command line option, as
  now it by default assumes www.planet-lab.org/PLCAPI

getsshkeys.py

index 540ba31..d0084db 100755 (executable)
@@ -4,15 +4,20 @@ import os
 import sys
 import string
 import time
-import plc
+import xml, xmlrpclib
 
 args = {}
-args['known_hosts'] =  os.environ['HOME'] + "/.ssh/known_hosts"
+args['known_hosts'] =  os.environ['HOME'] + os.sep + ".ssh" + os.sep + "known_hosts"
+args['XMLRPC_SERVER'] = 'https://www.planet-lab.org/PLCAPI/'
 
 class SSHKnownHosts:
        def __init__(self, args = args):
                self.args = args
                self.read_knownhosts()
+               self.auth = {}
+               self.auth['AuthMethod'] = "anonymous"
+               self.api = xmlrpclib.Server(args['XMLRPC_SERVER'], verbose=False, allow_none=True)
+               self.nodenetworks = {}
 
        def _split_kh_entry(self, line):
                s = line.split(' ')
@@ -64,7 +69,7 @@ class SSHKnownHosts:
                f.close()
 
        def updateAll(self):
-               l_nodes = plc.getNodes() 
+               l_nodes = self.getNodes() 
                d_nodes = {}
                nokey_list = []
                for node in l_nodes:
@@ -80,7 +85,7 @@ class SSHKnownHosts:
                return nokey_list
 
        def delete(self, host):
-               node = plc.getNodes(host) 
+               node = self.getNodes(host) 
                (host, ip, _, _) = self._record_from_node(node[0])
                index = "%s,%s" % (host,ip)
                if index in self.pl_keys:
@@ -92,7 +97,7 @@ class SSHKnownHosts:
                cmd = os.popen("/usr/bin/ssh-keyscan -t rsa %s 2>/dev/null" % host)
                line = cmd.read()
                (h,  ip,  key,  comment) = self._split_kh_entry(line[:-1])
-               node = plc.getNodes(host)
+               node = self.getNodes(host)
                (host2, ip2, x, x) = self._record_from_node(node[0])
                rec = { self._get_index(host2, ip2) : "%s %s" % (key, "DIRECT") }
 
@@ -100,7 +105,7 @@ class SSHKnownHosts:
                self.other_keys.update(rec)
 
        def update(self, host):
-               node = plc.getNodes(host) 
+               node = self.getNodes(host) 
                ret = self._record_from_node(node[0])
                (host, ip, key, comment)  = ret
                if ip == None:
@@ -109,16 +114,35 @@ class SSHKnownHosts:
                        rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
                        self.pl_keys.update(rec)
 
+       def getNodes(self, host=None):
+               if type(host) == type(""): host = [host]
+
+               # get the node(s) info
+               nodes = self.api.GetNodes(self.auth,host,["hostname","ssh_rsa_key","nodenetwork_ids"])
+
+               # for each node's node network, update the self.nodenetworks cache
+               nodenetworks = []
+               for node in nodes:
+                       for net in node["nodenetwork_ids"]:
+                               nodenetworks.append(net)
+
+               plcnodenetworks = self.api.GetNodeNetworks(self.auth,nodenetworks,["nodenetwork_id","ip"])
+               for n in plcnodenetworks:
+                       self.nodenetworks[n["nodenetwork_id"]]=n
+               return nodes
+
        def _record_from_node(self, node, nokey_list=None):
                host = node['hostname']
                key = node['ssh_rsa_key']
 
-               l_nw = plc.getNodeNetworks({'nodenetwork_id':node['nodenetwork_ids']})
-               if len(l_nw) == 0:
-                       # No network for this node. So, skip it.
-                       return (host, None, None, None)
+               nodenetworks = node['nodenetwork_ids']
+               if len(nodenetworks)==0: return (host, None, None, None)
 
-               ip = l_nw[0]['ip']
+               # the [0] subscript to node['nodenetwork_ids'] means
+               # that this function wont work with multihomed nodes
+               l_nw = self.nodenetworks.get(nodenetworks[0],None)
+               if l_nw is None: return (host, None, None, None)
+               ip = l_nw['ip']
 
                if key == None:
                        if nokey_list is not None: nokey_list += [node]
@@ -135,19 +159,14 @@ class SSHKnownHosts:
                return (host, ip, key, "PlanetLab_%s" % s_date) 
 
 
-def main():
+def main(hosts):
        k = SSHKnownHosts()
-       nokey_list = k.updateAll()
+       if len (hosts) > 0:
+               for host in hosts:
+                       k.update(host)
+       else:
+               k.updateAll()
+       k.write()
 
-       for node in nokey_list:
-               print "%5s %s" % (node['boot_state'], node['hostname'])
-       
 if __name__ == '__main__':
-       #main()
-       k = SSHKnownHosts()
-       #print "update"
-       #k.update('planetlab-4.cs.princeton.edu')
-       #print "updateDirect"
-       k.update(sys.argv[1])
-       #k.updateDirect(sys.argv[1])
-       k.write()
+       main(sys.argv[1:])