1 | // SPDX-License-Identifier: GPL-2.0 |
2 | |
3 | #define _GNU_SOURCE |
4 | |
5 | #include <assert.h> |
6 | #include <errno.h> |
7 | #include <fcntl.h> |
8 | #include <limits.h> |
9 | #include <string.h> |
10 | #include <stdarg.h> |
11 | #include <stdbool.h> |
12 | #include <stdint.h> |
13 | #include <inttypes.h> |
14 | #include <stdio.h> |
15 | #include <stdlib.h> |
16 | #include <strings.h> |
17 | #include <unistd.h> |
18 | #include <time.h> |
19 | |
20 | #include <sys/ioctl.h> |
21 | #include <sys/random.h> |
22 | #include <sys/socket.h> |
23 | #include <sys/types.h> |
24 | #include <sys/wait.h> |
25 | |
26 | #include <netdb.h> |
27 | #include <netinet/in.h> |
28 | |
29 | #include <linux/tcp.h> |
30 | #include <linux/sockios.h> |
31 | |
32 | #ifndef IPPROTO_MPTCP |
33 | #define IPPROTO_MPTCP 262 |
34 | #endif |
35 | #ifndef SOL_MPTCP |
36 | #define SOL_MPTCP 284 |
37 | #endif |
38 | |
39 | static int pf = AF_INET; |
40 | static int proto_tx = IPPROTO_MPTCP; |
41 | static int proto_rx = IPPROTO_MPTCP; |
42 | |
43 | static void die_perror(const char *msg) |
44 | { |
45 | perror(msg); |
46 | exit(1); |
47 | } |
48 | |
49 | static void die_usage(int r) |
50 | { |
51 | fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n" ); |
52 | exit(r); |
53 | } |
54 | |
55 | static void xerror(const char *fmt, ...) |
56 | { |
57 | va_list ap; |
58 | |
59 | va_start(ap, fmt); |
60 | vfprintf(stderr, fmt, ap); |
61 | va_end(ap); |
62 | fputc('\n', stderr); |
63 | exit(1); |
64 | } |
65 | |
66 | static const char *getxinfo_strerr(int err) |
67 | { |
68 | if (err == EAI_SYSTEM) |
69 | return strerror(errno); |
70 | |
71 | return gai_strerror(err); |
72 | } |
73 | |
74 | static void xgetaddrinfo(const char *node, const char *service, |
75 | const struct addrinfo *hints, |
76 | struct addrinfo **res) |
77 | { |
78 | int err = getaddrinfo(node, service, hints, res); |
79 | |
80 | if (err) { |
81 | const char *errstr = getxinfo_strerr(err); |
82 | |
83 | fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n" , |
84 | node ? node : "" , service ? service : "" , errstr); |
85 | exit(1); |
86 | } |
87 | } |
88 | |
89 | static int sock_listen_mptcp(const char * const listenaddr, |
90 | const char * const port) |
91 | { |
92 | int sock = -1; |
93 | struct addrinfo hints = { |
94 | .ai_protocol = IPPROTO_TCP, |
95 | .ai_socktype = SOCK_STREAM, |
96 | .ai_flags = AI_PASSIVE | AI_NUMERICHOST |
97 | }; |
98 | |
99 | hints.ai_family = pf; |
100 | |
101 | struct addrinfo *a, *addr; |
102 | int one = 1; |
103 | |
104 | xgetaddrinfo(node: listenaddr, service: port, hints: &hints, res: &addr); |
105 | hints.ai_family = pf; |
106 | |
107 | for (a = addr; a; a = a->ai_next) { |
108 | sock = socket(a->ai_family, a->ai_socktype, proto_rx); |
109 | if (sock < 0) |
110 | continue; |
111 | |
112 | if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, |
113 | sizeof(one))) |
114 | perror("setsockopt" ); |
115 | |
116 | if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) |
117 | break; /* success */ |
118 | |
119 | perror("bind" ); |
120 | close(sock); |
121 | sock = -1; |
122 | } |
123 | |
124 | freeaddrinfo(addr); |
125 | |
126 | if (sock < 0) |
127 | xerror(fmt: "could not create listen socket" ); |
128 | |
129 | if (listen(sock, 20)) |
130 | die_perror(msg: "listen" ); |
131 | |
132 | return sock; |
133 | } |
134 | |
135 | static int sock_connect_mptcp(const char * const remoteaddr, |
136 | const char * const port, int proto) |
137 | { |
138 | struct addrinfo hints = { |
139 | .ai_protocol = IPPROTO_TCP, |
140 | .ai_socktype = SOCK_STREAM, |
141 | }; |
142 | struct addrinfo *a, *addr; |
143 | int sock = -1; |
144 | |
145 | hints.ai_family = pf; |
146 | |
147 | xgetaddrinfo(node: remoteaddr, service: port, hints: &hints, res: &addr); |
148 | for (a = addr; a; a = a->ai_next) { |
149 | sock = socket(a->ai_family, a->ai_socktype, proto); |
150 | if (sock < 0) |
151 | continue; |
152 | |
153 | if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) |
154 | break; /* success */ |
155 | |
156 | die_perror(msg: "connect" ); |
157 | } |
158 | |
159 | if (sock < 0) |
160 | xerror(fmt: "could not create connect socket" ); |
161 | |
162 | freeaddrinfo(addr); |
163 | return sock; |
164 | } |
165 | |
166 | static int protostr_to_num(const char *s) |
167 | { |
168 | if (strcasecmp(s1: s, s2: "tcp" ) == 0) |
169 | return IPPROTO_TCP; |
170 | if (strcasecmp(s1: s, s2: "mptcp" ) == 0) |
171 | return IPPROTO_MPTCP; |
172 | |
173 | die_usage(r: 1); |
174 | return 0; |
175 | } |
176 | |
177 | static void parse_opts(int argc, char **argv) |
178 | { |
179 | int c; |
180 | |
181 | while ((c = getopt(argc, argv, "h6t:r:" )) != -1) { |
182 | switch (c) { |
183 | case 'h': |
184 | die_usage(r: 0); |
185 | break; |
186 | case '6': |
187 | pf = AF_INET6; |
188 | break; |
189 | case 't': |
190 | proto_tx = protostr_to_num(s: optarg); |
191 | break; |
192 | case 'r': |
193 | proto_rx = protostr_to_num(s: optarg); |
194 | break; |
195 | default: |
196 | die_usage(r: 1); |
197 | break; |
198 | } |
199 | } |
200 | } |
201 | |
202 | /* wait up to timeout milliseconds */ |
203 | static void wait_for_ack(int fd, int timeout, size_t total) |
204 | { |
205 | int i; |
206 | |
207 | for (i = 0; i < timeout; i++) { |
208 | int nsd, ret, queued = -1; |
209 | struct timespec req; |
210 | |
211 | ret = ioctl(fd, TIOCOUTQ, &queued); |
212 | if (ret < 0) |
213 | die_perror(msg: "TIOCOUTQ" ); |
214 | |
215 | ret = ioctl(fd, SIOCOUTQNSD, &nsd); |
216 | if (ret < 0) |
217 | die_perror(msg: "SIOCOUTQNSD" ); |
218 | |
219 | if ((size_t)queued > total) |
220 | xerror(fmt: "TIOCOUTQ %u, but only %zu expected\n" , queued, total); |
221 | assert(nsd <= queued); |
222 | |
223 | if (queued == 0) |
224 | return; |
225 | |
226 | /* wait for peer to ack rx of all data */ |
227 | req.tv_sec = 0; |
228 | req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */ |
229 | nanosleep(&req, NULL); |
230 | } |
231 | |
232 | xerror(fmt: "still tx data queued after %u ms\n" , timeout); |
233 | } |
234 | |
235 | static void connect_one_server(int fd, int unixfd) |
236 | { |
237 | size_t len, i, total, sent; |
238 | char buf[4096], buf2[4096]; |
239 | ssize_t ret; |
240 | |
241 | len = rand() % (sizeof(buf) - 1); |
242 | |
243 | if (len < 128) |
244 | len = 128; |
245 | |
246 | for (i = 0; i < len ; i++) { |
247 | buf[i] = rand() % 26; |
248 | buf[i] += 'A'; |
249 | } |
250 | |
251 | buf[i] = '\n'; |
252 | |
253 | /* un-block server */ |
254 | ret = read(unixfd, buf2, 4); |
255 | assert(ret == 4); |
256 | |
257 | assert(strncmp(buf2, "xmit" , 4) == 0); |
258 | |
259 | ret = write(unixfd, &len, sizeof(len)); |
260 | assert(ret == (ssize_t)sizeof(len)); |
261 | |
262 | ret = write(fd, buf, len); |
263 | if (ret < 0) |
264 | die_perror(msg: "write" ); |
265 | |
266 | if (ret != (ssize_t)len) |
267 | xerror(fmt: "short write" ); |
268 | |
269 | ret = read(unixfd, buf2, 4); |
270 | assert(strncmp(buf2, "huge" , 4) == 0); |
271 | |
272 | total = rand() % (16 * 1024 * 1024); |
273 | total += (1 * 1024 * 1024); |
274 | sent = total; |
275 | |
276 | ret = write(unixfd, &total, sizeof(total)); |
277 | assert(ret == (ssize_t)sizeof(total)); |
278 | |
279 | wait_for_ack(fd, timeout: 5000, total: len); |
280 | |
281 | while (total > 0) { |
282 | if (total > sizeof(buf)) |
283 | len = sizeof(buf); |
284 | else |
285 | len = total; |
286 | |
287 | ret = write(fd, buf, len); |
288 | if (ret < 0) |
289 | die_perror(msg: "write" ); |
290 | total -= ret; |
291 | |
292 | /* we don't have to care about buf content, only |
293 | * number of total bytes sent |
294 | */ |
295 | } |
296 | |
297 | ret = read(unixfd, buf2, 4); |
298 | assert(ret == 4); |
299 | assert(strncmp(buf2, "shut" , 4) == 0); |
300 | |
301 | wait_for_ack(fd, timeout: 5000, total: sent); |
302 | |
303 | ret = write(fd, buf, 1); |
304 | assert(ret == 1); |
305 | close(fd); |
306 | ret = write(unixfd, "closed" , 6); |
307 | assert(ret == 6); |
308 | |
309 | close(unixfd); |
310 | } |
311 | |
312 | static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv) |
313 | { |
314 | struct cmsghdr *cmsg; |
315 | |
316 | for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) { |
317 | if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) { |
318 | memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv)); |
319 | return; |
320 | } |
321 | } |
322 | |
323 | xerror(fmt: "could not find TCP_CM_INQ cmsg type" ); |
324 | } |
325 | |
326 | static void process_one_client(int fd, int unixfd) |
327 | { |
328 | unsigned int tcp_inq; |
329 | size_t expect_len; |
330 | char msg_buf[4096]; |
331 | char buf[4096]; |
332 | char tmp[16]; |
333 | struct iovec iov = { |
334 | .iov_base = buf, |
335 | .iov_len = 1, |
336 | }; |
337 | struct msghdr msg = { |
338 | .msg_iov = &iov, |
339 | .msg_iovlen = 1, |
340 | .msg_control = msg_buf, |
341 | .msg_controllen = sizeof(msg_buf), |
342 | }; |
343 | ssize_t ret, tot; |
344 | |
345 | ret = write(unixfd, "xmit" , 4); |
346 | assert(ret == 4); |
347 | |
348 | ret = read(unixfd, &expect_len, sizeof(expect_len)); |
349 | assert(ret == (ssize_t)sizeof(expect_len)); |
350 | |
351 | if (expect_len > sizeof(buf)) |
352 | xerror(fmt: "expect len %zu exceeds buffer size" , expect_len); |
353 | |
354 | for (;;) { |
355 | struct timespec req; |
356 | unsigned int queued; |
357 | |
358 | ret = ioctl(fd, FIONREAD, &queued); |
359 | if (ret < 0) |
360 | die_perror(msg: "FIONREAD" ); |
361 | if (queued > expect_len) |
362 | xerror(fmt: "FIONREAD returned %u, but only %zu expected\n" , |
363 | queued, expect_len); |
364 | if (queued == expect_len) |
365 | break; |
366 | |
367 | req.tv_sec = 0; |
368 | req.tv_nsec = 1000 * 1000ul; |
369 | nanosleep(&req, NULL); |
370 | } |
371 | |
372 | /* read one byte, expect cmsg to return expected - 1 */ |
373 | ret = recvmsg(fd, &msg, 0); |
374 | if (ret < 0) |
375 | die_perror(msg: "recvmsg" ); |
376 | |
377 | if (msg.msg_controllen == 0) |
378 | xerror(fmt: "msg_controllen is 0" ); |
379 | |
380 | get_tcp_inq(msgh: &msg, inqv: &tcp_inq); |
381 | |
382 | assert((size_t)tcp_inq == (expect_len - 1)); |
383 | |
384 | iov.iov_len = sizeof(buf); |
385 | ret = recvmsg(fd, &msg, 0); |
386 | if (ret < 0) |
387 | die_perror(msg: "recvmsg" ); |
388 | |
389 | /* should have gotten exact remainder of all pending data */ |
390 | assert(ret == (ssize_t)tcp_inq); |
391 | |
392 | /* should be 0, all drained */ |
393 | get_tcp_inq(msgh: &msg, inqv: &tcp_inq); |
394 | assert(tcp_inq == 0); |
395 | |
396 | /* request a large swath of data. */ |
397 | ret = write(unixfd, "huge" , 4); |
398 | assert(ret == 4); |
399 | |
400 | ret = read(unixfd, &expect_len, sizeof(expect_len)); |
401 | assert(ret == (ssize_t)sizeof(expect_len)); |
402 | |
403 | /* peer should send us a few mb of data */ |
404 | if (expect_len <= sizeof(buf)) |
405 | xerror(fmt: "expect len %zu too small\n" , expect_len); |
406 | |
407 | tot = 0; |
408 | do { |
409 | iov.iov_len = sizeof(buf); |
410 | ret = recvmsg(fd, &msg, 0); |
411 | if (ret < 0) |
412 | die_perror(msg: "recvmsg" ); |
413 | |
414 | tot += ret; |
415 | |
416 | get_tcp_inq(msgh: &msg, inqv: &tcp_inq); |
417 | |
418 | if (tcp_inq > expect_len - tot) |
419 | xerror(fmt: "inq %d, remaining %d total_len %d\n" , |
420 | tcp_inq, expect_len - tot, (int)expect_len); |
421 | |
422 | assert(tcp_inq <= expect_len - tot); |
423 | } while ((size_t)tot < expect_len); |
424 | |
425 | ret = write(unixfd, "shut" , 4); |
426 | assert(ret == 4); |
427 | |
428 | /* wait for hangup. Should have received one more byte of data. */ |
429 | ret = read(unixfd, tmp, sizeof(tmp)); |
430 | assert(ret == 6); |
431 | assert(strncmp(tmp, "closed" , 6) == 0); |
432 | |
433 | sleep(1); |
434 | |
435 | iov.iov_len = 1; |
436 | ret = recvmsg(fd, &msg, 0); |
437 | if (ret < 0) |
438 | die_perror(msg: "recvmsg" ); |
439 | assert(ret == 1); |
440 | |
441 | get_tcp_inq(msgh: &msg, inqv: &tcp_inq); |
442 | |
443 | /* tcp_inq should be 1 due to received fin. */ |
444 | assert(tcp_inq == 1); |
445 | |
446 | iov.iov_len = 1; |
447 | ret = recvmsg(fd, &msg, 0); |
448 | if (ret < 0) |
449 | die_perror(msg: "recvmsg" ); |
450 | |
451 | /* expect EOF */ |
452 | assert(ret == 0); |
453 | get_tcp_inq(msgh: &msg, inqv: &tcp_inq); |
454 | assert(tcp_inq == 1); |
455 | |
456 | close(fd); |
457 | } |
458 | |
459 | static int xaccept(int s) |
460 | { |
461 | int fd = accept(s, NULL, 0); |
462 | |
463 | if (fd < 0) |
464 | die_perror(msg: "accept" ); |
465 | |
466 | return fd; |
467 | } |
468 | |
469 | static int server(int unixfd) |
470 | { |
471 | int fd = -1, r, on = 1; |
472 | |
473 | switch (pf) { |
474 | case AF_INET: |
475 | fd = sock_listen_mptcp(listenaddr: "127.0.0.1" , port: "15432" ); |
476 | break; |
477 | case AF_INET6: |
478 | fd = sock_listen_mptcp(listenaddr: "::1" , port: "15432" ); |
479 | break; |
480 | default: |
481 | xerror(fmt: "Unknown pf %d\n" , pf); |
482 | break; |
483 | } |
484 | |
485 | r = write(unixfd, "conn" , 4); |
486 | assert(r == 4); |
487 | |
488 | alarm(15); |
489 | r = xaccept(s: fd); |
490 | |
491 | if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on))) |
492 | die_perror(msg: "setsockopt" ); |
493 | |
494 | process_one_client(fd: r, unixfd); |
495 | |
496 | return 0; |
497 | } |
498 | |
499 | static int client(int unixfd) |
500 | { |
501 | int fd = -1; |
502 | |
503 | alarm(15); |
504 | |
505 | switch (pf) { |
506 | case AF_INET: |
507 | fd = sock_connect_mptcp(remoteaddr: "127.0.0.1" , port: "15432" , proto: proto_tx); |
508 | break; |
509 | case AF_INET6: |
510 | fd = sock_connect_mptcp(remoteaddr: "::1" , port: "15432" , proto: proto_tx); |
511 | break; |
512 | default: |
513 | xerror(fmt: "Unknown pf %d\n" , pf); |
514 | } |
515 | |
516 | connect_one_server(fd, unixfd); |
517 | |
518 | return 0; |
519 | } |
520 | |
521 | static void init_rng(void) |
522 | { |
523 | unsigned int foo; |
524 | |
525 | if (getrandom(&foo, sizeof(foo), 0) == -1) { |
526 | perror("getrandom" ); |
527 | exit(1); |
528 | } |
529 | |
530 | srand(foo); |
531 | } |
532 | |
533 | static pid_t xfork(void) |
534 | { |
535 | pid_t p = fork(); |
536 | |
537 | if (p < 0) |
538 | die_perror(msg: "fork" ); |
539 | else if (p == 0) |
540 | init_rng(); |
541 | |
542 | return p; |
543 | } |
544 | |
545 | static int rcheck(int wstatus, const char *what) |
546 | { |
547 | if (WIFEXITED(wstatus)) { |
548 | if (WEXITSTATUS(wstatus) == 0) |
549 | return 0; |
550 | fprintf(stderr, "%s exited, status=%d\n" , what, WEXITSTATUS(wstatus)); |
551 | return WEXITSTATUS(wstatus); |
552 | } else if (WIFSIGNALED(wstatus)) { |
553 | xerror(fmt: "%s killed by signal %d\n" , what, WTERMSIG(wstatus)); |
554 | } else if (WIFSTOPPED(wstatus)) { |
555 | xerror(fmt: "%s stopped by signal %d\n" , what, WSTOPSIG(wstatus)); |
556 | } |
557 | |
558 | return 111; |
559 | } |
560 | |
561 | int main(int argc, char *argv[]) |
562 | { |
563 | int e1, e2, wstatus; |
564 | pid_t s, c, ret; |
565 | int unixfds[2]; |
566 | |
567 | parse_opts(argc, argv); |
568 | |
569 | e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds); |
570 | if (e1 < 0) |
571 | die_perror(msg: "pipe" ); |
572 | |
573 | s = xfork(); |
574 | if (s == 0) |
575 | return server(unixfd: unixfds[1]); |
576 | |
577 | close(unixfds[1]); |
578 | |
579 | /* wait until server bound a socket */ |
580 | e1 = read(unixfds[0], &e1, 4); |
581 | assert(e1 == 4); |
582 | |
583 | c = xfork(); |
584 | if (c == 0) |
585 | return client(unixfd: unixfds[0]); |
586 | |
587 | close(unixfds[0]); |
588 | |
589 | ret = waitpid(s, &wstatus, 0); |
590 | if (ret == -1) |
591 | die_perror(msg: "waitpid" ); |
592 | e1 = rcheck(wstatus, what: "server" ); |
593 | ret = waitpid(c, &wstatus, 0); |
594 | if (ret == -1) |
595 | die_perror(msg: "waitpid" ); |
596 | e2 = rcheck(wstatus, what: "client" ); |
597 | |
598 | return e1 ? e1 : e2; |
599 | } |
600 | |