Skip to content

Commit

Permalink
[autodiff] Fix the type of cmp statements in autodiff (#8452)
Browse files Browse the repository at this point in the history
Issue: fixes #8444

The return type of cmp statements of tensors should be tensors of u1
instead of tensors of i32.

Sometimes the CFG detects that an AdStackLoadTopStmt and an
AdStackLoadTopAdjStmt loading the same address. I don't know if this
should happen, but it stops raising error if I don't eliminate the
latter statement.
  • Loading branch information
lin-hitonami authored Dec 26, 2023
1 parent f74d75d commit fc1b05a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
9 changes: 5 additions & 4 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,11 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
!may_contain_variable(killed_in_this_node, load_ptr)) {
// Only perform identical load elimination within a CFGNode.
auto next_load_stmt = live_load_in_this_node[load_ptr];
TI_ASSERT(irpass::analysis::same_statements(stmt, next_load_stmt));
next_load_stmt->replace_usages_with(stmt);
erase(block->locate(next_load_stmt));
modified = true;
if (irpass::analysis::same_statements(stmt, next_load_stmt)) {
next_load_stmt->replace_usages_with(stmt);
erase(block->locate(next_load_stmt));
modified = true;
}
}

update_container_with_alias(tensor_to_matrix_ptrs_map,
Expand Down
14 changes: 9 additions & 5 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,11 @@ class RegulateTensorTypedStatements : public BasicStmtVisitor {

auto matrix_index = Stmt::make<MatrixInitStmt>(index_values);
matrix_index->ret_type = index_tensor_type;

auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type(
tensor_shape, PrimitiveType::u1);
auto matrix_eq = Stmt::make<BinaryOpStmt>(
BinaryOpType::cmp_eq, matrix_offset.get(), matrix_index.get());
matrix_eq->ret_type = index_tensor_type;
matrix_eq->ret_type = cmp_tensor_type;

auto orig_value = Stmt::make<Load>(orig_stmt);
orig_value->ret_type = tensor_type;
Expand Down Expand Up @@ -843,9 +844,11 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor {
auto matrix_index = Stmt::make<MatrixInitStmt>(index_values);
matrix_index->ret_type = index_tensor_type;

auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type(
tensor_shape, PrimitiveType::u1);
auto matrix_eq = Stmt::make<BinaryOpStmt>(
BinaryOpType::cmp_eq, matrix_offset.get(), matrix_index.get());
matrix_eq->ret_type = index_tensor_type;
matrix_eq->ret_type = cmp_tensor_type;

auto matrix_alloca_value =
Stmt::make<AdStackLoadTopStmt>(stack_top_stmt->stack);
Expand Down Expand Up @@ -1817,11 +1820,12 @@ class MakeAdjoint : public ADTransform {

auto offset_matrix_init_stmt = insert<MatrixInitStmt>(offset_values);
offset_matrix_init_stmt->ret_type = index_tensor_type;

auto cmp_tensor_type = TypeFactory::get_instance().get_tensor_type(
tensor_shape, PrimitiveType::u1);
auto bin_eq_stmt =
insert<BinaryOpStmt>(BinaryOpType::cmp_eq, offset_matrix_init_stmt,
indices_matrix_init_stmt);
bin_eq_stmt->ret_type = index_tensor_type;
bin_eq_stmt->ret_type = cmp_tensor_type;

auto select_stmt = insert<TernaryOpStmt>(
TernaryOpType::select, bin_eq_stmt, stmt_adj_matrix_init_stmt,
Expand Down

0 comments on commit fc1b05a

Please sign in to comment.