1 | /* |
2 | * Copyright (c) 2006, 2020 Oracle and/or its affiliates. |
3 | * |
4 | * This software is available to you under a choice of one of two |
5 | * licenses. You may choose to be licensed under the terms of the GNU |
6 | * General Public License (GPL) Version 2, available from the file |
7 | * COPYING in the main directory of this source tree, or the |
8 | * OpenIB.org BSD license below: |
9 | * |
10 | * Redistribution and use in source and binary forms, with or |
11 | * without modification, are permitted provided that the following |
12 | * conditions are met: |
13 | * |
14 | * - Redistributions of source code must retain the above |
15 | * copyright notice, this list of conditions and the following |
16 | * disclaimer. |
17 | * |
18 | * - Redistributions in binary form must reproduce the above |
19 | * copyright notice, this list of conditions and the following |
20 | * disclaimer in the documentation and/or other materials |
21 | * provided with the distribution. |
22 | * |
23 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
24 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
25 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND |
26 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS |
27 | * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN |
28 | * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
29 | * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
30 | * SOFTWARE. |
31 | * |
32 | */ |
33 | #include <linux/kernel.h> |
34 | #include <linux/slab.h> |
35 | #include <linux/export.h> |
36 | #include <linux/skbuff.h> |
37 | #include <linux/list.h> |
38 | #include <linux/errqueue.h> |
39 | |
40 | #include "rds.h" |
41 | |
42 | static unsigned int rds_exthdr_size[__RDS_EXTHDR_MAX] = { |
43 | [RDS_EXTHDR_NONE] = 0, |
44 | [RDS_EXTHDR_VERSION] = sizeof(struct rds_ext_header_version), |
45 | [RDS_EXTHDR_RDMA] = sizeof(struct rds_ext_header_rdma), |
46 | [RDS_EXTHDR_RDMA_DEST] = sizeof(struct rds_ext_header_rdma_dest), |
47 | [RDS_EXTHDR_NPATHS] = sizeof(u16), |
48 | [RDS_EXTHDR_GEN_NUM] = sizeof(u32), |
49 | }; |
50 | |
51 | void rds_message_addref(struct rds_message *rm) |
52 | { |
53 | rdsdebug("addref rm %p ref %d\n" , rm, refcount_read(&rm->m_refcount)); |
54 | refcount_inc(r: &rm->m_refcount); |
55 | } |
56 | EXPORT_SYMBOL_GPL(rds_message_addref); |
57 | |
58 | static inline bool rds_zcookie_add(struct rds_msg_zcopy_info *info, u32 cookie) |
59 | { |
60 | struct rds_zcopy_cookies *ck = &info->zcookies; |
61 | int ncookies = ck->num; |
62 | |
63 | if (ncookies == RDS_MAX_ZCOOKIES) |
64 | return false; |
65 | ck->cookies[ncookies] = cookie; |
66 | ck->num = ++ncookies; |
67 | return true; |
68 | } |
69 | |
70 | static struct rds_msg_zcopy_info *rds_info_from_znotifier(struct rds_znotifier *znotif) |
71 | { |
72 | return container_of(znotif, struct rds_msg_zcopy_info, znotif); |
73 | } |
74 | |
75 | void rds_notify_msg_zcopy_purge(struct rds_msg_zcopy_queue *q) |
76 | { |
77 | unsigned long flags; |
78 | LIST_HEAD(copy); |
79 | struct rds_msg_zcopy_info *info, *tmp; |
80 | |
81 | spin_lock_irqsave(&q->lock, flags); |
82 | list_splice(list: &q->zcookie_head, head: ©); |
83 | INIT_LIST_HEAD(list: &q->zcookie_head); |
84 | spin_unlock_irqrestore(lock: &q->lock, flags); |
85 | |
86 | list_for_each_entry_safe(info, tmp, ©, rs_zcookie_next) { |
87 | list_del(entry: &info->rs_zcookie_next); |
88 | kfree(objp: info); |
89 | } |
90 | } |
91 | |
92 | static void rds_rm_zerocopy_callback(struct rds_sock *rs, |
93 | struct rds_znotifier *znotif) |
94 | { |
95 | struct rds_msg_zcopy_info *info; |
96 | struct rds_msg_zcopy_queue *q; |
97 | u32 cookie = znotif->z_cookie; |
98 | struct rds_zcopy_cookies *ck; |
99 | struct list_head *head; |
100 | unsigned long flags; |
101 | |
102 | mm_unaccount_pinned_pages(mmp: &znotif->z_mmp); |
103 | q = &rs->rs_zcookie_queue; |
104 | spin_lock_irqsave(&q->lock, flags); |
105 | head = &q->zcookie_head; |
106 | if (!list_empty(head)) { |
107 | info = list_first_entry(head, struct rds_msg_zcopy_info, |
108 | rs_zcookie_next); |
109 | if (rds_zcookie_add(info, cookie)) { |
110 | spin_unlock_irqrestore(lock: &q->lock, flags); |
111 | kfree(objp: rds_info_from_znotifier(znotif)); |
112 | /* caller invokes rds_wake_sk_sleep() */ |
113 | return; |
114 | } |
115 | } |
116 | |
117 | info = rds_info_from_znotifier(znotif); |
118 | ck = &info->zcookies; |
119 | memset(ck, 0, sizeof(*ck)); |
120 | WARN_ON(!rds_zcookie_add(info, cookie)); |
121 | list_add_tail(new: &info->rs_zcookie_next, head: &q->zcookie_head); |
122 | |
123 | spin_unlock_irqrestore(lock: &q->lock, flags); |
124 | /* caller invokes rds_wake_sk_sleep() */ |
125 | } |
126 | |
127 | /* |
128 | * This relies on dma_map_sg() not touching sg[].page during merging. |
129 | */ |
130 | static void rds_message_purge(struct rds_message *rm) |
131 | { |
132 | unsigned long i, flags; |
133 | bool zcopy = false; |
134 | |
135 | if (unlikely(test_bit(RDS_MSG_PAGEVEC, &rm->m_flags))) |
136 | return; |
137 | |
138 | spin_lock_irqsave(&rm->m_rs_lock, flags); |
139 | if (rm->m_rs) { |
140 | struct rds_sock *rs = rm->m_rs; |
141 | |
142 | if (rm->data.op_mmp_znotifier) { |
143 | zcopy = true; |
144 | rds_rm_zerocopy_callback(rs, znotif: rm->data.op_mmp_znotifier); |
145 | rds_wake_sk_sleep(rs); |
146 | rm->data.op_mmp_znotifier = NULL; |
147 | } |
148 | sock_put(sk: rds_rs_to_sk(rs)); |
149 | rm->m_rs = NULL; |
150 | } |
151 | spin_unlock_irqrestore(lock: &rm->m_rs_lock, flags); |
152 | |
153 | for (i = 0; i < rm->data.op_nents; i++) { |
154 | /* XXX will have to put_page for page refs */ |
155 | if (!zcopy) |
156 | __free_page(sg_page(&rm->data.op_sg[i])); |
157 | else |
158 | put_page(page: sg_page(sg: &rm->data.op_sg[i])); |
159 | } |
160 | rm->data.op_nents = 0; |
161 | |
162 | if (rm->rdma.op_active) |
163 | rds_rdma_free_op(ro: &rm->rdma); |
164 | if (rm->rdma.op_rdma_mr) |
165 | kref_put(kref: &rm->rdma.op_rdma_mr->r_kref, release: __rds_put_mr_final); |
166 | |
167 | if (rm->atomic.op_active) |
168 | rds_atomic_free_op(ao: &rm->atomic); |
169 | if (rm->atomic.op_rdma_mr) |
170 | kref_put(kref: &rm->atomic.op_rdma_mr->r_kref, release: __rds_put_mr_final); |
171 | } |
172 | |
173 | void rds_message_put(struct rds_message *rm) |
174 | { |
175 | rdsdebug("put rm %p ref %d\n" , rm, refcount_read(&rm->m_refcount)); |
176 | WARN(!refcount_read(&rm->m_refcount), "danger refcount zero on %p\n" , rm); |
177 | if (refcount_dec_and_test(r: &rm->m_refcount)) { |
178 | BUG_ON(!list_empty(&rm->m_sock_item)); |
179 | BUG_ON(!list_empty(&rm->m_conn_item)); |
180 | rds_message_purge(rm); |
181 | |
182 | kfree(objp: rm); |
183 | } |
184 | } |
185 | EXPORT_SYMBOL_GPL(rds_message_put); |
186 | |
187 | void (struct rds_header *hdr, __be16 sport, |
188 | __be16 dport, u64 seq) |
189 | { |
190 | hdr->h_flags = 0; |
191 | hdr->h_sport = sport; |
192 | hdr->h_dport = dport; |
193 | hdr->h_sequence = cpu_to_be64(seq); |
194 | hdr->h_exthdr[0] = RDS_EXTHDR_NONE; |
195 | } |
196 | EXPORT_SYMBOL_GPL(rds_message_populate_header); |
197 | |
198 | int rds_message_add_extension(struct rds_header *hdr, unsigned int type, |
199 | const void *data, unsigned int len) |
200 | { |
201 | unsigned int ext_len = sizeof(u8) + len; |
202 | unsigned char *dst; |
203 | |
204 | /* For now, refuse to add more than one extension header */ |
205 | if (hdr->h_exthdr[0] != RDS_EXTHDR_NONE) |
206 | return 0; |
207 | |
208 | if (type >= __RDS_EXTHDR_MAX || len != rds_exthdr_size[type]) |
209 | return 0; |
210 | |
211 | if (ext_len >= RDS_HEADER_EXT_SPACE) |
212 | return 0; |
213 | dst = hdr->h_exthdr; |
214 | |
215 | *dst++ = type; |
216 | memcpy(dst, data, len); |
217 | |
218 | dst[len] = RDS_EXTHDR_NONE; |
219 | return 1; |
220 | } |
221 | EXPORT_SYMBOL_GPL(rds_message_add_extension); |
222 | |
223 | /* |
224 | * If a message has extension headers, retrieve them here. |
225 | * Call like this: |
226 | * |
227 | * unsigned int pos = 0; |
228 | * |
229 | * while (1) { |
230 | * buflen = sizeof(buffer); |
231 | * type = rds_message_next_extension(hdr, &pos, buffer, &buflen); |
232 | * if (type == RDS_EXTHDR_NONE) |
233 | * break; |
234 | * ... |
235 | * } |
236 | */ |
237 | int rds_message_next_extension(struct rds_header *hdr, |
238 | unsigned int *pos, void *buf, unsigned int *buflen) |
239 | { |
240 | unsigned int offset, ext_type, ext_len; |
241 | u8 *src = hdr->h_exthdr; |
242 | |
243 | offset = *pos; |
244 | if (offset >= RDS_HEADER_EXT_SPACE) |
245 | goto none; |
246 | |
247 | /* Get the extension type and length. For now, the |
248 | * length is implied by the extension type. */ |
249 | ext_type = src[offset++]; |
250 | |
251 | if (ext_type == RDS_EXTHDR_NONE || ext_type >= __RDS_EXTHDR_MAX) |
252 | goto none; |
253 | ext_len = rds_exthdr_size[ext_type]; |
254 | if (offset + ext_len > RDS_HEADER_EXT_SPACE) |
255 | goto none; |
256 | |
257 | *pos = offset + ext_len; |
258 | if (ext_len < *buflen) |
259 | *buflen = ext_len; |
260 | memcpy(buf, src + offset, *buflen); |
261 | return ext_type; |
262 | |
263 | none: |
264 | *pos = RDS_HEADER_EXT_SPACE; |
265 | *buflen = 0; |
266 | return RDS_EXTHDR_NONE; |
267 | } |
268 | |
269 | int rds_message_add_rdma_dest_extension(struct rds_header *hdr, u32 r_key, u32 offset) |
270 | { |
271 | struct rds_ext_header_rdma_dest ext_hdr; |
272 | |
273 | ext_hdr.h_rdma_rkey = cpu_to_be32(r_key); |
274 | ext_hdr.h_rdma_offset = cpu_to_be32(offset); |
275 | return rds_message_add_extension(hdr, RDS_EXTHDR_RDMA_DEST, &ext_hdr, sizeof(ext_hdr)); |
276 | } |
277 | EXPORT_SYMBOL_GPL(rds_message_add_rdma_dest_extension); |
278 | |
279 | /* |
280 | * Each rds_message is allocated with extra space for the scatterlist entries |
281 | * rds ops will need. This is to minimize memory allocation count. Then, each rds op |
282 | * can grab SGs when initializing its part of the rds_message. |
283 | */ |
284 | struct rds_message *rds_message_alloc(unsigned int , gfp_t gfp) |
285 | { |
286 | struct rds_message *rm; |
287 | |
288 | if (extra_len > KMALLOC_MAX_SIZE - sizeof(struct rds_message)) |
289 | return NULL; |
290 | |
291 | rm = kzalloc(size: sizeof(struct rds_message) + extra_len, flags: gfp); |
292 | if (!rm) |
293 | goto out; |
294 | |
295 | rm->m_used_sgs = 0; |
296 | rm->m_total_sgs = extra_len / sizeof(struct scatterlist); |
297 | |
298 | refcount_set(r: &rm->m_refcount, n: 1); |
299 | INIT_LIST_HEAD(list: &rm->m_sock_item); |
300 | INIT_LIST_HEAD(list: &rm->m_conn_item); |
301 | spin_lock_init(&rm->m_rs_lock); |
302 | init_waitqueue_head(&rm->m_flush_wait); |
303 | |
304 | out: |
305 | return rm; |
306 | } |
307 | |
308 | /* |
309 | * RDS ops use this to grab SG entries from the rm's sg pool. |
310 | */ |
311 | struct scatterlist *rds_message_alloc_sgs(struct rds_message *rm, int nents) |
312 | { |
313 | struct scatterlist *sg_first = (struct scatterlist *) &rm[1]; |
314 | struct scatterlist *sg_ret; |
315 | |
316 | if (nents <= 0) { |
317 | pr_warn("rds: alloc sgs failed! nents <= 0\n" ); |
318 | return ERR_PTR(error: -EINVAL); |
319 | } |
320 | |
321 | if (rm->m_used_sgs + nents > rm->m_total_sgs) { |
322 | pr_warn("rds: alloc sgs failed! total %d used %d nents %d\n" , |
323 | rm->m_total_sgs, rm->m_used_sgs, nents); |
324 | return ERR_PTR(error: -ENOMEM); |
325 | } |
326 | |
327 | sg_ret = &sg_first[rm->m_used_sgs]; |
328 | sg_init_table(sg_ret, nents); |
329 | rm->m_used_sgs += nents; |
330 | |
331 | return sg_ret; |
332 | } |
333 | |
334 | struct rds_message *rds_message_map_pages(unsigned long *page_addrs, unsigned int total_len) |
335 | { |
336 | struct rds_message *rm; |
337 | unsigned int i; |
338 | int num_sgs = DIV_ROUND_UP(total_len, PAGE_SIZE); |
339 | int = num_sgs * sizeof(struct scatterlist); |
340 | |
341 | rm = rds_message_alloc(extra_len: extra_bytes, GFP_NOWAIT); |
342 | if (!rm) |
343 | return ERR_PTR(error: -ENOMEM); |
344 | |
345 | set_bit(RDS_MSG_PAGEVEC, addr: &rm->m_flags); |
346 | rm->m_inc.i_hdr.h_len = cpu_to_be32(total_len); |
347 | rm->data.op_nents = DIV_ROUND_UP(total_len, PAGE_SIZE); |
348 | rm->data.op_sg = rds_message_alloc_sgs(rm, nents: num_sgs); |
349 | if (IS_ERR(ptr: rm->data.op_sg)) { |
350 | void *err = ERR_CAST(ptr: rm->data.op_sg); |
351 | rds_message_put(rm); |
352 | return err; |
353 | } |
354 | |
355 | for (i = 0; i < rm->data.op_nents; ++i) { |
356 | sg_set_page(sg: &rm->data.op_sg[i], |
357 | virt_to_page((void *)page_addrs[i]), |
358 | PAGE_SIZE, offset: 0); |
359 | } |
360 | |
361 | return rm; |
362 | } |
363 | |
364 | static int rds_message_zcopy_from_user(struct rds_message *rm, struct iov_iter *from) |
365 | { |
366 | struct scatterlist *sg; |
367 | int ret = 0; |
368 | int length = iov_iter_count(i: from); |
369 | struct rds_msg_zcopy_info *info; |
370 | |
371 | rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from)); |
372 | |
373 | /* |
374 | * now allocate and copy in the data payload. |
375 | */ |
376 | sg = rm->data.op_sg; |
377 | |
378 | info = kzalloc(size: sizeof(*info), GFP_KERNEL); |
379 | if (!info) |
380 | return -ENOMEM; |
381 | INIT_LIST_HEAD(list: &info->rs_zcookie_next); |
382 | rm->data.op_mmp_znotifier = &info->znotif; |
383 | if (mm_account_pinned_pages(mmp: &rm->data.op_mmp_znotifier->z_mmp, |
384 | size: length)) { |
385 | ret = -ENOMEM; |
386 | goto err; |
387 | } |
388 | while (iov_iter_count(i: from)) { |
389 | struct page *pages; |
390 | size_t start; |
391 | ssize_t copied; |
392 | |
393 | copied = iov_iter_get_pages2(i: from, pages: &pages, PAGE_SIZE, |
394 | maxpages: 1, start: &start); |
395 | if (copied < 0) { |
396 | struct mmpin *mmp; |
397 | int i; |
398 | |
399 | for (i = 0; i < rm->data.op_nents; i++) |
400 | put_page(page: sg_page(sg: &rm->data.op_sg[i])); |
401 | mmp = &rm->data.op_mmp_znotifier->z_mmp; |
402 | mm_unaccount_pinned_pages(mmp); |
403 | ret = -EFAULT; |
404 | goto err; |
405 | } |
406 | length -= copied; |
407 | sg_set_page(sg, page: pages, len: copied, offset: start); |
408 | rm->data.op_nents++; |
409 | sg++; |
410 | } |
411 | WARN_ON_ONCE(length != 0); |
412 | return ret; |
413 | err: |
414 | kfree(objp: info); |
415 | rm->data.op_mmp_znotifier = NULL; |
416 | return ret; |
417 | } |
418 | |
419 | int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from, |
420 | bool zcopy) |
421 | { |
422 | unsigned long to_copy, nbytes; |
423 | unsigned long sg_off; |
424 | struct scatterlist *sg; |
425 | int ret = 0; |
426 | |
427 | rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from)); |
428 | |
429 | /* now allocate and copy in the data payload. */ |
430 | sg = rm->data.op_sg; |
431 | sg_off = 0; /* Dear gcc, sg->page will be null from kzalloc. */ |
432 | |
433 | if (zcopy) |
434 | return rds_message_zcopy_from_user(rm, from); |
435 | |
436 | while (iov_iter_count(i: from)) { |
437 | if (!sg_page(sg)) { |
438 | ret = rds_page_remainder_alloc(scat: sg, bytes: iov_iter_count(i: from), |
439 | GFP_HIGHUSER); |
440 | if (ret) |
441 | return ret; |
442 | rm->data.op_nents++; |
443 | sg_off = 0; |
444 | } |
445 | |
446 | to_copy = min_t(unsigned long, iov_iter_count(from), |
447 | sg->length - sg_off); |
448 | |
449 | rds_stats_add(s_copy_from_user, to_copy); |
450 | nbytes = copy_page_from_iter(page: sg_page(sg), offset: sg->offset + sg_off, |
451 | bytes: to_copy, i: from); |
452 | if (nbytes != to_copy) |
453 | return -EFAULT; |
454 | |
455 | sg_off += to_copy; |
456 | |
457 | if (sg_off == sg->length) |
458 | sg++; |
459 | } |
460 | |
461 | return ret; |
462 | } |
463 | |
464 | int rds_message_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to) |
465 | { |
466 | struct rds_message *rm; |
467 | struct scatterlist *sg; |
468 | unsigned long to_copy; |
469 | unsigned long vec_off; |
470 | int copied; |
471 | int ret; |
472 | u32 len; |
473 | |
474 | rm = container_of(inc, struct rds_message, m_inc); |
475 | len = be32_to_cpu(rm->m_inc.i_hdr.h_len); |
476 | |
477 | sg = rm->data.op_sg; |
478 | vec_off = 0; |
479 | copied = 0; |
480 | |
481 | while (iov_iter_count(i: to) && copied < len) { |
482 | to_copy = min_t(unsigned long, iov_iter_count(to), |
483 | sg->length - vec_off); |
484 | to_copy = min_t(unsigned long, to_copy, len - copied); |
485 | |
486 | rds_stats_add(s_copy_to_user, to_copy); |
487 | ret = copy_page_to_iter(page: sg_page(sg), offset: sg->offset + vec_off, |
488 | bytes: to_copy, i: to); |
489 | if (ret != to_copy) |
490 | return -EFAULT; |
491 | |
492 | vec_off += to_copy; |
493 | copied += to_copy; |
494 | |
495 | if (vec_off == sg->length) { |
496 | vec_off = 0; |
497 | sg++; |
498 | } |
499 | } |
500 | |
501 | return copied; |
502 | } |
503 | |
504 | /* |
505 | * If the message is still on the send queue, wait until the transport |
506 | * is done with it. This is particularly important for RDMA operations. |
507 | */ |
508 | void rds_message_wait(struct rds_message *rm) |
509 | { |
510 | wait_event_interruptible(rm->m_flush_wait, |
511 | !test_bit(RDS_MSG_MAPPED, &rm->m_flags)); |
512 | } |
513 | |
514 | void rds_message_unmapped(struct rds_message *rm) |
515 | { |
516 | clear_bit(RDS_MSG_MAPPED, addr: &rm->m_flags); |
517 | wake_up_interruptible(&rm->m_flush_wait); |
518 | } |
519 | EXPORT_SYMBOL_GPL(rds_message_unmapped); |
520 | |