#
# NEPI, a framework to manage network experiments
# Copyright (C) 2013 INRIA
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation;
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# Author: Alina Quereilhac
# Claudio Freire
#
from __future__ import print_function
import select
import sys
import os
import struct
import socket
import threading
import traceback
import errno
import fcntl
import random
import traceback
import functools
import collections
import ctypes
import time
def ipfmt(ip):
ipbytes = map(ord,ip.decode("hex"))
return '.'.join(map(str,ipbytes))
tagtype = {
'0806' : 'arp',
'0800' : 'ipv4',
'8870' : 'jumbo',
'8863' : 'PPPoE discover',
'8864' : 'PPPoE',
'86dd' : 'ipv6',
}
def etherProto(packet, len=len):
if len(packet) > 14:
if packet[12] == "\x81" and packet[13] == "\x00":
# tagged
return packet[16:18]
else:
# untagged
return packet[12:14]
# default: ip
return "\x08\x00"
def formatPacket(packet, ether_mode):
if ether_mode:
stripped_packet = etherStrip(packet)
if not stripped_packet:
packet = packet.encode("hex")
if len(packet) < 28:
return "malformed eth " + packet.encode("hex")
else:
if packet[24:28] == "8100":
# tagged
ethertype = tagtype.get(packet[32:36], 'eth')
return ethertype + " " + ( '-'.join( (
packet[0:12], # MAC dest
packet[12:24], # MAC src
packet[24:32], # VLAN tag
packet[32:36], # Ethertype/len
packet[36:], # Payload
) ) )
else:
# untagged
ethertype = tagtype.get(packet[24:28], 'eth')
return ethertype + " " + ( '-'.join( (
packet[0:12], # MAC dest
packet[12:24], # MAC src
packet[24:28], # Ethertype/len
packet[28:], # Payload
) ) )
else:
packet = stripped_packet
packet = packet.encode("hex")
if len(packet) < 48:
return "malformed ip " + packet
else:
return "ip " + ( '-'.join( (
packet[0:1], #version
packet[1:2], #header length
packet[2:4], #diffserv/ECN
packet[4:8], #total length
packet[8:12], #ident
packet[12:16], #flags/fragment offs
packet[16:18], #ttl
packet[18:20], #ip-proto
packet[20:24], #checksum
ipfmt(packet[24:32]), # src-ip
ipfmt(packet[32:40]), # dst-ip
packet[40:48] if (int(packet[1],16) > 5) else "", # options
packet[48:] if (int(packet[1],16) > 5) else packet[40:], # payload
) ) )
def _packetReady(buf, ether_mode=False, len=len, str=str):
if not buf:
return False
rv = False
while not rv:
if len(buf[0]) < 4:
rv = False
elif ether_mode:
rv = True
else:
_,totallen = struct.unpack('HH',buf[0][:4])
totallen = socket.htons(totallen)
rv = len(buf[0]) >= totallen
if not rv and len(buf) > 1:
# collapse only first two buffers
# as needed, to mantain len(buf) meaningful
p1 = buf.popleft()
buf[0] = p1+str(buf[0])
else:
return rv
return rv
def _pullPacket(buf, ether_mode=False, len=len, buffer=buffer):
if ether_mode:
return buf.popleft()
else:
_,totallen = struct.unpack('HH',buf[0][:4])
totallen = socket.htons(totallen)
if len(buf[0]) > totallen:
rv = buffer(buf[0],0,totallen)
buf[0] = buffer(buf[0],totallen)
else:
rv = buf.popleft()
return rv
def etherStrip(buf, buffer=buffer, len=len):
if len(buf) < 14:
return ""
if buf[12:14] == '\x08\x10' and buf[16:18] == '\x08\x00':
# tagged ethernet frame
return buffer(buf, 18)
elif buf[12:14] == '\x08\x00':
# untagged ethernet frame
return buffer(buf, 14)
else:
return ""
def etherWrap(packet):
return ''.join((
"\x00"*6*2 # bogus src and dst mac
+"\x08\x00", # IPv4
packet, # payload
"\x00"*4, # bogus crc
))
def piStrip(buf, len=len):
if len(buf) < 4:
return buf
else:
return buffer(buf,4)
def piWrap(buf, ether_mode, etherProto=etherProto):
if ether_mode:
proto = etherProto(buf)
else:
proto = "\x08\x00"
return ''.join((
"\x00\x00", # PI: 16 bits flags
proto, # 16 bits proto
buf,
))
_padmap = [ chr(padding) * padding for padding in xrange(127) ]
del padding
def encrypt(packet, crypter, len=len, padmap=_padmap):
# pad
padding = crypter.block_size - len(packet) % crypter.block_size
packet += padmap[padding]
# encrypt
return crypter.encrypt(packet)
def decrypt(packet, crypter, ord=ord):
if packet:
# decrypt
packet = crypter.decrypt(packet)
# un-pad
padding = ord(packet[-1])
if not (0 < padding <= crypter.block_size):
# wrong padding
raise RuntimeError("Truncated packet %s")
packet = packet[:-padding]
return packet
def nonblock(fd):
try:
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fl |= os.O_NONBLOCK
fcntl.fcntl(fd, fcntl.F_SETFL, fl)
return True
except:
traceback.print_exc(file=sys.stderr)
# Just ignore
return False
def tun_fwd(tun, remote, with_pi, ether_mode, cipher_key, udp, TERMINATE, SUSPEND,
stderr = sys.stderr, reconnect = None, rwrite = None, rread = None,
tunqueue = 1000, tunkqueue = 1000, cipher = 'AES', accept_local = None,
accept_remote = None, slowlocal = True, queueclass = None,
bwlimit = None, len = len, max = max, min = min, buffer = buffer,
OSError = OSError, select = select.select, selecterror = select.error,
os = os, socket = socket,
retrycodes=(os.errno.EWOULDBLOCK, os.errno.EAGAIN, os.errno.EINTR) ):
crypto_mode = False
crypter = None
try:
if cipher_key and cipher:
import Crypto.Cipher
import hashlib
__import__('Crypto.Cipher.'+cipher)
ciphername = cipher
cipher = getattr(Crypto.Cipher, cipher)
hashed_key = hashlib.sha256(cipher_key).digest()
if ciphername == 'AES':
hashed_key = hashed_key[:16]
elif ciphername == 'Blowfish':
hashed_key = hashed_key[:24]
elif ciphername == 'DES':
hashed_key = hashed_key[:8]
elif ciphername == 'DES3':
hashed_key = hashed_key[:24]
crypter = cipher.new(
hashed_key,
cipher.MODE_ECB)
crypto_mode = True
except:
# We don't want decription to work only on one side,
# This could break things really bad
#crypto_mode = False
#crypter = None
traceback.print_exc(file=sys.stderr)
raise
if stderr is not None:
if crypto_mode:
print("Packets are transmitted in CIPHER", file=stderr)
else:
print("Packets are transmitted in PLAINTEXT", file=stderr)
if hasattr(remote, 'fileno'):
remote_fd = remote.fileno()
if rwrite is None:
def rwrite(remote, packet, os_write=os.write):
return os_write(remote_fd, packet)
if rread is None:
def rread(remote, maxlen, os_read=os.read):
return os_read(remote_fd, maxlen)
rnonblock = nonblock(remote)
tnonblock = nonblock(tun)
# Pick up TUN/TAP writing method
if with_pi:
try:
import iovec
# We have iovec, so we can skip PI injection
# and use iovec which does it natively
if ether_mode:
twrite = iovec.ethpiwrite
tread = iovec.piread2
else:
twrite = iovec.ippiwrite
tread = iovec.piread2
except ImportError:
# We have to inject PI headers pythonically
def twrite(fd, packet, oswrite=os.write, piWrap=piWrap, ether_mode=ether_mode):
return oswrite(fd, piWrap(packet, ether_mode))
# For reading, we strip PI headers with buffer slicing and that's it
def tread(fd, maxlen, osread=os.read, piStrip=piStrip):
return piStrip(osread(fd, maxlen))
else:
# No need to inject PI headers
twrite = os.write
tread = os.read
encrypt_ = encrypt
decrypt_ = decrypt
xrange_ = xrange
if accept_local is not None:
def tread(fd, maxlen, _tread=tread, accept=accept_local):
packet = _tread(fd, maxlen)
if accept(packet, 0):
return packet
else:
return None
if accept_remote is not None:
if crypto_mode:
def decrypt_(packet, crypter, decrypt_=decrypt_, accept=accept_remote):
packet = decrypt_(packet, crypter)
if accept(packet, 1):
return packet
else:
return None
else:
def rread(fd, maxlen, _rread=rread, accept=accept_remote):
packet = _rread(fd, maxlen)
if accept(packet, 1):
return packet
else:
return None
maxbkbuf = maxfwbuf = max(10,tunqueue-tunkqueue)
tunhurry = max(0,maxbkbuf/2)
if queueclass is None:
queueclass = collections.deque
maxbatch = 2000
maxtbatch = 50
else:
maxfwbuf = maxbkbuf = 2000000000
maxbatch = 50
maxtbatch = 30
tunhurry = 30
fwbuf = queueclass()
bkbuf = queueclass()
nfwbuf = 0
nbkbuf = 0
# backwards queue functions
# they may need packet inspection to
# reconstruct packet boundaries
if ether_mode or udp:
packetReady = bool
pullPacket = queueclass.popleft
reschedule = queueclass.appendleft
else:
packetReady = _packetReady
pullPacket = _pullPacket
reschedule = queueclass.appendleft
# forward queue functions
# no packet inspection needed
fpacketReady = bool
fpullPacket = queueclass.popleft
freschedule = queueclass.appendleft
tunfd = tun.fileno()
os_read = os.read
os_write = os.write
tget = time.time
maxbwfree = bwfree = 1500 * tunqueue
lastbwtime = tget()
remoteok = True
while not TERMINATE:
# The SUSPEND flag has been set. This means we need to wait on
# the SUSPEND condition until it is released.
while SUSPEND and not TERMINATE:
time.sleep(0.5)
wset = []
if packetReady(bkbuf):
wset.append(tun)
if remoteok and fpacketReady(fwbuf) and (not bwlimit or bwfree > 0):
wset.append(remote)
rset = []
if len(fwbuf) < maxfwbuf:
rset.append(tun)
if remoteok and len(bkbuf) < maxbkbuf:
rset.append(remote)
if remoteok:
eset = (tun,remote)
else:
eset = (tun,)
try:
rdrdy, wrdy, errs = select(rset,wset,eset,1)
except selecterror as e:
if e.args[0] == errno.EINTR:
# just retry
continue
else:
traceback.print_exc(file=sys.stderr)
# If the SUSPEND flag has been set, then the TUN will be in a bad
# state and the select error should be ignores.
if SUSPEND:
continue
else:
raise
# check for errors
if errs:
if reconnect is not None and remote in errs and tun not in errs:
remote = reconnect()
if hasattr(remote, 'fileno'):
remote_fd = remote.fileno()
elif udp and remote in errs and tun not in errs:
# In UDP mode, those are always transient errors
# Usually, an error will imply a read-ready socket
# that will raise an "Connection refused" error, so
# disable read-readiness just for now, and retry
# the select
remoteok = False
continue
else:
break
else:
remoteok = True
# check to see if we can write
#rr = wr = rt = wt = 0
if remote in wrdy:
sent = 0
try:
try:
for x in xrange(maxbatch):
packet = pullPacket(fwbuf)
if crypto_mode:
packet = encrypt_(packet, crypter)
sentnow = rwrite(remote, packet)
sent += sentnow
#wr += 1
if not udp and 0 <= sentnow < len(packet):
# packet partially sent
# reschedule the remaining part
# this doesn't happen ever in udp mode
freschedule(fwbuf, buffer(packet,sentnow))
if not rnonblock or not fpacketReady(fwbuf):
break
except OSError as e:
# This except handles the entire While block on PURPOSE
# as an optimization (setting a try/except block is expensive)
# The only operation that can raise this exception is rwrite
if e.errno in retrycodes:
# re-schedule packet
freschedule(fwbuf, packet)
else:
raise
except:
if reconnect is not None:
# in UDP mode, sometimes connected sockets can return a connection refused.
# Give the caller a chance to reconnect
remote = reconnect()
if hasattr(remote, 'fileno'):
remote_fd = remote.fileno()
elif not udp:
# in UDP mode, we ignore errors - packet loss man...
raise
#traceback.print_exc(file=sys.stderr)
if bwlimit:
bwfree -= sent
if tun in wrdy:
try:
for x in xrange(maxtbatch):
packet = pullPacket(bkbuf)
twrite(tunfd, packet)
#wt += 1
# Do not inject packets into the TUN faster than they arrive, unless we're falling
# behind. TUN devices discard packets if their queue is full (tunkqueue), but they
# don't block either (they're always ready to write), so if we flood the device
# we'll have high packet loss.
if not tnonblock or (slowlocal and len(bkbuf) < tunhurry) or not packetReady(bkbuf):
break
else:
if slowlocal:
# Give some time for the kernel to process the packets
time.sleep(0)
except OSError as e:
# This except handles the entire While block on PURPOSE
# as an optimization (setting a try/except block is expensive)
# The only operation that can raise this exception is os_write
if e.errno in retrycodes:
# re-schedule packet
reschedule(bkbuf, packet)
else:
raise
# check incoming data packets
if tun in rdrdy:
try:
for x in xrange(maxbatch):
packet = tread(tunfd,2000) # tun.read blocks until it gets 2k!
if not packet:
continue
#rt += 1
fwbuf.append(packet)
if not tnonblock or len(fwbuf) >= maxfwbuf:
break
except OSError as e:
# This except handles the entire While block on PURPOSE
# as an optimization (setting a try/except block is expensive)
# The only operation that can raise this exception is os_read
if e.errno not in retrycodes:
raise
if remote in rdrdy:
try:
try:
for x in xrange(maxbatch):
packet = rread(remote,2000)
#rr += 1
if crypto_mode:
packet = decrypt_(packet, crypter)
if not packet:
continue
elif not packet:
if not udp and packet == "":
# Connection broken, try to reconnect (or just die)
raise RuntimeError("Connection broken")
else:
continue
bkbuf.append(packet)
if not rnonblock or len(bkbuf) >= maxbkbuf:
break
except OSError as e:
# This except handles the entire While block on PURPOSE
# as an optimization (setting a try/except block is expensive)
# The only operation that can raise this exception is rread
if e.errno not in retrycodes:
raise
except Exception as e:
if reconnect is not None:
# in UDP mode, sometimes connected sockets can return a connection refused
# on read. Give the caller a chance to reconnect
remote = reconnect()
if hasattr(remote, 'fileno'):
remote_fd = remote.fileno()
elif not udp:
# in UDP mode, we ignore errors - packet loss man...
raise
traceback.print_exc(file=sys.stderr)
if bwlimit:
tnow = tget()
delta = tnow - lastbwtime
if delta > 0.001:
delta = int(bwlimit * delta)
if delta > 0:
bwfree = min(bwfree+delta, maxbwfree)
lastbwtime = tnow
#print >>sys.stderr, "rr:%d\twr:%d\trt:%d\twt:%d" % (rr,wr,rt,wt)
def udp_connect(TERMINATE, local_addr, local_port, peer_addr, peer_port):
rsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
retrydelay = 1.0
for i in xrange(30):
# TERMINATE is a array. An item can be added to TERMINATE, from
# outside this function to force termination of the loop
if TERMINATE:
raise OSError("Killed")
try:
rsock.bind((local_addr, local_port))
break
except socket.error:
# wait a while, retry
print("%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),), file=sys.stderr)
time.sleep(min(30.0,retrydelay))
retrydelay *= 1.1
else:
rsock.bind((local_addr, local_port))
print("Listening UDP at: %s:%d" % (local_addr, local_port), file=sys.stderr)
print("Connecting UDP to: %s:%d" % (peer_addr, peer_port), file=sys.stderr)
rsock.connect((peer_addr, peer_port))
return rsock
def udp_handshake(TERMINATE, rsock):
endme = False
def keepalive():
while not endme and not TERMINATE:
try:
rsock.send('')
except:
pass
time.sleep(1)
try:
rsock.send('')
except:
pass
keepalive_thread = threading.Thread(target=keepalive)
keepalive_thread.start()
for i in xrange(900):
if TERMINATE:
raise OSError("Killed")
try:
heartbeat = rsock.recv(10)
break
except:
time.sleep(1)
else:
heartbeat = rsock.recv(10)
endme = True
keepalive_thread.join()
def udp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
rsock = udp_connect(TERMINATE, local_addr, local_port, peer_addr,
peer_port)
udp_handshake(TERMINATE, rsock)
return rsock
def tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port):
sock = None
retrydelay = 1.0
# The peer has a firewall that prevents a response to the connect, we
# will be forever blocked in the connect, so we put a reasonable timeout.
rsock.settimeout(10)
# We wait for
for i in xrange(30):
if stop:
break
if TERMINATE:
raise OSError("Killed")
try:
rsock.connect((peer_addr, peer_port))
sock = rsock
break
except socket.error:
# wait a while, retry
print("%s: Could not connect. Retrying in a sec..." % (time.strftime('%c'),), file=sys.stderr)
time.sleep(min(30.0,retrydelay))
retrydelay *= 1.1
else:
rsock.connect((peer_addr, peer_port))
sock = rsock
if sock:
print("tcp_connect: TCP sock connected to remote %s:%s" % (peer_addr, peer_port), file=sys.stderr)
sock.settimeout(0)
print("tcp_connect: disabling NAGLE", file=sys.stderr)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return sock
def tcp_listen(TERMINATE, stop, lsock, local_addr, local_port):
sock = None
retrydelay = 1.0
# We try to bind to the local virtual interface.
# It might not exist yet so we wait in a loop.
for i in xrange(30):
if stop:
break
if TERMINATE:
raise OSError("Killed")
try:
lsock.bind((local_addr, local_port))
break
except socket.error:
# wait a while, retry
print("%s: Could not bind. Retrying in a sec..." % (time.strftime('%c'),), file=sys.stderr)
time.sleep(min(30.0,retrydelay))
retrydelay *= 1.1
else:
lsock.bind((local_addr, local_port))
print("tcp_listen: TCP sock listening in local sock %s:%s" % (local_addr, local_port), file=sys.stderr)
# Now we wait until the other side connects.
# The other side might not be ready yet, so we also wait in a loop for timeouts.
timeout = 1
lsock.listen(1)
for i in xrange(30):
if TERMINATE:
raise OSError("Killed")
rlist, wlist, xlist = select.select([lsock], [], [], timeout)
if stop:
break
if lsock in rlist:
sock,raddr = lsock.accept()
print("tcp_listen: TCP connection accepted in local sock %s:%s" % (local_addr, local_port), file=sys.stderr)
break
timeout += 5
return sock
def tcp_handshake(rsock, listen, hand):
# we are going to use a barrier algorithm to decide wich side listen.
# each side will "roll a dice" and send the resulting value to the other
# side.
win = False
rsock.settimeout(10)
try:
rsock.send(hand)
peer_hand = rsock.recv(4)
if not peer_hand:
print("tcp_handshake: connection reset by peer", file=sys.stderr)
return False
else:
print("tcp_handshake: hand %r, peer_hand %r" % (hand, peer_hand), file=sys.stderr)
if hand < peer_hand:
if listen:
win = True
elif hand > peer_hand:
if not listen:
win = True
finally:
rsock.settimeout(0)
return win
def tcp_establish(TERMINATE, local_addr, local_port, peer_addr, peer_port):
def listen(stop, hand, lsock, lresult):
win = False
rsock = tcp_listen(TERMINATE, stop, lsock, local_addr, local_port)
if rsock:
win = tcp_handshake(rsock, True, hand)
stop.append(True)
lresult.append((win, rsock))
def connect(stop, hand, rsock, rresult):
win = False
rsock = tcp_connect(TERMINATE, stop, rsock, peer_addr, peer_port)
if rsock:
win = tcp_handshake(rsock, False, hand)
stop.append(True)
rresult.append((win, rsock))
end = False
sock = None
for i in xrange(0, 50):
if end:
break
if TERMINATE:
raise OSError("Killed")
hand = struct.pack("!L", random.randint(0, 2**30))
stop = []
lresult = []
rresult = []
lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
listen_thread = threading.Thread(target=listen, args=(stop, hand, lsock, lresult))
connect_thread = threading.Thread(target=connect, args=(stop, hand, rsock, rresult))
connect_thread.start()
listen_thread.start()
connect_thread.join()
listen_thread.join()
(lwin, lrsock) = lresult[0]
(rwin, rrsock) = rresult[0]
if not lrsock or not rrsock:
if not lrsock:
sock = rrsock
if not rrsock:
sock = lrsock
end = True
# both socket are connected
else:
if lwin:
sock = lrsock
end = True
elif rwin:
sock = rrsock
end = True
if not sock:
raise OSError("Error: tcp_establish could not establish connection.")
return sock