Merge branch 'master' of ssh://git.onelab.eu/git/tests
[tests.git] / system / TestSsh.py
1 # Thierry Parmentelat <thierry.parmentelat@inria.fr>
2 # Copyright (C) 2010 INRIA 
3 #
4 # class for issuing commands on a box, either local or remote
5 #
6 # the notion of 'buildname' is for providing each test run with a dir of its own
7 # buildname is generally the name of the build being tested, and can be considered unique
8 #
9 # thus 'run_in_buildname' mostly :
10 # (*) either runs locally in . - as on a local node we are already in a dedicated directory
11 # (*) or makes sure that there's a remote dir called 'buildname' and runs in it
12 #
13 # also, the copy operations
14 # (*) either do nothing if ran locally
15 # (*) or copy a local file into the remote 'buildname' 
16
17
18 import os.path
19 import utils
20 import shutil
21
22 class TestSsh:
23     
24     # inserts a backslash before each occurence of the following chars
25     # \ " ' < > & | ; ( ) $ * ~ 
26     @staticmethod
27     def backslash_shell_specials (command):
28         result=''
29         for char in command:
30             if char in "\\\"'<>&|;()$*~":
31                 result +='\\'+char
32             else:
33                 result +=char
34         return result
35
36     # check main IP address against the provided hostname
37     @staticmethod
38     def is_local_hostname (hostname):
39         if hostname == "localhost":
40             return True
41         import socket
42         try:
43             local_ip = socket.gethostbyname(socket.gethostname())
44             remote_ip = socket.gethostbyname(hostname)
45             return local_ip==remote_ip
46         except:
47             utils.header("WARNING : something wrong in is_local_hostname with hostname=%s"%hostname)
48             return False
49
50     def __init__(self,hostname,buildname=None,key=None, username=None,unknown_host=True):
51         self.hostname=hostname
52         self.buildname=buildname
53         self.key=key
54         self.username=username
55         self.unknown_host=unknown_host
56
57     def is_local(self):
58         return TestSsh.is_local_hostname(self.hostname)
59      
60     std_options="-o BatchMode=yes -o StrictHostKeyChecking=no -o CheckHostIP=no -o ConnectTimeout=5 "
61     unknown_option="-o UserKnownHostsFile=/dev/null "
62     
63     def key_part (self):
64         if not self.key:
65             return ""
66         return "-i %s "%self.key
67
68     def hostname_part (self):
69         if not self.username:
70             return self.hostname
71         else:
72             return "%s@%s"%(self.username,self.hostname)
73     
74     # command gets run on the right box
75     def actual_command (self, command,keep_stdin=False):
76         if self.is_local():
77             return command
78         ssh_command = "ssh "
79         if not keep_stdin:
80             ssh_command += "-n "
81         ssh_command += TestSsh.std_options
82         if self.unknown_host: ssh_command += TestSsh.unknown_option
83         ssh_command += self.key_part()
84         ssh_command += "%s %s" %(self.hostname_part(),TestSsh.backslash_shell_specials(command))
85         return ssh_command
86
87     # same in argv form
88     def actual_argv (self, argv,keep_stdin=False):
89         if self.is_local():
90             return argv
91         ssh_argv=[]
92         ssh_argv.append('ssh')
93         if not keep_stdin: ssh_argv.append('-n')
94         ssh_argv += TestSsh.std_options.split()
95         if self.unknown_host: ssh_argv += TestSsh.unknown_option.split()
96         ssh_argv += self.key_part().split()
97         ssh_argv.append(self.hostname_part())
98         ssh_argv += argv
99         return ssh_argv
100
101     def header (self,message):
102         if not message: return
103         print "===============",message
104         sys.stdout.flush()
105
106     def run(self, command,message=None,background=False):
107         local_command = self.actual_command(command)
108         self.header(message)
109         return utils.system(local_command,background)
110
111     def clean_dir (self,dirname):
112         if self.is_local():
113             return 0
114         return self.run("rm -rf %s"%dirname)
115
116     def mkdir (self,dirname=None):
117         if self.is_local():
118             if dirname:
119                 return os.path.mkdir(dirname)
120             return 0
121         if dirname:
122             dirname="%s/%s"%(self.buildname,dirname)
123         else:
124             dirname=self.buildname
125         return self.run("mkdir -p %s"%dirname)
126
127     def rmdir (self,dirname=None):
128         if self.is_local():
129             if dirname:
130                 return shutil.rmtree(dirname)
131             return 0
132         if dirname:
133             dirname="%s/%s"%(self.buildname,dirname)
134         else:
135             dirname=self.buildname
136         return self.run("rm -rf %s"%dirname)
137
138     def create_buildname_once (self):
139         if self.is_local():
140             return
141         # create remote buildname on demand
142         try:
143             self.buildname_created
144         except:
145             self.mkdir()
146             self.buildname_created=True
147
148     def run_in_buildname (self,command, background=False):
149         if self.is_local():
150             return utils.system(command,background)
151         self.create_buildname_once()
152         return self.run("cd %s ; %s"%(self.buildname,command),background)
153
154     def copy (self,local_file,recursive=False):
155         if self.is_local():
156             return 0
157         self.create_buildname_once()
158         scp_command="scp "
159         scp_command += TestSsh.std_options
160         if recursive: scp_command += "-r "
161         scp_command += self.key_part()
162         scp_command += "%s %s:%s/%s"%(local_file,self.hostname_part(),
163                                       self.buildname,os.path.basename(local_file) or ".")
164         return utils.system(scp_command)
165
166     def copy_abs (self,local_file,remote_file,recursive=False):
167         if self.is_local():
168             dest=""
169         else:
170             dest= "%s:"%self.hostname_part()
171         scp_command="scp "
172         scp_command += TestSsh.std_options
173         if recursive: scp_command += "-r "
174         scp_command += self.key_part()
175         scp_command += "%s %s%s"%(local_file,dest,remote_file)
176         return utils.system(scp_command)
177
178     def copy_home (self, local_file, recursive=False):
179         return self.copy_abs(local_file,os.path.basename(local_file),recursive)
180
181     def fetch (self, remote_file, local_file, recursive=False):
182         if self.is_local():
183             command="cp "
184             if recursive: command += "-r "
185             command += "%s %s"%(remote_file,local_file)
186         else:
187             command="scp "
188             command += TestSsh.std_options
189             if recursive: command += "-r "
190             command += self.key_part()
191             # absolute path - do not preprend buildname
192             if remote_file.find("/")==0:
193                 remote_path=remote_file
194             else:
195                 remote_path="%s/%s"%(self.buildname,remote_file)
196             command += "%s:%s %s"%(self.hostname_part(),remote_path,local_file)
197         return utils.system(command)
198
199     # this is only to avoid harmless message when host cannot be identified
200     # convenience only
201     # the only place where this is needed is when tring to reach a slice in a node,
202     # which is done from the test master box
203     def clear_known_hosts (self):
204         known_hosts = "%s/.ssh/known_hosts"%os.getenv("HOME")
205         utils.header("Clearing entry for %s in %s"%(self.hostname,known_hosts))
206         return utils.system("sed -i -e /^%s/d %s"%(self.hostname,known_hosts))
207