1 | /* Copyright (c) 2018, Mellanox Technologies All rights reserved. |
2 | * |
3 | * This software is available to you under a choice of one of two |
4 | * licenses. You may choose to be licensed under the terms of the GNU |
5 | * General Public License (GPL) Version 2, available from the file |
6 | * COPYING in the main directory of this source tree, or the |
7 | * OpenIB.org BSD license below: |
8 | * |
9 | * Redistribution and use in source and binary forms, with or |
10 | * without modification, are permitted provided that the following |
11 | * conditions are met: |
12 | * |
13 | * - Redistributions of source code must retain the above |
14 | * copyright notice, this list of conditions and the following |
15 | * disclaimer. |
16 | * |
17 | * - Redistributions in binary form must reproduce the above |
18 | * copyright notice, this list of conditions and the following |
19 | * disclaimer in the documentation and/or other materials |
20 | * provided with the distribution. |
21 | * |
22 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
23 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
24 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND |
25 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS |
26 | * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN |
27 | * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
28 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
29 | * SOFTWARE. |
30 | */ |
31 | |
32 | #include <net/tls.h> |
33 | #include <crypto/aead.h> |
34 | #include <crypto/scatterwalk.h> |
35 | #include <net/ip6_checksum.h> |
36 | |
37 | #include "tls.h" |
38 | |
39 | static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk) |
40 | { |
41 | struct scatterlist *src = walk->sg; |
42 | int diff = walk->offset - src->offset; |
43 | |
44 | sg_set_page(sg, page: sg_page(sg: src), |
45 | len: src->length - diff, offset: walk->offset); |
46 | |
47 | scatterwalk_crypto_chain(head: sg, sg: sg_next(src), num: 2); |
48 | } |
49 | |
50 | static int tls_enc_record(struct aead_request *aead_req, |
51 | struct crypto_aead *aead, char *aad, |
52 | char *iv, __be64 rcd_sn, |
53 | struct scatter_walk *in, |
54 | struct scatter_walk *out, int *in_len, |
55 | struct tls_prot_info *prot) |
56 | { |
57 | unsigned char buf[TLS_HEADER_SIZE + TLS_MAX_IV_SIZE]; |
58 | const struct tls_cipher_desc *cipher_desc; |
59 | struct scatterlist sg_in[3]; |
60 | struct scatterlist sg_out[3]; |
61 | unsigned int buf_size; |
62 | u16 len; |
63 | int rc; |
64 | |
65 | cipher_desc = get_cipher_desc(cipher_type: prot->cipher_type); |
66 | DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); |
67 | |
68 | buf_size = TLS_HEADER_SIZE + cipher_desc->iv; |
69 | len = min_t(int, *in_len, buf_size); |
70 | |
71 | scatterwalk_copychunks(buf, walk: in, nbytes: len, out: 0); |
72 | scatterwalk_copychunks(buf, walk: out, nbytes: len, out: 1); |
73 | |
74 | *in_len -= len; |
75 | if (!*in_len) |
76 | return 0; |
77 | |
78 | scatterwalk_pagedone(walk: in, out: 0, more: 1); |
79 | scatterwalk_pagedone(walk: out, out: 1, more: 1); |
80 | |
81 | len = buf[4] | (buf[3] << 8); |
82 | len -= cipher_desc->iv; |
83 | |
84 | tls_make_aad(buf: aad, size: len - cipher_desc->tag, record_sequence: (char *)&rcd_sn, record_type: buf[0], prot); |
85 | |
86 | memcpy(iv + cipher_desc->salt, buf + TLS_HEADER_SIZE, cipher_desc->iv); |
87 | |
88 | sg_init_table(sg_in, ARRAY_SIZE(sg_in)); |
89 | sg_init_table(sg_out, ARRAY_SIZE(sg_out)); |
90 | sg_set_buf(sg: sg_in, buf: aad, TLS_AAD_SPACE_SIZE); |
91 | sg_set_buf(sg: sg_out, buf: aad, TLS_AAD_SPACE_SIZE); |
92 | chain_to_walk(sg: sg_in + 1, walk: in); |
93 | chain_to_walk(sg: sg_out + 1, walk: out); |
94 | |
95 | *in_len -= len; |
96 | if (*in_len < 0) { |
97 | *in_len += cipher_desc->tag; |
98 | /* the input buffer doesn't contain the entire record. |
99 | * trim len accordingly. The resulting authentication tag |
100 | * will contain garbage, but we don't care, so we won't |
101 | * include any of it in the output skb |
102 | * Note that we assume the output buffer length |
103 | * is larger then input buffer length + tag size |
104 | */ |
105 | if (*in_len < 0) |
106 | len += *in_len; |
107 | |
108 | *in_len = 0; |
109 | } |
110 | |
111 | if (*in_len) { |
112 | scatterwalk_copychunks(NULL, walk: in, nbytes: len, out: 2); |
113 | scatterwalk_pagedone(walk: in, out: 0, more: 1); |
114 | scatterwalk_copychunks(NULL, walk: out, nbytes: len, out: 2); |
115 | scatterwalk_pagedone(walk: out, out: 1, more: 1); |
116 | } |
117 | |
118 | len -= cipher_desc->tag; |
119 | aead_request_set_crypt(req: aead_req, src: sg_in, dst: sg_out, cryptlen: len, iv); |
120 | |
121 | rc = crypto_aead_encrypt(req: aead_req); |
122 | |
123 | return rc; |
124 | } |
125 | |
126 | static void tls_init_aead_request(struct aead_request *aead_req, |
127 | struct crypto_aead *aead) |
128 | { |
129 | aead_request_set_tfm(req: aead_req, tfm: aead); |
130 | aead_request_set_ad(req: aead_req, TLS_AAD_SPACE_SIZE); |
131 | } |
132 | |
133 | static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, |
134 | gfp_t flags) |
135 | { |
136 | unsigned int req_size = sizeof(struct aead_request) + |
137 | crypto_aead_reqsize(tfm: aead); |
138 | struct aead_request *aead_req; |
139 | |
140 | aead_req = kzalloc(size: req_size, flags); |
141 | if (aead_req) |
142 | tls_init_aead_request(aead_req, aead); |
143 | return aead_req; |
144 | } |
145 | |
146 | static int tls_enc_records(struct aead_request *aead_req, |
147 | struct crypto_aead *aead, struct scatterlist *sg_in, |
148 | struct scatterlist *sg_out, char *aad, char *iv, |
149 | u64 rcd_sn, int len, struct tls_prot_info *prot) |
150 | { |
151 | struct scatter_walk out, in; |
152 | int rc; |
153 | |
154 | scatterwalk_start(walk: &in, sg: sg_in); |
155 | scatterwalk_start(walk: &out, sg: sg_out); |
156 | |
157 | do { |
158 | rc = tls_enc_record(aead_req, aead, aad, iv, |
159 | cpu_to_be64(rcd_sn), in: &in, out: &out, in_len: &len, prot); |
160 | rcd_sn++; |
161 | |
162 | } while (rc == 0 && len); |
163 | |
164 | scatterwalk_done(walk: &in, out: 0, more: 0); |
165 | scatterwalk_done(walk: &out, out: 1, more: 0); |
166 | |
167 | return rc; |
168 | } |
169 | |
170 | /* Can't use icsk->icsk_af_ops->send_check here because the ip addresses |
171 | * might have been changed by NAT. |
172 | */ |
173 | static void update_chksum(struct sk_buff *skb, int headln) |
174 | { |
175 | struct tcphdr *th = tcp_hdr(skb); |
176 | int datalen = skb->len - headln; |
177 | const struct ipv6hdr *ipv6h; |
178 | const struct iphdr *iph; |
179 | |
180 | /* We only changed the payload so if we are using partial we don't |
181 | * need to update anything. |
182 | */ |
183 | if (likely(skb->ip_summed == CHECKSUM_PARTIAL)) |
184 | return; |
185 | |
186 | skb->ip_summed = CHECKSUM_PARTIAL; |
187 | skb->csum_start = skb_transport_header(skb) - skb->head; |
188 | skb->csum_offset = offsetof(struct tcphdr, check); |
189 | |
190 | if (skb->sk->sk_family == AF_INET6) { |
191 | ipv6h = ipv6_hdr(skb); |
192 | th->check = ~csum_ipv6_magic(saddr: &ipv6h->saddr, daddr: &ipv6h->daddr, |
193 | len: datalen, IPPROTO_TCP, sum: 0); |
194 | } else { |
195 | iph = ip_hdr(skb); |
196 | th->check = ~csum_tcpudp_magic(saddr: iph->saddr, daddr: iph->daddr, len: datalen, |
197 | IPPROTO_TCP, sum: 0); |
198 | } |
199 | } |
200 | |
201 | static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln) |
202 | { |
203 | struct sock *sk = skb->sk; |
204 | int delta; |
205 | |
206 | skb_copy_header(new: nskb, old: skb); |
207 | |
208 | skb_put(skb: nskb, len: skb->len); |
209 | memcpy(nskb->data, skb->data, headln); |
210 | |
211 | nskb->destructor = skb->destructor; |
212 | nskb->sk = sk; |
213 | skb->destructor = NULL; |
214 | skb->sk = NULL; |
215 | |
216 | update_chksum(skb: nskb, headln); |
217 | |
218 | /* sock_efree means skb must gone through skb_orphan_partial() */ |
219 | if (nskb->destructor == sock_efree) |
220 | return; |
221 | |
222 | delta = nskb->truesize - skb->truesize; |
223 | if (likely(delta < 0)) |
224 | WARN_ON_ONCE(refcount_sub_and_test(-delta, &sk->sk_wmem_alloc)); |
225 | else if (delta) |
226 | refcount_add(i: delta, r: &sk->sk_wmem_alloc); |
227 | } |
228 | |
229 | /* This function may be called after the user socket is already |
230 | * closed so make sure we don't use anything freed during |
231 | * tls_sk_proto_close here |
232 | */ |
233 | |
234 | static int fill_sg_in(struct scatterlist *sg_in, |
235 | struct sk_buff *skb, |
236 | struct tls_offload_context_tx *ctx, |
237 | u64 *rcd_sn, |
238 | s32 *sync_size, |
239 | int *resync_sgs) |
240 | { |
241 | int tcp_payload_offset = skb_tcp_all_headers(skb); |
242 | int payload_len = skb->len - tcp_payload_offset; |
243 | u32 tcp_seq = ntohl(tcp_hdr(skb)->seq); |
244 | struct tls_record_info *record; |
245 | unsigned long flags; |
246 | int remaining; |
247 | int i; |
248 | |
249 | spin_lock_irqsave(&ctx->lock, flags); |
250 | record = tls_get_record(context: ctx, seq: tcp_seq, p_record_sn: rcd_sn); |
251 | if (!record) { |
252 | spin_unlock_irqrestore(lock: &ctx->lock, flags); |
253 | return -EINVAL; |
254 | } |
255 | |
256 | *sync_size = tcp_seq - tls_record_start_seq(rec: record); |
257 | if (*sync_size < 0) { |
258 | int is_start_marker = tls_record_is_start_marker(rec: record); |
259 | |
260 | spin_unlock_irqrestore(lock: &ctx->lock, flags); |
261 | /* This should only occur if the relevant record was |
262 | * already acked. In that case it should be ok |
263 | * to drop the packet and avoid retransmission. |
264 | * |
265 | * There is a corner case where the packet contains |
266 | * both an acked and a non-acked record. |
267 | * We currently don't handle that case and rely |
268 | * on TCP to retransmit a packet that doesn't contain |
269 | * already acked payload. |
270 | */ |
271 | if (!is_start_marker) |
272 | *sync_size = 0; |
273 | return -EINVAL; |
274 | } |
275 | |
276 | remaining = *sync_size; |
277 | for (i = 0; remaining > 0; i++) { |
278 | skb_frag_t *frag = &record->frags[i]; |
279 | |
280 | __skb_frag_ref(frag); |
281 | sg_set_page(sg: sg_in + i, page: skb_frag_page(frag), |
282 | len: skb_frag_size(frag), offset: skb_frag_off(frag)); |
283 | |
284 | remaining -= skb_frag_size(frag); |
285 | |
286 | if (remaining < 0) |
287 | sg_in[i].length += remaining; |
288 | } |
289 | *resync_sgs = i; |
290 | |
291 | spin_unlock_irqrestore(lock: &ctx->lock, flags); |
292 | if (skb_to_sgvec(skb, sg: &sg_in[i], offset: tcp_payload_offset, len: payload_len) < 0) |
293 | return -EINVAL; |
294 | |
295 | return 0; |
296 | } |
297 | |
298 | static void fill_sg_out(struct scatterlist sg_out[3], void *buf, |
299 | struct tls_context *tls_ctx, |
300 | struct sk_buff *nskb, |
301 | int tcp_payload_offset, |
302 | int payload_len, |
303 | int sync_size, |
304 | void *dummy_buf) |
305 | { |
306 | const struct tls_cipher_desc *cipher_desc = |
307 | get_cipher_desc(cipher_type: tls_ctx->crypto_send.info.cipher_type); |
308 | |
309 | sg_set_buf(sg: &sg_out[0], buf: dummy_buf, buflen: sync_size); |
310 | sg_set_buf(sg: &sg_out[1], buf: nskb->data + tcp_payload_offset, buflen: payload_len); |
311 | /* Add room for authentication tag produced by crypto */ |
312 | dummy_buf += sync_size; |
313 | sg_set_buf(sg: &sg_out[2], buf: dummy_buf, buflen: cipher_desc->tag); |
314 | } |
315 | |
316 | static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, |
317 | struct scatterlist sg_out[3], |
318 | struct scatterlist *sg_in, |
319 | struct sk_buff *skb, |
320 | s32 sync_size, u64 rcd_sn) |
321 | { |
322 | struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); |
323 | int tcp_payload_offset = skb_tcp_all_headers(skb); |
324 | int payload_len = skb->len - tcp_payload_offset; |
325 | const struct tls_cipher_desc *cipher_desc; |
326 | void *buf, *iv, *aad, *dummy_buf, *salt; |
327 | struct aead_request *aead_req; |
328 | struct sk_buff *nskb = NULL; |
329 | int buf_len; |
330 | |
331 | aead_req = tls_alloc_aead_request(aead: ctx->aead_send, GFP_ATOMIC); |
332 | if (!aead_req) |
333 | return NULL; |
334 | |
335 | cipher_desc = get_cipher_desc(cipher_type: tls_ctx->crypto_send.info.cipher_type); |
336 | DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); |
337 | |
338 | buf_len = cipher_desc->salt + cipher_desc->iv + TLS_AAD_SPACE_SIZE + |
339 | sync_size + cipher_desc->tag; |
340 | buf = kmalloc(size: buf_len, GFP_ATOMIC); |
341 | if (!buf) |
342 | goto free_req; |
343 | |
344 | iv = buf; |
345 | salt = crypto_info_salt(crypto_info: &tls_ctx->crypto_send.info, cipher_desc); |
346 | memcpy(iv, salt, cipher_desc->salt); |
347 | aad = buf + cipher_desc->salt + cipher_desc->iv; |
348 | dummy_buf = aad + TLS_AAD_SPACE_SIZE; |
349 | |
350 | nskb = alloc_skb(size: skb_headroom(skb) + skb->len, GFP_ATOMIC); |
351 | if (!nskb) |
352 | goto free_buf; |
353 | |
354 | skb_reserve(skb: nskb, len: skb_headroom(skb)); |
355 | |
356 | fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset, |
357 | payload_len, sync_size, dummy_buf); |
358 | |
359 | if (tls_enc_records(aead_req, aead: ctx->aead_send, sg_in, sg_out, aad, iv, |
360 | rcd_sn, len: sync_size + payload_len, |
361 | prot: &tls_ctx->prot_info) < 0) |
362 | goto free_nskb; |
363 | |
364 | complete_skb(nskb, skb, headln: tcp_payload_offset); |
365 | |
366 | /* validate_xmit_skb_list assumes that if the skb wasn't segmented |
367 | * nskb->prev will point to the skb itself |
368 | */ |
369 | nskb->prev = nskb; |
370 | |
371 | free_buf: |
372 | kfree(objp: buf); |
373 | free_req: |
374 | kfree(objp: aead_req); |
375 | return nskb; |
376 | free_nskb: |
377 | kfree_skb(skb: nskb); |
378 | nskb = NULL; |
379 | goto free_buf; |
380 | } |
381 | |
382 | static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb) |
383 | { |
384 | int tcp_payload_offset = skb_tcp_all_headers(skb); |
385 | struct tls_context *tls_ctx = tls_get_ctx(sk); |
386 | struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); |
387 | int payload_len = skb->len - tcp_payload_offset; |
388 | struct scatterlist *sg_in, sg_out[3]; |
389 | struct sk_buff *nskb = NULL; |
390 | int sg_in_max_elements; |
391 | int resync_sgs = 0; |
392 | s32 sync_size = 0; |
393 | u64 rcd_sn; |
394 | |
395 | /* worst case is: |
396 | * MAX_SKB_FRAGS in tls_record_info |
397 | * MAX_SKB_FRAGS + 1 in SKB head and frags. |
398 | */ |
399 | sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1; |
400 | |
401 | if (!payload_len) |
402 | return skb; |
403 | |
404 | sg_in = kmalloc_array(n: sg_in_max_elements, size: sizeof(*sg_in), GFP_ATOMIC); |
405 | if (!sg_in) |
406 | goto free_orig; |
407 | |
408 | sg_init_table(sg_in, sg_in_max_elements); |
409 | sg_init_table(sg_out, ARRAY_SIZE(sg_out)); |
410 | |
411 | if (fill_sg_in(sg_in, skb, ctx, rcd_sn: &rcd_sn, sync_size: &sync_size, resync_sgs: &resync_sgs)) { |
412 | /* bypass packets before kernel TLS socket option was set */ |
413 | if (sync_size < 0 && payload_len <= -sync_size) |
414 | nskb = skb_get(skb); |
415 | goto put_sg; |
416 | } |
417 | |
418 | nskb = tls_enc_skb(tls_ctx, sg_out, sg_in, skb, sync_size, rcd_sn); |
419 | |
420 | put_sg: |
421 | while (resync_sgs) |
422 | put_page(page: sg_page(sg: &sg_in[--resync_sgs])); |
423 | kfree(objp: sg_in); |
424 | free_orig: |
425 | if (nskb) |
426 | consume_skb(skb); |
427 | else |
428 | kfree_skb(skb); |
429 | return nskb; |
430 | } |
431 | |
432 | struct sk_buff *tls_validate_xmit_skb(struct sock *sk, |
433 | struct net_device *dev, |
434 | struct sk_buff *skb) |
435 | { |
436 | if (dev == rcu_dereference_bh(tls_get_ctx(sk)->netdev) || |
437 | netif_is_bond_master(dev)) |
438 | return skb; |
439 | |
440 | return tls_sw_fallback(sk, skb); |
441 | } |
442 | EXPORT_SYMBOL_GPL(tls_validate_xmit_skb); |
443 | |
444 | struct sk_buff *tls_validate_xmit_skb_sw(struct sock *sk, |
445 | struct net_device *dev, |
446 | struct sk_buff *skb) |
447 | { |
448 | return tls_sw_fallback(sk, skb); |
449 | } |
450 | |
451 | struct sk_buff *tls_encrypt_skb(struct sk_buff *skb) |
452 | { |
453 | return tls_sw_fallback(sk: skb->sk, skb); |
454 | } |
455 | EXPORT_SYMBOL_GPL(tls_encrypt_skb); |
456 | |
457 | int tls_sw_fallback_init(struct sock *sk, |
458 | struct tls_offload_context_tx *offload_ctx, |
459 | struct tls_crypto_info *crypto_info) |
460 | { |
461 | const struct tls_cipher_desc *cipher_desc; |
462 | int rc; |
463 | |
464 | cipher_desc = get_cipher_desc(cipher_type: crypto_info->cipher_type); |
465 | if (!cipher_desc || !cipher_desc->offloadable) |
466 | return -EINVAL; |
467 | |
468 | offload_ctx->aead_send = |
469 | crypto_alloc_aead(alg_name: cipher_desc->cipher_name, type: 0, CRYPTO_ALG_ASYNC); |
470 | if (IS_ERR(ptr: offload_ctx->aead_send)) { |
471 | rc = PTR_ERR(ptr: offload_ctx->aead_send); |
472 | pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n" , rc); |
473 | offload_ctx->aead_send = NULL; |
474 | goto err_out; |
475 | } |
476 | |
477 | rc = crypto_aead_setkey(tfm: offload_ctx->aead_send, |
478 | key: crypto_info_key(crypto_info, cipher_desc), |
479 | keylen: cipher_desc->key); |
480 | if (rc) |
481 | goto free_aead; |
482 | |
483 | rc = crypto_aead_setauthsize(tfm: offload_ctx->aead_send, authsize: cipher_desc->tag); |
484 | if (rc) |
485 | goto free_aead; |
486 | |
487 | return 0; |
488 | free_aead: |
489 | crypto_free_aead(tfm: offload_ctx->aead_send); |
490 | err_out: |
491 | return rc; |
492 | } |
493 | |