1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* Copyright (c) 2022 Meta Platforms, Inc. and affiliates. */ |
3 | |
4 | #include <linux/bpf.h> |
5 | #include <bpf/bpf_helpers.h> |
6 | #include "bpf_misc.h" |
7 | #include "test_user_ringbuf.h" |
8 | |
9 | char _license[] SEC("license" ) = "GPL" ; |
10 | |
11 | struct { |
12 | __uint(type, BPF_MAP_TYPE_USER_RINGBUF); |
13 | } user_ringbuf SEC(".maps" ); |
14 | |
15 | struct { |
16 | __uint(type, BPF_MAP_TYPE_RINGBUF); |
17 | } kernel_ringbuf SEC(".maps" ); |
18 | |
19 | /* inputs */ |
20 | int pid, err, val; |
21 | |
22 | int read = 0; |
23 | |
24 | /* Counter used for end-to-end protocol test */ |
25 | __u64 kern_mutated = 0; |
26 | __u64 user_mutated = 0; |
27 | __u64 expected_user_mutated = 0; |
28 | |
29 | static int |
30 | is_test_process(void) |
31 | { |
32 | int cur_pid = bpf_get_current_pid_tgid() >> 32; |
33 | |
34 | return cur_pid == pid; |
35 | } |
36 | |
37 | static long |
38 | record_sample(struct bpf_dynptr *dynptr, void *context) |
39 | { |
40 | const struct sample *sample = NULL; |
41 | struct sample stack_sample; |
42 | int status; |
43 | static int num_calls; |
44 | |
45 | if (num_calls++ % 2 == 0) { |
46 | status = bpf_dynptr_read(&stack_sample, sizeof(stack_sample), dynptr, 0, 0); |
47 | if (status) { |
48 | bpf_printk("bpf_dynptr_read() failed: %d\n" , status); |
49 | err = 1; |
50 | return 1; |
51 | } |
52 | } else { |
53 | sample = bpf_dynptr_data(dynptr, 0, sizeof(*sample)); |
54 | if (!sample) { |
55 | bpf_printk("Unexpectedly failed to get sample\n" ); |
56 | err = 2; |
57 | return 1; |
58 | } |
59 | stack_sample = *sample; |
60 | } |
61 | |
62 | __sync_fetch_and_add(&read, 1); |
63 | return 0; |
64 | } |
65 | |
66 | static void |
67 | handle_sample_msg(const struct test_msg *msg) |
68 | { |
69 | switch (msg->msg_op) { |
70 | case TEST_MSG_OP_INC64: |
71 | kern_mutated += msg->operand_64; |
72 | break; |
73 | case TEST_MSG_OP_INC32: |
74 | kern_mutated += msg->operand_32; |
75 | break; |
76 | case TEST_MSG_OP_MUL64: |
77 | kern_mutated *= msg->operand_64; |
78 | break; |
79 | case TEST_MSG_OP_MUL32: |
80 | kern_mutated *= msg->operand_32; |
81 | break; |
82 | default: |
83 | bpf_printk("Unrecognized op %d\n" , msg->msg_op); |
84 | err = 2; |
85 | } |
86 | } |
87 | |
88 | static long |
89 | read_protocol_msg(struct bpf_dynptr *dynptr, void *context) |
90 | { |
91 | const struct test_msg *msg = NULL; |
92 | |
93 | msg = bpf_dynptr_data(dynptr, 0, sizeof(*msg)); |
94 | if (!msg) { |
95 | err = 1; |
96 | bpf_printk("Unexpectedly failed to get msg\n" ); |
97 | return 0; |
98 | } |
99 | |
100 | handle_sample_msg(msg); |
101 | |
102 | return 0; |
103 | } |
104 | |
105 | static int publish_next_kern_msg(__u32 index, void *context) |
106 | { |
107 | struct test_msg *msg = NULL; |
108 | int operand_64 = TEST_OP_64; |
109 | int operand_32 = TEST_OP_32; |
110 | |
111 | msg = bpf_ringbuf_reserve(&kernel_ringbuf, sizeof(*msg), 0); |
112 | if (!msg) { |
113 | err = 4; |
114 | return 1; |
115 | } |
116 | |
117 | switch (index % TEST_MSG_OP_NUM_OPS) { |
118 | case TEST_MSG_OP_INC64: |
119 | msg->operand_64 = operand_64; |
120 | msg->msg_op = TEST_MSG_OP_INC64; |
121 | expected_user_mutated += operand_64; |
122 | break; |
123 | case TEST_MSG_OP_INC32: |
124 | msg->operand_32 = operand_32; |
125 | msg->msg_op = TEST_MSG_OP_INC32; |
126 | expected_user_mutated += operand_32; |
127 | break; |
128 | case TEST_MSG_OP_MUL64: |
129 | msg->operand_64 = operand_64; |
130 | msg->msg_op = TEST_MSG_OP_MUL64; |
131 | expected_user_mutated *= operand_64; |
132 | break; |
133 | case TEST_MSG_OP_MUL32: |
134 | msg->operand_32 = operand_32; |
135 | msg->msg_op = TEST_MSG_OP_MUL32; |
136 | expected_user_mutated *= operand_32; |
137 | break; |
138 | default: |
139 | bpf_ringbuf_discard(msg, 0); |
140 | err = 5; |
141 | return 1; |
142 | } |
143 | |
144 | bpf_ringbuf_submit(msg, 0); |
145 | |
146 | return 0; |
147 | } |
148 | |
149 | static void |
150 | publish_kern_messages(void) |
151 | { |
152 | if (expected_user_mutated != user_mutated) { |
153 | bpf_printk("%lu != %lu\n" , expected_user_mutated, user_mutated); |
154 | err = 3; |
155 | return; |
156 | } |
157 | |
158 | bpf_loop(8, publish_next_kern_msg, NULL, 0); |
159 | } |
160 | |
161 | SEC("fentry/" SYS_PREFIX "sys_prctl" ) |
162 | int test_user_ringbuf_protocol(void *ctx) |
163 | { |
164 | long status = 0; |
165 | |
166 | if (!is_test_process()) |
167 | return 0; |
168 | |
169 | status = bpf_user_ringbuf_drain(&user_ringbuf, read_protocol_msg, NULL, 0); |
170 | if (status < 0) { |
171 | bpf_printk("Drain returned: %ld\n" , status); |
172 | err = 1; |
173 | return 0; |
174 | } |
175 | |
176 | publish_kern_messages(); |
177 | |
178 | return 0; |
179 | } |
180 | |
181 | SEC("fentry/" SYS_PREFIX "sys_getpgid" ) |
182 | int test_user_ringbuf(void *ctx) |
183 | { |
184 | if (!is_test_process()) |
185 | return 0; |
186 | |
187 | err = bpf_user_ringbuf_drain(&user_ringbuf, record_sample, NULL, 0); |
188 | |
189 | return 0; |
190 | } |
191 | |
192 | static long |
193 | do_nothing_cb(struct bpf_dynptr *dynptr, void *context) |
194 | { |
195 | __sync_fetch_and_add(&read, 1); |
196 | return 0; |
197 | } |
198 | |
199 | SEC("fentry/" SYS_PREFIX "sys_prlimit64" ) |
200 | int test_user_ringbuf_epoll(void *ctx) |
201 | { |
202 | long num_samples; |
203 | |
204 | if (!is_test_process()) |
205 | return 0; |
206 | |
207 | num_samples = bpf_user_ringbuf_drain(&user_ringbuf, do_nothing_cb, NULL, 0); |
208 | if (num_samples <= 0) |
209 | err = 1; |
210 | |
211 | return 0; |
212 | } |
213 | |