1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ |
3 | |
4 | #include <linux/skmsg.h> |
5 | #include <linux/filter.h> |
6 | #include <linux/bpf.h> |
7 | #include <linux/init.h> |
8 | #include <linux/wait.h> |
9 | #include <linux/util_macros.h> |
10 | |
11 | #include <net/inet_common.h> |
12 | #include <net/tls.h> |
13 | |
14 | void tcp_eat_skb(struct sock *sk, struct sk_buff *skb) |
15 | { |
16 | struct tcp_sock *tcp; |
17 | int copied; |
18 | |
19 | if (!skb || !skb->len || !sk_is_tcp(sk)) |
20 | return; |
21 | |
22 | if (skb_bpf_strparser(skb)) |
23 | return; |
24 | |
25 | tcp = tcp_sk(sk); |
26 | copied = tcp->copied_seq + skb->len; |
27 | WRITE_ONCE(tcp->copied_seq, copied); |
28 | tcp_rcv_space_adjust(sk); |
29 | __tcp_cleanup_rbuf(sk, copied: skb->len); |
30 | } |
31 | |
32 | static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, |
33 | struct sk_msg *msg, u32 apply_bytes, int flags) |
34 | { |
35 | bool apply = apply_bytes; |
36 | struct scatterlist *sge; |
37 | u32 size, copied = 0; |
38 | struct sk_msg *tmp; |
39 | int i, ret = 0; |
40 | |
41 | tmp = kzalloc(size: sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); |
42 | if (unlikely(!tmp)) |
43 | return -ENOMEM; |
44 | |
45 | lock_sock(sk); |
46 | tmp->sg.start = msg->sg.start; |
47 | i = msg->sg.start; |
48 | do { |
49 | sge = sk_msg_elem(msg, which: i); |
50 | size = (apply && apply_bytes < sge->length) ? |
51 | apply_bytes : sge->length; |
52 | if (!sk_wmem_schedule(sk, size)) { |
53 | if (!copied) |
54 | ret = -ENOMEM; |
55 | break; |
56 | } |
57 | |
58 | sk_mem_charge(sk, size); |
59 | sk_msg_xfer(dst: tmp, src: msg, which: i, size); |
60 | copied += size; |
61 | if (sge->length) |
62 | get_page(page: sk_msg_page(msg: tmp, which: i)); |
63 | sk_msg_iter_var_next(i); |
64 | tmp->sg.end = i; |
65 | if (apply) { |
66 | apply_bytes -= size; |
67 | if (!apply_bytes) { |
68 | if (sge->length) |
69 | sk_msg_iter_var_prev(i); |
70 | break; |
71 | } |
72 | } |
73 | } while (i != msg->sg.end); |
74 | |
75 | if (!ret) { |
76 | msg->sg.start = i; |
77 | sk_psock_queue_msg(psock, msg: tmp); |
78 | sk_psock_data_ready(sk, psock); |
79 | } else { |
80 | sk_msg_free(sk, msg: tmp); |
81 | kfree(objp: tmp); |
82 | } |
83 | |
84 | release_sock(sk); |
85 | return ret; |
86 | } |
87 | |
88 | static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, |
89 | int flags, bool uncharge) |
90 | { |
91 | struct msghdr msghdr = {}; |
92 | bool apply = apply_bytes; |
93 | struct scatterlist *sge; |
94 | struct page *page; |
95 | int size, ret = 0; |
96 | u32 off; |
97 | |
98 | while (1) { |
99 | struct bio_vec bvec; |
100 | bool has_tx_ulp; |
101 | |
102 | sge = sk_msg_elem(msg, which: msg->sg.start); |
103 | size = (apply && apply_bytes < sge->length) ? |
104 | apply_bytes : sge->length; |
105 | off = sge->offset; |
106 | page = sg_page(sg: sge); |
107 | |
108 | tcp_rate_check_app_limited(sk); |
109 | retry: |
110 | msghdr.msg_flags = flags | MSG_SPLICE_PAGES; |
111 | has_tx_ulp = tls_sw_has_ctx_tx(sk); |
112 | if (has_tx_ulp) |
113 | msghdr.msg_flags |= MSG_SENDPAGE_NOPOLICY; |
114 | |
115 | if (size < sge->length && msg->sg.start != msg->sg.end) |
116 | msghdr.msg_flags |= MSG_MORE; |
117 | |
118 | bvec_set_page(bv: &bvec, page, len: size, offset: off); |
119 | iov_iter_bvec(i: &msghdr.msg_iter, ITER_SOURCE, bvec: &bvec, nr_segs: 1, count: size); |
120 | ret = tcp_sendmsg_locked(sk, msg: &msghdr, size); |
121 | if (ret <= 0) |
122 | return ret; |
123 | |
124 | if (apply) |
125 | apply_bytes -= ret; |
126 | msg->sg.size -= ret; |
127 | sge->offset += ret; |
128 | sge->length -= ret; |
129 | if (uncharge) |
130 | sk_mem_uncharge(sk, size: ret); |
131 | if (ret != size) { |
132 | size -= ret; |
133 | off += ret; |
134 | goto retry; |
135 | } |
136 | if (!sge->length) { |
137 | put_page(page); |
138 | sk_msg_iter_next(msg, start); |
139 | sg_init_table(sge, 1); |
140 | if (msg->sg.start == msg->sg.end) |
141 | break; |
142 | } |
143 | if (apply && !apply_bytes) |
144 | break; |
145 | } |
146 | |
147 | return 0; |
148 | } |
149 | |
150 | static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, |
151 | u32 apply_bytes, int flags, bool uncharge) |
152 | { |
153 | int ret; |
154 | |
155 | lock_sock(sk); |
156 | ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); |
157 | release_sock(sk); |
158 | return ret; |
159 | } |
160 | |
161 | int tcp_bpf_sendmsg_redir(struct sock *sk, bool ingress, |
162 | struct sk_msg *msg, u32 bytes, int flags) |
163 | { |
164 | struct sk_psock *psock = sk_psock_get(sk); |
165 | int ret; |
166 | |
167 | if (unlikely(!psock)) |
168 | return -EPIPE; |
169 | |
170 | ret = ingress ? bpf_tcp_ingress(sk, psock, msg, apply_bytes: bytes, flags) : |
171 | tcp_bpf_push_locked(sk, msg, apply_bytes: bytes, flags, uncharge: false); |
172 | sk_psock_put(sk, psock); |
173 | return ret; |
174 | } |
175 | EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); |
176 | |
177 | #ifdef CONFIG_BPF_SYSCALL |
178 | static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, |
179 | long timeo) |
180 | { |
181 | DEFINE_WAIT_FUNC(wait, woken_wake_function); |
182 | int ret = 0; |
183 | |
184 | if (sk->sk_shutdown & RCV_SHUTDOWN) |
185 | return 1; |
186 | |
187 | if (!timeo) |
188 | return ret; |
189 | |
190 | add_wait_queue(wq_head: sk_sleep(sk), wq_entry: &wait); |
191 | sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); |
192 | ret = sk_wait_event(sk, &timeo, |
193 | !list_empty(&psock->ingress_msg) || |
194 | !skb_queue_empty_lockless(&sk->sk_receive_queue), &wait); |
195 | sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); |
196 | remove_wait_queue(wq_head: sk_sleep(sk), wq_entry: &wait); |
197 | return ret; |
198 | } |
199 | |
200 | static bool is_next_msg_fin(struct sk_psock *psock) |
201 | { |
202 | struct scatterlist *sge; |
203 | struct sk_msg *msg_rx; |
204 | int i; |
205 | |
206 | msg_rx = sk_psock_peek_msg(psock); |
207 | i = msg_rx->sg.start; |
208 | sge = sk_msg_elem(msg: msg_rx, which: i); |
209 | if (!sge->length) { |
210 | struct sk_buff *skb = msg_rx->skb; |
211 | |
212 | if (skb && TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) |
213 | return true; |
214 | } |
215 | return false; |
216 | } |
217 | |
218 | static int tcp_bpf_recvmsg_parser(struct sock *sk, |
219 | struct msghdr *msg, |
220 | size_t len, |
221 | int flags, |
222 | int *addr_len) |
223 | { |
224 | struct tcp_sock *tcp = tcp_sk(sk); |
225 | int peek = flags & MSG_PEEK; |
226 | u32 seq = tcp->copied_seq; |
227 | struct sk_psock *psock; |
228 | int copied = 0; |
229 | |
230 | if (unlikely(flags & MSG_ERRQUEUE)) |
231 | return inet_recv_error(sk, msg, len, addr_len); |
232 | |
233 | if (!len) |
234 | return 0; |
235 | |
236 | psock = sk_psock_get(sk); |
237 | if (unlikely(!psock)) |
238 | return tcp_recvmsg(sk, msg, len, flags, addr_len); |
239 | |
240 | lock_sock(sk); |
241 | |
242 | /* We may have received data on the sk_receive_queue pre-accept and |
243 | * then we can not use read_skb in this context because we haven't |
244 | * assigned a sk_socket yet so have no link to the ops. The work-around |
245 | * is to check the sk_receive_queue and in these cases read skbs off |
246 | * queue again. The read_skb hook is not running at this point because |
247 | * of lock_sock so we avoid having multiple runners in read_skb. |
248 | */ |
249 | if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) { |
250 | tcp_data_ready(sk); |
251 | /* This handles the ENOMEM errors if we both receive data |
252 | * pre accept and are already under memory pressure. At least |
253 | * let user know to retry. |
254 | */ |
255 | if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) { |
256 | copied = -EAGAIN; |
257 | goto out; |
258 | } |
259 | } |
260 | |
261 | msg_bytes_ready: |
262 | copied = sk_msg_recvmsg(sk, psock, msg, len, flags); |
263 | /* The typical case for EFAULT is the socket was gracefully |
264 | * shutdown with a FIN pkt. So check here the other case is |
265 | * some error on copy_page_to_iter which would be unexpected. |
266 | * On fin return correct return code to zero. |
267 | */ |
268 | if (copied == -EFAULT) { |
269 | bool is_fin = is_next_msg_fin(psock); |
270 | |
271 | if (is_fin) { |
272 | copied = 0; |
273 | seq++; |
274 | goto out; |
275 | } |
276 | } |
277 | seq += copied; |
278 | if (!copied) { |
279 | long timeo; |
280 | int data; |
281 | |
282 | if (sock_flag(sk, flag: SOCK_DONE)) |
283 | goto out; |
284 | |
285 | if (sk->sk_err) { |
286 | copied = sock_error(sk); |
287 | goto out; |
288 | } |
289 | |
290 | if (sk->sk_shutdown & RCV_SHUTDOWN) |
291 | goto out; |
292 | |
293 | if (sk->sk_state == TCP_CLOSE) { |
294 | copied = -ENOTCONN; |
295 | goto out; |
296 | } |
297 | |
298 | timeo = sock_rcvtimeo(sk, noblock: flags & MSG_DONTWAIT); |
299 | if (!timeo) { |
300 | copied = -EAGAIN; |
301 | goto out; |
302 | } |
303 | |
304 | if (signal_pending(current)) { |
305 | copied = sock_intr_errno(timeo); |
306 | goto out; |
307 | } |
308 | |
309 | data = tcp_msg_wait_data(sk, psock, timeo); |
310 | if (data < 0) { |
311 | copied = data; |
312 | goto unlock; |
313 | } |
314 | if (data && !sk_psock_queue_empty(psock)) |
315 | goto msg_bytes_ready; |
316 | copied = -EAGAIN; |
317 | } |
318 | out: |
319 | if (!peek) |
320 | WRITE_ONCE(tcp->copied_seq, seq); |
321 | tcp_rcv_space_adjust(sk); |
322 | if (copied > 0) |
323 | __tcp_cleanup_rbuf(sk, copied); |
324 | |
325 | unlock: |
326 | release_sock(sk); |
327 | sk_psock_put(sk, psock); |
328 | return copied; |
329 | } |
330 | |
331 | static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, |
332 | int flags, int *addr_len) |
333 | { |
334 | struct sk_psock *psock; |
335 | int copied, ret; |
336 | |
337 | if (unlikely(flags & MSG_ERRQUEUE)) |
338 | return inet_recv_error(sk, msg, len, addr_len); |
339 | |
340 | if (!len) |
341 | return 0; |
342 | |
343 | psock = sk_psock_get(sk); |
344 | if (unlikely(!psock)) |
345 | return tcp_recvmsg(sk, msg, len, flags, addr_len); |
346 | if (!skb_queue_empty(list: &sk->sk_receive_queue) && |
347 | sk_psock_queue_empty(psock)) { |
348 | sk_psock_put(sk, psock); |
349 | return tcp_recvmsg(sk, msg, len, flags, addr_len); |
350 | } |
351 | lock_sock(sk); |
352 | msg_bytes_ready: |
353 | copied = sk_msg_recvmsg(sk, psock, msg, len, flags); |
354 | if (!copied) { |
355 | long timeo; |
356 | int data; |
357 | |
358 | timeo = sock_rcvtimeo(sk, noblock: flags & MSG_DONTWAIT); |
359 | data = tcp_msg_wait_data(sk, psock, timeo); |
360 | if (data < 0) { |
361 | ret = data; |
362 | goto unlock; |
363 | } |
364 | if (data) { |
365 | if (!sk_psock_queue_empty(psock)) |
366 | goto msg_bytes_ready; |
367 | release_sock(sk); |
368 | sk_psock_put(sk, psock); |
369 | return tcp_recvmsg(sk, msg, len, flags, addr_len); |
370 | } |
371 | copied = -EAGAIN; |
372 | } |
373 | ret = copied; |
374 | |
375 | unlock: |
376 | release_sock(sk); |
377 | sk_psock_put(sk, psock); |
378 | return ret; |
379 | } |
380 | |
381 | static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, |
382 | struct sk_msg *msg, int *copied, int flags) |
383 | { |
384 | bool cork = false, enospc = sk_msg_full(msg), redir_ingress; |
385 | struct sock *sk_redir; |
386 | u32 tosend, origsize, sent, delta = 0; |
387 | u32 eval; |
388 | int ret; |
389 | |
390 | more_data: |
391 | if (psock->eval == __SK_NONE) { |
392 | /* Track delta in msg size to add/subtract it on SK_DROP from |
393 | * returned to user copied size. This ensures user doesn't |
394 | * get a positive return code with msg_cut_data and SK_DROP |
395 | * verdict. |
396 | */ |
397 | delta = msg->sg.size; |
398 | psock->eval = sk_psock_msg_verdict(sk, psock, msg); |
399 | delta -= msg->sg.size; |
400 | } |
401 | |
402 | if (msg->cork_bytes && |
403 | msg->cork_bytes > msg->sg.size && !enospc) { |
404 | psock->cork_bytes = msg->cork_bytes - msg->sg.size; |
405 | if (!psock->cork) { |
406 | psock->cork = kzalloc(size: sizeof(*psock->cork), |
407 | GFP_ATOMIC | __GFP_NOWARN); |
408 | if (!psock->cork) |
409 | return -ENOMEM; |
410 | } |
411 | memcpy(psock->cork, msg, sizeof(*msg)); |
412 | return 0; |
413 | } |
414 | |
415 | tosend = msg->sg.size; |
416 | if (psock->apply_bytes && psock->apply_bytes < tosend) |
417 | tosend = psock->apply_bytes; |
418 | eval = __SK_NONE; |
419 | |
420 | switch (psock->eval) { |
421 | case __SK_PASS: |
422 | ret = tcp_bpf_push(sk, msg, apply_bytes: tosend, flags, uncharge: true); |
423 | if (unlikely(ret)) { |
424 | *copied -= sk_msg_free(sk, msg); |
425 | break; |
426 | } |
427 | sk_msg_apply_bytes(psock, bytes: tosend); |
428 | break; |
429 | case __SK_REDIRECT: |
430 | redir_ingress = psock->redir_ingress; |
431 | sk_redir = psock->sk_redir; |
432 | sk_msg_apply_bytes(psock, bytes: tosend); |
433 | if (!psock->apply_bytes) { |
434 | /* Clean up before releasing the sock lock. */ |
435 | eval = psock->eval; |
436 | psock->eval = __SK_NONE; |
437 | psock->sk_redir = NULL; |
438 | } |
439 | if (psock->cork) { |
440 | cork = true; |
441 | psock->cork = NULL; |
442 | } |
443 | sk_msg_return(sk, msg, bytes: tosend); |
444 | release_sock(sk); |
445 | |
446 | origsize = msg->sg.size; |
447 | ret = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, |
448 | msg, tosend, flags); |
449 | sent = origsize - msg->sg.size; |
450 | |
451 | if (eval == __SK_REDIRECT) |
452 | sock_put(sk: sk_redir); |
453 | |
454 | lock_sock(sk); |
455 | if (unlikely(ret < 0)) { |
456 | int free = sk_msg_free_nocharge(sk, msg); |
457 | |
458 | if (!cork) |
459 | *copied -= free; |
460 | } |
461 | if (cork) { |
462 | sk_msg_free(sk, msg); |
463 | kfree(objp: msg); |
464 | msg = NULL; |
465 | ret = 0; |
466 | } |
467 | break; |
468 | case __SK_DROP: |
469 | default: |
470 | sk_msg_free_partial(sk, msg, bytes: tosend); |
471 | sk_msg_apply_bytes(psock, bytes: tosend); |
472 | *copied -= (tosend + delta); |
473 | return -EACCES; |
474 | } |
475 | |
476 | if (likely(!ret)) { |
477 | if (!psock->apply_bytes) { |
478 | psock->eval = __SK_NONE; |
479 | if (psock->sk_redir) { |
480 | sock_put(sk: psock->sk_redir); |
481 | psock->sk_redir = NULL; |
482 | } |
483 | } |
484 | if (msg && |
485 | msg->sg.data[msg->sg.start].page_link && |
486 | msg->sg.data[msg->sg.start].length) { |
487 | if (eval == __SK_REDIRECT) |
488 | sk_mem_charge(sk, size: tosend - sent); |
489 | goto more_data; |
490 | } |
491 | } |
492 | return ret; |
493 | } |
494 | |
495 | static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) |
496 | { |
497 | struct sk_msg tmp, *msg_tx = NULL; |
498 | int copied = 0, err = 0; |
499 | struct sk_psock *psock; |
500 | long timeo; |
501 | int flags; |
502 | |
503 | /* Don't let internal flags through */ |
504 | flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); |
505 | flags |= MSG_NO_SHARED_FRAGS; |
506 | |
507 | psock = sk_psock_get(sk); |
508 | if (unlikely(!psock)) |
509 | return tcp_sendmsg(sk, msg, size); |
510 | |
511 | lock_sock(sk); |
512 | timeo = sock_sndtimeo(sk, noblock: msg->msg_flags & MSG_DONTWAIT); |
513 | while (msg_data_left(msg)) { |
514 | bool enospc = false; |
515 | u32 copy, osize; |
516 | |
517 | if (sk->sk_err) { |
518 | err = -sk->sk_err; |
519 | goto out_err; |
520 | } |
521 | |
522 | copy = msg_data_left(msg); |
523 | if (!sk_stream_memory_free(sk)) |
524 | goto wait_for_sndbuf; |
525 | if (psock->cork) { |
526 | msg_tx = psock->cork; |
527 | } else { |
528 | msg_tx = &tmp; |
529 | sk_msg_init(msg: msg_tx); |
530 | } |
531 | |
532 | osize = msg_tx->sg.size; |
533 | err = sk_msg_alloc(sk, msg: msg_tx, len: msg_tx->sg.size + copy, elem_first_coalesce: msg_tx->sg.end - 1); |
534 | if (err) { |
535 | if (err != -ENOSPC) |
536 | goto wait_for_memory; |
537 | enospc = true; |
538 | copy = msg_tx->sg.size - osize; |
539 | } |
540 | |
541 | err = sk_msg_memcopy_from_iter(sk, from: &msg->msg_iter, msg: msg_tx, |
542 | bytes: copy); |
543 | if (err < 0) { |
544 | sk_msg_trim(sk, msg: msg_tx, len: osize); |
545 | goto out_err; |
546 | } |
547 | |
548 | copied += copy; |
549 | if (psock->cork_bytes) { |
550 | if (size > psock->cork_bytes) |
551 | psock->cork_bytes = 0; |
552 | else |
553 | psock->cork_bytes -= size; |
554 | if (psock->cork_bytes && !enospc) |
555 | goto out_err; |
556 | /* All cork bytes are accounted, rerun the prog. */ |
557 | psock->eval = __SK_NONE; |
558 | psock->cork_bytes = 0; |
559 | } |
560 | |
561 | err = tcp_bpf_send_verdict(sk, psock, msg: msg_tx, copied: &copied, flags); |
562 | if (unlikely(err < 0)) |
563 | goto out_err; |
564 | continue; |
565 | wait_for_sndbuf: |
566 | set_bit(SOCK_NOSPACE, addr: &sk->sk_socket->flags); |
567 | wait_for_memory: |
568 | err = sk_stream_wait_memory(sk, timeo_p: &timeo); |
569 | if (err) { |
570 | if (msg_tx && msg_tx != psock->cork) |
571 | sk_msg_free(sk, msg: msg_tx); |
572 | goto out_err; |
573 | } |
574 | } |
575 | out_err: |
576 | if (err < 0) |
577 | err = sk_stream_error(sk, flags: msg->msg_flags, err); |
578 | release_sock(sk); |
579 | sk_psock_put(sk, psock); |
580 | return copied ? copied : err; |
581 | } |
582 | |
583 | enum { |
584 | TCP_BPF_IPV4, |
585 | TCP_BPF_IPV6, |
586 | TCP_BPF_NUM_PROTS, |
587 | }; |
588 | |
589 | enum { |
590 | TCP_BPF_BASE, |
591 | TCP_BPF_TX, |
592 | TCP_BPF_RX, |
593 | TCP_BPF_TXRX, |
594 | TCP_BPF_NUM_CFGS, |
595 | }; |
596 | |
597 | static struct proto *tcpv6_prot_saved __read_mostly; |
598 | static DEFINE_SPINLOCK(tcpv6_prot_lock); |
599 | static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; |
600 | |
601 | static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], |
602 | struct proto *base) |
603 | { |
604 | prot[TCP_BPF_BASE] = *base; |
605 | prot[TCP_BPF_BASE].destroy = sock_map_destroy; |
606 | prot[TCP_BPF_BASE].close = sock_map_close; |
607 | prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; |
608 | prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; |
609 | |
610 | prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; |
611 | prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; |
612 | |
613 | prot[TCP_BPF_RX] = prot[TCP_BPF_BASE]; |
614 | prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser; |
615 | |
616 | prot[TCP_BPF_TXRX] = prot[TCP_BPF_TX]; |
617 | prot[TCP_BPF_TXRX].recvmsg = tcp_bpf_recvmsg_parser; |
618 | } |
619 | |
620 | static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops) |
621 | { |
622 | if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { |
623 | spin_lock_bh(lock: &tcpv6_prot_lock); |
624 | if (likely(ops != tcpv6_prot_saved)) { |
625 | tcp_bpf_rebuild_protos(prot: tcp_bpf_prots[TCP_BPF_IPV6], base: ops); |
626 | smp_store_release(&tcpv6_prot_saved, ops); |
627 | } |
628 | spin_unlock_bh(lock: &tcpv6_prot_lock); |
629 | } |
630 | } |
631 | |
632 | static int __init tcp_bpf_v4_build_proto(void) |
633 | { |
634 | tcp_bpf_rebuild_protos(prot: tcp_bpf_prots[TCP_BPF_IPV4], base: &tcp_prot); |
635 | return 0; |
636 | } |
637 | late_initcall(tcp_bpf_v4_build_proto); |
638 | |
639 | static int tcp_bpf_assert_proto_ops(struct proto *ops) |
640 | { |
641 | /* In order to avoid retpoline, we make assumptions when we call |
642 | * into ops if e.g. a psock is not present. Make sure they are |
643 | * indeed valid assumptions. |
644 | */ |
645 | return ops->recvmsg == tcp_recvmsg && |
646 | ops->sendmsg == tcp_sendmsg ? 0 : -ENOTSUPP; |
647 | } |
648 | |
649 | int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) |
650 | { |
651 | int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; |
652 | int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; |
653 | |
654 | if (psock->progs.stream_verdict || psock->progs.skb_verdict) { |
655 | config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX; |
656 | } |
657 | |
658 | if (restore) { |
659 | if (inet_csk_has_ulp(sk)) { |
660 | /* TLS does not have an unhash proto in SW cases, |
661 | * but we need to ensure we stop using the sock_map |
662 | * unhash routine because the associated psock is being |
663 | * removed. So use the original unhash handler. |
664 | */ |
665 | WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash); |
666 | tcp_update_ulp(sk, p: psock->sk_proto, write_space: psock->saved_write_space); |
667 | } else { |
668 | sk->sk_write_space = psock->saved_write_space; |
669 | /* Pairs with lockless read in sk_clone_lock() */ |
670 | sock_replace_proto(sk, proto: psock->sk_proto); |
671 | } |
672 | return 0; |
673 | } |
674 | |
675 | if (sk->sk_family == AF_INET6) { |
676 | if (tcp_bpf_assert_proto_ops(ops: psock->sk_proto)) |
677 | return -EINVAL; |
678 | |
679 | tcp_bpf_check_v6_needs_rebuild(ops: psock->sk_proto); |
680 | } |
681 | |
682 | /* Pairs with lockless read in sk_clone_lock() */ |
683 | sock_replace_proto(sk, proto: &tcp_bpf_prots[family][config]); |
684 | return 0; |
685 | } |
686 | EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); |
687 | |
688 | /* If a child got cloned from a listening socket that had tcp_bpf |
689 | * protocol callbacks installed, we need to restore the callbacks to |
690 | * the default ones because the child does not inherit the psock state |
691 | * that tcp_bpf callbacks expect. |
692 | */ |
693 | void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) |
694 | { |
695 | struct proto *prot = newsk->sk_prot; |
696 | |
697 | if (is_insidevar(prot, tcp_bpf_prots)) |
698 | newsk->sk_prot = sk->sk_prot_creator; |
699 | } |
700 | #endif /* CONFIG_BPF_SYSCALL */ |
701 | |