for access to the www.printbadnodes module
[monitor.git] / getsshkeys.py
1 #!/usr/bin/python
2
3 import os
4 import sys
5 import string
6 import time
7 import xml, xmlrpclib
8
9 args = {}
10 args['known_hosts'] =  os.environ['HOME'] + os.sep + ".ssh" + os.sep + "known_hosts"
11 args['XMLRPC_SERVER'] = 'https://www.planet-lab.org/PLCAPI/'
12
13 class SSHKnownHosts:
14         def __init__(self, args = args):
15                 self.args = args
16                 self.read_knownhosts()
17                 self.auth = {}
18                 self.auth['AuthMethod'] = "anonymous"
19                 self.api = xmlrpclib.Server(args['XMLRPC_SERVER'], verbose=False, allow_none=True)
20                 self.nodenetworks = {}
21
22         def _split_kh_entry(self, line):
23                 s = line.split(' ')
24                 try:
25                         (host,ip) = s[0].split(',')
26                 except:
27                         ip = s[0]
28                         host = ""
29
30                 key = ' '.join(s[1:3])
31                 comment = ' '.join(s[3:])
32                 return (host, ip, key, comment)
33
34         def _get_index(self, host, ip):
35                 index = ""
36                 if host is not "":
37                         index = "%s,%s" % (host,ip)
38                 else:
39                         index = ip
40                 return index
41                 
42         def read_knownhosts(self):
43                 kh_read = open(self.args["known_hosts"], 'r')
44                 self.pl_keys = {}
45                 self.other_keys = {}
46                 for line in kh_read:
47                         (host, ip, key, comment) = self._split_kh_entry(line[:-1])
48                         rec = { self._get_index(host, ip) : "%s %s" % (key, comment) }
49                         if 'PlanetLab' in comment:
50                                 self.pl_keys.update(rec)
51                         else:
52                                 self.other_keys.update(rec)
53
54                 #for i in self.pl_keys:
55                 #       print i
56                 #       print self.pl_keys[i]
57
58                 return
59
60         def write(self):
61                 self.write_knownhosts()
62
63         def write_knownhosts(self):
64                 f = open(self.args['known_hosts'], 'w')
65                 for index in self.pl_keys:
66                         print >>f, "%s %s" % (index, self.pl_keys[index])
67                 for index in self.other_keys:
68                         print >>f, "%s %s" % (index, self.other_keys[index])
69                 f.close()
70
71         def updateAll(self):
72                 l_nodes = self.getNodes() 
73                 d_nodes = {}
74                 nokey_list = []
75                 for node in l_nodes:
76                         name = node['hostname']
77                         d_nodes[name] = node
78
79                 for host in d_nodes:
80                         node = d_nodes[host]
81                         (host, ip, key, comment) = self._record_from_node(node, nokey_list)
82                         rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
83                         self.pl_keys.update(rec)
84
85                 return nokey_list
86
87         def delete(self, host):
88                 node = self.getNodes(host) 
89                 (host, ip, _, _) = self._record_from_node(node[0])
90                 index = "%s,%s" % (host,ip)
91                 if index in self.pl_keys:
92                         del self.pl_keys[index]
93                 if index in self.other_keys:
94                         del self.other_keys[index]
95
96         def updateDirect(self, host):
97                 cmd = os.popen("/usr/bin/ssh-keyscan -t rsa %s 2>/dev/null" % host)
98                 line = cmd.read()
99                 (h,  ip,  key,  comment) = self._split_kh_entry(line[:-1])
100                 node = self.getNodes(host)
101                 (host2, ip2, x, x) = self._record_from_node(node[0])
102                 rec = { self._get_index(host2, ip2) : "%s %s" % (key, "DIRECT") }
103
104                 self.delete(host)
105                 self.other_keys.update(rec)
106
107         def update(self, host):
108                 node = self.getNodes(host) 
109                 ret = self._record_from_node(node[0])
110                 (host, ip, key, comment)  = ret
111                 if ip == None:
112                         self.updateDirect(host)
113                 else:
114                         rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
115                         self.pl_keys.update(rec)
116
117         def getNodes(self, host=None):
118                 if type(host) == type(""): host = [host]
119
120                 # get the node(s) info
121                 nodes = self.api.GetNodes(self.auth,host,["hostname","ssh_rsa_key","nodenetwork_ids"])
122
123                 # for each node's node network, update the self.nodenetworks cache
124                 nodenetworks = []
125                 for node in nodes:
126                         for net in node["nodenetwork_ids"]:
127                                 nodenetworks.append(net)
128
129                 plcnodenetworks = self.api.GetNodeNetworks(self.auth,nodenetworks,["nodenetwork_id","ip"])
130                 for n in plcnodenetworks:
131                         self.nodenetworks[n["nodenetwork_id"]]=n
132                 return nodes
133
134         def _record_from_node(self, node, nokey_list=None):
135                 host = node['hostname']
136                 key = node['ssh_rsa_key']
137
138                 nodenetworks = node['nodenetwork_ids']
139                 if len(nodenetworks)==0: return (host, None, None, None)
140
141                 # the [0] subscript to node['nodenetwork_ids'] means
142                 # that this function wont work with multihomed nodes
143                 l_nw = self.nodenetworks.get(nodenetworks[0],None)
144                 if l_nw is None: return (host, None, None, None)
145                 ip = l_nw['ip']
146
147                 if key == None:
148                         if nokey_list is not None: nokey_list += [node]
149                         return (host, ip, None, None)
150
151                 key = key.strip()
152                 # TODO: check for '==' at end of key.
153                 if key[-1] != '=':
154                         print "Host with corrupt key! for %s %s" % (node['boot_state'], node['hostname'])
155
156                 s_date = time.strftime("%Y/%m/%d_%H:%M:%S",time.gmtime(time.time()))
157                 #rec = { "%s,%s" % (host,ip) : "%s %s" % (key, "PlanetLab_%s" % (s_date)) }
158                 #return rec
159                 return (host, ip, key, "PlanetLab_%s" % s_date) 
160
161
162 def main(hosts):
163         k = SSHKnownHosts()
164         if len (hosts) > 0:
165                 for host in hosts:
166                         k.update(host)
167         else:
168                 k.updateAll()
169         k.write()
170
171 if __name__ == '__main__':
172         main(sys.argv[1:])