applied the except and raise fixers to the master branch to close the gap with py3
[nepi.git] / test / util / sshfuncs.py
old mode 100644 (file)
new mode 100755 (executable)
index 9b282ce..2c9c30b
@@ -1,7 +1,25 @@
 #!/usr/bin/env python
-
-from neco.util.sshfuncs import rexec, rcopy, rspawn, rcheckpid, rstatus, rkill,\
-        RUNNING, FINISHED 
+#
+#    NEPI, a framework to manage network experiments
+#    Copyright (C) 2013 INRIA
+#
+#    This program is free software: you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License version 2 as
+#    published by the Free Software Foundation;
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# Author: Alina Quereilhac <alina.quereilhac@inria.fr>
+
+
+from nepi.util.sshfuncs import rexec, rcopy, rspawn, rgetpid, rstatus, rkill,\
+        ProcStatus
 
 import getpass
 import unittest
@@ -29,7 +47,7 @@ def find_bin(name, extra_path = None):
             try:
                 os.stat(d + "/" + name)
                 return d + "/" + name
-            except OSError, e:
+            except OSError as e:
                 if e.errno != os.errno.ENOENT:
                     raise
     return None
@@ -50,9 +68,8 @@ def gen_ssh_keypair(filename):
 def add_key_to_agent(filename):
     ssh_add = find_bin_or_die("ssh-add")
     args = [ssh_add, filename]
-    null = file("/dev/null", "w")
-    assert subprocess.Popen(args, stderr = null).wait() == 0
-    null.close()
+    with open("/dev/null", "w") as null:
+        assert subprocess.Popen(args, stderr = null).wait() == 0
 
 def get_free_port():
     s = socket.socket()
@@ -75,10 +92,9 @@ PermitUserEnvironment yes
 """
 
 def gen_sshd_config(filename, port, server_key, auth_keys):
-    conf = open(filename, "w")
-    text = _SSH_CONF % (port, server_key, auth_keys)
-    conf.write(text)
-    conf.close()
+    with open(filename, "w") as conf:
+        text = _SSH_CONF % (port, server_key, auth_keys)
+        conf.write(text)
     return filename
 
 def gen_auth_keys(pubkey, output, environ):
@@ -87,11 +103,11 @@ def gen_auth_keys(pubkey, output, environ):
     for k, v in environ.items():
         opts.append('environment="%s=%s"' % (k, v))
 
-    lines = file(pubkey).readlines()
+    with open(pubkey) as f:
+        lines = f.readlines()
     pubkey = lines[0].split()[0:2]
-    out = file(output, "w")
-    out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
-    out.close()
+    with open(output, "w") as out:
+        out.write("%s %s %s\n" % (",".join(opts), pubkey[0], pubkey[1]))
     return output
 
 def start_ssh_agent():
@@ -113,9 +129,8 @@ def stop_ssh_agent(data):
     # No need to gather the pid, ssh-agent knows how to kill itself; after we
     # had set up the environment
     ssh_agent = find_bin_or_die("ssh-agent")
-    null = file("/dev/null", "w")
-    proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
-    null.close()
+    with open("/dev/null", "w") as null:
+        proc = subprocess.Popen([ssh_agent, "-k"], stdout = null)
     assert proc.wait() == 0
     for k in data:
         del os.environ[k]
@@ -165,7 +180,7 @@ class SSHfuncsTestCase(unittest.TestCase):
 
         self.assertEquals(outlocal, outremote)
 
-    def test_rcopy(self):
+    def test_rcopy_list(self):
         env = test_environment()
         user = getpass.getuser()
         host = "localhost"
@@ -182,7 +197,7 @@ class SSHfuncsTestCase(unittest.TestCase):
         source = [dirpath, f1.name]
         destdir = tempfile.mkdtemp()
         dest = "%s@%s:%s" % (user, host, destdir)
-        rcopy(source, dest, port = env.port, agent = True)
+        rcopy(source, dest, port = env.port, agent = True, recursive = True)
 
         files = []
         def recls(files, dirname, names):
@@ -193,6 +208,38 @@ class SSHfuncsTestCase(unittest.TestCase):
 
         self.assertEquals(sorted(origfiles), sorted(files))
 
+        os.remove(f1.name)
+        shutil.rmtree(dirpath)
+
+    def test_rcopy_list(self):
+        env = test_environment()
+        user = getpass.getuser()
+        host = "localhost"
+
+        # create some temp files and directories to copy
+        dirpath = tempfile.mkdtemp()
+        f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
+        f.close()
+      
+        f1 = tempfile.NamedTemporaryFile(delete=False)
+        f1.close()
+        f1.name
+
+        # Copy a list of files
+        source = [dirpath, f1.name]
+        destdir = tempfile.mkdtemp()
+        dest = "%s@%s:%s" % (user, host, destdir)
+        ((out, err), proc) = rcopy(source, dest, port = env.port, agent = True, recursive = True)
+
+        files = []
+        def recls(files, dirname, names):
+            files.extend(names)
+        os.path.walk(destdir, recls, files)
+       
+        origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
+
+        self.assertEquals(sorted(origfiles), sorted(files))
+
     def test_rproc_manage(self):
         env = test_environment()
         user = getpass.getuser()
@@ -212,7 +259,7 @@ class SSHfuncsTestCase(unittest.TestCase):
 
         time.sleep(2)
 
-        (pid, ppid) = rcheckpid(pidfile,
+        (pid, ppid) = rgetpid(pidfile,
                 host = host,
                 user = user,
                 port = env.port,
@@ -224,7 +271,7 @@ class SSHfuncsTestCase(unittest.TestCase):
                 port = env.port, 
                 agent = True)
 
-        self.assertEquals(status, RUNNING)
+        self.assertEquals(status, ProcStatus.RUNNING)
 
         rkill(pid, ppid,
                 host = host,
@@ -238,7 +285,7 @@ class SSHfuncsTestCase(unittest.TestCase):
                 port = env.port, 
                 agent = True)
         
-        self.assertEquals(status, FINISHED)
+        self.assertEquals(status, ProcStatus.FINISHED)
 
 
 if __name__ == '__main__':