#!/usr/bin/python
-import os
import sys
-import string
-import time
-import xml, xmlrpclib
-try:
- from monitor import config
- auth = {'Username' : config.API_AUTH_USER,
- 'AuthMethod' : "password",
- 'AuthString' : config.API_AUTH_PASSWORD}
-except:
- import traceback
- print traceback.print_exc()
- auth = {'AuthMethod' : "anonymous"}
-
-args = {}
-args['known_hosts'] = os.environ['HOME'] + os.sep + ".ssh" + os.sep + "known_hosts"
-try:
- from monitor import config
- args['XMLRPC_SERVER'] = config.API_SERVER
-except:
- args['XMLRPC_SERVER'] = 'https://boot.planet-lab.org/PLCAPI/'
- print "Using default API server %s" % args['XMLRPC_SERVER']
-
-class SSHKnownHosts:
- def __init__(self, args = args):
- self.args = args
- self.read_knownhosts()
- self.auth = auth
- self.api = xmlrpclib.Server(args['XMLRPC_SERVER'], verbose=False, allow_none=True)
- self.nodenetworks = {}
-
- def _split_kh_entry(self, line):
- s = line.split(' ')
- try:
- (host,ip) = s[0].split(',')
- except:
- ip = s[0]
- host = ""
-
- key = ' '.join(s[1:3])
- comment = ' '.join(s[3:])
- return (host, ip, key, comment)
-
- def _get_index(self, host, ip):
- index = ""
- if host is not "":
- index = "%s,%s" % (host,ip)
- else:
- index = ip
- return index
-
- def read_knownhosts(self):
- kh_read = open(self.args["known_hosts"], 'r')
- self.pl_keys = {}
- self.other_keys = {}
- for line in kh_read:
- (host, ip, key, comment) = self._split_kh_entry(line[:-1])
- rec = { self._get_index(host, ip) : "%s %s" % (key, comment) }
- if 'PlanetLab' in comment:
- self.pl_keys.update(rec)
- else:
- self.other_keys.update(rec)
-
- #for i in self.pl_keys:
- # print i
- # print self.pl_keys[i]
-
- return
-
- def write(self):
- self.write_knownhosts()
-
- def write_knownhosts(self):
- f = open(self.args['known_hosts'], 'w')
- for index in self.pl_keys:
- print >>f, "%s %s" % (index, self.pl_keys[index])
- for index in self.other_keys:
- print >>f, "%s %s" % (index, self.other_keys[index])
- f.close()
-
- def updateAll(self):
- l_nodes = self.getNodes()
- d_nodes = {}
- nokey_list = []
- for node in l_nodes:
- name = node['hostname']
- d_nodes[name] = node
-
- for host in d_nodes:
- node = d_nodes[host]
- (host, ip, key, comment) = self._record_from_node(node, nokey_list)
- rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
- self.pl_keys.update(rec)
-
- return nokey_list
-
- def delete(self, host):
- node = self.getNodes(host)
- if len(node) > 0:
- (host, ip, _, _) = self._record_from_node(node[0])
- index = "%s,%s" % (host,ip)
- if index in self.pl_keys:
- del self.pl_keys[index]
- if index in self.other_keys:
- del self.other_keys[index]
- return node
-
- def updateDirect(self, host):
- cmd = os.popen("/usr/bin/ssh-keyscan -t rsa %s 2>/dev/null" % host)
- line = cmd.read()
- (h, ip, key, comment) = self._split_kh_entry(line[:-1])
- node = self.getNodes(host)
- (host2, ip2, x, x) = self._record_from_node(node[0])
- rec = { self._get_index(host2, ip2) : "%s %s" % (key, "DIRECT") }
-
- self.delete(host)
- self.other_keys.update(rec)
-
- def update(self, host):
- node = self.delete(host)
- #node = self.getNodes(host)
- if node is not []:
- ret = self._record_from_node(node[0])
- (host, ip, key, comment) = ret
- if ip == None:
- self.updateDirect(host)
- else:
- rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
- self.pl_keys.update(rec)
-
- def getNodes(self, host=None):
- if type(host) == type(""): host = [host]
-
- # get the node(s) info
- nodes = self.api.GetNodes(self.auth,host,["hostname","ssh_rsa_key","interface_ids"])
-
- # for each node's node network, update the self.nodenetworks cache
- nodenetworks = []
- for node in nodes:
- for net in node["interface_ids"]:
- nodenetworks.append(net)
-
- plcnodenetworks = self.api.GetInterfaces(self.auth,nodenetworks,["interface_id","ip"])
- for n in plcnodenetworks:
- self.nodenetworks[n["interface_id"]]=n
- return nodes
-
- def _record_from_node(self, node, nokey_list=None):
- host = node['hostname']
- key = node['ssh_rsa_key']
-
- nodenetworks = node['interface_ids']
- if len(nodenetworks)==0: return (host, None, None, None)
-
- # the [0] subscript to node['interface_ids'] means
- # that this function wont work with multihomed nodes
- l_nw = self.nodenetworks.get(nodenetworks[0],None)
- if l_nw is None: return (host, None, None, None)
- ip = l_nw['ip']
-
- if key == None:
- if nokey_list is not None: nokey_list += [node]
- return (host, ip, None, None)
-
- key = key.strip()
- # TODO: check for '==' at end of key.
- if len(key) > 0 and key[-1] != '=':
- print "Host with corrupt key! for %s %s" % (node['boot_state'], node['hostname'])
-
- s_date = time.strftime("%Y/%m/%d_%H:%M:%S",time.gmtime(time.time()))
- #rec = { "%s,%s" % (host,ip) : "%s %s" % (key, "PlanetLab_%s" % (s_date)) }
- #return rec
- return (host, ip, key, "PlanetLab_%s" % s_date)
-
+from monitor.util.sshkeys import SSHKnownHosts
def main(hosts):
k = SSHKnownHosts()