Skip to content

Commit

Permalink
[TOPI] GPU scatter_add using atomic (apache#7044)
Browse files Browse the repository at this point in the history
* use atomic add for faster 1d scatter add

* update tests

* run black

* more pylint fix

* remove fp64 bintcount test

Co-authored-by: masa <[email protected]>
  • Loading branch information
2 people authored and Tushar Dey committed Jan 20, 2021
1 parent 4a3fe16 commit 9c25460
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 9 deletions.
17 changes: 14 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,18 +1921,29 @@ def empty(self, inputs, input_types):
def bincount(self, inputs, input_types):
data = inputs[0]
weights = inputs[1]
input_type = _infer_type(data).checked_type.dtype
if input_type == "int64":
logging.warning(
"Casting an int64 input to int32, since we do not have int64 atomic add"
"needed for bincount yet."
)
data = _op.cast(data, "int32")
maximum = _op.max(data)
dim = maximum + _expr.const(1, dtype="int64")
dim = maximum + _expr.const(1, dtype="int32")
if weights:
weight_type = _infer_type(weights).checked_type
out_dtype = weight_type.dtype
updates = weights
else:
out_dtype = "int64"
out_dtype = "int32"
updates = _op.ones_like(data)

counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
return _op.scatter_add(counts, data, updates, axis=0)
out = _op.scatter_add(counts, data, updates, axis=0)
if input_type == "int32":
# Torch always outputs int64 results for bincount
return _op.cast(out, "int64")
return out

def scatter_add(self, inputs, input_types):
data = inputs[0]
Expand Down
80 changes: 79 additions & 1 deletion python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm
from tvm import te
from ..scatter import _verify_scatter_nd_inputs
from .nms import atomic_add


def ceil_div(a, b):
Expand Down Expand Up @@ -470,6 +471,83 @@ def update_func(dst_ptr, dst_index, update):
return out


def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
"""Generate scatter add ir for 1d inputs, using atomic_add instruction
Parameters
----------
data : tir.Tensor
The input data to the operator.
indices : tir.Tensor
The index locations to update.
updates : tir.Tensor
The values to update.
axis : int
The axis to scatter on
out : tir.Tensor
The output tensor.
Returns
-------
ret : tir
The computational ir.
"""
assert axis == 0
n = data.shape[0]

ib = tvm.tir.ir_builder.create()

out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

with ib.new_scope():
nthread_bx = ceil_div(n, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < n):
out_ptr[tid] = data_ptr[tid]

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)

ni = indices.shape[0]

atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local")

with ib.new_scope():
nthread_bx = ceil_div(ni, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

with ib.if_scope(tid < ni):
index = indices_ptr[tid]
with ib.if_scope(index < 0):
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]),
updates_ptr[tid],
)
with ib.else_scope():
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]),
updates_ptr[tid],
)

return ib.get()


def scatter_add(data, indices, updates, axis=0):
"""Update data by adding values in updates at positions defined by indices
Expand Down Expand Up @@ -501,7 +579,7 @@ def scatter_add(data, indices, updates, axis=0):
assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions"

ir_funcs = {
1: gen_ir_1d,
1: gen_scatter_add_1d_atomic,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
Expand Down
10 changes: 5 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3355,12 +3355,12 @@ def test_bincount():
def test_fn(x, weights=None):
return torch.bincount(x, weights=weights)

inp = torch.randint(0, 8, (5,), dtype=torch.int64)
weights = torch.linspace(0, 1, steps=5)
inp = torch.randint(0, 100, (10000,), dtype=torch.int64)
weights = torch.linspace(0, 100, steps=10000)

verify_trace_model(test_fn, [inp], ["llvm"])
verify_trace_model(test_fn, [inp, weights], ["llvm"])
verify_trace_model(test_fn, [inp, weights.to(torch.float64)], ["llvm"])
targets = ["llvm", "cuda"]
verify_trace_model(test_fn, [inp], targets)
verify_trace_model(test_fn, [inp, weights], targets)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,11 +1049,15 @@ def verify_scatter_add(dshape, ishape, axis=0):
ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
if target == "nvptx":
# TODO(masahi): support atomic in LLVM codegen
continue
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_scatter_add((10,), (10,), 0)
verify_scatter_add((1000,), (1000,), 0)
verify_scatter_add((10, 5), (10, 5), -2)
verify_scatter_add((10, 5), (10, 5), -1)
verify_scatter_add((10, 5), (3, 5), 0)
Expand Down

0 comments on commit 9c25460

Please sign in to comment.