summaryrefslogtreecommitdiffstats
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c49
1 files changed, 48 insertions, 1 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index dd3d6f7406f8..344019774ddd 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -212,6 +212,45 @@ static int vhost_worker(void *data)
}
}
+/* Helper to allocate iovec buffers for all vqs. */
+static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
+{
+ int i;
+ for (i = 0; i < dev->nvqs; ++i) {
+ dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect *
+ UIO_MAXIOV, GFP_KERNEL);
+ dev->vqs[i].log = kmalloc(sizeof *dev->vqs[i].log * UIO_MAXIOV,
+ GFP_KERNEL);
+ dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads *
+ UIO_MAXIOV, GFP_KERNEL);
+
+ if (!dev->vqs[i].indirect || !dev->vqs[i].log ||
+ !dev->vqs[i].heads)
+ goto err_nomem;
+ }
+ return 0;
+err_nomem:
+ for (; i >= 0; --i) {
+ kfree(dev->vqs[i].indirect);
+ kfree(dev->vqs[i].log);
+ kfree(dev->vqs[i].heads);
+ }
+ return -ENOMEM;
+}
+
+static void vhost_dev_free_iovecs(struct vhost_dev *dev)
+{
+ int i;
+ for (i = 0; i < dev->nvqs; ++i) {
+ kfree(dev->vqs[i].indirect);
+ dev->vqs[i].indirect = NULL;
+ kfree(dev->vqs[i].log);
+ dev->vqs[i].log = NULL;
+ kfree(dev->vqs[i].heads);
+ dev->vqs[i].heads = NULL;
+ }
+}
+
long vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue *vqs, int nvqs)
{
@@ -229,6 +268,9 @@ long vhost_dev_init(struct vhost_dev *dev,
dev->worker = NULL;
for (i = 0; i < dev->nvqs; ++i) {
+ dev->vqs[i].log = NULL;
+ dev->vqs[i].indirect = NULL;
+ dev->vqs[i].heads = NULL;
dev->vqs[i].dev = dev;
mutex_init(&dev->vqs[i].mutex);
vhost_vq_reset(dev, dev->vqs + i);
@@ -295,6 +337,10 @@ static long vhost_dev_set_owner(struct vhost_dev *dev)
if (err)
goto err_cgroup;
+ err = vhost_dev_alloc_iovecs(dev);
+ if (err)
+ goto err_cgroup;
+
return 0;
err_cgroup:
kthread_stop(worker);
@@ -345,6 +391,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
fput(dev->vqs[i].call);
vhost_vq_reset(dev, dev->vqs + i);
}
+ vhost_dev_free_iovecs(dev);
if (dev->log_ctx)
eventfd_ctx_put(dev->log_ctx);
dev->log_ctx = NULL;
@@ -947,7 +994,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
}
ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
- ARRAY_SIZE(vq->indirect));
+ UIO_MAXIOV);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d in indirect.\n", ret);
return ret;