1 | // SPDX-License-Identifier: GPL-2.0-only |
2 | /* |
3 | * Kernel/userspace transport abstraction for Hyper-V util driver. |
4 | * |
5 | * Copyright (C) 2015, Vitaly Kuznetsov <vkuznets@redhat.com> |
6 | */ |
7 | |
8 | #include <linux/slab.h> |
9 | #include <linux/fs.h> |
10 | #include <linux/poll.h> |
11 | |
12 | #include "hyperv_vmbus.h" |
13 | #include "hv_utils_transport.h" |
14 | |
15 | static DEFINE_SPINLOCK(hvt_list_lock); |
16 | static LIST_HEAD(hvt_list); |
17 | |
18 | static void hvt_reset(struct hvutil_transport *hvt) |
19 | { |
20 | kfree(objp: hvt->outmsg); |
21 | hvt->outmsg = NULL; |
22 | hvt->outmsg_len = 0; |
23 | if (hvt->on_reset) |
24 | hvt->on_reset(); |
25 | } |
26 | |
27 | static ssize_t hvt_op_read(struct file *file, char __user *buf, |
28 | size_t count, loff_t *ppos) |
29 | { |
30 | struct hvutil_transport *hvt; |
31 | int ret; |
32 | |
33 | hvt = container_of(file->f_op, struct hvutil_transport, fops); |
34 | |
35 | if (wait_event_interruptible(hvt->outmsg_q, hvt->outmsg_len > 0 || |
36 | hvt->mode != HVUTIL_TRANSPORT_CHARDEV)) |
37 | return -EINTR; |
38 | |
39 | mutex_lock(&hvt->lock); |
40 | |
41 | if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) { |
42 | ret = -EBADF; |
43 | goto out_unlock; |
44 | } |
45 | |
46 | if (!hvt->outmsg) { |
47 | ret = -EAGAIN; |
48 | goto out_unlock; |
49 | } |
50 | |
51 | if (count < hvt->outmsg_len) { |
52 | ret = -EINVAL; |
53 | goto out_unlock; |
54 | } |
55 | |
56 | if (!copy_to_user(to: buf, from: hvt->outmsg, n: hvt->outmsg_len)) |
57 | ret = hvt->outmsg_len; |
58 | else |
59 | ret = -EFAULT; |
60 | |
61 | kfree(objp: hvt->outmsg); |
62 | hvt->outmsg = NULL; |
63 | hvt->outmsg_len = 0; |
64 | |
65 | if (hvt->on_read) |
66 | hvt->on_read(); |
67 | hvt->on_read = NULL; |
68 | |
69 | out_unlock: |
70 | mutex_unlock(lock: &hvt->lock); |
71 | return ret; |
72 | } |
73 | |
74 | static ssize_t hvt_op_write(struct file *file, const char __user *buf, |
75 | size_t count, loff_t *ppos) |
76 | { |
77 | struct hvutil_transport *hvt; |
78 | u8 *inmsg; |
79 | int ret; |
80 | |
81 | hvt = container_of(file->f_op, struct hvutil_transport, fops); |
82 | |
83 | inmsg = memdup_user(buf, count); |
84 | if (IS_ERR(ptr: inmsg)) |
85 | return PTR_ERR(ptr: inmsg); |
86 | |
87 | if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) |
88 | ret = -EBADF; |
89 | else |
90 | ret = hvt->on_msg(inmsg, count); |
91 | |
92 | kfree(objp: inmsg); |
93 | |
94 | return ret ? ret : count; |
95 | } |
96 | |
97 | static __poll_t hvt_op_poll(struct file *file, poll_table *wait) |
98 | { |
99 | struct hvutil_transport *hvt; |
100 | |
101 | hvt = container_of(file->f_op, struct hvutil_transport, fops); |
102 | |
103 | poll_wait(filp: file, wait_address: &hvt->outmsg_q, p: wait); |
104 | |
105 | if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) |
106 | return EPOLLERR | EPOLLHUP; |
107 | |
108 | if (hvt->outmsg_len > 0) |
109 | return EPOLLIN | EPOLLRDNORM; |
110 | |
111 | return 0; |
112 | } |
113 | |
114 | static int hvt_op_open(struct inode *inode, struct file *file) |
115 | { |
116 | struct hvutil_transport *hvt; |
117 | int ret = 0; |
118 | bool issue_reset = false; |
119 | |
120 | hvt = container_of(file->f_op, struct hvutil_transport, fops); |
121 | |
122 | mutex_lock(&hvt->lock); |
123 | |
124 | if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) { |
125 | ret = -EBADF; |
126 | } else if (hvt->mode == HVUTIL_TRANSPORT_INIT) { |
127 | /* |
128 | * Switching to CHARDEV mode. We switch bach to INIT when |
129 | * device gets released. |
130 | */ |
131 | hvt->mode = HVUTIL_TRANSPORT_CHARDEV; |
132 | } |
133 | else if (hvt->mode == HVUTIL_TRANSPORT_NETLINK) { |
134 | /* |
135 | * We're switching from netlink communication to using char |
136 | * device. Issue the reset first. |
137 | */ |
138 | issue_reset = true; |
139 | hvt->mode = HVUTIL_TRANSPORT_CHARDEV; |
140 | } else { |
141 | ret = -EBUSY; |
142 | } |
143 | |
144 | if (issue_reset) |
145 | hvt_reset(hvt); |
146 | |
147 | mutex_unlock(lock: &hvt->lock); |
148 | |
149 | return ret; |
150 | } |
151 | |
152 | static void hvt_transport_free(struct hvutil_transport *hvt) |
153 | { |
154 | misc_deregister(misc: &hvt->mdev); |
155 | kfree(objp: hvt->outmsg); |
156 | kfree(objp: hvt); |
157 | } |
158 | |
159 | static int hvt_op_release(struct inode *inode, struct file *file) |
160 | { |
161 | struct hvutil_transport *hvt; |
162 | int mode_old; |
163 | |
164 | hvt = container_of(file->f_op, struct hvutil_transport, fops); |
165 | |
166 | mutex_lock(&hvt->lock); |
167 | mode_old = hvt->mode; |
168 | if (hvt->mode != HVUTIL_TRANSPORT_DESTROY) |
169 | hvt->mode = HVUTIL_TRANSPORT_INIT; |
170 | /* |
171 | * Cleanup message buffers to avoid spurious messages when the daemon |
172 | * connects back. |
173 | */ |
174 | hvt_reset(hvt); |
175 | |
176 | if (mode_old == HVUTIL_TRANSPORT_DESTROY) |
177 | complete(&hvt->release); |
178 | |
179 | mutex_unlock(lock: &hvt->lock); |
180 | |
181 | return 0; |
182 | } |
183 | |
184 | static void hvt_cn_callback(struct cn_msg *msg, struct netlink_skb_parms *nsp) |
185 | { |
186 | struct hvutil_transport *hvt, *hvt_found = NULL; |
187 | |
188 | spin_lock(lock: &hvt_list_lock); |
189 | list_for_each_entry(hvt, &hvt_list, list) { |
190 | if (hvt->cn_id.idx == msg->id.idx && |
191 | hvt->cn_id.val == msg->id.val) { |
192 | hvt_found = hvt; |
193 | break; |
194 | } |
195 | } |
196 | spin_unlock(lock: &hvt_list_lock); |
197 | if (!hvt_found) { |
198 | pr_warn("hvt_cn_callback: spurious message received!\n" ); |
199 | return; |
200 | } |
201 | |
202 | /* |
203 | * Switching to NETLINK mode. Switching to CHARDEV happens when someone |
204 | * opens the device. |
205 | */ |
206 | mutex_lock(&hvt->lock); |
207 | if (hvt->mode == HVUTIL_TRANSPORT_INIT) |
208 | hvt->mode = HVUTIL_TRANSPORT_NETLINK; |
209 | |
210 | if (hvt->mode == HVUTIL_TRANSPORT_NETLINK) |
211 | hvt_found->on_msg(msg->data, msg->len); |
212 | else |
213 | pr_warn("hvt_cn_callback: unexpected netlink message!\n" ); |
214 | mutex_unlock(lock: &hvt->lock); |
215 | } |
216 | |
217 | int hvutil_transport_send(struct hvutil_transport *hvt, void *msg, int len, |
218 | void (*on_read_cb)(void)) |
219 | { |
220 | struct cn_msg *cn_msg; |
221 | int ret = 0; |
222 | |
223 | if (hvt->mode == HVUTIL_TRANSPORT_INIT || |
224 | hvt->mode == HVUTIL_TRANSPORT_DESTROY) { |
225 | return -EINVAL; |
226 | } else if (hvt->mode == HVUTIL_TRANSPORT_NETLINK) { |
227 | cn_msg = kzalloc(size: sizeof(*cn_msg) + len, GFP_ATOMIC); |
228 | if (!cn_msg) |
229 | return -ENOMEM; |
230 | cn_msg->id.idx = hvt->cn_id.idx; |
231 | cn_msg->id.val = hvt->cn_id.val; |
232 | cn_msg->len = len; |
233 | memcpy(cn_msg->data, msg, len); |
234 | ret = cn_netlink_send(msg: cn_msg, portid: 0, group: 0, GFP_ATOMIC); |
235 | kfree(objp: cn_msg); |
236 | /* |
237 | * We don't know when netlink messages are delivered but unlike |
238 | * in CHARDEV mode we're not blocked and we can send next |
239 | * messages right away. |
240 | */ |
241 | if (on_read_cb) |
242 | on_read_cb(); |
243 | return ret; |
244 | } |
245 | /* HVUTIL_TRANSPORT_CHARDEV */ |
246 | mutex_lock(&hvt->lock); |
247 | if (hvt->mode != HVUTIL_TRANSPORT_CHARDEV) { |
248 | ret = -EINVAL; |
249 | goto out_unlock; |
250 | } |
251 | |
252 | if (hvt->outmsg) { |
253 | /* Previous message wasn't received */ |
254 | ret = -EFAULT; |
255 | goto out_unlock; |
256 | } |
257 | hvt->outmsg = kzalloc(size: len, GFP_KERNEL); |
258 | if (hvt->outmsg) { |
259 | memcpy(hvt->outmsg, msg, len); |
260 | hvt->outmsg_len = len; |
261 | hvt->on_read = on_read_cb; |
262 | wake_up_interruptible(&hvt->outmsg_q); |
263 | } else |
264 | ret = -ENOMEM; |
265 | out_unlock: |
266 | mutex_unlock(lock: &hvt->lock); |
267 | return ret; |
268 | } |
269 | |
270 | struct hvutil_transport *hvutil_transport_init(const char *name, |
271 | u32 cn_idx, u32 cn_val, |
272 | int (*on_msg)(void *, int), |
273 | void (*on_reset)(void)) |
274 | { |
275 | struct hvutil_transport *hvt; |
276 | |
277 | hvt = kzalloc(size: sizeof(*hvt), GFP_KERNEL); |
278 | if (!hvt) |
279 | return NULL; |
280 | |
281 | hvt->cn_id.idx = cn_idx; |
282 | hvt->cn_id.val = cn_val; |
283 | |
284 | hvt->mdev.minor = MISC_DYNAMIC_MINOR; |
285 | hvt->mdev.name = name; |
286 | |
287 | hvt->fops.owner = THIS_MODULE; |
288 | hvt->fops.read = hvt_op_read; |
289 | hvt->fops.write = hvt_op_write; |
290 | hvt->fops.poll = hvt_op_poll; |
291 | hvt->fops.open = hvt_op_open; |
292 | hvt->fops.release = hvt_op_release; |
293 | |
294 | hvt->mdev.fops = &hvt->fops; |
295 | |
296 | init_waitqueue_head(&hvt->outmsg_q); |
297 | mutex_init(&hvt->lock); |
298 | init_completion(x: &hvt->release); |
299 | |
300 | spin_lock(lock: &hvt_list_lock); |
301 | list_add(new: &hvt->list, head: &hvt_list); |
302 | spin_unlock(lock: &hvt_list_lock); |
303 | |
304 | hvt->on_msg = on_msg; |
305 | hvt->on_reset = on_reset; |
306 | |
307 | if (misc_register(misc: &hvt->mdev)) |
308 | goto err_free_hvt; |
309 | |
310 | /* Use cn_id.idx/cn_id.val to determine if we need to setup netlink */ |
311 | if (hvt->cn_id.idx > 0 && hvt->cn_id.val > 0 && |
312 | cn_add_callback(id: &hvt->cn_id, name, callback: hvt_cn_callback)) |
313 | goto err_free_hvt; |
314 | |
315 | return hvt; |
316 | |
317 | err_free_hvt: |
318 | spin_lock(lock: &hvt_list_lock); |
319 | list_del(entry: &hvt->list); |
320 | spin_unlock(lock: &hvt_list_lock); |
321 | kfree(objp: hvt); |
322 | return NULL; |
323 | } |
324 | |
325 | void hvutil_transport_destroy(struct hvutil_transport *hvt) |
326 | { |
327 | int mode_old; |
328 | |
329 | mutex_lock(&hvt->lock); |
330 | mode_old = hvt->mode; |
331 | hvt->mode = HVUTIL_TRANSPORT_DESTROY; |
332 | wake_up_interruptible(&hvt->outmsg_q); |
333 | mutex_unlock(lock: &hvt->lock); |
334 | |
335 | /* |
336 | * In case we were in 'chardev' mode we still have an open fd so we |
337 | * have to defer freeing the device. Netlink interface can be freed |
338 | * now. |
339 | */ |
340 | spin_lock(lock: &hvt_list_lock); |
341 | list_del(entry: &hvt->list); |
342 | spin_unlock(lock: &hvt_list_lock); |
343 | if (hvt->cn_id.idx > 0 && hvt->cn_id.val > 0) |
344 | cn_del_callback(id: &hvt->cn_id); |
345 | |
346 | if (mode_old == HVUTIL_TRANSPORT_CHARDEV) |
347 | wait_for_completion(&hvt->release); |
348 | |
349 | hvt_transport_free(hvt); |
350 | } |
351 | |