1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ |
3 | |
4 | #include <linux/skmsg.h> |
5 | #include <net/sock.h> |
6 | #include <net/udp.h> |
7 | #include <net/inet_common.h> |
8 | |
9 | #include "udp_impl.h" |
10 | |
11 | static struct proto *udpv6_prot_saved __read_mostly; |
12 | |
13 | static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, |
14 | int flags, int *addr_len) |
15 | { |
16 | #if IS_ENABLED(CONFIG_IPV6) |
17 | if (sk->sk_family == AF_INET6) |
18 | return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len); |
19 | #endif |
20 | return udp_prot.recvmsg(sk, msg, len, flags, addr_len); |
21 | } |
22 | |
23 | static bool udp_sk_has_data(struct sock *sk) |
24 | { |
25 | return !skb_queue_empty(list: &udp_sk(sk)->reader_queue) || |
26 | !skb_queue_empty(list: &sk->sk_receive_queue); |
27 | } |
28 | |
29 | static bool psock_has_data(struct sk_psock *psock) |
30 | { |
31 | return !skb_queue_empty(list: &psock->ingress_skb) || |
32 | !sk_psock_queue_empty(psock); |
33 | } |
34 | |
35 | #define udp_msg_has_data(__sk, __psock) \ |
36 | ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) |
37 | |
38 | static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, |
39 | long timeo) |
40 | { |
41 | DEFINE_WAIT_FUNC(wait, woken_wake_function); |
42 | int ret = 0; |
43 | |
44 | if (sk->sk_shutdown & RCV_SHUTDOWN) |
45 | return 1; |
46 | |
47 | if (!timeo) |
48 | return ret; |
49 | |
50 | add_wait_queue(wq_head: sk_sleep(sk), wq_entry: &wait); |
51 | sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); |
52 | ret = udp_msg_has_data(sk, psock); |
53 | if (!ret) { |
54 | wait_woken(wq_entry: &wait, TASK_INTERRUPTIBLE, timeout: timeo); |
55 | ret = udp_msg_has_data(sk, psock); |
56 | } |
57 | sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); |
58 | remove_wait_queue(wq_head: sk_sleep(sk), wq_entry: &wait); |
59 | return ret; |
60 | } |
61 | |
62 | static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, |
63 | int flags, int *addr_len) |
64 | { |
65 | struct sk_psock *psock; |
66 | int copied, ret; |
67 | |
68 | if (unlikely(flags & MSG_ERRQUEUE)) |
69 | return inet_recv_error(sk, msg, len, addr_len); |
70 | |
71 | if (!len) |
72 | return 0; |
73 | |
74 | psock = sk_psock_get(sk); |
75 | if (unlikely(!psock)) |
76 | return sk_udp_recvmsg(sk, msg, len, flags, addr_len); |
77 | |
78 | if (!psock_has_data(psock)) { |
79 | ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); |
80 | goto out; |
81 | } |
82 | |
83 | msg_bytes_ready: |
84 | copied = sk_msg_recvmsg(sk, psock, msg, len, flags); |
85 | if (!copied) { |
86 | long timeo; |
87 | int data; |
88 | |
89 | timeo = sock_rcvtimeo(sk, noblock: flags & MSG_DONTWAIT); |
90 | data = udp_msg_wait_data(sk, psock, timeo); |
91 | if (data) { |
92 | if (psock_has_data(psock)) |
93 | goto msg_bytes_ready; |
94 | ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); |
95 | goto out; |
96 | } |
97 | copied = -EAGAIN; |
98 | } |
99 | ret = copied; |
100 | out: |
101 | sk_psock_put(sk, psock); |
102 | return ret; |
103 | } |
104 | |
105 | enum { |
106 | UDP_BPF_IPV4, |
107 | UDP_BPF_IPV6, |
108 | UDP_BPF_NUM_PROTS, |
109 | }; |
110 | |
111 | static DEFINE_SPINLOCK(udpv6_prot_lock); |
112 | static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; |
113 | |
114 | static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) |
115 | { |
116 | *prot = *base; |
117 | prot->close = sock_map_close; |
118 | prot->recvmsg = udp_bpf_recvmsg; |
119 | prot->sock_is_readable = sk_msg_is_readable; |
120 | } |
121 | |
122 | static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) |
123 | { |
124 | if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { |
125 | spin_lock_bh(lock: &udpv6_prot_lock); |
126 | if (likely(ops != udpv6_prot_saved)) { |
127 | udp_bpf_rebuild_protos(prot: &udp_bpf_prots[UDP_BPF_IPV6], base: ops); |
128 | smp_store_release(&udpv6_prot_saved, ops); |
129 | } |
130 | spin_unlock_bh(lock: &udpv6_prot_lock); |
131 | } |
132 | } |
133 | |
134 | static int __init udp_bpf_v4_build_proto(void) |
135 | { |
136 | udp_bpf_rebuild_protos(prot: &udp_bpf_prots[UDP_BPF_IPV4], base: &udp_prot); |
137 | return 0; |
138 | } |
139 | late_initcall(udp_bpf_v4_build_proto); |
140 | |
141 | int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) |
142 | { |
143 | int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; |
144 | |
145 | if (restore) { |
146 | sk->sk_write_space = psock->saved_write_space; |
147 | sock_replace_proto(sk, proto: psock->sk_proto); |
148 | return 0; |
149 | } |
150 | |
151 | if (sk->sk_family == AF_INET6) |
152 | udp_bpf_check_v6_needs_rebuild(ops: psock->sk_proto); |
153 | |
154 | sock_replace_proto(sk, proto: &udp_bpf_prots[family]); |
155 | return 0; |
156 | } |
157 | EXPORT_SYMBOL_GPL(udp_bpf_update_proto); |
158 | |