keep the comment as it is
[playground.git] / omf_keys / get_slice_pub_keys.py
1 #!/usr/bin/env python
2 #
3 # plessh will run the given shell command on all Planetlab Europe
4 # nodes.
5 #
6
7 import sys
8 import os
9 import xmlrpclib
10 from threading import Thread
11 from getpass import getpass
12
13 PLC_HOST = "www.planet-lab.eu"
14 SLICE_NAME = ""
15 KEY_FILE = "~/.ssh/id_rsa"
16 BASE_CMD = 'ssh -q %(options)s -i %(key)s -l %(slice_name)s %(host)s "%(cmd)s" 2> /dev/null'
17 THREAD_COUNT = 10
18 SSH_OPTIONS = {
19     "BatchMode": "yes",
20     "StrictHostKeyChecking": "no",
21     "ConnectTimeout": 15,
22     "UserKnownHostsFile": "/dev/null",
23     "CheckHostIP": "no"
24     }
25
26 class API:
27     def __init__(self, hostname, username, password):
28         self.host = hostname
29         self.api = xmlrpclib.ServerProxy("https://%s/PLCAPI/" % self.host, allow_none=True)
30         self.auth = {
31             "Username": username,
32             "AuthString": password,
33             "AuthMethod": "password"
34             }
35         # this will raise an exception if not successful
36         self.api.AuthCheck(self.auth)
37
38     def wrap(self, function):
39         def wrapper(*args):
40             args = (self.auth, ) + args
41             return function(*args)
42         return wrapper
43
44     def __getattr__(self, attr):
45         return self.wrap(getattr(self.api, attr))
46
47
48 def getAPI(hostname, username=None, password=None):
49     if not username: username=raw_input("Please enter your PLC username: ")
50     if not password: password=getpass("Please enter your PLC password: ")
51     while True:
52         try:
53             return API(hostname=hostname, username=username, password=password)
54         except xmlrpclib.Fault,e:
55             print e
56             print "Please try again."
57             username=raw_input("Please enter your PLC username: ")
58             password=getpass("Please enter your PLC password: ")
59             continue
60
61
62 def change_key_comment(output, hostname):
63     fields = output.split()
64     new_fields = []
65
66     in_key = False
67     in_comment = False
68     for f in fields:
69         f = f.strip()
70         if f.startswith("ssh-"):
71             in_key = True
72         elif in_key:
73             in_key = False
74             in_comment = True
75         elif in_comment:
76             f = hostname
77         new_fields.append(f)
78
79     return " ".join(new_fields)
80     
81
82
83 class Command(Thread):
84     def __init__(self, cmd, hosts=[]):
85         Thread.__init__(self)
86         self.hosts = hosts
87         self.cmd = cmd
88         self.results = {}
89
90     def runCmd(self, host):
91         cmd = BASE_CMD % {'cmd': self.cmd,
92                           'host': host,
93                           'key': KEY_FILE,
94                           'slice_name': SLICE_NAME,
95                           'options':  " ".join(["-o %s=%s" % (o,v) for o,v in SSH_OPTIONS.items()])}
96         p = os.popen(cmd)
97         output = p.read().strip()
98         if output:
99             output = output.split('\n')[0] # only get the first line
100 # keep the comment for the keys
101 #            output = change_key_comment(output, host)
102             return output
103         return None
104
105     def run(self):
106         for host in self.hosts:
107             self.results[host] = self.runCmd(host)
108
109
110 def distributeHosts(num_hosts):
111     dist_list = []
112     host_per_thread = num_hosts / THREAD_COUNT
113     rest = num_hosts % THREAD_COUNT
114     for i in range(THREAD_COUNT):
115         c = host_per_thread
116         if rest:
117             c += 1
118             rest -= 1
119         dist_list.append(c)
120     return dist_list
121
122
123 def main(plchost, slice_name, print_report=lambda:None):
124     ple = getAPI(plchost)
125     slice_nodes = []
126     try:
127         node_ids = ple.GetSlices(slice_name)[0]['node_ids']
128         slice_nodes = [n['hostname'] for n in ple.GetNodes(node_ids)]
129     except IndexError:
130         print "Can not get nodes for slice: %s" % slice_name
131         sys.exit(1)
132
133     dist_list = distributeHosts(len(slice_nodes))
134     thread_list = []
135     index = 0
136     cmd = "cat /home/%s/.ssh/id_rsa.pub" % slice_name
137     print "Please wait gathering public keys from %d PlanetLab nodes. This may take a while..." % len(slice_nodes)
138     for i in range(THREAD_COUNT):
139         current = Command(cmd, slice_nodes[index:index+dist_list[i]])
140         index += dist_list[i]
141         thread_list.append(current)
142         current.start()
143     for i in thread_list:
144         i.join()
145
146     results = {}
147     for i in thread_list:
148         results.update(i.results)
149
150     print_report(results)
151
152 if __name__ == "__main__":
153     def save_report(results, fname="slice_keys"):
154         ok = 0
155         f = open(fname, "w")
156         for node in results.keys():
157             if results[node]:
158                 f.write("%s\n" % results[node])
159                 ok += 1
160         f.close()
161         print "Could get keys from %d of %d nodes" % (ok, len(results))
162         print "Please see %s file." % fname
163
164     def get_option(msg, default):
165         ret = raw_input("%s [%s] : " % (msg, default))
166         if not ret:
167             ret = default
168         return ret
169             
170     PLC_HOST=get_option("Please enter your PLC host", PLC_HOST)
171     SLICE_NAME=get_option("Please enter your slice name", SLICE_NAME)
172     KEY_FILE=get_option("Please enter the path of your key", KEY_FILE)
173     main(PLC_HOST, SLICE_NAME, print_report=save_report)
174