7b6531da33549dc6b147f10bf92045c3cd866c85
[bootmanager.git] / source / BootServerRequest.py
1 #!/usr/bin/python
2 #
3 # Copyright (c) 2003 Intel Corporation
4 # All rights reserved.
5 #
6 # Copyright (c) 2004-2006 The Trustees of Princeton University
7 # All rights reserved.
8
9 from __future__ import print_function
10
11 import os, sys
12 import re
13 import string
14 import urllib
15 import tempfile
16
17 # try to load pycurl
18 try:
19     import pycurl
20     PYCURL_LOADED = 1
21 except:
22     PYCURL_LOADED = 0
23
24
25 # if there is no cStringIO, fall back to the original
26 try:
27     from cStringIO import StringIO
28 except:
29     from StringIO import StringIO
30
31
32
33 class BootServerRequest:
34
35     VERBOSE = 0
36
37     # all possible places to check the cdrom mount point.
38     # /mnt/cdrom is typically after the machine has come up,
39     # and /usr is when the boot cd is running
40     CDROM_MOUNT_PATH = ("/mnt/cdrom/", "/usr/")
41     BOOTSERVER_CERTS = {}
42     MONITORSERVER_CERTS = {}
43     BOOTCD_VERSION = ""
44     HTTP_SUCCESS = 200
45     HAS_BOOTCD = 0
46     USE_PROXY = 0
47     PROXY = 0
48
49     # in seconds, how maximum time allowed for connect
50     DEFAULT_CURL_CONNECT_TIMEOUT = 30
51     # in seconds, maximum time allowed for any transfer
52     DEFAULT_CURL_MAX_TRANSFER_TIME = 3600
53     # location of curl executable, if pycurl isn't available
54     # and the DownloadFile method is called (backup, only
55     # really need for the boot cd environment where pycurl
56     # doesn't exist
57     CURL_CMD = 'curl'
58     CURL_SSL_VERSION = 3
59
60     def __init__(self, vars, verbose=0):
61
62         self.VERBOSE = verbose
63         self.VARS = vars
64             
65         # see if we have a boot cd mounted by checking for the version file
66         # if HAS_BOOTCD == 0 then either the machine doesn't have
67         # a boot cd, or something else is mounted
68         self.HAS_BOOTCD = 0
69
70         for path in self.CDROM_MOUNT_PATH:
71             self.Message("Checking existance of boot cd on {}".format(path))
72
73             os.system("/bin/mount {} > /dev/null 2>&1".format(path))
74                 
75             version_file = self.VARS['BOOTCD_VERSION_FILE'].format(path=path)
76             self.Message("Looking for version file {}".format(version_file))
77
78             if os.access(version_file, os.R_OK) == 0:
79                 self.Message("No boot cd found.");
80             else:
81                 self.Message("Found boot cd.")
82                 self.HAS_BOOTCD = 1
83                 break
84
85         if self.HAS_BOOTCD:
86
87             # check the version of the boot cd, and locate the certs
88             self.Message("Getting boot cd version.")
89         
90             versionRegExp = re.compile(r"PlanetLab BootCD v(\S+)")
91                 
92             bootcd_version_f = file(version_file, "r")
93             line = string.strip(bootcd_version_f.readline())
94             bootcd_version_f.close()
95             
96             match = versionRegExp.findall(line)
97             if match:
98                 (self.BOOTCD_VERSION) = match[0]
99             
100             # right now, all the versions of the bootcd are supported,
101             # so no need to check it
102             
103             self.Message("Getting server from configuration")
104             
105             bootservers = [ self.VARS['BOOT_SERVER'] ]
106             for bootserver in bootservers:
107                 bootserver = string.strip(bootserver)
108                 cacert_path = "{}/{}/{}".format(
109                     self.VARS['SERVER_CERT_DIR'].format(path=path),
110                     bootserver,
111                     self.VARS['CACERT_NAME'])
112                 if os.access(cacert_path, os.R_OK):
113                     self.BOOTSERVER_CERTS[bootserver] = cacert_path
114
115             monitorservers = [ self.VARS['MONITOR_SERVER'] ]
116             for monitorserver in monitorservers:
117                 monitorserver = string.strip(monitorserver)
118                 cacert_path = "{}/{}/{}".format(
119                     self.VARS['SERVER_CERT_DIR'].format(path=path),
120                     monitorserver,
121                     self.VARS['CACERT_NAME'])
122                 if os.access(cacert_path, os.R_OK):
123                     self.MONITORSERVER_CERTS[monitorserver] = cacert_path
124
125             self.Message("Set of servers to contact: {}".format(self.BOOTSERVER_CERTS))
126             self.Message("Set of servers to upload to: {}".format(self.MONITORSERVER_CERTS))
127         else:
128             self.Message("Using default boot server address.")
129             self.BOOTSERVER_CERTS[self.VARS['DEFAULT_BOOT_SERVER']] = ""
130             self.MONITORSERVER_CERTS[self.VARS['DEFAULT_BOOT_SERVER']] = ""
131
132
133     def CheckProxy(self):
134         # see if we have any proxy info from the machine
135         self.USE_PROXY = 0
136         self.Message("Checking existance of proxy config file...")
137         
138         if os.access(self.VARS['PROXY_FILE'], os.R_OK) and \
139                os.path.isfile(self.VARS['PROXY_FILE']):
140             self.PROXY = string.strip(file(self.VARS['PROXY_FILE'], 'r').readline())
141             self.USE_PROXY = 1
142             self.Message("Using proxy {}.".format(self.PROXY))
143         else:
144             self.Message("Not using any proxy.")
145
146
147
148     def Message(self, Msg):
149         if(self.VERBOSE):
150             print(Msg)
151
152     def Error(self, Msg):
153         sys.stderr.write(Msg + "\n")
154
155     def Warning(self, Msg):
156         self.Error(Msg)
157
158     def MakeRequest(self, PartialPath, GetVars,
159                      PostVars, DoSSL, DoCertCheck,
160                      ConnectTimeout = DEFAULT_CURL_CONNECT_TIMEOUT,
161                      MaxTransferTime = DEFAULT_CURL_MAX_TRANSFER_TIME,
162                      FormData = None):
163
164         fd, buffer_name = tempfile.mkstemp("MakeRequest-XXXXXX")
165         os.close(fd)
166         buffer = open(buffer_name, "w+b")
167
168         # the file "buffer_name" will be deleted by DownloadFile()
169
170         ok = self.DownloadFile(PartialPath, GetVars, PostVars,
171                                DoSSL, DoCertCheck, buffer_name,
172                                ConnectTimeout,
173                                MaxTransferTime,
174                                FormData)
175
176         # check the ok code, return the string only if it was successfull
177         if ok:
178             buffer.seek(0)
179             ret = buffer.read()
180         else:
181             ret = None
182
183         buffer.close()
184         try:
185             # just in case it is not deleted by DownloadFile()
186             os.unlink(buffer_name)
187         except OSError:
188             pass
189             
190         return ret
191
192     def DownloadFile(self, PartialPath, GetVars, PostVars,
193                      DoSSL, DoCertCheck, DestFilePath,
194                      ConnectTimeout = DEFAULT_CURL_CONNECT_TIMEOUT,
195                      MaxTransferTime = DEFAULT_CURL_MAX_TRANSFER_TIME,
196                      FormData = None):
197
198         self.Message("Attempting to retrieve {}".format(PartialPath))
199
200         # we can't do ssl and check the cert if we don't have a bootcd
201         if DoSSL and DoCertCheck and not self.HAS_BOOTCD:
202             self.Error("No boot cd exists (needed to use -c and -s.\n")
203             return 0
204
205         if DoSSL and not PYCURL_LOADED:
206             self.Warning("Using SSL without pycurl will by default " \
207                           "check at least standard certs.")
208
209         # ConnectTimeout has to be greater than 0
210         if ConnectTimeout <= 0:
211             self.Error("Connect timeout must be greater than zero.\n")
212             return 0
213
214
215         self.CheckProxy()
216
217         dopostdata = 0
218
219         # setup the post and get vars for the request
220         if PostVars:
221             dopostdata = 1
222             postdata = urllib.urlencode(PostVars)
223             self.Message("Posting data:\n{}\n".format(postdata))
224             
225         getstr = ""
226         if GetVars:
227             getstr = "?" + urllib.urlencode(GetVars)
228             self.Message("Get data:\n{}\n".format(getstr))
229
230         # now, attempt to make the request, starting at the first
231         # server in the list
232         if FormData:
233             cert_list = self.MONITORSERVER_CERTS
234         else:
235             cert_list = self.BOOTSERVER_CERTS
236         
237         for server in cert_list:
238             self.Message("Contacting server {}.".format(server))
239                         
240             certpath = cert_list[server]
241
242             
243             # output what we are going to be doing
244             self.Message("Connect timeout is {} seconds".format(ConnectTimeout))
245             self.Message("Max transfer time is {} seconds".format(MaxTransferTime))
246
247             if DoSSL:
248                 url = "https://{}/{}{}".format(server, PartialPath, getstr)
249                 
250                 if DoCertCheck and PYCURL_LOADED:
251                     self.Message("Using SSL version {} and verifying peer."
252                                  .format(self.CURL_SSL_VERSION))
253                 else:
254                     self.Message("Using SSL version {}."
255                                  .format(self.CURL_SSL_VERSION))
256             else:
257                 url = "http://{}/{}{}".format(server, PartialPath, getstr)
258                 
259             self.Message("URL: {}".format(url))
260             
261             # setup a new pycurl instance, or a curl command line string
262             # if we don't have pycurl
263             
264             if PYCURL_LOADED:
265                 curl = pycurl.Curl()
266
267                 # don't want curl sending any signals
268                 curl.setopt(pycurl.NOSIGNAL, 1)
269             
270                 curl.setopt(pycurl.CONNECTTIMEOUT, ConnectTimeout)
271                 curl.setopt(pycurl.TIMEOUT, MaxTransferTime)
272
273                 # do not follow location when attempting to download a file
274                 curl.setopt(pycurl.FOLLOWLOCATION, 0)
275
276                 if self.USE_PROXY:
277                     curl.setopt(pycurl.PROXY, self.PROXY)
278
279                 if DoSSL:
280                     curl.setopt(pycurl.SSLVERSION, self.CURL_SSL_VERSION)
281                 
282                     if DoCertCheck:
283                         curl.setopt(pycurl.CAINFO, certpath)
284                         curl.setopt(pycurl.SSL_VERIFYPEER, 2)
285                         
286                     else:
287                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
288                 
289                 if dopostdata:
290                     curl.setopt(pycurl.POSTFIELDS, postdata)
291
292                 # setup multipart/form-data upload
293                 if FormData:
294                     curl.setopt(pycurl.HTTPPOST, FormData)
295
296                 curl.setopt(pycurl.URL, url)
297             else:
298
299                 cmdline = "{} " \
300                           "--connect-timeout {} " \
301                           "--max-time {} " \
302                           "--header Pragma: " \
303                           "--output {} " \
304                           "--fail "\
305                           .format(self.CURL_CMD, ConnectTimeout,
306                                   MaxTransferTime, DestFilePath)
307
308                 if dopostdata:
309                     cmdline = cmdline + "--data '" + postdata + "' "
310
311                 if FormData:
312                     cmdline = cmdline + "".join(["--form '" + field + "' " for field in FormData])
313
314                 if not self.VERBOSE:
315                     cmdline = cmdline + "--silent "
316                     
317                 if self.USE_PROXY:
318                     cmdline = cmdline + "--proxy {} ".format(self.PROXY)
319
320                 if DoSSL:
321                     cmdline = cmdline + "--sslv{} ".format(self.CURL_SSL_VERSION)
322                     if DoCertCheck:
323                         cmdline = cmdline + "--cacert {} ".format(certpath)
324                  
325                 cmdline = cmdline + url
326
327                 self.Message("curl command: {}".format(cmdline))
328                 
329                 
330             if PYCURL_LOADED:
331                 try:
332                     # setup the output file
333                     outfile = open(DestFilePath,"wb")
334                     
335                     self.Message("Opened output file {}".format(DestFilePath))
336                 
337                     curl.setopt(pycurl.WRITEDATA, outfile)
338                 
339                     self.Message("Fetching...")
340                     curl.perform()
341                     self.Message("Done.")
342                 
343                     http_result = curl.getinfo(pycurl.HTTP_CODE)
344                     curl.close()
345                 
346                     outfile.close()
347                     self.Message("Results saved in {}".format(DestFilePath))
348
349                     # check the code, return 1 if successfull
350                     if http_result == self.HTTP_SUCCESS:
351                         self.Message("Successfull!")
352                         return 1
353                     else:
354                         self.Message("Failure, resultant http code: {}"
355                                      .format(http_result))
356
357                 except pycurl.error as err:
358                     errno, errstr = err
359                     self.Error("connect to {} failed; curl error {}: '{}'\n"
360                                .format(server, errno, errstr))
361         
362                 if not outfile.closed:
363                     try:
364                         os.unlink(DestFilePath)
365                         outfile.close()
366                     except OSError:
367                         pass
368
369             else:
370                 self.Message("Fetching...")
371                 rc = os.system(cmdline)
372                 self.Message("Done.")
373                 
374                 if rc != 0:
375                     try:
376                         os.unlink(DestFilePath)
377                     except OSError:
378                         pass
379                     self.Message("Failure, resultant curl code: {}".format(rc))
380                     self.Message("Removed {}".format(DestFilePath))
381                 else:
382                     self.Message("Successfull!")
383                     return 1
384             
385         self.Error("Unable to successfully contact any boot servers.\n")
386         return 0
387
388
389
390
391 def usage():
392     print(
393     """
394 Usage: BootServerRequest.py [options] <partialpath>
395 Options:
396  -c/--checkcert        Check SSL certs. Ignored if -s/--ssl missing.
397  -h/--help             This help text
398  -o/--output <file>    Write result to file
399  -s/--ssl              Make the request over HTTPS
400  -v                    Makes the operation more talkative
401 """);  
402
403
404
405 if __name__ == "__main__":
406     import getopt
407     
408     # check out our command line options
409     try:
410         opt_list, arg_list = getopt.getopt(sys.argv[1:],
411                                            "o:vhsc",
412                                            [ "output=", "verbose", \
413                                              "help","ssl","checkcert"])
414
415         ssl = 0
416         checkcert = 0
417         output_file = None
418         verbose = 0
419         
420         for opt, arg in opt_list:
421             if opt in ("-h","--help"):
422                 usage(0)
423                 sys.exit()
424             
425             if opt in ("-c","--checkcert"):
426                 checkcert = 1
427             
428             if opt in ("-s","--ssl"):
429                 ssl = 1
430
431             if opt in ("-o","--output"):
432                 output_file = arg
433
434             if opt == "-v":
435                 verbose = 1
436     
437         if len(arg_list) != 1:
438             raise Exception
439
440         partialpath = arg_list[0]
441         if string.lower(partialpath[:4]) == "http":
442             raise Exception
443
444     except:
445         usage()
446         sys.exit(2)
447
448     # got the command line args straightened out
449     requestor = BootServerRequest(verbose)
450         
451     if output_file:
452         requestor.DownloadFile(partialpath, None, None, ssl,
453                                 checkcert, output_file)
454     else:
455         result = requestor.MakeRequest(partialpath, None, None, ssl, checkcert)
456         if result:
457             print(result)
458         else:
459             sys.exit(1)