diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 2594767e44b3a..2ad9bee429d03 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -279,11 +279,13 @@ class AlgSimp : public BasicStmtVisitor { cast_to_result_type(one, stmt); auto new_exponent = Stmt::make(UnaryOpType::neg, stmt->rhs); + new_exponent->ret_type = stmt->rhs->ret_type; auto a_to_n = Stmt::make(BinaryOpType::pow, stmt->lhs, new_exponent.get()); a_to_n->ret_type = stmt->ret_type; auto result = Stmt::make(BinaryOpType::div, one, a_to_n.get()); + result->ret_type = stmt->ret_type; stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(new_exponent)); modifier.insert_before(stmt, std::move(a_to_n)); diff --git a/tests/python/test_optimization.py b/tests/python/test_optimization.py index c966464647da4..1500663e75f44 100644 --- a/tests/python/test_optimization.py +++ b/tests/python/test_optimization.py @@ -153,3 +153,19 @@ def my_cast(x: ti.f32) -> ti.u32: return ti.cast(y, ti.u32) assert my_cast(-1) == 4294967295 + + +@test_utils.test() +def test_negative_exp(): + @ti.dataclass + class Particle: + epsilon: ti.f32 + + @ti.kernel + def test() -> ti.f32: + p1 = Particle() + p1.epsilon = 1.0 + e = p1.epsilon + return e**-1 + + assert test() == 1.0