ustream-ssl: add support for using a fd instead of ustream as backing
[project/ustream-ssl.git] / ustream-ssl.c
1 /*
2 * ustream-ssl - library for SSL over ustream
3 *
4 * Copyright (C) 2012 Felix Fietkau <nbd@openwrt.org>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19 #include <errno.h>
20 #include <stdlib.h>
21 #include <string.h>
22 #include <libubox/ustream.h>
23
24 #include "ustream-ssl.h"
25 #include "ustream-internal.h"
26
27 static void ustream_ssl_error_cb(struct uloop_timeout *t)
28 {
29 struct ustream_ssl *us = container_of(t, struct ustream_ssl, error_timer);
30 static char buffer[128];
31 int error = us->error;
32
33 if (us->notify_error)
34 us->notify_error(us, error, __ustream_ssl_strerror(us->error, buffer, sizeof(buffer)));
35 }
36
37 static void ustream_ssl_check_conn(struct ustream_ssl *us)
38 {
39 if (us->connected || us->error)
40 return;
41
42 if (__ustream_ssl_connect(us) == U_SSL_OK) {
43
44 /* __ustream_ssl_connect() will also return U_SSL_OK when certificate
45 * verification failed!
46 *
47 * Applications may register a custom .notify_verify_error callback in the
48 * struct ustream_ssl which is called upon verification failures, but there
49 * is no straight forward way for the callback to terminate the connection
50 * initiation right away, e.g. through a true or false return value.
51 *
52 * Instead, existing implementations appear to set .eof field of the underlying
53 * ustream in the hope that this inhibits further operations on the stream.
54 *
55 * Declare this informal behaviour "official" and check for the state of the
56 * .eof member after __ustream_ssl_connect() returned, and do not write the
57 * pending data if it is set to true.
58 */
59
60 if (us->stream.eof)
61 return;
62
63 us->connected = true;
64 if (us->notify_connected)
65 us->notify_connected(us);
66 ustream_write_pending(&us->stream);
67 }
68 }
69
70 static bool __ustream_ssl_poll(struct ustream_ssl *us)
71 {
72 char *buf;
73 int len, ret;
74 bool more = false;
75
76 ustream_ssl_check_conn(us);
77 if (!us->connected || us->error)
78 return false;
79
80 do {
81 buf = ustream_reserve(&us->stream, 1, &len);
82 if (!len)
83 break;
84
85 ret = __ustream_ssl_read(us, buf, len);
86 if (ret == U_SSL_PENDING) {
87 if (us->conn)
88 ustream_poll(us->conn);
89 ret = __ustream_ssl_read(us, buf, len);
90 }
91
92 switch (ret) {
93 case U_SSL_PENDING:
94 return more;
95 case U_SSL_ERROR:
96 return false;
97 case 0:
98 us->stream.eof = true;
99 ustream_state_change(&us->stream);
100 return false;
101 default:
102 ustream_fill_read(&us->stream, ret);
103 more = true;
104 continue;
105 }
106 } while (1);
107
108 return more;
109 }
110
111 static void ustream_ssl_notify_read(struct ustream *s, int bytes)
112 {
113 struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
114
115 __ustream_ssl_poll(us);
116 }
117
118 static void ustream_ssl_notify_write(struct ustream *s, int bytes)
119 {
120 struct ustream_ssl *us = container_of(s->next, struct ustream_ssl, stream);
121
122 ustream_ssl_check_conn(us);
123 ustream_write_pending(s->next);
124 }
125
126 static void ustream_ssl_notify_state(struct ustream *s)
127 {
128 s->next->write_error = true;
129 ustream_state_change(s->next);
130 }
131
132 static int ustream_ssl_write(struct ustream *s, const char *buf, int len, bool more)
133 {
134 struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
135
136 if (!us->connected || us->error)
137 return 0;
138
139 if (us->conn && us->conn->w.data_bytes)
140 return 0;
141
142 return __ustream_ssl_write(us, buf, len);
143 }
144
145 static void ustream_ssl_set_read_blocked(struct ustream *s)
146 {
147 struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
148 unsigned int ev = ULOOP_WRITE | ULOOP_EDGE_TRIGGER;
149
150 if (us->conn) {
151 ustream_set_read_blocked(us->conn, !!s->read_blocked);
152 return;
153 }
154
155 if (!s->read_blocked)
156 ev |= ULOOP_READ;
157
158 uloop_fd_add(&us->fd, ev);
159 }
160
161 static void ustream_ssl_free(struct ustream *s)
162 {
163 struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
164
165 if (us->conn) {
166 us->conn->next = NULL;
167 us->conn->notify_read = NULL;
168 us->conn->notify_write = NULL;
169 us->conn->notify_state = NULL;
170 } else {
171 uloop_fd_delete(&us->fd);
172 }
173
174 uloop_timeout_cancel(&us->error_timer);
175 __ustream_ssl_session_free(us);
176 free(us->peer_cn);
177
178 us->ctx = NULL;
179 us->ssl = NULL;
180 us->conn = NULL;
181 us->peer_cn = NULL;
182 us->connected = false;
183 us->error = false;
184 us->valid_cert = false;
185 us->valid_cn = false;
186 }
187
188 static bool ustream_ssl_poll(struct ustream *s)
189 {
190 struct ustream_ssl *us = container_of(s, struct ustream_ssl, stream);
191 bool fd_poll = false;
192
193 if (us->conn)
194 fd_poll = ustream_poll(us->conn);
195
196 return __ustream_ssl_poll(us) || fd_poll;
197 }
198
199 static void ustream_ssl_fd_cb(struct uloop_fd *fd, unsigned int events)
200 {
201 struct ustream_ssl *us = container_of(fd, struct ustream_ssl, fd);
202
203 __ustream_ssl_poll(us);
204 }
205
206 static void ustream_ssl_stream_init(struct ustream_ssl *us)
207 {
208 struct ustream *conn = us->conn;
209 struct ustream *s = &us->stream;
210
211 if (conn) {
212 conn->notify_read = ustream_ssl_notify_read;
213 conn->notify_write = ustream_ssl_notify_write;
214 conn->notify_state = ustream_ssl_notify_state;
215 } else {
216 us->fd.cb = ustream_ssl_fd_cb;
217 uloop_fd_add(&us->fd, ULOOP_READ | ULOOP_WRITE | ULOOP_EDGE_TRIGGER);
218 }
219
220 s->set_read_blocked = ustream_ssl_set_read_blocked;
221 s->free = ustream_ssl_free;
222 s->write = ustream_ssl_write;
223 s->poll = ustream_ssl_poll;
224 ustream_init_defaults(s);
225 }
226
227 static int _ustream_ssl_init_common(struct ustream_ssl *us)
228 {
229 us->error_timer.cb = ustream_ssl_error_cb;
230
231 us->ssl = __ustream_ssl_session_new(us->ctx);
232 if (!us->ssl)
233 return -ENOMEM;
234
235 ustream_set_io(us);
236 ustream_ssl_stream_init(us);
237
238 if (us->server_name)
239 __ustream_ssl_set_server_name(us);
240
241 ustream_ssl_check_conn(us);
242
243 return 0;
244 }
245
246 static int _ustream_ssl_init_fd(struct ustream_ssl *us, int fd, struct ustream_ssl_ctx *ctx, bool server)
247 {
248 us->server = server;
249 us->ctx = ctx;
250 us->fd.fd = fd;
251
252 return _ustream_ssl_init_common(us);
253 }
254
255 static int _ustream_ssl_init(struct ustream_ssl *us, struct ustream *conn, struct ustream_ssl_ctx *ctx, bool server)
256 {
257 us->server = server;
258 us->ctx = ctx;
259
260 us->conn = conn;
261 conn->r.max_buffers = 4;
262 conn->next = &us->stream;
263
264 return _ustream_ssl_init_common(us);
265 }
266
267 static int _ustream_ssl_set_peer_cn(struct ustream_ssl *us, const char *name)
268 {
269 us->peer_cn = strdup(name);
270 __ustream_ssl_update_peer_cn(us);
271
272 return 0;
273 }
274
275 const struct ustream_ssl_ops ustream_ssl_ops = {
276 .context_new = __ustream_ssl_context_new,
277 .context_set_crt_file = __ustream_ssl_set_crt_file,
278 .context_set_key_file = __ustream_ssl_set_key_file,
279 .context_add_ca_crt_file = __ustream_ssl_add_ca_crt_file,
280 .context_set_ciphers = __ustream_ssl_set_ciphers,
281 .context_set_require_validation = __ustream_ssl_set_require_validation,
282 .context_set_debug = __ustream_ssl_set_debug,
283 .context_free = __ustream_ssl_context_free,
284 .init = _ustream_ssl_init,
285 .init_fd = _ustream_ssl_init_fd,
286 .set_peer_cn = _ustream_ssl_set_peer_cn,
287 };