56eb6e2ad7e8548fd4fe2e6fcea60c6c3d147ef4
[nepi.git] / src / neco / resources / linux / node.py
1 from neco.execution.attribute import Attribute, Flags
2 from neco.execution.resource import ResourceManager, clsinit, ResourceState
3 from neco.resources.linux import rpmfuncs, debfuncs 
4 from neco.util import sshfuncs, execfuncs 
5
6 import collections
7 import logging
8 import os
9 import random
10 import re
11 import tempfile
12 import time
13 import threading
14
15 # TODO: Verify files and dirs exists already
16
17 @clsinit
18 class LinuxNode(ResourceManager):
19     _rtype = "LinuxNode"
20
21     @classmethod
22     def _register_attributes(cls):
23         hostname = Attribute("hostname", "Hostname of the machine")
24
25         username = Attribute("username", "Local account username", 
26                 flags = Flags.Credential)
27
28         port = Attribute("port", "SSH port", flags = Flags.Credential)
29         
30         home = Attribute("home", 
31                 "Experiment home directory to store all experiment related files")
32         
33         identity = Attribute("identity", "SSH identity file",
34                 flags = Flags.Credential)
35         
36         server_key = Attribute("serverKey", "Server public key", 
37                 flags = Flags.Credential)
38         
39         clean_home = Attribute("cleanHome", "Remove all files and directories " + \
40                 " from home folder before starting experiment", 
41                 flags = Flags.ReadOnly)
42         
43         clean_processes = Attribute("cleanProcesses", 
44                 "Kill all running processes before starting experiment", 
45                 flags = Flags.ReadOnly)
46         
47         tear_down = Attribute("tearDown", "Bash script to be executed before " + \
48                 "releasing the resource", flags = Flags.ReadOnly)
49
50         cls._register_attribute(hostname)
51         cls._register_attribute(username)
52         cls._register_attribute(port)
53         cls._register_attribute(home)
54         cls._register_attribute(identity)
55         cls._register_attribute(server_key)
56         cls._register_attribute(clean_home)
57         cls._register_attribute(clean_processes)
58         cls._register_attribute(tear_down)
59
60     def __init__(self, ec, guid):
61         super(LinuxNode, self).__init__(ec, guid)
62         self._os = None
63         self._home = "nepi-exp-%s" % os.urandom(8).encode('hex')
64         
65         # lock to avoid concurrency issues on methods used by applications 
66         self._lock = threading.Lock()
67
68         self._logger = logging.getLogger("neco.linux.Node.%d " % self.guid)
69
70     @property
71     def home(self):
72         home = self.get("home")
73         if home and not home.startswith("nepi-"):
74             home = "nepi-" + home
75         return home or self._home
76
77     @property
78     def os(self):
79         if self._os:
80             return self._os
81
82         if (not self.get("hostname") or not self.get("username")):
83             msg = "Can't resolve OS for guid %d. Insufficient data." % self.guid
84             self.logger.error(msg)
85             raise RuntimeError, msg
86
87         (out, err), proc = self.execute("cat /etc/issue")
88
89         if err and proc.poll():
90             msg = "Error detecting OS for host %s. err: %s " % (self.get("hostname"), err)
91             self.logger.error(msg)
92             raise RuntimeError, msg
93
94         if out.find("Fedora release 12") == 0:
95             self._os = "f12"
96         elif out.find("Fedora release 14") == 0:
97             self._os = "f14"
98         elif out.find("Debian") == 0: 
99             self._os = "debian"
100         elif out.find("Ubuntu") ==0:
101             self._os = "ubuntu"
102         else:
103             msg = "Unsupported OS %s for host %s" % (out, self.get("hostname"))
104             self.logger.error(msg)
105             raise RuntimeError, msg
106
107         return self._os
108
109     @property
110     def localhost(self):
111         return self.get("hostname") in ['localhost', '127.0.0.7', '::1']
112
113     def provision(self, filters = None):
114         if not self.is_alive():
115             self._state = ResourceState.FAILED
116             self.logger.error("Deploy failed. Unresponsive node")
117             return
118
119     def deploy(self):
120         self.provision()
121
122         if self.get("cleanProcesses"):
123             self.clean_processes()
124
125         if self.get("cleanHome"):
126             # self.clean_home() -> this is dangerous
127             pass
128
129         self.mkdir(self.home)
130
131         super(LinuxNode, self).deploy()
132
133     def release(self):
134         tear_down = self.get("tearDown")
135         if tear_down:
136             self.execute(tear_down)
137
138         super(LinuxNode, self).release()
139
140     def validate_connection(self, guid):
141         # TODO: Validate!
142         return True
143
144     def clean_processes(self):
145         self.logger.info("Cleaning up processes")
146         
147         cmd = ("sudo -S killall python tcpdump || /bin/true ; " +
148             "sudo -S killall python tcpdump || /bin/true ; " +
149             "sudo -S kill $(ps -N -T -o pid --no-heading | grep -v $PPID | sort) || /bin/true ; " +
150             "sudo -S killall -u root || /bin/true ; " +
151             "sudo -S killall -u root || /bin/true ; ")
152
153         out = err = ""
154         with self._lock:
155            (out, err), proc = self.run_and_wait(cmd, self.home, 
156                 pidfile = "cppid",
157                 stdout = "cplog", 
158                 stderr = "cperr", 
159                 raise_on_error = True)
160
161         return (out, err)   
162             
163     def clean_home(self):
164         self.logger.info("Cleaning up home")
165
166         cmd = "find . -maxdepth 1  \( -name '.cache' -o -name '.local' -o -name '.config' -o -name 'nepi-*' \) -execdir rm -rf {} + "
167
168         out = err = ""
169         with self._lock:
170             (out, err), proc = self.run_and_wait(cmd, self.home,
171                 pidfile = "chpid",
172                 stdout = "chlog", 
173                 stderr = "cherr", 
174                 raise_on_error = True)
175         
176         return (out, err)   
177
178     def upload(self, src, dst):
179         """ Copy content to destination
180
181            src  content to copy. Can be a local file, directory or text input
182
183            dst  destination path on the remote host (remote is always self.host)
184         """
185         # If source is a string input 
186         if not os.path.isfile(src):
187             # src is text input that should be uploaded as file
188             # create a temporal file with the content to upload
189             f = tempfile.NamedTemporaryFile(delete=False)
190             f.write(src)
191             f.close()
192             src = f.name
193
194         if not self.localhost:
195             # Build destination as <user>@<server>:<path>
196             dst = "%s@%s:%s" % (self.get("username"), self.get("hostname"), dst)
197
198         return self.copy(src, dst)
199
200     def download(self, src, dst):
201         if not self.localhost:
202             # Build destination as <user>@<server>:<path>
203             src = "%s@%s:%s" % (self.get("username"), self.get("hostname"), src)
204         return self.copy(src, dst)
205
206     def install_packages(self, packages):
207         cmd = ""
208         if self.os in ["f12", "f14"]:
209             cmd = rpmfuncs.install_packages_command(self.os, packages)
210         elif self.os in ["debian", "ubuntu"]:
211             cmd = debfuncs.install_packages_command(self.os, packages)
212         else:
213             msg = "Error installing packages. OS not known for host %s " % (
214                     self.get("hostname"))
215             self.logger.error(msg)
216             raise RuntimeError, msg
217
218         out = err = ""
219         with self._lock:
220             (out, err), proc = self.run_and_wait(cmd, self.home, 
221                 pidfile = "instpkgpid",
222                 stdout = "instpkglog", 
223                 stderr = "instpkgerr", 
224                 raise_on_error = True)
225
226         return (out, err), proc 
227
228     def remove_packages(self, packages):
229         cmd = ""
230         if self.os in ["f12", "f14"]:
231             cmd = rpmfuncs.remove_packages_command(self.os, packages)
232         elif self.os in ["debian", "ubuntu"]:
233             cmd = debfuncs.remove_packages_command(self.os, packages)
234         else:
235             msg = "Error removing packages. OS not known for host %s " % (
236                     self.get("hostname"))
237             self.logger.error(msg)
238             raise RuntimeError, msg
239
240         out = err = ""
241         with self._lock:
242             (out, err), proc = self.run_and_wait(cmd, self.home, 
243                 pidfile = "rmpkgpid",
244                 stdout = "rmpkglog", 
245                 stderr = "rmpkgerr", 
246                 raise_on_error = True)
247          
248         return (out, err), proc 
249
250     def mkdir(self, path, clean = False):
251         if clean:
252             self.rmdir(path)
253
254         return self.execute("mkdir -p %s" % path)
255
256     def rmdir(self, path):
257         return self.execute("rm -rf %s" % path)
258
259     def run_and_wait(self, command, 
260             home = ".", 
261             pidfile = "pid", 
262             stdin = None, 
263             stdout = 'stdout', 
264             stderr = 'stderr', 
265             sudo = False,
266             raise_on_error = False):
267
268         (out, err), proc = self.run(command, home, 
269                 pidfile = pidfile,
270                 stdin = stdin, 
271                 stdout = stdout, 
272                 stderr = stderr, 
273                 sudo = sudo)
274
275         if proc.poll() and err:
276             msg = " Failed to run command %s on host %s" % (
277                     command, self.get("hostname"))
278             self.logger.error(msg)
279             if raise_on_error:
280                 raise RuntimeError, msg
281         
282         pid, ppid = self.wait_pid(
283                 home = home, 
284                 pidfile = pidfile, 
285                 raise_on_error = raise_on_error)
286
287         self.wait_run(pid, ppid)
288         
289         (out, err), proc = self.check_run_error(home, stderr)
290
291         if err or out:
292             msg = "Error while running command %s on host %s. error output: %s" % (
293                     command, self.get("hostname"), out)
294             if err:
295                 msg += " . err: %s" % err
296
297             self.logger.error(msg)
298             if raise_on_error:
299                 raise RuntimeError, msg
300         
301         return (out, err), proc
302  
303     def wait_pid(self, home = ".", pidfile = "pid", raise_on_error = False):
304         pid = ppid = None
305         delay = 1.0
306         for i in xrange(5):
307             pidtuple = self.checkpid(home = home, pidfile = pidfile)
308             
309             if pidtuple:
310                 pid, ppid = pidtuple
311                 break
312             else:
313                 time.sleep(delay)
314                 delay = min(30,delay*1.2)
315         else:
316             msg = " Failed to get pid for pidfile %s/%s on host %s" % (
317                     home, pidfile, self.get("hostname"))
318             self.logger.error(msg)
319             if raise_on_error:
320                 raise RuntimeError, msg
321
322         return pid, ppid
323
324     def wait_run(self, pid, ppid, trial = 0):
325         delay = 1.0
326         first = True
327         bustspin = 0
328
329         while True:
330             status = self.status(pid, ppid)
331             
332             if status is sshfuncs.FINISHED:
333                 break
334             elif status is not sshfuncs.RUNNING:
335                 bustspin += 1
336                 time.sleep(delay*(5.5+random.random()))
337                 if bustspin > 12:
338                     break
339             else:
340                 if first:
341                     first = False
342
343                 time.sleep(delay*(0.5+random.random()))
344                 delay = min(30,delay*1.2)
345                 bustspin = 0
346
347     def check_run_error(self, home, stderr = 'stderr'):
348         (out, err), proc = self.execute("cat %s" % 
349                 os.path.join(home, stderr))
350         return (out, err), proc
351
352     def check_run_output(self, home, stdout = 'stdout'):
353         (out, err), proc = self.execute("cat %s" % 
354                 os.path.join(home, stdout))
355         return (out, err), proc
356
357     def is_alive(self):
358         if self.localhost:
359             return True
360
361         out = err = ""
362         try:
363             (out, err), proc = self.execute("echo 'ALIVE'")
364         except:
365             import traceback
366             trace = traceback.format_exc()
367             self.logger.warn("Unresponsive host %s. got:\n out: %s err: %s\n traceback: %s", 
368                     self.get("hostname"), out, err, trace)
369             return False
370
371         if out.strip().startswith('ALIVE'):
372             return True
373         else:
374             self.logger.warn("Unresponsive host %s. got:\n%s%s", 
375                     self.get("hostname"), out, err)
376             return False
377
378             # TODO!
379             #if self.check_bad_host(out,err):
380             #    self.blacklist()
381
382     def copy(self, src, dst):
383         if self.localhost:
384             (out, err), proc =  execfuncs.lcopy(source, dest, 
385                     recursive = True)
386         else:
387             (out, err), proc = self.safe_retry(sshfuncs.rcopy)(
388                 src, dst, 
389                 port = self.get("port"),
390                 identity = self.get("identity"),
391                 server_key = self.get("serverKey"),
392                 recursive = True)
393
394         return (out, err), proc
395
396     def execute(self, command,
397             sudo = False,
398             stdin = None, 
399             env = None,
400             tty = False,
401             forward_x11 = False,
402             timeout = None,
403             retry = 0,
404             err_on_timeout = True,
405             connect_timeout = 30,
406             persistent = True
407             ):
408         """ Notice that this invocation will block until the
409         execution finishes. If this is not the desired behavior,
410         use 'run' instead."""
411
412         if self.localhost:
413             (out, err), proc = execfuncs.lexec(command, 
414                     user = user,
415                     sudo = sudo,
416                     stdin = stdin,
417                     env = env)
418         else:
419             (out, err), proc = self.safe_retry(sshfuncs.rexec)(
420                     command, 
421                     host = self.get("hostname"),
422                     user = self.get("username"),
423                     port = self.get("port"),
424                     agent = True,
425                     sudo = sudo,
426                     stdin = stdin,
427                     identity = self.get("identity"),
428                     server_key = self.get("serverKey"),
429                     env = env,
430                     tty = tty,
431                     forward_x11 = forward_x11,
432                     timeout = timeout,
433                     retry = retry,
434                     err_on_timeout = err_on_timeout,
435                     connect_timeout = connect_timeout,
436                     persistent = persistent
437                     )
438
439         return (out, err), proc
440
441     def run(self, command, 
442             home = None,
443             create_home = True,
444             pidfile = "pid",
445             stdin = None, 
446             stdout = 'stdout', 
447             stderr = 'stderr', 
448             sudo = False):
449
450         self.logger.info("Running %s", command)
451         
452         if self.localhost:
453             (out, err), proc = execfuncs.lspawn(command, pidfile, 
454                     stdout = stdout, 
455                     stderr = stderr, 
456                     stdin = stdin, 
457                     home = home, 
458                     create_home = create_home, 
459                     sudo = sudo,
460                     user = user) 
461         else:
462             # Start process in a "daemonized" way, using nohup and heavy
463             # stdin/out redirection to avoid connection issues
464             (out,err), proc = self.safe_retry(sshfuncs.rspawn)(
465                 command,
466                 pidfile = pidfile,
467                 home = home,
468                 create_home = create_home,
469                 stdin = stdin if stdin is not None else '/dev/null',
470                 stdout = stdout if stdout else '/dev/null',
471                 stderr = stderr if stderr else '/dev/null',
472                 sudo = sudo,
473                 host = self.get("hostname"),
474                 user = self.get("username"),
475                 port = self.get("port"),
476                 agent = True,
477                 identity = self.get("identity"),
478                 server_key = self.get("serverKey")
479                 )
480
481         return (out, err), proc
482
483     def checkpid(self, home = ".", pidfile = "pid"):
484         if self.localhost:
485             pidtuple =  execfuncs.lcheckpid(os.path.join(home, pidfile))
486         else:
487             pidtuple = sshfuncs.rcheckpid(
488                 os.path.join(home, pidfile),
489                 host = self.get("hostname"),
490                 user = self.get("username"),
491                 port = self.get("port"),
492                 agent = True,
493                 identity = self.get("identity"),
494                 server_key = self.get("serverKey")
495                 )
496         
497         return pidtuple
498     
499     def status(self, pid, ppid):
500         if self.localhost:
501             status = execfuncs.lstatus(pid, ppid)
502         else:
503             status = sshfuncs.rstatus(
504                     pid, ppid,
505                     host = self.get("hostname"),
506                     user = self.get("username"),
507                     port = self.get("port"),
508                     agent = True,
509                     identity = self.get("identity"),
510                     server_key = self.get("serverKey")
511                     )
512            
513         return status
514     
515     def kill(self, pid, ppid, sudo = False):
516         out = err = ""
517         proc = None
518         status = self.status(pid, ppid)
519
520         if status == sshfuncs.RUNNING:
521             if self.localhost:
522                 (out, err), proc = execfuncs.lkill(pid, ppid, sudo)
523             else:
524                 (out, err), proc = self.safe_retry(sshfuncs.rkill)(
525                     pid, ppid,
526                     host = self.get("hostname"),
527                     user = self.get("username"),
528                     port = self.get("port"),
529                     agent = True,
530                     sudo = sudo,
531                     identity = self.get("identity"),
532                     server_key = self.get("serverKey")
533                     )
534         return (out, err), proc
535
536     def check_bad_host(self, out, err):
537         badre = re.compile(r'(?:'
538                            r'|Error: disk I/O error'
539                            r')', 
540                            re.I)
541         return badre.search(out) or badre.search(err)
542
543     def blacklist(self):
544         # TODO!!!!
545         self.logger.warn("Blacklisting malfunctioning node %s", self.hostname)
546         #import util
547         #util.appendBlacklist(self.hostname)
548
549     def safe_retry(self, func):
550         """Retries a function invocation using a lock"""
551         import functools
552         @functools.wraps(func)
553         def rv(*p, **kw):
554             fail_msg = " Failed to execute function %s(%s, %s) at host %s" % (
555                 func.__name__, p, kw, self.get("hostname"))
556             retry = kw.pop("_retry", False)
557             wlock = kw.pop("_with_lock", False)
558
559             out = err = ""
560             proc = None
561             for i in xrange(0 if retry else 4):
562                 try:
563                     if wlock:
564                         with self._lock:
565                             (out, err), proc = func(*p, **kw)
566                     else:
567                         (out, err), proc = func(*p, **kw)
568                         
569                     if proc.poll():
570                         if retry:
571                             time.sleep(i*15)
572                             continue
573                         else:
574                             self.logger.error("%s. out: %s error: %s", fail_msg, out, err)
575                     break
576                 except RuntimeError, e:
577                     if x >= 3:
578                         self.logger.error("%s. error: %s", fail_msg, e.args)
579             return (out, err), proc
580
581         return rv
582