1 | // SPDX-License-Identifier: GPL-2.0-only |
2 | /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */ |
3 | |
4 | #include <linux/skbuff.h> |
5 | #include <linux/workqueue.h> |
6 | #include <net/strparser.h> |
7 | #include <net/tcp.h> |
8 | #include <net/sock.h> |
9 | #include <net/tls.h> |
10 | |
11 | #include "tls.h" |
12 | |
13 | static struct workqueue_struct *tls_strp_wq; |
14 | |
15 | static void tls_strp_abort_strp(struct tls_strparser *strp, int err) |
16 | { |
17 | if (strp->stopped) |
18 | return; |
19 | |
20 | strp->stopped = 1; |
21 | |
22 | /* Report an error on the lower socket */ |
23 | WRITE_ONCE(strp->sk->sk_err, -err); |
24 | /* Paired with smp_rmb() in tcp_poll() */ |
25 | smp_wmb(); |
26 | sk_error_report(sk: strp->sk); |
27 | } |
28 | |
29 | static void tls_strp_anchor_free(struct tls_strparser *strp) |
30 | { |
31 | struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); |
32 | |
33 | DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); |
34 | if (!strp->copy_mode) |
35 | shinfo->frag_list = NULL; |
36 | consume_skb(skb: strp->anchor); |
37 | strp->anchor = NULL; |
38 | } |
39 | |
40 | static struct sk_buff * |
41 | tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb, |
42 | int offset, int len) |
43 | { |
44 | struct sk_buff *skb; |
45 | int i, err; |
46 | |
47 | skb = alloc_skb_with_frags(header_len: 0, data_len: len, TLS_PAGE_ORDER, |
48 | errcode: &err, gfp_mask: strp->sk->sk_allocation); |
49 | if (!skb) |
50 | return NULL; |
51 | |
52 | for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { |
53 | skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; |
54 | |
55 | WARN_ON_ONCE(skb_copy_bits(in_skb, offset, |
56 | skb_frag_address(frag), |
57 | skb_frag_size(frag))); |
58 | offset += skb_frag_size(frag); |
59 | } |
60 | |
61 | skb->len = len; |
62 | skb->data_len = len; |
63 | skb_copy_header(new: skb, old: in_skb); |
64 | return skb; |
65 | } |
66 | |
67 | /* Create a new skb with the contents of input copied to its page frags */ |
68 | static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) |
69 | { |
70 | struct strp_msg *rxm; |
71 | struct sk_buff *skb; |
72 | |
73 | skb = tls_strp_skb_copy(strp, in_skb: strp->anchor, offset: strp->stm.offset, |
74 | len: strp->stm.full_len); |
75 | if (!skb) |
76 | return NULL; |
77 | |
78 | rxm = strp_msg(skb); |
79 | rxm->offset = 0; |
80 | return skb; |
81 | } |
82 | |
83 | /* Steal the input skb, input msg is invalid after calling this function */ |
84 | struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx) |
85 | { |
86 | struct tls_strparser *strp = &ctx->strp; |
87 | |
88 | #ifdef CONFIG_TLS_DEVICE |
89 | DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted); |
90 | #else |
91 | /* This function turns an input into an output, |
92 | * that can only happen if we have offload. |
93 | */ |
94 | WARN_ON(1); |
95 | #endif |
96 | |
97 | if (strp->copy_mode) { |
98 | struct sk_buff *skb; |
99 | |
100 | /* Replace anchor with an empty skb, this is a little |
101 | * dangerous but __tls_cur_msg() warns on empty skbs |
102 | * so hopefully we'll catch abuses. |
103 | */ |
104 | skb = alloc_skb(size: 0, priority: strp->sk->sk_allocation); |
105 | if (!skb) |
106 | return NULL; |
107 | |
108 | swap(strp->anchor, skb); |
109 | return skb; |
110 | } |
111 | |
112 | return tls_strp_msg_make_copy(strp); |
113 | } |
114 | |
115 | /* Force the input skb to be in copy mode. The data ownership remains |
116 | * with the input skb itself (meaning unpause will wipe it) but it can |
117 | * be modified. |
118 | */ |
119 | int tls_strp_msg_cow(struct tls_sw_context_rx *ctx) |
120 | { |
121 | struct tls_strparser *strp = &ctx->strp; |
122 | struct sk_buff *skb; |
123 | |
124 | if (strp->copy_mode) |
125 | return 0; |
126 | |
127 | skb = tls_strp_msg_make_copy(strp); |
128 | if (!skb) |
129 | return -ENOMEM; |
130 | |
131 | tls_strp_anchor_free(strp); |
132 | strp->anchor = skb; |
133 | |
134 | tcp_read_done(sk: strp->sk, len: strp->stm.full_len); |
135 | strp->copy_mode = 1; |
136 | |
137 | return 0; |
138 | } |
139 | |
140 | /* Make a clone (in the skb sense) of the input msg to keep a reference |
141 | * to the underlying data. The reference-holding skbs get placed on |
142 | * @dst. |
143 | */ |
144 | int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst) |
145 | { |
146 | struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); |
147 | |
148 | if (strp->copy_mode) { |
149 | struct sk_buff *skb; |
150 | |
151 | WARN_ON_ONCE(!shinfo->nr_frags); |
152 | |
153 | /* We can't skb_clone() the anchor, it gets wiped by unpause */ |
154 | skb = alloc_skb(size: 0, priority: strp->sk->sk_allocation); |
155 | if (!skb) |
156 | return -ENOMEM; |
157 | |
158 | __skb_queue_tail(list: dst, newsk: strp->anchor); |
159 | strp->anchor = skb; |
160 | } else { |
161 | struct sk_buff *iter, *clone; |
162 | int chunk, len, offset; |
163 | |
164 | offset = strp->stm.offset; |
165 | len = strp->stm.full_len; |
166 | iter = shinfo->frag_list; |
167 | |
168 | while (len > 0) { |
169 | if (iter->len <= offset) { |
170 | offset -= iter->len; |
171 | goto next; |
172 | } |
173 | |
174 | chunk = iter->len - offset; |
175 | offset = 0; |
176 | |
177 | clone = skb_clone(skb: iter, priority: strp->sk->sk_allocation); |
178 | if (!clone) |
179 | return -ENOMEM; |
180 | __skb_queue_tail(list: dst, newsk: clone); |
181 | |
182 | len -= chunk; |
183 | next: |
184 | iter = iter->next; |
185 | } |
186 | } |
187 | |
188 | return 0; |
189 | } |
190 | |
191 | static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) |
192 | { |
193 | struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); |
194 | int i; |
195 | |
196 | DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); |
197 | |
198 | for (i = 0; i < shinfo->nr_frags; i++) |
199 | __skb_frag_unref(frag: &shinfo->frags[i], recycle: false); |
200 | shinfo->nr_frags = 0; |
201 | if (strp->copy_mode) { |
202 | kfree_skb_list(segs: shinfo->frag_list); |
203 | shinfo->frag_list = NULL; |
204 | } |
205 | strp->copy_mode = 0; |
206 | strp->mixed_decrypted = 0; |
207 | } |
208 | |
209 | static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb, |
210 | struct sk_buff *in_skb, unsigned int offset, |
211 | size_t in_len) |
212 | { |
213 | size_t len, chunk; |
214 | skb_frag_t *frag; |
215 | int sz; |
216 | |
217 | frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; |
218 | |
219 | len = in_len; |
220 | /* First make sure we got the header */ |
221 | if (!strp->stm.full_len) { |
222 | /* Assume one page is more than enough for headers */ |
223 | chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag)); |
224 | WARN_ON_ONCE(skb_copy_bits(in_skb, offset, |
225 | skb_frag_address(frag) + |
226 | skb_frag_size(frag), |
227 | chunk)); |
228 | |
229 | skb->len += chunk; |
230 | skb->data_len += chunk; |
231 | skb_frag_size_add(frag, delta: chunk); |
232 | |
233 | sz = tls_rx_msg_size(strp, skb); |
234 | if (sz < 0) |
235 | return sz; |
236 | |
237 | /* We may have over-read, sz == 0 is guaranteed under-read */ |
238 | if (unlikely(sz && sz < skb->len)) { |
239 | int over = skb->len - sz; |
240 | |
241 | WARN_ON_ONCE(over > chunk); |
242 | skb->len -= over; |
243 | skb->data_len -= over; |
244 | skb_frag_size_add(frag, delta: -over); |
245 | |
246 | chunk -= over; |
247 | } |
248 | |
249 | frag++; |
250 | len -= chunk; |
251 | offset += chunk; |
252 | |
253 | strp->stm.full_len = sz; |
254 | if (!strp->stm.full_len) |
255 | goto read_done; |
256 | } |
257 | |
258 | /* Load up more data */ |
259 | while (len && strp->stm.full_len > skb->len) { |
260 | chunk = min_t(size_t, len, strp->stm.full_len - skb->len); |
261 | chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag)); |
262 | WARN_ON_ONCE(skb_copy_bits(in_skb, offset, |
263 | skb_frag_address(frag) + |
264 | skb_frag_size(frag), |
265 | chunk)); |
266 | |
267 | skb->len += chunk; |
268 | skb->data_len += chunk; |
269 | skb_frag_size_add(frag, delta: chunk); |
270 | frag++; |
271 | len -= chunk; |
272 | offset += chunk; |
273 | } |
274 | |
275 | read_done: |
276 | return in_len - len; |
277 | } |
278 | |
279 | static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb, |
280 | struct sk_buff *in_skb, unsigned int offset, |
281 | size_t in_len) |
282 | { |
283 | struct sk_buff *nskb, *first, *last; |
284 | struct skb_shared_info *shinfo; |
285 | size_t chunk; |
286 | int sz; |
287 | |
288 | if (strp->stm.full_len) |
289 | chunk = strp->stm.full_len - skb->len; |
290 | else |
291 | chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; |
292 | chunk = min(chunk, in_len); |
293 | |
294 | nskb = tls_strp_skb_copy(strp, in_skb, offset, len: chunk); |
295 | if (!nskb) |
296 | return -ENOMEM; |
297 | |
298 | shinfo = skb_shinfo(skb); |
299 | if (!shinfo->frag_list) { |
300 | shinfo->frag_list = nskb; |
301 | nskb->prev = nskb; |
302 | } else { |
303 | first = shinfo->frag_list; |
304 | last = first->prev; |
305 | last->next = nskb; |
306 | first->prev = nskb; |
307 | } |
308 | |
309 | skb->len += chunk; |
310 | skb->data_len += chunk; |
311 | |
312 | if (!strp->stm.full_len) { |
313 | sz = tls_rx_msg_size(strp, skb); |
314 | if (sz < 0) |
315 | return sz; |
316 | |
317 | /* We may have over-read, sz == 0 is guaranteed under-read */ |
318 | if (unlikely(sz && sz < skb->len)) { |
319 | int over = skb->len - sz; |
320 | |
321 | WARN_ON_ONCE(over > chunk); |
322 | skb->len -= over; |
323 | skb->data_len -= over; |
324 | __pskb_trim(skb: nskb, len: nskb->len - over); |
325 | |
326 | chunk -= over; |
327 | } |
328 | |
329 | strp->stm.full_len = sz; |
330 | } |
331 | |
332 | return chunk; |
333 | } |
334 | |
335 | static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, |
336 | unsigned int offset, size_t in_len) |
337 | { |
338 | struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; |
339 | struct sk_buff *skb; |
340 | int ret; |
341 | |
342 | if (strp->msg_ready) |
343 | return 0; |
344 | |
345 | skb = strp->anchor; |
346 | if (!skb->len) |
347 | skb_copy_decrypted(to: skb, from: in_skb); |
348 | else |
349 | strp->mixed_decrypted |= !!skb_cmp_decrypted(skb1: skb, skb2: in_skb); |
350 | |
351 | if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted) |
352 | ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len); |
353 | else |
354 | ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len); |
355 | if (ret < 0) { |
356 | desc->error = ret; |
357 | ret = 0; |
358 | } |
359 | |
360 | if (strp->stm.full_len && strp->stm.full_len == skb->len) { |
361 | desc->count = 0; |
362 | |
363 | strp->msg_ready = 1; |
364 | tls_rx_msg_ready(strp); |
365 | } |
366 | |
367 | return ret; |
368 | } |
369 | |
370 | static int tls_strp_read_copyin(struct tls_strparser *strp) |
371 | { |
372 | read_descriptor_t desc; |
373 | |
374 | desc.arg.data = strp; |
375 | desc.error = 0; |
376 | desc.count = 1; /* give more than one skb per call */ |
377 | |
378 | /* sk should be locked here, so okay to do read_sock */ |
379 | tcp_read_sock(sk: strp->sk, desc: &desc, recv_actor: tls_strp_copyin); |
380 | |
381 | return desc.error; |
382 | } |
383 | |
384 | static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) |
385 | { |
386 | struct skb_shared_info *shinfo; |
387 | struct page *page; |
388 | int need_spc, len; |
389 | |
390 | /* If the rbuf is small or rcv window has collapsed to 0 we need |
391 | * to read the data out. Otherwise the connection will stall. |
392 | * Without pressure threshold of INT_MAX will never be ready. |
393 | */ |
394 | if (likely(qshort && !tcp_epollin_ready(strp->sk, INT_MAX))) |
395 | return 0; |
396 | |
397 | shinfo = skb_shinfo(strp->anchor); |
398 | shinfo->frag_list = NULL; |
399 | |
400 | /* If we don't know the length go max plus page for cipher overhead */ |
401 | need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; |
402 | |
403 | for (len = need_spc; len > 0; len -= PAGE_SIZE) { |
404 | page = alloc_page(strp->sk->sk_allocation); |
405 | if (!page) { |
406 | tls_strp_flush_anchor_copy(strp); |
407 | return -ENOMEM; |
408 | } |
409 | |
410 | skb_fill_page_desc(skb: strp->anchor, i: shinfo->nr_frags++, |
411 | page, off: 0, size: 0); |
412 | } |
413 | |
414 | strp->copy_mode = 1; |
415 | strp->stm.offset = 0; |
416 | |
417 | strp->anchor->len = 0; |
418 | strp->anchor->data_len = 0; |
419 | strp->anchor->truesize = round_up(need_spc, PAGE_SIZE); |
420 | |
421 | tls_strp_read_copyin(strp); |
422 | |
423 | return 0; |
424 | } |
425 | |
426 | static bool tls_strp_check_queue_ok(struct tls_strparser *strp) |
427 | { |
428 | unsigned int len = strp->stm.offset + strp->stm.full_len; |
429 | struct sk_buff *first, *skb; |
430 | u32 seq; |
431 | |
432 | first = skb_shinfo(strp->anchor)->frag_list; |
433 | skb = first; |
434 | seq = TCP_SKB_CB(first)->seq; |
435 | |
436 | /* Make sure there's no duplicate data in the queue, |
437 | * and the decrypted status matches. |
438 | */ |
439 | while (skb->len < len) { |
440 | seq += skb->len; |
441 | len -= skb->len; |
442 | skb = skb->next; |
443 | |
444 | if (TCP_SKB_CB(skb)->seq != seq) |
445 | return false; |
446 | if (skb_cmp_decrypted(skb1: first, skb2: skb)) |
447 | return false; |
448 | } |
449 | |
450 | return true; |
451 | } |
452 | |
453 | static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) |
454 | { |
455 | struct tcp_sock *tp = tcp_sk(strp->sk); |
456 | struct sk_buff *first; |
457 | u32 offset; |
458 | |
459 | first = tcp_recv_skb(sk: strp->sk, seq: tp->copied_seq, off: &offset); |
460 | if (WARN_ON_ONCE(!first)) |
461 | return; |
462 | |
463 | /* Bestow the state onto the anchor */ |
464 | strp->anchor->len = offset + len; |
465 | strp->anchor->data_len = offset + len; |
466 | strp->anchor->truesize = offset + len; |
467 | |
468 | skb_shinfo(strp->anchor)->frag_list = first; |
469 | |
470 | skb_copy_header(new: strp->anchor, old: first); |
471 | strp->anchor->destructor = NULL; |
472 | |
473 | strp->stm.offset = offset; |
474 | } |
475 | |
476 | void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) |
477 | { |
478 | struct strp_msg *rxm; |
479 | struct tls_msg *tlm; |
480 | |
481 | DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready); |
482 | DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); |
483 | |
484 | if (!strp->copy_mode && force_refresh) { |
485 | if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) |
486 | return; |
487 | |
488 | tls_strp_load_anchor_with_queue(strp, len: strp->stm.full_len); |
489 | } |
490 | |
491 | rxm = strp_msg(skb: strp->anchor); |
492 | rxm->full_len = strp->stm.full_len; |
493 | rxm->offset = strp->stm.offset; |
494 | tlm = tls_msg(skb: strp->anchor); |
495 | tlm->control = strp->mark; |
496 | } |
497 | |
498 | /* Called with lock held on lower socket */ |
499 | static int tls_strp_read_sock(struct tls_strparser *strp) |
500 | { |
501 | int sz, inq; |
502 | |
503 | inq = tcp_inq(sk: strp->sk); |
504 | if (inq < 1) |
505 | return 0; |
506 | |
507 | if (unlikely(strp->copy_mode)) |
508 | return tls_strp_read_copyin(strp); |
509 | |
510 | if (inq < strp->stm.full_len) |
511 | return tls_strp_read_copy(strp, qshort: true); |
512 | |
513 | if (!strp->stm.full_len) { |
514 | tls_strp_load_anchor_with_queue(strp, len: inq); |
515 | |
516 | sz = tls_rx_msg_size(strp, skb: strp->anchor); |
517 | if (sz < 0) { |
518 | tls_strp_abort_strp(strp, err: sz); |
519 | return sz; |
520 | } |
521 | |
522 | strp->stm.full_len = sz; |
523 | |
524 | if (!strp->stm.full_len || inq < strp->stm.full_len) |
525 | return tls_strp_read_copy(strp, qshort: true); |
526 | } |
527 | |
528 | if (!tls_strp_check_queue_ok(strp)) |
529 | return tls_strp_read_copy(strp, qshort: false); |
530 | |
531 | strp->msg_ready = 1; |
532 | tls_rx_msg_ready(strp); |
533 | |
534 | return 0; |
535 | } |
536 | |
537 | void tls_strp_check_rcv(struct tls_strparser *strp) |
538 | { |
539 | if (unlikely(strp->stopped) || strp->msg_ready) |
540 | return; |
541 | |
542 | if (tls_strp_read_sock(strp) == -ENOMEM) |
543 | queue_work(wq: tls_strp_wq, work: &strp->work); |
544 | } |
545 | |
546 | /* Lower sock lock held */ |
547 | void tls_strp_data_ready(struct tls_strparser *strp) |
548 | { |
549 | /* This check is needed to synchronize with do_tls_strp_work. |
550 | * do_tls_strp_work acquires a process lock (lock_sock) whereas |
551 | * the lock held here is bh_lock_sock. The two locks can be |
552 | * held by different threads at the same time, but bh_lock_sock |
553 | * allows a thread in BH context to safely check if the process |
554 | * lock is held. In this case, if the lock is held, queue work. |
555 | */ |
556 | if (sock_owned_by_user_nocheck(sk: strp->sk)) { |
557 | queue_work(wq: tls_strp_wq, work: &strp->work); |
558 | return; |
559 | } |
560 | |
561 | tls_strp_check_rcv(strp); |
562 | } |
563 | |
564 | static void tls_strp_work(struct work_struct *w) |
565 | { |
566 | struct tls_strparser *strp = |
567 | container_of(w, struct tls_strparser, work); |
568 | |
569 | lock_sock(sk: strp->sk); |
570 | tls_strp_check_rcv(strp); |
571 | release_sock(sk: strp->sk); |
572 | } |
573 | |
574 | void tls_strp_msg_done(struct tls_strparser *strp) |
575 | { |
576 | WARN_ON(!strp->stm.full_len); |
577 | |
578 | if (likely(!strp->copy_mode)) |
579 | tcp_read_done(sk: strp->sk, len: strp->stm.full_len); |
580 | else |
581 | tls_strp_flush_anchor_copy(strp); |
582 | |
583 | strp->msg_ready = 0; |
584 | memset(&strp->stm, 0, sizeof(strp->stm)); |
585 | |
586 | tls_strp_check_rcv(strp); |
587 | } |
588 | |
589 | void tls_strp_stop(struct tls_strparser *strp) |
590 | { |
591 | strp->stopped = 1; |
592 | } |
593 | |
594 | int tls_strp_init(struct tls_strparser *strp, struct sock *sk) |
595 | { |
596 | memset(strp, 0, sizeof(*strp)); |
597 | |
598 | strp->sk = sk; |
599 | |
600 | strp->anchor = alloc_skb(size: 0, GFP_KERNEL); |
601 | if (!strp->anchor) |
602 | return -ENOMEM; |
603 | |
604 | INIT_WORK(&strp->work, tls_strp_work); |
605 | |
606 | return 0; |
607 | } |
608 | |
609 | /* strp must already be stopped so that tls_strp_recv will no longer be called. |
610 | * Note that tls_strp_done is not called with the lower socket held. |
611 | */ |
612 | void tls_strp_done(struct tls_strparser *strp) |
613 | { |
614 | WARN_ON(!strp->stopped); |
615 | |
616 | cancel_work_sync(work: &strp->work); |
617 | tls_strp_anchor_free(strp); |
618 | } |
619 | |
620 | int __init tls_strp_dev_init(void) |
621 | { |
622 | tls_strp_wq = create_workqueue("tls-strp" ); |
623 | if (unlikely(!tls_strp_wq)) |
624 | return -ENOMEM; |
625 | |
626 | return 0; |
627 | } |
628 | |
629 | void tls_strp_dev_exit(void) |
630 | { |
631 | destroy_workqueue(wq: tls_strp_wq); |
632 | } |
633 | |