Skip to content

Commit

Permalink
[simt] Support "__syncthreads_and", "__syncthreads_or", and "__syncth…
Browse files Browse the repository at this point in the history
…reads_count" from CUDA. (#8297)

Issue: #8289 

### Brief Summary

From the CUDA document:
Devices of compute capability 2.x and higher support three variations of
__syncthreads() described below.
```cpp
int __syncthreads_count(int predicate);
```
is identical to __syncthreads() with the additional feature that it
evaluates predicate for all threads of the block and returns the number
of threads for which predicate evaluates to non-zero.
```cpp
int __syncthreads_and(int predicate);
```
is identical to __syncthreads() with the additional feature that it
evaluates predicate for all threads of the block and returns non-zero if
and only if predicate evaluates to non-zero for all of them.
```cpp
int __syncthreads_or(int predicate);
```
is identical to __syncthreads() with the additional feature that it
evaluates predicate for all threads of the block and returns non-zero if
and only if predicate evaluates to non-zero for any of them.

This PR just add these three operations for CUDA only, the API looks
like:
```python
def sync_all_nonzero(predicate): # __syncthreads_and

def sync_any_nonzero(predicate): # __syncthreads_or

def sync_count_nonzero(predicate): #__syncthreads_count

```
And the predicate is always expected to be ti.int32
### Walkthrough
Overall, the code is just modified from the CUDA WARP operations, the
implementation is pretty straightforward. I tried to add some similar
tests to the WARP operations, and all tests are passed on my local
machine.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
wanmeihuali and pre-commit-ci[bot] authored Oct 31, 2023
1 parent 1ae0e46 commit b8d7ffd
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 0 deletions.
21 changes: 21 additions & 0 deletions python/taichi/lang/simt/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@ def sync():
raise ValueError(f"ti.block.shared_array is not supported for arch {arch}")


def sync_all_nonzero(predicate):
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda:
return impl.call_internal("block_barrier_and_i32", predicate, with_runtime_context=False)
raise ValueError(f"ti.block.sync_all_nonzero is not supported for arch {arch}")


def sync_any_nonzero(predicate):
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda:
return impl.call_internal("block_barrier_or_i32", predicate, with_runtime_context=False)
raise ValueError(f"ti.block.sync_any_nonzero is not supported for arch {arch}")


def sync_count_nonzero(predicate):
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda:
return impl.call_internal("block_barrier_count_i32", predicate, with_runtime_context=False)
raise ValueError(f"ti.block.sync_count_nonzero is not supported for arch {arch}")


def mem_sync():
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda:
Expand Down
3 changes: 3 additions & 0 deletions taichi/inc/internal_ops.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ PER_INTERNAL_OP(subgroupInclusiveXor)

// CUDA
PER_INTERNAL_OP(block_barrier)
PER_INTERNAL_OP(block_barrier_and_i32)
PER_INTERNAL_OP(block_barrier_or_i32)
PER_INTERNAL_OP(block_barrier_count_i32)
PER_INTERNAL_OP(grid_memfence)
PER_INTERNAL_OP(cuda_all_sync_i32)
PER_INTERNAL_OP(cuda_any_sync_i32)
Expand Down
3 changes: 3 additions & 0 deletions taichi/ir/type_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ void Operations::init_internals() {
PLAIN_OP(cuda_match_##name##_sync_##dt, u32, false, u32, dt)

PLAIN_OP(block_barrier, i32_void, false);
PLAIN_OP(block_barrier_and_i32, i32, false, i32);
PLAIN_OP(block_barrier_or_i32, i32, false, i32);
PLAIN_OP(block_barrier_count_i32, i32, false, i32);
PLAIN_OP(grid_memfence, i32_void, false);
CUDA_VOTE_SYNC(all);
CUDA_VOTE_SYNC(any);
Expand Down
3 changes: 3 additions & 0 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,9 @@ std::unique_ptr<llvm::Module> TaichiLLVMContext::module_from_file(
patch_intrinsic("block_dim", Intrinsic::nvvm_read_ptx_sreg_ntid_x);
patch_intrinsic("grid_dim", Intrinsic::nvvm_read_ptx_sreg_nctaid_x);
patch_intrinsic("block_barrier", Intrinsic::nvvm_barrier0, false);
patch_intrinsic("block_barrier_and_i32", Intrinsic::nvvm_barrier0_and);
patch_intrinsic("block_barrier_or_i32", Intrinsic::nvvm_barrier0_or);
patch_intrinsic("block_barrier_count_i32", Intrinsic::nvvm_barrier0_popc);
patch_intrinsic("warp_barrier", Intrinsic::nvvm_bar_warp_sync, false);
patch_intrinsic("block_memfence", Intrinsic::nvvm_membar_cta, false);
patch_intrinsic("grid_memfence", Intrinsic::nvvm_membar_gl, false);
Expand Down
12 changes: 12 additions & 0 deletions taichi/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,18 @@ uint32 cuda_active_mask() {
void block_barrier() {
}

int32 block_barrier_and_i32(int32 predicate) {
return 0;
}

int32 block_barrier_or_i32(int32 predicate) {
return 0;
}

int32 block_barrier_count_i32(int32 predicate) {
return 0;
}

void warp_barrier(uint32 mask) {
}

Expand Down
87 changes: 87 additions & 0 deletions tests/python/test_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,34 @@ def foo():
assert a[i] == 0


@test_utils.test(arch=ti.cuda)
def test_sync_all_nonzero():
a = ti.field(dtype=ti.i32, shape=256)
b = ti.field(dtype=ti.i32, shape=256)

@ti.kernel
def foo():
ti.loop_config(block_dim=256)
for i in range(256):
a[i] = ti.simt.block.sync_all_nonzero(b[i])

for i in range(256):
b[i] = 1
a[i] = -1

foo()

for i in range(256):
assert a[i] == 1

b[np.random.randint(0, 256)] = 0

foo()

for i in range(256):
assert a[i] == 0


@test_utils.test(arch=ti.cuda)
def test_any_nonzero():
a = ti.field(dtype=ti.i32, shape=32)
Expand Down Expand Up @@ -63,6 +91,65 @@ def foo():
assert a[i] == 1


@test_utils.test(arch=ti.cuda)
def test_sync_any_nonzero():
a = ti.field(dtype=ti.i32, shape=256)
b = ti.field(dtype=ti.i32, shape=256)

@ti.kernel
def foo():
ti.loop_config(block_dim=256)
for i in range(256):
a[i] = ti.simt.block.sync_any_nonzero(b[i])

for i in range(256):
b[i] = 0
a[i] = -1

foo()

for i in range(256):
assert a[i] == 0

b[np.random.randint(0, 256)] = 1

foo()

for i in range(256):
assert a[i] == 1


@test_utils.test(arch=ti.cuda)
def test_sync_count_nonzero():
a = ti.field(dtype=ti.i32, shape=256)
b = ti.field(dtype=ti.i32, shape=256)

@ti.kernel
def foo():
ti.loop_config(block_dim=256)
for i in range(256):
a[i] = ti.simt.block.sync_count_nonzero(b[i])

for i in range(256):
b[i] = 0
a[i] = -1

foo()

for i in range(256):
assert a[i] == 0

random_idx_count = np.random.randint(0, 256)
random_idx = np.random.choice(256, random_idx_count, replace=False)
for i in range(random_idx_count):
b[random_idx[i]] = 1

foo()

for i in range(256):
assert a[i] == random_idx_count


@test_utils.test(arch=ti.cuda)
def test_unique():
a = ti.field(dtype=ti.u32, shape=32)
Expand Down

0 comments on commit b8d7ffd

Please sign in to comment.