From 60d8fbb5e669db4b85f0ccd9b86744a8355eb2d9 Mon Sep 17 00:00:00 2001 From: Felix Fietkau Date: Thu, 18 Apr 2024 12:42:01 +0200 Subject: [PATCH] mbedtls: handle session tickets for TLS 1.3 Store them inside the context in order to handle reconnect Signed-off-by: Felix Fietkau --- ustream-internal.h | 1 + ustream-mbedtls.c | 92 +++++++++++++++++++++++++++++++++------------- ustream-mbedtls.h | 3 ++ 3 files changed, 71 insertions(+), 25 deletions(-) diff --git a/ustream-internal.h b/ustream-internal.h index f8f28e1..50e105f 100644 --- a/ustream-internal.h +++ b/ustream-internal.h @@ -31,6 +31,7 @@ enum ssl_conn_status { U_SSL_OK = 0, U_SSL_PENDING = -1, U_SSL_ERROR = -2, + U_SSL_RETRY = -3, }; void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *s); diff --git a/ustream-mbedtls.c b/ustream-mbedtls.c index ff2c9a9..b671148 100644 --- a/ustream-mbedtls.c +++ b/ustream-mbedtls.c @@ -361,6 +361,7 @@ __hidden int __ustream_ssl_set_require_validation(struct ustream_ssl_ctx *ctx, b __hidden void __ustream_ssl_context_free(struct ustream_ssl_ctx *ctx) { + free(ctx->session_data); #if defined(MBEDTLS_SSL_CACHE_C) mbedtls_ssl_cache_free(&ctx->cache); #endif @@ -378,14 +379,48 @@ static void ustream_ssl_error(struct ustream_ssl *us, int ret) uloop_timeout_set(&us->error_timer, 0); } -static bool ssl_do_wait(int ret) +static void +__ustream_ssl_save_session(struct ustream_ssl *us) +{ + struct ustream_ssl_ctx *ctx = us->ctx; + mbedtls_ssl_session sess; + + if (ctx->server) + return; + + free(ctx->session_data); + ctx->session_data = NULL; + + mbedtls_ssl_session_init(&sess); + if (mbedtls_ssl_get_session(us->ssl, &sess) != 0) + return; + + mbedtls_ssl_session_save(&sess, NULL, 0, &ctx->session_data_len); + ctx->session_data = malloc(ctx->session_data_len); + if (mbedtls_ssl_session_save(&sess, ctx->session_data, ctx->session_data_len, + &ctx->session_data_len)) + ctx->session_data_len = 0; + mbedtls_ssl_session_free(&sess); +} + +static int ssl_check_return(struct ustream_ssl *us, int ret) { switch(ret) { case MBEDTLS_ERR_SSL_WANT_READ: case MBEDTLS_ERR_SSL_WANT_WRITE: - return true; + return U_SSL_PENDING; + case MBEDTLS_ERR_SSL_RECEIVED_NEW_SESSION_TICKET: +#ifdef MBEDTLS_ECP_RESTARTABLE + case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS: +#endif + __ustream_ssl_save_session(us); + return U_SSL_RETRY; + case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY: + case MBEDTLS_ERR_NET_CONN_RESET: + return 0; default: - return false; + ustream_ssl_error(us, ret); + return U_SSL_ERROR; } } @@ -424,17 +459,17 @@ __hidden enum ssl_conn_status __ustream_ssl_connect(struct ustream_ssl *us) void *ssl = us->ssl; int r; - r = mbedtls_ssl_handshake(ssl); - if (r == 0) { - ustream_ssl_verify_cert(us); - return U_SSL_OK; - } + do { + r = mbedtls_ssl_handshake(ssl); + if (r == 0) { + ustream_ssl_verify_cert(us); + return U_SSL_OK; + } - if (ssl_do_wait(r)) - return U_SSL_PENDING; + r = ssl_check_return(us, r); + } while (r == U_SSL_RETRY); - ustream_ssl_error(us, r); - return U_SSL_ERROR; + return r; } __hidden int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int len) @@ -444,12 +479,14 @@ __hidden int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int le while (done != len) { ret = mbedtls_ssl_write(ssl, (const unsigned char *) buf + done, len - done); - if (ret < 0) { - if (ssl_do_wait(ret)) + ret = ssl_check_return(us, ret); + if (ret == U_SSL_RETRY) + continue; + + if (ret == U_SSL_PENDING) return done; - ustream_ssl_error(us, ret); return -1; } @@ -461,18 +498,15 @@ __hidden int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int le __hidden int __ustream_ssl_read(struct ustream_ssl *us, char *buf, int len) { - int ret = mbedtls_ssl_read(us->ssl, (unsigned char *) buf, len); - - if (ret < 0) { - if (ssl_do_wait(ret)) - return U_SSL_PENDING; + int ret; - if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) - return 0; + do { + ret = mbedtls_ssl_read(us->ssl, (unsigned char *) buf, len); + if (ret >= 0) + return ret; - ustream_ssl_error(us, ret); - return U_SSL_ERROR; - } + ret = ssl_check_return(us, ret); + } while (ret == U_SSL_RETRY); return ret; } @@ -491,6 +525,7 @@ __hidden void __ustream_ssl_set_debug(struct ustream_ssl_ctx *ctx, int level, __hidden void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx) { mbedtls_ssl_context *ssl; + mbedtls_ssl_session sess; ssl = calloc(1, sizeof(*ssl)); if (!ssl) @@ -503,6 +538,13 @@ __hidden void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx) return NULL; } + if (!ctx->session_data_len) + return ssl; + + mbedtls_ssl_session_init(&sess); + if (mbedtls_ssl_session_load(&sess, ctx->session_data, ctx->session_data_len) == 0) + mbedtls_ssl_set_session(ssl, &sess); + return ssl; } diff --git a/ustream-mbedtls.h b/ustream-mbedtls.h index ff907f6..31df680 100644 --- a/ustream-mbedtls.h +++ b/ustream-mbedtls.h @@ -43,6 +43,9 @@ struct ustream_ssl_ctx { void *debug_cb_priv; bool server; int *ciphersuites; + + void *session_data; + size_t session_data_len; }; static inline char *__ustream_ssl_strerror(int error, char *buffer, int len) -- 2.30.2