1 | // SPDX-License-Identifier: GPL-2.0 |
2 | #include <net/tcp.h> |
3 | #include <net/strparser.h> |
4 | #include <net/xfrm.h> |
5 | #include <net/esp.h> |
6 | #include <net/espintcp.h> |
7 | #include <linux/skmsg.h> |
8 | #include <net/inet_common.h> |
9 | #include <trace/events/sock.h> |
10 | #if IS_ENABLED(CONFIG_IPV6) |
11 | #include <net/ipv6_stubs.h> |
12 | #endif |
13 | #include <net/hotdata.h> |
14 | |
15 | static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb, |
16 | struct sock *sk) |
17 | { |
18 | if (atomic_read(v: &sk->sk_rmem_alloc) >= sk->sk_rcvbuf || |
19 | !sk_rmem_schedule(sk, skb, size: skb->truesize)) { |
20 | XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR); |
21 | kfree_skb(skb); |
22 | return; |
23 | } |
24 | |
25 | skb_set_owner_r(skb, sk); |
26 | |
27 | memset(skb->cb, 0, sizeof(skb->cb)); |
28 | skb_queue_tail(list: &ctx->ike_queue, newsk: skb); |
29 | ctx->saved_data_ready(sk); |
30 | } |
31 | |
32 | static void handle_esp(struct sk_buff *skb, struct sock *sk) |
33 | { |
34 | struct tcp_skb_cb *tcp_cb = (struct tcp_skb_cb *)skb->cb; |
35 | |
36 | skb_reset_transport_header(skb); |
37 | |
38 | /* restore IP CB, we need at least IP6CB->nhoff */ |
39 | memmove(skb->cb, &tcp_cb->header, sizeof(tcp_cb->header)); |
40 | |
41 | rcu_read_lock(); |
42 | skb->dev = dev_get_by_index_rcu(net: sock_net(sk), ifindex: skb->skb_iif); |
43 | local_bh_disable(); |
44 | #if IS_ENABLED(CONFIG_IPV6) |
45 | if (sk->sk_family == AF_INET6) |
46 | ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); |
47 | else |
48 | #endif |
49 | xfrm4_rcv_encap(skb, IPPROTO_ESP, spi: 0, TCP_ENCAP_ESPINTCP); |
50 | local_bh_enable(); |
51 | rcu_read_unlock(); |
52 | } |
53 | |
54 | static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb) |
55 | { |
56 | struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx, |
57 | strp); |
58 | struct strp_msg *rxm = strp_msg(skb); |
59 | int len = rxm->full_len - 2; |
60 | u32 nonesp_marker; |
61 | int err; |
62 | |
63 | /* keepalive packet? */ |
64 | if (unlikely(len == 1)) { |
65 | u8 data; |
66 | |
67 | err = skb_copy_bits(skb, offset: rxm->offset + 2, to: &data, len: 1); |
68 | if (err < 0) { |
69 | XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); |
70 | kfree_skb(skb); |
71 | return; |
72 | } |
73 | |
74 | if (data == 0xff) { |
75 | kfree_skb(skb); |
76 | return; |
77 | } |
78 | } |
79 | |
80 | /* drop other short messages */ |
81 | if (unlikely(len <= sizeof(nonesp_marker))) { |
82 | XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); |
83 | kfree_skb(skb); |
84 | return; |
85 | } |
86 | |
87 | err = skb_copy_bits(skb, offset: rxm->offset + 2, to: &nonesp_marker, |
88 | len: sizeof(nonesp_marker)); |
89 | if (err < 0) { |
90 | XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); |
91 | kfree_skb(skb); |
92 | return; |
93 | } |
94 | |
95 | /* remove header, leave non-ESP marker/SPI */ |
96 | if (!pskb_pull(skb, len: rxm->offset + 2)) { |
97 | XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR); |
98 | kfree_skb(skb); |
99 | return; |
100 | } |
101 | |
102 | if (pskb_trim(skb, len: rxm->full_len - 2) != 0) { |
103 | XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR); |
104 | kfree_skb(skb); |
105 | return; |
106 | } |
107 | |
108 | if (nonesp_marker == 0) |
109 | handle_nonesp(ctx, skb, sk: strp->sk); |
110 | else |
111 | handle_esp(skb, sk: strp->sk); |
112 | } |
113 | |
114 | static int espintcp_parse(struct strparser *strp, struct sk_buff *skb) |
115 | { |
116 | struct strp_msg *rxm = strp_msg(skb); |
117 | __be16 blen; |
118 | u16 len; |
119 | int err; |
120 | |
121 | if (skb->len < rxm->offset + 2) |
122 | return 0; |
123 | |
124 | err = skb_copy_bits(skb, offset: rxm->offset, to: &blen, len: sizeof(blen)); |
125 | if (err < 0) |
126 | return err; |
127 | |
128 | len = be16_to_cpu(blen); |
129 | if (len < 2) |
130 | return -EINVAL; |
131 | |
132 | return len; |
133 | } |
134 | |
135 | static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, |
136 | int flags, int *addr_len) |
137 | { |
138 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
139 | struct sk_buff *skb; |
140 | int err = 0; |
141 | int copied; |
142 | int off = 0; |
143 | |
144 | skb = __skb_recv_datagram(sk, sk_queue: &ctx->ike_queue, flags, off: &off, err: &err); |
145 | if (!skb) { |
146 | if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN) |
147 | return 0; |
148 | return err; |
149 | } |
150 | |
151 | copied = len; |
152 | if (copied > skb->len) |
153 | copied = skb->len; |
154 | else if (copied < skb->len) |
155 | msg->msg_flags |= MSG_TRUNC; |
156 | |
157 | err = skb_copy_datagram_msg(from: skb, offset: 0, msg, size: copied); |
158 | if (unlikely(err)) { |
159 | kfree_skb(skb); |
160 | return err; |
161 | } |
162 | |
163 | if (flags & MSG_TRUNC) |
164 | copied = skb->len; |
165 | kfree_skb(skb); |
166 | return copied; |
167 | } |
168 | |
169 | int espintcp_queue_out(struct sock *sk, struct sk_buff *skb) |
170 | { |
171 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
172 | |
173 | if (skb_queue_len(list_: &ctx->out_queue) >= |
174 | READ_ONCE(net_hotdata.max_backlog)) |
175 | return -ENOBUFS; |
176 | |
177 | __skb_queue_tail(list: &ctx->out_queue, newsk: skb); |
178 | |
179 | return 0; |
180 | } |
181 | EXPORT_SYMBOL_GPL(espintcp_queue_out); |
182 | |
183 | /* espintcp length field is 2B and length includes the length field's size */ |
184 | #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2) |
185 | |
186 | static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg, |
187 | int flags) |
188 | { |
189 | do { |
190 | int ret; |
191 | |
192 | ret = skb_send_sock_locked(sk, skb: emsg->skb, |
193 | offset: emsg->offset, len: emsg->len); |
194 | if (ret < 0) |
195 | return ret; |
196 | |
197 | emsg->len -= ret; |
198 | emsg->offset += ret; |
199 | } while (emsg->len > 0); |
200 | |
201 | kfree_skb(skb: emsg->skb); |
202 | memset(emsg, 0, sizeof(*emsg)); |
203 | |
204 | return 0; |
205 | } |
206 | |
207 | static int espintcp_sendskmsg_locked(struct sock *sk, |
208 | struct espintcp_msg *emsg, int flags) |
209 | { |
210 | struct msghdr msghdr = { |
211 | .msg_flags = flags | MSG_SPLICE_PAGES | MSG_MORE, |
212 | }; |
213 | struct sk_msg *skmsg = &emsg->skmsg; |
214 | bool more = flags & MSG_MORE; |
215 | struct scatterlist *sg; |
216 | int done = 0; |
217 | int ret; |
218 | |
219 | sg = &skmsg->sg.data[skmsg->sg.start]; |
220 | do { |
221 | struct bio_vec bvec; |
222 | size_t size = sg->length - emsg->offset; |
223 | int offset = sg->offset + emsg->offset; |
224 | struct page *p; |
225 | |
226 | emsg->offset = 0; |
227 | |
228 | if (sg_is_last(sg) && !more) |
229 | msghdr.msg_flags &= ~MSG_MORE; |
230 | |
231 | p = sg_page(sg); |
232 | retry: |
233 | bvec_set_page(bv: &bvec, page: p, len: size, offset); |
234 | iov_iter_bvec(i: &msghdr.msg_iter, ITER_SOURCE, bvec: &bvec, nr_segs: 1, count: size); |
235 | ret = tcp_sendmsg_locked(sk, msg: &msghdr, size); |
236 | if (ret < 0) { |
237 | emsg->offset = offset - sg->offset; |
238 | skmsg->sg.start += done; |
239 | return ret; |
240 | } |
241 | |
242 | if (ret != size) { |
243 | offset += ret; |
244 | size -= ret; |
245 | goto retry; |
246 | } |
247 | |
248 | done++; |
249 | put_page(page: p); |
250 | sk_mem_uncharge(sk, size: sg->length); |
251 | sg = sg_next(sg); |
252 | } while (sg); |
253 | |
254 | memset(emsg, 0, sizeof(*emsg)); |
255 | |
256 | return 0; |
257 | } |
258 | |
259 | static int espintcp_push_msgs(struct sock *sk, int flags) |
260 | { |
261 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
262 | struct espintcp_msg *emsg = &ctx->partial; |
263 | int err; |
264 | |
265 | if (!emsg->len) |
266 | return 0; |
267 | |
268 | if (ctx->tx_running) |
269 | return -EAGAIN; |
270 | ctx->tx_running = 1; |
271 | |
272 | if (emsg->skb) |
273 | err = espintcp_sendskb_locked(sk, emsg, flags); |
274 | else |
275 | err = espintcp_sendskmsg_locked(sk, emsg, flags); |
276 | if (err == -EAGAIN) { |
277 | ctx->tx_running = 0; |
278 | return flags & MSG_DONTWAIT ? -EAGAIN : 0; |
279 | } |
280 | if (!err) |
281 | memset(emsg, 0, sizeof(*emsg)); |
282 | |
283 | ctx->tx_running = 0; |
284 | |
285 | return err; |
286 | } |
287 | |
288 | int espintcp_push_skb(struct sock *sk, struct sk_buff *skb) |
289 | { |
290 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
291 | struct espintcp_msg *emsg = &ctx->partial; |
292 | unsigned int len; |
293 | int offset; |
294 | |
295 | if (sk->sk_state != TCP_ESTABLISHED) { |
296 | kfree_skb(skb); |
297 | return -ECONNRESET; |
298 | } |
299 | |
300 | offset = skb_transport_offset(skb); |
301 | len = skb->len - offset; |
302 | |
303 | espintcp_push_msgs(sk, flags: 0); |
304 | |
305 | if (emsg->len) { |
306 | kfree_skb(skb); |
307 | return -ENOBUFS; |
308 | } |
309 | |
310 | skb_set_owner_w(skb, sk); |
311 | |
312 | emsg->offset = offset; |
313 | emsg->len = len; |
314 | emsg->skb = skb; |
315 | |
316 | espintcp_push_msgs(sk, flags: 0); |
317 | |
318 | return 0; |
319 | } |
320 | EXPORT_SYMBOL_GPL(espintcp_push_skb); |
321 | |
322 | static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) |
323 | { |
324 | long timeo = sock_sndtimeo(sk, noblock: msg->msg_flags & MSG_DONTWAIT); |
325 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
326 | struct espintcp_msg *emsg = &ctx->partial; |
327 | struct iov_iter pfx_iter; |
328 | struct kvec pfx_iov = {}; |
329 | size_t msglen = size + 2; |
330 | char buf[2] = {0}; |
331 | int err, end; |
332 | |
333 | if (msg->msg_flags & ~MSG_DONTWAIT) |
334 | return -EOPNOTSUPP; |
335 | |
336 | if (size > MAX_ESPINTCP_MSG) |
337 | return -EMSGSIZE; |
338 | |
339 | if (msg->msg_controllen) |
340 | return -EOPNOTSUPP; |
341 | |
342 | lock_sock(sk); |
343 | |
344 | err = espintcp_push_msgs(sk, flags: msg->msg_flags & MSG_DONTWAIT); |
345 | if (err < 0) { |
346 | if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT)) |
347 | err = -ENOBUFS; |
348 | goto unlock; |
349 | } |
350 | |
351 | sk_msg_init(msg: &emsg->skmsg); |
352 | while (1) { |
353 | /* only -ENOMEM is possible since we don't coalesce */ |
354 | err = sk_msg_alloc(sk, msg: &emsg->skmsg, len: msglen, elem_first_coalesce: 0); |
355 | if (!err) |
356 | break; |
357 | |
358 | err = sk_stream_wait_memory(sk, timeo_p: &timeo); |
359 | if (err) |
360 | goto fail; |
361 | } |
362 | |
363 | *((__be16 *)buf) = cpu_to_be16(msglen); |
364 | pfx_iov.iov_base = buf; |
365 | pfx_iov.iov_len = sizeof(buf); |
366 | iov_iter_kvec(i: &pfx_iter, ITER_SOURCE, kvec: &pfx_iov, nr_segs: 1, count: pfx_iov.iov_len); |
367 | |
368 | err = sk_msg_memcopy_from_iter(sk, from: &pfx_iter, msg: &emsg->skmsg, |
369 | bytes: pfx_iov.iov_len); |
370 | if (err < 0) |
371 | goto fail; |
372 | |
373 | err = sk_msg_memcopy_from_iter(sk, from: &msg->msg_iter, msg: &emsg->skmsg, bytes: size); |
374 | if (err < 0) |
375 | goto fail; |
376 | |
377 | end = emsg->skmsg.sg.end; |
378 | emsg->len = size; |
379 | sk_msg_iter_var_prev(end); |
380 | sg_mark_end(sg: sk_msg_elem(msg: &emsg->skmsg, which: end)); |
381 | |
382 | tcp_rate_check_app_limited(sk); |
383 | |
384 | err = espintcp_push_msgs(sk, flags: msg->msg_flags & MSG_DONTWAIT); |
385 | /* this message could be partially sent, keep it */ |
386 | |
387 | release_sock(sk); |
388 | |
389 | return size; |
390 | |
391 | fail: |
392 | sk_msg_free(sk, msg: &emsg->skmsg); |
393 | memset(emsg, 0, sizeof(*emsg)); |
394 | unlock: |
395 | release_sock(sk); |
396 | return err; |
397 | } |
398 | |
399 | static struct proto espintcp_prot __ro_after_init; |
400 | static struct proto_ops espintcp_ops __ro_after_init; |
401 | static struct proto espintcp6_prot; |
402 | static struct proto_ops espintcp6_ops; |
403 | static DEFINE_MUTEX(tcpv6_prot_mutex); |
404 | |
405 | static void espintcp_data_ready(struct sock *sk) |
406 | { |
407 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
408 | |
409 | trace_sk_data_ready(sk); |
410 | |
411 | strp_data_ready(strp: &ctx->strp); |
412 | } |
413 | |
414 | static void espintcp_tx_work(struct work_struct *work) |
415 | { |
416 | struct espintcp_ctx *ctx = container_of(work, |
417 | struct espintcp_ctx, work); |
418 | struct sock *sk = ctx->strp.sk; |
419 | |
420 | lock_sock(sk); |
421 | if (!ctx->tx_running) |
422 | espintcp_push_msgs(sk, flags: 0); |
423 | release_sock(sk); |
424 | } |
425 | |
426 | static void espintcp_write_space(struct sock *sk) |
427 | { |
428 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
429 | |
430 | schedule_work(work: &ctx->work); |
431 | ctx->saved_write_space(sk); |
432 | } |
433 | |
434 | static void espintcp_destruct(struct sock *sk) |
435 | { |
436 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
437 | |
438 | ctx->saved_destruct(sk); |
439 | kfree(objp: ctx); |
440 | } |
441 | |
442 | bool tcp_is_ulp_esp(struct sock *sk) |
443 | { |
444 | return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot; |
445 | } |
446 | EXPORT_SYMBOL_GPL(tcp_is_ulp_esp); |
447 | |
448 | static void build_protos(struct proto *espintcp_prot, |
449 | struct proto_ops *espintcp_ops, |
450 | const struct proto *orig_prot, |
451 | const struct proto_ops *orig_ops); |
452 | static int espintcp_init_sk(struct sock *sk) |
453 | { |
454 | struct inet_connection_sock *icsk = inet_csk(sk); |
455 | struct strp_callbacks cb = { |
456 | .rcv_msg = espintcp_rcv, |
457 | .parse_msg = espintcp_parse, |
458 | }; |
459 | struct espintcp_ctx *ctx; |
460 | int err; |
461 | |
462 | /* sockmap is not compatible with espintcp */ |
463 | if (sk->sk_user_data) |
464 | return -EBUSY; |
465 | |
466 | ctx = kzalloc(size: sizeof(*ctx), GFP_KERNEL); |
467 | if (!ctx) |
468 | return -ENOMEM; |
469 | |
470 | err = strp_init(strp: &ctx->strp, sk, cb: &cb); |
471 | if (err) |
472 | goto free; |
473 | |
474 | __sk_dst_reset(sk); |
475 | |
476 | strp_check_rcv(strp: &ctx->strp); |
477 | skb_queue_head_init(list: &ctx->ike_queue); |
478 | skb_queue_head_init(list: &ctx->out_queue); |
479 | |
480 | if (sk->sk_family == AF_INET) { |
481 | sk->sk_prot = &espintcp_prot; |
482 | sk->sk_socket->ops = &espintcp_ops; |
483 | } else { |
484 | mutex_lock(&tcpv6_prot_mutex); |
485 | if (!espintcp6_prot.recvmsg) |
486 | build_protos(espintcp_prot: &espintcp6_prot, espintcp_ops: &espintcp6_ops, orig_prot: sk->sk_prot, orig_ops: sk->sk_socket->ops); |
487 | mutex_unlock(lock: &tcpv6_prot_mutex); |
488 | |
489 | sk->sk_prot = &espintcp6_prot; |
490 | sk->sk_socket->ops = &espintcp6_ops; |
491 | } |
492 | ctx->saved_data_ready = sk->sk_data_ready; |
493 | ctx->saved_write_space = sk->sk_write_space; |
494 | ctx->saved_destruct = sk->sk_destruct; |
495 | sk->sk_data_ready = espintcp_data_ready; |
496 | sk->sk_write_space = espintcp_write_space; |
497 | sk->sk_destruct = espintcp_destruct; |
498 | rcu_assign_pointer(icsk->icsk_ulp_data, ctx); |
499 | INIT_WORK(&ctx->work, espintcp_tx_work); |
500 | |
501 | /* avoid using task_frag */ |
502 | sk->sk_allocation = GFP_ATOMIC; |
503 | sk->sk_use_task_frag = false; |
504 | |
505 | return 0; |
506 | |
507 | free: |
508 | kfree(objp: ctx); |
509 | return err; |
510 | } |
511 | |
512 | static void espintcp_release(struct sock *sk) |
513 | { |
514 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
515 | struct sk_buff_head queue; |
516 | struct sk_buff *skb; |
517 | |
518 | __skb_queue_head_init(list: &queue); |
519 | skb_queue_splice_init(list: &ctx->out_queue, head: &queue); |
520 | |
521 | while ((skb = __skb_dequeue(list: &queue))) |
522 | espintcp_push_skb(sk, skb); |
523 | |
524 | tcp_release_cb(sk); |
525 | } |
526 | |
527 | static void espintcp_close(struct sock *sk, long timeout) |
528 | { |
529 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
530 | struct espintcp_msg *emsg = &ctx->partial; |
531 | |
532 | strp_stop(strp: &ctx->strp); |
533 | |
534 | sk->sk_prot = &tcp_prot; |
535 | barrier(); |
536 | |
537 | cancel_work_sync(work: &ctx->work); |
538 | strp_done(strp: &ctx->strp); |
539 | |
540 | skb_queue_purge(list: &ctx->out_queue); |
541 | skb_queue_purge(list: &ctx->ike_queue); |
542 | |
543 | if (emsg->len) { |
544 | if (emsg->skb) |
545 | kfree_skb(skb: emsg->skb); |
546 | else |
547 | sk_msg_free(sk, msg: &emsg->skmsg); |
548 | } |
549 | |
550 | tcp_close(sk, timeout); |
551 | } |
552 | |
553 | static __poll_t espintcp_poll(struct file *file, struct socket *sock, |
554 | poll_table *wait) |
555 | { |
556 | __poll_t mask = datagram_poll(file, sock, wait); |
557 | struct sock *sk = sock->sk; |
558 | struct espintcp_ctx *ctx = espintcp_getctx(sk); |
559 | |
560 | if (!skb_queue_empty(list: &ctx->ike_queue)) |
561 | mask |= EPOLLIN | EPOLLRDNORM; |
562 | |
563 | return mask; |
564 | } |
565 | |
566 | static void build_protos(struct proto *espintcp_prot, |
567 | struct proto_ops *espintcp_ops, |
568 | const struct proto *orig_prot, |
569 | const struct proto_ops *orig_ops) |
570 | { |
571 | memcpy(espintcp_prot, orig_prot, sizeof(struct proto)); |
572 | memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops)); |
573 | espintcp_prot->sendmsg = espintcp_sendmsg; |
574 | espintcp_prot->recvmsg = espintcp_recvmsg; |
575 | espintcp_prot->close = espintcp_close; |
576 | espintcp_prot->release_cb = espintcp_release; |
577 | espintcp_ops->poll = espintcp_poll; |
578 | } |
579 | |
580 | static struct tcp_ulp_ops espintcp_ulp __read_mostly = { |
581 | .name = "espintcp" , |
582 | .owner = THIS_MODULE, |
583 | .init = espintcp_init_sk, |
584 | }; |
585 | |
586 | void __init espintcp_init(void) |
587 | { |
588 | build_protos(espintcp_prot: &espintcp_prot, espintcp_ops: &espintcp_ops, orig_prot: &tcp_prot, orig_ops: &inet_stream_ops); |
589 | |
590 | tcp_register_ulp(type: &espintcp_ulp); |
591 | } |
592 | |