Add the ability to connect to a vconn asynchronously.
[sliver-openvswitch.git] / lib / vconn-ssl.c
index 199c987..ed2e035 100644 (file)
@@ -33,7 +33,9 @@
 #include "socket-util.h"
 #include "util.h"
 #include "openflow.h"
+#include "poll-loop.h"
 #include "ofp-print.h"
+#include "socket-util.h"
 #include "vconn.h"
 
 #include "vlog.h"
@@ -42,8 +44,8 @@
 /* Active SSL. */
 
 enum ssl_state {
-    STATE_SSL_CONNECTING,
-    STATE_CONNECTED
+    STATE_TCP_CONNECTING,
+    STATE_SSL_CONNECTING
 };
 
 enum session_type {
@@ -61,6 +63,7 @@ struct ssl_vconn
     SSL *ssl;
     struct buffer *rxbuf;
     struct buffer *txbuf;
+    struct poll_waiter *tx_waiter;
 };
 
 /* SSL context created by ssl_init(). */
@@ -71,15 +74,15 @@ static bool has_private_key, has_certificate, has_ca_cert;
 
 static int ssl_init(void);
 static int do_ssl_init(void);
-static void connect_completed(struct ssl_vconn *, int error);
 static bool ssl_wants_io(int ssl_error);
 static void ssl_close(struct vconn *);
-static bool state_machine(struct ssl_vconn *sslv);
+static int interpret_ssl_error(const char *function, int ret, int error);
+static void ssl_do_tx(int fd, short int revents, void *vconn_);
 static DH *tmp_dh_callback(SSL *ssl, int is_export UNUSED, int keylength);
 
 static int
 new_ssl_vconn(const char *name, int fd, enum session_type type,
-              struct vconn **vconnp)
+              enum ssl_state state, struct vconn **vconnp)
 {
     struct ssl_vconn *sslv;
     SSL *ssl = NULL;
@@ -104,13 +107,7 @@ new_ssl_vconn(const char *name, int fd, enum session_type type,
         goto error;
     }
 
-    /* Make 'fd' non-blocking and disable Nagle. */
-    retval = set_nonblocking(fd);
-    if (retval) {
-        VLOG_ERR("%s: set_nonblocking: %s", name, strerror(retval));
-        close(fd);
-        return retval;
-    }
+    /* Disable Nagle. */
     retval = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, sizeof on);
     if (retval) {
         VLOG_ERR("%s: setsockopt(TCP_NODELAY): %s", name, strerror(errno));
@@ -133,12 +130,14 @@ new_ssl_vconn(const char *name, int fd, enum session_type type,
     /* Create and return the ssl_vconn. */
     sslv = xmalloc(sizeof *sslv);
     sslv->vconn.class = &ssl_vconn_class;
-    sslv->state = STATE_SSL_CONNECTING;
+    sslv->vconn.connect_status = EAGAIN;
+    sslv->state = state;
     sslv->type = type;
     sslv->fd = fd;
     sslv->ssl = ssl;
     sslv->rxbuf = NULL;
     sslv->txbuf = NULL;
+    sslv->tx_waiter = NULL;
     *vconnp = &sslv->vconn;
     return 0;
 
@@ -194,95 +193,74 @@ ssl_open(const char *name, char *suffix, struct vconn **vconnp)
         VLOG_ERR("%s: socket: %s", name, strerror(errno));
         return errno;
     }
+    retval = set_nonblocking(fd);
+    if (retval) {
+        close(fd);
+        return retval;
+    }
 
-    /* Connect socket (blocking). */
+    /* Connect socket. */
     retval = connect(fd, (struct sockaddr *) &sin, sizeof sin);
     if (retval < 0) {
-        int error = errno;
-        VLOG_ERR("%s: connect: %s", name, strerror(error));
-        close(fd);
-        return error;
+        if (errno == EINPROGRESS) {
+            return new_ssl_vconn(name, fd, CLIENT, STATE_TCP_CONNECTING,
+                                 vconnp);
+        } else {
+            int error = errno;
+            VLOG_ERR("%s: connect: %s", name, strerror(error));
+            close(fd);
+            return error;
+        }
+    } else {
+        return new_ssl_vconn(name, fd, CLIENT, STATE_SSL_CONNECTING,
+                             vconnp);
     }
-
-    /* Make an ssl_vconn for the socket. */
-    return new_ssl_vconn(name, fd, CLIENT, vconnp);
 }
 
-static void
-ssl_close(struct vconn *vconn)
+static int
+ssl_connect(struct vconn *vconn)
 {
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    SSL_free(sslv->ssl);
-    close(sslv->fd);
-    free(sslv);
-}
+    int retval;
 
-static bool
-ssl_want_io_to_events(SSL *ssl, short int *events)
-{
-    if (SSL_want_read(ssl)) {
-        *events |= POLLIN;
-        return true;
-    } else if (SSL_want_write(ssl)) {
-        *events |= POLLOUT;
-        return true;
-    } else {
-        return false;
-    }
-}
+    switch (sslv->state) {
+    case STATE_TCP_CONNECTING:
+        retval = check_connection_completion(sslv->fd);
+        if (retval) {
+            return retval;
+        }
+        sslv->state = STATE_SSL_CONNECTING;
+        /* Fall through. */
 
-static bool
-ssl_prepoll(struct vconn *vconn, int want, struct pollfd *pfd)
-{
-    struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    pfd->fd = sslv->fd;
-    if (!state_machine(sslv)) {
-        switch (sslv->state) {
-        case STATE_SSL_CONNECTING:
-            if (!ssl_want_io_to_events(sslv->ssl, &pfd->events)) {
-                /* state_machine() should have transitioned us away to another
-                 * state. */
-                NOT_REACHED();
+    case STATE_SSL_CONNECTING:
+        retval = (sslv->type == CLIENT
+                   ? SSL_connect(sslv->ssl) : SSL_accept(sslv->ssl));
+        if (retval != 1) {
+            int error = SSL_get_error(sslv->ssl, retval);
+            if (retval < 0 && ssl_wants_io(error)) {
+                return EAGAIN;
+            } else {
+                interpret_ssl_error((sslv->type == CLIENT ? "SSL_connect"
+                                     : "SSL_accept"), retval, error);
+                shutdown(sslv->fd, SHUT_RDWR);
+                return EPROTO;
             }
-            break;
-        default:
-            NOT_REACHED();
-        }
-    } else if (sslv->connect_error) {
-        pfd->events = 0;
-        return true;
-    } else if (!ssl_want_io_to_events(sslv->ssl, &pfd->events)) {
-        if (want & WANT_RECV) {
-            pfd->events |= POLLIN;
-        }
-        if (want & WANT_SEND || sslv->txbuf) {
-            pfd->events |= POLLOUT;
+        } else {
+            return 0;
         }
     }
-    return false;
+
+    NOT_REACHED();
 }
 
 static void
-ssl_postpoll(struct vconn *vconn, short int *revents)
+ssl_close(struct vconn *vconn)
 {
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
-    if (!state_machine(sslv)) {
-        *revents = 0;
-    } else if (sslv->connect_error) {
-        *revents |= POLLERR;
-    } else if (*revents & POLLOUT && sslv->txbuf) {
-        ssize_t n = SSL_write(sslv->ssl, sslv->txbuf->data, sslv->txbuf->size);
-        if (n > 0) {
-            buffer_pull(sslv->txbuf, n);
-            if (sslv->txbuf->size == 0) {
-                buffer_delete(sslv->txbuf);
-                sslv->txbuf = NULL;
-            }
-        }
-        if (sslv->txbuf) {
-            *revents &= ~POLLOUT;
-        }
-    }
+    poll_cancel(sslv->tx_waiter);
+    SSL_free(sslv->ssl);
+    close(sslv->fd);
+    free(sslv);
 }
 
 static int
@@ -355,12 +333,6 @@ ssl_recv(struct vconn *vconn, struct buffer **bufferp)
     size_t want_bytes;
     ssize_t ret;
 
-    if (!state_machine(sslv)) {
-        return EAGAIN;
-    } else if (sslv->connect_error) {
-        return sslv->connect_error;
-    }
-
     if (sslv->rxbuf == NULL) {
         sslv->rxbuf = buffer_new(1564);
     }
@@ -412,18 +384,53 @@ again:
     }
 }
 
+static void
+ssl_clear_txbuf(struct ssl_vconn *sslv)
+{
+    buffer_delete(sslv->txbuf);
+    sslv->txbuf = NULL;
+    sslv->tx_waiter = NULL;
+}
+
+static void
+ssl_register_tx_waiter(struct vconn *vconn) 
+{
+    struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
+    short int events = SSL_want_read(sslv->ssl) ? POLLIN : POLLOUT;
+    sslv->tx_waiter = poll_fd_callback(sslv->fd, events, ssl_do_tx, vconn);
+}
+
+static void
+ssl_do_tx(int fd UNUSED, short int revents UNUSED, void *vconn_)
+{
+    struct vconn *vconn = vconn_;
+    struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
+    int ret = SSL_write(sslv->ssl, sslv->txbuf->data, sslv->txbuf->size);
+    if (ret > 0) {
+        buffer_pull(sslv->txbuf, ret);
+        if (sslv->txbuf->size == 0) {
+            ssl_clear_txbuf(sslv);
+            return;
+        }
+    } else {
+        int error = SSL_get_error(sslv->ssl, ret);
+        if (error == SSL_ERROR_ZERO_RETURN) {
+            /* Connection closed (EOF). */
+            VLOG_WARN("SSL_write: connection close");
+        } else if (interpret_ssl_error("SSL_write", ret, error) != EAGAIN) {
+            ssl_clear_txbuf(sslv);
+            return;
+        }
+    }
+    ssl_register_tx_waiter(vconn);
+}
+
 static int
 ssl_send(struct vconn *vconn, struct buffer *buffer)
 {
     struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
     ssize_t ret;
 
-    if (!state_machine(sslv)) {
-        return EAGAIN;
-    } else if (sslv->connect_error) {
-        return sslv->connect_error;
-    }
-
     if (sslv->txbuf) {
         return EAGAIN;
     }
@@ -435,6 +442,7 @@ ssl_send(struct vconn *vconn, struct buffer *buffer)
         } else {
             sslv->txbuf = buffer;
             buffer_pull(buffer, ret);
+            ssl_register_tx_waiter(vconn);
         }
         return 0;
     } else {
@@ -449,14 +457,65 @@ ssl_send(struct vconn *vconn, struct buffer *buffer)
     }
 }
 
+static bool
+ssl_needs_wait(struct ssl_vconn *sslv) 
+{
+    if (SSL_want_read(sslv->ssl)) {
+        poll_fd_wait(sslv->fd, POLLIN, NULL);
+        return true;
+    } else if (SSL_want_write(sslv->ssl)) {
+        poll_fd_wait(sslv->fd, POLLOUT, NULL);
+        return true;
+    } else {
+        return false;
+    }
+}
+
+static void
+ssl_wait(struct vconn *vconn, enum vconn_wait_type wait)
+{
+    struct ssl_vconn *sslv = ssl_vconn_cast(vconn);
+
+    switch (wait) {
+    case WAIT_CONNECT:
+        if (vconn_connect(vconn) != EAGAIN) {
+            poll_immediate_wake();
+        } else if (sslv->state == STATE_TCP_CONNECTING) {
+            poll_fd_wait(sslv->fd, POLLOUT, NULL);
+        } else if (!ssl_needs_wait(sslv)) {
+            NOT_REACHED();
+        }
+        break;
+
+    case WAIT_RECV:
+        if (!ssl_needs_wait(sslv)) {
+            if (SSL_pending(sslv->ssl)) {
+                poll_immediate_wake();
+            } else {
+                poll_fd_wait(sslv->fd, POLLIN, NULL);
+            }
+        }
+        break;
+
+    case WAIT_SEND:
+        if (!sslv->txbuf && !ssl_needs_wait(sslv)) {
+            poll_fd_wait(sslv->fd, POLLOUT, NULL);
+        }
+        break;
+
+    default:
+        NOT_REACHED();
+    }
+}
+
 struct vconn_class ssl_vconn_class = {
     .name = "ssl",
     .open = ssl_open,
     .close = ssl_close,
-    .prepoll = ssl_prepoll,
-    .postpoll = ssl_postpoll,
+    .connect = ssl_connect,
     .recv = ssl_recv,
     .send = ssl_send,
+    .wait = ssl_wait,
 };
 \f
 /* Passive SSL. */
@@ -524,13 +583,13 @@ pssl_open(const char *name, char *suffix, struct vconn **vconnp)
 
     retval = set_nonblocking(fd);
     if (retval) {
-        VLOG_ERR("%s: set_nonblocking: %s", name, strerror(retval));
         close(fd);
         return retval;
     }
 
     pssl = xmalloc(sizeof *pssl);
     pssl->vconn.class = &pssl_vconn_class;
+    pssl->vconn.connect_status = 0;
     pssl->fd = fd;
     *vconnp = &pssl->vconn;
     return 0;
@@ -544,41 +603,46 @@ pssl_close(struct vconn *vconn)
     free(pssl);
 }
 
-static bool
-pssl_prepoll(struct vconn *vconn, int want, struct pollfd *pfd)
-{
-    struct pssl_vconn *pssl = pssl_vconn_cast(vconn);
-    pfd->fd = pssl->fd;
-    if (want & WANT_ACCEPT) {
-        pfd->events |= POLLIN;
-    }
-    return false;
-}
-
 static int
 pssl_accept(struct vconn *vconn, struct vconn **new_vconnp)
 {
     struct pssl_vconn *pssl = pssl_vconn_cast(vconn);
     int new_fd;
+    int error;
 
     new_fd = accept(pssl->fd, NULL, NULL);
     if (new_fd < 0) {
         int error = errno;
         if (error != EAGAIN) {
-            VLOG_DBG("pssl: accept: %s", strerror(error));
+            VLOG_DBG("accept: %s", strerror(error));
         }
         return error;
     }
 
-    return new_ssl_vconn("ssl" /* FIXME */, new_fd, SERVER, new_vconnp);
+    error = set_nonblocking(new_fd);
+    if (error) {
+        close(new_fd);
+        return error;
+    }
+
+    return new_ssl_vconn("ssl" /* FIXME */, new_fd,
+                         SERVER, STATE_SSL_CONNECTING, new_vconnp);
+}
+
+static void
+pssl_wait(struct vconn *vconn, enum vconn_wait_type wait)
+{
+    struct pssl_vconn *pssl = pssl_vconn_cast(vconn);
+    assert(wait == WAIT_ACCEPT);
+    poll_fd_wait(pssl->fd, POLLIN, NULL);
 }
 
 struct vconn_class pssl_vconn_class = {
     .name = "pssl",
     .open = pssl_open,
     .close = pssl_close,
-    .prepoll = pssl_prepoll,
     .accept = pssl_accept,
+    .wait = pssl_wait,
 };
 \f
 /*
@@ -633,37 +697,6 @@ do_ssl_init(void)
     return 0;
 }
 
-static bool
-state_machine(struct ssl_vconn *sslv)
-{
-    if (sslv->state == STATE_SSL_CONNECTING) {
-        int ret = (sslv->type == CLIENT
-                   ? SSL_connect(sslv->ssl) : SSL_accept(sslv->ssl));
-        if (ret != 1) {
-            int error = SSL_get_error(sslv->ssl, ret);
-            if (ret < 0 && ssl_wants_io(error)) {
-                /* Stay in this state to repeat the SSL_connect later. */
-                return false;
-            } else {
-                interpret_ssl_error((sslv->type == CLIENT ? "SSL_connect"
-                                     : "SSL_accept"), ret, error);
-                shutdown(sslv->fd, SHUT_RDWR);
-                connect_completed(sslv, EPROTO);
-            }
-        } else {
-            connect_completed(sslv, 0);
-        }
-    }
-    return sslv->state == STATE_CONNECTED;
-}
-
-static void
-connect_completed(struct ssl_vconn *sslv, int error)
-{
-    sslv->state = STATE_CONNECTED;
-    sslv->connect_error = error;
-}
-
 static DH *
 tmp_dh_callback(SSL *ssl, int is_export UNUSED, int keylength)
 {