From 524a76e5af78fa577c46e0d24bdedd4254e07cd4 Mon Sep 17 00:00:00 2001 From: Felix Fietkau Date: Fri, 19 Apr 2024 16:43:35 +0200 Subject: [PATCH] ustream-ssl: add support for using a fd instead of ustream as backing This improves performance by avoiding double buffering Signed-off-by: Felix Fietkau --- ustream-internal.h | 3 +- ustream-io-openssl.c | 21 +++++++++-- ustream-io-wolfssl.c | 15 +++++--- ustream-mbedtls.c | 33 ++++++++++++++--- ustream-mbedtls.h | 1 - ustream-openssl.c | 15 +++++--- ustream-openssl.h | 2 -- ustream-ssl.c | 86 +++++++++++++++++++++++++++++++++----------- ustream-ssl.h | 2 ++ 9 files changed, 135 insertions(+), 43 deletions(-) diff --git a/ustream-internal.h b/ustream-internal.h index 50e105f..4eec9cd 100644 --- a/ustream-internal.h +++ b/ustream-internal.h @@ -34,7 +34,7 @@ enum ssl_conn_status { U_SSL_RETRY = -3, }; -void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *s); +void ustream_set_io(struct ustream_ssl *us); struct ustream_ssl_ctx *__ustream_ssl_context_new(bool server); int __ustream_ssl_add_ca_crt_file(struct ustream_ssl_ctx *ctx, const char *file); int __ustream_ssl_set_crt_file(struct ustream_ssl_ctx *ctx, const char *file); @@ -46,5 +46,6 @@ void __ustream_ssl_context_free(struct ustream_ssl_ctx *ctx); enum ssl_conn_status __ustream_ssl_connect(struct ustream_ssl *us); int __ustream_ssl_read(struct ustream_ssl *us, char *buf, int len); int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int len); +void __ustream_ssl_session_free(struct ustream_ssl *us); #endif diff --git a/ustream-io-openssl.c b/ustream-io-openssl.c index 7045bb6..4ca77de 100644 --- a/ustream-io-openssl.c +++ b/ustream-io-openssl.c @@ -137,8 +137,23 @@ static BIO *ustream_bio_new(struct ustream *s) return bio; } -__hidden void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *conn) +static BIO *fd_bio_new(int fd) { - BIO *bio = ustream_bio_new(conn); - SSL_set_bio(ssl, bio, bio); + BIO *bio = BIO_new(BIO_s_socket()); + + BIO_set_fd(bio, fd, BIO_NOCLOSE); + + return bio; +} + +__hidden void ustream_set_io(struct ustream_ssl *us) +{ + BIO *bio; + + if (us->conn) + bio = ustream_bio_new(us->conn); + else + bio = fd_bio_new(us->fd.fd); + + SSL_set_bio(us->ssl, bio, bio); } diff --git a/ustream-io-wolfssl.c b/ustream-io-wolfssl.c index 4ff85d3..0a97edc 100644 --- a/ustream-io-wolfssl.c +++ b/ustream-io-wolfssl.c @@ -65,10 +65,15 @@ static int io_send_cb(SSL* ssl, char *buf, int sz, void *ctx) return s_ustream_write(buf, sz, ctx); } -__hidden void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *conn) +__hidden void ustream_set_io(struct ustream_ssl *us) { - wolfSSL_SSLSetIORecv(ssl, io_recv_cb); - wolfSSL_SSLSetIOSend(ssl, io_send_cb); - wolfSSL_SetIOReadCtx(ssl, conn); - wolfSSL_SetIOWriteCtx(ssl, conn); + if (!us->conn) { + wolfSSL_set_fd(us->ssl, us->fd.fd); + return; + } + + wolfSSL_SSLSetIORecv(us->ssl, io_recv_cb); + wolfSSL_SSLSetIOSend(us->ssl, io_send_cb); + wolfSSL_SetIOReadCtx(us->ssl, us->conn); + wolfSSL_SetIOWriteCtx(us->ssl, us->conn); } diff --git a/ustream-mbedtls.c b/ustream-mbedtls.c index 6b8e1c0..361ff99 100644 --- a/ustream-mbedtls.c +++ b/ustream-mbedtls.c @@ -85,9 +85,32 @@ static int s_ustream_write(void *ctx, const unsigned char *buf, size_t len) return ret; } -__hidden void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *conn) +static int s_fd_read(void *ctx, unsigned char *buf, size_t len) { - mbedtls_ssl_set_bio(ssl, conn, s_ustream_write, s_ustream_read, NULL); + struct uloop_fd *ufd = ctx; + mbedtls_net_context net = { + .fd = ufd->fd + }; + + return mbedtls_net_recv(&net, buf, len); +} + +static int s_fd_write(void *ctx, const unsigned char *buf, size_t len) +{ + struct uloop_fd *ufd = ctx; + mbedtls_net_context net = { + .fd = ufd->fd + }; + + return mbedtls_net_send(&net, buf, len); +} + +__hidden void ustream_set_io(struct ustream_ssl *us) +{ + if (us->conn) + mbedtls_ssl_set_bio(us->ssl, us->conn, s_ustream_write, s_ustream_read, NULL); + else + mbedtls_ssl_set_bio(us->ssl, &us->fd, s_fd_write, s_fd_read, NULL); } static int _random(void *ctx, unsigned char *out, size_t len) @@ -553,8 +576,8 @@ __hidden void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx) return ssl; } -__hidden void __ustream_ssl_session_free(void *ssl) +__hidden void __ustream_ssl_session_free(struct ustream_ssl *us) { - mbedtls_ssl_free(ssl); - free(ssl); + mbedtls_ssl_free(us->ssl); + free(us->ssl); } diff --git a/ustream-mbedtls.h b/ustream-mbedtls.h index 31df680..281b919 100644 --- a/ustream-mbedtls.h +++ b/ustream-mbedtls.h @@ -64,7 +64,6 @@ static inline void __ustream_ssl_update_peer_cn(struct ustream_ssl *us) mbedtls_ssl_set_hostname(us->ssl, us->peer_cn); } -void __ustream_ssl_session_free(void *ssl); void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx); #endif diff --git a/ustream-openssl.c b/ustream-openssl.c index 3d576be..b080081 100644 --- a/ustream-openssl.c +++ b/ustream-openssl.c @@ -245,13 +245,18 @@ __hidden void __ustream_ssl_context_free(struct ustream_ssl_ctx *ctx) free(ctx); } -void __ustream_ssl_session_free(void *ssl) +__hidden void __ustream_ssl_session_free(struct ustream_ssl *us) { - BIO *bio = SSL_get_wbio(ssl); - struct bio_ctx *ctx = BIO_get_data(bio); + BIO *bio = SSL_get_wbio(us->ssl); + struct bio_ctx *ctx; - SSL_shutdown(ssl); - SSL_free(ssl); + SSL_shutdown(us->ssl); + SSL_free(us->ssl); + + if (!us->conn) + return; + + ctx = BIO_get_data(bio); if (ctx) { BIO_meth_free(ctx->meth); free(ctx); diff --git a/ustream-openssl.h b/ustream-openssl.h index f547aa6..847f5aa 100644 --- a/ustream-openssl.h +++ b/ustream-openssl.h @@ -36,8 +36,6 @@ struct ustream_ssl_ctx { void *debug_cb_priv; }; -void __ustream_ssl_session_free(void *ssl); - struct bio_ctx { BIO_METHOD *meth; struct ustream *stream; diff --git a/ustream-ssl.c b/ustream-ssl.c index d3048ca..b076299 100644 --- a/ustream-ssl.c +++ b/ustream-ssl.c @@ -67,9 +67,8 @@ static void ustream_ssl_check_conn(struct ustream_ssl *us) } } -static bool __ustream_ssl_poll(struct ustream *s) +static bool __ustream_ssl_poll(struct ustream_ssl *us) { - struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream); char *buf; int len, ret; bool more = false; @@ -85,7 +84,8 @@ static bool __ustream_ssl_poll(struct ustream *s) ret = __ustream_ssl_read(us, buf, len); if (ret == U_SSL_PENDING) { - ustream_poll(us->conn); + if (us->conn) + ustream_poll(us->conn); ret = __ustream_ssl_read(us, buf, len); } @@ -110,7 +110,9 @@ static bool __ustream_ssl_poll(struct ustream *s) static void ustream_ssl_notify_read(struct ustream *s, int bytes) { - __ustream_ssl_poll(s); + struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream); + + __ustream_ssl_poll(us); } static void ustream_ssl_notify_write(struct ustream *s, int bytes) @@ -134,7 +136,7 @@ static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool m if (!us->connected || us->error) return 0; - if (us->conn->w.data_bytes) + if (us->conn && us->conn->w.data_bytes) return 0; return __ustream_ssl_write(us, buf, len); @@ -143,8 +145,17 @@ static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool m static void ustream_ssl_set_read_blocked(struct ustream *s) { struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream); + unsigned int ev = ULOOP_WRITE | ULOOP_EDGE_TRIGGER; + + if (us->conn) { + ustream_set_read_blocked(us->conn, !!s->read_blocked); + return; + } + + if (!s->read_blocked) + ev |= ULOOP_READ; - ustream_set_read_blocked(us->conn, !!s->read_blocked); + uloop_fd_add(&us->fd, ev); } static void ustream_ssl_free(struct ustream *s) @@ -156,10 +167,12 @@ static void ustream_ssl_free(struct ustream *s) us->conn->notify_read = NULL; us->conn->notify_write = NULL; us->conn->notify_state = NULL; + } else { + uloop_fd_delete(&us->fd); } uloop_timeout_cancel(&us->error_timer); - __ustream_ssl_session_free(us->ssl); + __ustream_ssl_session_free(us); free(us->peer_cn); us->ctx = NULL; @@ -175,10 +188,19 @@ static void ustream_ssl_free(struct ustream *s) static bool ustream_ssl_poll(struct ustream *s) { struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream); - bool fd_poll; + bool fd_poll = false; + + if (us->conn) + fd_poll = ustream_poll(us->conn); + + return __ustream_ssl_poll(us) || fd_poll; +} + +static void ustream_ssl_fd_cb(struct uloop_fd *fd, unsigned int events) +{ + struct ustream_ssl *us = container_of(fd, struct ustream_ssl, fd); - fd_poll = ustream_poll(us->conn); - return __ustream_ssl_poll(us->conn) || fd_poll; + __ustream_ssl_poll(us); } static void ustream_ssl_stream_init(struct ustream_ssl *us) @@ -186,31 +208,31 @@ static void ustream_ssl_stream_init(struct ustream_ssl *us) struct ustream *conn = us->conn; struct ustream *s = &us->stream; - conn->notify_read = ustream_ssl_notify_read; - conn->notify_write = ustream_ssl_notify_write; - conn->notify_state = ustream_ssl_notify_state; + if (conn) { + conn->notify_read = ustream_ssl_notify_read; + conn->notify_write = ustream_ssl_notify_write; + conn->notify_state = ustream_ssl_notify_state; + } else { + us->fd.cb = ustream_ssl_fd_cb; + uloop_fd_add(&us->fd, ULOOP_READ | ULOOP_WRITE | ULOOP_EDGE_TRIGGER); + } + s->set_read_blocked = ustream_ssl_set_read_blocked; s->free = ustream_ssl_free; s->write = ustream_ssl_write; s->poll = ustream_ssl_poll; - s->set_read_blocked = ustream_ssl_set_read_blocked; ustream_init_defaults(s); } -static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server) +static int _ustream_ssl_init_common(struct ustream_ssl *us) { us->error_timer.cb = ustream_ssl_error_cb; - us->server = server; - us->conn = conn; - us->ctx = ctx; us->ssl = __ustream_ssl_session_new(us->ctx); if (!us->ssl) return -ENOMEM; - conn->r.max_buffers = 4; - conn->next = &us->stream; - ustream_set_io(ctx, us->ssl, conn); + ustream_set_io(us); ustream_ssl_stream_init(us); if (us->server_name) @@ -221,6 +243,27 @@ static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struc return 0; } +static int _ustream_ssl_init_fd(struct ustream_ssl *us, int fd, struct ustream_ssl_ctx *ctx, bool server) +{ + us->server = server; + us->ctx = ctx; + us->fd.fd = fd; + + return _ustream_ssl_init_common(us); +} + +static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server) +{ + us->server = server; + us->ctx = ctx; + + us->conn = conn; + conn->r.max_buffers = 4; + conn->next = &us->stream; + + return _ustream_ssl_init_common(us); +} + static int _ustream_ssl_set_peer_cn(struct ustream_ssl *us, const char *name) { us->peer_cn = strdup(name); @@ -239,5 +282,6 @@ const struct ustream_ssl_ops ustream_ssl_ops = { .context_set_debug = __ustream_ssl_set_debug, .context_free = __ustream_ssl_context_free, .init = _ustream_ssl_init, + .init_fd = _ustream_ssl_init_fd, .set_peer_cn = _ustream_ssl_set_peer_cn, }; diff --git a/ustream-ssl.h b/ustream-ssl.h index b1115c6..fe545f4 100644 --- a/ustream-ssl.h +++ b/ustream-ssl.h @@ -25,6 +25,7 @@ struct ustream_ssl { struct ustream stream; struct ustream *conn; struct uloop_timeout error_timer; + struct uloop_fd fd; void (*notify_connected)(struct ustream_ssl *us); void (*notify_error)(struct ustream_ssl *us, int error, const char *str); @@ -56,6 +57,7 @@ struct ustream_ssl_ops { int (*context_add_ca_crt_file)(struct ustream_ssl_ctx *ctx, const char *file); void (*context_free)(struct ustream_ssl_ctx *ctx); + int (*init_fd)(struct ustream_ssl *us, int fd, struct ustream_ssl_ctx *ctx, bool server); int (*init)(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server); int (*set_peer_cn)(struct ustream_ssl *conn, const char *name); -- 2.30.2