1 | // SPDX-License-Identifier: GPL-2.0-or-later |
2 | |
3 | #include <crypto/hash.h> |
4 | #include <linux/cpu.h> |
5 | #include <linux/kref.h> |
6 | #include <linux/module.h> |
7 | #include <linux/mutex.h> |
8 | #include <linux/percpu.h> |
9 | #include <linux/workqueue.h> |
10 | #include <net/tcp.h> |
11 | |
12 | static size_t __scratch_size; |
13 | static DEFINE_PER_CPU(void __rcu *, sigpool_scratch); |
14 | |
15 | struct sigpool_entry { |
16 | struct crypto_ahash *hash; |
17 | const char *alg; |
18 | struct kref kref; |
19 | uint16_t needs_key:1, |
20 | reserved:15; |
21 | }; |
22 | |
23 | #define CPOOL_SIZE (PAGE_SIZE / sizeof(struct sigpool_entry)) |
24 | static struct sigpool_entry cpool[CPOOL_SIZE]; |
25 | static unsigned int cpool_populated; |
26 | static DEFINE_MUTEX(cpool_mutex); |
27 | |
28 | /* Slow-path */ |
29 | struct scratches_to_free { |
30 | struct rcu_head rcu; |
31 | unsigned int cnt; |
32 | void *scratches[]; |
33 | }; |
34 | |
35 | static void free_old_scratches(struct rcu_head *head) |
36 | { |
37 | struct scratches_to_free *stf; |
38 | |
39 | stf = container_of(head, struct scratches_to_free, rcu); |
40 | while (stf->cnt--) |
41 | kfree(objp: stf->scratches[stf->cnt]); |
42 | kfree(objp: stf); |
43 | } |
44 | |
45 | /** |
46 | * sigpool_reserve_scratch - re-allocates scratch buffer, slow-path |
47 | * @size: request size for the scratch/temp buffer |
48 | */ |
49 | static int sigpool_reserve_scratch(size_t size) |
50 | { |
51 | struct scratches_to_free *stf; |
52 | size_t stf_sz = struct_size(stf, scratches, num_possible_cpus()); |
53 | int cpu, err = 0; |
54 | |
55 | lockdep_assert_held(&cpool_mutex); |
56 | if (__scratch_size >= size) |
57 | return 0; |
58 | |
59 | stf = kmalloc(size: stf_sz, GFP_KERNEL); |
60 | if (!stf) |
61 | return -ENOMEM; |
62 | stf->cnt = 0; |
63 | |
64 | size = max(size, __scratch_size); |
65 | cpus_read_lock(); |
66 | for_each_possible_cpu(cpu) { |
67 | void *scratch, *old_scratch; |
68 | |
69 | scratch = kmalloc_node(size, GFP_KERNEL, cpu_to_node(cpu)); |
70 | if (!scratch) { |
71 | err = -ENOMEM; |
72 | break; |
73 | } |
74 | |
75 | old_scratch = rcu_replace_pointer(per_cpu(sigpool_scratch, cpu), |
76 | scratch, lockdep_is_held(&cpool_mutex)); |
77 | if (!cpu_online(cpu) || !old_scratch) { |
78 | kfree(objp: old_scratch); |
79 | continue; |
80 | } |
81 | stf->scratches[stf->cnt++] = old_scratch; |
82 | } |
83 | cpus_read_unlock(); |
84 | if (!err) |
85 | __scratch_size = size; |
86 | |
87 | call_rcu(head: &stf->rcu, func: free_old_scratches); |
88 | return err; |
89 | } |
90 | |
91 | static void sigpool_scratch_free(void) |
92 | { |
93 | int cpu; |
94 | |
95 | for_each_possible_cpu(cpu) |
96 | kfree(rcu_replace_pointer(per_cpu(sigpool_scratch, cpu), |
97 | NULL, lockdep_is_held(&cpool_mutex))); |
98 | __scratch_size = 0; |
99 | } |
100 | |
101 | static int __cpool_try_clone(struct crypto_ahash *hash) |
102 | { |
103 | struct crypto_ahash *tmp; |
104 | |
105 | tmp = crypto_clone_ahash(tfm: hash); |
106 | if (IS_ERR(ptr: tmp)) |
107 | return PTR_ERR(ptr: tmp); |
108 | |
109 | crypto_free_ahash(tfm: tmp); |
110 | return 0; |
111 | } |
112 | |
113 | static int __cpool_alloc_ahash(struct sigpool_entry *e, const char *alg) |
114 | { |
115 | struct crypto_ahash *cpu0_hash; |
116 | int ret; |
117 | |
118 | e->alg = kstrdup(s: alg, GFP_KERNEL); |
119 | if (!e->alg) |
120 | return -ENOMEM; |
121 | |
122 | cpu0_hash = crypto_alloc_ahash(alg_name: alg, type: 0, CRYPTO_ALG_ASYNC); |
123 | if (IS_ERR(ptr: cpu0_hash)) { |
124 | ret = PTR_ERR(ptr: cpu0_hash); |
125 | goto out_free_alg; |
126 | } |
127 | |
128 | e->needs_key = crypto_ahash_get_flags(tfm: cpu0_hash) & CRYPTO_TFM_NEED_KEY; |
129 | |
130 | ret = __cpool_try_clone(hash: cpu0_hash); |
131 | if (ret) |
132 | goto out_free_cpu0_hash; |
133 | e->hash = cpu0_hash; |
134 | kref_init(kref: &e->kref); |
135 | return 0; |
136 | |
137 | out_free_cpu0_hash: |
138 | crypto_free_ahash(tfm: cpu0_hash); |
139 | out_free_alg: |
140 | kfree(objp: e->alg); |
141 | e->alg = NULL; |
142 | return ret; |
143 | } |
144 | |
145 | /** |
146 | * tcp_sigpool_alloc_ahash - allocates pool for ahash requests |
147 | * @alg: name of async hash algorithm |
148 | * @scratch_size: reserve a tcp_sigpool::scratch buffer of this size |
149 | */ |
150 | int tcp_sigpool_alloc_ahash(const char *alg, size_t scratch_size) |
151 | { |
152 | int i, ret; |
153 | |
154 | /* slow-path */ |
155 | mutex_lock(&cpool_mutex); |
156 | ret = sigpool_reserve_scratch(size: scratch_size); |
157 | if (ret) |
158 | goto out; |
159 | for (i = 0; i < cpool_populated; i++) { |
160 | if (!cpool[i].alg) |
161 | continue; |
162 | if (strcmp(cpool[i].alg, alg)) |
163 | continue; |
164 | |
165 | /* pairs with tcp_sigpool_release() */ |
166 | if (!kref_get_unless_zero(kref: &cpool[i].kref)) |
167 | kref_init(kref: &cpool[i].kref); |
168 | ret = i; |
169 | goto out; |
170 | } |
171 | |
172 | for (i = 0; i < cpool_populated; i++) { |
173 | if (!cpool[i].alg) |
174 | break; |
175 | } |
176 | if (i >= CPOOL_SIZE) { |
177 | ret = -ENOSPC; |
178 | goto out; |
179 | } |
180 | |
181 | ret = __cpool_alloc_ahash(e: &cpool[i], alg); |
182 | if (!ret) { |
183 | ret = i; |
184 | if (i == cpool_populated) |
185 | cpool_populated++; |
186 | } |
187 | out: |
188 | mutex_unlock(lock: &cpool_mutex); |
189 | return ret; |
190 | } |
191 | EXPORT_SYMBOL_GPL(tcp_sigpool_alloc_ahash); |
192 | |
193 | static void __cpool_free_entry(struct sigpool_entry *e) |
194 | { |
195 | crypto_free_ahash(tfm: e->hash); |
196 | kfree(objp: e->alg); |
197 | memset(e, 0, sizeof(*e)); |
198 | } |
199 | |
200 | static void cpool_cleanup_work_cb(struct work_struct *work) |
201 | { |
202 | bool free_scratch = true; |
203 | unsigned int i; |
204 | |
205 | mutex_lock(&cpool_mutex); |
206 | for (i = 0; i < cpool_populated; i++) { |
207 | if (kref_read(kref: &cpool[i].kref) > 0) { |
208 | free_scratch = false; |
209 | continue; |
210 | } |
211 | if (!cpool[i].alg) |
212 | continue; |
213 | __cpool_free_entry(e: &cpool[i]); |
214 | } |
215 | if (free_scratch) |
216 | sigpool_scratch_free(); |
217 | mutex_unlock(lock: &cpool_mutex); |
218 | } |
219 | |
220 | static DECLARE_WORK(cpool_cleanup_work, cpool_cleanup_work_cb); |
221 | static void cpool_schedule_cleanup(struct kref *kref) |
222 | { |
223 | schedule_work(work: &cpool_cleanup_work); |
224 | } |
225 | |
226 | /** |
227 | * tcp_sigpool_release - decreases number of users for a pool. If it was |
228 | * the last user of the pool, releases any memory that was consumed. |
229 | * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() |
230 | */ |
231 | void tcp_sigpool_release(unsigned int id) |
232 | { |
233 | if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) |
234 | return; |
235 | |
236 | /* slow-path */ |
237 | kref_put(kref: &cpool[id].kref, release: cpool_schedule_cleanup); |
238 | } |
239 | EXPORT_SYMBOL_GPL(tcp_sigpool_release); |
240 | |
241 | /** |
242 | * tcp_sigpool_get - increases number of users (refcounter) for a pool |
243 | * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() |
244 | */ |
245 | void tcp_sigpool_get(unsigned int id) |
246 | { |
247 | if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) |
248 | return; |
249 | kref_get(kref: &cpool[id].kref); |
250 | } |
251 | EXPORT_SYMBOL_GPL(tcp_sigpool_get); |
252 | |
253 | int tcp_sigpool_start(unsigned int id, struct tcp_sigpool *c) __cond_acquires(RCU_BH) |
254 | { |
255 | struct crypto_ahash *hash; |
256 | |
257 | rcu_read_lock_bh(); |
258 | if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) { |
259 | rcu_read_unlock_bh(); |
260 | return -EINVAL; |
261 | } |
262 | |
263 | hash = crypto_clone_ahash(tfm: cpool[id].hash); |
264 | if (IS_ERR(ptr: hash)) { |
265 | rcu_read_unlock_bh(); |
266 | return PTR_ERR(ptr: hash); |
267 | } |
268 | |
269 | c->req = ahash_request_alloc(tfm: hash, GFP_ATOMIC); |
270 | if (!c->req) { |
271 | crypto_free_ahash(tfm: hash); |
272 | rcu_read_unlock_bh(); |
273 | return -ENOMEM; |
274 | } |
275 | ahash_request_set_callback(req: c->req, flags: 0, NULL, NULL); |
276 | |
277 | /* Pairs with tcp_sigpool_reserve_scratch(), scratch area is |
278 | * valid (allocated) until tcp_sigpool_end(). |
279 | */ |
280 | c->scratch = rcu_dereference_bh(*this_cpu_ptr(&sigpool_scratch)); |
281 | return 0; |
282 | } |
283 | EXPORT_SYMBOL_GPL(tcp_sigpool_start); |
284 | |
285 | void tcp_sigpool_end(struct tcp_sigpool *c) __releases(RCU_BH) |
286 | { |
287 | struct crypto_ahash *hash = crypto_ahash_reqtfm(req: c->req); |
288 | |
289 | rcu_read_unlock_bh(); |
290 | ahash_request_free(req: c->req); |
291 | crypto_free_ahash(tfm: hash); |
292 | } |
293 | EXPORT_SYMBOL_GPL(tcp_sigpool_end); |
294 | |
295 | /** |
296 | * tcp_sigpool_algo - return algorithm of tcp_sigpool |
297 | * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() |
298 | * @buf: buffer to return name of algorithm |
299 | * @buf_len: size of @buf |
300 | */ |
301 | size_t tcp_sigpool_algo(unsigned int id, char *buf, size_t buf_len) |
302 | { |
303 | if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) |
304 | return -EINVAL; |
305 | |
306 | return strscpy(buf, cpool[id].alg, buf_len); |
307 | } |
308 | EXPORT_SYMBOL_GPL(tcp_sigpool_algo); |
309 | |
310 | /** |
311 | * tcp_sigpool_hash_skb_data - hash data in skb with initialized tcp_sigpool |
312 | * @hp: tcp_sigpool pointer |
313 | * @skb: buffer to add sign for |
314 | * @header_len: TCP header length for this segment |
315 | */ |
316 | int tcp_sigpool_hash_skb_data(struct tcp_sigpool *hp, |
317 | const struct sk_buff *skb, |
318 | unsigned int ) |
319 | { |
320 | const unsigned int head_data_len = skb_headlen(skb) > header_len ? |
321 | skb_headlen(skb) - header_len : 0; |
322 | const struct skb_shared_info *shi = skb_shinfo(skb); |
323 | const struct tcphdr *tp = tcp_hdr(skb); |
324 | struct ahash_request *req = hp->req; |
325 | struct sk_buff *frag_iter; |
326 | struct scatterlist sg; |
327 | unsigned int i; |
328 | |
329 | sg_init_table(&sg, 1); |
330 | |
331 | sg_set_buf(sg: &sg, buf: ((u8 *)tp) + header_len, buflen: head_data_len); |
332 | ahash_request_set_crypt(req, src: &sg, NULL, nbytes: head_data_len); |
333 | if (crypto_ahash_update(req)) |
334 | return 1; |
335 | |
336 | for (i = 0; i < shi->nr_frags; ++i) { |
337 | const skb_frag_t *f = &shi->frags[i]; |
338 | unsigned int offset = skb_frag_off(frag: f); |
339 | struct page *page; |
340 | |
341 | page = skb_frag_page(frag: f) + (offset >> PAGE_SHIFT); |
342 | sg_set_page(sg: &sg, page, len: skb_frag_size(frag: f), offset_in_page(offset)); |
343 | ahash_request_set_crypt(req, src: &sg, NULL, nbytes: skb_frag_size(frag: f)); |
344 | if (crypto_ahash_update(req)) |
345 | return 1; |
346 | } |
347 | |
348 | skb_walk_frags(skb, frag_iter) |
349 | if (tcp_sigpool_hash_skb_data(hp, skb: frag_iter, header_len: 0)) |
350 | return 1; |
351 | |
352 | return 0; |
353 | } |
354 | EXPORT_SYMBOL(tcp_sigpool_hash_skb_data); |
355 | |
356 | MODULE_LICENSE("GPL" ); |
357 | MODULE_DESCRIPTION("Per-CPU pool of crypto requests" ); |
358 | |