Skip to content

Commit

Permalink
[Lang] [ir] [cuda] Add clz instruction (#8276)
Browse files Browse the repository at this point in the history
Issue: #8212 

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5d312ab</samp>

This pull request implements a new `clz` function in Taichi, which
counts the number of leading zeros for a 32-bit integer. The function is
available as a Python function decorator, a unary operation in the
Taichi expression system, and a backend-specific intrinsic in the code
generation. The pull request modifies the relevant files in the
`python`, `taichi`, and `codegen` directories.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 5d312ab</samp>

* Add a new function `clz` to count the number of leading zeros for a
32-bit integer
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-059028cb0798284bed05638becbc32d256736846de19746e196fe5f5ee7fd061R1118-R1126),
[link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-5b3923516b48467202850afb384ef9901ecefae0173f03bcc9055adffe96d738R814-R816),
[link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-f95015864ea3251da5d376f2a11e8f5a0045d7aaf4602370471686f56561dafdR22),
[link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-b0b26408cd63f0a7edc6e9a6936ec09df7dc5f37c2ab65d72b3f9125f1385ba1R90),
[link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-af631a0c71978fe591e17005f01f7c06bc30ae36c65df306bbb3b08ade770167R941))
* Define the function `clz` in the `ops` module in
`python/taichi/lang/ops.py` using a unary operation wrapper
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-059028cb0798284bed05638becbc32d256736846de19746e196fe5f5ee7fd061R1118-R1126))
* Add a wrapper function `clz` in the `mathimpl` module in
`python/taichi/math/mathimpl.py` to allow using `clz` as a Taichi
function decorator
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-5b3923516b48467202850afb384ef9901ecefae0173f03bcc9055adffe96d738R814-R816))
* Add a new macro for the `clz` unary operation in
`taichi/inc/unary_op.inc.h` and `taichi/ir/expression_ops.h` to expand
to the corresponding enum value and expression class
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-f95015864ea3251da5d376f2a11e8f5a0045d7aaf4602370471686f56561dafdR22),
[link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-b0b26408cd63f0a7edc6e9a6936ec09df7dc5f37c2ab65d72b3f9125f1385ba1R90))
* Add a new macro for the `clz` unary operation in
`taichi/python/export_lang.cpp` to bind the operation to the Python
interface
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-af631a0c71978fe591e17005f01f7c06bc30ae36c65df306bbb3b08ade770167R941))
* Implement the `clz` unary operation for different backends
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-50537ad5ea3b900c0d55a088f3cc285986340ad68c9b96fea481187c4dce49eaL289-R296),
[link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-3c663c78745adcd3f6a7ac81fe99e628decc3040f292ea1e20ecd4b85a7f4313R210-R213),
[link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-1620f2a387fc8acc55e2b2cfced07bb9cba59702609aae6e9489e703cbab5000R900-R904))
* Add a new case for the `clz` unary operation in the CUDA backend code
generation in `taichi/codegen/cuda/codegen_cuda.cpp`, which calls the
CUDA intrinsic function `__clz` and checks the input type
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-50537ad5ea3b900c0d55a088f3cc285986340ad68c9b96fea481187c4dce49eaL289-R296))
* Add a new case for the `clz` unary operation in the LLVM backend code
generation in `taichi/codegen/llvm/codegen_llvm.cpp`, which calls the
LLVM intrinsic function `ctlz` and assigns the result to the statement
value
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-3c663c78745adcd3f6a7ac81fe99e628decc3040f292ea1e20ecd4b85a7f4313R210-R213))
* Add a new case for the `clz` unary operation in the SPIRV backend code
generation in `taichi/codegen/spirv/spirv_codegen.cpp`, which calls the
GLSL 450 extended instruction `FindMSB` and subtracts the result from 32
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-1620f2a387fc8acc55e2b2cfced07bb9cba59702609aae6e9489e703cbab5000R900-R904))
* Add a new method for the `clz` unary operation in the IR builder
class, which is a helper class for constructing IR statements
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-bdb4f85a29d6478a4482d81ca072237534fb641b52f3c529aca93e872ade6fecR278-R281),
[link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-1894085b261e833e3e66924fc5b1cf63b9dd8b8aa0b3e78ec64366396131470dR177))
* Add a declaration for the `clz` unary operation method in the IR
builder class header file in `taichi/ir/ir_builder.h`
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-1894085b261e833e3e66924fc5b1cf63b9dd8b8aa0b3e78ec64366396131470dR177))
* Add a definition for the `clz` unary operation method in the IR
builder class source file in `taichi/ir/ir_builder.cpp`, which creates
and inserts a new unary operation statement with the `clz` type and the
input value
([link](https://github.com/taichi-dev/taichi/pull/8276/files?diff=unified&w=0#diff-bdb4f85a29d6478a4482d81ca072237534fb641b52f3c529aca93e872ade6fecR278-R281))

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Bob Cao <[email protected]>
Co-authored-by: Lin Jiang <[email protected]>
  • Loading branch information
4 people authored Oct 31, 2023
1 parent c72897e commit ac49f79
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,18 @@ def py_ifte(cond, x1, x2):
return _ternary_operation(_ti_core.expr_ifte, py_ifte, cond, x1, x2)


def clz(a):
"""Count the number of leading zeros for a 32bit integer"""

def _clz(x):
for i in range(32):
if 2**i > x:
return 32 - i
return 0

return _unary_operation(_ti_core.expr_clz, _clz, a)


@writeback_binary
def atomic_add(x, y):
"""Atomically compute `x + y`, store the result in `x`,
Expand Down
6 changes: 6 additions & 0 deletions python/taichi/math/mathimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,12 +812,18 @@ def popcnt(x):
return ops.popcnt(x)


@func
def clz(x):
return ops.clz(x)


__all__ = [
"acos",
"asin",
"atan2",
"ceil",
"clamp",
"clz",
"cos",
"cross",
"degrees",
Expand Down
9 changes: 9 additions & 0 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::clz) {
if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) {
stmt->ret_type = PrimitiveType::i32;
llvm_val[stmt] = call("__nv_clz", input);
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) {
llvm_val[stmt] = call("__nv_clzll", input);
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::log) {
if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) {
// logf has fast-math option
Expand Down
6 changes: 6 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) {
llvm_val[stmt] =
builder->CreateIntrinsic(llvm::Intrinsic::ctpop, {input_type}, {input});
}
else if (op == UnaryOpType::clz) {
llvm_val[stmt] = builder->CreateIntrinsic(
llvm::Intrinsic::ctlz, {input_type},
{input,
llvm::ConstantInt::get(llvm::Type::getInt1Ty(*llvm_context), 0)});
}
else {
TI_P(unary_op_type_name(op));
TI_NOT_IMPLEMENTED
Expand Down
6 changes: 6 additions & 0 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,12 @@ class TaskCodegen : public IRVisitor {
ir_->store_variable(val, v);
} else if (stmt->op_type == UnaryOpType::popcnt) {
val = ir_->popcnt(operand_val);
} else if (stmt->op_type == UnaryOpType::clz) {
uint32_t FindMSB_id = 74;
spirv::Value msb = ir_->call_glsl450(dst_type, FindMSB_id, operand_val);
spirv::Value bitcnt = ir_->int_immediate_number(ir_->i32_type(), 32);
spirv::Value one = ir_->int_immediate_number(ir_->i32_type(), 1);
val = ir_->sub(ir_->sub(bitcnt, msb), one);
}
#define UNARY_OP_TO_SPIRV(op, instruction, instruction_id, max_bits) \
else if (stmt->op_type == UnaryOpType::op) { \
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/unary_op.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ PER_UNARY_OP(rcp)
PER_UNARY_OP(exp)
PER_UNARY_OP(log)
PER_UNARY_OP(popcnt)
PER_UNARY_OP(clz)
PER_UNARY_OP(rsqrt)
PER_UNARY_OP(bit_not)
PER_UNARY_OP(logic_not)
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/expression_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ DEFINE_EXPRESSION_FUNC_UNARY(rsqrt)
DEFINE_EXPRESSION_FUNC_UNARY(exp)
DEFINE_EXPRESSION_FUNC_UNARY(log)
DEFINE_EXPRESSION_FUNC_UNARY(popcnt)
DEFINE_EXPRESSION_FUNC_UNARY(clz)
DEFINE_EXPRESSION_FUNC_UNARY(logic_not)
DEFINE_EXPRESSION_OP_UNARY(~, bit_not)
DEFINE_EXPRESSION_OP_UNARY(-, neg)
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ UnaryOpStmt *IRBuilder::create_popcnt(Stmt *value) {
return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::popcnt, value));
}

UnaryOpStmt *IRBuilder::create_clz(Stmt *value) {
return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::clz, value));
}

BinaryOpStmt *IRBuilder::create_add(Stmt *l, Stmt *r) {
return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::add, l, r));
}
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class IRBuilder {
UnaryOpStmt *create_exp(Stmt *value);
UnaryOpStmt *create_log(Stmt *value);
UnaryOpStmt *create_popcnt(Stmt *value);
UnaryOpStmt *create_clz(Stmt *value);

// Binary operations. Returns the result.
BinaryOpStmt *create_add(Stmt *l, Stmt *r);
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ void export_lang(py::module &m) {
DEFINE_EXPRESSION_OP(exp)
DEFINE_EXPRESSION_OP(log)
DEFINE_EXPRESSION_OP(popcnt)
DEFINE_EXPRESSION_OP(clz)

DEFINE_EXPRESSION_OP(select)
DEFINE_EXPRESSION_OP(ifte)
Expand Down
1 change: 1 addition & 0 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _get_expected_matrix_apis():
"cinv",
"clamp",
"clog",
"clz",
"cmul",
"cos",
"cpow",
Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,22 @@ def test_u64(x: ti.uint64) -> ti.int32:
assert test_i64(10000) == 5


@test_utils.test(arch=[ti.cpu, ti.metal, ti.cuda, ti.vulkan])
def test_clz():
@ti.kernel
def test_i32(x: ti.int32) -> ti.int32:
return ti.math.clz(x)

# assert test_i32(0) == 32
assert test_i32(1) == 31
assert test_i32(2) == 30
assert test_i32(3) == 30
assert test_i32(4) == 29
assert test_i32(5) == 29
assert test_i32(1023) == 22
assert test_i32(1024) == 21


@test_utils.test(arch=[ti.metal])
def test_popcnt():
@ti.kernel
Expand Down

0 comments on commit ac49f79

Please sign in to comment.