oops, too much messing about
[nodemanager.git] / tools.py
1 # -*- python-indent: 4 -*-
2
3 """A few things that didn't seem to fit anywhere else."""
4
5 import os, os.path
6 import pwd
7 import tempfile
8 import fcntl
9 import errno
10 import threading
11 import subprocess
12 import shutil
13 import sys
14 import signal
15
16 import logger
17
18 PID_FILE = '/var/run/nodemanager.pid'
19
20 ####################
21 def get_default_if():
22     interface = get_if_from_hwaddr(get_hwaddr_from_plnode())
23     if not interface: interface = "eth0"
24     return interface
25
26 def get_hwaddr_from_plnode():
27     try:
28         for line in open("/usr/boot/plnode.txt", 'r').readlines():
29             if line.startswith("NET_DEVICE"):
30                 return line.split("=")[1].strip().strip('"')
31     except:
32         pass
33     return None
34
35 def get_if_from_hwaddr(hwaddr):
36     import sioc
37     devs = sioc.gifconf()
38     for dev in devs:
39         dev_hwaddr = sioc.gifhwaddr(dev)
40         if dev_hwaddr == hwaddr: return dev
41     return None
42
43 ####################
44 # daemonizing
45 def as_daemon_thread(run):
46     """Call function <run> with no arguments in its own thread."""
47     thr = threading.Thread(target=run)
48     thr.setDaemon(True)
49     thr.start()
50
51 def close_nonstandard_fds():
52     """Close all open file descriptors other than 0, 1, and 2."""
53     _SC_OPEN_MAX = 4
54     for fd in range(3, os.sysconf(_SC_OPEN_MAX)):
55         try: os.close(fd)
56         except OSError: pass  # most likely an fd that isn't open
57
58 # after http://www.erlenstar.demon.co.uk/unix/faq_2.html
59 def daemon():
60     """Daemonize the current process."""
61     if os.fork() != 0: os._exit(0)
62     os.setsid()
63     if os.fork() != 0: os._exit(0)
64     os.chdir('/')
65     os.umask(0022)
66     devnull = os.open(os.devnull, os.O_RDWR)
67     os.dup2(devnull, 0)
68     # xxx fixme - this is just to make sure that nothing gets stupidly lost - should use devnull
69     crashlog = os.open('/var/log/nodemanager.daemon', os.O_RDWR | os.O_APPEND | os.O_CREAT, 0644)
70     os.dup2(crashlog, 1)
71     os.dup2(crashlog, 2)
72
73 def fork_as(su, function, *args):
74     """
75 fork(), cd / to avoid keeping unused directories open,
76 close all nonstandard file descriptors (to avoid capturing open sockets),
77 fork() again (to avoid zombies) and call <function>
78 with arguments <args> in the grandchild process.
79 If <su> is not None, set our group and user ids
80  appropriately in the child process.
81     """
82     child_pid = os.fork()
83     if child_pid == 0:
84         try:
85             os.chdir('/')
86             close_nonstandard_fds()
87             if su:
88                 pw_ent = pwd.getpwnam(su)
89                 os.setegid(pw_ent[3])
90                 os.seteuid(pw_ent[2])
91             child_pid = os.fork()
92             if child_pid == 0: function(*args)
93         except:
94             os.seteuid(os.getuid())  # undo su so we can write the log file
95             os.setegid(os.getgid())
96             logger.log_exc("tools: fork_as")
97         os._exit(0)
98     else: os.waitpid(child_pid, 0)
99
100 ####################
101 # manage files
102 def pid_file():
103     """
104 We use a pid file to ensure that only one copy of NM is running at a given time.
105 If successful, this function will write a pid file containing the pid of the current process.
106 The return value is the pid of the other running process, or None otherwise.
107     """
108     other_pid = None
109     if os.access(PID_FILE, os.F_OK):  # check for a pid file
110         handle = open(PID_FILE)  # pid file exists, read it
111         other_pid = int(handle.read())
112         handle.close()
113         # check for a process with that pid by sending signal 0
114         try: os.kill(other_pid, 0)
115         except OSError, e:
116             if e.errno == errno.ESRCH: other_pid = None  # doesn't exist
117             else: raise  # who knows
118     if other_pid == None:
119         # write a new pid file
120         write_file(PID_FILE, lambda f: f.write(str(os.getpid())))
121     return other_pid
122
123 def write_file(filename, do_write, **kw_args):
124     """
125 Write file <filename> atomically by opening a temporary file,
126 using <do_write> to write that file, and then renaming the temporary file.
127     """
128     shutil.move(write_temp_file(do_write, **kw_args), filename)
129
130 def write_temp_file(do_write, mode=None, uidgid=None):
131     fd, temporary_filename = tempfile.mkstemp()
132     if mode: os.chmod(temporary_filename, mode)
133     if uidgid: os.chown(temporary_filename, *uidgid)
134     f = os.fdopen(fd, 'w')
135     try: do_write(f)
136     finally: f.close()
137     return temporary_filename
138
139 def replace_file_with_string (target, new_contents, chmod=None, remove_if_empty=False):
140     """
141 Replace a target file with a new contents
142 checks for changes: does not do anything if previous state was already right
143 can handle chmod if requested
144 can also remove resulting file if contents are void, if requested
145 performs atomically:
146 writes in a tmp file, which is then renamed (from sliverauth originally)
147 returns True if a change occurred, or the file is deleted
148     """
149     try:
150         current=file(target).read()
151     except:
152         current=""
153     if current==new_contents:
154         # if turns out to be an empty string, and remove_if_empty is set,
155         # then make sure to trash the file if it exists
156         if remove_if_empty and not new_contents and os.path.isfile(target):
157             logger.verbose("tools.replace_file_with_string: removing file %s"%target)
158             try: os.unlink(target)
159             finally: return True
160         return False
161     # overwrite target file: create a temp in the same directory
162     path=os.path.dirname(target) or '.'
163     fd, name = tempfile.mkstemp('','repl',path)
164     os.write(fd,new_contents)
165     os.close(fd)
166     if os.path.exists(target):
167         os.unlink(target)
168     shutil.move(name,target)
169     if chmod: os.chmod(target,chmod)
170     return True
171
172 ####################
173 # utilities functions to get (cached) information from the node
174
175 # get node_id from /etc/planetlab/node_id and cache it
176 _node_id=None
177 def node_id():
178     global _node_id
179     if _node_id is None:
180         try:
181             _node_id=int(file("/etc/planetlab/node_id").read())
182         except:
183             _node_id=""
184     return _node_id
185
186 _root_context_arch=None
187 def root_context_arch():
188     global _root_context_arch
189     if not _root_context_arch:
190         sp=subprocess.Popen(["uname","-i"],stdout=subprocess.PIPE)
191         (_root_context_arch,_)=sp.communicate()
192         _root_context_arch=_root_context_arch.strip()
193     return _root_context_arch
194
195
196 ####################
197 class NMLock:
198     def __init__(self, file):
199         logger.log("tools: Lock %s initialized." % file, 2)
200         self.fd = os.open(file, os.O_RDWR|os.O_CREAT, 0600)
201         flags = fcntl.fcntl(self.fd, fcntl.F_GETFD)
202         flags |= fcntl.FD_CLOEXEC
203         fcntl.fcntl(self.fd, fcntl.F_SETFD, flags)
204     def __del__(self):
205         os.close(self.fd)
206     def acquire(self):
207         logger.log("tools: Lock acquired.", 2)
208         fcntl.lockf(self.fd, fcntl.LOCK_SH)
209     def release(self):
210         logger.log("tools: Lock released.", 2)
211         fcntl.lockf(self.fd, fcntl.LOCK_UN)
212
213 ####################
214 # Utilities for getting the IP address of a LXC/Openvswitch slice. Do this by
215 # running ifconfig inside of the slice's context.
216
217 def get_sliver_process(slice_name, process_cmdline):
218     """ 
219     Utility function to find a process inside of an LXC sliver. Returns
220     (cgroup_fn, pid). cgroup_fn is the filename of the cgroup file for
221     the process, for example /proc/2592/cgroup. Pid is the process id of
222     the process. If the process is not found then (None, None) is returned.
223     """
224     try:
225         cmd = 'grep %s /proc/*/cgroup | grep freezer'%slice_name
226         output = os.popen(cmd).readlines()
227     except:
228         # the slice couldn't be found
229         logger.log("get_sliver_process: couldn't find slice %s" % slice_name)
230         return (None, None)
231
232     cgroup_fn = None
233     pid = None
234     for e in output:
235         try:
236             l = e.rstrip()
237             path = l.split(':')[0]
238             comp = l.rsplit(':')[-1]
239             slice_name_check = comp.rsplit('/')[-1]
240             # the lines below were added by Guilherme <gsm@machados.org>
241             # due to the ipv6 plugin requirements (LXC)
242             virt=get_node_virt()
243             if virt=='lxc':
244                 slice_name_check = slice_name_check.rsplit('.')[0]
245
246             if (slice_name_check == slice_name):
247                 slice_path = path
248                 pid = slice_path.split('/')[2]
249                 cmdline = open('/proc/%s/cmdline'%pid).read().rstrip('\n\x00')
250                 if (cmdline == process_cmdline):
251                     cgroup_fn = slice_path
252                     break
253         except:
254             break
255
256     if (not cgroup_fn) or (not pid):
257         logger.log("get_sliver_process: process %s not running in slice %s" % (process_cmdline, slice_name))
258         return (None, None)
259
260     return (cgroup_fn, pid)
261
262 ###################################################
263 # Added by Guilherme Sperb Machado <gsm@machados.org>
264 ###################################################
265
266 try:
267     import re
268     import socket
269     import fileinput
270 except:
271     logger.log("Could not import 're', 'socket', or 'fileinput' python packages.")
272
273 # TODO: is there anything better to do if the "libvirt", "sliver_libvirt",
274 # and "sliver_lxc" are not in place?
275 try:
276     import libvirt
277     from sliver_libvirt import Sliver_Libvirt
278     import sliver_lxc
279 except:
280     logger.log("Could not import 'sliver_lxc' or 'libvirt' or 'sliver_libvirt'.")
281 ###################################################
282
283 def get_sliver_ifconfig(slice_name, device="eth0"):
284     """ 
285     return the output of "ifconfig" run from inside the sliver.
286
287     side effects: adds "/usr/sbin" to sys.path
288     """
289
290     # See if setns is installed. If it's not then we're probably not running
291     # LXC.
292     if not os.path.exists("/usr/sbin/setns.so"):
293         return None
294
295     # setns is part of lxcsu and is installed to /usr/sbin
296     if not "/usr/sbin" in sys.path:
297         sys.path.append("/usr/sbin")
298     import setns
299
300     (cgroup_fn, pid) = get_sliver_process(slice_name, "/sbin/init")
301     if (not cgroup_fn) or (not pid):
302         return None
303
304     path = '/proc/%s/ns/net'%pid
305
306     result = None
307     try:
308         setns.chcontext(path)
309
310         args = ["/sbin/ifconfig", device]
311         sub = subprocess.Popen(args, stderr = subprocess.PIPE, stdout = subprocess.PIPE)
312         sub.wait()
313
314         if (sub.returncode != 0):
315             logger.log("get_slice_ifconfig: error in ifconfig: %s" % sub.stderr.read())
316
317         result = sub.stdout.read()
318     finally:
319         setns.chcontext("/proc/1/ns/net")
320
321     return result
322
323 def get_sliver_ip(slice_name):
324     ifconfig = get_sliver_ifconfig(slice_name)
325     if not ifconfig:
326         return None
327
328     for line in ifconfig.split("\n"):
329         if "inet addr:" in line:
330             # example: '          inet addr:192.168.122.189  Bcast:192.168.122.255  Mask:255.255.255.0'
331             parts = line.strip().split()
332             if len(parts)>=2 and parts[1].startswith("addr:"):
333                 return parts[1].split(":")[1]
334
335     return None
336
337 ###################################################
338 # Author: Guilherme Sperb Machado <gsm@machados.org>
339 ###################################################
340 # Get the slice ipv6 address
341 # Only for LXC!
342 ###################################################
343 def get_sliver_ipv6(slice_name):
344     ifconfig = get_sliver_ifconfig(slice_name)
345     if not ifconfig:
346         return None,None
347
348     # example: 'inet6 2001:67c:16dc:1302:5054:ff:fea7:7882  prefixlen 64  scopeid 0x0<global>'
349     prog = re.compile(r'inet6\s+(.*)\s+prefixlen\s+(\d+)\s+scopeid\s+(.+)<global>')
350     for line in ifconfig.split("\n"):
351         search = prog.search(line)
352         if search:
353             ipv6addr = search.group(1)
354             prefixlen = search.group(2)
355             return (ipv6addr,prefixlen)
356     return None,None
357
358 ###################################################
359 # Author: Guilherme Sperb Machado <gsm@machados.org>
360 ###################################################
361 # Check if the address is a AF_INET6 family address
362 ###################################################
363 def is_valid_ipv6(ipv6addr):
364     try:
365         socket.inet_pton(socket.AF_INET6, ipv6addr)
366     except socket.error:
367         return False
368     return True
369
370 ### this returns the kind of virtualization on the node
371 # either 'vs' or 'lxc'
372 # also caches it in /etc/planetlab/virt for next calls
373 # could be promoted to core nm if need be
374 virt_stamp="/etc/planetlab/virt"
375 def get_node_virt ():
376     try:
377         return file(virt_stamp).read().strip()
378     except:
379         pass
380     logger.log("Computing virt..")
381     try:
382         if subprocess.call ([ 'vserver', '--help' ]) ==0: virt='vs'
383         else:                                             virt='lxc'
384     except:
385         virt='lxc'
386     with file(virt_stamp,"w") as f:
387         f.write(virt)
388     return virt
389
390 ### this return True or False to indicate that systemctl is present on that box
391 # cache result in memory as _has_systemctl
392 _has_systemctl=None
393 def has_systemctl ():
394     global _has_systemctl
395     if _has_systemctl is None:
396         _has_systemctl = (subprocess.call([ 'systemctl', '--help' ]) == 0)
397     return _has_systemctl
398
399 ###################################################
400 # Author: Guilherme Sperb Machado <gsm@machados.org>
401 ###################################################
402 # This method was developed to support the ipv6 plugin
403 # Only for LXC!
404 ###################################################
405 def reboot_slivers():
406     type = 'sliver.LXC'
407     # connecting to the libvirtd
408     connLibvirt = Sliver_Libvirt.getConnection(type)
409     domains = connLibvirt.listAllDomains()
410     for domain in domains:
411         try:
412             domain.destroy()
413             logger.log("tools: DESTROYED %s" % (domain.name()) )
414             domain.create()
415             logger.log("tools: CREATED %s" % (domain.name()) )
416         except:
417             logger.log("tools: FAILED to reboot %s" % (domain.name()) )
418
419 ###################################################
420 # Author: Guilherme Sperb Machado <gsm@machados.org>
421 ###################################################
422 # Get the /etc/hosts file path
423 ###################################################
424 def get_hosts_file_path(slicename):
425     containerDir = os.path.join(sliver_lxc.Sliver_LXC.CON_BASE_DIR, slicename)
426     return os.path.join(containerDir, 'etc', 'hosts')
427
428 ###################################################
429 # Author: Guilherme Sperb Machado <gsm@machados.org>
430 ###################################################
431 # Search if there is a specific ipv6 address in the
432 # /etc/hosts file of a given slice
433 # If the parameter 'ipv6addr' is None, then search
434 # for any ipv6 address
435 ###################################################
436 def search_ipv6addr_hosts(slicename, ipv6addr):
437     hostsFilePath = get_hosts_file_path(slicename)
438     found=False
439     try:
440         for line in fileinput.input(r'%s' % (hostsFilePath)):
441             if ipv6addr is not None:
442                 if re.search(r'%s' % (ipv6addr), line):
443                     found=True
444             else:
445                 search = re.search(r'^(.*)\s+.*$', line)
446                 if search:
447                     ipv6candidate = search.group(1)
448                     ipv6candidatestrip = ipv6candidate.strip()
449                     valid = is_valid_ipv6(ipv6candidatestrip)
450                     if valid:
451                         found=True
452             fileinput.close()
453             return found
454     except:
455         logger.log("tools: FAILED to search %s in /etc/hosts file of slice=%s" % \
456                    (ipv6addr, slicename) )
457
458 ###################################################
459 # Author: Guilherme Sperb Machado <gsm@machados.org>
460 ###################################################
461 # Removes all ipv6 addresses from the /etc/hosts
462 # file of a given slice
463 ###################################################
464 def remove_all_ipv6addr_hosts(slicename, node):
465     hostsFilePath = get_hosts_file_path(slicename)
466     try:
467         for line in fileinput.input(r'%s' % (hostsFilePath), inplace=True):
468             search = re.search(r'^(.*)\s+(%s|%s)$' % (node,'localhost'), line)
469             if search:
470                 ipv6candidate = search.group(1)
471                 ipv6candidatestrip = ipv6candidate.strip()
472                 valid = is_valid_ipv6(ipv6candidatestrip)
473                 if not valid:
474                     print line,
475         fileinput.close()
476         logger.log("tools: REMOVED IPv6 address from /etc/hosts file of slice=%s" % \
477                    (slicename) )
478     except:
479         logger.log("tools: FAILED to remove the IPv6 address from /etc/hosts file of slice=%s" % \
480                    (slicename) )
481
482 ###################################################
483 # Author: Guilherme Sperb Machado <gsm@machados.org>
484 ###################################################
485 # Adds an ipv6 address to the /etc/hosts file within a slice
486 ###################################################
487 def add_ipv6addr_hosts_line(slicename, node, ipv6addr):
488     hostsFilePath = get_hosts_file_path(slicename)
489     logger.log("tools: %s" % (hostsFilePath) )
490     # debugging purposes:
491     #string = "127.0.0.1\tlocalhost\n192.168.100.179\tmyplc-node1-vm.mgmt.local\n"
492     #string = "127.0.0.1\tlocalhost\n"
493     try:
494         with open(hostsFilePath, "a") as file:
495             file.write(ipv6addr + " " + node + "\n")
496             file.close()
497         logger.log("tools: ADDED IPv6 address to /etc/hosts file of slice=%s" % \
498                    (slicename) )
499     except:
500         logger.log("tools: FAILED to add the IPv6 address to /etc/hosts file of slice=%s" % \
501                    (slicename) )
502
503
504
505 # how to run a command in a slice
506 # now this is a painful matter
507 # the problem is with capsh that forces a bash command to be injected in its exec'ed command
508 # so because lxcsu uses capsh, you cannot exec anything else than bash
509 # bottom line is, what actually needs to be called is
510 # vs:  vserver exec slicename command and its arguments
511 # lxc: lxcsu slicename "command and its arguments"
512 # which, OK, is no big deal as long as the command is simple enough,
513 # but do not stretch it with arguments that have spaces or need quoting as that will become a nightmare
514 def command_in_slice (slicename, argv):
515     virt=get_node_virt()
516     if virt=='vs':
517         return [ 'vserver', slicename, 'exec', ] + argv
518     elif virt=='lxc':
519         # wrap up argv in a single string for -c
520         return [ 'lxcsu', slicename, ] + [ " ".join(argv) ]
521     logger.log("command_in_slice: WARNING: could not find a valid virt")
522     return argv
523
524 ####################
525 def init_signals ():
526     def handler (signum, frame):
527         logger.log("Received signal %d - exiting"%signum)
528         os._exit(1)
529     signal.signal(signal.SIGHUP,handler)
530     signal.signal(signal.SIGQUIT,handler)
531     signal.signal(signal.SIGINT,handler)
532     signal.signal(signal.SIGTERM,handler)