1 | // SPDX-License-Identifier: GPL-2.0 |
2 | |
3 | #define _GNU_SOURCE |
4 | |
5 | #include <arpa/inet.h> |
6 | #include <errno.h> |
7 | #include <error.h> |
8 | #include <fcntl.h> |
9 | #include <limits.h> |
10 | #include <linux/filter.h> |
11 | #include <linux/bpf.h> |
12 | #include <linux/if_packet.h> |
13 | #include <linux/if_vlan.h> |
14 | #include <linux/virtio_net.h> |
15 | #include <net/if.h> |
16 | #include <net/ethernet.h> |
17 | #include <netinet/ip.h> |
18 | #include <netinet/udp.h> |
19 | #include <poll.h> |
20 | #include <sched.h> |
21 | #include <stdbool.h> |
22 | #include <stdint.h> |
23 | #include <stdio.h> |
24 | #include <stdlib.h> |
25 | #include <string.h> |
26 | #include <sys/mman.h> |
27 | #include <sys/socket.h> |
28 | #include <sys/stat.h> |
29 | #include <sys/types.h> |
30 | #include <unistd.h> |
31 | |
32 | #include "psock_lib.h" |
33 | |
34 | static bool cfg_use_bind; |
35 | static bool cfg_use_csum_off; |
36 | static bool cfg_use_csum_off_bad; |
37 | static bool cfg_use_dgram; |
38 | static bool cfg_use_gso; |
39 | static bool cfg_use_qdisc_bypass; |
40 | static bool cfg_use_vlan; |
41 | static bool cfg_use_vnet; |
42 | |
43 | static char *cfg_ifname = "lo" ; |
44 | static int cfg_mtu = 1500; |
45 | static int cfg_payload_len = DATA_LEN; |
46 | static int cfg_truncate_len = INT_MAX; |
47 | static uint16_t cfg_port = 8000; |
48 | |
49 | /* test sending up to max mtu + 1 */ |
50 | #define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1) |
51 | |
52 | static char tbuf[TEST_SZ], rbuf[TEST_SZ]; |
53 | |
54 | static unsigned long add_csum_hword(const uint16_t *start, int num_u16) |
55 | { |
56 | unsigned long sum = 0; |
57 | int i; |
58 | |
59 | for (i = 0; i < num_u16; i++) |
60 | sum += start[i]; |
61 | |
62 | return sum; |
63 | } |
64 | |
65 | static uint16_t build_ip_csum(const uint16_t *start, int num_u16, |
66 | unsigned long sum) |
67 | { |
68 | sum += add_csum_hword(start, num_u16); |
69 | |
70 | while (sum >> 16) |
71 | sum = (sum & 0xffff) + (sum >> 16); |
72 | |
73 | return ~sum; |
74 | } |
75 | |
76 | static int (void *) |
77 | { |
78 | struct virtio_net_hdr *vh = header; |
79 | |
80 | vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr); |
81 | |
82 | if (cfg_use_csum_off) { |
83 | vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM; |
84 | vh->csum_start = ETH_HLEN + sizeof(struct iphdr); |
85 | vh->csum_offset = __builtin_offsetof(struct udphdr, check); |
86 | |
87 | /* position check field exactly one byte beyond end of packet */ |
88 | if (cfg_use_csum_off_bad) |
89 | vh->csum_start += sizeof(struct udphdr) + cfg_payload_len - |
90 | vh->csum_offset - 1; |
91 | } |
92 | |
93 | if (cfg_use_gso) { |
94 | vh->gso_type = VIRTIO_NET_HDR_GSO_UDP; |
95 | vh->gso_size = cfg_mtu - sizeof(struct iphdr); |
96 | } |
97 | |
98 | return sizeof(*vh); |
99 | } |
100 | |
101 | static int (void *) |
102 | { |
103 | struct ethhdr *eth = header; |
104 | |
105 | if (cfg_use_vlan) { |
106 | uint16_t *tag = header + ETH_HLEN; |
107 | |
108 | eth->h_proto = htons(ETH_P_8021Q); |
109 | tag[1] = htons(ETH_P_IP); |
110 | return ETH_HLEN + 4; |
111 | } |
112 | |
113 | eth->h_proto = htons(ETH_P_IP); |
114 | return ETH_HLEN; |
115 | } |
116 | |
117 | static int (void *, int payload_len) |
118 | { |
119 | struct iphdr *iph = header; |
120 | |
121 | iph->ihl = 5; |
122 | iph->version = 4; |
123 | iph->ttl = 8; |
124 | iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len); |
125 | iph->id = htons(1337); |
126 | iph->protocol = IPPROTO_UDP; |
127 | iph->saddr = htonl((172 << 24) | (17 << 16) | 2); |
128 | iph->daddr = htonl((172 << 24) | (17 << 16) | 1); |
129 | iph->check = build_ip_csum(start: (void *) iph, num_u16: iph->ihl << 1, sum: 0); |
130 | |
131 | return iph->ihl << 2; |
132 | } |
133 | |
134 | static int (void *, int payload_len) |
135 | { |
136 | const int alen = sizeof(uint32_t); |
137 | struct udphdr *udph = header; |
138 | int len = sizeof(*udph) + payload_len; |
139 | |
140 | udph->source = htons(9); |
141 | udph->dest = htons(cfg_port); |
142 | udph->len = htons(len); |
143 | |
144 | if (cfg_use_csum_off) |
145 | udph->check = build_ip_csum(start: header - (2 * alen), num_u16: alen, |
146 | htons(IPPROTO_UDP) + udph->len); |
147 | else |
148 | udph->check = 0; |
149 | |
150 | return sizeof(*udph); |
151 | } |
152 | |
153 | static int build_packet(int payload_len) |
154 | { |
155 | int off = 0; |
156 | |
157 | off += build_vnet_header(header: tbuf); |
158 | off += build_eth_header(header: tbuf + off); |
159 | off += build_ipv4_header(header: tbuf + off, payload_len); |
160 | off += build_udp_header(header: tbuf + off, payload_len); |
161 | |
162 | if (off + payload_len > sizeof(tbuf)) |
163 | error(1, 0, "payload length exceeds max" ); |
164 | |
165 | memset(tbuf + off, DATA_CHAR, payload_len); |
166 | |
167 | return off + payload_len; |
168 | } |
169 | |
170 | static void do_bind(int fd) |
171 | { |
172 | struct sockaddr_ll laddr = {0}; |
173 | |
174 | laddr.sll_family = AF_PACKET; |
175 | laddr.sll_protocol = htons(ETH_P_IP); |
176 | laddr.sll_ifindex = if_nametoindex(cfg_ifname); |
177 | if (!laddr.sll_ifindex) |
178 | error(1, errno, "if_nametoindex" ); |
179 | |
180 | if (bind(fd, (void *)&laddr, sizeof(laddr))) |
181 | error(1, errno, "bind" ); |
182 | } |
183 | |
184 | static void do_send(int fd, char *buf, int len) |
185 | { |
186 | int ret; |
187 | |
188 | if (!cfg_use_vnet) { |
189 | buf += sizeof(struct virtio_net_hdr); |
190 | len -= sizeof(struct virtio_net_hdr); |
191 | } |
192 | if (cfg_use_dgram) { |
193 | buf += ETH_HLEN; |
194 | len -= ETH_HLEN; |
195 | } |
196 | |
197 | if (cfg_use_bind) { |
198 | ret = write(fd, buf, len); |
199 | } else { |
200 | struct sockaddr_ll laddr = {0}; |
201 | |
202 | laddr.sll_protocol = htons(ETH_P_IP); |
203 | laddr.sll_ifindex = if_nametoindex(cfg_ifname); |
204 | if (!laddr.sll_ifindex) |
205 | error(1, errno, "if_nametoindex" ); |
206 | |
207 | ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr)); |
208 | } |
209 | |
210 | if (ret == -1) |
211 | error(1, errno, "write" ); |
212 | if (ret != len) |
213 | error(1, 0, "write: %u %u" , ret, len); |
214 | |
215 | fprintf(stderr, "tx: %u\n" , ret); |
216 | } |
217 | |
218 | static int do_tx(void) |
219 | { |
220 | const int one = 1; |
221 | int fd, len; |
222 | |
223 | fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0); |
224 | if (fd == -1) |
225 | error(1, errno, "socket t" ); |
226 | |
227 | if (cfg_use_bind) |
228 | do_bind(fd); |
229 | |
230 | if (cfg_use_qdisc_bypass && |
231 | setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one))) |
232 | error(1, errno, "setsockopt qdisc bypass" ); |
233 | |
234 | if (cfg_use_vnet && |
235 | setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one))) |
236 | error(1, errno, "setsockopt vnet" ); |
237 | |
238 | len = build_packet(payload_len: cfg_payload_len); |
239 | |
240 | if (cfg_truncate_len < len) |
241 | len = cfg_truncate_len; |
242 | |
243 | do_send(fd, buf: tbuf, len); |
244 | |
245 | if (close(fd)) |
246 | error(1, errno, "close t" ); |
247 | |
248 | return len; |
249 | } |
250 | |
251 | static int setup_rx(void) |
252 | { |
253 | struct timeval tv = { .tv_usec = 100 * 1000 }; |
254 | struct sockaddr_in raddr = {0}; |
255 | int fd; |
256 | |
257 | fd = socket(PF_INET, SOCK_DGRAM, 0); |
258 | if (fd == -1) |
259 | error(1, errno, "socket r" ); |
260 | |
261 | if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv))) |
262 | error(1, errno, "setsockopt rcv timeout" ); |
263 | |
264 | raddr.sin_family = AF_INET; |
265 | raddr.sin_port = htons(cfg_port); |
266 | raddr.sin_addr.s_addr = htonl(INADDR_ANY); |
267 | |
268 | if (bind(fd, (void *)&raddr, sizeof(raddr))) |
269 | error(1, errno, "bind r" ); |
270 | |
271 | return fd; |
272 | } |
273 | |
274 | static void do_rx(int fd, int expected_len, char *expected) |
275 | { |
276 | int ret; |
277 | |
278 | ret = recv(fd, rbuf, sizeof(rbuf), 0); |
279 | if (ret == -1) |
280 | error(1, errno, "recv" ); |
281 | if (ret != expected_len) |
282 | error(1, 0, "recv: %u != %u" , ret, expected_len); |
283 | |
284 | if (memcmp(p: rbuf, q: expected, size: ret)) |
285 | error(1, 0, "recv: data mismatch" ); |
286 | |
287 | fprintf(stderr, "rx: %u\n" , ret); |
288 | } |
289 | |
290 | static int setup_sniffer(void) |
291 | { |
292 | struct timeval tv = { .tv_usec = 100 * 1000 }; |
293 | int fd; |
294 | |
295 | fd = socket(PF_PACKET, SOCK_RAW, 0); |
296 | if (fd == -1) |
297 | error(1, errno, "socket p" ); |
298 | |
299 | if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv))) |
300 | error(1, errno, "setsockopt rcv timeout" ); |
301 | |
302 | pair_udp_setfilter(fd); |
303 | do_bind(fd); |
304 | |
305 | return fd; |
306 | } |
307 | |
308 | static void parse_opts(int argc, char **argv) |
309 | { |
310 | int c; |
311 | |
312 | while ((c = getopt(argc, argv, "bcCdgl:qt:vV" )) != -1) { |
313 | switch (c) { |
314 | case 'b': |
315 | cfg_use_bind = true; |
316 | break; |
317 | case 'c': |
318 | cfg_use_csum_off = true; |
319 | break; |
320 | case 'C': |
321 | cfg_use_csum_off_bad = true; |
322 | break; |
323 | case 'd': |
324 | cfg_use_dgram = true; |
325 | break; |
326 | case 'g': |
327 | cfg_use_gso = true; |
328 | break; |
329 | case 'l': |
330 | cfg_payload_len = strtoul(optarg, NULL, 0); |
331 | break; |
332 | case 'q': |
333 | cfg_use_qdisc_bypass = true; |
334 | break; |
335 | case 't': |
336 | cfg_truncate_len = strtoul(optarg, NULL, 0); |
337 | break; |
338 | case 'v': |
339 | cfg_use_vnet = true; |
340 | break; |
341 | case 'V': |
342 | cfg_use_vlan = true; |
343 | break; |
344 | default: |
345 | error(1, 0, "%s: parse error" , argv[0]); |
346 | } |
347 | } |
348 | |
349 | if (cfg_use_vlan && cfg_use_dgram) |
350 | error(1, 0, "option vlan (-V) conflicts with dgram (-d)" ); |
351 | |
352 | if (cfg_use_csum_off && !cfg_use_vnet) |
353 | error(1, 0, "option csum offload (-c) requires vnet (-v)" ); |
354 | |
355 | if (cfg_use_csum_off_bad && !cfg_use_csum_off) |
356 | error(1, 0, "option csum bad (-C) requires csum offload (-c)" ); |
357 | |
358 | if (cfg_use_gso && !cfg_use_csum_off) |
359 | error(1, 0, "option gso (-g) requires csum offload (-c)" ); |
360 | } |
361 | |
362 | static void run_test(void) |
363 | { |
364 | int fdr, fds, total_len; |
365 | |
366 | fdr = setup_rx(); |
367 | fds = setup_sniffer(); |
368 | |
369 | total_len = do_tx(); |
370 | |
371 | /* BPF filter accepts only this length, vlan changes MAC */ |
372 | if (cfg_payload_len == DATA_LEN && !cfg_use_vlan) |
373 | do_rx(fd: fds, expected_len: total_len - sizeof(struct virtio_net_hdr), |
374 | expected: tbuf + sizeof(struct virtio_net_hdr)); |
375 | |
376 | do_rx(fd: fdr, expected_len: cfg_payload_len, expected: tbuf + total_len - cfg_payload_len); |
377 | |
378 | if (close(fds)) |
379 | error(1, errno, "close s" ); |
380 | if (close(fdr)) |
381 | error(1, errno, "close r" ); |
382 | } |
383 | |
384 | int main(int argc, char **argv) |
385 | { |
386 | parse_opts(argc, argv); |
387 | |
388 | if (system("ip link set dev lo mtu 1500" )) |
389 | error(1, errno, "ip link set mtu" ); |
390 | if (system("ip addr add dev lo 172.17.0.1/24" )) |
391 | error(1, errno, "ip addr add" ); |
392 | if (system("sysctl -w net.ipv4.conf.lo.accept_local=1" )) |
393 | error(1, errno, "sysctl lo.accept_local" ); |
394 | |
395 | run_test(); |
396 | |
397 | fprintf(stderr, "OK\n\n" ); |
398 | return 0; |
399 | } |
400 | |