1f70b4050414f482b7f0f3004ddf486aa269b02a
[tests.git] / system / TestSsh.py
1 # Thierry Parmentelat <thierry.parmentelat@inria.fr>
2 # Copyright (C) 2010 INRIA 
3 #
4 # class for issuing commands on a box, either local or remote
5 #
6 # the notion of 'buildname' is for providing each test run with a dir of its own
7 # buildname is generally the name of the build being tested, and can be considered unique
8 #
9 # thus 'run_in_buildname' mostly :
10 # (*) either runs locally in . - as on a local node we are already in a dedicated directory
11 # (*) or makes sure that there's a remote dir called 'buildname' and runs in it
12 #
13 # also, the copy operations
14 # (*) either do nothing if ran locally
15 # (*) or copy a local file into the remote 'buildname' 
16
17
18 import sys
19 import os.path
20 import utils
21 import shutil
22
23 class TestSsh:
24     
25     # inserts a backslash before each occurence of the following chars
26     # \ " ' < > & | ; ( ) $ * ~ 
27     @staticmethod
28     def backslash_shell_specials (command):
29         result=''
30         for char in command:
31             if char in "\\\"'<>&|;()$*~":
32                 result +='\\'+char
33             else:
34                 result +=char
35         return result
36
37     # check main IP address against the provided hostname
38     @staticmethod
39     def is_local_hostname (hostname):
40         if hostname == "localhost":
41             return True
42         import socket
43         try:
44             local_ip = socket.gethostbyname(socket.gethostname())
45             remote_ip = socket.gethostbyname(hostname)
46             return local_ip==remote_ip
47         except:
48             utils.header("WARNING : something wrong in is_local_hostname with hostname=%s"%hostname)
49             return False
50
51     def __init__(self,hostname,buildname=None,key=None, username=None,unknown_host=True):
52         self.hostname=hostname
53         self.buildname=buildname
54         self.key=key
55         self.username=username
56         self.unknown_host=unknown_host
57
58     def is_local(self):
59         return TestSsh.is_local_hostname(self.hostname)
60      
61     std_options="-o BatchMode=yes -o StrictHostKeyChecking=no -o CheckHostIP=no -o ConnectTimeout=5 "
62     unknown_option="-o UserKnownHostsFile=/dev/null "
63     
64     def key_part (self):
65         if not self.key:
66             return ""
67         return "-i %s "%self.key
68
69     def hostname_part (self):
70         if not self.username:
71             return self.hostname
72         else:
73             return "%s@%s"%(self.username,self.hostname)
74     
75     # command gets run on the right box
76     def actual_command (self, command,keep_stdin=False):
77         if self.is_local():
78             return command
79         ssh_command = "ssh "
80         if not keep_stdin:
81             ssh_command += "-n "
82         ssh_command += TestSsh.std_options
83         if self.unknown_host: ssh_command += TestSsh.unknown_option
84         ssh_command += self.key_part()
85         ssh_command += "%s %s" %(self.hostname_part(),TestSsh.backslash_shell_specials(command))
86         return ssh_command
87
88     # same in argv form
89     def actual_argv (self, argv,keep_stdin=False):
90         if self.is_local():
91             return argv
92         ssh_argv=[]
93         ssh_argv.append('ssh')
94         if not keep_stdin: ssh_argv.append('-n')
95         ssh_argv += TestSsh.std_options.split()
96         if self.unknown_host: ssh_argv += TestSsh.unknown_option.split()
97         ssh_argv += self.key_part().split()
98         ssh_argv.append(self.hostname_part())
99         ssh_argv += argv
100         return ssh_argv
101
102     def header (self,message):
103         if not message: return
104         print "===============",message
105         sys.stdout.flush()
106
107     def run(self, command,message=None,background=False,dry_run=False):
108         local_command = self.actual_command(command)
109         if dry_run:
110             utils.header("DRY RUN " + local_command)
111             return 0
112         else:
113             self.header(message)
114             return utils.system(local_command,background)
115
116     def clean_dir (self,dirname):
117         if self.is_local():
118             return 0
119         return self.run("rm -rf %s"%dirname)
120
121     def mkdir (self,dirname=None,abs=False):
122         if self.is_local():
123             if dirname:
124                 return os.path.mkdir(dirname)
125             return 0
126         if not abs:
127             if dirname:
128                 dirname="%s/%s"%(self.buildname,dirname)
129             else:
130                 dirname=self.buildname
131         if dirname=='.': return
132         return self.run("mkdir -p %s"%dirname)
133
134     def rmdir (self,dirname=None):
135         if self.is_local():
136             if dirname:
137                 return shutil.rmtree(dirname)
138             return 0
139         if dirname:
140             dirname="%s/%s"%(self.buildname,dirname)
141         else:
142             dirname=self.buildname
143         return self.run("rm -rf %s"%dirname)
144
145     def create_buildname_once (self):
146         if self.is_local():
147             return
148         # create remote buildname on demand
149         try:
150             self.buildname_created
151         except:
152             self.mkdir()
153             self.buildname_created=True
154
155     def run_in_buildname (self,command, background=False):
156         if self.is_local():
157             return utils.system(command,background)
158         self.create_buildname_once()
159         return self.run("cd %s ; %s"%(self.buildname,command),background)
160
161     def copy (self,local_file,recursive=False):
162         if self.is_local():
163             return 0
164         self.create_buildname_once()
165         scp_command="scp "
166         scp_command += TestSsh.std_options
167         if recursive: scp_command += "-r "
168         scp_command += self.key_part()
169         scp_command += "%s %s:%s/%s"%(local_file,self.hostname_part(),
170                                       self.buildname,os.path.basename(local_file) or ".")
171         return utils.system(scp_command)
172
173     def copy_abs (self,local_file,remote_file,recursive=False):
174         if self.is_local():
175             dest=""
176         else:
177             dest= "%s:"%self.hostname_part()
178         scp_command="scp "
179         scp_command += TestSsh.std_options
180         if recursive: scp_command += "-r "
181         scp_command += self.key_part()
182         scp_command += "%s %s%s"%(local_file,dest,remote_file)
183         return utils.system(scp_command)
184
185     def copy_home (self, local_file, recursive=False):
186         return self.copy_abs(local_file,os.path.basename(local_file),recursive)
187
188     def fetch (self, remote_file, local_file, recursive=False):
189         if self.is_local():
190             command="cp "
191             if recursive: command += "-r "
192             command += "%s %s"%(remote_file,local_file)
193         else:
194             command="scp "
195             command += TestSsh.std_options
196             if recursive: command += "-r "
197             command += self.key_part()
198             # absolute path - do not preprend buildname
199             if remote_file.find("/")==0:
200                 remote_path=remote_file
201             else:
202                 remote_path="%s/%s"%(self.buildname,remote_file)
203             command += "%s:%s %s"%(self.hostname_part(),remote_path,local_file)
204         return utils.system(command)
205
206     # this is only to avoid harmless message when host cannot be identified
207     # convenience only
208     # the only place where this is needed is when tring to reach a slice in a node,
209     # which is done from the test master box
210     def clear_known_hosts (self):
211         known_hosts = "%s/.ssh/known_hosts"%os.getenv("HOME")
212         utils.header("Clearing entry for %s in %s"%(self.hostname,known_hosts))
213         return utils.system("sed -i -e /^%s/d %s"%(self.hostname,known_hosts))
214