1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* |
3 | * Copyright (c) 2018 Facebook |
4 | */ |
5 | #include <linux/bpf.h> |
6 | #include <linux/err.h> |
7 | #include <linux/sock_diag.h> |
8 | #include <net/sock_reuseport.h> |
9 | |
10 | struct reuseport_array { |
11 | struct bpf_map map; |
12 | struct sock __rcu *ptrs[]; |
13 | }; |
14 | |
15 | static struct reuseport_array *reuseport_array(struct bpf_map *map) |
16 | { |
17 | return (struct reuseport_array *)map; |
18 | } |
19 | |
20 | /* The caller must hold the reuseport_lock */ |
21 | void bpf_sk_reuseport_detach(struct sock *sk) |
22 | { |
23 | struct sock __rcu **socks; |
24 | |
25 | write_lock_bh(&sk->sk_callback_lock); |
26 | socks = sk->sk_user_data; |
27 | if (socks) { |
28 | WRITE_ONCE(sk->sk_user_data, NULL); |
29 | /* |
30 | * Do not move this NULL assignment outside of |
31 | * sk->sk_callback_lock because there is |
32 | * a race with reuseport_array_free() |
33 | * which does not hold the reuseport_lock. |
34 | */ |
35 | RCU_INIT_POINTER(*socks, NULL); |
36 | } |
37 | write_unlock_bh(&sk->sk_callback_lock); |
38 | } |
39 | |
40 | static int reuseport_array_alloc_check(union bpf_attr *attr) |
41 | { |
42 | if (attr->value_size != sizeof(u32) && |
43 | attr->value_size != sizeof(u64)) |
44 | return -EINVAL; |
45 | |
46 | return array_map_alloc_check(attr); |
47 | } |
48 | |
49 | static void *reuseport_array_lookup_elem(struct bpf_map *map, void *key) |
50 | { |
51 | struct reuseport_array *array = reuseport_array(map); |
52 | u32 index = *(u32 *)key; |
53 | |
54 | if (unlikely(index >= array->map.max_entries)) |
55 | return NULL; |
56 | |
57 | return rcu_dereference(array->ptrs[index]); |
58 | } |
59 | |
60 | /* Called from syscall only */ |
61 | static int reuseport_array_delete_elem(struct bpf_map *map, void *key) |
62 | { |
63 | struct reuseport_array *array = reuseport_array(map); |
64 | u32 index = *(u32 *)key; |
65 | struct sock *sk; |
66 | int err; |
67 | |
68 | if (index >= map->max_entries) |
69 | return -E2BIG; |
70 | |
71 | if (!rcu_access_pointer(array->ptrs[index])) |
72 | return -ENOENT; |
73 | |
74 | spin_lock_bh(&reuseport_lock); |
75 | |
76 | sk = rcu_dereference_protected(array->ptrs[index], |
77 | lockdep_is_held(&reuseport_lock)); |
78 | if (sk) { |
79 | write_lock_bh(&sk->sk_callback_lock); |
80 | WRITE_ONCE(sk->sk_user_data, NULL); |
81 | RCU_INIT_POINTER(array->ptrs[index], NULL); |
82 | write_unlock_bh(&sk->sk_callback_lock); |
83 | err = 0; |
84 | } else { |
85 | err = -ENOENT; |
86 | } |
87 | |
88 | spin_unlock_bh(&reuseport_lock); |
89 | |
90 | return err; |
91 | } |
92 | |
93 | static void reuseport_array_free(struct bpf_map *map) |
94 | { |
95 | struct reuseport_array *array = reuseport_array(map); |
96 | struct sock *sk; |
97 | u32 i; |
98 | |
99 | synchronize_rcu(); |
100 | |
101 | /* |
102 | * ops->map_*_elem() will not be able to access this |
103 | * array now. Hence, this function only races with |
104 | * bpf_sk_reuseport_detach() which was triggerred by |
105 | * close() or disconnect(). |
106 | * |
107 | * This function and bpf_sk_reuseport_detach() are |
108 | * both removing sk from "array". Who removes it |
109 | * first does not matter. |
110 | * |
111 | * The only concern here is bpf_sk_reuseport_detach() |
112 | * may access "array" which is being freed here. |
113 | * bpf_sk_reuseport_detach() access this "array" |
114 | * through sk->sk_user_data _and_ with sk->sk_callback_lock |
115 | * held which is enough because this "array" is not freed |
116 | * until all sk->sk_user_data has stopped referencing this "array". |
117 | * |
118 | * Hence, due to the above, taking "reuseport_lock" is not |
119 | * needed here. |
120 | */ |
121 | |
122 | /* |
123 | * Since reuseport_lock is not taken, sk is accessed under |
124 | * rcu_read_lock() |
125 | */ |
126 | rcu_read_lock(); |
127 | for (i = 0; i < map->max_entries; i++) { |
128 | sk = rcu_dereference(array->ptrs[i]); |
129 | if (sk) { |
130 | write_lock_bh(&sk->sk_callback_lock); |
131 | /* |
132 | * No need for WRITE_ONCE(). At this point, |
133 | * no one is reading it without taking the |
134 | * sk->sk_callback_lock. |
135 | */ |
136 | sk->sk_user_data = NULL; |
137 | write_unlock_bh(&sk->sk_callback_lock); |
138 | RCU_INIT_POINTER(array->ptrs[i], NULL); |
139 | } |
140 | } |
141 | rcu_read_unlock(); |
142 | |
143 | /* |
144 | * Once reaching here, all sk->sk_user_data is not |
145 | * referenceing this "array". "array" can be freed now. |
146 | */ |
147 | bpf_map_area_free(array); |
148 | } |
149 | |
150 | static struct bpf_map *reuseport_array_alloc(union bpf_attr *attr) |
151 | { |
152 | int err, numa_node = bpf_map_attr_numa_node(attr); |
153 | struct reuseport_array *array; |
154 | u64 cost, array_size; |
155 | |
156 | if (!capable(CAP_SYS_ADMIN)) |
157 | return ERR_PTR(-EPERM); |
158 | |
159 | array_size = sizeof(*array); |
160 | array_size += (u64)attr->max_entries * sizeof(struct sock *); |
161 | |
162 | /* make sure there is no u32 overflow later in round_up() */ |
163 | cost = array_size; |
164 | if (cost >= U32_MAX - PAGE_SIZE) |
165 | return ERR_PTR(-ENOMEM); |
166 | cost = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT; |
167 | |
168 | err = bpf_map_precharge_memlock(cost); |
169 | if (err) |
170 | return ERR_PTR(err); |
171 | |
172 | /* allocate all map elements and zero-initialize them */ |
173 | array = bpf_map_area_alloc(array_size, numa_node); |
174 | if (!array) |
175 | return ERR_PTR(-ENOMEM); |
176 | |
177 | /* copy mandatory map attributes */ |
178 | bpf_map_init_from_attr(&array->map, attr); |
179 | array->map.pages = cost; |
180 | |
181 | return &array->map; |
182 | } |
183 | |
184 | int bpf_fd_reuseport_array_lookup_elem(struct bpf_map *map, void *key, |
185 | void *value) |
186 | { |
187 | struct sock *sk; |
188 | int err; |
189 | |
190 | if (map->value_size != sizeof(u64)) |
191 | return -ENOSPC; |
192 | |
193 | rcu_read_lock(); |
194 | sk = reuseport_array_lookup_elem(map, key); |
195 | if (sk) { |
196 | *(u64 *)value = sock_gen_cookie(sk); |
197 | err = 0; |
198 | } else { |
199 | err = -ENOENT; |
200 | } |
201 | rcu_read_unlock(); |
202 | |
203 | return err; |
204 | } |
205 | |
206 | static int |
207 | reuseport_array_update_check(const struct reuseport_array *array, |
208 | const struct sock *nsk, |
209 | const struct sock *osk, |
210 | const struct sock_reuseport *nsk_reuse, |
211 | u32 map_flags) |
212 | { |
213 | if (osk && map_flags == BPF_NOEXIST) |
214 | return -EEXIST; |
215 | |
216 | if (!osk && map_flags == BPF_EXIST) |
217 | return -ENOENT; |
218 | |
219 | if (nsk->sk_protocol != IPPROTO_UDP && nsk->sk_protocol != IPPROTO_TCP) |
220 | return -ENOTSUPP; |
221 | |
222 | if (nsk->sk_family != AF_INET && nsk->sk_family != AF_INET6) |
223 | return -ENOTSUPP; |
224 | |
225 | if (nsk->sk_type != SOCK_STREAM && nsk->sk_type != SOCK_DGRAM) |
226 | return -ENOTSUPP; |
227 | |
228 | /* |
229 | * sk must be hashed (i.e. listening in the TCP case or binded |
230 | * in the UDP case) and |
231 | * it must also be a SO_REUSEPORT sk (i.e. reuse cannot be NULL). |
232 | * |
233 | * Also, sk will be used in bpf helper that is protected by |
234 | * rcu_read_lock(). |
235 | */ |
236 | if (!sock_flag(nsk, SOCK_RCU_FREE) || !sk_hashed(nsk) || !nsk_reuse) |
237 | return -EINVAL; |
238 | |
239 | /* READ_ONCE because the sk->sk_callback_lock may not be held here */ |
240 | if (READ_ONCE(nsk->sk_user_data)) |
241 | return -EBUSY; |
242 | |
243 | return 0; |
244 | } |
245 | |
246 | /* |
247 | * Called from syscall only. |
248 | * The "nsk" in the fd refcnt. |
249 | * The "osk" and "reuse" are protected by reuseport_lock. |
250 | */ |
251 | int bpf_fd_reuseport_array_update_elem(struct bpf_map *map, void *key, |
252 | void *value, u64 map_flags) |
253 | { |
254 | struct reuseport_array *array = reuseport_array(map); |
255 | struct sock *free_osk = NULL, *osk, *nsk; |
256 | struct sock_reuseport *reuse; |
257 | u32 index = *(u32 *)key; |
258 | struct socket *socket; |
259 | int err, fd; |
260 | |
261 | if (map_flags > BPF_EXIST) |
262 | return -EINVAL; |
263 | |
264 | if (index >= map->max_entries) |
265 | return -E2BIG; |
266 | |
267 | if (map->value_size == sizeof(u64)) { |
268 | u64 fd64 = *(u64 *)value; |
269 | |
270 | if (fd64 > S32_MAX) |
271 | return -EINVAL; |
272 | fd = fd64; |
273 | } else { |
274 | fd = *(int *)value; |
275 | } |
276 | |
277 | socket = sockfd_lookup(fd, &err); |
278 | if (!socket) |
279 | return err; |
280 | |
281 | nsk = socket->sk; |
282 | if (!nsk) { |
283 | err = -EINVAL; |
284 | goto put_file; |
285 | } |
286 | |
287 | /* Quick checks before taking reuseport_lock */ |
288 | err = reuseport_array_update_check(array, nsk, |
289 | rcu_access_pointer(array->ptrs[index]), |
290 | rcu_access_pointer(nsk->sk_reuseport_cb), |
291 | map_flags); |
292 | if (err) |
293 | goto put_file; |
294 | |
295 | spin_lock_bh(&reuseport_lock); |
296 | /* |
297 | * Some of the checks only need reuseport_lock |
298 | * but it is done under sk_callback_lock also |
299 | * for simplicity reason. |
300 | */ |
301 | write_lock_bh(&nsk->sk_callback_lock); |
302 | |
303 | osk = rcu_dereference_protected(array->ptrs[index], |
304 | lockdep_is_held(&reuseport_lock)); |
305 | reuse = rcu_dereference_protected(nsk->sk_reuseport_cb, |
306 | lockdep_is_held(&reuseport_lock)); |
307 | err = reuseport_array_update_check(array, nsk, osk, reuse, map_flags); |
308 | if (err) |
309 | goto put_file_unlock; |
310 | |
311 | /* Ensure reuse->reuseport_id is set */ |
312 | err = reuseport_get_id(reuse); |
313 | if (err < 0) |
314 | goto put_file_unlock; |
315 | |
316 | WRITE_ONCE(nsk->sk_user_data, &array->ptrs[index]); |
317 | rcu_assign_pointer(array->ptrs[index], nsk); |
318 | free_osk = osk; |
319 | err = 0; |
320 | |
321 | put_file_unlock: |
322 | write_unlock_bh(&nsk->sk_callback_lock); |
323 | |
324 | if (free_osk) { |
325 | write_lock_bh(&free_osk->sk_callback_lock); |
326 | WRITE_ONCE(free_osk->sk_user_data, NULL); |
327 | write_unlock_bh(&free_osk->sk_callback_lock); |
328 | } |
329 | |
330 | spin_unlock_bh(&reuseport_lock); |
331 | put_file: |
332 | fput(socket->file); |
333 | return err; |
334 | } |
335 | |
336 | /* Called from syscall */ |
337 | static int reuseport_array_get_next_key(struct bpf_map *map, void *key, |
338 | void *next_key) |
339 | { |
340 | struct reuseport_array *array = reuseport_array(map); |
341 | u32 index = key ? *(u32 *)key : U32_MAX; |
342 | u32 *next = (u32 *)next_key; |
343 | |
344 | if (index >= array->map.max_entries) { |
345 | *next = 0; |
346 | return 0; |
347 | } |
348 | |
349 | if (index == array->map.max_entries - 1) |
350 | return -ENOENT; |
351 | |
352 | *next = index + 1; |
353 | return 0; |
354 | } |
355 | |
356 | const struct bpf_map_ops reuseport_array_ops = { |
357 | .map_alloc_check = reuseport_array_alloc_check, |
358 | .map_alloc = reuseport_array_alloc, |
359 | .map_free = reuseport_array_free, |
360 | .map_lookup_elem = reuseport_array_lookup_elem, |
361 | .map_get_next_key = reuseport_array_get_next_key, |
362 | .map_delete_elem = reuseport_array_delete_elem, |
363 | }; |
364 | |