python: Workaround UNIX socket path length limits
[sliver-openvswitch.git] / python / ovs / stream.py
index 16e383a..c640ebf 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2010 Nicira Networks
+# Copyright (c) 2010, 2011, 2012 Nicira, Inc.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # limitations under the License.
 
 import errno
 # limitations under the License.
 
 import errno
-import logging
 import os
 import os
-import select
 import socket
 import socket
-import sys
 
 import ovs.poller
 import ovs.socket_util
 
 import ovs.poller
 import ovs.socket_util
+import ovs.vlog
+
+vlog = ovs.vlog.Vlog("stream")
+
+
+def stream_or_pstream_needs_probes(name):
+    """ 1 if the stream or pstream specified by 'name' needs periodic probes to
+    verify connectivity.  For [p]streams which need probes, it can take a long
+    time to notice the connection was dropped.  Returns 0 if probes aren't
+    needed, and -1 if 'name' is invalid"""
+
+    if PassiveStream.is_valid_name(name) or Stream.is_valid_name(name):
+        # Only unix and punix are supported currently.
+        return 0
+    else:
+        return -1
+
 
 class Stream(object):
     """Bidirectional byte stream.  Currently only Unix domain sockets
     are implemented."""
 
 class Stream(object):
     """Bidirectional byte stream.  Currently only Unix domain sockets
     are implemented."""
-    n_unix_sockets = 0
 
     # States.
     __S_CONNECTING = 0
 
     # States.
     __S_CONNECTING = 0
@@ -37,17 +50,29 @@ class Stream(object):
     W_RECV = 1                  # Data received.
     W_SEND = 2                  # Send buffer room available.
 
     W_RECV = 1                  # Data received.
     W_SEND = 2                  # Send buffer room available.
 
+    _SOCKET_METHODS = {}
+
+    @staticmethod
+    def register_method(method, cls):
+        Stream._SOCKET_METHODS[method + ":"] = cls
+
+    @staticmethod
+    def _find_method(name):
+        for method, cls in Stream._SOCKET_METHODS.items():
+            if name.startswith(method):
+                return cls
+        return None
+
     @staticmethod
     def is_valid_name(name):
         """Returns True if 'name' is a stream name in the form "TYPE:ARGS" and
     @staticmethod
     def is_valid_name(name):
         """Returns True if 'name' is a stream name in the form "TYPE:ARGS" and
-        TYPE is a supported stream type (currently only "unix:"), otherwise
-        False."""
-        return name.startswith("unix:")
+        TYPE is a supported stream type (currently only "unix:" and "tcp:"),
+        otherwise False."""
+        return bool(Stream._find_method(name))
 
 
-    def __init__(self, socket, name, bind_path, status):
+    def __init__(self, socket, name, status):
         self.socket = socket
         self.name = name
         self.socket = socket
         self.name = name
-        self.bind_path = bind_path
         if status == errno.EAGAIN:
             self.state = Stream.__S_CONNECTING
         elif status == 0:
         if status == errno.EAGAIN:
             self.state = Stream.__S_CONNECTING
         elif status == 0:
@@ -57,12 +82,18 @@ class Stream(object):
 
         self.error = 0
 
 
         self.error = 0
 
+    # Default value of dscp bits for connection between controller and manager.
+    # Value of IPTOS_PREC_INTERNETCONTROL = 0xc0 which is defined
+    # in <netinet/ip.h> is used.
+    IPTOS_PREC_INTERNETCONTROL = 0xc0
+    DSCP_DEFAULT = IPTOS_PREC_INTERNETCONTROL >> 2
+
     @staticmethod
     @staticmethod
-    def open(name):
+    def open(name, dscp=DSCP_DEFAULT):
         """Attempts to connect a stream to a remote peer.  'name' is a
         connection name in the form "TYPE:ARGS", where TYPE is an active stream
         class's name and ARGS are stream class-specific.  Currently the only
         """Attempts to connect a stream to a remote peer.  'name' is a
         connection name in the form "TYPE:ARGS", where TYPE is an active stream
         class's name and ARGS are stream class-specific.  Currently the only
-        supported TYPE is "unix".
+        supported TYPEs are "unix" and "tcp".
 
         Returns (error, stream): on success 'error' is 0 and 'stream' is the
         new Stream, on failure 'error' is a positive errno value and 'stream'
 
         Returns (error, stream): on success 'error' is 0 and 'stream' is the
         new Stream, on failure 'error' is a positive errno value and 'stream'
@@ -71,21 +102,21 @@ class Stream(object):
         Never returns errno.EAGAIN or errno.EINPROGRESS.  Instead, returns 0
         and a new Stream.  The connect() method can be used to check for
         successful connection completion."""
         Never returns errno.EAGAIN or errno.EINPROGRESS.  Instead, returns 0
         and a new Stream.  The connect() method can be used to check for
         successful connection completion."""
-        if not Stream.is_valid_name(name):
+        cls = Stream._find_method(name)
+        if not cls:
             return errno.EAFNOSUPPORT, None
 
             return errno.EAFNOSUPPORT, None
 
-        Stream.n_unix_sockets += 1
-        bind_path = "/tmp/stream-unix.%ld.%d" % (os.getpid(),
-                                                 Stream.n_unix_sockets)
-        connect_path = name[5:]
-        error, sock = ovs.socket_util.make_unix_socket(socket.SOCK_STREAM,
-                                                       True, bind_path,
-                                                       connect_path)
+        suffix = name.split(":", 1)[1]
+        error, sock = cls._open(suffix, dscp)
         if error:
             return error, None
         else:
             status = ovs.socket_util.check_connection_completion(sock)
         if error:
             return error, None
         else:
             status = ovs.socket_util.check_connection_completion(sock)
-            return 0, Stream(sock, name, bind_path, status)
+            return 0, Stream(sock, name, status)
+
+    @staticmethod
+    def _open(suffix, dscp):
+        raise NotImplementedError("This method must be overrided by subclass")
 
     @staticmethod
     def open_block((error, stream)):
 
     @staticmethod
     def open_block((error, stream)):
@@ -103,11 +134,11 @@ class Stream(object):
                     break
                 stream.run()
                 poller = ovs.poller.Poller()
                     break
                 stream.run()
                 poller = ovs.poller.Poller()
-                stream.run_wait()
+                stream.run_wait(poller)
                 stream.connect_wait(poller)
                 poller.block()
             assert error != errno.EINPROGRESS
                 stream.connect_wait(poller)
                 poller.block()
             assert error != errno.EINPROGRESS
-        
+
         if error and stream:
             stream.close()
             stream = None
         if error and stream:
             stream.close()
             stream = None
@@ -115,9 +146,6 @@ class Stream(object):
 
     def close(self):
         self.socket.close()
 
     def close(self):
         self.socket.close()
-        if self.bind_path is not None:
-            ovs.fatal_signal.unlink_file_now(self.bind_path)
-            self.bind_path = None
 
     def __scs_connecting(self):
         retval = ovs.socket_util.check_connection_completion(self.socket)
 
     def __scs_connecting(self):
         retval = ovs.socket_util.check_connection_completion(self.socket)
@@ -133,20 +161,22 @@ class Stream(object):
         is complete, returns 0 if the connection was successful or a positive
         errno value if it failed.  If the connection is still in progress,
         returns errno.EAGAIN."""
         is complete, returns 0 if the connection was successful or a positive
         errno value if it failed.  If the connection is still in progress,
         returns errno.EAGAIN."""
-        last_state = -1         # Always differs from initial self.state
-        while self.state != last_state:
-            last_state = self.state
-            if self.state == Stream.__S_CONNECTING:
-                self.__scs_connecting()
-            elif self.state == Stream.__S_CONNECTED:
-                return 0
-            elif self.state == Stream.__S_DISCONNECTED:
-                return self.error
+
+        if self.state == Stream.__S_CONNECTING:
+            self.__scs_connecting()
+
+        if self.state == Stream.__S_CONNECTING:
+            return errno.EAGAIN
+        elif self.state == Stream.__S_CONNECTED:
+            return 0
+        else:
+            assert self.state == Stream.__S_DISCONNECTED
+            return self.error
 
     def recv(self, n):
         """Tries to receive up to 'n' bytes from this stream.  Returns a
         (error, string) tuple:
 
     def recv(self, n):
         """Tries to receive up to 'n' bytes from this stream.  Returns a
         (error, string) tuple:
-        
+
             - If successful, 'error' is zero and 'string' contains between 1
               and 'n' bytes of data.
 
             - If successful, 'error' is zero and 'string' contains between 1
               and 'n' bytes of data.
 
@@ -154,7 +184,7 @@ class Stream(object):
 
             - If the connection has been closed in the normal fashion or if 'n'
               is 0, the tuple is (0, "").
 
             - If the connection has been closed in the normal fashion or if 'n'
               is 0, the tuple is (0, "").
-        
+
         The recv function will not block waiting for data to arrive.  If no
         data have been received, it returns (errno.EAGAIN, "") immediately."""
 
         The recv function will not block waiting for data to arrive.  If no
         data have been received, it returns (errno.EAGAIN, "") immediately."""
 
@@ -206,27 +236,25 @@ class Stream(object):
 
         if self.state == Stream.__S_CONNECTING:
             wait = Stream.W_CONNECT
 
         if self.state == Stream.__S_CONNECTING:
             wait = Stream.W_CONNECT
-        if wait in (Stream.W_CONNECT, Stream.W_SEND):
-            poller.fd_wait(self.socket, select.POLLOUT)
+        if wait == Stream.W_RECV:
+            poller.fd_wait(self.socket, ovs.poller.POLLIN)
         else:
         else:
-            poller.fd_wait(self.socket, select.POLLIN)
+            poller.fd_wait(self.socket, ovs.poller.POLLOUT)
 
     def connect_wait(self, poller):
         self.wait(poller, Stream.W_CONNECT)
 
     def connect_wait(self, poller):
         self.wait(poller, Stream.W_CONNECT)
-        
+
     def recv_wait(self, poller):
         self.wait(poller, Stream.W_RECV)
     def recv_wait(self, poller):
         self.wait(poller, Stream.W_RECV)
-        
+
     def send_wait(self, poller):
         self.wait(poller, Stream.W_SEND)
     def send_wait(self, poller):
         self.wait(poller, Stream.W_SEND)
-        
-    def get_name(self):
-        return self.name
-        
+
     def __del__(self):
         # Don't delete the file: we might have forked.
         self.socket.close()
 
     def __del__(self):
         # Don't delete the file: we might have forked.
         self.socket.close()
 
+
 class PassiveStream(object):
     @staticmethod
     def is_valid_name(name):
 class PassiveStream(object):
     @staticmethod
     def is_valid_name(name):
@@ -262,7 +290,7 @@ class PassiveStream(object):
         try:
             sock.listen(10)
         except socket.error, e:
         try:
             sock.listen(10)
         except socket.error, e:
-            logging.error("%s: listen: %s" % (name, os.strerror(e.error)))
+            vlog.err("%s: listen: %s" % (name, os.strerror(e.error)))
             sock.close()
             return e.error, None
 
             sock.close()
             return e.error, None
 
@@ -288,29 +316,47 @@ class PassiveStream(object):
             try:
                 sock, addr = self.socket.accept()
                 ovs.socket_util.set_nonblocking(sock)
             try:
                 sock, addr = self.socket.accept()
                 ovs.socket_util.set_nonblocking(sock)
-                return 0, Stream(sock, "unix:%s" % addr, None, 0)
+                return 0, Stream(sock, "unix:%s" % addr, 0)
             except socket.error, e:
                 error = ovs.socket_util.get_exception_errno(e)
                 if error != errno.EAGAIN:
                     # XXX rate-limit
             except socket.error, e:
                 error = ovs.socket_util.get_exception_errno(e)
                 if error != errno.EAGAIN:
                     # XXX rate-limit
-                    logging.debug("accept: %s" % os.strerror(error))
+                    vlog.dbg("accept: %s" % os.strerror(error))
                 return error, None
 
     def wait(self, poller):
                 return error, None
 
     def wait(self, poller):
-        poller.fd_wait(self.socket, select.POLLIN)
+        poller.fd_wait(self.socket, ovs.poller.POLLIN)
 
     def __del__(self):
         # Don't delete the file: we might have forked.
         self.socket.close()
 
 
     def __del__(self):
         # Don't delete the file: we might have forked.
         self.socket.close()
 
-def usage(name, active, passive, bootstrap):
-    print
-    if active:
-        print("Active %s connection methods:" % name)
-        print("  unix:FILE               "
-               "Unix domain socket named FILE");
-
-    if passive:
-        print("Passive %s connection methods:" % name)
-        print("  punix:FILE              "
-              "listen on Unix domain socket FILE")
+
+def usage(name):
+    return """
+Active %s connection methods:
+  unix:FILE               Unix domain socket named FILE
+  tcp:IP:PORT             TCP socket to IP with port no of PORT
+
+Passive %s connection methods:
+  punix:FILE              Listen on Unix domain socket FILE""" % (name, name)
+
+
+class UnixStream(Stream):
+    @staticmethod
+    def _open(suffix, dscp):
+        connect_path = suffix
+        return  ovs.socket_util.make_unix_socket(socket.SOCK_STREAM,
+                                                 True, None, connect_path)
+Stream.register_method("unix", UnixStream)
+
+
+class TCPStream(Stream):
+    @staticmethod
+    def _open(suffix, dscp):
+        error, sock = ovs.socket_util.inet_open_active(socket.SOCK_STREAM,
+                                                       suffix, 0, dscp)
+        if not error:
+            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+        return error, sock
+Stream.register_method("tcp", TCPStream)