|
@@ -645,6 +645,18 @@ static void set_device_domain(struct amd_iommu *iommu,
|
|
*
|
|
*
|
|
*****************************************************************************/
|
|
*****************************************************************************/
|
|
|
|
|
|
|
|
+/*
|
|
|
|
+ * This function checks if the driver got a valid device from the caller to
|
|
|
|
+ * avoid dereferencing invalid pointers.
|
|
|
|
+ */
|
|
|
|
+static bool check_device(struct device *dev)
|
|
|
|
+{
|
|
|
|
+ if (!dev || !dev->dma_mask)
|
|
|
|
+ return false;
|
|
|
|
+
|
|
|
|
+ return true;
|
|
|
|
+}
|
|
|
|
+
|
|
/*
|
|
/*
|
|
* In the dma_ops path we only have the struct device. This function
|
|
* In the dma_ops path we only have the struct device. This function
|
|
* finds the corresponding IOMMU, the protection domain and the
|
|
* finds the corresponding IOMMU, the protection domain and the
|
|
@@ -661,18 +673,19 @@ static int get_device_resources(struct device *dev,
|
|
struct pci_dev *pcidev;
|
|
struct pci_dev *pcidev;
|
|
u16 _bdf;
|
|
u16 _bdf;
|
|
|
|
|
|
- BUG_ON(!dev || dev->bus != &pci_bus_type || !dev->dma_mask);
|
|
|
|
|
|
+ *iommu = NULL;
|
|
|
|
+ *domain = NULL;
|
|
|
|
+ *bdf = 0xffff;
|
|
|
|
+
|
|
|
|
+ if (dev->bus != &pci_bus_type)
|
|
|
|
+ return 0;
|
|
|
|
|
|
pcidev = to_pci_dev(dev);
|
|
pcidev = to_pci_dev(dev);
|
|
_bdf = calc_devid(pcidev->bus->number, pcidev->devfn);
|
|
_bdf = calc_devid(pcidev->bus->number, pcidev->devfn);
|
|
|
|
|
|
/* device not translated by any IOMMU in the system? */
|
|
/* device not translated by any IOMMU in the system? */
|
|
- if (_bdf > amd_iommu_last_bdf) {
|
|
|
|
- *iommu = NULL;
|
|
|
|
- *domain = NULL;
|
|
|
|
- *bdf = 0xffff;
|
|
|
|
|
|
+ if (_bdf > amd_iommu_last_bdf)
|
|
return 0;
|
|
return 0;
|
|
- }
|
|
|
|
|
|
|
|
*bdf = amd_iommu_alias_table[_bdf];
|
|
*bdf = amd_iommu_alias_table[_bdf];
|
|
|
|
|
|
@@ -826,6 +839,9 @@ static dma_addr_t map_single(struct device *dev, phys_addr_t paddr,
|
|
u16 devid;
|
|
u16 devid;
|
|
dma_addr_t addr;
|
|
dma_addr_t addr;
|
|
|
|
|
|
|
|
+ if (!check_device(dev))
|
|
|
|
+ return bad_dma_address;
|
|
|
|
+
|
|
get_device_resources(dev, &iommu, &domain, &devid);
|
|
get_device_resources(dev, &iommu, &domain, &devid);
|
|
|
|
|
|
if (iommu == NULL || domain == NULL)
|
|
if (iommu == NULL || domain == NULL)
|
|
@@ -860,7 +876,8 @@ static void unmap_single(struct device *dev, dma_addr_t dma_addr,
|
|
struct protection_domain *domain;
|
|
struct protection_domain *domain;
|
|
u16 devid;
|
|
u16 devid;
|
|
|
|
|
|
- if (!get_device_resources(dev, &iommu, &domain, &devid))
|
|
|
|
|
|
+ if (!check_device(dev) ||
|
|
|
|
+ !get_device_resources(dev, &iommu, &domain, &devid))
|
|
/* device not handled by any AMD IOMMU */
|
|
/* device not handled by any AMD IOMMU */
|
|
return;
|
|
return;
|
|
|
|
|
|
@@ -910,6 +927,9 @@ static int map_sg(struct device *dev, struct scatterlist *sglist,
|
|
phys_addr_t paddr;
|
|
phys_addr_t paddr;
|
|
int mapped_elems = 0;
|
|
int mapped_elems = 0;
|
|
|
|
|
|
|
|
+ if (!check_device(dev))
|
|
|
|
+ return 0;
|
|
|
|
+
|
|
get_device_resources(dev, &iommu, &domain, &devid);
|
|
get_device_resources(dev, &iommu, &domain, &devid);
|
|
|
|
|
|
if (!iommu || !domain)
|
|
if (!iommu || !domain)
|
|
@@ -967,7 +987,8 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist,
|
|
u16 devid;
|
|
u16 devid;
|
|
int i;
|
|
int i;
|
|
|
|
|
|
- if (!get_device_resources(dev, &iommu, &domain, &devid))
|
|
|
|
|
|
+ if (!check_device(dev) ||
|
|
|
|
+ !get_device_resources(dev, &iommu, &domain, &devid))
|
|
return;
|
|
return;
|
|
|
|
|
|
spin_lock_irqsave(&domain->lock, flags);
|
|
spin_lock_irqsave(&domain->lock, flags);
|
|
@@ -999,6 +1020,9 @@ static void *alloc_coherent(struct device *dev, size_t size,
|
|
u16 devid;
|
|
u16 devid;
|
|
phys_addr_t paddr;
|
|
phys_addr_t paddr;
|
|
|
|
|
|
|
|
+ if (!check_device(dev))
|
|
|
|
+ return NULL;
|
|
|
|
+
|
|
virt_addr = (void *)__get_free_pages(flag, get_order(size));
|
|
virt_addr = (void *)__get_free_pages(flag, get_order(size));
|
|
if (!virt_addr)
|
|
if (!virt_addr)
|
|
return 0;
|
|
return 0;
|
|
@@ -1047,6 +1071,9 @@ static void free_coherent(struct device *dev, size_t size,
|
|
struct protection_domain *domain;
|
|
struct protection_domain *domain;
|
|
u16 devid;
|
|
u16 devid;
|
|
|
|
|
|
|
|
+ if (!check_device(dev))
|
|
|
|
+ return;
|
|
|
|
+
|
|
get_device_resources(dev, &iommu, &domain, &devid);
|
|
get_device_resources(dev, &iommu, &domain, &devid);
|
|
|
|
|
|
if (!iommu || !domain)
|
|
if (!iommu || !domain)
|