1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ |
3 | |
4 | #include <linux/skmsg.h> |
5 | #include <linux/filter.h> |
6 | #include <linux/bpf.h> |
7 | #include <linux/init.h> |
8 | #include <linux/wait.h> |
9 | |
10 | #include <net/inet_common.h> |
11 | #include <net/tls.h> |
12 | |
13 | static bool tcp_bpf_stream_read(const struct sock *sk) |
14 | { |
15 | struct sk_psock *psock; |
16 | bool empty = true; |
17 | |
18 | rcu_read_lock(); |
19 | psock = sk_psock(sk); |
20 | if (likely(psock)) |
21 | empty = list_empty(&psock->ingress_msg); |
22 | rcu_read_unlock(); |
23 | return !empty; |
24 | } |
25 | |
26 | static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, |
27 | int flags, long timeo, int *err) |
28 | { |
29 | DEFINE_WAIT_FUNC(wait, woken_wake_function); |
30 | int ret; |
31 | |
32 | add_wait_queue(sk_sleep(sk), &wait); |
33 | sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); |
34 | ret = sk_wait_event(sk, &timeo, |
35 | !list_empty(&psock->ingress_msg) || |
36 | !skb_queue_empty(&sk->sk_receive_queue), &wait); |
37 | sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); |
38 | remove_wait_queue(sk_sleep(sk), &wait); |
39 | return ret; |
40 | } |
41 | |
42 | int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, |
43 | struct msghdr *msg, int len, int flags) |
44 | { |
45 | struct iov_iter *iter = &msg->msg_iter; |
46 | int peek = flags & MSG_PEEK; |
47 | int i, ret, copied = 0; |
48 | struct sk_msg *msg_rx; |
49 | |
50 | msg_rx = list_first_entry_or_null(&psock->ingress_msg, |
51 | struct sk_msg, list); |
52 | |
53 | while (copied != len) { |
54 | struct scatterlist *sge; |
55 | |
56 | if (unlikely(!msg_rx)) |
57 | break; |
58 | |
59 | i = msg_rx->sg.start; |
60 | do { |
61 | struct page *page; |
62 | int copy; |
63 | |
64 | sge = sk_msg_elem(msg_rx, i); |
65 | copy = sge->length; |
66 | page = sg_page(sge); |
67 | if (copied + copy > len) |
68 | copy = len - copied; |
69 | ret = copy_page_to_iter(page, sge->offset, copy, iter); |
70 | if (ret != copy) { |
71 | msg_rx->sg.start = i; |
72 | return -EFAULT; |
73 | } |
74 | |
75 | copied += copy; |
76 | if (likely(!peek)) { |
77 | sge->offset += copy; |
78 | sge->length -= copy; |
79 | sk_mem_uncharge(sk, copy); |
80 | msg_rx->sg.size -= copy; |
81 | |
82 | if (!sge->length) { |
83 | sk_msg_iter_var_next(i); |
84 | if (!msg_rx->skb) |
85 | put_page(page); |
86 | } |
87 | } else { |
88 | sk_msg_iter_var_next(i); |
89 | } |
90 | |
91 | if (copied == len) |
92 | break; |
93 | } while (i != msg_rx->sg.end); |
94 | |
95 | if (unlikely(peek)) { |
96 | msg_rx = list_next_entry(msg_rx, list); |
97 | continue; |
98 | } |
99 | |
100 | msg_rx->sg.start = i; |
101 | if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { |
102 | list_del(&msg_rx->list); |
103 | if (msg_rx->skb) |
104 | consume_skb(msg_rx->skb); |
105 | kfree(msg_rx); |
106 | } |
107 | msg_rx = list_first_entry_or_null(&psock->ingress_msg, |
108 | struct sk_msg, list); |
109 | } |
110 | |
111 | return copied; |
112 | } |
113 | EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); |
114 | |
115 | int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, |
116 | int nonblock, int flags, int *addr_len) |
117 | { |
118 | struct sk_psock *psock; |
119 | int copied, ret; |
120 | |
121 | if (unlikely(flags & MSG_ERRQUEUE)) |
122 | return inet_recv_error(sk, msg, len, addr_len); |
123 | if (!skb_queue_empty(&sk->sk_receive_queue)) |
124 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); |
125 | |
126 | psock = sk_psock_get(sk); |
127 | if (unlikely(!psock)) |
128 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); |
129 | lock_sock(sk); |
130 | msg_bytes_ready: |
131 | copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); |
132 | if (!copied) { |
133 | int data, err = 0; |
134 | long timeo; |
135 | |
136 | timeo = sock_rcvtimeo(sk, nonblock); |
137 | data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); |
138 | if (data) { |
139 | if (skb_queue_empty(&sk->sk_receive_queue)) |
140 | goto msg_bytes_ready; |
141 | release_sock(sk); |
142 | sk_psock_put(sk, psock); |
143 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); |
144 | } |
145 | if (err) { |
146 | ret = err; |
147 | goto out; |
148 | } |
149 | copied = -EAGAIN; |
150 | } |
151 | ret = copied; |
152 | out: |
153 | release_sock(sk); |
154 | sk_psock_put(sk, psock); |
155 | return ret; |
156 | } |
157 | |
158 | static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, |
159 | struct sk_msg *msg, u32 apply_bytes, int flags) |
160 | { |
161 | bool apply = apply_bytes; |
162 | struct scatterlist *sge; |
163 | u32 size, copied = 0; |
164 | struct sk_msg *tmp; |
165 | int i, ret = 0; |
166 | |
167 | tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); |
168 | if (unlikely(!tmp)) |
169 | return -ENOMEM; |
170 | |
171 | lock_sock(sk); |
172 | tmp->sg.start = msg->sg.start; |
173 | i = msg->sg.start; |
174 | do { |
175 | sge = sk_msg_elem(msg, i); |
176 | size = (apply && apply_bytes < sge->length) ? |
177 | apply_bytes : sge->length; |
178 | if (!sk_wmem_schedule(sk, size)) { |
179 | if (!copied) |
180 | ret = -ENOMEM; |
181 | break; |
182 | } |
183 | |
184 | sk_mem_charge(sk, size); |
185 | sk_msg_xfer(tmp, msg, i, size); |
186 | copied += size; |
187 | if (sge->length) |
188 | get_page(sk_msg_page(tmp, i)); |
189 | sk_msg_iter_var_next(i); |
190 | tmp->sg.end = i; |
191 | if (apply) { |
192 | apply_bytes -= size; |
193 | if (!apply_bytes) |
194 | break; |
195 | } |
196 | } while (i != msg->sg.end); |
197 | |
198 | if (!ret) { |
199 | msg->sg.start = i; |
200 | msg->sg.size -= apply_bytes; |
201 | sk_psock_queue_msg(psock, tmp); |
202 | sk_psock_data_ready(sk, psock); |
203 | } else { |
204 | sk_msg_free(sk, tmp); |
205 | kfree(tmp); |
206 | } |
207 | |
208 | release_sock(sk); |
209 | return ret; |
210 | } |
211 | |
212 | static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, |
213 | int flags, bool uncharge) |
214 | { |
215 | bool apply = apply_bytes; |
216 | struct scatterlist *sge; |
217 | struct page *page; |
218 | int size, ret = 0; |
219 | u32 off; |
220 | |
221 | while (1) { |
222 | bool has_tx_ulp; |
223 | |
224 | sge = sk_msg_elem(msg, msg->sg.start); |
225 | size = (apply && apply_bytes < sge->length) ? |
226 | apply_bytes : sge->length; |
227 | off = sge->offset; |
228 | page = sg_page(sge); |
229 | |
230 | tcp_rate_check_app_limited(sk); |
231 | retry: |
232 | has_tx_ulp = tls_sw_has_ctx_tx(sk); |
233 | if (has_tx_ulp) { |
234 | flags |= MSG_SENDPAGE_NOPOLICY; |
235 | ret = kernel_sendpage_locked(sk, |
236 | page, off, size, flags); |
237 | } else { |
238 | ret = do_tcp_sendpages(sk, page, off, size, flags); |
239 | } |
240 | |
241 | if (ret <= 0) |
242 | return ret; |
243 | if (apply) |
244 | apply_bytes -= ret; |
245 | msg->sg.size -= ret; |
246 | sge->offset += ret; |
247 | sge->length -= ret; |
248 | if (uncharge) |
249 | sk_mem_uncharge(sk, ret); |
250 | if (ret != size) { |
251 | size -= ret; |
252 | off += ret; |
253 | goto retry; |
254 | } |
255 | if (!sge->length) { |
256 | put_page(page); |
257 | sk_msg_iter_next(msg, start); |
258 | sg_init_table(sge, 1); |
259 | if (msg->sg.start == msg->sg.end) |
260 | break; |
261 | } |
262 | if (apply && !apply_bytes) |
263 | break; |
264 | } |
265 | |
266 | return 0; |
267 | } |
268 | |
269 | static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, |
270 | u32 apply_bytes, int flags, bool uncharge) |
271 | { |
272 | int ret; |
273 | |
274 | lock_sock(sk); |
275 | ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); |
276 | release_sock(sk); |
277 | return ret; |
278 | } |
279 | |
280 | int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, |
281 | u32 bytes, int flags) |
282 | { |
283 | bool ingress = sk_msg_to_ingress(msg); |
284 | struct sk_psock *psock = sk_psock_get(sk); |
285 | int ret; |
286 | |
287 | if (unlikely(!psock)) { |
288 | sk_msg_free(sk, msg); |
289 | return 0; |
290 | } |
291 | ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : |
292 | tcp_bpf_push_locked(sk, msg, bytes, flags, false); |
293 | sk_psock_put(sk, psock); |
294 | return ret; |
295 | } |
296 | EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); |
297 | |
298 | static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, |
299 | struct sk_msg *msg, int *copied, int flags) |
300 | { |
301 | bool cork = false, enospc = msg->sg.start == msg->sg.end; |
302 | struct sock *sk_redir; |
303 | u32 tosend, delta = 0; |
304 | int ret; |
305 | |
306 | more_data: |
307 | if (psock->eval == __SK_NONE) { |
308 | /* Track delta in msg size to add/subtract it on SK_DROP from |
309 | * returned to user copied size. This ensures user doesn't |
310 | * get a positive return code with msg_cut_data and SK_DROP |
311 | * verdict. |
312 | */ |
313 | delta = msg->sg.size; |
314 | psock->eval = sk_psock_msg_verdict(sk, psock, msg); |
315 | if (msg->sg.size < delta) |
316 | delta -= msg->sg.size; |
317 | else |
318 | delta = 0; |
319 | } |
320 | |
321 | if (msg->cork_bytes && |
322 | msg->cork_bytes > msg->sg.size && !enospc) { |
323 | psock->cork_bytes = msg->cork_bytes - msg->sg.size; |
324 | if (!psock->cork) { |
325 | psock->cork = kzalloc(sizeof(*psock->cork), |
326 | GFP_ATOMIC | __GFP_NOWARN); |
327 | if (!psock->cork) |
328 | return -ENOMEM; |
329 | } |
330 | memcpy(psock->cork, msg, sizeof(*msg)); |
331 | return 0; |
332 | } |
333 | |
334 | tosend = msg->sg.size; |
335 | if (psock->apply_bytes && psock->apply_bytes < tosend) |
336 | tosend = psock->apply_bytes; |
337 | |
338 | switch (psock->eval) { |
339 | case __SK_PASS: |
340 | ret = tcp_bpf_push(sk, msg, tosend, flags, true); |
341 | if (unlikely(ret)) { |
342 | *copied -= sk_msg_free(sk, msg); |
343 | break; |
344 | } |
345 | sk_msg_apply_bytes(psock, tosend); |
346 | break; |
347 | case __SK_REDIRECT: |
348 | sk_redir = psock->sk_redir; |
349 | sk_msg_apply_bytes(psock, tosend); |
350 | if (psock->cork) { |
351 | cork = true; |
352 | psock->cork = NULL; |
353 | } |
354 | sk_msg_return(sk, msg, tosend); |
355 | release_sock(sk); |
356 | ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); |
357 | lock_sock(sk); |
358 | if (unlikely(ret < 0)) { |
359 | int free = sk_msg_free_nocharge(sk, msg); |
360 | |
361 | if (!cork) |
362 | *copied -= free; |
363 | } |
364 | if (cork) { |
365 | sk_msg_free(sk, msg); |
366 | kfree(msg); |
367 | msg = NULL; |
368 | ret = 0; |
369 | } |
370 | break; |
371 | case __SK_DROP: |
372 | default: |
373 | sk_msg_free_partial(sk, msg, tosend); |
374 | sk_msg_apply_bytes(psock, tosend); |
375 | *copied -= (tosend + delta); |
376 | return -EACCES; |
377 | } |
378 | |
379 | if (likely(!ret)) { |
380 | if (!psock->apply_bytes) { |
381 | psock->eval = __SK_NONE; |
382 | if (psock->sk_redir) { |
383 | sock_put(psock->sk_redir); |
384 | psock->sk_redir = NULL; |
385 | } |
386 | } |
387 | if (msg && |
388 | msg->sg.data[msg->sg.start].page_link && |
389 | msg->sg.data[msg->sg.start].length) |
390 | goto more_data; |
391 | } |
392 | return ret; |
393 | } |
394 | |
395 | static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) |
396 | { |
397 | struct sk_msg tmp, *msg_tx = NULL; |
398 | int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS; |
399 | int copied = 0, err = 0; |
400 | struct sk_psock *psock; |
401 | long timeo; |
402 | |
403 | psock = sk_psock_get(sk); |
404 | if (unlikely(!psock)) |
405 | return tcp_sendmsg(sk, msg, size); |
406 | |
407 | lock_sock(sk); |
408 | timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); |
409 | while (msg_data_left(msg)) { |
410 | bool enospc = false; |
411 | u32 copy, osize; |
412 | |
413 | if (sk->sk_err) { |
414 | err = -sk->sk_err; |
415 | goto out_err; |
416 | } |
417 | |
418 | copy = msg_data_left(msg); |
419 | if (!sk_stream_memory_free(sk)) |
420 | goto wait_for_sndbuf; |
421 | if (psock->cork) { |
422 | msg_tx = psock->cork; |
423 | } else { |
424 | msg_tx = &tmp; |
425 | sk_msg_init(msg_tx); |
426 | } |
427 | |
428 | osize = msg_tx->sg.size; |
429 | err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); |
430 | if (err) { |
431 | if (err != -ENOSPC) |
432 | goto wait_for_memory; |
433 | enospc = true; |
434 | copy = msg_tx->sg.size - osize; |
435 | } |
436 | |
437 | err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, |
438 | copy); |
439 | if (err < 0) { |
440 | sk_msg_trim(sk, msg_tx, osize); |
441 | goto out_err; |
442 | } |
443 | |
444 | copied += copy; |
445 | if (psock->cork_bytes) { |
446 | if (size > psock->cork_bytes) |
447 | psock->cork_bytes = 0; |
448 | else |
449 | psock->cork_bytes -= size; |
450 | if (psock->cork_bytes && !enospc) |
451 | goto out_err; |
452 | /* All cork bytes are accounted, rerun the prog. */ |
453 | psock->eval = __SK_NONE; |
454 | psock->cork_bytes = 0; |
455 | } |
456 | |
457 | err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); |
458 | if (unlikely(err < 0)) |
459 | goto out_err; |
460 | continue; |
461 | wait_for_sndbuf: |
462 | set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); |
463 | wait_for_memory: |
464 | err = sk_stream_wait_memory(sk, &timeo); |
465 | if (err) { |
466 | if (msg_tx && msg_tx != psock->cork) |
467 | sk_msg_free(sk, msg_tx); |
468 | goto out_err; |
469 | } |
470 | } |
471 | out_err: |
472 | if (err < 0) |
473 | err = sk_stream_error(sk, msg->msg_flags, err); |
474 | release_sock(sk); |
475 | sk_psock_put(sk, psock); |
476 | return copied ? copied : err; |
477 | } |
478 | |
479 | static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, |
480 | size_t size, int flags) |
481 | { |
482 | struct sk_msg tmp, *msg = NULL; |
483 | int err = 0, copied = 0; |
484 | struct sk_psock *psock; |
485 | bool enospc = false; |
486 | |
487 | psock = sk_psock_get(sk); |
488 | if (unlikely(!psock)) |
489 | return tcp_sendpage(sk, page, offset, size, flags); |
490 | |
491 | lock_sock(sk); |
492 | if (psock->cork) { |
493 | msg = psock->cork; |
494 | } else { |
495 | msg = &tmp; |
496 | sk_msg_init(msg); |
497 | } |
498 | |
499 | /* Catch case where ring is full and sendpage is stalled. */ |
500 | if (unlikely(sk_msg_full(msg))) |
501 | goto out_err; |
502 | |
503 | sk_msg_page_add(msg, page, size, offset); |
504 | sk_mem_charge(sk, size); |
505 | copied = size; |
506 | if (sk_msg_full(msg)) |
507 | enospc = true; |
508 | if (psock->cork_bytes) { |
509 | if (size > psock->cork_bytes) |
510 | psock->cork_bytes = 0; |
511 | else |
512 | psock->cork_bytes -= size; |
513 | if (psock->cork_bytes && !enospc) |
514 | goto out_err; |
515 | /* All cork bytes are accounted, rerun the prog. */ |
516 | psock->eval = __SK_NONE; |
517 | psock->cork_bytes = 0; |
518 | } |
519 | |
520 | err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); |
521 | out_err: |
522 | release_sock(sk); |
523 | sk_psock_put(sk, psock); |
524 | return copied ? copied : err; |
525 | } |
526 | |
527 | static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) |
528 | { |
529 | struct sk_psock_link *link; |
530 | |
531 | sk_psock_cork_free(psock); |
532 | __sk_psock_purge_ingress_msg(psock); |
533 | while ((link = sk_psock_link_pop(psock))) { |
534 | sk_psock_unlink(sk, link); |
535 | sk_psock_free_link(link); |
536 | } |
537 | } |
538 | |
539 | static void tcp_bpf_unhash(struct sock *sk) |
540 | { |
541 | void (*saved_unhash)(struct sock *sk); |
542 | struct sk_psock *psock; |
543 | |
544 | rcu_read_lock(); |
545 | psock = sk_psock(sk); |
546 | if (unlikely(!psock)) { |
547 | rcu_read_unlock(); |
548 | if (sk->sk_prot->unhash) |
549 | sk->sk_prot->unhash(sk); |
550 | return; |
551 | } |
552 | |
553 | saved_unhash = psock->saved_unhash; |
554 | tcp_bpf_remove(sk, psock); |
555 | rcu_read_unlock(); |
556 | saved_unhash(sk); |
557 | } |
558 | |
559 | static void tcp_bpf_close(struct sock *sk, long timeout) |
560 | { |
561 | void (*saved_close)(struct sock *sk, long timeout); |
562 | struct sk_psock *psock; |
563 | |
564 | lock_sock(sk); |
565 | rcu_read_lock(); |
566 | psock = sk_psock(sk); |
567 | if (unlikely(!psock)) { |
568 | rcu_read_unlock(); |
569 | release_sock(sk); |
570 | return sk->sk_prot->close(sk, timeout); |
571 | } |
572 | |
573 | saved_close = psock->saved_close; |
574 | tcp_bpf_remove(sk, psock); |
575 | rcu_read_unlock(); |
576 | release_sock(sk); |
577 | saved_close(sk, timeout); |
578 | } |
579 | |
580 | enum { |
581 | TCP_BPF_IPV4, |
582 | TCP_BPF_IPV6, |
583 | TCP_BPF_NUM_PROTS, |
584 | }; |
585 | |
586 | enum { |
587 | TCP_BPF_BASE, |
588 | TCP_BPF_TX, |
589 | TCP_BPF_NUM_CFGS, |
590 | }; |
591 | |
592 | static struct proto *tcpv6_prot_saved __read_mostly; |
593 | static DEFINE_SPINLOCK(tcpv6_prot_lock); |
594 | static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; |
595 | |
596 | static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], |
597 | struct proto *base) |
598 | { |
599 | prot[TCP_BPF_BASE] = *base; |
600 | prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; |
601 | prot[TCP_BPF_BASE].close = tcp_bpf_close; |
602 | prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; |
603 | prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; |
604 | |
605 | prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; |
606 | prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; |
607 | prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; |
608 | } |
609 | |
610 | static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) |
611 | { |
612 | if (sk->sk_family == AF_INET6 && |
613 | unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { |
614 | spin_lock_bh(&tcpv6_prot_lock); |
615 | if (likely(ops != tcpv6_prot_saved)) { |
616 | tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); |
617 | smp_store_release(&tcpv6_prot_saved, ops); |
618 | } |
619 | spin_unlock_bh(&tcpv6_prot_lock); |
620 | } |
621 | } |
622 | |
623 | static int __init tcp_bpf_v4_build_proto(void) |
624 | { |
625 | tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); |
626 | return 0; |
627 | } |
628 | core_initcall(tcp_bpf_v4_build_proto); |
629 | |
630 | static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) |
631 | { |
632 | int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; |
633 | int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; |
634 | |
635 | sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); |
636 | } |
637 | |
638 | static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) |
639 | { |
640 | int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; |
641 | int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; |
642 | |
643 | /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed |
644 | * or added requiring sk_prot hook updates. We keep original saved |
645 | * hooks in this case. |
646 | */ |
647 | sk->sk_prot = &tcp_bpf_prots[family][config]; |
648 | } |
649 | |
650 | static int tcp_bpf_assert_proto_ops(struct proto *ops) |
651 | { |
652 | /* In order to avoid retpoline, we make assumptions when we call |
653 | * into ops if e.g. a psock is not present. Make sure they are |
654 | * indeed valid assumptions. |
655 | */ |
656 | return ops->recvmsg == tcp_recvmsg && |
657 | ops->sendmsg == tcp_sendmsg && |
658 | ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; |
659 | } |
660 | |
661 | void tcp_bpf_reinit(struct sock *sk) |
662 | { |
663 | struct sk_psock *psock; |
664 | |
665 | sock_owned_by_me(sk); |
666 | |
667 | rcu_read_lock(); |
668 | psock = sk_psock(sk); |
669 | tcp_bpf_reinit_sk_prot(sk, psock); |
670 | rcu_read_unlock(); |
671 | } |
672 | |
673 | int tcp_bpf_init(struct sock *sk) |
674 | { |
675 | struct proto *ops = READ_ONCE(sk->sk_prot); |
676 | struct sk_psock *psock; |
677 | |
678 | sock_owned_by_me(sk); |
679 | |
680 | rcu_read_lock(); |
681 | psock = sk_psock(sk); |
682 | if (unlikely(!psock || psock->sk_proto || |
683 | tcp_bpf_assert_proto_ops(ops))) { |
684 | rcu_read_unlock(); |
685 | return -EINVAL; |
686 | } |
687 | tcp_bpf_check_v6_needs_rebuild(sk, ops); |
688 | tcp_bpf_update_sk_prot(sk, psock); |
689 | rcu_read_unlock(); |
690 | return 0; |
691 | } |
692 | |