controllers should allow refreshes while findall is running.
[monitor.git] / monitor / getsshkeys.py
1 #!/usr/bin/python
2
3 import os
4 import sys
5 import string
6 import time
7 import xml, xmlrpclib
8 try:
9         from monitor import config
10         auth = {'Username'   : config.API_AUTH_USER,
11                 'AuthMethod' : "password",
12                         'AuthString' : config.API_AUTH_PASSWORD}
13 except:
14         import traceback
15         print traceback.print_exc()
16         auth = {'AuthMethod' : "anonymous"}
17
18
19 class SSHKnownHosts:
20         def __init__(self, args = None):
21                 if not args: args = {}
22                 args['known_hosts'] =  os.environ['HOME'] + os.sep + ".ssh" + os.sep + "known_hosts"
23                 try:
24                         from monitor import config
25                         args['XMLRPC_SERVER'] = config.API_SERVER
26                 except:
27                         args['XMLRPC_SERVER'] = 'https://boot.planet-lab.org/PLCAPI/'
28                         print "Using default API server %s" %  args['XMLRPC_SERVER']
29                 self.args = args
30                 self.read_knownhosts()
31                 self.auth = auth
32                 self.api = xmlrpclib.Server(args['XMLRPC_SERVER'], verbose=False, allow_none=True)
33                 self.nodenetworks = {}
34
35         def _split_kh_entry(self, line):
36                 s = line.split(' ')
37                 try:
38                         (host,ip) = s[0].split(',')
39                 except:
40                         ip = s[0]
41                         host = ""
42
43                 key = ' '.join(s[1:3])
44                 comment = ' '.join(s[3:])
45                 return (host, ip, key, comment)
46
47         def _get_index(self, host, ip):
48                 index = ""
49                 if host is not "":
50                         index = "%s,%s" % (host,ip)
51                 else:
52                         index = ip
53                 return index
54                 
55         def read_knownhosts(self):
56                 kh_read = open(self.args["known_hosts"], 'r')
57                 self.pl_keys = {}
58                 self.other_keys = {}
59                 for line in kh_read:
60                         (host, ip, key, comment) = self._split_kh_entry(line[:-1])
61                         rec = { self._get_index(host, ip) : "%s %s" % (key, comment) }
62                         if 'PlanetLab' in comment:
63                                 self.pl_keys.update(rec)
64                         else:
65                                 self.other_keys.update(rec)
66
67                 #for i in self.pl_keys:
68                 #       print i
69                 #       print self.pl_keys[i]
70
71                 return
72
73         def write(self):
74                 self.write_knownhosts()
75
76         def write_knownhosts(self):
77                 f = open(self.args['known_hosts'], 'w')
78                 for index in self.pl_keys:
79                         print >>f, "%s %s" % (index, self.pl_keys[index])
80                 for index in self.other_keys:
81                         print >>f, "%s %s" % (index, self.other_keys[index])
82                 f.close()
83
84         def updateAll(self):
85                 l_nodes = self.getNodes() 
86                 d_nodes = {}
87                 nokey_list = []
88                 for node in l_nodes:
89                         name = node['hostname']
90                         d_nodes[name] = node
91
92                 for host in d_nodes:
93                         node = d_nodes[host]
94                         (host, ip, key, comment) = self._record_from_node(node, nokey_list)
95                         rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
96                         self.pl_keys.update(rec)
97
98                 return nokey_list
99
100         def delete(self, host):
101                 node = self.getNodes(host) 
102                 if len(node) > 0:
103                         (host, ip, _, _) = self._record_from_node(node[0])
104                         index = "%s,%s" % (host,ip)
105                         if index in self.pl_keys:
106                                 del self.pl_keys[index]
107                         if index in self.other_keys:
108                                 del self.other_keys[index]
109                 return node
110
111         def updateDirect(self, host):
112                 cmd = os.popen("/usr/bin/ssh-keyscan -t rsa %s 2>/dev/null" % host)
113                 line = cmd.read()
114                 (h,  ip,  key,  comment) = self._split_kh_entry(line[:-1])
115                 node = self.getNodes(host)
116                 (host2, ip2, x, x) = self._record_from_node(node[0])
117                 rec = { self._get_index(host2, ip2) : "%s %s" % (key, "DIRECT") }
118
119                 self.delete(host)
120                 self.other_keys.update(rec)
121
122         def update(self, host):
123                 node = self.delete(host)
124                 #node = self.getNodes(host) 
125                 if node is not []:
126                         ret = self._record_from_node(node[0])
127                         (host, ip, key, comment)  = ret
128                         if ip == None:
129                                 self.updateDirect(host)
130                         else:
131                                 rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
132                                 self.pl_keys.update(rec)
133
134         def getNodes(self, host=None):
135                 if type(host) == type(""): host = [host]
136
137                 # get the node(s) info
138                 nodes = self.api.GetNodes(self.auth,host,["hostname","ssh_rsa_key","interface_ids"])
139
140                 # for each node's node network, update the self.nodenetworks cache
141                 nodenetworks = []
142                 for node in nodes:
143                         for net in node["interface_ids"]:
144                                 nodenetworks.append(net)
145
146                 plcnodenetworks = self.api.GetInterfaces(self.auth,nodenetworks,["interface_id","ip"])
147                 for n in plcnodenetworks:
148                         self.nodenetworks[n["interface_id"]]=n
149                 return nodes
150
151         def _record_from_node(self, node, nokey_list=None):
152                 host = node['hostname']
153                 key = node['ssh_rsa_key']
154
155                 nodenetworks = node['interface_ids']
156                 if len(nodenetworks)==0: return (host, None, None, None)
157
158                 # the [0] subscript to node['interface_ids'] means
159                 # that this function wont work with multihomed nodes
160                 l_nw = self.nodenetworks.get(nodenetworks[0],None)
161                 if l_nw is None: return (host, None, None, None)
162                 ip = l_nw['ip']
163
164                 if key == None:
165                         if nokey_list is not None: nokey_list += [node]
166                         return (host, ip, None, None)
167
168                 key = key.strip()
169                 # TODO: check for '==' at end of key.
170                 if len(key) > 0 and key[-1] != '=':
171                         print "Host with corrupt key! for %s %s" % (node['boot_state'], node['hostname'])
172
173                 s_date = time.strftime("%Y/%m/%d_%H:%M:%S",time.gmtime(time.time()))
174                 #rec = { "%s,%s" % (host,ip) : "%s %s" % (key, "PlanetLab_%s" % (s_date)) }
175                 #return rec
176                 return (host, ip, key, "PlanetLab_%s" % s_date) 
177
178
179 def main(hosts):
180         k = SSHKnownHosts()
181         if len (hosts) > 0:
182                 for host in hosts:
183                         k.updateDirect(host)
184         else:
185                 k.updateAll()
186         k.write()
187
188 if __name__ == '__main__':
189         main(sys.argv[1:])