c9d1e25a33cab6d3816f0a1194ab4b291539a531
[tests.git] / system / TestSsh.py
1 # Thierry Parmentelat <thierry.parmentelat@inria.fr>
2 # Copyright (C) 2010 INRIA 
3 #
4 # class for issuing commands on a box, either local or remote
5 #
6 # the notion of 'buildname' is for providing each test run with a dir of its own
7 # buildname is generally the name of the build being tested, and can be considered unique
8 #
9 # thus 'run_in_buildname' mostly :
10 # (*) either runs locally in . - as on a local node we are already in a dedicated directory
11 # (*) or makes sure that there's a remote dir called 'buildname' and runs in it
12 #
13 # also, the copy operations
14 # (*) either do nothing if ran locally
15 # (*) or copy a local file into the remote 'buildname' 
16
17
18 import sys
19 import os.path
20 import utils
21 import shutil
22
23 class TestSsh:
24     
25     # inserts a backslash before each occurence of the following chars
26     # \ " ' < > & | ; ( ) $ * ~ 
27     @staticmethod
28     def backslash_shell_specials (command):
29         result=''
30         for char in command:
31             if char in "\\\"'<>&|;()$*~":
32                 result +='\\'+char
33             else:
34                 result +=char
35         return result
36
37     # check main IP address against the provided hostname
38     @staticmethod
39     def is_local_hostname (hostname):
40         if hostname == "localhost":
41             return True
42         import socket
43         try:
44             local_ip = socket.gethostbyname(socket.gethostname())
45             remote_ip = socket.gethostbyname(hostname)
46             return local_ip==remote_ip
47         except:
48             utils.header("WARNING : something wrong in is_local_hostname with hostname=%s"%hostname)
49             return False
50
51     # some boxes have their working space in user's homedir (/root), 
52     # some others in a dedicated area with max. space (/vservers)
53     # when root is not specified we use the homedir
54     def __init__(self,hostname,buildname=None,key=None, username=None,unknown_host=True, root=None):
55         self.hostname=hostname
56         self.buildname=buildname
57         self.key=key
58         self.username=username
59         self.unknown_host=unknown_host
60         self.root=root
61
62     def is_local(self):
63         return TestSsh.is_local_hostname(self.hostname)
64      
65     std_options="-o BatchMode=yes -o StrictHostKeyChecking=no -o CheckHostIP=no -o ConnectTimeout=5 "
66     unknown_option="-o UserKnownHostsFile=/dev/null "
67     
68     def key_part (self):
69         if not self.key:
70             return ""
71         return "-i %s "%self.key
72
73     def hostname_part (self):
74         if not self.username:
75             return self.hostname
76         else:
77             return "%s@%s"%(self.username,self.hostname)
78     
79     # command gets run on the right box
80     def actual_command (self, command, keep_stdin=False, dry_run=False):
81         if self.is_local():
82             return command
83         ssh_command = "ssh "
84         if not dry_run:
85             if not keep_stdin:
86                 ssh_command += "-n "
87             ssh_command += TestSsh.std_options
88             if self.unknown_host: ssh_command += TestSsh.unknown_option
89         ssh_command += self.key_part()
90         ssh_command += "%s %s" %(self.hostname_part(),TestSsh.backslash_shell_specials(command))
91         return ssh_command
92
93     # same in argv form
94     def actual_argv (self, argv,keep_stdin=False, dry_run=False):
95         if self.is_local():
96             return argv
97         ssh_argv=[]
98         ssh_argv.append('ssh')
99         if not dry_run:
100             if not keep_stdin: ssh_argv.append('-n')
101             ssh_argv += TestSsh.std_options.split()
102             if self.unknown_host: ssh_argv += TestSsh.unknown_option.split()
103         ssh_argv += self.key_part().split()
104         ssh_argv.append(self.hostname_part())
105         ssh_argv += argv
106         return ssh_argv
107
108     def header (self,message):
109         if not message: return
110         print "===============",message
111         sys.stdout.flush()
112
113     def run(self, command,message=None,background=False,dry_run=False):
114         local_command = self.actual_command(command, dry_run=dry_run)
115         if dry_run:
116             utils.header("DRY RUN " + local_command)
117             return 0
118         else:
119             self.header(message)
120             return utils.system(local_command,background)
121
122     def run_in_buildname (self,command, background=False, dry_run=False):
123         if self.is_local():
124             return utils.system(command,background)
125         self.create_buildname_once(dry_run)
126         return self.run("cd %s ; %s"%(self.fullname(self.buildname),command),
127                         background=background, dry_run=dry_run)
128
129     def fullname (self,dirname):
130         if self.root==None:     return dirname
131         else:                   return os.path.join(self.root,dirname)
132         
133     def mkdir (self,dirname=None,abs=False,dry_run=False):
134         if self.is_local():
135             if dirname:
136                 return os.path.mkdir(dirname)
137             return 0
138         # ab. paths remain as-is
139         if not abs:
140             if dirname:
141                 dirname="%s/%s"%(self.buildname,dirname)
142             else:
143                 dirname=self.buildname
144             dirname=self.fullname(dirname)
145         if dirname=='.': return
146         return self.run("mkdir -p %s"%dirname,dry_run=dry_run)
147
148     def rmdir (self,dirname=None, dry_run=False):
149         if self.is_local():
150             if dirname:
151                 return shutil.rmtree(dirname)
152             return 0
153         if dirname:
154             dirname="%s/%s"%(self.buildname,dirname)
155         else:
156             dirname=self.buildname
157         dirname=self.fullname(dirname)
158         return self.run("rm -rf %s"%dirname, dry_run=dry_run)
159
160     def create_buildname_once (self, dry_run):
161         if self.is_local():
162             return
163         # create remote buildname on demand
164         try:
165             self.buildname_created
166         except:
167             self.mkdir(dry_run=dry_run)
168             self.buildname_created=True
169
170     def copy (self,local_file,recursive=False,dry_run=False):
171         if self.is_local():
172             return 0
173         self.create_buildname_once(dry_run)
174         scp_command="scp "
175         if not dry_run:
176             scp_command += TestSsh.std_options
177         if recursive: scp_command += "-r "
178         scp_command += self.key_part()
179         scp_command += "%s %s:%s/%s"%(local_file,self.hostname_part(),
180                                       self.fullname(self.buildname),os.path.basename(local_file) or ".")
181         if dry_run:
182             utils.header ("DRY RUN TestSsh.copy %s"%scp_command)
183             return True
184         return utils.system(scp_command)
185
186     def copy_abs (self,local_file,remote_file,recursive=False):
187         if self.is_local():
188             dest=""
189         else:
190             dest= "%s:"%self.hostname_part()
191         scp_command="scp "
192         scp_command += TestSsh.std_options
193         if recursive: scp_command += "-r "
194         scp_command += self.key_part()
195         scp_command += "%s %s%s"%(local_file,dest,remote_file)
196         return utils.system(scp_command)
197
198     def copy_home (self, local_file, recursive=False):
199         return self.copy_abs(local_file,os.path.basename(local_file),recursive)
200
201     def fetch (self, remote_file, local_file, recursive=False, dry_run=False):
202         if self.is_local():
203             command="cp "
204             if recursive: command += "-r "
205             command += "%s %s"%(remote_file,local_file)
206         else:
207             command="scp "
208             if not dry_run:
209                 command += TestSsh.std_options
210             if recursive: command += "-r "
211             command += self.key_part()
212             # absolute path - do not preprend buildname
213             if remote_file.find("/")==0:
214                 remote_path=remote_file
215             else:
216                 remote_path="%s/%s"%(self.buildname,remote_file)
217                 remote_path=self.fullname(remote_path)
218             command += "%s:%s %s"%(self.hostname_part(),remote_path,local_file)
219         return utils.system(command)
220
221     # this is only to avoid harmless message when host cannot be identified
222     # convenience only
223     # the only place where this is needed is when tring to reach a slice in a node,
224     # which is done from the test master box
225     def clear_known_hosts (self):
226         known_hosts = "%s/.ssh/known_hosts"%os.getenv("HOME")
227         utils.header("Clearing entry for %s in %s"%(self.hostname,known_hosts))
228         return utils.system("sed -i -e /^%s/d %s"%(self.hostname,known_hosts))
229