88cd081f8461493d8bc2dee35013a77a66f22c0d
[bootmanager.git] / source / BootServerRequest.py
1 #!/usr/bin/python
2 #
3 # Copyright (c) 2003 Intel Corporation
4 # All rights reserved.
5 #
6 # Copyright (c) 2004-2006 The Trustees of Princeton University
7 # All rights reserved.
8
9 import os, sys
10 import re
11 import string
12 import urllib
13 import tempfile
14
15 # try to load pycurl
16 try:
17     import pycurl
18     PYCURL_LOADED= 1
19 except:
20     PYCURL_LOADED= 0
21
22
23 # if there is no cStringIO, fall back to the original
24 try:
25     from cStringIO import StringIO
26 except:
27     from StringIO import StringIO
28
29
30
31 class BootServerRequest:
32
33     VERBOSE = 0
34
35     # all possible places to check the cdrom mount point.
36     # /mnt/cdrom is typically after the machine has come up,
37     # and /usr is when the boot cd is running
38     CDROM_MOUNT_PATH = ("/mnt/cdrom/","/usr/")
39     BOOTSERVER_CERTS= {}
40     MONITORSERVER_CERTS= {}
41     BOOTCD_VERSION=""
42     HTTP_SUCCESS=200
43     HAS_BOOTCD=0
44     USE_PROXY=0
45     PROXY=0
46
47     # in seconds, how maximum time allowed for connect
48     DEFAULT_CURL_CONNECT_TIMEOUT=30
49     # in seconds, maximum time allowed for any transfer
50     DEFAULT_CURL_MAX_TRANSFER_TIME=3600
51     # location of curl executable, if pycurl isn't available
52     # and the DownloadFile method is called (backup, only
53     # really need for the boot cd environment where pycurl
54     # doesn't exist
55     CURL_CMD = 'curl'
56     CURL_SSL_VERSION=3
57
58     def __init__(self, vars, verbose=0):
59
60         self.VERBOSE= verbose
61         self.VARS=vars
62             
63         # see if we have a boot cd mounted by checking for the version file
64         # if HAS_BOOTCD == 0 then either the machine doesn't have
65         # a boot cd, or something else is mounted
66         self.HAS_BOOTCD = 0
67
68         for path in self.CDROM_MOUNT_PATH:
69             self.Message( "Checking existance of boot cd on %s" % path )
70
71             os.system("/bin/mount %s > /dev/null 2>&1" % path )
72                 
73             version_file= self.VARS['BOOTCD_VERSION_FILE'] % {'path' : path}
74             self.Message( "Looking for version file %s" % version_file )
75
76             if os.access(version_file, os.R_OK) == 0:
77                 self.Message( "No boot cd found." );
78             else:
79                 self.Message( "Found boot cd." )
80                 self.HAS_BOOTCD=1
81                 break
82
83         if self.HAS_BOOTCD:
84
85             # check the version of the boot cd, and locate the certs
86             self.Message( "Getting boot cd version." )
87         
88             versionRegExp= re.compile(r"PlanetLab BootCD v(\S+)")
89                 
90             bootcd_version_f= file(version_file,"r")
91             line= string.strip(bootcd_version_f.readline())
92             bootcd_version_f.close()
93             
94             match= versionRegExp.findall(line)
95             if match:
96                 (self.BOOTCD_VERSION)= match[0]
97             
98             # right now, all the versions of the bootcd are supported,
99             # so no need to check it
100             
101             self.Message( "Getting server from configuration" )
102             
103             bootservers= [ self.VARS['BOOT_SERVER'] ]
104             for bootserver in bootservers:
105                 bootserver = string.strip(bootserver)
106                 cacert_path= "%s/%s/%s" % \
107                              (self.VARS['SERVER_CERT_DIR'] % {'path' : path},
108                               bootserver,self.VARS['CACERT_NAME'])
109                 if os.access(cacert_path, os.R_OK):
110                     self.BOOTSERVER_CERTS[bootserver]= cacert_path
111
112             monitorservers= [ self.VARS['MONITOR_SERVER'] ]
113             for monitorserver in monitorservers:
114                 monitorserver = string.strip(monitorserver)
115                 cacert_path= "%s/%s/%s" % \
116                              (self.VARS['SERVER_CERT_DIR'] % {'path' : path},
117                               monitorserver,self.VARS['CACERT_NAME'])
118                 if os.access(cacert_path, os.R_OK):
119                     self.MONITORSERVER_CERTS[monitorserver]= cacert_path
120
121             self.Message( "Set of servers to contact: %s" %
122                           str(self.BOOTSERVER_CERTS) )
123             self.Message( "Set of servers to upload to: %s" %
124                           str(self.MONITORSERVER_CERTS) )
125         else:
126             self.Message( "Using default boot server address." )
127             self.BOOTSERVER_CERTS[self.VARS['DEFAULT_BOOT_SERVER']]= ""
128             self.MONITORSERVER_CERTS[self.VARS['DEFAULT_BOOT_SERVER']]= ""
129
130
131     def CheckProxy( self ):
132         # see if we have any proxy info from the machine
133         self.USE_PROXY= 0
134         self.Message( "Checking existance of proxy config file..." )
135         
136         if os.access(self.VARS['PROXY_FILE'], os.R_OK) and \
137                os.path.isfile(self.VARS['PROXY_FILE']):
138             self.PROXY= string.strip(file(self.VARS['PROXY_FILE'],'r').readline())
139             self.USE_PROXY= 1
140             self.Message( "Using proxy %s." % self.PROXY)
141         else:
142             self.Message( "Not using any proxy." )
143
144
145
146     def Message( self, Msg ):
147         if( self.VERBOSE ):
148             print( Msg )
149
150
151
152     def Error( self, Msg ):
153         sys.stderr.write( Msg + "\n" )
154
155
156
157     def Warning( self, Msg ):
158         self.Error(Msg)
159
160
161
162     def MakeRequest( self, PartialPath, GetVars,
163                      PostVars, DoSSL, DoCertCheck,
164                      ConnectTimeout= DEFAULT_CURL_CONNECT_TIMEOUT,
165                      MaxTransferTime= DEFAULT_CURL_MAX_TRANSFER_TIME,
166                      FormData= None):
167
168         (fd, buffer_name) = tempfile.mkstemp("MakeRequest-XXXXXX")
169         os.close(fd)
170         buffer = open(buffer_name, "w+b")
171
172         # the file "buffer_name" will be deleted by DownloadFile()
173
174         ok = self.DownloadFile(PartialPath, GetVars, PostVars,
175                                DoSSL, DoCertCheck, buffer_name,
176                                ConnectTimeout,
177                                MaxTransferTime,
178                                FormData)
179
180         # check the ok code, return the string only if it was successfull
181         if ok:
182             buffer.seek(0)
183             ret = buffer.read()
184         else:
185             ret = None
186
187         buffer.close()
188         try:
189             # just in case it is not deleted by DownloadFile()
190             os.unlink(buffer_name)
191         except OSError:
192             pass
193             
194         return ret
195
196     def DownloadFile(self, PartialPath, GetVars, PostVars,
197                      DoSSL, DoCertCheck, DestFilePath,
198                      ConnectTimeout= DEFAULT_CURL_CONNECT_TIMEOUT,
199                      MaxTransferTime= DEFAULT_CURL_MAX_TRANSFER_TIME,
200                      FormData= None):
201
202         self.Message( "Attempting to retrieve %s" % PartialPath )
203
204         # we can't do ssl and check the cert if we don't have a bootcd
205         if DoSSL and DoCertCheck and not self.HAS_BOOTCD:
206             self.Error( "No boot cd exists (needed to use -c and -s.\n" )
207             return 0
208
209         if DoSSL and not PYCURL_LOADED:
210             self.Warning( "Using SSL without pycurl will by default " \
211                           "check at least standard certs." )
212
213         # ConnectTimeout has to be greater than 0
214         if ConnectTimeout <= 0:
215             self.Error( "Connect timeout must be greater than zero.\n" )
216             return 0
217
218
219         self.CheckProxy()
220
221         dopostdata= 0
222
223         # setup the post and get vars for the request
224         if PostVars:
225             dopostdata= 1
226             postdata = urllib.urlencode(PostVars)
227             self.Message( "Posting data:\n%s\n" % postdata )
228             
229         getstr= ""
230         if GetVars:
231             getstr= "?" + urllib.urlencode(GetVars)
232             self.Message( "Get data:\n%s\n" % getstr )
233
234         # now, attempt to make the request, starting at the first
235         # server in the list
236         if FormData:
237             cert_list = self.MONITORSERVER_CERTS
238         else:
239             cert_list = self.BOOTSERVER_CERTS
240         
241         for server in cert_list:
242             self.Message( "Contacting server %s." % server )
243                         
244             certpath = cert_list[server]
245
246             
247             # output what we are going to be doing
248             self.Message( "Connect timeout is %s seconds" % \
249                           ConnectTimeout )
250
251             self.Message( "Max transfer time is %s seconds" % \
252                           MaxTransferTime )
253
254             if DoSSL:
255                 url = "https://%s/%s%s" % (server,PartialPath,getstr)
256                 
257                 if DoCertCheck and PYCURL_LOADED:
258                     self.Message( "Using SSL version %d and verifying peer." %
259                              self.CURL_SSL_VERSION )
260                 else:
261                     self.Message( "Using SSL version %d." %
262                              self.CURL_SSL_VERSION )
263             else:
264                 url = "http://%s/%s%s" % (server,PartialPath,getstr)
265                 
266             self.Message( "URL: %s" % url )
267             
268             # setup a new pycurl instance, or a curl command line string
269             # if we don't have pycurl
270             
271             if PYCURL_LOADED:
272                 curl= pycurl.Curl()
273
274                 # don't want curl sending any signals
275                 curl.setopt(pycurl.NOSIGNAL, 1)
276             
277                 curl.setopt(pycurl.CONNECTTIMEOUT, ConnectTimeout)
278                 curl.setopt(pycurl.TIMEOUT, MaxTransferTime)
279
280                 # do not follow location when attempting to download a file
281                 curl.setopt(pycurl.FOLLOWLOCATION, 0)
282
283                 if self.USE_PROXY:
284                     curl.setopt(pycurl.PROXY, self.PROXY )
285
286                 if DoSSL:
287                     curl.setopt(pycurl.SSLVERSION, self.CURL_SSL_VERSION)
288                 
289                     if DoCertCheck:
290                         curl.setopt(pycurl.CAINFO, certpath)
291                         curl.setopt(pycurl.SSL_VERIFYPEER, 2)
292                         
293                     else:
294                         curl.setopt(pycurl.SSL_VERIFYPEER, 0)
295                 
296                 if dopostdata:
297                     curl.setopt(pycurl.POSTFIELDS, postdata)
298
299                 # setup multipart/form-data upload
300                 if FormData:
301                     curl.setopt(pycurl.HTTPPOST, FormData)
302
303                 curl.setopt(pycurl.URL, url)
304             else:
305
306                 cmdline = "%s " \
307                           "--connect-timeout %d " \
308                           "--max-time %d " \
309                           "--header Pragma: " \
310                           "--output %s " \
311                           "--fail " % \
312                           (self.CURL_CMD, ConnectTimeout,
313                            MaxTransferTime, DestFilePath)
314
315                 if dopostdata:
316                     cmdline = cmdline + "--data '" + postdata + "' "
317
318                 if FormData:
319                     cmdline = cmdline + "".join(["--form '" + field + "' " for field in FormData])
320
321                 if not self.VERBOSE:
322                     cmdline = cmdline + "--silent "
323                     
324                 if self.USE_PROXY:
325                     cmdline = cmdline + "--proxy %s " % self.PROXY
326
327                 if DoSSL:
328                     cmdline = cmdline + "--sslv%d " % self.CURL_SSL_VERSION
329                     if DoCertCheck:
330                         cmdline = cmdline + "--cacert %s " % certpath
331                  
332                 cmdline = cmdline + url
333
334                 self.Message( "curl command: %s" % cmdline )
335                 
336                 
337             if PYCURL_LOADED:
338                 try:
339                     # setup the output file
340                     outfile = open(DestFilePath,"wb")
341                     
342                     self.Message( "Opened output file %s" % DestFilePath )
343                 
344                     curl.setopt(pycurl.WRITEDATA, outfile)
345                 
346                     self.Message( "Fetching..." )
347                     curl.perform()
348                     self.Message( "Done." )
349                 
350                     http_result= curl.getinfo(pycurl.HTTP_CODE)
351                     curl.close()
352                 
353                     outfile.close()
354                     self.Message( "Results saved in %s" % DestFilePath )
355
356                     # check the code, return 1 if successfull
357                     if http_result == self.HTTP_SUCCESS:
358                         self.Message( "Successfull!" )
359                         return 1
360                     else:
361                         self.Message( "Failure, resultant http code: %d" % \
362                                       http_result )
363
364                 except pycurl.error, err:
365                     errno, errstr= err
366                     self.Error( "connect to %s failed; curl error %d: '%s'\n" %
367                        (server,errno,errstr) )
368         
369                 if not outfile.closed:
370                     try:
371                         os.unlink(DestFilePath)
372                         outfile.close()
373                     except OSError:
374                         pass
375
376             else:
377                 self.Message( "Fetching..." )
378                 rc = os.system(cmdline)
379                 self.Message( "Done." )
380                 
381                 if rc != 0:
382                     try:
383                         os.unlink( DestFilePath )
384                     except OSError:
385                         pass
386                     self.Message( "Failure, resultant curl code: %d" % rc )
387                     self.Message( "Removed %s" % DestFilePath )
388                 else:
389                     self.Message( "Successfull!" )
390                     return 1
391             
392         self.Error( "Unable to successfully contact any boot servers.\n" )
393         return 0
394
395
396
397
398 def usage():
399     print(
400     """
401 Usage: BootServerRequest.py [options] <partialpath>
402 Options:
403  -c/--checkcert        Check SSL certs. Ignored if -s/--ssl missing.
404  -h/--help             This help text
405  -o/--output <file>    Write result to file
406  -s/--ssl              Make the request over HTTPS
407  -v                    Makes the operation more talkative
408 """);  
409
410
411
412 if __name__ == "__main__":
413     import getopt
414     
415     # check out our command line options
416     try:
417         opt_list, arg_list = getopt.getopt(sys.argv[1:],
418                                            "o:vhsc",
419                                            [ "output=", "verbose", \
420                                              "help","ssl","checkcert"])
421
422         ssl= 0
423         checkcert= 0
424         output_file= None
425         verbose= 0
426         
427         for opt, arg in opt_list:
428             if opt in ("-h","--help"):
429                 usage(0)
430                 sys.exit()
431             
432             if opt in ("-c","--checkcert"):
433                 checkcert= 1
434             
435             if opt in ("-s","--ssl"):
436                 ssl= 1
437
438             if opt in ("-o","--output"):
439                 output_file= arg
440
441             if opt == "-v":
442                 verbose= 1
443     
444         if len(arg_list) != 1:
445             raise Exception
446
447         partialpath= arg_list[0]
448         if string.lower(partialpath[:4]) == "http":
449             raise Exception
450
451     except:
452         usage()
453         sys.exit(2)
454
455     # got the command line args straightened out
456     requestor= BootServerRequest(verbose)
457         
458     if output_file:
459         requestor.DownloadFile( partialpath, None, None, ssl,
460                                 checkcert, output_file)
461     else:
462         result= requestor.MakeRequest( partialpath, None, None, ssl, checkcert)
463         if result:
464             print result
465         else:
466             sys.exit(1)