still making both branches closer
[nepi.git] / src / nepi / resources / linux / node.py
index ab10bae..a411c64 100644 (file)
@@ -297,7 +297,7 @@ class LinuxNode(ResourceManager):
         if not self.localhost and not self.get("username"):
             msg = "Can't resolve OS, insufficient data "
             self.error(msg)
-            raise RuntimeError, msg
+            raise RuntimeError(msg)
 
         out = self.get_os()
 
@@ -355,7 +355,7 @@ class LinuxNode(ResourceManager):
             trace = traceback.format_exc()
             msg = "Deploy failed. Unresponsive node {} -- traceback {}".format(self.get("hostname"), trace)
             self.error(msg)
-            raise RuntimeError, msg
+            raise RuntimeError(msg)
 
         self.find_home()
 
@@ -440,8 +440,45 @@ class LinuxNode(ResourceManager):
                    "sudo -S killall -u {} || /bin/true ; ".format(self.get("username")))
         else:
             if self.state >= ResourceState.READY:
+                ########################
+                #Collect all process (must change for a more intelligent way)
+                ppid = []
+                pids = []
+                avoid_pids = "ps axjf | awk '{print $1,$2}'"
+                (out, err), proc = self.execute(avoid_pids)
+                if len(out) != 0:
+                    for line in out.strip().split("\n"):
+                        parts = line.strip().split(" ")
+                        ppid.append(parts[0])
+                        pids.append(parts[1])
+
+                #Collect all process below ssh -D
+                tree_owner = 0
+                ssh_pids = []
+                sshs = "ps aux | grep 'sshd' | awk '{print $2,$12}'"
+                (out, err), proc = self.execute(sshs)
+                if len(out) != 0:
+                    for line in out.strip().split("\n"):
+                        parts = line.strip().split(" ")
+                        if parts[1].startswith('root@pts'):
+                            ssh_pids.append(parts[0])
+                        elif parts[1] == "-D":
+                            tree_owner = parts[0]
+
+                avoid_kill = []
+                temp = []
+                #Search for the child process of the pid's collected at the first block.
+                for process in ssh_pids:
+                    temp = self.search_for_child(process, pids, ppid)
+                    avoid_kill = list(set(temp))
+                
+                if len(avoid_kill) > 0:
+                    avoid_kill.append(tree_owner) 
+                ########################
+
                 import pickle
-                pids = pickle.load(open("/tmp/save.proc", "rb"))
+                with open("/tmp/save.proc", "rb") as pickle_file:
+                    pids = pickle.load(pickle_file)
                 pids_temp = dict()
                 ps_aux = "ps aux | awk '{print $2,$11}'"
                 (out, err), proc = self.execute(ps_aux)
@@ -449,8 +486,17 @@ class LinuxNode(ResourceManager):
                     for line in out.strip().split("\n"):
                         parts = line.strip().split(" ")
                         pids_temp[parts[0]] = parts[1]
+                    # creates the difference between the machine pids freezed (pickle) and the actual
+                    # adding the avoided pids filtered above (avoid_kill) to allow users keep process
+                    # alive when using besides ssh connections  
                     kill_pids = set(pids_temp.items()) - set(pids.items())
-                    kill_pids = ' '.join(dict(kill_pids).keys())
+                    # py2/py3 : keep it simple
+                    kill_pids = ' '.join(kill_pids)
+
+                    # removing pids from beside connections and its process
+                    kill_pids = kill_pids.split(' ')
+                    kill_pids = list(set(kill_pids) - set(avoid_kill))
+                    kill_pids = ' '.join(kill_pids)
 
                     cmd = ("killall tcpdump || /bin/true ; " +
                            "kill $(ps aux | grep '[.]nepi' | awk '{print $2}') || /bin/true ; " +
@@ -464,6 +510,16 @@ class LinuxNode(ResourceManager):
 
         (out, err), proc = self.execute(cmd, retry = 1, with_lock = True)
 
+    def search_for_child(self, pid, pids, ppid, family=[]):
+        """ Recursive function to search for child. List A contains the pids and list B the parents (ppid)
+        """
+        family.append(pid)
+        for key, value in enumerate(ppid):
+            if value == pid:
+                child = pids[key]
+                self.search_for_child(child, pids, ppid)
+        return family
+        
     def clean_home(self):
         """ Cleans all NEPI related folders in the Linux host
         """
@@ -711,15 +767,17 @@ class LinuxNode(ResourceManager):
         if text and not os.path.isfile(src):
             # src is text input that should be uploaded as file
             # create a temporal file with the content to upload
-            f = tempfile.NamedTemporaryFile(delete=False)
+            # in python3 we need to open in binary mode if str is bytes
+            mode = 'w' if isinstance(src, str) else 'wb'
+            f = tempfile.NamedTemporaryFile(mode=mode, delete=False)
             f.write(src)
             f.close()
             src = f.name
 
         # If dst files should not be overwritten, check that the files do not
-        # exits already
+        # exist already
         if isinstance(src, str):
-            src = map(str.strip, src.split(";"))
+            src = [s.strip() for s in src.split(";")]
     
         if overwrite == False:
             src = self.filter_existing_files(src, dst)
@@ -742,7 +800,7 @@ class LinuxNode(ResourceManager):
             
             msg = "{} out: {} err: {}".format(msg, out, err)
             if raise_on_error:
-                raise RuntimeError, msg
+                raise RuntimeError(msg)
 
         return ((out, err), proc)
 
@@ -758,7 +816,7 @@ class LinuxNode(ResourceManager):
             self.error(msg, out, err)
 
             if raise_on_error:
-                raise RuntimeError, msg
+                raise RuntimeError(msg)
 
         return ((out, err), proc)
 
@@ -771,7 +829,7 @@ class LinuxNode(ResourceManager):
         else:
             msg = "Error installing packages ( OS not known ) "
             self.error(msg, self.os)
-            raise RuntimeError, msg
+            raise RuntimeError(msg)
 
         return command
 
@@ -812,7 +870,7 @@ class LinuxNode(ResourceManager):
         else:
             msg = "Error removing packages ( OS not known ) "
             self.error(msg)
-            raise RuntimeError, msg
+            raise RuntimeError(msg)
 
         run_home = run_home or home
 
@@ -849,7 +907,7 @@ class LinuxNode(ResourceManager):
         if isinstance(paths, str):
             paths = [paths]
 
-        cmd = " ; ".join(map(lambda path: "rm -rf {}".format(path), paths))
+        cmd = " ; ".join(["rm -rf {}".format(path) for path in paths])
 
         return self.execute(cmd, with_lock = True)
     
@@ -896,7 +954,7 @@ class LinuxNode(ResourceManager):
             msg = " Failed to run command '{}' ".format(command)
             self.error(msg, out, err)
             if raise_on_error:
-                raise RuntimeError, msg
+                raise RuntimeError(msg)
 
         # Wait for pid file to be generated
         pid, ppid = self.wait_pid(
@@ -918,7 +976,7 @@ class LinuxNode(ResourceManager):
                 self.error(msg, eout, err)
 
                 if raise_on_error:
-                    raise RuntimeError, msg
+                    raise RuntimeError(msg)
 
         (out, oerr), proc = self.check_output(home, stdout)
         
@@ -1019,7 +1077,7 @@ class LinuxNode(ResourceManager):
         pid = ppid = None
         delay = 1.0
 
-        for i in xrange(2):
+        for i in range(2):
             pidtuple = self.getpid(home = home, pidfile = pidfile)
             
             if pidtuple:
@@ -1033,7 +1091,7 @@ class LinuxNode(ResourceManager):
             self.error(msg)
     
             if raise_on_error:
-                raise RuntimeError, msg
+                raise RuntimeError(msg)
 
         return pid, ppid
 
@@ -1112,7 +1170,7 @@ class LinuxNode(ResourceManager):
 
         if not self._home_dir:
             self.error(msg)
-            raise RuntimeError, msg
+            raise RuntimeError(msg)
 
     def filter_existing_files(self, src, dst):
         """ Removes files that already exist in the Linux host from src list
@@ -1122,14 +1180,14 @@ class LinuxNode(ResourceManager):
                 if len(src) > 1 else {dst: src[0]}
 
         command = []
-        for d in dests.keys():
+        for d in dests:
             command.append(" [ -f {dst} ] && echo '{dst}' ".format(dst=d) )
 
         command = ";".join(command)
 
         (out, err), proc = self.execute(command, retry = 1, with_lock = True)
         
-        for d in dests.keys():
+        for d in dests:
             if out.find(d) > -1:
                 del dests[d]