Merge branch 'ipv6'
[nodemanager.git] / tools.py
1 # -*- python-indent: 4 -*-
2
3 """A few things that didn't seem to fit anywhere else."""
4
5 import os, os.path
6 import pwd
7 import tempfile
8 import fcntl
9 import errno
10 import threading
11 import subprocess
12 import shutil
13 import sys
14 import signal
15
16 import logger
17
18 PID_FILE = '/var/run/nodemanager.pid'
19
20 ####################
21 def get_default_if():
22     interface = get_if_from_hwaddr(get_hwaddr_from_plnode())
23     if not interface: interface = "eth0"
24     return interface
25
26 def get_hwaddr_from_plnode():
27     try:
28         for line in open("/usr/boot/plnode.txt", 'r').readlines():
29             if line.startswith("NET_DEVICE"):
30                 return line.split("=")[1].strip().strip('"')
31     except:
32         pass
33     return None
34
35 def get_if_from_hwaddr(hwaddr):
36     import sioc
37     devs = sioc.gifconf()
38     for dev in devs:
39         dev_hwaddr = sioc.gifhwaddr(dev)
40         if dev_hwaddr == hwaddr: return dev
41     return None
42
43 ####################
44 # daemonizing
45 def as_daemon_thread(run):
46     """Call function <run> with no arguments in its own thread."""
47     thr = threading.Thread(target=run)
48     thr.setDaemon(True)
49     thr.start()
50
51 def close_nonstandard_fds():
52     """Close all open file descriptors other than 0, 1, and 2."""
53     _SC_OPEN_MAX = 4
54     for fd in range(3, os.sysconf(_SC_OPEN_MAX)):
55         try: os.close(fd)
56         except OSError: pass  # most likely an fd that isn't open
57
58 # after http://www.erlenstar.demon.co.uk/unix/faq_2.html
59 def daemon():
60     """Daemonize the current process."""
61     if os.fork() != 0: os._exit(0)
62     os.setsid()
63     if os.fork() != 0: os._exit(0)
64     os.chdir('/')
65     os.umask(0022)
66     devnull = os.open(os.devnull, os.O_RDWR)
67     os.dup2(devnull, 0)
68     # xxx fixme - this is just to make sure that nothing gets stupidly lost - should use devnull
69     crashlog = os.open('/var/log/nodemanager.daemon', os.O_RDWR | os.O_APPEND | os.O_CREAT, 0644)
70     os.dup2(crashlog, 1)
71     os.dup2(crashlog, 2)
72
73 def fork_as(su, function, *args):
74     """
75 fork(), cd / to avoid keeping unused directories open,
76 close all nonstandard file descriptors (to avoid capturing open sockets),
77 fork() again (to avoid zombies) and call <function>
78 with arguments <args> in the grandchild process.
79 If <su> is not None, set our group and user ids
80  appropriately in the child process.
81     """
82     child_pid = os.fork()
83     if child_pid == 0:
84         try:
85             os.chdir('/')
86             close_nonstandard_fds()
87             if su:
88                 pw_ent = pwd.getpwnam(su)
89                 os.setegid(pw_ent[3])
90                 os.seteuid(pw_ent[2])
91             child_pid = os.fork()
92             if child_pid == 0: function(*args)
93         except:
94             os.seteuid(os.getuid())  # undo su so we can write the log file
95             os.setegid(os.getgid())
96             logger.log_exc("tools: fork_as")
97         os._exit(0)
98     else: os.waitpid(child_pid, 0)
99
100 ####################
101 # manage files
102 def pid_file():
103     """
104 We use a pid file to ensure that only one copy of NM is running at a given time.
105 If successful, this function will write a pid file containing the pid of the current process.
106 The return value is the pid of the other running process, or None otherwise.
107     """
108     other_pid = None
109     if os.access(PID_FILE, os.F_OK):  # check for a pid file
110         handle = open(PID_FILE)  # pid file exists, read it
111         other_pid = int(handle.read())
112         handle.close()
113         # check for a process with that pid by sending signal 0
114         try: os.kill(other_pid, 0)
115         except OSError, e:
116             if e.errno == errno.ESRCH: other_pid = None  # doesn't exist
117             else: raise  # who knows
118     if other_pid == None:
119         # write a new pid file
120         write_file(PID_FILE, lambda f: f.write(str(os.getpid())))
121     return other_pid
122
123 def write_file(filename, do_write, **kw_args):
124     """
125 Write file <filename> atomically by opening a temporary file,
126 using <do_write> to write that file, and then renaming the temporary file.
127     """
128     shutil.move(write_temp_file(do_write, **kw_args), filename)
129
130 def write_temp_file(do_write, mode=None, uidgid=None):
131     fd, temporary_filename = tempfile.mkstemp()
132     if mode: os.chmod(temporary_filename, mode)
133     if uidgid: os.chown(temporary_filename, *uidgid)
134     f = os.fdopen(fd, 'w')
135     try: do_write(f)
136     finally: f.close()
137     return temporary_filename
138
139 def replace_file_with_string (target, new_contents, chmod=None, remove_if_empty=False):
140     """
141 Replace a target file with a new contents
142 checks for changes: does not do anything if previous state was already right
143 can handle chmod if requested
144 can also remove resulting file if contents are void, if requested
145 performs atomically:
146 writes in a tmp file, which is then renamed (from sliverauth originally)
147 returns True if a change occurred, or the file is deleted
148     """
149     try:
150         current=file(target).read()
151     except:
152         current=""
153     if current==new_contents:
154         # if turns out to be an empty string, and remove_if_empty is set,
155         # then make sure to trash the file if it exists
156         if remove_if_empty and not new_contents and os.path.isfile(target):
157             logger.verbose("tools.replace_file_with_string: removing file %s"%target)
158             try: os.unlink(target)
159             finally: return True
160         return False
161     # overwrite target file: create a temp in the same directory
162     path=os.path.dirname(target) or '.'
163     fd, name = tempfile.mkstemp('','repl',path)
164     os.write(fd,new_contents)
165     os.close(fd)
166     if os.path.exists(target):
167         os.unlink(target)
168     shutil.move(name,target)
169     if chmod: os.chmod(target,chmod)
170     return True
171
172 ####################
173 # utilities functions to get (cached) information from the node
174
175 # get node_id from /etc/planetlab/node_id and cache it
176 _node_id=None
177 def node_id():
178     global _node_id
179     if _node_id is None:
180         try:
181             _node_id=int(file("/etc/planetlab/node_id").read())
182         except:
183             _node_id=""
184     return _node_id
185
186 _root_context_arch=None
187 def root_context_arch():
188     global _root_context_arch
189     if not _root_context_arch:
190         sp=subprocess.Popen(["uname","-i"],stdout=subprocess.PIPE)
191         (_root_context_arch,_)=sp.communicate()
192         _root_context_arch=_root_context_arch.strip()
193     return _root_context_arch
194
195
196 ####################
197 class NMLock:
198     def __init__(self, file):
199         logger.log("tools: Lock %s initialized." % file, 2)
200         self.fd = os.open(file, os.O_RDWR|os.O_CREAT, 0600)
201         flags = fcntl.fcntl(self.fd, fcntl.F_GETFD)
202         flags |= fcntl.FD_CLOEXEC
203         fcntl.fcntl(self.fd, fcntl.F_SETFD, flags)
204     def __del__(self):
205         os.close(self.fd)
206     def acquire(self):
207         logger.log("tools: Lock acquired.", 2)
208         fcntl.lockf(self.fd, fcntl.LOCK_SH)
209     def release(self):
210         logger.log("tools: Lock released.", 2)
211         fcntl.lockf(self.fd, fcntl.LOCK_UN)
212
213 ####################
214 # Utilities for getting the IP address of a LXC/Openvswitch slice. Do this by
215 # running ifconfig inside of the slice's context.
216
217 def get_sliver_process(slice_name, process_cmdline):
218     """ Utility function to find a process inside of an LXC sliver. Returns
219         (cgroup_fn, pid). cgroup_fn is the filename of the cgroup file for
220         the process, for example /proc/2592/cgroup. Pid is the process id of
221         the process. If the process is not found then (None, None) is returned.
222     """
223     try:
224         cmd = 'grep %s /proc/*/cgroup | grep freezer'%slice_name
225         output = os.popen(cmd).readlines()
226     except:
227         # the slice couldn't be found
228         logger.log("get_sliver_process: couldn't find slice %s" % slice_name)
229         return (None, None)
230
231     cgroup_fn = None
232     pid = None
233     for e in output:
234         try:
235             l = e.rstrip()
236             path = l.split(':')[0]
237             comp = l.rsplit(':')[-1]
238             slice_name_check = comp.rsplit('/')[-1]
239             # the lines below were added by Guilherme <gsm@machados.org>
240             # due to the ipv6 plugin requirements (LXC)
241             virt=get_node_virt()
242             if virt=='lxc':
243                 slice_name_check = slice_name_check.rsplit('.')[0]
244
245             if (slice_name_check == slice_name):
246                 slice_path = path
247                 pid = slice_path.split('/')[2]
248                 cmdline = open('/proc/%s/cmdline'%pid).read().rstrip('\n\x00')
249                 if (cmdline == process_cmdline):
250                     cgroup_fn = slice_path
251                     break
252         except:
253             break
254
255     if (not cgroup_fn) or (not pid):
256         logger.log("get_sliver_process: process %s not running in slice %s" % (process_cmdline, slice_name))
257         return (None, None)
258
259     return (cgroup_fn, pid)
260
261 ###################################################
262 # Added by Guilherme Sperb Machado <gsm@machados.org>
263 ###################################################
264
265 try:
266     import re
267     import socket
268     import fileinput
269 except:
270     logger.log("Could not import 're', 'socket', or 'fileinput' python packages.")
271
272 # TODO: is there anything better to do if the "libvirt", "sliver_libvirt",
273 # and "sliver_lxc" are not in place?
274 try:
275     import libvirt
276     from sliver_libvirt import Sliver_Libvirt
277     import sliver_lxc
278 except:
279     logger.log("Could not import 'sliver_lxc' or 'libvirt' or 'sliver_libvirt'.")
280 ###################################################
281
282 def get_sliver_ifconfig(slice_name, device="eth0"):
283     """ return the output of "ifconfig" run from inside the sliver.
284
285         side effects: adds "/usr/sbin" to sys.path
286     """
287
288     # See if setns is installed. If it's not then we're probably not running
289     # LXC.
290     if not os.path.exists("/usr/sbin/setns.so"):
291         return None
292
293     # setns is part of lxcsu and is installed to /usr/sbin
294     if not "/usr/sbin" in sys.path:
295         sys.path.append("/usr/sbin")
296     import setns
297
298     (cgroup_fn, pid) = get_sliver_process(slice_name, "/sbin/init")
299     if (not cgroup_fn) or (not pid):
300         return None
301
302     path = '/proc/%s/ns/net'%pid
303
304     result = None
305     try:
306         setns.chcontext(path)
307
308         args = ["/sbin/ifconfig", device]
309         sub = subprocess.Popen(args, stderr = subprocess.PIPE, stdout = subprocess.PIPE)
310         sub.wait()
311
312         if (sub.returncode != 0):
313             logger.log("get_slice_ifconfig: error in ifconfig: %s" % sub.stderr.read())
314
315         result = sub.stdout.read()
316     finally:
317         setns.chcontext("/proc/1/ns/net")
318
319     return result
320
321 def get_sliver_ip(slice_name):
322     ifconfig = get_sliver_ifconfig(slice_name)
323     if not ifconfig:
324         return None
325
326     for line in ifconfig.split("\n"):
327         if "inet addr:" in line:
328             # example: '          inet addr:192.168.122.189  Bcast:192.168.122.255  Mask:255.255.255.0'
329             parts = line.strip().split()
330             if len(parts)>=2 and parts[1].startswith("addr:"):
331                 return parts[1].split(":")[1]
332
333     return None
334
335 ###################################################
336 # Author: Guilherme Sperb Machado <gsm@machados.org>
337 ###################################################
338 # Get the slice ipv6 address
339 # Only for LXC!
340 ###################################################
341 def get_sliver_ipv6(slice_name):
342     ifconfig = get_sliver_ifconfig(slice_name)
343     if not ifconfig:
344         return None,None
345
346     # example: 'inet6 2001:67c:16dc:1302:5054:ff:fea7:7882  prefixlen 64  scopeid 0x0<global>'
347     prog = re.compile(r'inet6\s+(.*)\s+prefixlen\s+(\d+)\s+scopeid\s+(.+)<global>')
348     for line in ifconfig.split("\n"):
349         search = prog.search(line)
350         if search:
351             ipv6addr = search.group(1)
352             prefixlen = search.group(2)
353             return (ipv6addr,prefixlen)
354     return None,None
355
356 ###################################################
357 # Author: Guilherme Sperb Machado <gsm@machados.org>
358 ###################################################
359 # Check if the address is a AF_INET6 family address
360 ###################################################
361 def is_valid_ipv6(ipv6addr):
362     try:
363         socket.inet_pton(socket.AF_INET6, ipv6addr)
364     except socket.error:
365         return False
366     return True
367
368 ### this returns the kind of virtualization on the node
369 # either 'vs' or 'lxc'
370 # also caches it in /etc/planetlab/virt for next calls
371 # could be promoted to core nm if need be
372 virt_stamp="/etc/planetlab/virt"
373 def get_node_virt ():
374     try:
375         return file(virt_stamp).read().strip()
376     except:
377         pass
378     logger.log("Computing virt..")
379     try:
380         if subprocess.call ([ 'vserver', '--help' ]) ==0: virt='vs'
381         else:                                             virt='lxc'
382     except:
383         virt='lxc'
384     with file(virt_stamp,"w") as f:
385         f.write(virt)
386     return virt
387
388 ### this return True or False to indicate that systemctl is present on that box
389 # cache result in memory as _has_systemctl
390 _has_systemctl=None
391 def has_systemctl ():
392     global _has_systemctl
393     if _has_systemctl is None:
394         _has_systemctl = (subprocess.call([ 'systemctl', '--help' ]) == 0)
395     return _has_systemctl
396
397 ###################################################
398 # Author: Guilherme Sperb Machado <gsm@machados.org>
399 ###################################################
400 # This method was developed to support the ipv6 plugin
401 # Only for LXC!
402 ###################################################
403 def reboot_slivers():
404     type = 'sliver.LXC'
405     # connecting to the libvirtd
406     connLibvirt = Sliver_Libvirt.getConnection(type)
407     domains = connLibvirt.listAllDomains()
408     for domain in domains:
409         try:
410             domain.destroy()
411             logger.log("tools: DESTROYED %s" % (domain.name()) )
412             domain.create()
413             logger.log("tools: CREATED %s" % (domain.name()) )
414         except:
415             logger.log("tools: FAILED to reboot %s" % (domain.name()) )
416
417 ###################################################
418 # Author: Guilherme Sperb Machado <gsm@machados.org>
419 ###################################################
420 # Get the /etc/hosts file path
421 ###################################################
422 def get_hosts_file_path(slicename):
423     containerDir = os.path.join(sliver_lxc.Sliver_LXC.CON_BASE_DIR, slicename)
424     return os.path.join(containerDir, 'etc', 'hosts')
425
426 ###################################################
427 # Author: Guilherme Sperb Machado <gsm@machados.org>
428 ###################################################
429 # Search if there is a specific ipv6 address in the
430 # /etc/hosts file of a given slice
431 # If the parameter 'ipv6addr' is None, then search
432 # for any ipv6 address
433 ###################################################
434 def search_ipv6addr_hosts(slicename, ipv6addr):
435     hostsFilePath = get_hosts_file_path(slicename)
436     found=False
437     try:
438         for line in fileinput.input(r'%s' % (hostsFilePath)):
439             if ipv6addr is not None:
440                 if re.search(r'%s' % (ipv6addr), line):
441                     found=True
442             else:
443                 search = re.search(r'^(.*)\s+.*$', line)
444                 if search:
445                     ipv6candidate = search.group(1)
446                     ipv6candidatestrip = ipv6candidate.strip()
447                     valid = is_valid_ipv6(ipv6candidatestrip)
448                     if valid:
449                         found=True
450             fileinput.close()
451             return found
452     except:
453         logger.log("tools: FAILED to search %s in /etc/hosts file of slice=%s" % \
454                    (ipv6addr, slicename) )
455
456 ###################################################
457 # Author: Guilherme Sperb Machado <gsm@machados.org>
458 ###################################################
459 # Removes all ipv6 addresses from the /etc/hosts
460 # file of a given slice
461 ###################################################
462 def remove_all_ipv6addr_hosts(slicename, node):
463     hostsFilePath = get_hosts_file_path(slicename)
464     try:
465         for line in fileinput.input(r'%s' % (hostsFilePath), inplace=True):
466             search = re.search(r'^(.*)\s+(%s|%s)$' % (node,'localhost'), line)
467             if search:
468                 ipv6candidate = search.group(1)
469                 ipv6candidatestrip = ipv6candidate.strip()
470                 valid = is_valid_ipv6(ipv6candidatestrip)
471                 if not valid:
472                     print line,
473         fileinput.close()
474         logger.log("tools: REMOVED IPv6 address from /etc/hosts file of slice=%s" % \
475                    (slicename) )
476     except:
477         logger.log("tools: FAILED to remove the IPv6 address from /etc/hosts file of slice=%s" % \
478                    (slicename) )
479
480 ###################################################
481 # Author: Guilherme Sperb Machado <gsm@machados.org>
482 ###################################################
483 # Adds an ipv6 address to the /etc/hosts file within a slice
484 ###################################################
485 def add_ipv6addr_hosts_line(slicename, node, ipv6addr):
486     hostsFilePath = get_hosts_file_path(slicename)
487     logger.log("tools: %s" % (hostsFilePath) )
488     # debugging purposes:
489     #string = "127.0.0.1\tlocalhost\n192.168.100.179\tmyplc-node1-vm.mgmt.local\n"
490     #string = "127.0.0.1\tlocalhost\n"
491     try:
492         with open(hostsFilePath, "a") as file:
493             file.write(ipv6addr + " " + node + "\n")
494             file.close()
495         logger.log("tools: ADDED IPv6 address to /etc/hosts file of slice=%s" % \
496                    (slicename) )
497     except:
498         logger.log("tools: FAILED to add the IPv6 address to /etc/hosts file of slice=%s" % \
499                    (slicename) )
500
501
502
503 # how to run a command in a slice
504 # now this is a painful matter
505 # the problem is with capsh that forces a bash command to be injected in its exec'ed command
506 # so because lxcsu uses capsh, you cannot exec anything else than bash
507 # bottom line is, what actually needs to be called is
508 # vs:  vserver exec slicename command and its arguments
509 # lxc: lxcsu slicename "command and its arguments"
510 # which, OK, is no big deal as long as the command is simple enough,
511 # but do not stretch it with arguments that have spaces or need quoting as that will become a nightmare
512 def command_in_slice (slicename, argv):
513     virt=get_node_virt()
514     if virt=='vs':
515         return [ 'vserver', slicename, 'exec', ] + argv
516     elif virt=='lxc':
517         # wrap up argv in a single string for -c
518         return [ 'lxcsu', slicename, ] + [ " ".join(argv) ]
519     logger.log("command_in_slice: WARNING: could not find a valid virt")
520     return argv
521
522 ####################
523 def init_signals ():
524     def handler (signum, frame):
525         logger.log("Received signal %d - exiting"%signum)
526         os._exit(1)
527     signal.signal(signal.SIGHUP,handler)
528     signal.signal(signal.SIGQUIT,handler)
529     signal.signal(signal.SIGINT,handler)
530     signal.signal(signal.SIGTERM,handler)