1 | // SPDX-License-Identifier: GPL-2.0-only |
2 | /* Copyright (c) 2016 Facebook |
3 | */ |
4 | #include <linux/bpf.h> |
5 | #include <linux/if_link.h> |
6 | #include <assert.h> |
7 | #include <errno.h> |
8 | #include <signal.h> |
9 | #include <stdio.h> |
10 | #include <stdlib.h> |
11 | #include <string.h> |
12 | #include <net/if.h> |
13 | #include <arpa/inet.h> |
14 | #include <netinet/ether.h> |
15 | #include <unistd.h> |
16 | #include <time.h> |
17 | #include <bpf/libbpf.h> |
18 | #include <bpf/bpf.h> |
19 | #include "bpf_util.h" |
20 | #include "xdp_tx_iptunnel_common.h" |
21 | |
22 | #define STATS_INTERVAL_S 2U |
23 | |
24 | static int ifindex = -1; |
25 | static __u32 xdp_flags = XDP_FLAGS_UPDATE_IF_NOEXIST; |
26 | static int rxcnt_map_fd; |
27 | static __u32 prog_id; |
28 | |
29 | static void int_exit(int sig) |
30 | { |
31 | __u32 curr_prog_id = 0; |
32 | |
33 | if (ifindex > -1) { |
34 | if (bpf_xdp_query_id(ifindex, xdp_flags, &curr_prog_id)) { |
35 | printf(format: "bpf_xdp_query_id failed\n" ); |
36 | exit(status: 1); |
37 | } |
38 | if (prog_id == curr_prog_id) |
39 | bpf_xdp_detach(ifindex, xdp_flags, NULL); |
40 | else if (!curr_prog_id) |
41 | printf(format: "couldn't find a prog id on a given iface\n" ); |
42 | else |
43 | printf(format: "program on interface changed, not removing\n" ); |
44 | } |
45 | exit(status: 0); |
46 | } |
47 | |
48 | /* simple per-protocol drop counter |
49 | */ |
50 | static void poll_stats(unsigned int kill_after_s) |
51 | { |
52 | const unsigned int nr_protos = 256; |
53 | unsigned int nr_cpus = bpf_num_possible_cpus(); |
54 | time_t started_at = time(NULL); |
55 | __u64 values[nr_cpus], prev[nr_protos][nr_cpus]; |
56 | __u32 proto; |
57 | int i; |
58 | |
59 | memset(s: prev, c: 0, n: sizeof(prev)); |
60 | |
61 | while (!kill_after_s || time(NULL) - started_at <= kill_after_s) { |
62 | sleep(STATS_INTERVAL_S); |
63 | |
64 | for (proto = 0; proto < nr_protos; proto++) { |
65 | __u64 sum = 0; |
66 | |
67 | assert(bpf_map_lookup_elem(rxcnt_map_fd, &proto, |
68 | values) == 0); |
69 | for (i = 0; i < nr_cpus; i++) |
70 | sum += (values[i] - prev[proto][i]); |
71 | |
72 | if (sum) |
73 | printf(format: "proto %u: sum:%10llu pkts, rate:%10llu pkts/s\n" , |
74 | proto, sum, sum / STATS_INTERVAL_S); |
75 | memcpy(dest: prev[proto], src: values, n: sizeof(values)); |
76 | } |
77 | } |
78 | } |
79 | |
80 | static void usage(const char *cmd) |
81 | { |
82 | printf(format: "Start a XDP prog which encapsulates incoming packets\n" |
83 | "in an IPv4/v6 header and XDP_TX it out. The dst <VIP:PORT>\n" |
84 | "is used to select packets to encapsulate\n\n" ); |
85 | printf(format: "Usage: %s [...]\n" , cmd); |
86 | printf(format: " -i <ifname|ifindex> Interface\n" ); |
87 | printf(format: " -a <vip-service-address> IPv4 or IPv6\n" ); |
88 | printf(format: " -p <vip-service-port> A port range (e.g. 433-444) is also allowed\n" ); |
89 | printf(format: " -s <source-ip> Used in the IPTunnel header\n" ); |
90 | printf(format: " -d <dest-ip> Used in the IPTunnel header\n" ); |
91 | printf(format: " -m <dest-MAC> Used in sending the IP Tunneled pkt\n" ); |
92 | printf(format: " -T <stop-after-X-seconds> Default: 0 (forever)\n" ); |
93 | printf(format: " -P <IP-Protocol> Default is TCP\n" ); |
94 | printf(format: " -S use skb-mode\n" ); |
95 | printf(format: " -N enforce native mode\n" ); |
96 | printf(format: " -F Force loading the XDP prog\n" ); |
97 | printf(format: " -h Display this help\n" ); |
98 | } |
99 | |
100 | static int parse_ipstr(const char *ipstr, unsigned int *addr) |
101 | { |
102 | if (inet_pton(AF_INET6, cp: ipstr, buf: addr) == 1) { |
103 | return AF_INET6; |
104 | } else if (inet_pton(AF_INET, cp: ipstr, buf: addr) == 1) { |
105 | addr[1] = addr[2] = addr[3] = 0; |
106 | return AF_INET; |
107 | } |
108 | |
109 | fprintf(stderr, format: "%s is an invalid IP\n" , ipstr); |
110 | return AF_UNSPEC; |
111 | } |
112 | |
113 | static int parse_ports(const char *port_str, int *min_port, int *max_port) |
114 | { |
115 | char *end; |
116 | long tmp_min_port; |
117 | long tmp_max_port; |
118 | |
119 | tmp_min_port = strtol(nptr: optarg, endptr: &end, base: 10); |
120 | if (tmp_min_port < 1 || tmp_min_port > 65535) { |
121 | fprintf(stderr, format: "Invalid port(s):%s\n" , optarg); |
122 | return 1; |
123 | } |
124 | |
125 | if (*end == '-') { |
126 | end++; |
127 | tmp_max_port = strtol(nptr: end, NULL, base: 10); |
128 | if (tmp_max_port < 1 || tmp_max_port > 65535) { |
129 | fprintf(stderr, format: "Invalid port(s):%s\n" , optarg); |
130 | return 1; |
131 | } |
132 | } else { |
133 | tmp_max_port = tmp_min_port; |
134 | } |
135 | |
136 | if (tmp_min_port > tmp_max_port) { |
137 | fprintf(stderr, format: "Invalid port(s):%s\n" , optarg); |
138 | return 1; |
139 | } |
140 | |
141 | if (tmp_max_port - tmp_min_port + 1 > MAX_IPTNL_ENTRIES) { |
142 | fprintf(stderr, format: "Port range (%s) is larger than %u\n" , |
143 | port_str, MAX_IPTNL_ENTRIES); |
144 | return 1; |
145 | } |
146 | *min_port = tmp_min_port; |
147 | *max_port = tmp_max_port; |
148 | |
149 | return 0; |
150 | } |
151 | |
152 | int main(int argc, char **argv) |
153 | { |
154 | int min_port = 0, max_port = 0, vip2tnl_map_fd; |
155 | const char *optstr = "i:a:p:s:d:m:T:P:FSNh" ; |
156 | unsigned char opt_flags[256] = {}; |
157 | struct bpf_prog_info info = {}; |
158 | __u32 info_len = sizeof(info); |
159 | unsigned int kill_after_s = 0; |
160 | struct iptnl_info tnl = {}; |
161 | struct bpf_program *prog; |
162 | struct bpf_object *obj; |
163 | struct vip vip = {}; |
164 | char filename[256]; |
165 | int opt, prog_fd; |
166 | int i, err; |
167 | |
168 | tnl.family = AF_UNSPEC; |
169 | vip.protocol = IPPROTO_TCP; |
170 | |
171 | for (i = 0; i < strlen(s: optstr); i++) |
172 | if (optstr[i] != 'h' && 'a' <= optstr[i] && optstr[i] <= 'z') |
173 | opt_flags[(unsigned char)optstr[i]] = 1; |
174 | |
175 | while ((opt = getopt(argc: argc, argv: argv, shortopts: optstr)) != -1) { |
176 | unsigned short family; |
177 | unsigned int *v6; |
178 | |
179 | switch (opt) { |
180 | case 'i': |
181 | ifindex = if_nametoindex(ifname: optarg); |
182 | if (!ifindex) |
183 | ifindex = atoi(nptr: optarg); |
184 | break; |
185 | case 'a': |
186 | vip.family = parse_ipstr(ipstr: optarg, addr: vip.daddr.v6); |
187 | if (vip.family == AF_UNSPEC) |
188 | return 1; |
189 | break; |
190 | case 'p': |
191 | if (parse_ports(port_str: optarg, min_port: &min_port, max_port: &max_port)) |
192 | return 1; |
193 | break; |
194 | case 'P': |
195 | vip.protocol = atoi(nptr: optarg); |
196 | break; |
197 | case 's': |
198 | case 'd': |
199 | if (opt == 's') |
200 | v6 = tnl.saddr.v6; |
201 | else |
202 | v6 = tnl.daddr.v6; |
203 | |
204 | family = parse_ipstr(ipstr: optarg, addr: v6); |
205 | if (family == AF_UNSPEC) |
206 | return 1; |
207 | if (tnl.family == AF_UNSPEC) { |
208 | tnl.family = family; |
209 | } else if (tnl.family != family) { |
210 | fprintf(stderr, |
211 | format: "The IP version of the src and dst addresses used in the IP encapsulation does not match\n" ); |
212 | return 1; |
213 | } |
214 | break; |
215 | case 'm': |
216 | if (!ether_aton_r(asc: optarg, |
217 | addr: (struct ether_addr *)tnl.dmac)) { |
218 | fprintf(stderr, format: "Invalid mac address:%s\n" , |
219 | optarg); |
220 | return 1; |
221 | } |
222 | break; |
223 | case 'T': |
224 | kill_after_s = atoi(nptr: optarg); |
225 | break; |
226 | case 'S': |
227 | xdp_flags |= XDP_FLAGS_SKB_MODE; |
228 | break; |
229 | case 'N': |
230 | /* default, set below */ |
231 | break; |
232 | case 'F': |
233 | xdp_flags &= ~XDP_FLAGS_UPDATE_IF_NOEXIST; |
234 | break; |
235 | default: |
236 | usage(cmd: argv[0]); |
237 | return 1; |
238 | } |
239 | opt_flags[opt] = 0; |
240 | } |
241 | |
242 | if (!(xdp_flags & XDP_FLAGS_SKB_MODE)) |
243 | xdp_flags |= XDP_FLAGS_DRV_MODE; |
244 | |
245 | for (i = 0; i < strlen(s: optstr); i++) { |
246 | if (opt_flags[(unsigned int)optstr[i]]) { |
247 | fprintf(stderr, format: "Missing argument -%c\n" , optstr[i]); |
248 | usage(cmd: argv[0]); |
249 | return 1; |
250 | } |
251 | } |
252 | |
253 | if (!ifindex) { |
254 | fprintf(stderr, format: "Invalid ifname\n" ); |
255 | return 1; |
256 | } |
257 | |
258 | snprintf(s: filename, maxlen: sizeof(filename), format: "%s_kern.o" , argv[0]); |
259 | |
260 | obj = bpf_object__open_file(filename, NULL); |
261 | if (libbpf_get_error(obj)) |
262 | return 1; |
263 | |
264 | prog = bpf_object__next_program(obj, NULL); |
265 | bpf_program__set_type(prog, BPF_PROG_TYPE_XDP); |
266 | |
267 | err = bpf_object__load(obj); |
268 | if (err) { |
269 | printf(format: "bpf_object__load(): %s\n" , strerror(errno)); |
270 | return 1; |
271 | } |
272 | prog_fd = bpf_program__fd(prog); |
273 | |
274 | rxcnt_map_fd = bpf_object__find_map_fd_by_name(obj, "rxcnt" ); |
275 | vip2tnl_map_fd = bpf_object__find_map_fd_by_name(obj, "vip2tnl" ); |
276 | if (vip2tnl_map_fd < 0 || rxcnt_map_fd < 0) { |
277 | printf(format: "bpf_object__find_map_fd_by_name failed\n" ); |
278 | return 1; |
279 | } |
280 | |
281 | signal(SIGINT, handler: int_exit); |
282 | signal(SIGTERM, handler: int_exit); |
283 | |
284 | while (min_port <= max_port) { |
285 | vip.dport = htons(min_port++); |
286 | if (bpf_map_update_elem(vip2tnl_map_fd, &vip, &tnl, |
287 | BPF_NOEXIST)) { |
288 | perror(s: "bpf_map_update_elem(&vip2tnl)" ); |
289 | return 1; |
290 | } |
291 | } |
292 | |
293 | if (bpf_xdp_attach(ifindex, prog_fd, xdp_flags, NULL) < 0) { |
294 | printf(format: "link set xdp fd failed\n" ); |
295 | return 1; |
296 | } |
297 | |
298 | err = bpf_prog_get_info_by_fd(prog_fd, &info, &info_len); |
299 | if (err) { |
300 | printf(format: "can't get prog info - %s\n" , strerror(errno)); |
301 | return err; |
302 | } |
303 | prog_id = info.id; |
304 | |
305 | poll_stats(kill_after_s); |
306 | |
307 | bpf_xdp_detach(ifindex, xdp_flags, NULL); |
308 | |
309 | return 0; |
310 | } |
311 | |