1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */ |
3 | |
4 | #include <linux/skmsg.h> |
5 | #include <linux/bpf.h> |
6 | #include <net/sock.h> |
7 | #include <net/af_unix.h> |
8 | |
9 | #define unix_sk_has_data(__sk, __psock) \ |
10 | ({ !skb_queue_empty(&__sk->sk_receive_queue) || \ |
11 | !skb_queue_empty(&__psock->ingress_skb) || \ |
12 | !list_empty(&__psock->ingress_msg); \ |
13 | }) |
14 | |
15 | static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock, |
16 | long timeo) |
17 | { |
18 | DEFINE_WAIT_FUNC(wait, woken_wake_function); |
19 | struct unix_sock *u = unix_sk(sk); |
20 | int ret = 0; |
21 | |
22 | if (sk->sk_shutdown & RCV_SHUTDOWN) |
23 | return 1; |
24 | |
25 | if (!timeo) |
26 | return ret; |
27 | |
28 | add_wait_queue(wq_head: sk_sleep(sk), wq_entry: &wait); |
29 | sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); |
30 | if (!unix_sk_has_data(sk, psock)) { |
31 | mutex_unlock(lock: &u->iolock); |
32 | wait_woken(wq_entry: &wait, TASK_INTERRUPTIBLE, timeout: timeo); |
33 | mutex_lock(&u->iolock); |
34 | ret = unix_sk_has_data(sk, psock); |
35 | } |
36 | sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); |
37 | remove_wait_queue(wq_head: sk_sleep(sk), wq_entry: &wait); |
38 | return ret; |
39 | } |
40 | |
41 | static int __unix_recvmsg(struct sock *sk, struct msghdr *msg, |
42 | size_t len, int flags) |
43 | { |
44 | if (sk->sk_type == SOCK_DGRAM) |
45 | return __unix_dgram_recvmsg(sk, msg, size: len, flags); |
46 | else |
47 | return __unix_stream_recvmsg(sk, msg, size: len, flags); |
48 | } |
49 | |
50 | static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg, |
51 | size_t len, int flags, int *addr_len) |
52 | { |
53 | struct unix_sock *u = unix_sk(sk); |
54 | struct sk_psock *psock; |
55 | int copied; |
56 | |
57 | if (!len) |
58 | return 0; |
59 | |
60 | psock = sk_psock_get(sk); |
61 | if (unlikely(!psock)) |
62 | return __unix_recvmsg(sk, msg, len, flags); |
63 | |
64 | mutex_lock(&u->iolock); |
65 | if (!skb_queue_empty(list: &sk->sk_receive_queue) && |
66 | sk_psock_queue_empty(psock)) { |
67 | mutex_unlock(lock: &u->iolock); |
68 | sk_psock_put(sk, psock); |
69 | return __unix_recvmsg(sk, msg, len, flags); |
70 | } |
71 | |
72 | msg_bytes_ready: |
73 | copied = sk_msg_recvmsg(sk, psock, msg, len, flags); |
74 | if (!copied) { |
75 | long timeo; |
76 | int data; |
77 | |
78 | timeo = sock_rcvtimeo(sk, noblock: flags & MSG_DONTWAIT); |
79 | data = unix_msg_wait_data(sk, psock, timeo); |
80 | if (data) { |
81 | if (!sk_psock_queue_empty(psock)) |
82 | goto msg_bytes_ready; |
83 | mutex_unlock(lock: &u->iolock); |
84 | sk_psock_put(sk, psock); |
85 | return __unix_recvmsg(sk, msg, len, flags); |
86 | } |
87 | copied = -EAGAIN; |
88 | } |
89 | mutex_unlock(lock: &u->iolock); |
90 | sk_psock_put(sk, psock); |
91 | return copied; |
92 | } |
93 | |
94 | static struct proto *unix_dgram_prot_saved __read_mostly; |
95 | static DEFINE_SPINLOCK(unix_dgram_prot_lock); |
96 | static struct proto unix_dgram_bpf_prot; |
97 | |
98 | static struct proto *unix_stream_prot_saved __read_mostly; |
99 | static DEFINE_SPINLOCK(unix_stream_prot_lock); |
100 | static struct proto unix_stream_bpf_prot; |
101 | |
102 | static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base) |
103 | { |
104 | *prot = *base; |
105 | prot->close = sock_map_close; |
106 | prot->recvmsg = unix_bpf_recvmsg; |
107 | prot->sock_is_readable = sk_msg_is_readable; |
108 | } |
109 | |
110 | static void unix_stream_bpf_rebuild_protos(struct proto *prot, |
111 | const struct proto *base) |
112 | { |
113 | *prot = *base; |
114 | prot->close = sock_map_close; |
115 | prot->recvmsg = unix_bpf_recvmsg; |
116 | prot->sock_is_readable = sk_msg_is_readable; |
117 | prot->unhash = sock_map_unhash; |
118 | } |
119 | |
120 | static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops) |
121 | { |
122 | if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) { |
123 | spin_lock_bh(lock: &unix_dgram_prot_lock); |
124 | if (likely(ops != unix_dgram_prot_saved)) { |
125 | unix_dgram_bpf_rebuild_protos(prot: &unix_dgram_bpf_prot, base: ops); |
126 | smp_store_release(&unix_dgram_prot_saved, ops); |
127 | } |
128 | spin_unlock_bh(lock: &unix_dgram_prot_lock); |
129 | } |
130 | } |
131 | |
132 | static void unix_stream_bpf_check_needs_rebuild(struct proto *ops) |
133 | { |
134 | if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) { |
135 | spin_lock_bh(lock: &unix_stream_prot_lock); |
136 | if (likely(ops != unix_stream_prot_saved)) { |
137 | unix_stream_bpf_rebuild_protos(prot: &unix_stream_bpf_prot, base: ops); |
138 | smp_store_release(&unix_stream_prot_saved, ops); |
139 | } |
140 | spin_unlock_bh(lock: &unix_stream_prot_lock); |
141 | } |
142 | } |
143 | |
144 | int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) |
145 | { |
146 | if (sk->sk_type != SOCK_DGRAM) |
147 | return -EOPNOTSUPP; |
148 | |
149 | if (restore) { |
150 | sk->sk_write_space = psock->saved_write_space; |
151 | sock_replace_proto(sk, proto: psock->sk_proto); |
152 | return 0; |
153 | } |
154 | |
155 | unix_dgram_bpf_check_needs_rebuild(ops: psock->sk_proto); |
156 | sock_replace_proto(sk, proto: &unix_dgram_bpf_prot); |
157 | return 0; |
158 | } |
159 | |
160 | int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) |
161 | { |
162 | struct sock *sk_pair; |
163 | |
164 | /* Restore does not decrement the sk_pair reference yet because we must |
165 | * keep the a reference to the socket until after an RCU grace period |
166 | * and any pending sends have completed. |
167 | */ |
168 | if (restore) { |
169 | sk->sk_write_space = psock->saved_write_space; |
170 | sock_replace_proto(sk, proto: psock->sk_proto); |
171 | return 0; |
172 | } |
173 | |
174 | /* psock_update_sk_prot can be called multiple times if psock is |
175 | * added to multiple maps and/or slots in the same map. There is |
176 | * also an edge case where replacing a psock with itself can trigger |
177 | * an extra psock_update_sk_prot during the insert process. So it |
178 | * must be safe to do multiple calls. Here we need to ensure we don't |
179 | * increment the refcnt through sock_hold many times. There will only |
180 | * be a single matching destroy operation. |
181 | */ |
182 | if (!psock->sk_pair) { |
183 | sk_pair = unix_peer(sk); |
184 | sock_hold(sk: sk_pair); |
185 | psock->sk_pair = sk_pair; |
186 | } |
187 | |
188 | unix_stream_bpf_check_needs_rebuild(ops: psock->sk_proto); |
189 | sock_replace_proto(sk, proto: &unix_stream_bpf_prot); |
190 | return 0; |
191 | } |
192 | |
193 | void __init unix_bpf_build_proto(void) |
194 | { |
195 | unix_dgram_bpf_rebuild_protos(prot: &unix_dgram_bpf_prot, base: &unix_dgram_proto); |
196 | unix_stream_bpf_rebuild_protos(prot: &unix_stream_bpf_prot, base: &unix_stream_proto); |
197 | |
198 | } |
199 | |