1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* |
3 | * Amazon Nitro Secure Module driver. |
4 | * |
5 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
6 | * |
7 | * The Nitro Secure Module implements commands via CBOR over virtio. |
8 | * This driver exposes a raw message ioctls on /dev/nsm that user |
9 | * space can use to issue these commands. |
10 | */ |
11 | |
12 | #include <linux/file.h> |
13 | #include <linux/fs.h> |
14 | #include <linux/interrupt.h> |
15 | #include <linux/hw_random.h> |
16 | #include <linux/miscdevice.h> |
17 | #include <linux/module.h> |
18 | #include <linux/mutex.h> |
19 | #include <linux/slab.h> |
20 | #include <linux/string.h> |
21 | #include <linux/uaccess.h> |
22 | #include <linux/uio.h> |
23 | #include <linux/virtio_config.h> |
24 | #include <linux/virtio_ids.h> |
25 | #include <linux/virtio.h> |
26 | #include <linux/wait.h> |
27 | #include <uapi/linux/nsm.h> |
28 | |
29 | /* Timeout for NSM virtqueue respose in milliseconds. */ |
30 | #define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */ |
31 | |
32 | /* Maximum length input data */ |
33 | struct nsm_data_req { |
34 | u32 len; |
35 | u8 data[NSM_REQUEST_MAX_SIZE]; |
36 | }; |
37 | |
38 | /* Maximum length output data */ |
39 | struct nsm_data_resp { |
40 | u32 len; |
41 | u8 data[NSM_RESPONSE_MAX_SIZE]; |
42 | }; |
43 | |
44 | /* Full NSM request/response message */ |
45 | struct nsm_msg { |
46 | struct nsm_data_req req; |
47 | struct nsm_data_resp resp; |
48 | }; |
49 | |
50 | struct nsm { |
51 | struct virtio_device *vdev; |
52 | struct virtqueue *vq; |
53 | struct mutex lock; |
54 | struct completion cmd_done; |
55 | struct miscdevice misc; |
56 | struct hwrng hwrng; |
57 | struct work_struct misc_init; |
58 | struct nsm_msg msg; |
59 | }; |
60 | |
61 | /* NSM device ID */ |
62 | static const struct virtio_device_id id_table[] = { |
63 | { VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID }, |
64 | { 0 }, |
65 | }; |
66 | |
67 | static struct nsm *file_to_nsm(struct file *file) |
68 | { |
69 | return container_of(file->private_data, struct nsm, misc); |
70 | } |
71 | |
72 | static struct nsm *hwrng_to_nsm(struct hwrng *rng) |
73 | { |
74 | return container_of(rng, struct nsm, hwrng); |
75 | } |
76 | |
77 | #define CBOR_TYPE_MASK 0xE0 |
78 | #define CBOR_TYPE_MAP 0xA0 |
79 | #define CBOR_TYPE_TEXT 0x60 |
80 | #define CBOR_TYPE_ARRAY 0x40 |
81 | #define 1 |
82 | |
83 | #define CBOR_SHORT_SIZE_MAX_VALUE 23 |
84 | #define CBOR_LONG_SIZE_U8 24 |
85 | #define CBOR_LONG_SIZE_U16 25 |
86 | #define CBOR_LONG_SIZE_U32 26 |
87 | #define CBOR_LONG_SIZE_U64 27 |
88 | |
89 | static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size) |
90 | { |
91 | if (cbor_object_size == 0 || cbor_object == NULL) |
92 | return false; |
93 | |
94 | return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY; |
95 | } |
96 | |
97 | static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array) |
98 | { |
99 | u8 cbor_short_size; |
100 | void *array_len_p; |
101 | u64 array_len; |
102 | u64 array_offset; |
103 | |
104 | if (!cbor_object_is_array(cbor_object, cbor_object_size)) |
105 | return -EFAULT; |
106 | |
107 | cbor_short_size = (cbor_object[0] & 0x1F); |
108 | |
109 | /* Decoding byte array length */ |
110 | array_offset = CBOR_HEADER_SIZE_SHORT; |
111 | if (cbor_short_size >= CBOR_LONG_SIZE_U8) |
112 | array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8); |
113 | |
114 | if (cbor_object_size < array_offset) |
115 | return -EFAULT; |
116 | |
117 | array_len_p = &cbor_object[1]; |
118 | |
119 | switch (cbor_short_size) { |
120 | case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */ |
121 | array_len = cbor_short_size; |
122 | break; |
123 | case CBOR_LONG_SIZE_U8: |
124 | array_len = *(u8 *)array_len_p; |
125 | break; |
126 | case CBOR_LONG_SIZE_U16: |
127 | array_len = be16_to_cpup(p: (__be16 *)array_len_p); |
128 | break; |
129 | case CBOR_LONG_SIZE_U32: |
130 | array_len = be32_to_cpup(p: (__be32 *)array_len_p); |
131 | break; |
132 | case CBOR_LONG_SIZE_U64: |
133 | array_len = be64_to_cpup(p: (__be64 *)array_len_p); |
134 | break; |
135 | } |
136 | |
137 | if (cbor_object_size < array_offset) |
138 | return -EFAULT; |
139 | |
140 | if (cbor_object_size - array_offset < array_len) |
141 | return -EFAULT; |
142 | |
143 | if (array_len > INT_MAX) |
144 | return -EFAULT; |
145 | |
146 | *cbor_array = cbor_object + array_offset; |
147 | return array_len; |
148 | } |
149 | |
150 | /* Copy the request of a raw message to kernel space */ |
151 | static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req, |
152 | struct nsm_raw *raw) |
153 | { |
154 | /* Verify the user input size. */ |
155 | if (raw->request.len > sizeof(req->data)) |
156 | return -EMSGSIZE; |
157 | |
158 | /* Copy the request payload */ |
159 | if (copy_from_user(to: req->data, u64_to_user_ptr(raw->request.addr), |
160 | n: raw->request.len)) |
161 | return -EFAULT; |
162 | |
163 | req->len = raw->request.len; |
164 | |
165 | return 0; |
166 | } |
167 | |
168 | /* Copy the response of a raw message back to user-space */ |
169 | static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp, |
170 | struct nsm_raw *raw) |
171 | { |
172 | /* Truncate any message that does not fit. */ |
173 | raw->response.len = min_t(u64, raw->response.len, resp->len); |
174 | |
175 | /* Copy the response content to user space */ |
176 | if (copy_to_user(u64_to_user_ptr(raw->response.addr), |
177 | from: resp->data, n: raw->response.len)) |
178 | return -EFAULT; |
179 | |
180 | return 0; |
181 | } |
182 | |
183 | /* Virtqueue interrupt handler */ |
184 | static void nsm_vq_callback(struct virtqueue *vq) |
185 | { |
186 | struct nsm *nsm = vq->vdev->priv; |
187 | |
188 | complete(&nsm->cmd_done); |
189 | } |
190 | |
191 | /* Forward a message to the NSM device and wait for the response from it */ |
192 | static int nsm_sendrecv_msg_locked(struct nsm *nsm) |
193 | { |
194 | struct device *dev = &nsm->vdev->dev; |
195 | struct scatterlist sg_in, sg_out; |
196 | struct nsm_msg *msg = &nsm->msg; |
197 | struct virtqueue *vq = nsm->vq; |
198 | unsigned int len; |
199 | void *queue_buf; |
200 | bool kicked; |
201 | int rc; |
202 | |
203 | /* Initialize scatter-gather lists with request and response buffers. */ |
204 | sg_init_one(&sg_out, msg->req.data, msg->req.len); |
205 | sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data)); |
206 | |
207 | init_completion(x: &nsm->cmd_done); |
208 | /* Add the request buffer (read by the device). */ |
209 | rc = virtqueue_add_outbuf(vq, sg: &sg_out, num: 1, data: msg->req.data, GFP_KERNEL); |
210 | if (rc) |
211 | return rc; |
212 | |
213 | /* Add the response buffer (written by the device). */ |
214 | rc = virtqueue_add_inbuf(vq, sg: &sg_in, num: 1, data: msg->resp.data, GFP_KERNEL); |
215 | if (rc) |
216 | goto cleanup; |
217 | |
218 | kicked = virtqueue_kick(vq); |
219 | if (!kicked) { |
220 | /* Cannot kick the virtqueue. */ |
221 | rc = -EIO; |
222 | goto cleanup; |
223 | } |
224 | |
225 | /* If the kick succeeded, wait for the device's response. */ |
226 | if (!wait_for_completion_io_timeout(x: &nsm->cmd_done, |
227 | timeout: msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) { |
228 | rc = -ETIMEDOUT; |
229 | goto cleanup; |
230 | } |
231 | |
232 | queue_buf = virtqueue_get_buf(vq, len: &len); |
233 | if (!queue_buf || (queue_buf != msg->req.data)) { |
234 | dev_err(dev, "wrong request buffer." ); |
235 | rc = -ENODATA; |
236 | goto cleanup; |
237 | } |
238 | |
239 | queue_buf = virtqueue_get_buf(vq, len: &len); |
240 | if (!queue_buf || (queue_buf != msg->resp.data)) { |
241 | dev_err(dev, "wrong response buffer." ); |
242 | rc = -ENODATA; |
243 | goto cleanup; |
244 | } |
245 | |
246 | msg->resp.len = len; |
247 | |
248 | rc = 0; |
249 | |
250 | cleanup: |
251 | if (rc) { |
252 | /* Clean the virtqueue. */ |
253 | while (virtqueue_get_buf(vq, len: &len) != NULL) |
254 | ; |
255 | } |
256 | |
257 | return rc; |
258 | } |
259 | |
260 | static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req) |
261 | { |
262 | /* |
263 | * 69 # text(9) |
264 | * 47657452616E646F6D # "GetRandom" |
265 | */ |
266 | const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom" ), |
267 | 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' }; |
268 | |
269 | memcpy(req->data, request, sizeof(request)); |
270 | req->len = sizeof(request); |
271 | |
272 | return 0; |
273 | } |
274 | |
275 | static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp, |
276 | void *out, size_t max) |
277 | { |
278 | /* |
279 | * A1 # map(1) |
280 | * 69 # text(9) - Name of field |
281 | * 47657452616E646F6D # "GetRandom" |
282 | * A1 # map(1) - The field itself |
283 | * 66 # text(6) |
284 | * 72616E646F6D # "random" |
285 | * # The rest of the response is random data |
286 | */ |
287 | const u8 response[] = { CBOR_TYPE_MAP + 1, |
288 | CBOR_TYPE_TEXT + strlen("GetRandom" ), |
289 | 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm', |
290 | CBOR_TYPE_MAP + 1, |
291 | CBOR_TYPE_TEXT + strlen("random" ), |
292 | 'r', 'a', 'n', 'd', 'o', 'm' }; |
293 | struct device *dev = &nsm->vdev->dev; |
294 | u8 *rand_data = NULL; |
295 | u8 *resp_ptr = resp->data; |
296 | u64 resp_len = resp->len; |
297 | int rc; |
298 | |
299 | if ((resp->len < sizeof(response) + 1) || |
300 | (memcmp(p: resp_ptr, q: response, size: sizeof(response)) != 0)) { |
301 | dev_err(dev, "Invalid response for GetRandom" ); |
302 | return -EFAULT; |
303 | } |
304 | |
305 | resp_ptr += sizeof(response); |
306 | resp_len -= sizeof(response); |
307 | |
308 | rc = cbor_object_get_array(cbor_object: resp_ptr, cbor_object_size: resp_len, cbor_array: &rand_data); |
309 | if (rc < 0) { |
310 | dev_err(dev, "GetRandom: Invalid CBOR encoding\n" ); |
311 | return rc; |
312 | } |
313 | |
314 | rc = min_t(size_t, rc, max); |
315 | memcpy(out, rand_data, rc); |
316 | |
317 | return rc; |
318 | } |
319 | |
320 | /* |
321 | * HwRNG implementation |
322 | */ |
323 | static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait) |
324 | { |
325 | struct nsm *nsm = hwrng_to_nsm(rng); |
326 | struct device *dev = &nsm->vdev->dev; |
327 | int rc = 0; |
328 | |
329 | /* NSM always needs to wait for a response */ |
330 | if (!wait) |
331 | return 0; |
332 | |
333 | mutex_lock(&nsm->lock); |
334 | |
335 | rc = fill_req_get_random(nsm, req: &nsm->msg.req); |
336 | if (rc != 0) |
337 | goto out; |
338 | |
339 | rc = nsm_sendrecv_msg_locked(nsm); |
340 | if (rc != 0) |
341 | goto out; |
342 | |
343 | rc = parse_resp_get_random(nsm, resp: &nsm->msg.resp, out: data, max); |
344 | if (rc < 0) |
345 | goto out; |
346 | |
347 | dev_dbg(dev, "RNG: returning rand bytes = %d" , rc); |
348 | out: |
349 | mutex_unlock(lock: &nsm->lock); |
350 | return rc; |
351 | } |
352 | |
353 | static long nsm_dev_ioctl(struct file *file, unsigned int cmd, |
354 | unsigned long arg) |
355 | { |
356 | void __user *argp = u64_to_user_ptr((u64)arg); |
357 | struct nsm *nsm = file_to_nsm(file); |
358 | struct nsm_raw raw; |
359 | int r = 0; |
360 | |
361 | if (cmd != NSM_IOCTL_RAW) |
362 | return -EINVAL; |
363 | |
364 | if (_IOC_SIZE(cmd) != sizeof(raw)) |
365 | return -EINVAL; |
366 | |
367 | /* Copy user argument struct to kernel argument struct */ |
368 | r = -EFAULT; |
369 | if (copy_from_user(to: &raw, from: argp, _IOC_SIZE(cmd))) |
370 | goto out; |
371 | |
372 | mutex_lock(&nsm->lock); |
373 | |
374 | /* Convert kernel argument struct to device request */ |
375 | r = fill_req_raw(nsm, req: &nsm->msg.req, raw: &raw); |
376 | if (r) |
377 | goto out; |
378 | |
379 | /* Send message to NSM and read reply */ |
380 | r = nsm_sendrecv_msg_locked(nsm); |
381 | if (r) |
382 | goto out; |
383 | |
384 | /* Parse device response into kernel argument struct */ |
385 | r = parse_resp_raw(nsm, resp: &nsm->msg.resp, raw: &raw); |
386 | if (r) |
387 | goto out; |
388 | |
389 | /* Copy kernel argument struct back to user argument struct */ |
390 | r = -EFAULT; |
391 | if (copy_to_user(to: argp, from: &raw, n: sizeof(raw))) |
392 | goto out; |
393 | |
394 | r = 0; |
395 | |
396 | out: |
397 | mutex_unlock(lock: &nsm->lock); |
398 | return r; |
399 | } |
400 | |
401 | static int nsm_device_init_vq(struct virtio_device *vdev) |
402 | { |
403 | struct virtqueue *vq = virtio_find_single_vq(vdev, |
404 | c: nsm_vq_callback, n: "nsm.vq.0" ); |
405 | struct nsm *nsm = vdev->priv; |
406 | |
407 | if (IS_ERR(ptr: vq)) |
408 | return PTR_ERR(ptr: vq); |
409 | |
410 | nsm->vq = vq; |
411 | |
412 | return 0; |
413 | } |
414 | |
415 | static const struct file_operations nsm_dev_fops = { |
416 | .unlocked_ioctl = nsm_dev_ioctl, |
417 | .compat_ioctl = compat_ptr_ioctl, |
418 | }; |
419 | |
420 | /* Handler for probing the NSM device */ |
421 | static int nsm_device_probe(struct virtio_device *vdev) |
422 | { |
423 | struct device *dev = &vdev->dev; |
424 | struct nsm *nsm; |
425 | int rc; |
426 | |
427 | nsm = devm_kzalloc(dev: &vdev->dev, size: sizeof(*nsm), GFP_KERNEL); |
428 | if (!nsm) |
429 | return -ENOMEM; |
430 | |
431 | vdev->priv = nsm; |
432 | nsm->vdev = vdev; |
433 | |
434 | rc = nsm_device_init_vq(vdev); |
435 | if (rc) { |
436 | dev_err(dev, "queue failed to initialize: %d.\n" , rc); |
437 | goto err_init_vq; |
438 | } |
439 | |
440 | mutex_init(&nsm->lock); |
441 | |
442 | /* Register as hwrng provider */ |
443 | nsm->hwrng = (struct hwrng) { |
444 | .read = nsm_rng_read, |
445 | .name = "nsm-hwrng" , |
446 | .quality = 1000, |
447 | }; |
448 | |
449 | rc = hwrng_register(rng: &nsm->hwrng); |
450 | if (rc) { |
451 | dev_err(dev, "RNG initialization error: %d.\n" , rc); |
452 | goto err_hwrng; |
453 | } |
454 | |
455 | /* Register /dev/nsm device node */ |
456 | nsm->misc = (struct miscdevice) { |
457 | .minor = MISC_DYNAMIC_MINOR, |
458 | .name = "nsm" , |
459 | .fops = &nsm_dev_fops, |
460 | .mode = 0666, |
461 | }; |
462 | |
463 | rc = misc_register(misc: &nsm->misc); |
464 | if (rc) { |
465 | dev_err(dev, "misc device registration error: %d.\n" , rc); |
466 | goto err_misc; |
467 | } |
468 | |
469 | return 0; |
470 | |
471 | err_misc: |
472 | hwrng_unregister(rng: &nsm->hwrng); |
473 | err_hwrng: |
474 | vdev->config->del_vqs(vdev); |
475 | err_init_vq: |
476 | return rc; |
477 | } |
478 | |
479 | /* Handler for removing the NSM device */ |
480 | static void nsm_device_remove(struct virtio_device *vdev) |
481 | { |
482 | struct nsm *nsm = vdev->priv; |
483 | |
484 | hwrng_unregister(rng: &nsm->hwrng); |
485 | |
486 | vdev->config->del_vqs(vdev); |
487 | misc_deregister(misc: &nsm->misc); |
488 | } |
489 | |
490 | /* NSM device configuration structure */ |
491 | static struct virtio_driver virtio_nsm_driver = { |
492 | .feature_table = 0, |
493 | .feature_table_size = 0, |
494 | .feature_table_legacy = 0, |
495 | .feature_table_size_legacy = 0, |
496 | .driver.name = KBUILD_MODNAME, |
497 | .driver.owner = THIS_MODULE, |
498 | .id_table = id_table, |
499 | .probe = nsm_device_probe, |
500 | .remove = nsm_device_remove, |
501 | }; |
502 | |
503 | module_virtio_driver(virtio_nsm_driver); |
504 | MODULE_DEVICE_TABLE(virtio, id_table); |
505 | MODULE_DESCRIPTION("Virtio NSM driver" ); |
506 | MODULE_LICENSE("GPL" ); |
507 | |