1// SPDX-License-Identifier: GPL-2.0
2/*
3 * Amazon Nitro Secure Module driver.
4 *
5 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
6 *
7 * The Nitro Secure Module implements commands via CBOR over virtio.
8 * This driver exposes a raw message ioctls on /dev/nsm that user
9 * space can use to issue these commands.
10 */
11
12#include <linux/file.h>
13#include <linux/fs.h>
14#include <linux/interrupt.h>
15#include <linux/hw_random.h>
16#include <linux/miscdevice.h>
17#include <linux/module.h>
18#include <linux/mutex.h>
19#include <linux/slab.h>
20#include <linux/string.h>
21#include <linux/uaccess.h>
22#include <linux/uio.h>
23#include <linux/virtio_config.h>
24#include <linux/virtio_ids.h>
25#include <linux/virtio.h>
26#include <linux/wait.h>
27#include <uapi/linux/nsm.h>
28
29/* Timeout for NSM virtqueue respose in milliseconds. */
30#define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */
31
32/* Maximum length input data */
33struct nsm_data_req {
34 u32 len;
35 u8 data[NSM_REQUEST_MAX_SIZE];
36};
37
38/* Maximum length output data */
39struct nsm_data_resp {
40 u32 len;
41 u8 data[NSM_RESPONSE_MAX_SIZE];
42};
43
44/* Full NSM request/response message */
45struct nsm_msg {
46 struct nsm_data_req req;
47 struct nsm_data_resp resp;
48};
49
50struct nsm {
51 struct virtio_device *vdev;
52 struct virtqueue *vq;
53 struct mutex lock;
54 struct completion cmd_done;
55 struct miscdevice misc;
56 struct hwrng hwrng;
57 struct work_struct misc_init;
58 struct nsm_msg msg;
59};
60
61/* NSM device ID */
62static const struct virtio_device_id id_table[] = {
63 { VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID },
64 { 0 },
65};
66
67static struct nsm *file_to_nsm(struct file *file)
68{
69 return container_of(file->private_data, struct nsm, misc);
70}
71
72static struct nsm *hwrng_to_nsm(struct hwrng *rng)
73{
74 return container_of(rng, struct nsm, hwrng);
75}
76
77#define CBOR_TYPE_MASK 0xE0
78#define CBOR_TYPE_MAP 0xA0
79#define CBOR_TYPE_TEXT 0x60
80#define CBOR_TYPE_ARRAY 0x40
81#define CBOR_HEADER_SIZE_SHORT 1
82
83#define CBOR_SHORT_SIZE_MAX_VALUE 23
84#define CBOR_LONG_SIZE_U8 24
85#define CBOR_LONG_SIZE_U16 25
86#define CBOR_LONG_SIZE_U32 26
87#define CBOR_LONG_SIZE_U64 27
88
89static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size)
90{
91 if (cbor_object_size == 0 || cbor_object == NULL)
92 return false;
93
94 return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY;
95}
96
97static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array)
98{
99 u8 cbor_short_size;
100 void *array_len_p;
101 u64 array_len;
102 u64 array_offset;
103
104 if (!cbor_object_is_array(cbor_object, cbor_object_size))
105 return -EFAULT;
106
107 cbor_short_size = (cbor_object[0] & 0x1F);
108
109 /* Decoding byte array length */
110 array_offset = CBOR_HEADER_SIZE_SHORT;
111 if (cbor_short_size >= CBOR_LONG_SIZE_U8)
112 array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8);
113
114 if (cbor_object_size < array_offset)
115 return -EFAULT;
116
117 array_len_p = &cbor_object[1];
118
119 switch (cbor_short_size) {
120 case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */
121 array_len = cbor_short_size;
122 break;
123 case CBOR_LONG_SIZE_U8:
124 array_len = *(u8 *)array_len_p;
125 break;
126 case CBOR_LONG_SIZE_U16:
127 array_len = be16_to_cpup(p: (__be16 *)array_len_p);
128 break;
129 case CBOR_LONG_SIZE_U32:
130 array_len = be32_to_cpup(p: (__be32 *)array_len_p);
131 break;
132 case CBOR_LONG_SIZE_U64:
133 array_len = be64_to_cpup(p: (__be64 *)array_len_p);
134 break;
135 }
136
137 if (cbor_object_size < array_offset)
138 return -EFAULT;
139
140 if (cbor_object_size - array_offset < array_len)
141 return -EFAULT;
142
143 if (array_len > INT_MAX)
144 return -EFAULT;
145
146 *cbor_array = cbor_object + array_offset;
147 return array_len;
148}
149
150/* Copy the request of a raw message to kernel space */
151static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req,
152 struct nsm_raw *raw)
153{
154 /* Verify the user input size. */
155 if (raw->request.len > sizeof(req->data))
156 return -EMSGSIZE;
157
158 /* Copy the request payload */
159 if (copy_from_user(to: req->data, u64_to_user_ptr(raw->request.addr),
160 n: raw->request.len))
161 return -EFAULT;
162
163 req->len = raw->request.len;
164
165 return 0;
166}
167
168/* Copy the response of a raw message back to user-space */
169static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp,
170 struct nsm_raw *raw)
171{
172 /* Truncate any message that does not fit. */
173 raw->response.len = min_t(u64, raw->response.len, resp->len);
174
175 /* Copy the response content to user space */
176 if (copy_to_user(u64_to_user_ptr(raw->response.addr),
177 from: resp->data, n: raw->response.len))
178 return -EFAULT;
179
180 return 0;
181}
182
183/* Virtqueue interrupt handler */
184static void nsm_vq_callback(struct virtqueue *vq)
185{
186 struct nsm *nsm = vq->vdev->priv;
187
188 complete(&nsm->cmd_done);
189}
190
191/* Forward a message to the NSM device and wait for the response from it */
192static int nsm_sendrecv_msg_locked(struct nsm *nsm)
193{
194 struct device *dev = &nsm->vdev->dev;
195 struct scatterlist sg_in, sg_out;
196 struct nsm_msg *msg = &nsm->msg;
197 struct virtqueue *vq = nsm->vq;
198 unsigned int len;
199 void *queue_buf;
200 bool kicked;
201 int rc;
202
203 /* Initialize scatter-gather lists with request and response buffers. */
204 sg_init_one(&sg_out, msg->req.data, msg->req.len);
205 sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data));
206
207 init_completion(x: &nsm->cmd_done);
208 /* Add the request buffer (read by the device). */
209 rc = virtqueue_add_outbuf(vq, sg: &sg_out, num: 1, data: msg->req.data, GFP_KERNEL);
210 if (rc)
211 return rc;
212
213 /* Add the response buffer (written by the device). */
214 rc = virtqueue_add_inbuf(vq, sg: &sg_in, num: 1, data: msg->resp.data, GFP_KERNEL);
215 if (rc)
216 goto cleanup;
217
218 kicked = virtqueue_kick(vq);
219 if (!kicked) {
220 /* Cannot kick the virtqueue. */
221 rc = -EIO;
222 goto cleanup;
223 }
224
225 /* If the kick succeeded, wait for the device's response. */
226 if (!wait_for_completion_io_timeout(x: &nsm->cmd_done,
227 timeout: msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) {
228 rc = -ETIMEDOUT;
229 goto cleanup;
230 }
231
232 queue_buf = virtqueue_get_buf(vq, len: &len);
233 if (!queue_buf || (queue_buf != msg->req.data)) {
234 dev_err(dev, "wrong request buffer.");
235 rc = -ENODATA;
236 goto cleanup;
237 }
238
239 queue_buf = virtqueue_get_buf(vq, len: &len);
240 if (!queue_buf || (queue_buf != msg->resp.data)) {
241 dev_err(dev, "wrong response buffer.");
242 rc = -ENODATA;
243 goto cleanup;
244 }
245
246 msg->resp.len = len;
247
248 rc = 0;
249
250cleanup:
251 if (rc) {
252 /* Clean the virtqueue. */
253 while (virtqueue_get_buf(vq, len: &len) != NULL)
254 ;
255 }
256
257 return rc;
258}
259
260static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req)
261{
262 /*
263 * 69 # text(9)
264 * 47657452616E646F6D # "GetRandom"
265 */
266 const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom"),
267 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' };
268
269 memcpy(req->data, request, sizeof(request));
270 req->len = sizeof(request);
271
272 return 0;
273}
274
275static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp,
276 void *out, size_t max)
277{
278 /*
279 * A1 # map(1)
280 * 69 # text(9) - Name of field
281 * 47657452616E646F6D # "GetRandom"
282 * A1 # map(1) - The field itself
283 * 66 # text(6)
284 * 72616E646F6D # "random"
285 * # The rest of the response is random data
286 */
287 const u8 response[] = { CBOR_TYPE_MAP + 1,
288 CBOR_TYPE_TEXT + strlen("GetRandom"),
289 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm',
290 CBOR_TYPE_MAP + 1,
291 CBOR_TYPE_TEXT + strlen("random"),
292 'r', 'a', 'n', 'd', 'o', 'm' };
293 struct device *dev = &nsm->vdev->dev;
294 u8 *rand_data = NULL;
295 u8 *resp_ptr = resp->data;
296 u64 resp_len = resp->len;
297 int rc;
298
299 if ((resp->len < sizeof(response) + 1) ||
300 (memcmp(p: resp_ptr, q: response, size: sizeof(response)) != 0)) {
301 dev_err(dev, "Invalid response for GetRandom");
302 return -EFAULT;
303 }
304
305 resp_ptr += sizeof(response);
306 resp_len -= sizeof(response);
307
308 rc = cbor_object_get_array(cbor_object: resp_ptr, cbor_object_size: resp_len, cbor_array: &rand_data);
309 if (rc < 0) {
310 dev_err(dev, "GetRandom: Invalid CBOR encoding\n");
311 return rc;
312 }
313
314 rc = min_t(size_t, rc, max);
315 memcpy(out, rand_data, rc);
316
317 return rc;
318}
319
320/*
321 * HwRNG implementation
322 */
323static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait)
324{
325 struct nsm *nsm = hwrng_to_nsm(rng);
326 struct device *dev = &nsm->vdev->dev;
327 int rc = 0;
328
329 /* NSM always needs to wait for a response */
330 if (!wait)
331 return 0;
332
333 mutex_lock(&nsm->lock);
334
335 rc = fill_req_get_random(nsm, req: &nsm->msg.req);
336 if (rc != 0)
337 goto out;
338
339 rc = nsm_sendrecv_msg_locked(nsm);
340 if (rc != 0)
341 goto out;
342
343 rc = parse_resp_get_random(nsm, resp: &nsm->msg.resp, out: data, max);
344 if (rc < 0)
345 goto out;
346
347 dev_dbg(dev, "RNG: returning rand bytes = %d", rc);
348out:
349 mutex_unlock(lock: &nsm->lock);
350 return rc;
351}
352
353static long nsm_dev_ioctl(struct file *file, unsigned int cmd,
354 unsigned long arg)
355{
356 void __user *argp = u64_to_user_ptr((u64)arg);
357 struct nsm *nsm = file_to_nsm(file);
358 struct nsm_raw raw;
359 int r = 0;
360
361 if (cmd != NSM_IOCTL_RAW)
362 return -EINVAL;
363
364 if (_IOC_SIZE(cmd) != sizeof(raw))
365 return -EINVAL;
366
367 /* Copy user argument struct to kernel argument struct */
368 r = -EFAULT;
369 if (copy_from_user(to: &raw, from: argp, _IOC_SIZE(cmd)))
370 goto out;
371
372 mutex_lock(&nsm->lock);
373
374 /* Convert kernel argument struct to device request */
375 r = fill_req_raw(nsm, req: &nsm->msg.req, raw: &raw);
376 if (r)
377 goto out;
378
379 /* Send message to NSM and read reply */
380 r = nsm_sendrecv_msg_locked(nsm);
381 if (r)
382 goto out;
383
384 /* Parse device response into kernel argument struct */
385 r = parse_resp_raw(nsm, resp: &nsm->msg.resp, raw: &raw);
386 if (r)
387 goto out;
388
389 /* Copy kernel argument struct back to user argument struct */
390 r = -EFAULT;
391 if (copy_to_user(to: argp, from: &raw, n: sizeof(raw)))
392 goto out;
393
394 r = 0;
395
396out:
397 mutex_unlock(lock: &nsm->lock);
398 return r;
399}
400
401static int nsm_device_init_vq(struct virtio_device *vdev)
402{
403 struct virtqueue *vq = virtio_find_single_vq(vdev,
404 c: nsm_vq_callback, n: "nsm.vq.0");
405 struct nsm *nsm = vdev->priv;
406
407 if (IS_ERR(ptr: vq))
408 return PTR_ERR(ptr: vq);
409
410 nsm->vq = vq;
411
412 return 0;
413}
414
415static const struct file_operations nsm_dev_fops = {
416 .unlocked_ioctl = nsm_dev_ioctl,
417 .compat_ioctl = compat_ptr_ioctl,
418};
419
420/* Handler for probing the NSM device */
421static int nsm_device_probe(struct virtio_device *vdev)
422{
423 struct device *dev = &vdev->dev;
424 struct nsm *nsm;
425 int rc;
426
427 nsm = devm_kzalloc(dev: &vdev->dev, size: sizeof(*nsm), GFP_KERNEL);
428 if (!nsm)
429 return -ENOMEM;
430
431 vdev->priv = nsm;
432 nsm->vdev = vdev;
433
434 rc = nsm_device_init_vq(vdev);
435 if (rc) {
436 dev_err(dev, "queue failed to initialize: %d.\n", rc);
437 goto err_init_vq;
438 }
439
440 mutex_init(&nsm->lock);
441
442 /* Register as hwrng provider */
443 nsm->hwrng = (struct hwrng) {
444 .read = nsm_rng_read,
445 .name = "nsm-hwrng",
446 .quality = 1000,
447 };
448
449 rc = hwrng_register(rng: &nsm->hwrng);
450 if (rc) {
451 dev_err(dev, "RNG initialization error: %d.\n", rc);
452 goto err_hwrng;
453 }
454
455 /* Register /dev/nsm device node */
456 nsm->misc = (struct miscdevice) {
457 .minor = MISC_DYNAMIC_MINOR,
458 .name = "nsm",
459 .fops = &nsm_dev_fops,
460 .mode = 0666,
461 };
462
463 rc = misc_register(misc: &nsm->misc);
464 if (rc) {
465 dev_err(dev, "misc device registration error: %d.\n", rc);
466 goto err_misc;
467 }
468
469 return 0;
470
471err_misc:
472 hwrng_unregister(rng: &nsm->hwrng);
473err_hwrng:
474 vdev->config->del_vqs(vdev);
475err_init_vq:
476 return rc;
477}
478
479/* Handler for removing the NSM device */
480static void nsm_device_remove(struct virtio_device *vdev)
481{
482 struct nsm *nsm = vdev->priv;
483
484 hwrng_unregister(rng: &nsm->hwrng);
485
486 vdev->config->del_vqs(vdev);
487 misc_deregister(misc: &nsm->misc);
488}
489
490/* NSM device configuration structure */
491static struct virtio_driver virtio_nsm_driver = {
492 .feature_table = 0,
493 .feature_table_size = 0,
494 .feature_table_legacy = 0,
495 .feature_table_size_legacy = 0,
496 .driver.name = KBUILD_MODNAME,
497 .driver.owner = THIS_MODULE,
498 .id_table = id_table,
499 .probe = nsm_device_probe,
500 .remove = nsm_device_remove,
501};
502
503module_virtio_driver(virtio_nsm_driver);
504MODULE_DEVICE_TABLE(virtio, id_table);
505MODULE_DESCRIPTION("Virtio NSM driver");
506MODULE_LICENSE("GPL");
507

source code of linux/drivers/misc/nsm.c