remove old code that was filling in for when pycurl is not available
[bootmanager.git] / source / BootServerRequest.py
1 #!/usr/bin/python
2 #
3 # Copyright (c) 2003 Intel Corporation
4 # All rights reserved.
5 #
6 # Copyright (c) 2004-2006 The Trustees of Princeton University
7 # All rights reserved.
8
9 from __future__ import print_function
10
11 import os, sys
12 import re
13 import string
14 import urllib
15 import tempfile
16
17 import pycurl
18
19 # if there is no cStringIO, fall back to the original
20 try:
21     from cStringIO import StringIO
22 except:
23     from StringIO import StringIO
24
25
26
27 class BootServerRequest:
28
29     # all possible places to check the cdrom mount point.
30     # /mnt/cdrom is typically after the machine has come up,
31     # and /usr is when the boot cd is running
32     CDROM_MOUNT_PATH = ("/mnt/cdrom/", "/usr/")
33     BOOTSERVER_CERTS = {}
34     MONITORSERVER_CERTS = {}
35     BOOTCD_VERSION = ""
36     HTTP_SUCCESS = 200
37     HAS_BOOTCD = 0
38     USE_PROXY = 0
39     PROXY = 0
40
41     # in seconds, how maximum time allowed for connect
42     DEFAULT_CURL_CONNECT_TIMEOUT = 30
43     # in seconds, maximum time allowed for any transfer
44     DEFAULT_CURL_MAX_TRANSFER_TIME = 3600
45     # location of curl executable, if pycurl isn't available
46     # and the DownloadFile method is called (backup, only
47     # really need for the boot cd environment where pycurl
48     # doesn't exist
49     CURL_CMD = 'curl'
50
51     # use TLSv1 and not SSLv3 anymore
52     CURL_SSL_VERSION = pycurl.SSLVERSION_TLSv1
53
54     def __init__(self, vars, verbose=0):
55
56         self.VERBOSE = verbose
57         self.VARS = vars
58             
59         # see if we have a boot cd mounted by checking for the version file
60         # if HAS_BOOTCD == 0 then either the machine doesn't have
61         # a boot cd, or something else is mounted
62         self.HAS_BOOTCD = 0
63
64         for path in self.CDROM_MOUNT_PATH:
65             self.Message("Checking existance of boot cd on {}".format(path))
66
67             os.system("/bin/mount {} > /dev/null 2>&1".format(path))
68                 
69             version_file = self.VARS['BOOTCD_VERSION_FILE'].format(path=path)
70             self.Message("Looking for version file {}".format(version_file))
71
72             if os.access(version_file, os.R_OK) == 0:
73                 self.Message("No boot cd found.");
74             else:
75                 self.Message("Found boot cd.")
76                 self.HAS_BOOTCD = 1
77                 break
78
79         if self.HAS_BOOTCD:
80
81             # check the version of the boot cd, and locate the certs
82             self.Message("Getting boot cd version.")
83         
84             versionRegExp = re.compile(r"PlanetLab BootCD v(\S+)")
85                 
86             bootcd_version_f = file(version_file, "r")
87             line = string.strip(bootcd_version_f.readline())
88             bootcd_version_f.close()
89             
90             match = versionRegExp.findall(line)
91             if match:
92                 (self.BOOTCD_VERSION) = match[0]
93             
94             # right now, all the versions of the bootcd are supported,
95             # so no need to check it
96             
97             self.Message("Getting server from configuration")
98             
99             bootservers = [ self.VARS['BOOT_SERVER'] ]
100             for bootserver in bootservers:
101                 bootserver = string.strip(bootserver)
102                 cacert_path = "{}/{}/{}".format(
103                     self.VARS['SERVER_CERT_DIR'].format(path=path),
104                     bootserver,
105                     self.VARS['CACERT_NAME'])
106                 if os.access(cacert_path, os.R_OK):
107                     self.BOOTSERVER_CERTS[bootserver] = cacert_path
108
109             monitorservers = [ self.VARS['MONITOR_SERVER'] ]
110             for monitorserver in monitorservers:
111                 monitorserver = string.strip(monitorserver)
112                 cacert_path = "{}/{}/{}".format(
113                     self.VARS['SERVER_CERT_DIR'].format(path=path),
114                     monitorserver,
115                     self.VARS['CACERT_NAME'])
116                 if os.access(cacert_path, os.R_OK):
117                     self.MONITORSERVER_CERTS[monitorserver] = cacert_path
118
119             self.Message("Set of servers to contact: {}".format(self.BOOTSERVER_CERTS))
120             self.Message("Set of servers to upload to: {}".format(self.MONITORSERVER_CERTS))
121         else:
122             self.Message("Using default boot server address.")
123             self.BOOTSERVER_CERTS[self.VARS['DEFAULT_BOOT_SERVER']] = ""
124             self.MONITORSERVER_CERTS[self.VARS['DEFAULT_BOOT_SERVER']] = ""
125
126
127     def CheckProxy(self):
128         # see if we have any proxy info from the machine
129         self.USE_PROXY = 0
130         self.Message("Checking existance of proxy config file...")
131         
132         if os.access(self.VARS['PROXY_FILE'], os.R_OK) and \
133                os.path.isfile(self.VARS['PROXY_FILE']):
134             self.PROXY = string.strip(file(self.VARS['PROXY_FILE'], 'r').readline())
135             self.USE_PROXY = 1
136             self.Message("Using proxy {}.".format(self.PROXY))
137         else:
138             self.Message("Not using any proxy.")
139
140
141
142     def Message(self, Msg):
143         if(self.VERBOSE):
144             print(Msg)
145
146     def Error(self, Msg):
147         sys.stderr.write(Msg + "\n")
148
149     def Warning(self, Msg):
150         self.Error(Msg)
151
152     def MakeRequest(self, PartialPath, GetVars,
153                     PostVars, DoSSL, DoCertCheck,
154                     ConnectTimeout = DEFAULT_CURL_CONNECT_TIMEOUT,
155                     MaxTransferTime = DEFAULT_CURL_MAX_TRANSFER_TIME,
156                     FormData = None):
157
158         fd, buffer_name = tempfile.mkstemp("MakeRequest-XXXXXX")
159         os.close(fd)
160         buffer = open(buffer_name, "w+b")
161
162         # the file "buffer_name" will be deleted by DownloadFile()
163
164         ok = self.DownloadFile(PartialPath, GetVars, PostVars,
165                                DoSSL, DoCertCheck, buffer_name,
166                                ConnectTimeout,
167                                MaxTransferTime,
168                                FormData)
169
170         # check the ok code, return the string only if it was successfull
171         if ok:
172             buffer.seek(0)
173             ret = buffer.read()
174         else:
175             ret = None
176
177         buffer.close()
178         try:
179             # just in case it is not deleted by DownloadFile()
180             os.unlink(buffer_name)
181         except OSError:
182             pass
183             
184         return ret
185
186     def DownloadFile(self, PartialPath, GetVars, PostVars,
187                      DoSSL, DoCertCheck, DestFilePath,
188                      ConnectTimeout = DEFAULT_CURL_CONNECT_TIMEOUT,
189                      MaxTransferTime = DEFAULT_CURL_MAX_TRANSFER_TIME,
190                      FormData = None):
191
192         self.Message("Attempting to retrieve {}".format(PartialPath))
193
194         # we can't do ssl and check the cert if we don't have a bootcd
195         if DoSSL and DoCertCheck and not self.HAS_BOOTCD:
196             self.Error("No boot cd exists (needed to use -c and -s.\n")
197             return 0
198
199         # ConnectTimeout has to be greater than 0
200         if ConnectTimeout <= 0:
201             self.Error("Connect timeout must be greater than zero.\n")
202             return 0
203
204
205         self.CheckProxy()
206
207         dopostdata = 0
208
209         # setup the post and get vars for the request
210         if PostVars:
211             dopostdata = 1
212             postdata = urllib.urlencode(PostVars)
213             self.Message("Posting data:\n{}\n".format(postdata))
214             
215         getstr = ""
216         if GetVars:
217             getstr = "?" + urllib.urlencode(GetVars)
218             self.Message("Get data:\n{}\n".format(getstr))
219
220         # now, attempt to make the request, starting at the first
221         # server in the list
222         if FormData:
223             cert_list = self.MONITORSERVER_CERTS
224         else:
225             cert_list = self.BOOTSERVER_CERTS
226         
227         for server in cert_list:
228             self.Message("Contacting server {}.".format(server))
229                         
230             certpath = cert_list[server]
231
232             
233             # output what we are going to be doing
234             self.Message("Connect timeout is {} seconds".format(ConnectTimeout))
235             self.Message("Max transfer time is {} seconds".format(MaxTransferTime))
236
237             if DoSSL:
238                 url = "https://{}/{}{}".format(server, PartialPath, getstr)
239                 
240                 if DoCertCheck:
241                     self.Message("Using SSL version {} and verifying peer."
242                                  .format(self.CURL_SSL_VERSION))
243                 else:
244                     self.Message("Using SSL version {}."
245                                  .format(self.CURL_SSL_VERSION))
246             else:
247                 url = "http://{}/{}{}".format(server, PartialPath, getstr)
248                 
249             self.Message("URL: {}".format(url))
250             
251             # setup a new pycurl instance
252             curl = pycurl.Curl()
253
254             # don't want curl sending any signals
255             curl.setopt(pycurl.NOSIGNAL, 1)
256
257             curl.setopt(pycurl.CONNECTTIMEOUT, ConnectTimeout)
258             curl.setopt(pycurl.TIMEOUT, MaxTransferTime)
259
260             # do not follow location when attempting to download a file
261             curl.setopt(pycurl.FOLLOWLOCATION, 0)
262
263             if self.USE_PROXY:
264                 curl.setopt(pycurl.PROXY, self.PROXY)
265
266             if DoSSL:
267                 curl.setopt(pycurl.SSLVERSION, self.CURL_SSL_VERSION)
268                 
269                 if DoCertCheck:
270                     curl.setopt(pycurl.CAINFO, certpath)
271                     curl.setopt(pycurl.SSL_VERIFYPEER, 2)
272                 else:
273                     curl.setopt(pycurl.SSL_VERIFYPEER, 0)
274                 
275             if dopostdata:
276                 curl.setopt(pycurl.POSTFIELDS, postdata)
277
278             # setup multipart/form-data upload
279             if FormData:
280                 curl.setopt(pycurl.HTTPPOST, FormData)
281
282             curl.setopt(pycurl.URL, url)
283
284             try:
285                 # setup the output file
286                 with open(DestFilePath,"wb") as outfile:
287                 
288                     self.Message("Opened output file {}".format(DestFilePath))
289             
290                     curl.setopt(pycurl.WRITEDATA, outfile)
291             
292                     self.Message("Fetching...")
293                     curl.perform()
294                     self.Message("Done.")
295             
296                     http_result = curl.getinfo(pycurl.HTTP_CODE)
297                     curl.close()
298             
299                 self.Message("Results saved in {}".format(DestFilePath))
300
301                 # check the code, return 1 if successfull
302                 if http_result == self.HTTP_SUCCESS:
303                     self.Message("Successfull!")
304                     return 1
305                 else:
306                     self.Message("Failure, resultant http code: {}"
307                                  .format(http_result))
308
309             except pycurl.error as err:
310                 errno, errstr = err
311                 self.Error("connect to {} failed; curl error {}: '{}'\n"
312                            .format(server, errno, errstr))
313     
314         self.Error("Unable to successfully contact any boot servers.\n")
315         return 0
316
317
318
319 def usage():
320     print(
321     """
322 Usage: BootServerRequest.py [options] <partialpath>
323 Options:
324  -c/--checkcert        Check SSL certs. Ignored if -s/--ssl missing.
325  -h/--help             This help text
326  -o/--output <file>    Write result to file
327  -s/--ssl              Make the request over HTTPS
328  -v                    Makes the operation more talkative
329 """);  
330
331
332
333 if __name__ == "__main__":
334     import getopt
335     
336     # check out our command line options
337     try:
338         opt_list, arg_list = getopt.getopt(sys.argv[1:],
339                                            "o:vhsc",
340                                            [ "output=", "verbose", \
341                                              "help","ssl","checkcert"])
342
343         ssl = 0
344         checkcert = 0
345         output_file = None
346         verbose = 0
347         
348         for opt, arg in opt_list:
349             if opt in ("-h","--help"):
350                 usage(0)
351                 sys.exit()
352             
353             if opt in ("-c","--checkcert"):
354                 checkcert = 1
355             
356             if opt in ("-s","--ssl"):
357                 ssl = 1
358
359             if opt in ("-o","--output"):
360                 output_file = arg
361
362             if opt == "-v":
363                 verbose = 1
364     
365         if len(arg_list) != 1:
366             raise Exception
367
368         partialpath = arg_list[0]
369         if string.lower(partialpath[:4]) == "http":
370             raise Exception
371
372     except:
373         usage()
374         sys.exit(2)
375
376     # got the command line args straightened out
377     requestor = BootServerRequest(verbose)
378         
379     if output_file:
380         requestor.DownloadFile(partialpath, None, None, ssl,
381                                 checkcert, output_file)
382     else:
383         result = requestor.MakeRequest(partialpath, None, None, ssl, checkcert)
384         if result:
385             print(result)
386         else:
387             sys.exit(1)