9f7b07fcbc5d3948021d6d771aed41689d774859
[bootmanager.git] / source / BootServerRequest.py
1 #!/usr/bin/python
2 #
3 # Copyright (c) 2003 Intel Corporation
4 # All rights reserved.
5 #
6 # Copyright (c) 2004-2006 The Trustees of Princeton University
7 # All rights reserved.
8
9 from __future__ import print_function
10
11 import sys
12 import os
13 import re
14 import string
15 import urllib
16 import tempfile
17
18 import pycurl
19
20 # if there is no cStringIO, fall back to the original
21 try:
22     from cStringIO import StringIO
23 except:
24     from StringIO import StringIO
25
26
27
28 class BootServerRequest:
29
30     # all possible places to check the cdrom mount point.
31     # /mnt/cdrom is typically after the machine has come up,
32     # and /usr is when the boot cd is running
33     CDROM_MOUNT_PATH = ("/mnt/cdrom/", "/usr/")
34     BOOTSERVER_CERTS = {}
35     MONITORSERVER_CERTS = {}
36     BOOTCD_VERSION = ""
37     HTTP_SUCCESS = 200
38     HAS_BOOTCD = 0
39     USE_PROXY = 0
40     PROXY = 0
41
42     # in seconds, how maximum time allowed for connect
43     DEFAULT_CURL_CONNECT_TIMEOUT = 30
44     # in seconds, maximum time allowed for any transfer
45     DEFAULT_CURL_MAX_TRANSFER_TIME = 3600
46     # location of curl executable, if pycurl isn't available
47     # and the DownloadFile method is called (backup, only
48     # really need for the boot cd environment where pycurl
49     # doesn't exist
50     CURL_CMD = 'curl'
51
52     # use TLSv1 and not SSLv3 anymore
53     CURL_SSL_VERSION = pycurl.SSLVERSION_TLSv1
54
55     def __init__(self, vars, verbose=0):
56
57         self.VERBOSE = verbose
58         self.VARS = vars
59
60         # see if we have a boot cd mounted by checking for the version file
61         # if HAS_BOOTCD == 0 then either the machine doesn't have
62         # a boot cd, or something else is mounted
63         self.HAS_BOOTCD = 0
64
65         for path in self.CDROM_MOUNT_PATH:
66             self.Message("Checking existance of boot cd on {}".format(path))
67
68             os.system("/bin/mount {} > /dev/null 2>&1".format(path))
69
70             version_file = self.VARS['BOOTCD_VERSION_FILE'].format(path=path)
71             self.Message("Looking for version file {}".format(version_file))
72
73             if os.access(version_file, os.R_OK) == 0:
74                 self.Message("No boot cd found.")
75             else:
76                 self.Message("Found boot cd.")
77                 self.HAS_BOOTCD = 1
78                 break
79
80         if self.HAS_BOOTCD:
81
82             # check the version of the boot cd, and locate the certs
83             self.Message("Getting boot cd version.")
84
85             versionRegExp = re.compile(r"PlanetLab BootCD v(\S+)")
86
87             bootcd_version_f = file(version_file, "r")
88             line = string.strip(bootcd_version_f.readline())
89             bootcd_version_f.close()
90
91             match = versionRegExp.findall(line)
92             if match:
93                 (self.BOOTCD_VERSION) = match[0]
94
95             # right now, all the versions of the bootcd are supported,
96             # so no need to check it
97
98             self.Message("Getting server from configuration")
99
100             bootservers = [ self.VARS['BOOT_SERVER'] ]
101             for bootserver in bootservers:
102                 bootserver = string.strip(bootserver)
103                 cacert_path = "{}/{}/{}".format(
104                     self.VARS['SERVER_CERT_DIR'].format(path=path),
105                     bootserver,
106                     self.VARS['CACERT_NAME'])
107                 if os.access(cacert_path, os.R_OK):
108                     self.BOOTSERVER_CERTS[bootserver] = cacert_path
109
110             monitorservers = [ self.VARS['MONITOR_SERVER'] ]
111             for monitorserver in monitorservers:
112                 monitorserver = string.strip(monitorserver)
113                 cacert_path = "{}/{}/{}".format(
114                     self.VARS['SERVER_CERT_DIR'].format(path=path),
115                     monitorserver,
116                     self.VARS['CACERT_NAME'])
117                 if os.access(cacert_path, os.R_OK):
118                     self.MONITORSERVER_CERTS[monitorserver] = cacert_path
119
120             self.Message("Set of servers to contact: {}".format(self.BOOTSERVER_CERTS))
121             self.Message("Set of servers to upload to: {}".format(self.MONITORSERVER_CERTS))
122         else:
123             self.Message("Using default boot server address.")
124             self.BOOTSERVER_CERTS[self.VARS['DEFAULT_BOOT_SERVER']] = ""
125             self.MONITORSERVER_CERTS[self.VARS['DEFAULT_BOOT_SERVER']] = ""
126
127
128     def CheckProxy(self):
129         # see if we have any proxy info from the machine
130         self.USE_PROXY = 0
131         self.Message("Checking existance of proxy config file...")
132
133         if os.access(self.VARS['PROXY_FILE'], os.R_OK) and \
134                os.path.isfile(self.VARS['PROXY_FILE']):
135             self.PROXY = string.strip(file(self.VARS['PROXY_FILE'], 'r').readline())
136             self.USE_PROXY = 1
137             self.Message("Using proxy {}.".format(self.PROXY))
138         else:
139             self.Message("Not using any proxy.")
140
141
142
143     def Message(self, Msg):
144         if(self.VERBOSE):
145             print(Msg)
146
147     def Error(self, Msg):
148         sys.stderr.write(Msg + "\n")
149
150     def Warning(self, Msg):
151         self.Error(Msg)
152
153     def MakeRequest(self, PartialPath, GetVars,
154                     PostVars, DoSSL, DoCertCheck,
155                     ConnectTimeout=DEFAULT_CURL_CONNECT_TIMEOUT,
156                     MaxTransferTime=DEFAULT_CURL_MAX_TRANSFER_TIME,
157                     FormData=None):
158
159         fd, buffer_name = tempfile.mkstemp("MakeRequest-XXXXXX")
160         os.close(fd)
161         buffer = open(buffer_name, "w+b")
162
163         # the file "buffer_name" will be deleted by DownloadFile()
164
165         ok = self.DownloadFile(PartialPath, GetVars, PostVars,
166                                DoSSL, DoCertCheck, buffer_name,
167                                ConnectTimeout,
168                                MaxTransferTime,
169                                FormData)
170
171         # check the ok code, return the string only if it was successfull
172         if ok:
173             buffer.seek(0)
174             ret = buffer.read()
175         else:
176             ret = None
177
178         buffer.close()
179         try:
180             # just in case it is not deleted by DownloadFile()
181             os.unlink(buffer_name)
182         except OSError:
183             pass
184
185         return ret
186
187     def DownloadFile(self, PartialPath, GetVars, PostVars,
188                      DoSSL, DoCertCheck, DestFilePath,
189                      ConnectTimeout = DEFAULT_CURL_CONNECT_TIMEOUT,
190                      MaxTransferTime = DEFAULT_CURL_MAX_TRANSFER_TIME,
191                      FormData = None):
192
193         self.Message("Attempting to retrieve {}".format(PartialPath))
194
195         # we can't do ssl and check the cert if we don't have a bootcd
196         if DoSSL and DoCertCheck and not self.HAS_BOOTCD:
197             self.Error("No boot cd exists (needed to use -c and -s.\n")
198             return 0
199
200         # ConnectTimeout has to be greater than 0
201         if ConnectTimeout <= 0:
202             self.Error("Connect timeout must be greater than zero.\n")
203             return 0
204
205
206         self.CheckProxy()
207
208         dopostdata = 0
209
210         # setup the post and get vars for the request
211         if PostVars:
212             dopostdata = 1
213             postdata = urllib.urlencode(PostVars)
214             self.Message("Posting data:\n{}\n".format(postdata))
215
216         getstr = ""
217         if GetVars:
218             getstr = "?" + urllib.urlencode(GetVars)
219             self.Message("Get data:\n{}\n".format(getstr))
220
221         # now, attempt to make the request, starting at the first
222         # server in the list
223         if FormData:
224             cert_list = self.MONITORSERVER_CERTS
225         else:
226             cert_list = self.BOOTSERVER_CERTS
227
228         for server in cert_list:
229             self.Message("Contacting server {}.".format(server))
230
231             certpath = cert_list[server]
232
233
234             # output what we are going to be doing
235             self.Message("Connect timeout is {} seconds".format(ConnectTimeout))
236             self.Message("Max transfer time is {} seconds".format(MaxTransferTime))
237
238             if DoSSL:
239                 url = "https://{}/{}{}".format(server, PartialPath, getstr)
240
241                 if DoCertCheck:
242                     self.Message("Using SSL version {} and verifying peer."
243                                  .format(self.CURL_SSL_VERSION))
244                 else:
245                     self.Message("Using SSL version {}."
246                                  .format(self.CURL_SSL_VERSION))
247             else:
248                 url = "http://{}/{}{}".format(server, PartialPath, getstr)
249
250             self.Message("URL: {}".format(url))
251
252             # setup a new pycurl instance
253             curl = pycurl.Curl()
254
255             # don't want curl sending any signals
256             curl.setopt(pycurl.NOSIGNAL, 1)
257
258             curl.setopt(pycurl.CONNECTTIMEOUT, ConnectTimeout)
259             curl.setopt(pycurl.TIMEOUT, MaxTransferTime)
260
261             # do not follow location when attempting to download a file
262             curl.setopt(pycurl.FOLLOWLOCATION, 0)
263
264             if self.USE_PROXY:
265                 curl.setopt(pycurl.PROXY, self.PROXY)
266
267             if DoSSL:
268                 curl.setopt(pycurl.SSLVERSION, self.CURL_SSL_VERSION)
269
270                 if DoCertCheck:
271                     curl.setopt(pycurl.CAINFO, certpath)
272                     curl.setopt(pycurl.SSL_VERIFYPEER, 2)
273                 else:
274                     curl.setopt(pycurl.SSL_VERIFYPEER, 0)
275
276             if dopostdata:
277                 curl.setopt(pycurl.POSTFIELDS, postdata)
278
279             # setup multipart/form-data upload
280             if FormData:
281                 curl.setopt(pycurl.HTTPPOST, FormData)
282
283             curl.setopt(pycurl.URL, url)
284
285             try:
286                 # setup the output file
287                 with open(DestFilePath, "wb") as outfile:
288
289                     self.Message("Opened output file {}".format(DestFilePath))
290
291                     curl.setopt(pycurl.WRITEDATA, outfile)
292
293                     self.Message("Fetching...")
294                     curl.perform()
295                     self.Message("Done.")
296
297                     http_result = curl.getinfo(pycurl.HTTP_CODE)
298                     curl.close()
299
300                 self.Message("Results saved in {}".format(DestFilePath))
301
302                 # check the code, return 1 if successfull
303                 if http_result == self.HTTP_SUCCESS:
304                     self.Message("Successfull!")
305                     return 1
306                 else:
307                     self.Message("Failure, resultant http code: {}"
308                                  .format(http_result))
309
310             except pycurl.error as err:
311                 errno, errstr = err
312                 self.Error("connect to {} failed; curl error {}: '{}'\n"
313                            .format(server, errno, errstr))
314
315         self.Error("Unable to successfully contact any boot servers.\n")
316         return 0
317
318
319
320 def usage():
321     print(
322     """
323 Usage: BootServerRequest.py [options] <partialpath>
324 Options:
325  -c/--checkcert        Check SSL certs. Ignored if -s/--ssl missing.
326  -h/--help             This help text
327  -o/--output <file>    Write result to file
328  -s/--ssl              Make the request over HTTPS
329  -v                    Makes the operation more talkative
330 """);
331
332
333
334 if __name__ == "__main__":
335     import getopt
336
337     # check out our command line options
338     try:
339         opt_list, arg_list = getopt.getopt(sys.argv[1:],
340                                            "o:vhsc",
341                                            [ "output=", "verbose", \
342                                              "help","ssl","checkcert"])
343
344         ssl = 0
345         checkcert = 0
346         output_file = None
347         verbose = 0
348
349         for opt, arg in opt_list:
350             if opt in ("-h","--help"):
351                 usage(0)
352                 sys.exit()
353
354             if opt in ("-c","--checkcert"):
355                 checkcert = 1
356
357             if opt in ("-s","--ssl"):
358                 ssl = 1
359
360             if opt in ("-o","--output"):
361                 output_file = arg
362
363             if opt == "-v":
364                 verbose = 1
365
366         if len(arg_list) != 1:
367             raise Exception
368
369         partialpath = arg_list[0]
370         if string.lower(partialpath[:4]) == "http":
371             raise Exception
372
373     except:
374         usage()
375         sys.exit(2)
376
377     # got the command line args straightened out
378     requestor = BootServerRequest(verbose)
379
380     if output_file:
381         requestor.DownloadFile(partialpath, None, None, ssl,
382                                checkcert, output_file)
383     else:
384         result = requestor.MakeRequest(partialpath, None, None, ssl, checkcert)
385         if result:
386             print(result)
387         else:
388             sys.exit(1)