mbedtls: handle session tickets for TLS 1.3
authorFelix Fietkau <nbd@nbd.name>
Thu, 18 Apr 2024 10:42:01 +0000 (12:42 +0200)
committerFelix Fietkau <nbd@nbd.name>
Thu, 18 Apr 2024 11:10:54 +0000 (13:10 +0200)
Store them inside the context in order to handle reconnect

Signed-off-by: Felix Fietkau <nbd@nbd.name>
ustream-internal.h
ustream-mbedtls.c
ustream-mbedtls.h

index f8f28e1ab9a3f0e2bd707e47c93cb7100faad1fc..50e105f0ddb6f90c5e6305ffb40f19c7fabb5720 100644 (file)
@@ -31,6 +31,7 @@ enum ssl_conn_status {
        U_SSL_OK = 0,
        U_SSL_PENDING = -1,
        U_SSL_ERROR = -2,
+       U_SSL_RETRY = -3,
 };
 
 void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *s);
index ff2c9a9aad936deba7ac1f18f6b3b929355ae6dd..b6711488139d4bd916e13c9e13f1821c0201af33 100644 (file)
@@ -361,6 +361,7 @@ __hidden int __ustream_ssl_set_require_validation(struct ustream_ssl_ctx *ctx, b
 
 __hidden void __ustream_ssl_context_free(struct ustream_ssl_ctx *ctx)
 {
+       free(ctx->session_data);
 #if defined(MBEDTLS_SSL_CACHE_C)
        mbedtls_ssl_cache_free(&ctx->cache);
 #endif
@@ -378,14 +379,48 @@ static void ustream_ssl_error(struct ustream_ssl *us, int ret)
        uloop_timeout_set(&us->error_timer, 0);
 }
 
-static bool ssl_do_wait(int ret)
+static void
+__ustream_ssl_save_session(struct ustream_ssl *us)
+{
+       struct ustream_ssl_ctx *ctx = us->ctx;
+       mbedtls_ssl_session sess;
+
+       if (ctx->server)
+               return;
+
+       free(ctx->session_data);
+       ctx->session_data = NULL;
+
+       mbedtls_ssl_session_init(&sess);
+       if (mbedtls_ssl_get_session(us->ssl, &sess) != 0)
+               return;
+
+       mbedtls_ssl_session_save(&sess, NULL, 0, &ctx->session_data_len);
+       ctx->session_data = malloc(ctx->session_data_len);
+       if (mbedtls_ssl_session_save(&sess, ctx->session_data, ctx->session_data_len,
+                                    &ctx->session_data_len))
+               ctx->session_data_len = 0;
+       mbedtls_ssl_session_free(&sess);
+}
+
+static int ssl_check_return(struct ustream_ssl *us, int ret)
 {
        switch(ret) {
        case MBEDTLS_ERR_SSL_WANT_READ:
        case MBEDTLS_ERR_SSL_WANT_WRITE:
-               return true;
+               return U_SSL_PENDING;
+       case MBEDTLS_ERR_SSL_RECEIVED_NEW_SESSION_TICKET:
+#ifdef MBEDTLS_ECP_RESTARTABLE
+       case MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS:
+#endif
+               __ustream_ssl_save_session(us);
+               return U_SSL_RETRY;
+       case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
+       case MBEDTLS_ERR_NET_CONN_RESET:
+               return 0;
        default:
-               return false;
+               ustream_ssl_error(us, ret);
+               return U_SSL_ERROR;
        }
 }
 
@@ -424,17 +459,17 @@ __hidden enum ssl_conn_status __ustream_ssl_connect(struct ustream_ssl *us)
        void *ssl = us->ssl;
        int r;
 
-       r = mbedtls_ssl_handshake(ssl);
-       if (r == 0) {
-               ustream_ssl_verify_cert(us);
-               return U_SSL_OK;
-       }
+       do {
+               r = mbedtls_ssl_handshake(ssl);
+               if (r == 0) {
+                       ustream_ssl_verify_cert(us);
+                       return U_SSL_OK;
+               }
 
-       if (ssl_do_wait(r))
-               return U_SSL_PENDING;
+               r = ssl_check_return(us, r);
+       } while (r == U_SSL_RETRY);
 
-       ustream_ssl_error(us, r);
-       return U_SSL_ERROR;
+       return r;
 }
 
 __hidden int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int len)
@@ -444,12 +479,14 @@ __hidden int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int le
 
        while (done != len) {
                ret = mbedtls_ssl_write(ssl, (const unsigned char *) buf + done, len - done);
-
                if (ret < 0) {
-                       if (ssl_do_wait(ret))
+                       ret = ssl_check_return(us, ret);
+                       if (ret == U_SSL_RETRY)
+                               continue;
+
+                       if (ret == U_SSL_PENDING)
                                return done;
 
-                       ustream_ssl_error(us, ret);
                        return -1;
                }
 
@@ -461,18 +498,15 @@ __hidden int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int le
 
 __hidden int __ustream_ssl_read(struct ustream_ssl *us, char *buf, int len)
 {
-       int ret = mbedtls_ssl_read(us->ssl, (unsigned char *) buf, len);
-
-       if (ret < 0) {
-               if (ssl_do_wait(ret))
-                       return U_SSL_PENDING;
+       int ret;
 
-               if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
-                       return 0;
+       do {
+               ret = mbedtls_ssl_read(us->ssl, (unsigned char *) buf, len);
+               if (ret >= 0)
+                       return ret;
 
-               ustream_ssl_error(us, ret);
-               return U_SSL_ERROR;
-       }
+               ret = ssl_check_return(us, ret);
+       } while (ret == U_SSL_RETRY);
 
        return ret;
 }
@@ -491,6 +525,7 @@ __hidden void __ustream_ssl_set_debug(struct ustream_ssl_ctx *ctx, int level,
 __hidden void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx)
 {
        mbedtls_ssl_context *ssl;
+       mbedtls_ssl_session sess;
 
        ssl = calloc(1, sizeof(*ssl));
        if (!ssl)
@@ -503,6 +538,13 @@ __hidden void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx)
                return NULL;
        }
 
+       if (!ctx->session_data_len)
+               return ssl;
+
+       mbedtls_ssl_session_init(&sess);
+       if (mbedtls_ssl_session_load(&sess, ctx->session_data, ctx->session_data_len) == 0)
+               mbedtls_ssl_set_session(ssl, &sess);
+
        return ssl;
 }
 
index ff907f6d5bc1708d618c5384f0cd094406f68ccc..31df680d2e1d9915bf57b4de54675afda0619d3e 100644 (file)
@@ -43,6 +43,9 @@ struct ustream_ssl_ctx {
        void *debug_cb_priv;
        bool server;
        int *ciphersuites;
+
+       void *session_data;
+       size_t session_data_len;
 };
 
 static inline char *__ustream_ssl_strerror(int error, char *buffer, int len)