adding back use of sudo in sshfuncs.rexec
[nepi.git] / src / nepi / resources / linux / node.py
1 #
2 #    NEPI, a framework to manage network experiments
3 #    Copyright (C) 2013 INRIA
4 #
5 #    This program is free software: you can redistribute it and/or modify
6 #    it under the terms of the GNU General Public License as published by
7 #    the Free Software Foundation, either version 3 of the License, or
8 #    (at your option) any later version.
9 #
10 #    This program is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 #    GNU General Public License for more details.
14 #
15 #    You should have received a copy of the GNU General Public License
16 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 #
18 # Author: Alina Quereilhac <alina.quereilhac@inria.fr>
19
20 from nepi.execution.attribute import Attribute, Flags
21 from nepi.execution.resource import ResourceManager, clsinit, ResourceState
22 from nepi.resources.linux import rpmfuncs, debfuncs 
23 from nepi.util import sshfuncs, execfuncs
24 from nepi.util.sshfuncs import ProcStatus
25
26 import collections
27 import os
28 import random
29 import re
30 import tempfile
31 import time
32 import threading
33
34 # TODO: Verify files and dirs exists already
35 # TODO: Blacklist nodes!
36 # TODO: Unify delays!!
37 # TODO: Validate outcome of uploads!! 
38
39 reschedule_delay = "0.5s"
40
41 class ExitCode:
42     """
43     Error codes that the rexitcode function can return if unable to
44     check the exit code of a spawned process
45     """
46     FILENOTFOUND = -1
47     CORRUPTFILE = -2
48     ERROR = -3
49     OK = 0
50
51 class OSType:
52     """
53     Supported flavors of Linux OS
54     """
55     FEDORA_12 = "f12"
56     FEDORA_14 = "f14"
57     FEDORA = "fedora"
58     UBUNTU = "ubuntu"
59     DEBIAN = "debian"
60
61 @clsinit
62 class LinuxNode(ResourceManager):
63     _rtype = "LinuxNode"
64
65     @classmethod
66     def _register_attributes(cls):
67         hostname = Attribute("hostname", "Hostname of the machine",
68                 flags = Flags.ExecReadOnly)
69
70         username = Attribute("username", "Local account username", 
71                 flags = Flags.Credential)
72
73         port = Attribute("port", "SSH port", flags = Flags.ExecReadOnly)
74         
75         home = Attribute("home",
76                 "Experiment home directory to store all experiment related files",
77                 flags = Flags.ExecReadOnly)
78         
79         identity = Attribute("identity", "SSH identity file",
80                 flags = Flags.Credential)
81         
82         server_key = Attribute("serverKey", "Server public key", 
83                 flags = Flags.ExecReadOnly)
84         
85         clean_home = Attribute("cleanHome", "Remove all files and directories " + \
86                 " from home folder before starting experiment", 
87                 flags = Flags.ExecReadOnly)
88         
89         clean_processes = Attribute("cleanProcesses", 
90                 "Kill all running processes before starting experiment",
91                 flags = Flags.ExecReadOnly)
92         
93         tear_down = Attribute("tearDown", "Bash script to be executed before " + \
94                 "releasing the resource",
95                 flags = Flags.ExecReadOnly)
96
97         cls._register_attribute(hostname)
98         cls._register_attribute(username)
99         cls._register_attribute(port)
100         cls._register_attribute(home)
101         cls._register_attribute(identity)
102         cls._register_attribute(server_key)
103         cls._register_attribute(clean_home)
104         cls._register_attribute(clean_processes)
105         cls._register_attribute(tear_down)
106
107     def __init__(self, ec, guid):
108         super(LinuxNode, self).__init__(ec, guid)
109         self._os = None
110         
111         # lock to avoid concurrency issues on methods used by applications 
112         self._lock = threading.Lock()
113     
114     def log_message(self, msg):
115         return " guid %d - host %s - %s " % (self.guid, 
116                 self.get("hostname"), msg)
117
118     @property
119     def home(self):
120         return self.get("home") or ""
121
122     @property
123     def exp_home(self):
124         return os.path.join(self.home, self.ec.exp_id)
125
126     @property
127     def node_home(self):
128         node_home = "node-%d" % self.guid
129         return os.path.join(self.exp_home, node_home)
130
131     @property
132     def os(self):
133         if self._os:
134             return self._os
135
136         if (not self.get("hostname") or not self.get("username")):
137             msg = "Can't resolve OS, insufficient data "
138             self.error(msg)
139             raise RuntimeError, msg
140
141         (out, err), proc = self.execute("cat /etc/issue", with_lock = True)
142
143         if err and proc.poll():
144             msg = "Error detecting OS "
145             self.error(msg, out, err)
146             raise RuntimeError, "%s - %s - %s" %( msg, out, err )
147
148         if out.find("Fedora release 12") == 0:
149             self._os = OSType.FEDORA_12
150         elif out.find("Fedora release 14") == 0:
151             self._os = OSType.FEDORA_14
152         elif out.find("Debian") == 0: 
153             self._os = OSType.DEBIAN
154         elif out.find("Ubuntu") ==0:
155             self._os = OSType.UBUNTU
156         else:
157             msg = "Unsupported OS"
158             self.error(msg, out)
159             raise RuntimeError, "%s - %s " %( msg, out )
160
161         return self._os
162
163     @property
164     def localhost(self):
165         return self.get("hostname") in ['localhost', '127.0.0.7', '::1']
166
167     def provision(self):
168         if not self.is_alive():
169             self._state = ResourceState.FAILED
170             msg = "Deploy failed. Unresponsive node %s" % self.get("hostname")
171             self.error(msg)
172             raise RuntimeError, msg
173
174         if self.get("cleanProcesses"):
175             self.clean_processes()
176
177         if self.get("cleanHome"):
178             self.clean_home()
179        
180         self.mkdir(self.node_home)
181
182         super(LinuxNode, self).provision()
183
184     def deploy(self):
185         if self.state == ResourceState.NEW:
186             try:
187                 self.discover()
188                 self.provision()
189             except:
190                 self._state = ResourceState.FAILED
191                 raise
192
193         # Node needs to wait until all associated interfaces are 
194         # ready before it can finalize deployment
195         from nepi.resources.linux.interface import LinuxInterface
196         ifaces = self.get_connected(LinuxInterface.rtype())
197         for iface in ifaces:
198             if iface.state < ResourceState.READY:
199                 self.ec.schedule(reschedule_delay, self.deploy)
200                 return 
201
202         super(LinuxNode, self).deploy()
203
204     def release(self):
205         tear_down = self.get("tearDown")
206         if tear_down:
207             self.execute(tear_down)
208
209         super(LinuxNode, self).release()
210
211     def valid_connection(self, guid):
212         # TODO: Validate!
213         return True
214
215     def clean_processes(self, killer = False):
216         self.info("Cleaning up processes")
217         
218         if killer:
219             # Hardcore kill
220             cmd = ("sudo -S killall python tcpdump || /bin/true ; " +
221                 "sudo -S killall python tcpdump || /bin/true ; " +
222                 "sudo -S kill $(ps -N -T -o pid --no-heading | grep -v $PPID | sort) || /bin/true ; " +
223                 "sudo -S killall -u root || /bin/true ; " +
224                 "sudo -S killall -u root || /bin/true ; ")
225         else:
226             # Be gentler...
227             cmd = ("sudo -S killall tcpdump || /bin/true ; " +
228                 "sudo -S killall tcpdump || /bin/true ; " +
229                 "sudo -S killall -u %s || /bin/true ; " % self.get("username") +
230                 "sudo -S killall -u %s || /bin/true ; " % self.get("username"))
231
232         out = err = ""
233         (out, err), proc = self.execute(cmd, retry = 1, with_lock = True) 
234             
235     def clean_home(self):
236         self.info("Cleaning up home")
237         
238         cmd = (
239             # "find . -maxdepth 1  \( -name '.cache' -o -name '.local' -o -name '.config' -o -name 'nepi-*' \)" +
240             "find . -maxdepth 1 -name 'nepi-*' " +
241             " -execdir rm -rf {} + "
242             )
243             
244         if self.home:
245             cmd = "cd %s ; " % self.home + cmd
246
247         out = err = ""
248         (out, err), proc = self.execute(cmd, with_lock = True)
249
250     def upload(self, src, dst, text = False):
251         """ Copy content to destination
252
253            src  content to copy. Can be a local file, directory or a list of files
254
255            dst  destination path on the remote host (remote is always self.host)
256
257            text src is text input, it must be stored into a temp file before uploading
258         """
259         # If source is a string input 
260         f = None
261         if text and not os.path.isfile(src):
262             # src is text input that should be uploaded as file
263             # create a temporal file with the content to upload
264             f = tempfile.NamedTemporaryFile(delete=False)
265             f.write(src)
266             f.close()
267             src = f.name
268
269         if not self.localhost:
270             # Build destination as <user>@<server>:<path>
271             dst = "%s@%s:%s" % (self.get("username"), self.get("hostname"), dst)
272
273         result = self.copy(src, dst)
274
275         # clean up temp file
276         if f:
277             os.remove(f.name)
278
279         return result
280
281     def download(self, src, dst):
282         if not self.localhost:
283             # Build destination as <user>@<server>:<path>
284             src = "%s@%s:%s" % (self.get("username"), self.get("hostname"), src)
285         return self.copy(src, dst)
286
287     def install_packages(self, packages, home):
288         command = ""
289         if self.os in [OSType.FEDORA_12, OSType.FEDORA_14, OSType.FEDORA]:
290             command = rpmfuncs.install_packages_command(self.os, packages)
291         elif self.os in [OSType.DEBIAN, OSType.UBUNTU]:
292             command = debfuncs.install_packages_command(self.os, packages)
293         else:
294             msg = "Error installing packages ( OS not known ) "
295             self.error(msg, self.os)
296             raise RuntimeError, msg
297
298         out = err = ""
299         (out, err), proc = self.run_and_wait(command, home, 
300             shfile = "instpkg.sh",
301             pidfile = "instpkg_pidfile",
302             ecodefile = "instpkg_exitcode",
303             stdout = "instpkg_stdout", 
304             stderr = "instpkg_stderr",
305             raise_on_error = True)
306
307         return (out, err), proc 
308
309     def remove_packages(self, packages, home):
310         command = ""
311         if self.os in [OSType.FEDORA_12, OSType.FEDORA_14, OSType.FEDORA]:
312             command = rpmfuncs.remove_packages_command(self.os, packages)
313         elif self.os in [OSType.DEBIAN, OSType.UBUNTU]:
314             command = debfuncs.remove_packages_command(self.os, packages)
315         else:
316             msg = "Error removing packages ( OS not known ) "
317             self.error(msg)
318             raise RuntimeError, msg
319
320         out = err = ""
321         (out, err), proc = self.run_and_wait(command, home, 
322             shfile = "rmpkg.sh",
323             pidfile = "rmpkg_pidfile",
324             ecodefile = "rmpkg_exitcode",
325             stdout = "rmpkg_stdout", 
326             stderr = "rmpkg_stderr",
327             raise_on_error = True)
328          
329         return (out, err), proc 
330
331     def mkdir(self, path, clean = False):
332         if clean:
333             self.rmdir(path)
334
335         return self.execute("mkdir -p %s" % path, with_lock = True)
336
337     def rmdir(self, path):
338         return self.execute("rm -rf %s" % path, with_lock = True)
339         
340     def run_and_wait(self, command, home, 
341             shfile = "cmd.sh",
342             env = None,
343             pidfile = "pidfile", 
344             ecodefile = "exitcode", 
345             stdin = None, 
346             stdout = "stdout", 
347             stderr = "stderr", 
348             sudo = False,
349             tty = False,
350             raise_on_error = False):
351         """ 
352         runs a command in background on the remote host, busy-waiting
353         until the command finishes execution.
354         This is more robust than doing a simple synchronized 'execute',
355         since in the remote host the command can continue to run detached
356         even if network disconnections occur
357         """
358         self.upload_command(command, home, 
359             shfile = shfile, 
360             ecodefile = ecodefile, 
361             env = env)
362
363         command = "bash ./%s" % shfile
364         # run command in background in remote host
365         (out, err), proc = self.run(command, home, 
366                 pidfile = pidfile,
367                 stdin = stdin, 
368                 stdout = stdout, 
369                 stderr = stderr, 
370                 sudo = sudo,
371                 tty = tty)
372
373         # check no errors occurred
374         if proc.poll() and err:
375             msg = " Failed to run command '%s' " % command
376             self.error(msg, out, err)
377             if raise_on_error:
378                 raise RuntimeError, msg
379
380         # Wait for pid file to be generated
381         pid, ppid = self.wait_pid(
382                 home = home, 
383                 pidfile = pidfile, 
384                 raise_on_error = raise_on_error)
385
386         # wait until command finishes to execute
387         self.wait_run(pid, ppid)
388       
389         (out, err), proc = self.check_errors(home, ecodefile, stderr)
390
391         # Out is what was written in the stderr file
392         if out or err:
393             msg = " Failed to run command '%s' " % command
394             self.error(msg, out, err)
395
396             if raise_on_error:
397                 raise RuntimeError, msg
398         
399         return (out, err), proc
400
401     def exitcode(self, home, ecodefile = "exitcode"):
402         """
403         Get the exit code of an application.
404         Returns an integer value with the exit code 
405         """
406         (out, err), proc = self.check_output(home, ecodefile)
407
408         # Succeeded to open file, return exit code in the file
409         if proc.wait() == 0:
410             try:
411                 return int(out.strip())
412             except:
413                 # Error in the content of the file!
414                 return ExitCode.CORRUPTFILE
415
416         # No such file or directory
417         if proc.returncode == 1:
418             return ExitCode.FILENOTFOUND
419         
420         # Other error from 'cat'
421         return ExitCode.ERROR
422
423     def upload_command(self, command, home, 
424             shfile = "cmd.sh",
425             ecodefile = "exitcode",
426             env = None):
427         """ Saves the command as a bash script file in the remote host, and
428         forces to save the exit code of the command execution to the ecodefile
429         """
430       
431         # The exit code of the command will be stored in ecodefile
432         command = " %(command)s ; echo $? > %(ecodefile)s ;" % {
433                 'command': command,
434                 'ecodefile': ecodefile,
435                 } 
436
437         # Export environment
438         environ = "\n".join(map(lambda e: "export %s" % e, env.split(" "))) + "\n" \
439             if env else ""
440
441         # Add environ to command
442         command = environ + command
443
444         dst = os.path.join(home, shfile)
445         return self.upload(command, dst, text = True)
446
447     def check_errors(self, home, 
448             ecodefile = "exitcode", 
449             stderr = "stderr"):
450         """
451         Checks whether errors occurred while running a command.
452         It first checks the exit code for the command, and only if the
453         exit code is an error one it returns the error output.
454         """
455         out = err = ""
456         proc = None
457
458         # get Exit code
459         ecode = self.exitcode(home, ecodefile)
460
461         if ecode in [ ExitCode.CORRUPTFILE, ExitCode.ERROR ]:
462             err = "Error retrieving exit code status from file %s/%s" % (home, ecodefile)
463         elif ecode > 0 or ecode == ExitCode.FILENOTFOUND:
464             # The process returned an error code or didn't exist. 
465             # Check standard error.
466             (out, err), proc = self.check_output(home, stderr)
467             
468             # If the stderr file was not found, assume nothing happened.
469             # We just ignore the error.
470             # (cat returns 1 for error "No such file or directory")
471             if ecode == ExitCode.FILENOTFOUND and proc.poll() == 1: 
472                 out = err = ""
473        
474         return (out, err), proc
475  
476     def wait_pid(self, home, pidfile = "pidfile", raise_on_error = False):
477         """ Waits until the pid file for the command is generated, 
478             and returns the pid and ppid of the process """
479         pid = ppid = None
480         delay = 1.0
481
482         for i in xrange(4):
483             pidtuple = self.getpid(home = home, pidfile = pidfile)
484             
485             if pidtuple:
486                 pid, ppid = pidtuple
487                 break
488             else:
489                 time.sleep(delay)
490                 delay = delay * 1.5
491         else:
492             msg = " Failed to get pid for pidfile %s/%s " % (
493                     home, pidfile )
494             self.error(msg)
495             
496             if raise_on_error:
497                 raise RuntimeError, msg
498
499         return pid, ppid
500
501     def wait_run(self, pid, ppid, trial = 0):
502         """ wait for a remote process to finish execution """
503         start_delay = 1.0
504
505         while True:
506             status = self.status(pid, ppid)
507             
508             if status is ProcStatus.FINISHED:
509                 break
510             elif status is not ProcStatus.RUNNING:
511                 delay = delay * 1.5
512                 time.sleep(delay)
513                 # If it takes more than 20 seconds to start, then
514                 # asume something went wrong
515                 if delay > 20:
516                     break
517             else:
518                 # The app is running, just wait...
519                 time.sleep(0.5)
520
521     def check_output(self, home, filename):
522         """ Retrives content of file """
523         (out, err), proc = self.execute("cat %s" % 
524             os.path.join(home, filename), retry = 1, with_lock = True)
525         return (out, err), proc
526
527     def is_alive(self):
528         if self.localhost:
529             return True
530
531         out = err = ""
532         try:
533             # TODO: FIX NOT ALIVE!!!!
534             (out, err), proc = self.execute("echo 'ALIVE' || (echo 'NOTALIVE') >&2", retry = 5, 
535                     with_lock = True)
536         except:
537             import traceback
538             trace = traceback.format_exc()
539             msg = "Unresponsive host  %s " % err
540             self.error(msg, out, trace)
541             return False
542
543         if out.strip().startswith('ALIVE'):
544             return True
545         else:
546             msg = "Unresponsive host "
547             self.error(msg, out, err)
548             return False
549
550     def copy(self, src, dst):
551         if self.localhost:
552             (out, err), proc = execfuncs.lcopy(source, dest, 
553                     recursive = True,
554                     strict_host_checking = False)
555         else:
556             with self._lock:
557                 (out, err), proc = sshfuncs.rcopy(
558                     src, dst, 
559                     port = self.get("port"),
560                     identity = self.get("identity"),
561                     server_key = self.get("serverKey"),
562                     recursive = True,
563                     strict_host_checking = False)
564
565         return (out, err), proc
566
567     def execute(self, command,
568             sudo = False,
569             stdin = None, 
570             env = None,
571             tty = False,
572             forward_x11 = False,
573             timeout = None,
574             retry = 3,
575             err_on_timeout = True,
576             connect_timeout = 30,
577             strict_host_checking = False,
578             persistent = True,
579             blocking = True,
580             with_lock = False
581             ):
582         """ Notice that this invocation will block until the
583         execution finishes. If this is not the desired behavior,
584         use 'run' instead."""
585
586         if self.localhost:
587             (out, err), proc = execfuncs.lexec(command, 
588                     user = user,
589                     sudo = sudo,
590                     stdin = stdin,
591                     env = env)
592         else:
593             if with_lock:
594                 with self._lock:
595                     (out, err), proc = sshfuncs.rexec(
596                         command, 
597                         host = self.get("hostname"),
598                         user = self.get("username"),
599                         port = self.get("port"),
600                         agent = True,
601                         sudo = sudo,
602                         stdin = stdin,
603                         identity = self.get("identity"),
604                         server_key = self.get("serverKey"),
605                         env = env,
606                         tty = tty,
607                         forward_x11 = forward_x11,
608                         timeout = timeout,
609                         retry = retry,
610                         err_on_timeout = err_on_timeout,
611                         connect_timeout = connect_timeout,
612                         persistent = persistent,
613                         blocking = blocking, 
614                         strict_host_checking = strict_host_checking
615                         )
616             else:
617                 (out, err), proc = sshfuncs.rexec(
618                     command, 
619                     host = self.get("hostname"),
620                     user = self.get("username"),
621                     port = self.get("port"),
622                     agent = True,
623                     sudo = sudo,
624                     stdin = stdin,
625                     identity = self.get("identity"),
626                     server_key = self.get("serverKey"),
627                     env = env,
628                     tty = tty,
629                     forward_x11 = forward_x11,
630                     timeout = timeout,
631                     retry = retry,
632                     err_on_timeout = err_on_timeout,
633                     connect_timeout = connect_timeout,
634                     persistent = persistent,
635                     blocking = blocking, 
636                     strict_host_checking = strict_host_checking
637                     )
638
639         return (out, err), proc
640
641     def run(self, command, home,
642             create_home = False,
643             pidfile = 'pidfile',
644             stdin = None, 
645             stdout = 'stdout', 
646             stderr = 'stderr', 
647             sudo = False,
648             tty = False):
649         
650         self.debug("Running command '%s'" % command)
651         
652         if self.localhost:
653             (out, err), proc = execfuncs.lspawn(command, pidfile, 
654                     stdout = stdout, 
655                     stderr = stderr, 
656                     stdin = stdin, 
657                     home = home, 
658                     create_home = create_home, 
659                     sudo = sudo,
660                     user = user) 
661         else:
662             with self._lock:
663                 (out, err), proc = sshfuncs.rspawn(
664                     command,
665                     pidfile = pidfile,
666                     home = home,
667                     create_home = create_home,
668                     stdin = stdin if stdin is not None else '/dev/null',
669                     stdout = stdout if stdout else '/dev/null',
670                     stderr = stderr if stderr else '/dev/null',
671                     sudo = sudo,
672                     host = self.get("hostname"),
673                     user = self.get("username"),
674                     port = self.get("port"),
675                     agent = True,
676                     identity = self.get("identity"),
677                     server_key = self.get("serverKey"),
678                     tty = tty
679                     )
680
681         return (out, err), proc
682
683     def getpid(self, home, pidfile = "pidfile"):
684         if self.localhost:
685             pidtuple =  execfuncs.lgetpid(os.path.join(home, pidfile))
686         else:
687             with self._lock:
688                 pidtuple = sshfuncs.rgetpid(
689                     os.path.join(home, pidfile),
690                     host = self.get("hostname"),
691                     user = self.get("username"),
692                     port = self.get("port"),
693                     agent = True,
694                     identity = self.get("identity"),
695                     server_key = self.get("serverKey")
696                     )
697         
698         return pidtuple
699
700     def status(self, pid, ppid):
701         if self.localhost:
702             status = execfuncs.lstatus(pid, ppid)
703         else:
704             with self._lock:
705                 status = sshfuncs.rstatus(
706                         pid, ppid,
707                         host = self.get("hostname"),
708                         user = self.get("username"),
709                         port = self.get("port"),
710                         agent = True,
711                         identity = self.get("identity"),
712                         server_key = self.get("serverKey")
713                         )
714            
715         return status
716     
717     def kill(self, pid, ppid, sudo = False):
718         out = err = ""
719         proc = None
720         status = self.status(pid, ppid)
721
722         if status == sshfuncs.ProcStatus.RUNNING:
723             if self.localhost:
724                 (out, err), proc = execfuncs.lkill(pid, ppid, sudo)
725             else:
726                 with self._lock:
727                     (out, err), proc = sshfuncs.rkill(
728                         pid, ppid,
729                         host = self.get("hostname"),
730                         user = self.get("username"),
731                         port = self.get("port"),
732                         agent = True,
733                         sudo = sudo,
734                         identity = self.get("identity"),
735                         server_key = self.get("serverKey")
736                         )
737
738         return (out, err), proc
739