ustream-ssl: add support for using a fd instead of ustream as backing
[project/ustream-ssl.git] / ustream-ssl.c
index 0ae5df678adc4b441010638f356dd6b97d43ddce..b07629931525fe285f94b8c28b50d95b04fed5b8 100644 (file)
@@ -17,6 +17,8 @@
  */
 
 #include <errno.h>
+#include <stdlib.h>
+#include <string.h>
 #include <libubox/ustream.h>
 
 #include "ustream-ssl.h"
@@ -38,15 +40,35 @@ static void ustream_ssl_check_conn(struct ustream_ssl *us)
                return;
 
        if (__ustream_ssl_connect(us) == U_SSL_OK) {
+
+               /* __ustream_ssl_connect() will also return U_SSL_OK when certificate
+                * verification failed!
+                *
+                * Applications may register a custom .notify_verify_error callback in the
+                * struct ustream_ssl which is called upon verification failures, but there
+                * is no straight forward way for the callback to terminate the connection
+                * initiation right away, e.g. through a true or false return value.
+                *
+                * Instead, existing implementations appear to set .eof field of the underlying
+                * ustream in the hope that this inhibits further operations on the stream.
+                *
+                * Declare this informal behaviour "official" and check for the state of the
+                * .eof member after __ustream_ssl_connect() returned, and do not write the
+                * pending data if it is set to true.
+                */
+
+               if (us->stream.eof)
+                       return;
+
                us->connected = true;
                if (us->notify_connected)
                        us->notify_connected(us);
+               ustream_write_pending(&us->stream);
        }
 }
 
-static bool __ustream_ssl_poll(struct ustream *s)
+static bool __ustream_ssl_poll(struct ustream_ssl *us)
 {
-       struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
        char *buf;
        int len, ret;
        bool more = false;
@@ -61,6 +83,12 @@ static bool __ustream_ssl_poll(struct ustream *s)
                        break;
 
                ret = __ustream_ssl_read(us, buf, len);
+               if (ret == U_SSL_PENDING) {
+                       if (us->conn)
+                               ustream_poll(us->conn);
+                       ret = __ustream_ssl_read(us, buf, len);
+               }
+
                switch (ret) {
                case U_SSL_PENDING:
                        return more;
@@ -82,7 +110,9 @@ static bool __ustream_ssl_poll(struct ustream *s)
 
 static void ustream_ssl_notify_read(struct ustream *s, int bytes)
 {
-       __ustream_ssl_poll(s);
+       struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
+
+       __ustream_ssl_poll(us);
 }
 
 static void ustream_ssl_notify_write(struct ustream *s, int bytes)
@@ -106,7 +136,7 @@ static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool m
        if (!us->connected || us->error)
                return 0;
 
-       if (us->conn->w.data_bytes)
+       if (us->conn && us->conn->w.data_bytes)
                return 0;
 
        return __ustream_ssl_write(us, buf, len);
@@ -115,8 +145,17 @@ static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool m
 static void ustream_ssl_set_read_blocked(struct ustream *s)
 {
        struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
+       unsigned int ev = ULOOP_WRITE | ULOOP_EDGE_TRIGGER;
+
+       if (us->conn) {
+               ustream_set_read_blocked(us->conn, !!s->read_blocked);
+               return;
+       }
+
+       if (!s->read_blocked)
+               ev |= ULOOP_READ;
 
-       ustream_set_read_blocked(us->conn, !!s->read_blocked);
+       uloop_fd_add(&us->fd, ev);
 }
 
 static void ustream_ssl_free(struct ustream *s)
@@ -128,24 +167,40 @@ static void ustream_ssl_free(struct ustream *s)
                us->conn->notify_read = NULL;
                us->conn->notify_write = NULL;
                us->conn->notify_state = NULL;
+       } else {
+               uloop_fd_delete(&us->fd);
        }
 
        uloop_timeout_cancel(&us->error_timer);
-       __ustream_ssl_session_free(us->ssl);
+       __ustream_ssl_session_free(us);
+       free(us->peer_cn);
+
        us->ctx = NULL;
        us->ssl = NULL;
        us->conn = NULL;
+       us->peer_cn = NULL;
        us->connected = false;
        us->error = false;
+       us->valid_cert = false;
+       us->valid_cn = false;
 }
 
 static bool ustream_ssl_poll(struct ustream *s)
 {
        struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
-       bool fd_poll;
+       bool fd_poll = false;
+
+       if (us->conn)
+               fd_poll = ustream_poll(us->conn);
+
+       return __ustream_ssl_poll(us) || fd_poll;
+}
+
+static void ustream_ssl_fd_cb(struct uloop_fd *fd, unsigned int events)
+{
+       struct ustream_ssl *us = container_of(fd, struct ustream_ssl, fd);
 
-       fd_poll = ustream_poll(us->conn);
-       return __ustream_ssl_poll(s) || fd_poll;
+       __ustream_ssl_poll(us);
 }
 
 static void ustream_ssl_stream_init(struct ustream_ssl *us)
@@ -153,32 +208,67 @@ static void ustream_ssl_stream_init(struct ustream_ssl *us)
        struct ustream *conn = us->conn;
        struct ustream *s = &us->stream;
 
-       conn->notify_read = ustream_ssl_notify_read;
-       conn->notify_write = ustream_ssl_notify_write;
-       conn->notify_state = ustream_ssl_notify_state;
+       if (conn) {
+               conn->notify_read = ustream_ssl_notify_read;
+               conn->notify_write = ustream_ssl_notify_write;
+               conn->notify_state = ustream_ssl_notify_state;
+       } else {
+               us->fd.cb = ustream_ssl_fd_cb;
+               uloop_fd_add(&us->fd, ULOOP_READ | ULOOP_WRITE | ULOOP_EDGE_TRIGGER);
+       }
 
+       s->set_read_blocked = ustream_ssl_set_read_blocked;
        s->free = ustream_ssl_free;
        s->write = ustream_ssl_write;
        s->poll = ustream_ssl_poll;
-       s->set_read_blocked = ustream_ssl_set_read_blocked;
        ustream_init_defaults(s);
 }
 
-static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, void *ctx, bool server)
+static int _ustream_ssl_init_common(struct ustream_ssl *us)
 {
        us->error_timer.cb = ustream_ssl_error_cb;
-       us->server = server;
-       us->conn = conn;
-       us->ctx = ctx;
 
        us->ssl = __ustream_ssl_session_new(us->ctx);
        if (!us->ssl)
                return -ENOMEM;
 
-       conn->next = &us->stream;
-       ustream_set_io(ctx, us->ssl, conn);
+       ustream_set_io(us);
        ustream_ssl_stream_init(us);
 
+       if (us->server_name)
+               __ustream_ssl_set_server_name(us);
+
+       ustream_ssl_check_conn(us);
+
+       return 0;
+}
+
+static int _ustream_ssl_init_fd(struct ustream_ssl *us, int fd, struct ustream_ssl_ctx *ctx, bool server)
+{
+       us->server = server;
+       us->ctx = ctx;
+       us->fd.fd = fd;
+
+       return _ustream_ssl_init_common(us);
+}
+
+static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server)
+{
+       us->server = server;
+       us->ctx = ctx;
+
+       us->conn = conn;
+       conn->r.max_buffers = 4;
+       conn->next = &us->stream;
+
+       return _ustream_ssl_init_common(us);
+}
+
+static int _ustream_ssl_set_peer_cn(struct ustream_ssl *us, const char *name)
+{
+       us->peer_cn = strdup(name);
+       __ustream_ssl_update_peer_cn(us);
+
        return 0;
 }
 
@@ -186,6 +276,12 @@ const struct ustream_ssl_ops ustream_ssl_ops = {
        .context_new = __ustream_ssl_context_new,
        .context_set_crt_file = __ustream_ssl_set_crt_file,
        .context_set_key_file = __ustream_ssl_set_key_file,
+       .context_add_ca_crt_file = __ustream_ssl_add_ca_crt_file,
+       .context_set_ciphers = __ustream_ssl_set_ciphers,
+       .context_set_require_validation = __ustream_ssl_set_require_validation,
+       .context_set_debug = __ustream_ssl_set_debug,
        .context_free = __ustream_ssl_context_free,
        .init = _ustream_ssl_init,
+       .init_fd = _ustream_ssl_init_fd,
+       .set_peer_cn = _ustream_ssl_set_peer_cn,
 };