1 | // SPDX-License-Identifier: GPL-2.0 |
2 | |
3 | /* |
4 | * Test key rotation for TFO. |
5 | * New keys are 'rotated' in two steps: |
6 | * 1) Add new key as the 'backup' key 'behind' the primary key |
7 | * 2) Make new key the primary by swapping the backup and primary keys |
8 | * |
9 | * The rotation is done in stages using multiple sockets bound |
10 | * to the same port via SO_REUSEPORT. This simulates key rotation |
11 | * behind say a load balancer. We verify that across the rotation |
12 | * there are no cases in which a cookie is not accepted by verifying |
13 | * that TcpExtTCPFastOpenPassiveFail remains 0. |
14 | */ |
15 | #define _GNU_SOURCE |
16 | #include <arpa/inet.h> |
17 | #include <errno.h> |
18 | #include <error.h> |
19 | #include <stdbool.h> |
20 | #include <stdio.h> |
21 | #include <stdlib.h> |
22 | #include <string.h> |
23 | #include <sys/epoll.h> |
24 | #include <unistd.h> |
25 | #include <netinet/tcp.h> |
26 | #include <fcntl.h> |
27 | #include <time.h> |
28 | |
29 | #include "../kselftest.h" |
30 | |
31 | #ifndef TCP_FASTOPEN_KEY |
32 | #define TCP_FASTOPEN_KEY 33 |
33 | #endif |
34 | |
35 | #define N_LISTEN 10 |
36 | #define PROC_FASTOPEN_KEY "/proc/sys/net/ipv4/tcp_fastopen_key" |
37 | #define KEY_LENGTH 16 |
38 | |
39 | static bool do_ipv6; |
40 | static bool do_sockopt; |
41 | static bool do_rotate; |
42 | static int key_len = KEY_LENGTH; |
43 | static int rcv_fds[N_LISTEN]; |
44 | static int proc_fd; |
45 | static const char *IP4_ADDR = "127.0.0.1" ; |
46 | static const char *IP6_ADDR = "::1" ; |
47 | static const int PORT = 8891; |
48 | |
49 | static void get_keys(int fd, uint32_t *keys) |
50 | { |
51 | char buf[128]; |
52 | socklen_t len = KEY_LENGTH * 2; |
53 | |
54 | if (do_sockopt) { |
55 | if (getsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, &len)) |
56 | error(1, errno, "Unable to get key" ); |
57 | return; |
58 | } |
59 | lseek(proc_fd, 0, SEEK_SET); |
60 | if (read(proc_fd, buf, sizeof(buf)) <= 0) |
61 | error(1, errno, "Unable to read %s" , PROC_FASTOPEN_KEY); |
62 | if (sscanf(buf, "%x-%x-%x-%x,%x-%x-%x-%x" , keys, keys + 1, keys + 2, |
63 | keys + 3, keys + 4, keys + 5, keys + 6, keys + 7) != 8) |
64 | error(1, 0, "Unable to parse %s" , PROC_FASTOPEN_KEY); |
65 | } |
66 | |
67 | static void set_keys(int fd, uint32_t *keys) |
68 | { |
69 | char buf[128]; |
70 | |
71 | if (do_sockopt) { |
72 | if (setsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, |
73 | key_len)) |
74 | error(1, errno, "Unable to set key" ); |
75 | return; |
76 | } |
77 | if (do_rotate) |
78 | snprintf(buf, 128, "%08x-%08x-%08x-%08x,%08x-%08x-%08x-%08x" , |
79 | keys[0], keys[1], keys[2], keys[3], keys[4], keys[5], |
80 | keys[6], keys[7]); |
81 | else |
82 | snprintf(buf, 128, "%08x-%08x-%08x-%08x" , |
83 | keys[0], keys[1], keys[2], keys[3]); |
84 | lseek(proc_fd, 0, SEEK_SET); |
85 | if (write(proc_fd, buf, sizeof(buf)) <= 0) |
86 | error(1, errno, "Unable to write %s" , PROC_FASTOPEN_KEY); |
87 | } |
88 | |
89 | static void build_rcv_fd(int family, int proto, int *rcv_fds) |
90 | { |
91 | struct sockaddr_in addr4 = {0}; |
92 | struct sockaddr_in6 addr6 = {0}; |
93 | struct sockaddr *addr; |
94 | int opt = 1, i, sz; |
95 | int qlen = 100; |
96 | uint32_t keys[8]; |
97 | |
98 | switch (family) { |
99 | case AF_INET: |
100 | addr4.sin_family = family; |
101 | addr4.sin_addr.s_addr = htonl(INADDR_ANY); |
102 | addr4.sin_port = htons(PORT); |
103 | sz = sizeof(addr4); |
104 | addr = (struct sockaddr *)&addr4; |
105 | break; |
106 | case AF_INET6: |
107 | addr6.sin6_family = AF_INET6; |
108 | addr6.sin6_addr = in6addr_any; |
109 | addr6.sin6_port = htons(PORT); |
110 | sz = sizeof(addr6); |
111 | addr = (struct sockaddr *)&addr6; |
112 | break; |
113 | default: |
114 | error(1, 0, "Unsupported family %d" , family); |
115 | /* clang does not recognize error() above as terminating |
116 | * the program, so it complains that saddr, sz are |
117 | * not initialized when this code path is taken. Silence it. |
118 | */ |
119 | return; |
120 | } |
121 | for (i = 0; i < ARRAY_SIZE(keys); i++) |
122 | keys[i] = rand(); |
123 | for (i = 0; i < N_LISTEN; i++) { |
124 | rcv_fds[i] = socket(family, proto, 0); |
125 | if (rcv_fds[i] < 0) |
126 | error(1, errno, "failed to create receive socket" ); |
127 | if (setsockopt(rcv_fds[i], SOL_SOCKET, SO_REUSEPORT, &opt, |
128 | sizeof(opt))) |
129 | error(1, errno, "failed to set SO_REUSEPORT" ); |
130 | if (bind(rcv_fds[i], addr, sz)) |
131 | error(1, errno, "failed to bind receive socket" ); |
132 | if (setsockopt(rcv_fds[i], SOL_TCP, TCP_FASTOPEN, &qlen, |
133 | sizeof(qlen))) |
134 | error(1, errno, "failed to set TCP_FASTOPEN" ); |
135 | set_keys(rcv_fds[i], keys); |
136 | if (proto == SOCK_STREAM && listen(rcv_fds[i], 10)) |
137 | error(1, errno, "failed to listen on receive port" ); |
138 | } |
139 | } |
140 | |
141 | static int connect_and_send(int family, int proto) |
142 | { |
143 | struct sockaddr_in saddr4 = {0}; |
144 | struct sockaddr_in daddr4 = {0}; |
145 | struct sockaddr_in6 saddr6 = {0}; |
146 | struct sockaddr_in6 daddr6 = {0}; |
147 | struct sockaddr *saddr, *daddr; |
148 | int fd, sz, ret; |
149 | char data[1]; |
150 | |
151 | switch (family) { |
152 | case AF_INET: |
153 | saddr4.sin_family = AF_INET; |
154 | saddr4.sin_addr.s_addr = htonl(INADDR_ANY); |
155 | saddr4.sin_port = 0; |
156 | |
157 | daddr4.sin_family = AF_INET; |
158 | if (!inet_pton(family, IP4_ADDR, &daddr4.sin_addr.s_addr)) |
159 | error(1, errno, "inet_pton failed: %s" , IP4_ADDR); |
160 | daddr4.sin_port = htons(PORT); |
161 | |
162 | sz = sizeof(saddr4); |
163 | saddr = (struct sockaddr *)&saddr4; |
164 | daddr = (struct sockaddr *)&daddr4; |
165 | break; |
166 | case AF_INET6: |
167 | saddr6.sin6_family = AF_INET6; |
168 | saddr6.sin6_addr = in6addr_any; |
169 | |
170 | daddr6.sin6_family = AF_INET6; |
171 | if (!inet_pton(family, IP6_ADDR, &daddr6.sin6_addr)) |
172 | error(1, errno, "inet_pton failed: %s" , IP6_ADDR); |
173 | daddr6.sin6_port = htons(PORT); |
174 | |
175 | sz = sizeof(saddr6); |
176 | saddr = (struct sockaddr *)&saddr6; |
177 | daddr = (struct sockaddr *)&daddr6; |
178 | break; |
179 | default: |
180 | error(1, 0, "Unsupported family %d" , family); |
181 | /* clang does not recognize error() above as terminating |
182 | * the program, so it complains that saddr, daddr, sz are |
183 | * not initialized when this code path is taken. Silence it. |
184 | */ |
185 | return -1; |
186 | } |
187 | fd = socket(family, proto, 0); |
188 | if (fd < 0) |
189 | error(1, errno, "failed to create send socket" ); |
190 | if (bind(fd, saddr, sz)) |
191 | error(1, errno, "failed to bind send socket" ); |
192 | data[0] = 'a'; |
193 | ret = sendto(fd, data, 1, MSG_FASTOPEN, daddr, sz); |
194 | if (ret != 1) |
195 | error(1, errno, "failed to sendto" ); |
196 | |
197 | return fd; |
198 | } |
199 | |
200 | static bool is_listen_fd(int fd) |
201 | { |
202 | int i; |
203 | |
204 | for (i = 0; i < N_LISTEN; i++) { |
205 | if (rcv_fds[i] == fd) |
206 | return true; |
207 | } |
208 | return false; |
209 | } |
210 | |
211 | static void rotate_key(int fd) |
212 | { |
213 | static int iter; |
214 | static uint32_t new_key[4]; |
215 | uint32_t keys[8]; |
216 | uint32_t tmp_key[4]; |
217 | int i; |
218 | |
219 | if (iter < N_LISTEN) { |
220 | /* first set new key as backups */ |
221 | if (iter == 0) { |
222 | for (i = 0; i < ARRAY_SIZE(new_key); i++) |
223 | new_key[i] = rand(); |
224 | } |
225 | get_keys(fd, keys); |
226 | memcpy(keys + 4, new_key, KEY_LENGTH); |
227 | set_keys(fd, keys); |
228 | } else { |
229 | /* swap the keys */ |
230 | get_keys(fd, keys); |
231 | memcpy(tmp_key, keys + 4, KEY_LENGTH); |
232 | memcpy(keys + 4, keys, KEY_LENGTH); |
233 | memcpy(keys, tmp_key, KEY_LENGTH); |
234 | set_keys(fd, keys); |
235 | } |
236 | if (++iter >= (N_LISTEN * 2)) |
237 | iter = 0; |
238 | } |
239 | |
240 | static void run_one_test(int family) |
241 | { |
242 | struct epoll_event ev; |
243 | int i, send_fd; |
244 | int n_loops = 10000; |
245 | int rotate_key_fd = 0; |
246 | int key_rotate_interval = 50; |
247 | int fd, epfd; |
248 | char buf[1]; |
249 | |
250 | build_rcv_fd(family, SOCK_STREAM, rcv_fds); |
251 | epfd = epoll_create(1); |
252 | if (epfd < 0) |
253 | error(1, errno, "failed to create epoll" ); |
254 | ev.events = EPOLLIN; |
255 | for (i = 0; i < N_LISTEN; i++) { |
256 | ev.data.fd = rcv_fds[i]; |
257 | if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fds[i], &ev)) |
258 | error(1, errno, "failed to register sock epoll" ); |
259 | } |
260 | while (n_loops--) { |
261 | send_fd = connect_and_send(family, SOCK_STREAM); |
262 | if (do_rotate && ((n_loops % key_rotate_interval) == 0)) { |
263 | rotate_key(fd: rcv_fds[rotate_key_fd]); |
264 | if (++rotate_key_fd >= N_LISTEN) |
265 | rotate_key_fd = 0; |
266 | } |
267 | while (1) { |
268 | i = epoll_wait(epfd, &ev, 1, -1); |
269 | if (i < 0) |
270 | error(1, errno, "epoll_wait failed" ); |
271 | if (is_listen_fd(ev.data.fd)) { |
272 | fd = accept(ev.data.fd, NULL, NULL); |
273 | if (fd < 0) |
274 | error(1, errno, "failed to accept" ); |
275 | ev.data.fd = fd; |
276 | if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev)) |
277 | error(1, errno, "failed epoll add" ); |
278 | continue; |
279 | } |
280 | i = recv(ev.data.fd, buf, sizeof(buf), 0); |
281 | if (i != 1) |
282 | error(1, errno, "failed recv data" ); |
283 | if (epoll_ctl(epfd, EPOLL_CTL_DEL, ev.data.fd, NULL)) |
284 | error(1, errno, "failed epoll del" ); |
285 | close(ev.data.fd); |
286 | break; |
287 | } |
288 | close(send_fd); |
289 | } |
290 | for (i = 0; i < N_LISTEN; i++) |
291 | close(rcv_fds[i]); |
292 | } |
293 | |
294 | static void parse_opts(int argc, char **argv) |
295 | { |
296 | int c; |
297 | |
298 | while ((c = getopt(argc, argv, "46sr" )) != -1) { |
299 | switch (c) { |
300 | case '4': |
301 | do_ipv6 = false; |
302 | break; |
303 | case '6': |
304 | do_ipv6 = true; |
305 | break; |
306 | case 's': |
307 | do_sockopt = true; |
308 | break; |
309 | case 'r': |
310 | do_rotate = true; |
311 | key_len = KEY_LENGTH * 2; |
312 | break; |
313 | default: |
314 | error(1, 0, "%s: parse error" , argv[0]); |
315 | } |
316 | } |
317 | } |
318 | |
319 | int main(int argc, char **argv) |
320 | { |
321 | parse_opts(argc, argv); |
322 | proc_fd = open(PROC_FASTOPEN_KEY, O_RDWR); |
323 | if (proc_fd < 0) |
324 | error(1, errno, "Unable to open %s" , PROC_FASTOPEN_KEY); |
325 | srand(time(NULL)); |
326 | if (do_ipv6) |
327 | run_one_test(AF_INET6); |
328 | else |
329 | run_one_test(AF_INET); |
330 | close(proc_fd); |
331 | fprintf(stderr, "PASS\n" ); |
332 | return 0; |
333 | } |
334 | |