3fd1a7f1add8bf939c87fe9da902673e295ad705
[nepi.git] / src / nepi / testbeds / planetlab / tunproto.py
1 # -*- coding: utf-8 -*-
2
3 import weakref
4 import os
5 import os.path
6 import rspawn
7 import subprocess
8 import threading
9 import base64
10 import time
11 import re
12 import sys
13 import logging
14
15 from nepi.util import server
16
17 class TunProtoBase(object):
18     def __init__(self, local, peer, home_path, key):
19         # Weak references, since ifaces do have a reference to the
20         # tunneling protocol implementation - we don't want strong
21         # circular references.
22         self.peer = weakref.ref(peer)
23         self.local = weakref.ref(local)
24         
25         self.port = 15000
26         self.mode = 'pl-tun'
27         self.key = key
28         self.cross_slice = False
29         
30         self.home_path = home_path
31        
32         self._started = False
33
34         self._pid = None
35         self._ppid = None
36         self._if_name = None
37
38         self._pointopoint = None
39         self._netprefix = None
40         self._address = None
41
42         # Logging
43         self._logger = logging.getLogger('nepi.testbeds.planetlab')
44     
45     def __str__(self):
46         local = self.local()
47         if local:
48             return '<%s for %s>' % (self.__class__.__name__, local)
49         else:
50             return super(TunProtoBase,self).__str__()
51
52     def _make_home(self):
53         local = self.local()
54         
55         if not local:
56             raise RuntimeError, "Lost reference to peering interfaces before launching"
57         if not local.node:
58             raise RuntimeError, "Unconnected TUN - missing node"
59         
60         # Make sure all the paths are created where 
61         # they have to be created for deployment
62         # Also remove pidfile, if there is one.
63         # Old pidfiles from previous runs can be troublesome.
64         cmd = "mkdir -p %(home)s ; rm -f %(home)s/pid %(home)s/*.so" % {
65             'home' : server.shell_escape(self.home_path)
66         }
67         (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
68             cmd,
69             host = local.node.hostname,
70             port = None,
71             user = local.node.slicename,
72             agent = None,
73             ident_key = local.node.ident_path,
74             server_key = local.node.server_key,
75             timeout = 60,
76             retry = 3
77             )
78         
79         if proc.wait():
80             raise RuntimeError, "Failed to set up TUN forwarder: %s %s" % (out,err,)
81     
82     def _install_scripts(self):
83         local = self.local()
84         
85         if not local:
86             raise RuntimeError, "Lost reference to peering interfaces before launching"
87         if not local.node:
88             raise RuntimeError, "Unconnected TUN - missing node"
89         
90         # Install the tun_connect script and tunalloc utility
91         from nepi.util import tunchannel
92         from nepi.util import ipaddr2
93         sources = [
94             os.path.join(os.path.dirname(__file__), 'scripts', 'tun_connect.py'),
95             os.path.join(os.path.dirname(__file__), 'scripts', 'tunalloc.c'),
96             re.sub(r"([.]py)[co]$", r'\1', tunchannel.__file__, 1), # pyc/o files are version-specific
97             re.sub(r"([.]py)[co]$", r'\1', ipaddr2.__file__, 1), # pyc/o files are version-specific
98         ]
99         if local.filter_module:
100             filter_sources = filter(bool,map(str.strip,local.filter_module.module.split()))
101             filter_module = filter_sources[0]
102             
103             # Translate paths to builtin sources
104             for i,source in enumerate(filter_sources):
105                 if not os.path.exists(source):
106                     # Um... try the builtin folder
107                     source = os.path.join(os.path.dirname(__file__), "scripts", source)
108                     if os.path.exists(source):
109                         # Yep... replace
110                         filter_sources[i] = source
111
112             sources.extend(set(filter_sources))
113                 
114         else:
115             filter_module = None
116             filter_sources = None
117         dest = "%s@%s:%s" % (
118             local.node.slicename, local.node.hostname, 
119             os.path.join(self.home_path,'.'),)
120         (out,err),proc = server.eintr_retry(server.popen_scp)(
121             sources,
122             dest,
123             ident_key = local.node.ident_path,
124             server_key = local.node.server_key
125             )
126     
127         if proc.wait():
128             raise RuntimeError, "Failed upload TUN connect script %r: %s %s" % (sources, out,err,)
129         
130         # Make sure all dependencies are satisfied
131         local.node.wait_dependencies()
132
133         cmd = ( (
134             "cd %(home)s && "
135             "gcc -fPIC -shared tunalloc.c -o tunalloc.so && "
136             
137             "wget -q -c -O python-iovec-src.tar.gz %(iovec_url)s && "
138             "mkdir -p python-iovec && "
139             "cd python-iovec && "
140             "tar xzf ../python-iovec-src.tar.gz --strip-components=1 && "
141             "python setup.py build && "
142             "python setup.py install --install-lib .. && "
143             "cd .. "
144             
145             + ( " && "
146                 "gcc -fPIC -shared %(sources)s -o %(module)s.so " % {
147                    'module' : os.path.basename(filter_module).rsplit('.',1)[0],
148                    'sources' : ' '.join(map(os.path.basename,filter_sources))
149                 }
150                 
151                 if filter_module is not None and filter_module.endswith('.c')
152                 else ""
153             )
154             
155             + ( " && "
156                 "wget -q -c -O python-passfd-src.tar.gz %(passfd_url)s && "
157                 "mkdir -p python-passfd && "
158                 "cd python-passfd && "
159                 "tar xzf ../python-passfd-src.tar.gz --strip-components=1 && "
160                 "python setup.py build && "
161                 "python setup.py install --install-lib .. "
162                 
163                 if local.tun_proto == "fd" 
164                 else ""
165             ) 
166           )
167         % {
168             'home' : server.shell_escape(self.home_path),
169             'passfd_url' : "http://nepi.pl.sophia.inria.fr/code/python-passfd/archive/tip.tar.gz",
170             'iovec_url' : "http://nepi.pl.sophia.inria.fr/code/python-iovec/archive/tip.tar.gz",
171         } )
172         (out,err),proc = server.popen_ssh_command(
173             cmd,
174             host = local.node.hostname,
175             port = None,
176             user = local.node.slicename,
177             agent = None,
178             ident_key = local.node.ident_path,
179             server_key = local.node.server_key,
180             timeout = 300
181             )
182         
183         if proc.wait():
184             raise RuntimeError, "Failed to set up TUN forwarder: %s %s" % (out,err,)
185         
186     def launch(self, check_proto):
187         peer = self.peer()
188         local = self.local()
189         
190         if not peer or not local:
191             raise RuntimeError, "Lost reference to peering interfaces before launching"
192         
193         peer_port = peer.tun_port
194         peer_addr = peer.tun_addr
195         peer_proto = peer.tun_proto
196         peer_cipher = peer.tun_cipher
197         
198         local_port = self.port
199         local_cap  = local.capture
200         self._address = local_addr = local.address
201         self._netprefix = local_mask = local.netprefix
202         local_snat = local.snat
203         local_txq  = local.txqueuelen
204         self._pointopoint = local_p2p  = local.pointopoint
205         local_cipher=local.tun_cipher
206         local_mcast= local.multicast
207         local_bwlim= local.bwlimit
208         local_mcastfwd = local.multicast_forwarder
209         
210         if not local_p2p and hasattr(peer, 'address'):
211             self._pointopoint = local_p2p = peer.address
212
213         if check_proto != peer_proto:
214             raise RuntimeError, "Peering protocol mismatch: %s != %s" % (check_proto, peer_proto)
215         
216         if local_cipher != peer_cipher:
217             raise RuntimeError, "Peering cipher mismatch: %s != %s" % (local_cipher, peer_cipher)
218         
219         if check_proto == 'gre' and local_cipher.lower() != 'plain':
220             raise RuntimeError, "Misconfigured TUN: %s - GRE tunnels do not support encryption. Got %s, you MUST use PLAIN" % (local, local_cipher,)
221
222         if local.filter_module:
223             if check_proto not in ('udp', 'tcp'):
224                 raise RuntimeError, "Miscofnigured TUN: %s - filtered tunnels only work with udp or tcp links" % (local,)
225             filter_module = filter(bool,map(str.strip,local.filter_module.module.split()))
226             filter_module = os.path.join('.',os.path.basename(filter_module[0]))
227             if filter_module.endswith('.c'):
228                 filter_module = filter_module.rsplit('.',1)[0] + '.so'
229             filter_args = local.filter_module.args
230         else:
231             filter_module = None
232             filter_args = None
233         
234         args = ["python", "tun_connect.py", 
235             "-m", str(self.mode),
236             "-t", str(check_proto),
237             "-A", str(local_addr),
238             "-M", str(local_mask),
239             "-C", str(local_cipher),
240             ]
241         
242         if check_proto == 'fd':
243             passfd_arg = str(peer_addr)
244             if passfd_arg.startswith('\x00'):
245                 # cannot shell_encode null characters :(
246                 passfd_arg = "base64:"+base64.b64encode(passfd_arg)
247             else:
248                 passfd_arg = '$HOME/'+server.shell_escape(passfd_arg)
249             args.extend([
250                 "--pass-fd", passfd_arg
251             ])
252         elif check_proto == 'gre':
253             if self.cross_slice:
254                 args.extend([
255                     "-K", str(self.key.strip('='))
256                 ])
257
258             args.extend([
259                 "-a", str(peer_addr),
260             ])
261         # both udp and tcp
262         else:
263             args.extend([
264                 "-P", str(local_port),
265                 "-p", str(peer_port),
266                 "-a", str(peer_addr),
267                 "-k", str(self.key)
268             ])
269         
270         if local_snat:
271             args.append("-S")
272         if local_p2p:
273             args.extend(("-Z",str(local_p2p)))
274         if local_txq:
275             args.extend(("-Q",str(local_txq)))
276         if not local_cap:
277             args.append("-N")
278         elif local_cap == 'pcap':
279             args.extend(('-c','pcap'))
280         if local_bwlim:
281             args.extend(("-b",str(local_bwlim*1024)))
282         if filter_module:
283             args.extend(("--filter", filter_module))
284         if filter_args:
285             args.extend(("--filter-args", filter_args))
286         if local_mcast and local_mcastfwd:
287             args.extend(("--multicast-forwarder", local_mcastfwd))
288
289         self._logger.info("Starting %s", self)
290         
291         self._make_home()
292         self._install_scripts()
293
294         # Start process in a "daemonized" way, using nohup and heavy
295         # stdin/out redirection to avoid connection issues
296         (out,err),proc = rspawn.remote_spawn(
297             " ".join(args),
298             
299             pidfile = './pid',
300             home = self.home_path,
301             stdin = '/dev/null',
302             stdout = 'capture',
303             stderr = rspawn.STDOUT,
304             sudo = True,
305             
306             host = local.node.hostname,
307             port = None,
308             user = local.node.slicename,
309             agent = None,
310             ident_key = local.node.ident_path,
311             server_key = local.node.server_key
312             )
313         
314         if proc.wait():
315             raise RuntimeError, "Failed to set up TUN: %s %s" % (out,err,)
316        
317         self._started = True
318     
319     def recover(self):
320         # Tunnel should be still running in its node
321         # Just check its pidfile and we're done
322         self._started = True
323         self.checkpid()
324     
325     def wait(self):
326         local = self.local()
327         
328         # Wait for the connection to be established
329         retrytime = 2.0
330         for spin in xrange(30):
331             if self.status() != rspawn.RUNNING:
332                 self._logger.warn("FAILED TO CONNECT! %s", self)
333                 break
334             
335             # Connected?
336             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
337                 "cd %(home)s ; grep -a -c Connected capture" % dict(
338                     home = server.shell_escape(self.home_path)),
339                 host = local.node.hostname,
340                 port = None,
341                 user = local.node.slicename,
342                 agent = None,
343                 ident_key = local.node.ident_path,
344                 server_key = local.node.server_key,
345                 timeout = 60,
346                 err_on_timeout = False
347                 )
348             proc.wait()
349
350             if out.strip() == '1':
351                 break
352
353             # At least listening?
354             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
355                 "cd %(home)s ; grep -a -c Listening capture" % dict(
356                     home = server.shell_escape(self.home_path)),
357                 host = local.node.hostname,
358                 port = None,
359                 user = local.node.slicename,
360                 agent = None,
361                 ident_key = local.node.ident_path,
362                 server_key = local.node.server_key,
363                 timeout = 60,
364                 err_on_timeout = False
365                 )
366             proc.wait()
367
368             time.sleep(min(30.0, retrytime))
369             retrytime *= 1.1
370         else:
371             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
372                 "cat %(home)s/capture" % dict(
373                     home = server.shell_escape(self.home_path)),
374                 host = local.node.hostname,
375                 port = None,
376                 user = local.node.slicename,
377                 agent = None,
378                 ident_key = local.node.ident_path,
379                 server_key = local.node.server_key,
380                 timeout = 60,
381                 retry = 3,
382                 err_on_timeout = False
383                 )
384             proc.wait()
385
386             raise RuntimeError, "FAILED TO CONNECT %s: %s%s" % (self,out,err)
387     
388     @property
389     def if_name(self):
390         if not self._if_name:
391             # Inspect the trace to check the assigned iface
392             local = self.local()
393             if local:
394                 cmd = "cd %(home)s ; grep -a 'Using tun:' capture | head -1" % dict(
395                             home = server.shell_escape(self.home_path))
396                 for spin in xrange(30):
397                     (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
398                         cmd,
399                         host = local.node.hostname,
400                         port = None,
401                         user = local.node.slicename,
402                         agent = None,
403                         ident_key = local.node.ident_path,
404                         server_key = local.node.server_key,
405                         timeout = 60,
406                         err_on_timeout = False
407                         )
408                     
409                     if proc.wait():
410                         self._logger.debug("if_name: failed cmd %s", cmd)
411                         time.sleep(1)
412                         continue
413                     
414                     out = out.strip()
415                     
416                     match = re.match(r"Using +tun: +([-a-zA-Z0-9]*).*",out)
417                     if match:
418                         self._if_name = match.group(1)
419                         break
420                     elif out:
421                         self._logger.debug("if_name: %r does not match expected pattern from cmd %s", out, cmd)
422                     else:
423                         self._logger.debug("if_name: empty output from cmd %s", cmd)
424                     time.sleep(3)
425                 else:
426                     self._logger.warn("if_name: Could not get interface name")
427         return self._if_name
428     
429     def if_alive(self):
430         name = self.if_name
431         if name:
432             local = self.local()
433             for i in xrange(30):
434                 (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
435                     "ip link show %s >/dev/null 2>&1 && echo ALIVE || echo DEAD" % (name,),
436                     host = local.node.hostname,
437                     port = None,
438                     user = local.node.slicename,
439                     agent = None,
440                     ident_key = local.node.ident_path,
441                     server_key = local.node.server_key,
442                     timeout = 60,
443                     err_on_timeout = False
444                     )
445                 
446                 if proc.wait():
447                     time.sleep(1)
448                     continue
449                 
450                 if out.strip() == 'DEAD':
451                     return False
452                 elif out.strip() == 'ALIVE':
453                     return True
454         return False
455     
456     def checkpid(self):            
457         local = self.local()
458         
459         if not local:
460             raise RuntimeError, "Lost reference to local interface"
461         
462         # Get PID/PPID
463         # NOTE: wait a bit for the pidfile to be created
464         if self._started and not self._pid or not self._ppid:
465             pidtuple = rspawn.remote_check_pid(
466                 os.path.join(self.home_path,'pid'),
467                 host = local.node.hostname,
468                 port = None,
469                 user = local.node.slicename,
470                 agent = None,
471                 ident_key = local.node.ident_path,
472                 server_key = local.node.server_key
473                 )
474             
475             if pidtuple:
476                 self._pid, self._ppid = pidtuple
477     
478     def status(self):
479         local = self.local()
480         
481         if not local:
482             raise RuntimeError, "Lost reference to local interface"
483         
484         self.checkpid()
485         if not self._started:
486             return rspawn.NOT_STARTED
487         elif not self._pid or not self._ppid:
488             return rspawn.NOT_STARTED
489         else:
490             status = rspawn.remote_status(
491                 self._pid, self._ppid,
492                 host = local.node.hostname,
493                 port = None,
494                 user = local.node.slicename,
495                 agent = None,
496                 ident_key = local.node.ident_path,
497                 server_key = local.node.server_key
498                 )
499             return status
500     
501     def kill(self, nowait = True):
502         local = self.local()
503         
504         if not local:
505             raise RuntimeError, "Lost reference to local interface"
506         
507         status = self.status()
508         if status == rspawn.RUNNING:
509             self._logger.info("Stopping %s", self)
510             
511             # kill by ppid+pid - SIGTERM first, then try SIGKILL
512             rspawn.remote_kill(
513                 self._pid, self._ppid,
514                 host = local.node.hostname,
515                 port = None,
516                 user = local.node.slicename,
517                 agent = None,
518                 ident_key = local.node.ident_path,
519                 server_key = local.node.server_key,
520                 sudo = True,
521                 nowait = nowait
522                 )
523     
524     def waitkill(self):
525         interval = 1.0
526         for i in xrange(30):
527             status = self.status()
528             if status != rspawn.RUNNING:
529                 self._logger.info("Stopped %s", self)
530                 break
531             time.sleep(interval)
532             interval = min(30.0, interval * 1.1)
533         else:
534             self.kill(nowait=False)
535
536         if self.if_name:
537             for i in xrange(30):
538                 if not self.if_alive():
539                     self._logger.info("Device down %s", self)
540                     break
541                 time.sleep(interval)
542                 interval = min(30.0, interval * 1.1)
543             else:
544                 local = self.local()
545                 
546                 if local:
547                     # Forcibly shut down interface
548                     (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
549                         "sudo -S bash -c 'echo %s > /vsys/vif_down.in'" % (self.if_name,),
550                         host = local.node.hostname,
551                         port = None,
552                         user = local.node.slicename,
553                         agent = None,
554                         ident_key = local.node.ident_path,
555                         server_key = local.node.server_key,
556                         timeout = 60,
557                         err_on_timeout = False
558                         )
559                     proc.wait()
560
561     def if_down(self):
562         # TODO!!! need to set the vif down with vsys/vif_down.in ... which 
563         # doesn't currently work.
564         local = self.local()
565         
566         if local:
567             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
568                 "sudo -S bash -c 'kill -s USR1 %d'" % (self._pid,),
569                 host = local.node.hostname,
570                 port = None,
571                 user = local.node.slicename,
572                 agent = None,
573                 ident_key = local.node.ident_path,
574                 server_key = local.node.server_key,
575                 timeout = 60,
576                 err_on_timeout = False
577                 )
578             proc.wait()
579
580     def if_up(self):
581         # TODO!!! need to set the vif up with vsys/vif_up.in ... which 
582         # doesn't currently work.
583         local = self.local()
584         
585         if local:
586             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
587                 "sudo -S bash -c 'kill -s USR2 %d'" % (self._pid,),
588                 host = local.node.hostname,
589                 port = None,
590                 user = local.node.slicename,
591                 agent = None,
592                 ident_key = local.node.ident_path,
593                 server_key = local.node.server_key,
594                 timeout = 60,
595                 err_on_timeout = False
596                 )
597             proc.wait()    
598
599     _TRACEMAP = {
600         # tracename : (remotename, localname)
601         'packets' : ('capture','capture'),
602         'pcap' : ('pcap','capture.pcap'),
603     }
604     
605     def remote_trace_path(self, whichtrace, tracemap = None):
606         tracemap = self._TRACEMAP if not tracemap else tracemap
607         
608         if whichtrace not in tracemap:
609             return None
610         
611         return os.path.join(self.home_path, tracemap[whichtrace][1])
612         
613     def sync_trace(self, local_dir, whichtrace, tracemap = None):
614         tracemap = self._TRACEMAP if not tracemap else tracemap
615         
616         if whichtrace not in tracemap:
617             return None
618         
619         local = self.local()
620         
621         if not local:
622             return None
623         
624         local_path = os.path.join(local_dir, tracemap[whichtrace][1])
625         
626         # create parent local folders
627         if os.path.dirname(local_path):
628             proc = subprocess.Popen(
629                 ["mkdir", "-p", os.path.dirname(local_path)],
630                 stdout = open("/dev/null","w"),
631                 stdin = open("/dev/null","r"))
632
633             if proc.wait():
634                 raise RuntimeError, "Failed to synchronize trace"
635         
636         # sync files
637         (out,err),proc = server.popen_scp(
638             '%s@%s:%s' % (local.node.slicename, local.node.hostname, 
639                 os.path.join(self.home_path, tracemap[whichtrace][0])),
640             local_path,
641             port = None,
642             agent = None,
643             ident_key = local.node.ident_path,
644             server_key = local.node.server_key
645             )
646         
647         if proc.wait():
648             raise RuntimeError, "Failed to synchronize trace: %s %s" % (out,err,)
649         
650         return local_path
651         
652     def shutdown(self):
653         self.kill()
654     
655     def destroy(self):
656         self.waitkill()
657
658 class TunProtoUDP(TunProtoBase):
659     def __init__(self, local, peer, home_path, key):
660         super(TunProtoUDP, self).__init__(local, peer, home_path, key)
661     
662     def launch(self):
663         super(TunProtoUDP, self).launch('udp')
664
665 class TunProtoFD(TunProtoBase):
666     def __init__(self, local, peer, home_path, key):
667         super(TunProtoFD, self).__init__(local, peer, home_path, key)
668     
669     def launch(self):
670         super(TunProtoFD, self).launch('fd')
671
672 class TunProtoGRE(TunProtoBase):
673     def __init__(self, local, peer, home_path, key):
674         super(TunProtoGRE, self).__init__(local, peer, home_path, key)
675         self.mode = 'pl-gre-ip'
676
677     def launch(self):
678         super(TunProtoGRE, self).launch('gre')
679
680 class TunProtoTCP(TunProtoBase):
681     def __init__(self, local, peer, home_path, key):
682         super(TunProtoTCP, self).__init__(local, peer, home_path, key)
683     
684     def launch(self):
685         super(TunProtoTCP, self).launch('tcp')
686
687 class TapProtoUDP(TunProtoUDP):
688     def __init__(self, local, peer, home_path, key):
689         super(TapProtoUDP, self).__init__(local, peer, home_path, key)
690         self.mode = 'pl-tap'
691
692 class TapProtoTCP(TunProtoTCP):
693     def __init__(self, local, peer, home_path, key):
694         super(TapProtoTCP, self).__init__(local, peer, home_path, key)
695         self.mode = 'pl-tap'
696
697 class TapProtoFD(TunProtoFD):
698     def __init__(self, local, peer, home_path, key):
699         super(TapProtoFD, self).__init__(local, peer, home_path, key)
700         self.mode = 'pl-tap'
701
702 class TapProtoGRE(TunProtoGRE):
703     def __init__(self, local, peer, home_path, key):
704         super(TapProtoGRE, self).__init__(local, peer, home_path, key)
705         self.mode = 'pl-gre-eth'
706
707 TUN_PROTO_MAP = {
708     'tcp' : TunProtoTCP,
709     'udp' : TunProtoUDP,
710     'fd'  : TunProtoFD,
711     'gre' : TunProtoGRE,
712 }
713
714 TAP_PROTO_MAP = {
715     'tcp' : TapProtoTCP,
716     'udp' : TapProtoUDP,
717     'fd'  : TapProtoFD,
718     'gre' : TapProtoGRE,
719 }
720