1// SPDX-License-Identifier: GPL-2.0
2
3#define _GNU_SOURCE
4
5#include <assert.h>
6#include <errno.h>
7#include <fcntl.h>
8#include <limits.h>
9#include <string.h>
10#include <stdarg.h>
11#include <stdbool.h>
12#include <stdint.h>
13#include <inttypes.h>
14#include <stdio.h>
15#include <stdlib.h>
16#include <strings.h>
17#include <unistd.h>
18#include <time.h>
19
20#include <sys/ioctl.h>
21#include <sys/random.h>
22#include <sys/socket.h>
23#include <sys/types.h>
24#include <sys/wait.h>
25
26#include <netdb.h>
27#include <netinet/in.h>
28
29#include <linux/tcp.h>
30#include <linux/sockios.h>
31
32#ifndef IPPROTO_MPTCP
33#define IPPROTO_MPTCP 262
34#endif
35#ifndef SOL_MPTCP
36#define SOL_MPTCP 284
37#endif
38
39static int pf = AF_INET;
40static int proto_tx = IPPROTO_MPTCP;
41static int proto_rx = IPPROTO_MPTCP;
42
43static void die_perror(const char *msg)
44{
45 perror(msg);
46 exit(1);
47}
48
49static void die_usage(int r)
50{
51 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
52 exit(r);
53}
54
55static void xerror(const char *fmt, ...)
56{
57 va_list ap;
58
59 va_start(ap, fmt);
60 vfprintf(stderr, fmt, ap);
61 va_end(ap);
62 fputc('\n', stderr);
63 exit(1);
64}
65
66static const char *getxinfo_strerr(int err)
67{
68 if (err == EAI_SYSTEM)
69 return strerror(errno);
70
71 return gai_strerror(err);
72}
73
74static void xgetaddrinfo(const char *node, const char *service,
75 const struct addrinfo *hints,
76 struct addrinfo **res)
77{
78 int err = getaddrinfo(node, service, hints, res);
79
80 if (err) {
81 const char *errstr = getxinfo_strerr(err);
82
83 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
84 node ? node : "", service ? service : "", errstr);
85 exit(1);
86 }
87}
88
89static int sock_listen_mptcp(const char * const listenaddr,
90 const char * const port)
91{
92 int sock = -1;
93 struct addrinfo hints = {
94 .ai_protocol = IPPROTO_TCP,
95 .ai_socktype = SOCK_STREAM,
96 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
97 };
98
99 hints.ai_family = pf;
100
101 struct addrinfo *a, *addr;
102 int one = 1;
103
104 xgetaddrinfo(node: listenaddr, service: port, hints: &hints, res: &addr);
105 hints.ai_family = pf;
106
107 for (a = addr; a; a = a->ai_next) {
108 sock = socket(a->ai_family, a->ai_socktype, proto_rx);
109 if (sock < 0)
110 continue;
111
112 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
113 sizeof(one)))
114 perror("setsockopt");
115
116 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
117 break; /* success */
118
119 perror("bind");
120 close(sock);
121 sock = -1;
122 }
123
124 freeaddrinfo(addr);
125
126 if (sock < 0)
127 xerror(fmt: "could not create listen socket");
128
129 if (listen(sock, 20))
130 die_perror(msg: "listen");
131
132 return sock;
133}
134
135static int sock_connect_mptcp(const char * const remoteaddr,
136 const char * const port, int proto)
137{
138 struct addrinfo hints = {
139 .ai_protocol = IPPROTO_TCP,
140 .ai_socktype = SOCK_STREAM,
141 };
142 struct addrinfo *a, *addr;
143 int sock = -1;
144
145 hints.ai_family = pf;
146
147 xgetaddrinfo(node: remoteaddr, service: port, hints: &hints, res: &addr);
148 for (a = addr; a; a = a->ai_next) {
149 sock = socket(a->ai_family, a->ai_socktype, proto);
150 if (sock < 0)
151 continue;
152
153 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
154 break; /* success */
155
156 die_perror(msg: "connect");
157 }
158
159 if (sock < 0)
160 xerror(fmt: "could not create connect socket");
161
162 freeaddrinfo(addr);
163 return sock;
164}
165
166static int protostr_to_num(const char *s)
167{
168 if (strcasecmp(s1: s, s2: "tcp") == 0)
169 return IPPROTO_TCP;
170 if (strcasecmp(s1: s, s2: "mptcp") == 0)
171 return IPPROTO_MPTCP;
172
173 die_usage(r: 1);
174 return 0;
175}
176
177static void parse_opts(int argc, char **argv)
178{
179 int c;
180
181 while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
182 switch (c) {
183 case 'h':
184 die_usage(r: 0);
185 break;
186 case '6':
187 pf = AF_INET6;
188 break;
189 case 't':
190 proto_tx = protostr_to_num(s: optarg);
191 break;
192 case 'r':
193 proto_rx = protostr_to_num(s: optarg);
194 break;
195 default:
196 die_usage(r: 1);
197 break;
198 }
199 }
200}
201
202/* wait up to timeout milliseconds */
203static void wait_for_ack(int fd, int timeout, size_t total)
204{
205 int i;
206
207 for (i = 0; i < timeout; i++) {
208 int nsd, ret, queued = -1;
209 struct timespec req;
210
211 ret = ioctl(fd, TIOCOUTQ, &queued);
212 if (ret < 0)
213 die_perror(msg: "TIOCOUTQ");
214
215 ret = ioctl(fd, SIOCOUTQNSD, &nsd);
216 if (ret < 0)
217 die_perror(msg: "SIOCOUTQNSD");
218
219 if ((size_t)queued > total)
220 xerror(fmt: "TIOCOUTQ %u, but only %zu expected\n", queued, total);
221 assert(nsd <= queued);
222
223 if (queued == 0)
224 return;
225
226 /* wait for peer to ack rx of all data */
227 req.tv_sec = 0;
228 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
229 nanosleep(&req, NULL);
230 }
231
232 xerror(fmt: "still tx data queued after %u ms\n", timeout);
233}
234
235static void connect_one_server(int fd, int unixfd)
236{
237 size_t len, i, total, sent;
238 char buf[4096], buf2[4096];
239 ssize_t ret;
240
241 len = rand() % (sizeof(buf) - 1);
242
243 if (len < 128)
244 len = 128;
245
246 for (i = 0; i < len ; i++) {
247 buf[i] = rand() % 26;
248 buf[i] += 'A';
249 }
250
251 buf[i] = '\n';
252
253 /* un-block server */
254 ret = read(unixfd, buf2, 4);
255 assert(ret == 4);
256
257 assert(strncmp(buf2, "xmit", 4) == 0);
258
259 ret = write(unixfd, &len, sizeof(len));
260 assert(ret == (ssize_t)sizeof(len));
261
262 ret = write(fd, buf, len);
263 if (ret < 0)
264 die_perror(msg: "write");
265
266 if (ret != (ssize_t)len)
267 xerror(fmt: "short write");
268
269 ret = read(unixfd, buf2, 4);
270 assert(strncmp(buf2, "huge", 4) == 0);
271
272 total = rand() % (16 * 1024 * 1024);
273 total += (1 * 1024 * 1024);
274 sent = total;
275
276 ret = write(unixfd, &total, sizeof(total));
277 assert(ret == (ssize_t)sizeof(total));
278
279 wait_for_ack(fd, timeout: 5000, total: len);
280
281 while (total > 0) {
282 if (total > sizeof(buf))
283 len = sizeof(buf);
284 else
285 len = total;
286
287 ret = write(fd, buf, len);
288 if (ret < 0)
289 die_perror(msg: "write");
290 total -= ret;
291
292 /* we don't have to care about buf content, only
293 * number of total bytes sent
294 */
295 }
296
297 ret = read(unixfd, buf2, 4);
298 assert(ret == 4);
299 assert(strncmp(buf2, "shut", 4) == 0);
300
301 wait_for_ack(fd, timeout: 5000, total: sent);
302
303 ret = write(fd, buf, 1);
304 assert(ret == 1);
305 close(fd);
306 ret = write(unixfd, "closed", 6);
307 assert(ret == 6);
308
309 close(unixfd);
310}
311
312static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
313{
314 struct cmsghdr *cmsg;
315
316 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
317 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
318 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
319 return;
320 }
321 }
322
323 xerror(fmt: "could not find TCP_CM_INQ cmsg type");
324}
325
326static void process_one_client(int fd, int unixfd)
327{
328 unsigned int tcp_inq;
329 size_t expect_len;
330 char msg_buf[4096];
331 char buf[4096];
332 char tmp[16];
333 struct iovec iov = {
334 .iov_base = buf,
335 .iov_len = 1,
336 };
337 struct msghdr msg = {
338 .msg_iov = &iov,
339 .msg_iovlen = 1,
340 .msg_control = msg_buf,
341 .msg_controllen = sizeof(msg_buf),
342 };
343 ssize_t ret, tot;
344
345 ret = write(unixfd, "xmit", 4);
346 assert(ret == 4);
347
348 ret = read(unixfd, &expect_len, sizeof(expect_len));
349 assert(ret == (ssize_t)sizeof(expect_len));
350
351 if (expect_len > sizeof(buf))
352 xerror(fmt: "expect len %zu exceeds buffer size", expect_len);
353
354 for (;;) {
355 struct timespec req;
356 unsigned int queued;
357
358 ret = ioctl(fd, FIONREAD, &queued);
359 if (ret < 0)
360 die_perror(msg: "FIONREAD");
361 if (queued > expect_len)
362 xerror(fmt: "FIONREAD returned %u, but only %zu expected\n",
363 queued, expect_len);
364 if (queued == expect_len)
365 break;
366
367 req.tv_sec = 0;
368 req.tv_nsec = 1000 * 1000ul;
369 nanosleep(&req, NULL);
370 }
371
372 /* read one byte, expect cmsg to return expected - 1 */
373 ret = recvmsg(fd, &msg, 0);
374 if (ret < 0)
375 die_perror(msg: "recvmsg");
376
377 if (msg.msg_controllen == 0)
378 xerror(fmt: "msg_controllen is 0");
379
380 get_tcp_inq(msgh: &msg, inqv: &tcp_inq);
381
382 assert((size_t)tcp_inq == (expect_len - 1));
383
384 iov.iov_len = sizeof(buf);
385 ret = recvmsg(fd, &msg, 0);
386 if (ret < 0)
387 die_perror(msg: "recvmsg");
388
389 /* should have gotten exact remainder of all pending data */
390 assert(ret == (ssize_t)tcp_inq);
391
392 /* should be 0, all drained */
393 get_tcp_inq(msgh: &msg, inqv: &tcp_inq);
394 assert(tcp_inq == 0);
395
396 /* request a large swath of data. */
397 ret = write(unixfd, "huge", 4);
398 assert(ret == 4);
399
400 ret = read(unixfd, &expect_len, sizeof(expect_len));
401 assert(ret == (ssize_t)sizeof(expect_len));
402
403 /* peer should send us a few mb of data */
404 if (expect_len <= sizeof(buf))
405 xerror(fmt: "expect len %zu too small\n", expect_len);
406
407 tot = 0;
408 do {
409 iov.iov_len = sizeof(buf);
410 ret = recvmsg(fd, &msg, 0);
411 if (ret < 0)
412 die_perror(msg: "recvmsg");
413
414 tot += ret;
415
416 get_tcp_inq(msgh: &msg, inqv: &tcp_inq);
417
418 if (tcp_inq > expect_len - tot)
419 xerror(fmt: "inq %d, remaining %d total_len %d\n",
420 tcp_inq, expect_len - tot, (int)expect_len);
421
422 assert(tcp_inq <= expect_len - tot);
423 } while ((size_t)tot < expect_len);
424
425 ret = write(unixfd, "shut", 4);
426 assert(ret == 4);
427
428 /* wait for hangup. Should have received one more byte of data. */
429 ret = read(unixfd, tmp, sizeof(tmp));
430 assert(ret == 6);
431 assert(strncmp(tmp, "closed", 6) == 0);
432
433 sleep(1);
434
435 iov.iov_len = 1;
436 ret = recvmsg(fd, &msg, 0);
437 if (ret < 0)
438 die_perror(msg: "recvmsg");
439 assert(ret == 1);
440
441 get_tcp_inq(msgh: &msg, inqv: &tcp_inq);
442
443 /* tcp_inq should be 1 due to received fin. */
444 assert(tcp_inq == 1);
445
446 iov.iov_len = 1;
447 ret = recvmsg(fd, &msg, 0);
448 if (ret < 0)
449 die_perror(msg: "recvmsg");
450
451 /* expect EOF */
452 assert(ret == 0);
453 get_tcp_inq(msgh: &msg, inqv: &tcp_inq);
454 assert(tcp_inq == 1);
455
456 close(fd);
457}
458
459static int xaccept(int s)
460{
461 int fd = accept(s, NULL, 0);
462
463 if (fd < 0)
464 die_perror(msg: "accept");
465
466 return fd;
467}
468
469static int server(int unixfd)
470{
471 int fd = -1, r, on = 1;
472
473 switch (pf) {
474 case AF_INET:
475 fd = sock_listen_mptcp(listenaddr: "127.0.0.1", port: "15432");
476 break;
477 case AF_INET6:
478 fd = sock_listen_mptcp(listenaddr: "::1", port: "15432");
479 break;
480 default:
481 xerror(fmt: "Unknown pf %d\n", pf);
482 break;
483 }
484
485 r = write(unixfd, "conn", 4);
486 assert(r == 4);
487
488 alarm(15);
489 r = xaccept(s: fd);
490
491 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
492 die_perror(msg: "setsockopt");
493
494 process_one_client(fd: r, unixfd);
495
496 return 0;
497}
498
499static int client(int unixfd)
500{
501 int fd = -1;
502
503 alarm(15);
504
505 switch (pf) {
506 case AF_INET:
507 fd = sock_connect_mptcp(remoteaddr: "127.0.0.1", port: "15432", proto: proto_tx);
508 break;
509 case AF_INET6:
510 fd = sock_connect_mptcp(remoteaddr: "::1", port: "15432", proto: proto_tx);
511 break;
512 default:
513 xerror(fmt: "Unknown pf %d\n", pf);
514 }
515
516 connect_one_server(fd, unixfd);
517
518 return 0;
519}
520
521static void init_rng(void)
522{
523 unsigned int foo;
524
525 if (getrandom(&foo, sizeof(foo), 0) == -1) {
526 perror("getrandom");
527 exit(1);
528 }
529
530 srand(foo);
531}
532
533static pid_t xfork(void)
534{
535 pid_t p = fork();
536
537 if (p < 0)
538 die_perror(msg: "fork");
539 else if (p == 0)
540 init_rng();
541
542 return p;
543}
544
545static int rcheck(int wstatus, const char *what)
546{
547 if (WIFEXITED(wstatus)) {
548 if (WEXITSTATUS(wstatus) == 0)
549 return 0;
550 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
551 return WEXITSTATUS(wstatus);
552 } else if (WIFSIGNALED(wstatus)) {
553 xerror(fmt: "%s killed by signal %d\n", what, WTERMSIG(wstatus));
554 } else if (WIFSTOPPED(wstatus)) {
555 xerror(fmt: "%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
556 }
557
558 return 111;
559}
560
561int main(int argc, char *argv[])
562{
563 int e1, e2, wstatus;
564 pid_t s, c, ret;
565 int unixfds[2];
566
567 parse_opts(argc, argv);
568
569 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
570 if (e1 < 0)
571 die_perror(msg: "pipe");
572
573 s = xfork();
574 if (s == 0)
575 return server(unixfd: unixfds[1]);
576
577 close(unixfds[1]);
578
579 /* wait until server bound a socket */
580 e1 = read(unixfds[0], &e1, 4);
581 assert(e1 == 4);
582
583 c = xfork();
584 if (c == 0)
585 return client(unixfd: unixfds[0]);
586
587 close(unixfds[0]);
588
589 ret = waitpid(s, &wstatus, 0);
590 if (ret == -1)
591 die_perror(msg: "waitpid");
592 e1 = rcheck(wstatus, what: "server");
593 ret = waitpid(c, &wstatus, 0);
594 if (ret == -1)
595 die_perror(msg: "waitpid");
596 e2 = rcheck(wstatus, what: "client");
597
598 return e1 ? e1 : e2;
599}
600

source code of linux/tools/testing/selftests/net/mptcp/mptcp_inq.c