fixes to make them more stand-alone and general.
[monitor.git] / getsshkeys.py
1 #!/usr/bin/python
2
3 import os
4 import sys
5 import string
6 import time
7 import plc
8
9 args = {}
10 args['known_hosts'] =  os.environ['HOME'] + "/.ssh/known_hosts"
11
12 class SSHKnownHosts:
13         def __init__(self, args = args):
14                 self.args = args
15                 self.read_knownhosts()
16
17         def _split_kh_entry(self, line):
18                 s = line.split(' ')
19                 try:
20                         (host,ip) = s[0].split(',')
21                 except:
22                         ip = s[0]
23                         host = ""
24
25                 key = ' '.join(s[1:3])
26                 comment = ' '.join(s[3:])
27                 return (host, ip, key, comment)
28
29         def _get_index(self, host, ip):
30                 index = ""
31                 if host is not "":
32                         index = "%s,%s" % (host,ip)
33                 else:
34                         index = ip
35                 return index
36                 
37         def read_knownhosts(self):
38                 kh_read = open(self.args["known_hosts"], 'r')
39                 self.pl_keys = {}
40                 self.other_keys = {}
41                 for line in kh_read:
42                         (host, ip, key, comment) = self._split_kh_entry(line[:-1])
43                         rec = { self._get_index(host, ip) : "%s %s" % (key, comment) }
44                         if 'PlanetLab' in comment:
45                                 self.pl_keys.update(rec)
46                         else:
47                                 self.other_keys.update(rec)
48
49                 #for i in self.pl_keys:
50                 #       print i
51                 #       print self.pl_keys[i]
52
53                 return
54
55         def write(self):
56                 self.write_knownhosts()
57
58         def write_knownhosts(self):
59                 f = open(self.args['known_hosts'], 'w')
60                 for index in self.pl_keys:
61                         print >>f, "%s %s" % (index, self.pl_keys[index])
62                 for index in self.other_keys:
63                         print >>f, "%s %s" % (index, self.other_keys[index])
64                 f.close()
65
66         def updateAll(self):
67                 l_nodes = plc.getNodes() 
68                 d_nodes = {}
69                 nokey_list = []
70                 for node in l_nodes:
71                         name = node['hostname']
72                         d_nodes[name] = node
73
74                 for host in d_nodes:
75                         node = d_nodes[host]
76                         (host, ip, key, comment) = self._record_from_node(node, nokey_list)
77                         rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
78                         self.pl_keys.update(rec)
79
80                 return nokey_list
81
82         def delete(self, host):
83                 node = plc.getNodes(host) 
84                 (host, ip, _, _) = self._record_from_node(node[0])
85                 index = "%s,%s" % (host,ip)
86                 if index in self.pl_keys:
87                         del self.pl_keys[index]
88                 if index in self.other_keys:
89                         del self.other_keys[index]
90
91         def updateDirect(self, host):
92                 cmd = os.popen("/usr/bin/ssh-keyscan -t rsa %s 2>/dev/null" % host)
93                 line = cmd.read()
94                 (h,  ip,  key,  comment) = self._split_kh_entry(line[:-1])
95                 node = plc.getNodes(host)
96                 (host2, ip2, x, x) = self._record_from_node(node[0])
97                 rec = { self._get_index(host2, ip2) : "%s %s" % (key, "DIRECT") }
98
99                 self.delete(host)
100                 self.other_keys.update(rec)
101
102         def update(self, host):
103                 node = plc.getNodes(host) 
104                 ret = self._record_from_node(node[0])
105                 (host, ip, key, comment)  = ret
106                 if ip == None:
107                         self.updateDirect(host)
108                 else:
109                         rec = { "%s,%s" % (host,ip) : "%s %s" % (key, comment) }
110                         self.pl_keys.update(rec)
111
112         def _record_from_node(self, node, nokey_list=None):
113                 host = node['hostname']
114                 key = node['ssh_rsa_key']
115
116                 l_nw = plc.getNodeNetworks({'nodenetwork_id':node['nodenetwork_ids']})
117                 if len(l_nw) == 0:
118                         # No network for this node. So, skip it.
119                         return (host, None, None, None)
120
121                 ip = l_nw[0]['ip']
122
123                 if key == None:
124                         if nokey_list is not None: nokey_list += [node]
125                         return (host, ip, None, None)
126
127                 key = key.strip()
128                 # TODO: check for '==' at end of key.
129                 if key[-1] != '=':
130                         print "Host with corrupt key! for %s %s" % (node['boot_state'], node['hostname'])
131
132                 s_date = time.strftime("%Y/%m/%d_%H:%M:%S",time.gmtime(time.time()))
133                 #rec = { "%s,%s" % (host,ip) : "%s %s" % (key, "PlanetLab_%s" % (s_date)) }
134                 #return rec
135                 return (host, ip, key, "PlanetLab_%s" % s_date) 
136
137
138 def main():
139         k = SSHKnownHosts()
140         nokey_list = k.updateAll()
141
142         for node in nokey_list:
143                 print "%5s %s" % (node['boot_state'], node['hostname'])
144         
145 if __name__ == '__main__':
146         #main()
147         k = SSHKnownHosts()
148         #print "update"
149         #k.update('planetlab-4.cs.princeton.edu')
150         #print "updateDirect"
151         k.update(sys.argv[1])
152         #k.updateDirect(sys.argv[1])
153         k.write()