e07d97eb156378454f5381e1b102da1d1ad8fe26
[tests.git] / system / TestSsh.py
1 # Thierry Parmentelat <thierry.parmentelat@inria.fr>
2 # Copyright (C) 2010 INRIA 
3 #
4 # class for issuing commands on a box, either local or remote
5 #
6 # the notion of 'buildname' is for providing each test run with a dir of its own
7 # buildname is generally the name of the build being tested, and can be considered unique
8 #
9 # thus 'run_in_buildname' mostly :
10 # (*) either runs locally in . - as on a local node we are already in a dedicated directory
11 # (*) or makes sure that there's a remote dir called 'buildname' and runs in it
12 #
13 # also, the copy operations
14 # (*) either do nothing if ran locally
15 # (*) or copy a local file into the remote 'buildname' 
16
17
18 import sys
19 import os.path
20 import utils
21 import shutil
22
23 class TestSsh:
24     
25     # inserts a backslash before each occurence of the following chars
26     # \ " ' < > & | ; ( ) $ * ~ 
27     @staticmethod
28     def backslash_shell_specials (command):
29         result=''
30         for char in command:
31             if char in "\\\"'<>&|;()$*~":
32                 result +='\\'+char
33             else:
34                 result +=char
35         return result
36
37     # check main IP address against the provided hostname
38     @staticmethod
39     def is_local_hostname (hostname):
40         if hostname == "localhost":
41             return True
42         import socket
43         try:
44             local_ip = socket.gethostbyname(socket.gethostname())
45             remote_ip = socket.gethostbyname(hostname)
46             return local_ip==remote_ip
47         except:
48             utils.header("WARNING : something wrong in is_local_hostname with hostname=%s"%hostname)
49             return False
50
51     # some boxes have their working space in user's homedir (/root), 
52     # some others in a dedicated area with max. space (/vservers)
53     # when root is not specified we use the homedir
54     def __init__(self,hostname,buildname=None,key=None, username=None,unknown_host=True, root=None):
55         self.hostname=hostname
56         self.buildname=buildname
57         self.key=key
58         self.username=username
59         self.unknown_host=unknown_host
60         self.root=root
61
62     def is_local(self):
63         return TestSsh.is_local_hostname(self.hostname)
64      
65     std_options="-o BatchMode=yes -o StrictHostKeyChecking=no -o CheckHostIP=no -o ConnectTimeout=5 "
66     unknown_option="-o UserKnownHostsFile=/dev/null "
67     
68     def key_part (self):
69         if not self.key:
70             return ""
71         return "-i %s "%self.key
72
73     def hostname_part (self):
74         if not self.username:
75             return self.hostname
76         else:
77             return "%s@%s"%(self.username,self.hostname)
78     
79     # command gets run on the right box
80     def actual_command (self, command, keep_stdin=False, dry_run=False,backslash=True):
81         if self.is_local():
82             return command
83         ssh_command = "ssh "
84         if not dry_run:
85             if not keep_stdin:
86                 ssh_command += "-n "
87             ssh_command += TestSsh.std_options
88             if self.unknown_host: ssh_command += TestSsh.unknown_option
89         ssh_command += self.key_part()
90         ssh_command += self.hostname_part() + " "
91         if backslash:
92             ssh_command += TestSsh.backslash_shell_specials(command)
93         else:
94             ssh_command += command
95         return ssh_command
96
97     # same in argv form
98     def actual_argv (self, argv,keep_stdin=False, dry_run=False):
99         if self.is_local():
100             return argv
101         ssh_argv=[]
102         ssh_argv.append('ssh')
103         if not dry_run:
104             if not keep_stdin: ssh_argv.append('-n')
105             ssh_argv += TestSsh.std_options.split()
106             if self.unknown_host: ssh_argv += TestSsh.unknown_option.split()
107         ssh_argv += self.key_part().split()
108         ssh_argv.append(self.hostname_part())
109         ssh_argv += argv
110         return ssh_argv
111
112     def header (self,message):
113         if not message: return
114         print "===============",message
115         sys.stdout.flush()
116
117     def run(self, command,message=None,background=False,dry_run=False):
118         local_command = self.actual_command(command, dry_run=dry_run)
119         if dry_run:
120             utils.header("DRY RUN " + local_command)
121             return 0
122         else:
123             self.header(message)
124             return utils.system(local_command,background)
125
126     def run_in_buildname (self,command, background=False, dry_run=False):
127         if self.is_local():
128             return utils.system(command,background)
129         self.create_buildname_once(dry_run)
130         return self.run("cd %s ; %s"%(self.fullname(self.buildname),command),
131                         background=background, dry_run=dry_run)
132
133     def fullname (self,dirname):
134         if self.root==None:     return dirname
135         else:                   return os.path.join(self.root,dirname)
136         
137     def mkdir (self,dirname=None,abs=False,dry_run=False):
138         if self.is_local():
139             if dirname:
140                 return os.path.mkdir(dirname)
141             return 0
142         # ab. paths remain as-is
143         if not abs:
144             if dirname:
145                 dirname="%s/%s"%(self.buildname,dirname)
146             else:
147                 dirname=self.buildname
148             dirname=self.fullname(dirname)
149         if dirname=='.': return
150         return self.run("mkdir -p %s"%dirname,dry_run=dry_run)
151
152     def rmdir (self,dirname=None, dry_run=False):
153         if self.is_local():
154             if dirname:
155                 return shutil.rmtree(dirname)
156             return 0
157         if dirname:
158             dirname="%s/%s"%(self.buildname,dirname)
159         else:
160             dirname=self.buildname
161         dirname=self.fullname(dirname)
162         return self.run("rm -rf %s"%dirname, dry_run=dry_run)
163
164     def create_buildname_once (self, dry_run):
165         if self.is_local():
166             return
167         # create remote buildname on demand
168         try:
169             self.buildname_created
170         except:
171             self.mkdir(dry_run=dry_run)
172             self.buildname_created=True
173
174     def copy (self,local_file,recursive=False,dry_run=False):
175         if self.is_local():
176             return 0
177         self.create_buildname_once(dry_run)
178         scp_command="scp "
179         if not dry_run:
180             scp_command += TestSsh.std_options
181         if recursive: scp_command += "-r "
182         scp_command += self.key_part()
183         scp_command += "%s %s:%s/%s"%(local_file,self.hostname_part(),
184                                       self.fullname(self.buildname),os.path.basename(local_file) or ".")
185         if dry_run:
186             utils.header ("DRY RUN TestSsh.copy %s"%scp_command)
187             return True
188         return utils.system(scp_command)
189
190     def copy_abs (self,local_file,remote_file,recursive=False):
191         if self.is_local():
192             dest=""
193         else:
194             dest= "%s:"%self.hostname_part()
195         scp_command="scp "
196         scp_command += TestSsh.std_options
197         if recursive: scp_command += "-r "
198         scp_command += self.key_part()
199         scp_command += "%s %s%s"%(local_file,dest,remote_file)
200         return utils.system(scp_command)
201
202     def copy_home (self, local_file, recursive=False):
203         return self.copy_abs(local_file,os.path.basename(local_file),recursive)
204
205     def fetch (self, remote_file, local_file, recursive=False, dry_run=False):
206         if self.is_local():
207             command="cp "
208             if recursive: command += "-r "
209             command += "%s %s"%(remote_file,local_file)
210         else:
211             command="scp "
212             if not dry_run:
213                 command += TestSsh.std_options
214             if recursive: command += "-r "
215             command += self.key_part()
216             # absolute path - do not preprend buildname
217             if remote_file.find("/")==0:
218                 remote_path=remote_file
219             else:
220                 remote_path="%s/%s"%(self.buildname,remote_file)
221                 remote_path=self.fullname(remote_path)
222             command += "%s:%s %s"%(self.hostname_part(),remote_path,local_file)
223         return utils.system(command)
224
225     # this is only to avoid harmless message when host cannot be identified
226     # convenience only
227     # the only place where this is needed is when tring to reach a slice in a node,
228     # which is done from the test master box
229     def clear_known_hosts (self):
230         known_hosts = "%s/.ssh/known_hosts"%os.getenv("HOME")
231         utils.header("Clearing entry for %s in %s"%(self.hostname,known_hosts))
232         return utils.system("sed -i -e /^%s/d %s"%(self.hostname,known_hosts))
233