Fixing wrong license
[nepi.git] / test / util / sshfuncs.py
index 3342157..99f44b0 100755 (executable)
@@ -4,9 +4,8 @@
 #    Copyright (C) 2013 INRIA
 #
 #    This program is free software: you can redistribute it and/or modify
-#    it under the terms of the GNU General Public License as published by
-#    the Free Software Foundation, either version 3 of the License, or
-#    (at your option) any later version.
+#    it under the terms of the GNU General Public License version 2 as
+#    published by the Free Software Foundation;
 #
 #    This program is distributed in the hope that it will be useful,
 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
@@ -19,8 +18,8 @@
 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
 
 
-from nepi.util.sshfuncs import rexec, rcopy, rspawn, rcheckpid, rstatus, rkill,\
-        RUNNING, FINISHED 
+from nepi.util.sshfuncs import rexec, rcopy, rspawn, rgetpid, rstatus, rkill,\
+        ProcStatus
 
 import getpass
 import unittest
@@ -184,7 +183,7 @@ class SSHfuncsTestCase(unittest.TestCase):
 
         self.assertEquals(outlocal, outremote)
 
-    def test_rcopy(self):
+    def test_rcopy_list(self):
         env = test_environment()
         user = getpass.getuser()
         host = "localhost"
@@ -212,6 +211,38 @@ class SSHfuncsTestCase(unittest.TestCase):
 
         self.assertEquals(sorted(origfiles), sorted(files))
 
+        os.remove(f1.name)
+        shutil.rmtree(dirpath)
+
+    def test_rcopy_list(self):
+        env = test_environment()
+        user = getpass.getuser()
+        host = "localhost"
+
+        # create some temp files and directories to copy
+        dirpath = tempfile.mkdtemp()
+        f = tempfile.NamedTemporaryFile(dir=dirpath, delete=False)
+        f.close()
+      
+        f1 = tempfile.NamedTemporaryFile(delete=False)
+        f1.close()
+        f1.name
+
+        # Copy a list of files
+        source = [dirpath, f1.name]
+        destdir = tempfile.mkdtemp()
+        dest = "%s@%s:%s" % (user, host, destdir)
+        ((out, err), proc) = rcopy(source, dest, port = env.port, agent = True, recursive = True)
+
+        files = []
+        def recls(files, dirname, names):
+            files.extend(names)
+        os.path.walk(destdir, recls, files)
+       
+        origfiles = map(lambda s: os.path.basename(s), [dirpath, f.name, f1.name])
+
+        self.assertEquals(sorted(origfiles), sorted(files))
+
     def test_rproc_manage(self):
         env = test_environment()
         user = getpass.getuser()
@@ -231,7 +262,7 @@ class SSHfuncsTestCase(unittest.TestCase):
 
         time.sleep(2)
 
-        (pid, ppid) = rcheckpid(pidfile,
+        (pid, ppid) = rgetpid(pidfile,
                 host = host,
                 user = user,
                 port = env.port,
@@ -243,7 +274,7 @@ class SSHfuncsTestCase(unittest.TestCase):
                 port = env.port, 
                 agent = True)
 
-        self.assertEquals(status, RUNNING)
+        self.assertEquals(status, ProcStatus.RUNNING)
 
         rkill(pid, ppid,
                 host = host,
@@ -257,7 +288,7 @@ class SSHfuncsTestCase(unittest.TestCase):
                 port = env.port, 
                 agent = True)
         
-        self.assertEquals(status, FINISHED)
+        self.assertEquals(status, ProcStatus.FINISHED)
 
 
 if __name__ == '__main__':