137ea68742e3aa1877905f8423e1d787a92b4ca9
[monitor.git] / getsshkeys.py
1 #!/usr/bin/python
2
3 import os
4 import sys
5 import string
6 import time
7 import xml, xmlrpclib
8 try:
9         from monitor import config
10         auth = {'Username'   : config.API_AUTH_USER,
11                 'AuthMethod' : "password",
12                         'AuthString' : config.API_AUTH_PASSWORD}
13 except:
14         import traceback
15         print traceback.print_exc()
16         auth = {'AuthMethod' : "anonymous"}
17
18 args = {}
19 args['known_hosts'] =  os.environ['HOME'] + os.sep + ".ssh" + os.sep + "known_hosts"
20 args['XMLRPC_SERVER'] = 'https://boot.planet-lab.org/PLCAPI/'
21
22 class SSHKnownHosts:
23         def __init__(self, args = args):
24                 self.args = args
25                 self.read_knownhosts()
26                 self.auth = auth
27                 self.api = xmlrpclib.Server(args['XMLRPC_SERVER'], verbose=False, allow_none=True)
28                 self.nodenetworks = {}
29
30         def _split_kh_entry(self, line):
31                 s = line.split(' ')
32                 try:
33                         (host,ip) = s[0].split(',')
34                 except:
35                         ip = s[0]
36                         host = ""
37
38                 key = ' '.join(s[1:3])
39                 comment = ' '.join(s[3:])
40                 return (host, ip, key, comment)
41
42         def _get_index(self, host, ip):
43                 index = ""
44                 if host is not "":
45                         index = "%s,%s" % (host,ip)
46                 else:
47                         index = ip
48                 return index
49                 
50         def read_knownhosts(self):
51                 kh_read = open(self.args["known_hosts"], 'r')
52                 self.pl_keys = {}
53                 self.other_keys = {}
54                 for line in kh_read:
55                         (host, ip, key, comment) = self._split_kh_entry(line[:-1])
56                         rec = { self._get_index(host, ip) : "%s %s" % (key, comment) }
57                         if 'PlanetLab' in comment:
58                                 self.pl_keys.update(rec)
59                         else:
60                                 self.other_keys.update(rec)
61
62                 #for i in self.pl_keys:
63                 #       print i
64                 #       print self.pl_keys[i]
65
66                 return
67
68         def write(self):
69                 self.write_knownhosts()
70
71         def write_knownhosts(self):
72                 f = open(self.args['known_hosts'], 'w')
73                 for index in self.pl_keys:
74                         print >>f, "%s %s" % (index, self.pl_keys[index])
75                 for index in self.other_keys:
76                         print >>f, "%s %s" % (index, self.other_keys[index])
77                 f.close()
78
79         def updateAll(self):
80                 l_nodes = self.getNodes() 
81                 d_nodes = {}
82                 nokey_list = []
83                 for node in l_nodes:
84                         name = node['hostname']
85                         d_nodes[name] = node
86
87                 for host in d_nodes:
88                         node = d_nodes[host]
89                         (host, ip, key, comment) = self._record_from_node(node, nokey_list)
90                         rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
91                         self.pl_keys.update(rec)
92
93                 return nokey_list
94
95         def delete(self, host):
96                 node = self.getNodes(host) 
97                 if len(node) > 0:
98                         (host, ip, _, _) = self._record_from_node(node[0])
99                         index = "%s,%s" % (host,ip)
100                         if index in self.pl_keys:
101                                 del self.pl_keys[index]
102                         if index in self.other_keys:
103                                 del self.other_keys[index]
104                 return node
105
106         def updateDirect(self, host):
107                 cmd = os.popen("/usr/bin/ssh-keyscan -t rsa %s 2>/dev/null" % host)
108                 line = cmd.read()
109                 (h,  ip,  key,  comment) = self._split_kh_entry(line[:-1])
110                 node = self.getNodes(host)
111                 (host2, ip2, x, x) = self._record_from_node(node[0])
112                 rec = { self._get_index(host2, ip2) : "%s %s" % (key, "DIRECT") }
113
114                 self.delete(host)
115                 self.other_keys.update(rec)
116
117         def update(self, host):
118                 node = self.delete(host)
119                 #node = self.getNodes(host) 
120                 if node is not []:
121                         ret = self._record_from_node(node[0])
122                         (host, ip, key, comment)  = ret
123                         if ip == None:
124                                 self.updateDirect(host)
125                         else:
126                                 rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
127                                 self.pl_keys.update(rec)
128
129         def getNodes(self, host=None):
130                 if type(host) == type(""): host = [host]
131
132                 # get the node(s) info
133                 nodes = self.api.GetNodes(self.auth,host,["hostname","ssh_rsa_key","nodenetwork_ids"])
134
135                 # for each node's node network, update the self.nodenetworks cache
136                 nodenetworks = []
137                 for node in nodes:
138                         for net in node["nodenetwork_ids"]:
139                                 nodenetworks.append(net)
140
141                 plcnodenetworks = self.api.GetNodeNetworks(self.auth,nodenetworks,["nodenetwork_id","ip"])
142                 for n in plcnodenetworks:
143                         self.nodenetworks[n["nodenetwork_id"]]=n
144                 return nodes
145
146         def _record_from_node(self, node, nokey_list=None):
147                 host = node['hostname']
148                 key = node['ssh_rsa_key']
149
150                 nodenetworks = node['nodenetwork_ids']
151                 if len(nodenetworks)==0: return (host, None, None, None)
152
153                 # the [0] subscript to node['nodenetwork_ids'] means
154                 # that this function wont work with multihomed nodes
155                 l_nw = self.nodenetworks.get(nodenetworks[0],None)
156                 if l_nw is None: return (host, None, None, None)
157                 ip = l_nw['ip']
158
159                 if key == None:
160                         if nokey_list is not None: nokey_list += [node]
161                         return (host, ip, None, None)
162
163                 key = key.strip()
164                 # TODO: check for '==' at end of key.
165                 if len(key) > 0 and key[-1] != '=':
166                         print "Host with corrupt key! for %s %s" % (node['boot_state'], node['hostname'])
167
168                 s_date = time.strftime("%Y/%m/%d_%H:%M:%S",time.gmtime(time.time()))
169                 #rec = { "%s,%s" % (host,ip) : "%s %s" % (key, "PlanetLab_%s" % (s_date)) }
170                 #return rec
171                 return (host, ip, key, "PlanetLab_%s" % s_date) 
172
173
174 def main(hosts):
175         k = SSHKnownHosts()
176         if len (hosts) > 0:
177                 for host in hosts:
178                         k.updateDirect(host)
179         else:
180                 k.updateAll()
181         k.write()
182
183 if __name__ == '__main__':
184         main(sys.argv[1:])