|
@@ -15,6 +15,7 @@
|
|
#include <linux/vhost.h>
|
|
#include <linux/vhost.h>
|
|
#include <linux/virtio_net.h>
|
|
#include <linux/virtio_net.h>
|
|
#include <linux/mm.h>
|
|
#include <linux/mm.h>
|
|
|
|
+#include <linux/mmu_context.h>
|
|
#include <linux/miscdevice.h>
|
|
#include <linux/miscdevice.h>
|
|
#include <linux/mutex.h>
|
|
#include <linux/mutex.h>
|
|
#include <linux/rcupdate.h>
|
|
#include <linux/rcupdate.h>
|
|
@@ -29,8 +30,6 @@
|
|
#include <linux/if_packet.h>
|
|
#include <linux/if_packet.h>
|
|
#include <linux/if_arp.h>
|
|
#include <linux/if_arp.h>
|
|
|
|
|
|
-#include <net/sock.h>
|
|
|
|
-
|
|
|
|
#include "vhost.h"
|
|
#include "vhost.h"
|
|
|
|
|
|
enum {
|
|
enum {
|
|
@@ -157,7 +156,6 @@ static void vhost_vq_reset(struct vhost_dev *dev,
|
|
vq->avail_idx = 0;
|
|
vq->avail_idx = 0;
|
|
vq->last_used_idx = 0;
|
|
vq->last_used_idx = 0;
|
|
vq->used_flags = 0;
|
|
vq->used_flags = 0;
|
|
- vq->used_flags = 0;
|
|
|
|
vq->log_used = false;
|
|
vq->log_used = false;
|
|
vq->log_addr = -1ull;
|
|
vq->log_addr = -1ull;
|
|
vq->vhost_hlen = 0;
|
|
vq->vhost_hlen = 0;
|
|
@@ -178,6 +176,8 @@ static int vhost_worker(void *data)
|
|
struct vhost_work *work = NULL;
|
|
struct vhost_work *work = NULL;
|
|
unsigned uninitialized_var(seq);
|
|
unsigned uninitialized_var(seq);
|
|
|
|
|
|
|
|
+ use_mm(dev->mm);
|
|
|
|
+
|
|
for (;;) {
|
|
for (;;) {
|
|
/* mb paired w/ kthread_stop */
|
|
/* mb paired w/ kthread_stop */
|
|
set_current_state(TASK_INTERRUPTIBLE);
|
|
set_current_state(TASK_INTERRUPTIBLE);
|
|
@@ -192,7 +192,7 @@ static int vhost_worker(void *data)
|
|
if (kthread_should_stop()) {
|
|
if (kthread_should_stop()) {
|
|
spin_unlock_irq(&dev->work_lock);
|
|
spin_unlock_irq(&dev->work_lock);
|
|
__set_current_state(TASK_RUNNING);
|
|
__set_current_state(TASK_RUNNING);
|
|
- return 0;
|
|
|
|
|
|
+ break;
|
|
}
|
|
}
|
|
if (!list_empty(&dev->work_list)) {
|
|
if (!list_empty(&dev->work_list)) {
|
|
work = list_first_entry(&dev->work_list,
|
|
work = list_first_entry(&dev->work_list,
|
|
@@ -210,6 +210,8 @@ static int vhost_worker(void *data)
|
|
schedule();
|
|
schedule();
|
|
|
|
|
|
}
|
|
}
|
|
|
|
+ unuse_mm(dev->mm);
|
|
|
|
+ return 0;
|
|
}
|
|
}
|
|
|
|
|
|
/* Helper to allocate iovec buffers for all vqs. */
|
|
/* Helper to allocate iovec buffers for all vqs. */
|
|
@@ -402,15 +404,14 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
|
|
kfree(rcu_dereference_protected(dev->memory,
|
|
kfree(rcu_dereference_protected(dev->memory,
|
|
lockdep_is_held(&dev->mutex)));
|
|
lockdep_is_held(&dev->mutex)));
|
|
RCU_INIT_POINTER(dev->memory, NULL);
|
|
RCU_INIT_POINTER(dev->memory, NULL);
|
|
- if (dev->mm)
|
|
|
|
- mmput(dev->mm);
|
|
|
|
- dev->mm = NULL;
|
|
|
|
-
|
|
|
|
WARN_ON(!list_empty(&dev->work_list));
|
|
WARN_ON(!list_empty(&dev->work_list));
|
|
if (dev->worker) {
|
|
if (dev->worker) {
|
|
kthread_stop(dev->worker);
|
|
kthread_stop(dev->worker);
|
|
dev->worker = NULL;
|
|
dev->worker = NULL;
|
|
}
|
|
}
|
|
|
|
+ if (dev->mm)
|
|
|
|
+ mmput(dev->mm);
|
|
|
|
+ dev->mm = NULL;
|
|
}
|
|
}
|
|
|
|
|
|
static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
|
|
static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
|
|
@@ -881,14 +882,15 @@ static int set_bit_to_user(int nr, void __user *addr)
|
|
static int log_write(void __user *log_base,
|
|
static int log_write(void __user *log_base,
|
|
u64 write_address, u64 write_length)
|
|
u64 write_address, u64 write_length)
|
|
{
|
|
{
|
|
|
|
+ u64 write_page = write_address / VHOST_PAGE_SIZE;
|
|
int r;
|
|
int r;
|
|
if (!write_length)
|
|
if (!write_length)
|
|
return 0;
|
|
return 0;
|
|
- write_address /= VHOST_PAGE_SIZE;
|
|
|
|
|
|
+ write_length += write_address % VHOST_PAGE_SIZE;
|
|
for (;;) {
|
|
for (;;) {
|
|
u64 base = (u64)(unsigned long)log_base;
|
|
u64 base = (u64)(unsigned long)log_base;
|
|
- u64 log = base + write_address / 8;
|
|
|
|
- int bit = write_address % 8;
|
|
|
|
|
|
+ u64 log = base + write_page / 8;
|
|
|
|
+ int bit = write_page % 8;
|
|
if ((u64)(unsigned long)log != log)
|
|
if ((u64)(unsigned long)log != log)
|
|
return -EFAULT;
|
|
return -EFAULT;
|
|
r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
|
|
r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
|
|
@@ -897,7 +899,7 @@ static int log_write(void __user *log_base,
|
|
if (write_length <= VHOST_PAGE_SIZE)
|
|
if (write_length <= VHOST_PAGE_SIZE)
|
|
break;
|
|
break;
|
|
write_length -= VHOST_PAGE_SIZE;
|
|
write_length -= VHOST_PAGE_SIZE;
|
|
- write_address += VHOST_PAGE_SIZE;
|
|
|
|
|
|
+ write_page += 1;
|
|
}
|
|
}
|
|
return r;
|
|
return r;
|
|
}
|
|
}
|
|
@@ -1092,7 +1094,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
|
|
|
|
|
|
/* Check it isn't doing very strange things with descriptor numbers. */
|
|
/* Check it isn't doing very strange things with descriptor numbers. */
|
|
last_avail_idx = vq->last_avail_idx;
|
|
last_avail_idx = vq->last_avail_idx;
|
|
- if (unlikely(get_user(vq->avail_idx, &vq->avail->idx))) {
|
|
|
|
|
|
+ if (unlikely(__get_user(vq->avail_idx, &vq->avail->idx))) {
|
|
vq_err(vq, "Failed to access avail idx at %p\n",
|
|
vq_err(vq, "Failed to access avail idx at %p\n",
|
|
&vq->avail->idx);
|
|
&vq->avail->idx);
|
|
return -EFAULT;
|
|
return -EFAULT;
|
|
@@ -1113,8 +1115,8 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
|
|
|
|
|
|
/* Grab the next descriptor number they're advertising, and increment
|
|
/* Grab the next descriptor number they're advertising, and increment
|
|
* the index we've seen. */
|
|
* the index we've seen. */
|
|
- if (unlikely(get_user(head,
|
|
|
|
- &vq->avail->ring[last_avail_idx % vq->num]))) {
|
|
|
|
|
|
+ if (unlikely(__get_user(head,
|
|
|
|
+ &vq->avail->ring[last_avail_idx % vq->num]))) {
|
|
vq_err(vq, "Failed to read head: idx %d address %p\n",
|
|
vq_err(vq, "Failed to read head: idx %d address %p\n",
|
|
last_avail_idx,
|
|
last_avail_idx,
|
|
&vq->avail->ring[last_avail_idx % vq->num]);
|
|
&vq->avail->ring[last_avail_idx % vq->num]);
|
|
@@ -1213,17 +1215,17 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
|
|
/* The virtqueue contains a ring of used buffers. Get a pointer to the
|
|
/* The virtqueue contains a ring of used buffers. Get a pointer to the
|
|
* next entry in that used ring. */
|
|
* next entry in that used ring. */
|
|
used = &vq->used->ring[vq->last_used_idx % vq->num];
|
|
used = &vq->used->ring[vq->last_used_idx % vq->num];
|
|
- if (put_user(head, &used->id)) {
|
|
|
|
|
|
+ if (__put_user(head, &used->id)) {
|
|
vq_err(vq, "Failed to write used id");
|
|
vq_err(vq, "Failed to write used id");
|
|
return -EFAULT;
|
|
return -EFAULT;
|
|
}
|
|
}
|
|
- if (put_user(len, &used->len)) {
|
|
|
|
|
|
+ if (__put_user(len, &used->len)) {
|
|
vq_err(vq, "Failed to write used len");
|
|
vq_err(vq, "Failed to write used len");
|
|
return -EFAULT;
|
|
return -EFAULT;
|
|
}
|
|
}
|
|
/* Make sure buffer is written before we update index. */
|
|
/* Make sure buffer is written before we update index. */
|
|
smp_wmb();
|
|
smp_wmb();
|
|
- if (put_user(vq->last_used_idx + 1, &vq->used->idx)) {
|
|
|
|
|
|
+ if (__put_user(vq->last_used_idx + 1, &vq->used->idx)) {
|
|
vq_err(vq, "Failed to increment used idx");
|
|
vq_err(vq, "Failed to increment used idx");
|
|
return -EFAULT;
|
|
return -EFAULT;
|
|
}
|
|
}
|
|
@@ -1255,7 +1257,7 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
|
|
|
|
|
|
start = vq->last_used_idx % vq->num;
|
|
start = vq->last_used_idx % vq->num;
|
|
used = vq->used->ring + start;
|
|
used = vq->used->ring + start;
|
|
- if (copy_to_user(used, heads, count * sizeof *used)) {
|
|
|
|
|
|
+ if (__copy_to_user(used, heads, count * sizeof *used)) {
|
|
vq_err(vq, "Failed to write used");
|
|
vq_err(vq, "Failed to write used");
|
|
return -EFAULT;
|
|
return -EFAULT;
|
|
}
|
|
}
|
|
@@ -1316,7 +1318,7 @@ void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
|
|
* interrupts. */
|
|
* interrupts. */
|
|
smp_mb();
|
|
smp_mb();
|
|
|
|
|
|
- if (get_user(flags, &vq->avail->flags)) {
|
|
|
|
|
|
+ if (__get_user(flags, &vq->avail->flags)) {
|
|
vq_err(vq, "Failed to get flags");
|
|
vq_err(vq, "Failed to get flags");
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
@@ -1367,7 +1369,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq)
|
|
/* They could have slipped one in as we were doing that: make
|
|
/* They could have slipped one in as we were doing that: make
|
|
* sure it's written, then check again. */
|
|
* sure it's written, then check again. */
|
|
smp_mb();
|
|
smp_mb();
|
|
- r = get_user(avail_idx, &vq->avail->idx);
|
|
|
|
|
|
+ r = __get_user(avail_idx, &vq->avail->idx);
|
|
if (r) {
|
|
if (r) {
|
|
vq_err(vq, "Failed to check avail idx at %p: %d\n",
|
|
vq_err(vq, "Failed to check avail idx at %p: %d\n",
|
|
&vq->avail->idx, r);
|
|
&vq->avail->idx, r);
|