FORGE: Added including script
[myslice.git] / forge / script / scp.py
1 # scp.py
2 # Copyright (C) 2008 James Bardin <jbardin@bu.edu>
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 2 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along
15 # with this program; if not, write to the Free Software Foundation, Inc.,
16 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17
18
19 """
20 Utilities for sending files over ssh using the scp1 protocol.
21 """
22
23 import os
24 from socket import timeout as SocketTimeout
25
26 class SCPClient(object):
27     """
28     An scp1 implementation, compatible with openssh scp.
29     Raises SCPException for all transport related errors. Local filesystem
30     and OS errors pass through. 
31
32     Main public methods are .put and .get 
33     The get method is controlled by the remote scp instance, and behaves 
34     accordingly. This means that symlinks are resolved, and the transfer is
35     halted after too many levels of symlinks are detected.
36     The put method uses os.walk for recursion, and sends files accordingly.
37     Since scp doesn't support symlinks, we send file symlinks as the file
38     (matching scp behaviour), but we make no attempt at symlinked directories.
39
40     Convenience methods:
41         put_r:  put with recursion
42         put_p:  put preserving times
43         put_rp: put with recursion, preserving times
44         get_r:  get with recursion
45         get_p:  get preserving times
46         get_rp: get with recursion, preserving times
47     """
48     def __init__(self, transport, buff_size = 16384, socket_timeout = 5.0,
49                  callback = None):
50         """
51         Create an scp1 client.
52
53         @param transport: an existing paramiko L{Transport}
54         @type transport: L{Transport}
55         @param buff_size: size of the scp send buffer.
56         @type buff_size: int
57         @param socket_timeout: channel socket timeout in seconds
58         @type socket_timeout: float
59         @param callback: callback function for transfer status
60         @type callback: func
61         """
62         self.transport = transport
63         self.buff_size = buff_size
64         self.socket_timeout = socket_timeout
65         self.channel = None
66         self.preserve_times = False
67         self.callback = callback
68         self._recv_dir = ''
69         self._utime = None
70         self._dirtimes = {}
71
72     def put(self, files, remote_path = '.', 
73             recursive = False, preserve_times = False):
74         """
75         Transfer files to remote host.
76
77         @param files: A single path, or a list of paths to be transfered.
78             recursive must be True to transfer directories.
79         @type files: string OR list of strings
80         @param remote_path: path in which to receive the files on the remote
81             host. defaults to '.'
82         @type remote_path: str
83         @param recursive: transfer files and directories recursively
84         @type recursive: bool
85         @param preserve_times: preserve mtime and atime of transfered files
86             and directories.
87         @type preserve_times: bool
88         """
89         self.preserve_times = preserve_times
90         self.channel = self.transport.open_session()
91         self.channel.settimeout(self.socket_timeout)
92         scp_command = ('scp -t %s\n', 'scp -r -t %s\n')[recursive]
93         self.channel.exec_command(scp_command % remote_path)
94         self._recv_confirm()
95        
96         if not isinstance(files, (list, tuple)):
97             files = [files]
98         
99         if recursive:
100             self._send_recursive(files)
101         else:
102             self._send_files(files)
103         
104         if self.channel:
105             self.channel.close()
106     
107     def put_r(self, files, remote_path = '.'):
108         """
109         Convenience function for a recursive put
110
111         @param files: A single path, or a list of paths to be transfered.
112         @type files: str, list
113         @param remote_path: path in which to receive the files on the remote
114             host. defaults to '.'
115         @type remote_path: bool
116         """
117         self.put(files, remote_path, recursive = True)
118
119     def put_p(self, files, remote_path = '.'):
120         """
121         Convenience function to put preserving times.
122
123         @param files: A single path, or a list of paths to be transfered.
124         @type files: str, list
125         @param remote_path: path in which to receive the files on the remote
126             host. defaults to '.'
127         @type remote_path: bool
128         """
129         self.put(files, remote_path, preserve_times = True)
130    
131     def put_rp(self, files, remote_path = '.'):
132         """
133         Convenience function for a recursive put, preserving times.
134
135         @param files: A single path, or a list of paths to be transfered.
136         @type files: str, list
137         @param remote_path: path in which to receive the files on the remote
138             host. defaults to '.'
139         @type remote_path: bool
140         """
141         self.put(files, remote_path, recursive = True, preserve_times = True)
142
143     def get(self, remote_path, local_path = '',
144             recursive = False, preserve_times = False):
145         """
146         Transfer files from remote host to localhost
147
148         @param remote_path: path to retreive from remote host. since this is
149             evaluated by scp on the remote host, shell wildcards and 
150             environment variables may be used.
151         @type remote_path: str
152         @param local_path: path in which to receive files locally
153         @type local_path: str
154         @param recursive: transfer files and directories recursively
155         @type recursive: bool
156         @param preserve_times: preserve mtime and atime of transfered files
157             and directories.
158         @type preserve_times: bool
159         """
160         self._recv_dir = local_path or os.getcwd() 
161         rcsv = ('', ' -r')[recursive]
162         prsv = ('', ' -p')[preserve_times]
163         self.channel = self.transport.open_session()
164         self.channel.settimeout(self.socket_timeout)
165         self.channel.exec_command('scp%s%s -f %s' % (rcsv, prsv, remote_path))
166         self._recv_all()
167         
168         if self.channel:
169             self.channel.close()
170
171     def get_r(self, remote_path, local_path = '.'):
172         """
173         Convenience function for a recursive get
174
175         @param remote_path: path to retrieve from server
176         @type remote_path: str
177         @param local_path: path in which to recieve files. default cwd
178         @type local_path: str
179         """
180         self.get(remote_path, local_path, recursive = True)
181
182     def get_p(self, remote_path, local_path = '.'):
183         """
184         Convenience function for get, preserving times.
185
186         @param remote_path: path to retrieve from server
187         @type remote_path: str
188         @param local_path: path in which to recieve files. default cwd
189         @type local_path: str
190         """
191         self.get(remote_path, local_path, preserve_times = True)
192
193     def get_rp(self, remote_path, local_path = '.'):
194         """
195         Convenience function for a recursive get, preserving times.
196
197         @param remote_path: path to retrieve from server
198         @type remote_path: str
199         @param local_path: path in which to recieve files. default cwd
200         @type local_path: str
201         """
202         self.get(remote_path, local_path, recursive = True, preserve_times = True)
203
204     def _read_stats(self, name):
205         """return just the file stats needed for scp"""
206         stats = os.stat(name)
207         mode = oct(stats.st_mode)[-4:]
208         size = stats.st_size
209         atime = int(stats.st_atime)
210         mtime = int(stats.st_mtime)
211         return (mode, size, mtime, atime)
212
213     def _send_files(self, files): 
214         for name in files:
215             basename = os.path.basename(name)
216             (mode, size, mtime, atime) = self._read_stats(name)
217             if self.preserve_times:
218                 self._send_time(mtime, atime)
219             file_hdl = file(name, 'rb')
220             self.channel.sendall('C%s %d %s\n' % (mode, size, basename))
221             self._recv_confirm()
222             file_pos = 0
223             buff_size = self.buff_size
224             chan = self.channel
225             while file_pos < size:
226                 chan.sendall(file_hdl.read(buff_size))
227                 file_pos = file_hdl.tell()
228                 if self.callback:
229                     self.callback(file_pos, size)
230             chan.sendall('\x00')
231             file_hdl.close()
232
233     def _send_recursive(self, files):
234         for base in files:
235             lastdir = base
236             for root, dirs, fls in os.walk(base):
237                 # pop back out to the next dir in the walk
238                 while lastdir != os.path.commonprefix([lastdir, root]):
239                     self._send_popd()
240                     lastdir = os.path.split(lastdir)[0]
241                 self._send_pushd(root)
242                 lastdir = root
243                 self._send_files([os.path.join(root, f) for f in fls])
244         
245     def _send_pushd(self, directory):
246         (mode, size, mtime, atime) = self._read_stats(directory)
247         basename = os.path.basename(directory)
248         if self.preserve_times:
249             self._send_time(mtime, atime)
250         self.channel.sendall('D%s 0 %s\n' % (mode, basename))
251         self._recv_confirm()
252
253     def _send_popd(self):
254         self.channel.sendall('E\n')
255         self._recv_confirm()
256
257     def _send_time(self, mtime, atime):
258         self.channel.sendall('T%d 0 %d 0\n' % (mtime, atime))
259         self._recv_confirm()
260
261     def _recv_confirm(self):
262         # read scp response
263         msg = ''
264         try:
265             msg = self.channel.recv(512)
266         except SocketTimeout:
267             raise SCPException('Timout waiting for scp response')
268         if msg and msg[0] == '\x00':
269             return
270         elif msg and msg[0] == '\x01':
271             raise SCPException(msg[1:])
272         elif self.channel.recv_stderr_ready():
273             msg = self.channel.recv_stderr(512)
274             raise SCPException(msg)
275         elif not msg:
276             raise SCPException('No response from server')
277         else:
278             raise SCPException('Invalid response from server: ' + msg)
279     
280     def _recv_all(self):
281         # loop over scp commands, and recive as necessary
282         command = {'C': self._recv_file,
283                    'T': self._set_time,
284                    'D': self._recv_pushd,
285                    'E': self._recv_popd}
286         while not self.channel.closed:
287             # wait for command as long as we're open
288             self.channel.sendall('\x00')
289             msg = self.channel.recv(1024)
290             if not msg: # chan closed while recving
291                 break
292             code = msg[0]
293             try:
294                 command[code](msg[1:])
295             except KeyError:
296                 raise SCPException(repr(msg))
297         # directory times can't be set until we're done writing files
298         self._set_dirtimes()
299     
300     def _set_time(self, cmd):
301         try:
302             times = cmd.split()
303             mtime = int(times[0])
304             atime = int(times[2]) or mtime
305         except:
306             self.channel.send('\x01')
307             raise SCPException('Bad time format')
308         # save for later
309         self._utime = (mtime, atime)
310
311     def _recv_file(self, cmd):
312         chan = self.channel
313         parts = cmd.split()
314         try:
315             mode = int(parts[0], 8)
316             size = int(parts[1])
317             path = os.path.join(self._recv_dir, parts[2])
318         except:
319             chan.send('\x01')
320             chan.close()
321             raise SCPException('Bad file format')
322         
323         try:
324             file_hdl = file(path, 'wb')
325         except IOError, e:
326             chan.send('\x01'+e.message)
327             chan.close()
328             raise
329
330         buff_size = self.buff_size
331         pos = 0
332         chan.send('\x00')
333         try:
334             while pos < size:
335                 # we have to make sure we don't read the final byte
336                 if size - pos <= buff_size:
337                     buff_size = size - pos
338                 file_hdl.write(chan.recv(buff_size))
339                 pos = file_hdl.tell()
340                 if self.callback:
341                     self.callback(pos, size)
342             
343             msg = chan.recv(512)
344             if msg and msg[0] != '\x00':
345                 raise SCPException(msg[1:])
346         except SocketTimeout:
347             chan.close()
348             raise SCPException('Error receiving, socket.timeout')
349
350         file_hdl.truncate()
351         try:
352             os.utime(path, self._utime)
353             self._utime = None
354             os.chmod(path, mode)
355             # should we notify the other end?
356         finally:
357             file_hdl.close()
358         # '\x00' confirmation sent in _recv_all
359
360     def _recv_pushd(self, cmd):
361         parts = cmd.split()
362         try:
363             mode = int(parts[0], 8)
364             path = os.path.join(self._recv_dir, parts[2])
365         except:
366             self.channel.send('\x01')
367             raise SCPException('Bad directory format')
368         try:
369             if not os.path.exists(path):
370                 os.mkdir(path, mode)
371             elif os.path.isdir(path):
372                 os.chmod(path, mode)
373             else:
374                 raise SCPException('%s: Not a directory' % path)
375             self._dirtimes[path] = (self._utime)
376             self._utime = None
377             self._recv_dir = path
378         except (OSError, SCPException), e:
379             self.channel.send('\x01'+e.message)
380             raise
381
382     def _recv_popd(self, *cmd):
383         self._recv_dir = os.path.split(self._recv_dir)[0]
384         
385     def _set_dirtimes(self):
386         try:
387             for d in self._dirtimes:
388                 os.utime(d, self._dirtimes[d])
389         finally:
390             self._dirtimes = {}
391
392
393 class SCPException(Exception):
394     """SCP exception class"""
395     pass