add callbacks for debug messages
[project/ustream-ssl.git] / ustream-mbedtls.c
1 /*
2 * ustream-ssl - library for SSL over ustream
3 *
4 * Copyright (C) 2012 Felix Fietkau <nbd@openwrt.org>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19 #include <sys/types.h>
20 #include <sys/random.h>
21 #include <fcntl.h>
22 #include <unistd.h>
23 #include <stdlib.h>
24 #include <string.h>
25
26 #include "ustream-ssl.h"
27 #include "ustream-internal.h"
28 #include <psa/crypto.h>
29 #include <mbedtls/debug.h>
30
31 static void debug_cb(void *ctx_p, int level,
32 const char *file, int line,
33 const char *str)
34 {
35 struct ustream_ssl_ctx *ctx = ctx_p;
36 const char *fstr;
37 char buf[512];
38 int len;
39
40 if (!ctx->debug_cb)
41 return;
42
43 while ((fstr = strstr(file + 1, "library/")) != NULL)
44 file = fstr;
45
46 len = snprintf(buf, sizeof(buf), "%s:%04d: %s", file, line, str);
47 if (len >= (int)sizeof(buf))
48 len = (int)sizeof(buf) - 1;
49 if (buf[len - 1] == '\n')
50 buf[len - 1] = 0;
51 ctx->debug_cb(ctx->debug_cb_priv, level, buf);
52 }
53
54 static int s_ustream_read(void *ctx, unsigned char *buf, size_t len)
55 {
56 struct ustream *s = ctx;
57 char *sbuf;
58 int slen;
59
60 if (s->eof)
61 return 0;
62
63 sbuf = ustream_get_read_buf(s, &slen);
64 if ((size_t) slen > len)
65 slen = len;
66
67 if (!slen)
68 return MBEDTLS_ERR_SSL_WANT_READ;
69
70 memcpy(buf, sbuf, slen);
71 ustream_consume(s, slen);
72
73 return slen;
74 }
75
76 static int s_ustream_write(void *ctx, const unsigned char *buf, size_t len)
77 {
78 struct ustream *s = ctx;
79 int ret;
80
81 ret = ustream_write(s, (const char *) buf, len, false);
82 if (ret < 0 || s->write_error)
83 return MBEDTLS_ERR_NET_SEND_FAILED;
84
85 return ret;
86 }
87
88 __hidden void ustream_set_io(struct ustream_ssl_ctx *ctx, void *ssl, struct ustream *conn)
89 {
90 mbedtls_ssl_set_bio(ssl, conn, s_ustream_write, s_ustream_read, NULL);
91 }
92
93 static int _random(void *ctx, unsigned char *out, size_t len)
94 {
95 #ifdef linux
96 if (getrandom(out, len, 0) != (ssize_t) len)
97 return MBEDTLS_ERR_ENTROPY_SOURCE_FAILED;
98 #else
99 static FILE *f;
100
101 if (!f)
102 f = fopen("/dev/urandom", "r");
103 if (fread(out, len, 1, f) != 1)
104 return MBEDTLS_ERR_ENTROPY_SOURCE_FAILED;
105 #endif
106
107 return 0;
108 }
109
110 #define AES_GCM_CIPHERS(v) \
111 MBEDTLS_TLS_##v##_WITH_AES_128_GCM_SHA256, \
112 MBEDTLS_TLS_##v##_WITH_AES_256_GCM_SHA384
113
114 #define AES_CBC_CIPHERS(v) \
115 MBEDTLS_TLS_##v##_WITH_AES_128_CBC_SHA, \
116 MBEDTLS_TLS_##v##_WITH_AES_256_CBC_SHA
117
118 #define AES_CIPHERS(v) \
119 AES_GCM_CIPHERS(v), \
120 AES_CBC_CIPHERS(v)
121
122 static const int default_ciphersuites_server[] =
123 {
124 MBEDTLS_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
125 AES_GCM_CIPHERS(ECDHE_ECDSA),
126 MBEDTLS_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
127 AES_GCM_CIPHERS(ECDHE_RSA),
128 AES_CBC_CIPHERS(ECDHE_RSA),
129 AES_CIPHERS(RSA),
130 0
131 };
132
133 static const int default_ciphersuites_client[] =
134 {
135 MBEDTLS_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
136 AES_GCM_CIPHERS(ECDHE_ECDSA),
137 MBEDTLS_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
138 AES_GCM_CIPHERS(ECDHE_RSA),
139 MBEDTLS_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
140 AES_GCM_CIPHERS(DHE_RSA),
141 AES_CBC_CIPHERS(ECDHE_ECDSA),
142 AES_CBC_CIPHERS(ECDHE_RSA),
143 AES_CBC_CIPHERS(DHE_RSA),
144 /* Removed in Mbed TLS 3.0.0 */
145 #ifdef MBEDTLS_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA
146 MBEDTLS_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA,
147 #endif
148 AES_CIPHERS(RSA),
149 /* Removed in Mbed TLS 3.0.0 */
150 #ifdef MBEDTLS_TLS_RSA_WITH_3DES_EDE_CBC_SHA
151 MBEDTLS_TLS_RSA_WITH_3DES_EDE_CBC_SHA,
152 #endif
153 0
154 };
155
156
157 __hidden struct ustream_ssl_ctx *
158 __ustream_ssl_context_new(bool server)
159 {
160 struct ustream_ssl_ctx *ctx;
161 mbedtls_ssl_config *conf;
162 int ep;
163
164 #ifdef MBEDTLS_PSA_CRYPTO_C
165 static bool psa_init;
166
167 if (!psa_init && !psa_crypto_init())
168 psa_init = true;
169 #endif
170
171 ctx = calloc(1, sizeof(*ctx));
172 if (!ctx)
173 return NULL;
174
175 ctx->server = server;
176 mbedtls_pk_init(&ctx->key);
177 mbedtls_x509_crt_init(&ctx->cert);
178 mbedtls_x509_crt_init(&ctx->ca_cert);
179
180 #if defined(MBEDTLS_SSL_CACHE_C)
181 mbedtls_ssl_cache_init(&ctx->cache);
182 mbedtls_ssl_cache_set_timeout(&ctx->cache, 30 * 60);
183 mbedtls_ssl_cache_set_max_entries(&ctx->cache, 5);
184 #endif
185
186 conf = &ctx->conf;
187 mbedtls_ssl_config_init(conf);
188
189 ep = server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
190
191 mbedtls_ssl_config_defaults(conf, ep, MBEDTLS_SSL_TRANSPORT_STREAM,
192 MBEDTLS_SSL_PRESET_DEFAULT);
193 mbedtls_ssl_conf_rng(conf, _random, NULL);
194
195 if (server) {
196 mbedtls_ssl_conf_authmode(conf, MBEDTLS_SSL_VERIFY_NONE);
197 mbedtls_ssl_conf_ciphersuites(conf, default_ciphersuites_server);
198 mbedtls_ssl_conf_min_version(conf, MBEDTLS_SSL_MAJOR_VERSION_3,
199 MBEDTLS_SSL_MINOR_VERSION_3);
200 } else {
201 mbedtls_ssl_conf_authmode(conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
202 mbedtls_ssl_conf_ciphersuites(conf, default_ciphersuites_client);
203 }
204
205 #if defined(MBEDTLS_SSL_CACHE_C)
206 mbedtls_ssl_conf_session_cache(conf, &ctx->cache,
207 mbedtls_ssl_cache_get,
208 mbedtls_ssl_cache_set);
209 #endif
210 return ctx;
211 }
212
213 static void ustream_ssl_update_own_cert(struct ustream_ssl_ctx *ctx)
214 {
215 if (!ctx->cert.version)
216 return;
217
218 if (mbedtls_pk_get_type(&ctx->key) == MBEDTLS_PK_NONE)
219 return;
220
221 mbedtls_ssl_conf_own_cert(&ctx->conf, &ctx->cert, &ctx->key);
222 }
223
224 __hidden int __ustream_ssl_add_ca_crt_file(struct ustream_ssl_ctx *ctx, const char *file)
225 {
226 int ret;
227
228 ret = mbedtls_x509_crt_parse_file(&ctx->ca_cert, file);
229 if (ret)
230 return -1;
231
232 mbedtls_ssl_conf_ca_chain(&ctx->conf, &ctx->ca_cert, NULL);
233 mbedtls_ssl_conf_authmode(&ctx->conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
234 return 0;
235 }
236
237 __hidden int __ustream_ssl_set_crt_file(struct ustream_ssl_ctx *ctx, const char *file)
238 {
239 int ret;
240
241 ret = mbedtls_x509_crt_parse_file(&ctx->cert, file);
242 if (ret)
243 return -1;
244
245 ustream_ssl_update_own_cert(ctx);
246 return 0;
247 }
248
249 __hidden int __ustream_ssl_set_key_file(struct ustream_ssl_ctx *ctx, const char *file)
250 {
251 int ret;
252
253 #if (MBEDTLS_VERSION_NUMBER >= 0x03000000)
254 ret = mbedtls_pk_parse_keyfile(&ctx->key, file, NULL, _random, NULL);
255 #else
256 ret = mbedtls_pk_parse_keyfile(&ctx->key, file, NULL);
257 #endif
258 if (ret)
259 return -1;
260
261 ustream_ssl_update_own_cert(ctx);
262 return 0;
263 }
264
265 __hidden int __ustream_ssl_set_ciphers(struct ustream_ssl_ctx *ctx, const char *ciphers)
266 {
267 int *ciphersuites = NULL, *tmp, id;
268 char *cipherstr, *p, *last, c;
269 size_t len = 0;
270
271 if (ciphers == NULL)
272 return -1;
273
274 cipherstr = strdup(ciphers);
275
276 if (cipherstr == NULL)
277 return -1;
278
279 for (p = cipherstr, last = p;; p++) {
280 if (*p == ':' || *p == 0) {
281 c = *p;
282 *p = 0;
283
284 id = mbedtls_ssl_get_ciphersuite_id(last);
285
286 if (id != 0) {
287 tmp = realloc(ciphersuites, (len + 2) * sizeof(int));
288
289 if (tmp == NULL) {
290 free(ciphersuites);
291 free(cipherstr);
292
293 return -1;
294 }
295
296 ciphersuites = tmp;
297 ciphersuites[len++] = id;
298 ciphersuites[len] = 0;
299 }
300
301 if (c == 0)
302 break;
303
304 last = p + 1;
305 }
306
307 /*
308 * mbedTLS expects cipher names with dashes while many sources elsewhere
309 * like the Firefox wiki or Wireshark specify ciphers with underscores,
310 * so simply convert all underscores to dashes to accept both notations.
311 */
312 else if (*p == '_') {
313 *p = '-';
314 }
315 }
316
317 free(cipherstr);
318
319 if (len == 0)
320 return -1;
321
322 mbedtls_ssl_conf_ciphersuites(&ctx->conf, ciphersuites);
323 free(ctx->ciphersuites);
324
325 ctx->ciphersuites = ciphersuites;
326
327 return 0;
328 }
329
330 __hidden int __ustream_ssl_set_require_validation(struct ustream_ssl_ctx *ctx, bool require)
331 {
332 int mode = MBEDTLS_SSL_VERIFY_OPTIONAL;
333
334 if (!require)
335 mode = MBEDTLS_SSL_VERIFY_NONE;
336
337 mbedtls_ssl_conf_authmode(&ctx->conf, mode);
338
339 return 0;
340 }
341
342 __hidden void __ustream_ssl_context_free(struct ustream_ssl_ctx *ctx)
343 {
344 #if defined(MBEDTLS_SSL_CACHE_C)
345 mbedtls_ssl_cache_free(&ctx->cache);
346 #endif
347 mbedtls_pk_free(&ctx->key);
348 mbedtls_x509_crt_free(&ctx->ca_cert);
349 mbedtls_x509_crt_free(&ctx->cert);
350 mbedtls_ssl_config_free(&ctx->conf);
351 free(ctx->ciphersuites);
352 free(ctx);
353 }
354
355 static void ustream_ssl_error(struct ustream_ssl *us, int ret)
356 {
357 us->error = ret;
358 uloop_timeout_set(&us->error_timer, 0);
359 }
360
361 static bool ssl_do_wait(int ret)
362 {
363 switch(ret) {
364 case MBEDTLS_ERR_SSL_WANT_READ:
365 case MBEDTLS_ERR_SSL_WANT_WRITE:
366 return true;
367 default:
368 return false;
369 }
370 }
371
372 static void ustream_ssl_verify_cert(struct ustream_ssl *us)
373 {
374 void *ssl = us->ssl;
375 const char *msg = NULL;
376 bool cn_mismatch;
377 int r;
378
379 r = mbedtls_ssl_get_verify_result(ssl);
380 cn_mismatch = r & MBEDTLS_X509_BADCERT_CN_MISMATCH;
381 r &= ~MBEDTLS_X509_BADCERT_CN_MISMATCH;
382
383 if (r & MBEDTLS_X509_BADCERT_EXPIRED)
384 msg = "certificate has expired";
385 else if (r & MBEDTLS_X509_BADCERT_REVOKED)
386 msg = "certificate has been revoked";
387 else if (r & MBEDTLS_X509_BADCERT_NOT_TRUSTED)
388 msg = "certificate is self-signed or not signed by a trusted CA";
389 else
390 msg = "unknown error";
391
392 if (r) {
393 if (us->notify_verify_error)
394 us->notify_verify_error(us, r, msg);
395 return;
396 }
397
398 if (!cn_mismatch)
399 us->valid_cn = true;
400 }
401
402 __hidden enum ssl_conn_status __ustream_ssl_connect(struct ustream_ssl *us)
403 {
404 void *ssl = us->ssl;
405 int r;
406
407 r = mbedtls_ssl_handshake(ssl);
408 if (r == 0) {
409 ustream_ssl_verify_cert(us);
410 return U_SSL_OK;
411 }
412
413 if (ssl_do_wait(r))
414 return U_SSL_PENDING;
415
416 ustream_ssl_error(us, r);
417 return U_SSL_ERROR;
418 }
419
420 __hidden int __ustream_ssl_write(struct ustream_ssl *us, const char *buf, int len)
421 {
422 void *ssl = us->ssl;
423 int done = 0, ret = 0;
424
425 while (done != len) {
426 ret = mbedtls_ssl_write(ssl, (const unsigned char *) buf + done, len - done);
427
428 if (ret < 0) {
429 if (ssl_do_wait(ret))
430 return done;
431
432 ustream_ssl_error(us, ret);
433 return -1;
434 }
435
436 done += ret;
437 }
438
439 return done;
440 }
441
442 __hidden int __ustream_ssl_read(struct ustream_ssl *us, char *buf, int len)
443 {
444 int ret = mbedtls_ssl_read(us->ssl, (unsigned char *) buf, len);
445
446 if (ret < 0) {
447 if (ssl_do_wait(ret))
448 return U_SSL_PENDING;
449
450 if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
451 return 0;
452
453 ustream_ssl_error(us, ret);
454 return U_SSL_ERROR;
455 }
456
457 return ret;
458 }
459
460 __hidden void __ustream_ssl_set_debug(struct ustream_ssl_ctx *ctx, int level,
461 ustream_ssl_debug_cb cb, void *cb_priv)
462 {
463 ctx->debug_cb = cb;
464 ctx->debug_cb_priv = cb_priv;
465 mbedtls_ssl_conf_dbg(&ctx->conf, debug_cb, ctx);
466 mbedtls_debug_set_threshold(level);
467 }
468
469 __hidden void *__ustream_ssl_session_new(struct ustream_ssl_ctx *ctx)
470 {
471 mbedtls_ssl_context *ssl;
472
473 ssl = calloc(1, sizeof(*ssl));
474 if (!ssl)
475 return NULL;
476
477 mbedtls_ssl_init(ssl);
478
479 if (mbedtls_ssl_setup(ssl, &ctx->conf)) {
480 free(ssl);
481 return NULL;
482 }
483
484 return ssl;
485 }
486
487 __hidden void __ustream_ssl_session_free(void *ssl)
488 {
489 mbedtls_ssl_free(ssl);
490 free(ssl);
491 }