Added TCP-handshake for TunChannel and tun_connect.py
[nepi.git] / src / nepi / testbeds / planetlab / tunproto.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import weakref
5 import os
6 import os.path
7 import rspawn
8 import subprocess
9 import threading
10 import base64
11 import time
12 import re
13 import sys
14 import logging
15
16 from nepi.util import server
17
18 class TunProtoBase(object):
19     def __init__(self, local, peer, home_path, key):
20         # Weak references, since ifaces do have a reference to the
21         # tunneling protocol implementation - we don't want strong
22         # circular references.
23         self.peer = weakref.ref(peer)
24         self.local = weakref.ref(local)
25         
26         self.port = 15000
27         self.mode = 'pl-tun'
28         self.key = key
29         
30         self.home_path = home_path
31        
32         self._started = False
33
34         self._pid = None
35         self._ppid = None
36         self._if_name = None
37
38         # Logging
39         self._logger = logging.getLogger('nepi.testbeds.planetlab')
40     
41     def __str__(self):
42         local = self.local()
43         if local:
44             return '<%s for %s>' % (self.__class__.__name__, local)
45         else:
46             return super(TunProtoBase,self).__str__()
47
48     def _make_home(self):
49         local = self.local()
50         
51         if not local:
52             raise RuntimeError, "Lost reference to peering interfaces before launching"
53         if not local.node:
54             raise RuntimeError, "Unconnected TUN - missing node"
55         
56         # Make sure all the paths are created where 
57         # they have to be created for deployment
58         # Also remove pidfile, if there is one.
59         # Old pidfiles from previous runs can be troublesome.
60         cmd = "mkdir -p %(home)s ; rm -f %(home)s/pid %(home)s/*.so" % {
61             'home' : server.shell_escape(self.home_path)
62         }
63         (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
64             cmd,
65             host = local.node.hostname,
66             port = None,
67             user = local.node.slicename,
68             agent = None,
69             ident_key = local.node.ident_path,
70             server_key = local.node.server_key,
71             timeout = 60,
72             retry = 3
73             )
74         
75         if proc.wait():
76             raise RuntimeError, "Failed to set up TUN forwarder: %s %s" % (out,err,)
77     
78     def _install_scripts(self):
79         local = self.local()
80         
81         if not local:
82             raise RuntimeError, "Lost reference to peering interfaces before launching"
83         if not local.node:
84             raise RuntimeError, "Unconnected TUN - missing node"
85         
86         # Install the tun_connect script and tunalloc utility
87         from nepi.util import tunchannel
88         from nepi.util import ipaddr2
89         sources = [
90             os.path.join(os.path.dirname(__file__), 'scripts', 'tun_connect.py'),
91             os.path.join(os.path.dirname(__file__), 'scripts', 'tunalloc.c'),
92             re.sub(r"([.]py)[co]$", r'\1', tunchannel.__file__, 1), # pyc/o files are version-specific
93             re.sub(r"([.]py)[co]$", r'\1', ipaddr2.__file__, 1), # pyc/o files are version-specific
94         ]
95         if local.filter_module:
96             filter_sources = filter(bool,map(str.strip,local.filter_module.module.split()))
97             filter_module = filter_sources[0]
98             
99             # Translate paths to builtin sources
100             for i,source in enumerate(filter_sources):
101                 if not os.path.exists(source):
102                     # Um... try the builtin folder
103                     source = os.path.join(os.path.dirname(__file__), "scripts", source)
104                     if os.path.exists(source):
105                         # Yep... replace
106                         filter_sources[i] = source
107
108             sources.extend(set(filter_sources))
109                 
110         else:
111             filter_module = None
112             filter_sources = None
113         dest = "%s@%s:%s" % (
114             local.node.slicename, local.node.hostname, 
115             os.path.join(self.home_path,'.'),)
116         (out,err),proc = server.eintr_retry(server.popen_scp)(
117             sources,
118             dest,
119             ident_key = local.node.ident_path,
120             server_key = local.node.server_key
121             )
122     
123         if proc.wait():
124             raise RuntimeError, "Failed upload TUN connect script %r: %s %s" % (sources, out,err,)
125         
126         # Make sure all dependencies are satisfied
127         local.node.wait_dependencies()
128
129         cmd = ( (
130             "cd %(home)s && "
131             "gcc -fPIC -shared tunalloc.c -o tunalloc.so && "
132             
133             "wget -q -c -O python-iovec-src.tar.gz %(iovec_url)s && "
134             "mkdir -p python-iovec && "
135             "cd python-iovec && "
136             "tar xzf ../python-iovec-src.tar.gz --strip-components=1 && "
137             "python setup.py build && "
138             "python setup.py install --install-lib .. && "
139             "cd .. "
140             
141             + ( " && "
142                 "gcc -fPIC -shared %(sources)s -o %(module)s.so " % {
143                    'module' : os.path.basename(filter_module).rsplit('.',1)[0],
144                    'sources' : ' '.join(map(os.path.basename,filter_sources))
145                 }
146                 
147                 if filter_module is not None and filter_module.endswith('.c')
148                 else ""
149             )
150             
151             + ( " && "
152                 "wget -q -c -O python-passfd-src.tar.gz %(passfd_url)s && "
153                 "mkdir -p python-passfd && "
154                 "cd python-passfd && "
155                 "tar xzf ../python-passfd-src.tar.gz --strip-components=1 && "
156                 "python setup.py build && "
157                 "python setup.py install --install-lib .. "
158                 
159                 if local.tun_proto == "fd" 
160                 else ""
161             ) 
162           )
163         % {
164             'home' : server.shell_escape(self.home_path),
165             'passfd_url' : "http://yans.pl.sophia.inria.fr/code/hgwebdir.cgi/python-passfd/archive/2a6472c64c87.tar.gz",
166             'iovec_url' : "http://yans.pl.sophia.inria.fr/code/hgwebdir.cgi/python-iovec/archive/tip.tar.gz",
167         } )
168         (out,err),proc = server.popen_ssh_command(
169             cmd,
170             host = local.node.hostname,
171             port = None,
172             user = local.node.slicename,
173             agent = None,
174             ident_key = local.node.ident_path,
175             server_key = local.node.server_key,
176             timeout = 300
177             )
178         
179         if proc.wait():
180             raise RuntimeError, "Failed to set up TUN forwarder: %s %s" % (out,err,)
181         
182     def launch(self, check_proto):
183         peer = self.peer()
184         local = self.local()
185         
186         if not peer or not local:
187             raise RuntimeError, "Lost reference to peering interfaces before launching"
188         
189         peer_port = peer.tun_port
190         peer_addr = peer.tun_addr
191         peer_proto = peer.tun_proto
192         peer_cipher = peer.tun_cipher
193         
194         local_port = self.port
195         local_cap  = local.capture
196         local_addr = local.address
197         local_mask = local.netprefix
198         local_snat = local.snat
199         local_txq  = local.txqueuelen
200         local_p2p  = local.pointopoint
201         local_cipher = local.tun_cipher
202         local_mcast = local.multicast
203         local_bwlim = local.bwlimit
204         
205         if not local_p2p and hasattr(peer, 'address'):
206             local_p2p = peer.address
207
208         if check_proto != peer_proto:
209             raise RuntimeError, "Peering protocol mismatch: %s != %s" % (check_proto, peer_proto)
210         
211         if local_cipher != peer_cipher:
212             raise RuntimeError, "Peering cipher mismatch: %s != %s" % (local_cipher, peer_cipher)
213         
214         if check_proto == 'gre' and local_cipher.lower() != 'plain':
215             raise RuntimeError, "Misconfigured TUN: %s - GRE tunnels do not support encryption. Got %s, you MUST use PLAIN" % (local, local_cipher,)
216
217         if local.filter_module:
218             if check_proto not in ('udp', 'tcp'):
219                 raise RuntimeError, "Miscofnigured TUN: %s - filtered tunnels only work with udp or tcp links" % (local,)
220             filter_module = filter(bool,map(str.strip,local.filter_module.module.split()))
221             filter_module = os.path.join('.',os.path.basename(filter_module[0]))
222             if filter_module.endswith('.c'):
223                 filter_module = filter_module.rsplit('.',1)[0] + '.so'
224             filter_args = local.filter_module.args
225         else:
226             filter_module = None
227             filter_args = None
228         
229         args = ["python", "tun_connect.py", 
230             "-m", str(self.mode),
231             "-t", str(check_proto),
232             "-A", str(local_addr),
233             "-M", str(local_mask),
234             "-C", str(local_cipher)]
235         
236         if check_proto == 'fd':
237             passfd_arg = str(peer_addr)
238             if passfd_arg.startswith('\x00'):
239                 # cannot shell_encode null characters :(
240                 passfd_arg = "base64:"+base64.b64encode(passfd_arg)
241             else:
242                 passfd_arg = '$HOME/'+server.shell_escape(passfd_arg)
243             args.extend([
244                 "--pass-fd", passfd_arg
245             ])
246         elif check_proto == 'gre':
247             args.extend([
248                 "-K", str(min(local_port, peer_port)),
249                 "-a", str(peer_addr),
250             ])
251         # both udp and tcp
252         else:
253             args.extend([
254                 "-P", str(local_port),
255                 "-p", str(peer_port),
256                 "-a", str(peer_addr),
257                 "-k", str(self.key)
258             ])
259         
260         if local_snat:
261             args.append("-S")
262         if local_p2p:
263             args.extend(("-Z",str(local_p2p)))
264         if local_txq:
265             args.extend(("-Q",str(local_txq)))
266         if not local_cap:
267             args.append("-N")
268         elif local_cap == 'pcap':
269             args.extend(('-c','pcap'))
270         if local_mcast:
271             args.append("--multicast")
272         if local_bwlim:
273             args.extend(("-b",str(local_bwlim*1024)))
274         if filter_module:
275             args.extend(("--filter", filter_module))
276         if filter_args:
277             args.extend(("--filter-args", filter_args))
278
279         self._logger.info("Starting %s", self)
280         
281         self._make_home()
282         self._install_scripts()
283
284         # Start process in a "daemonized" way, using nohup and heavy
285         # stdin/out redirection to avoid connection issues
286         (out,err),proc = rspawn.remote_spawn(
287             " ".join(args),
288             
289             pidfile = './pid',
290             home = self.home_path,
291             stdin = '/dev/null',
292             stdout = 'capture',
293             stderr = rspawn.STDOUT,
294             sudo = True,
295             
296             host = local.node.hostname,
297             port = None,
298             user = local.node.slicename,
299             agent = None,
300             ident_key = local.node.ident_path,
301             server_key = local.node.server_key
302             )
303         
304         if proc.wait():
305             raise RuntimeError, "Failed to set up TUN: %s %s" % (out,err,)
306        
307         self._started = True
308     
309     def recover(self):
310         # Tunnel should be still running in its node
311         # Just check its pidfile and we're done
312         self._started = True
313         self.checkpid()
314     
315     def wait(self):
316         local = self.local()
317         
318         # Wait for the connection to be established
319         retrytime = 2.0
320         for spin in xrange(30):
321             if self.status() != rspawn.RUNNING:
322                 self._logger.warn("FAILED TO CONNECT! %s", self)
323                 break
324             
325             # Connected?
326             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
327                 "cd %(home)s ; grep -c Connected capture" % dict(
328                     home = server.shell_escape(self.home_path)),
329                 host = local.node.hostname,
330                 port = None,
331                 user = local.node.slicename,
332                 agent = None,
333                 ident_key = local.node.ident_path,
334                 server_key = local.node.server_key,
335                 timeout = 60,
336                 err_on_timeout = False
337                 )
338             proc.wait()
339
340             if out.strip() == '1':
341                 break
342
343             # At least listening?
344             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
345                 "cd %(home)s ; grep -c Listening capture" % dict(
346                     home = server.shell_escape(self.home_path)),
347                 host = local.node.hostname,
348                 port = None,
349                 user = local.node.slicename,
350                 agent = None,
351                 ident_key = local.node.ident_path,
352                 server_key = local.node.server_key,
353                 timeout = 60,
354                 err_on_timeout = False
355                 )
356             proc.wait()
357
358             time.sleep(min(30.0, retrytime))
359             retrytime *= 1.1
360         else:
361             (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
362                 "cat %(home)s/capture" % dict(
363                     home = server.shell_escape(self.home_path)),
364                 host = local.node.hostname,
365                 port = None,
366                 user = local.node.slicename,
367                 agent = None,
368                 ident_key = local.node.ident_path,
369                 server_key = local.node.server_key,
370                 timeout = 60,
371                 retry = 3,
372                 err_on_timeout = False
373                 )
374             proc.wait()
375
376             raise RuntimeError, "FAILED TO CONNECT %s: %s%s" % (self,out,err)
377     
378     @property
379     def if_name(self):
380         if not self._if_name:
381             # Inspect the trace to check the assigned iface
382             local = self.local()
383             if local:
384                 cmd = "cd %(home)s ; grep 'Using tun:' capture | head -1" % dict(
385                             home = server.shell_escape(self.home_path))
386                 for spin in xrange(30):
387                     (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
388                         cmd,
389                         host = local.node.hostname,
390                         port = None,
391                         user = local.node.slicename,
392                         agent = None,
393                         ident_key = local.node.ident_path,
394                         server_key = local.node.server_key,
395                         timeout = 60,
396                         err_on_timeout = False
397                         )
398                     
399                     if proc.wait():
400                         self._logger.debug("if_name: failed cmd %s", cmd)
401                         time.sleep(1)
402                         continue
403                     
404                     out = out.strip()
405                     
406                     match = re.match(r"Using +tun: +([-a-zA-Z0-9]*).*",out)
407                     if match:
408                         self._if_name = match.group(1)
409                         break
410                     elif out:
411                         self._logger.debug("if_name: %r does not match expected pattern from cmd %s", out, cmd)
412                     else:
413                         self._logger.debug("if_name: empty output from cmd %s", cmd)
414                     time.sleep(1)
415                 else:
416                     self._logger.warn("if_name: Could not get interface name")
417         return self._if_name
418     
419     def if_alive(self):
420         name = self.if_name
421         if name:
422             local = self.local()
423             for i in xrange(30):
424                 (out,err),proc = server.eintr_retry(server.popen_ssh_command)(
425                     "ip show %s >/dev/null 2>&1 && echo ALIVE || echo DEAD" % (name,),
426                     host = local.node.hostname,
427                     port = None,
428                     user = local.node.slicename,
429                     agent = None,
430                     ident_key = local.node.ident_path,
431                     server_key = local.node.server_key,
432                     timeout = 60,
433                     err_on_timeout = False
434                     )
435                 
436                 if proc.wait():
437                     time.sleep(1)
438                     continue
439                 
440                 if out.strip() == 'DEAD':
441                     return False
442                 elif out.strip() == 'ALIVE':
443                     return True
444         return False
445     
446     def checkpid(self):            
447         local = self.local()
448         
449         if not local:
450             raise RuntimeError, "Lost reference to local interface"
451         
452         # Get PID/PPID
453         # NOTE: wait a bit for the pidfile to be created
454         if self._started and not self._pid or not self._ppid:
455             pidtuple = rspawn.remote_check_pid(
456                 os.path.join(self.home_path,'pid'),
457                 host = local.node.hostname,
458                 port = None,
459                 user = local.node.slicename,
460                 agent = None,
461                 ident_key = local.node.ident_path,
462                 server_key = local.node.server_key
463                 )
464             
465             if pidtuple:
466                 self._pid, self._ppid = pidtuple
467     
468     def status(self):
469         local = self.local()
470         
471         if not local:
472             raise RuntimeError, "Lost reference to local interface"
473         
474         self.checkpid()
475         if not self._started:
476             return rspawn.NOT_STARTED
477         elif not self._pid or not self._ppid:
478             return rspawn.NOT_STARTED
479         else:
480             status = rspawn.remote_status(
481                 self._pid, self._ppid,
482                 host = local.node.hostname,
483                 port = None,
484                 user = local.node.slicename,
485                 agent = None,
486                 ident_key = local.node.ident_path,
487                 server_key = local.node.server_key
488                 )
489             return status
490     
491     def kill(self, nowait = True):
492         local = self.local()
493         
494         if not local:
495             raise RuntimeError, "Lost reference to local interface"
496         
497         status = self.status()
498         if status == rspawn.RUNNING:
499             self._logger.info("Stopping %s", self)
500             
501             # kill by ppid+pid - SIGTERM first, then try SIGKILL
502             rspawn.remote_kill(
503                 self._pid, self._ppid,
504                 host = local.node.hostname,
505                 port = None,
506                 user = local.node.slicename,
507                 agent = None,
508                 ident_key = local.node.ident_path,
509                 server_key = local.node.server_key,
510                 sudo = True,
511                 nowait = nowait
512                 )
513     
514     def waitkill(self):
515         interval = 1.0
516         for i in xrange(30):
517             status = self.status()
518             if status != rspawn.RUNNING:
519                 self._logger.info("Stopped %s", self)
520                 break
521             time.sleep(interval)
522             interval = min(30.0, interval * 1.1)
523         else:
524             self.kill(nowait=False)
525
526         if self.if_name:
527             for i in xrange(30):
528                 if not self.if_alive():
529                     self._logger.info("Device down %s", self)
530                     break
531                 time.sleep(interval)
532                 interval = min(30.0, interval * 1.1)
533     
534     _TRACEMAP = {
535         # tracename : (remotename, localname)
536         'packets' : ('capture','capture'),
537         'pcap' : ('pcap','capture.pcap'),
538     }
539     
540     def remote_trace_path(self, whichtrace, tracemap = None):
541         tracemap = self._TRACEMAP if not tracemap else tracemap
542         
543         if whichtrace not in tracemap:
544             return None
545         
546         return os.path.join(self.home_path, tracemap[whichtrace][1])
547         
548     def sync_trace(self, local_dir, whichtrace, tracemap = None):
549         tracemap = self._TRACEMAP if not tracemap else tracemap
550         
551         if whichtrace not in tracemap:
552             return None
553         
554         local = self.local()
555         
556         if not local:
557             return None
558         
559         local_path = os.path.join(local_dir, tracemap[whichtrace][1])
560         
561         # create parent local folders
562         if os.path.dirname(local_path):
563             proc = subprocess.Popen(
564                 ["mkdir", "-p", os.path.dirname(local_path)],
565                 stdout = open("/dev/null","w"),
566                 stdin = open("/dev/null","r"))
567
568             if proc.wait():
569                 raise RuntimeError, "Failed to synchronize trace"
570         
571         # sync files
572         (out,err),proc = server.popen_scp(
573             '%s@%s:%s' % (local.node.slicename, local.node.hostname, 
574                 os.path.join(self.home_path, tracemap[whichtrace][0])),
575             local_path,
576             port = None,
577             agent = None,
578             ident_key = local.node.ident_path,
579             server_key = local.node.server_key
580             )
581         
582         if proc.wait():
583             raise RuntimeError, "Failed to synchronize trace: %s %s" % (out,err,)
584         
585         return local_path
586         
587     def shutdown(self):
588         self.kill()
589     
590     def destroy(self):
591         self.waitkill()
592
593 class TunProtoUDP(TunProtoBase):
594     def __init__(self, local, peer, home_path, key):
595         super(TunProtoUDP, self).__init__(local, peer, home_path, key)
596     
597     def launch(self):
598         super(TunProtoUDP, self).launch('udp')
599
600 class TunProtoFD(TunProtoBase):
601     def __init__(self, local, peer, home_path, key):
602         super(TunProtoFD, self).__init__(local, peer, home_path, key)
603     
604     def launch(self):
605         super(TunProtoFD, self).launch('fd')
606
607 class TunProtoGRE(TunProtoBase):
608     def __init__(self, local, peer, home_path, key):
609         super(TunProtoGRE, self).__init__(local, peer, home_path, key)
610         self.mode = 'pl-gre-ip'
611
612     def launch(self):
613         super(TunProtoGRE, self).launch('gre')
614
615 class TunProtoTCP(TunProtoBase):
616     def __init__(self, local, peer, home_path, key):
617         super(TunProtoTCP, self).__init__(local, peer, home_path, key)
618     
619     def launch(self):
620         super(TunProtoTCP, self).launch('tcp')
621
622 class TapProtoUDP(TunProtoUDP):
623     def __init__(self, local, peer, home_path, key):
624         super(TapProtoUDP, self).__init__(local, peer, home_path, key)
625         self.mode = 'pl-tap'
626
627 class TapProtoTCP(TunProtoTCP):
628     def __init__(self, local, peer, home_path, key):
629         super(TapProtoTCP, self).__init__(local, peer, home_path, key)
630         self.mode = 'pl-tap'
631
632 class TapProtoFD(TunProtoFD):
633     def __init__(self, local, peer, home_path, key):
634         super(TapProtoFD, self).__init__(local, peer, home_path, key)
635         self.mode = 'pl-tap'
636
637 class TapProtoGRE(TunProtoGRE):
638     def __init__(self, local, peer, home_path, key):
639         super(TapProtoGRE, self).__init__(local, peer, home_path, key)
640         self.mode = 'pl-gre-eth'
641
642 TUN_PROTO_MAP = {
643     'tcp' : TunProtoTCP,
644     'udp' : TunProtoUDP,
645     'fd'  : TunProtoFD,
646     'gre' : TunProtoGRE,
647 }
648
649 TAP_PROTO_MAP = {
650     'tcp' : TapProtoTCP,
651     'udp' : TapProtoUDP,
652     'fd'  : TapProtoFD,
653     'gre' : TapProtoGRE,
654 }
655