diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c index f655859eec000e..f1e23583e6c08e 100644 --- a/drivers/infiniband/hw/mlx5/odp.c +++ b/drivers/infiniband/hw/mlx5/odp.c @@ -228,13 +228,27 @@ static void destroy_unused_implicit_child_mr(struct mlx5_ib_mr *mr) unsigned long idx = ib_umem_start(odp) >> MLX5_IMR_MTT_SHIFT; struct mlx5_ib_mr *imr = mr->parent; + /* + * If userspace is racing freeing the parent implicit ODP MR then we can + * loose the race with parent destruction. In this case + * mlx5_ib_free_odp_mr() will free everything in the implicit_children + * xarray so NOP is fine. This child MR cannot be destroyed here because + * we are under its umem_mutex. + */ if (!refcount_inc_not_zero(&imr->mmkey.usecount)) return; - xa_erase(&imr->implicit_children, idx); + xa_lock(&imr->implicit_children); + if (__xa_cmpxchg(&imr->implicit_children, idx, mr, NULL, GFP_KERNEL) != + mr) { + xa_unlock(&imr->implicit_children); + return; + } + if (MLX5_CAP_ODP(mr_to_mdev(mr)->mdev, mem_page_fault)) - xa_erase(&mr_to_mdev(mr)->odp_mkeys, - mlx5_base_mkey(mr->mmkey.key)); + __xa_erase(&mr_to_mdev(mr)->odp_mkeys, + mlx5_base_mkey(mr->mmkey.key)); + xa_unlock(&imr->implicit_children); /* Freeing a MR is a sleeping operation, so bounce to a work queue */ INIT_WORK(&mr->odp_destroy.work, free_implicit_child_mr_work); @@ -502,18 +516,18 @@ static struct mlx5_ib_mr *implicit_get_child_mr(struct mlx5_ib_mr *imr, refcount_inc(&ret->mmkey.usecount); goto out_lock; } - xa_unlock(&imr->implicit_children); if (MLX5_CAP_ODP(dev->mdev, mem_page_fault)) { - ret = xa_store(&dev->odp_mkeys, mlx5_base_mkey(mr->mmkey.key), - &mr->mmkey, GFP_KERNEL); + ret = __xa_store(&dev->odp_mkeys, mlx5_base_mkey(mr->mmkey.key), + &mr->mmkey, GFP_KERNEL); if (xa_is_err(ret)) { ret = ERR_PTR(xa_err(ret)); - xa_erase(&imr->implicit_children, idx); - goto out_mr; + __xa_erase(&imr->implicit_children, idx); + goto out_lock; } mr->mmkey.type = MLX5_MKEY_IMPLICIT_CHILD; } + xa_unlock(&imr->implicit_children); mlx5_ib_dbg(mr_to_mdev(imr), "key %x mr %p\n", mr->mmkey.key, mr); return mr;