From c27024848b59493cdf74deda10141c5efc2f2d56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AD=94=E6=B3=95=E5=B0=91=E5=A5=B3=E8=B5=B5=E5=BF=97?= =?UTF-8?q?=E8=BE=89?= Date: Mon, 17 Jul 2023 17:39:26 +0800 Subject: [PATCH 1/7] [Lang] Implement struct DebugInfo and ErrorEmitter (#8284) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: # ### Brief Summary ### 🤖 Generated by Copilot at d9d3f20 This pull request improves the error and warning reporting in Taichi by adding source code location information to the Taichi IR and exceptions. It modifies `taichi/common/exceptions.h` to define new exception classes, and `taichi/ir/ir.h` to store and use location information in IR statements. ### Walkthrough ### 🤖 Generated by Copilot at d9d3f20 * Introduce a new ErrorEmitter struct that handles the formatting and throwing of errors and warnings in Taichi programs ([link](https://github.com/taichi-dev/taichi/pull/8284/files?diff=unified&w=0#diff-ccaf2900f7c75403d5aecff661ea03785341c847f0f7b1a5c75c6b93e5ede5d9L3-R20), [link](https://github.com/taichi-dev/taichi/pull/8284/files?diff=unified&w=0#diff-61484fa2a50e309478017fb2a436198aa4b0afdf72a4039bf574fc4f2aedbe4eR390-R407)) * Refactor the exception hierarchy and create a new TaichiError class and a new TaichiWarning class that inherit from TaichiExceptionImpl and have different emit methods ([link](https://github.com/taichi-dev/taichi/pull/8284/files?diff=unified&w=0#diff-ccaf2900f7c75403d5aecff661ea03785341c847f0f7b1a5c75c6b93e5ede5d9L18-R134)) * Add a new member src_location to the Stmt class to store the source code location of each statement in the IR ([link](https://github.com/taichi-dev/taichi/pull/8284/files?diff=unified&w=0#diff-61484fa2a50e309478017fb2a436198aa4b0afdf72a4039bf574fc4f2aedbe4eR422)) --- taichi/codegen/llvm/codegen_llvm.cpp | 8 +- taichi/codegen/spirv/spirv_codegen.cpp | 18 +-- taichi/common/exceptions.h | 135 ++++++++++++++++-- taichi/ir/expr.cpp | 6 +- taichi/ir/expr.h | 2 + taichi/ir/expression.h | 10 +- taichi/ir/frontend_ir.cpp | 44 +++--- taichi/ir/ir.cpp | 2 +- taichi/ir/ir.h | 14 +- taichi/ir/statements.cpp | 2 +- taichi/program/program.cpp | 2 +- taichi/transforms/alg_simp.cpp | 6 +- taichi/transforms/auto_diff.cpp | 6 +- taichi/transforms/check_out_of_bound.cpp | 8 +- taichi/transforms/demote_atomics.cpp | 2 +- taichi/transforms/frontend_type_check.cpp | 8 +- taichi/transforms/inlining.cpp | 2 +- taichi/transforms/lower_ast.cpp | 2 +- taichi/transforms/lower_matrix_ptr.cpp | 2 +- taichi/transforms/offload.cpp | 2 +- taichi/transforms/simplify.cpp | 3 +- taichi/transforms/type_check.cpp | 12 +- tests/cpp/ir/frontend_type_inference_test.cpp | 3 +- 23 files changed, 220 insertions(+), 79 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 9e487302ce9286..4ae01c67af1d3c 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -569,7 +569,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm_val[stmt] = call("debug_add_" + stmt->ret_type->to_string(), get_arg(0), llvm_val[stmt->lhs], llvm_val[stmt->rhs], - builder->CreateGlobalStringPtr(stmt->tb)); + builder->CreateGlobalStringPtr(stmt->get_tb())); #endif } else { llvm_val[stmt] = @@ -584,7 +584,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm_val[stmt] = call("debug_sub_" + stmt->ret_type->to_string(), get_arg(0), llvm_val[stmt->lhs], llvm_val[stmt->rhs], - builder->CreateGlobalStringPtr(stmt->tb)); + builder->CreateGlobalStringPtr(stmt->get_tb())); #endif } else { llvm_val[stmt] = @@ -599,7 +599,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm_val[stmt] = call("debug_mul_" + stmt->ret_type->to_string(), get_arg(0), llvm_val[stmt->lhs], llvm_val[stmt->rhs], - builder->CreateGlobalStringPtr(stmt->tb)); + builder->CreateGlobalStringPtr(stmt->get_tb())); #endif } else { llvm_val[stmt] = @@ -646,7 +646,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm_val[stmt] = call("debug_shl_" + stmt->ret_type->to_string(), get_arg(0), llvm_val[stmt->lhs], llvm_val[stmt->rhs], - builder->CreateGlobalStringPtr(stmt->tb)); + builder->CreateGlobalStringPtr(stmt->get_tb())); } else { llvm_val[stmt] = builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 56fd1f454de31e..4d678bf89b31c8 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1139,31 +1139,31 @@ class TaskCodegen : public IRVisitor { TI_WARN_IF(lhs_value.stype.id != rhs_value.stype.id, "${} type {} != ${} type {}\n{}", lhs_name, lhs_value.stype.dt->to_string(), rhs_name, - rhs_value.stype.dt->to_string(), bin->tb); + rhs_value.stype.dt->to_string(), bin->get_tb()); bool debug = caps_->get(DeviceCapability::spirv_has_non_semantic_info); if (debug && op_type == BinaryOpType::add && is_integral(dst_type.dt)) { if (is_unsigned(dst_type.dt)) { - bin_value = generate_uadd_overflow(lhs_value, rhs_value, bin->tb); + bin_value = generate_uadd_overflow(lhs_value, rhs_value, bin->get_tb()); } else { - bin_value = generate_sadd_overflow(lhs_value, rhs_value, bin->tb); + bin_value = generate_sadd_overflow(lhs_value, rhs_value, bin->get_tb()); } bin_value = ir_->cast(dst_type, bin_value); } else if (debug && op_type == BinaryOpType::sub && is_integral(dst_type.dt)) { if (is_unsigned(dst_type.dt)) { - bin_value = generate_usub_overflow(lhs_value, rhs_value, bin->tb); + bin_value = generate_usub_overflow(lhs_value, rhs_value, bin->get_tb()); } else { - bin_value = generate_ssub_overflow(lhs_value, rhs_value, bin->tb); + bin_value = generate_ssub_overflow(lhs_value, rhs_value, bin->get_tb()); } bin_value = ir_->cast(dst_type, bin_value); } else if (debug && op_type == BinaryOpType::mul && is_integral(dst_type.dt)) { if (is_unsigned(dst_type.dt)) { - bin_value = generate_umul_overflow(lhs_value, rhs_value, bin->tb); + bin_value = generate_umul_overflow(lhs_value, rhs_value, bin->get_tb()); } else { - bin_value = generate_smul_overflow(lhs_value, rhs_value, bin->tb); + bin_value = generate_smul_overflow(lhs_value, rhs_value, bin->get_tb()); } bin_value = ir_->cast(dst_type, bin_value); } @@ -1187,9 +1187,9 @@ class TaskCodegen : public IRVisitor { else if (debug && op_type == BinaryOpType::bit_shl) { if (is_unsigned(dst_type.dt)) { - bin_value = generate_ushl_overflow(lhs_value, rhs_value, bin->tb); + bin_value = generate_ushl_overflow(lhs_value, rhs_value, bin->get_tb()); } else { - bin_value = generate_sshl_overflow(lhs_value, rhs_value, bin->tb); + bin_value = generate_sshl_overflow(lhs_value, rhs_value, bin->get_tb()); } } BINARY_OP_TO_SPIRV_BITWISE(bit_and, OpBitwiseAnd) diff --git a/taichi/common/exceptions.h b/taichi/common/exceptions.h index eb3570ffe23b97..23c656a3d92395 100644 --- a/taichi/common/exceptions.h +++ b/taichi/common/exceptions.h @@ -1,13 +1,50 @@ #pragma once +#include +#include +#include +#include "taichi/common/logging.h" + namespace taichi::lang { class IRModified {}; +struct Location { + int line_number = 0; + std::string var_name = ""; +}; + +struct DebugInfo { + Location src_loc; + std::string tb; + + explicit DebugInfo() = default; + + explicit DebugInfo(std::string tb_) : tb(tb_) { + } + + explicit DebugInfo(const char *tb_) : tb(tb_) { + } + + std::string const &get_tb() const { + return tb; + } + + void set_tb(std::string const &tb) { + this->tb = tb; + } +}; + class TaichiExceptionImpl : public std::exception { + friend struct ErrorEmitter; + + protected: std::string msg_; public: + // Add default constructor to allow passing Exception to ErrorEmitter + // TODO: remove this and find a better way + explicit TaichiExceptionImpl() = default; explicit TaichiExceptionImpl(const std::string msg) : msg_(msg) { } const char *what() const noexcept override { @@ -15,24 +52,106 @@ class TaichiExceptionImpl : public std::exception { } }; -class TaichiTypeError : public TaichiExceptionImpl { +class TaichiError : public TaichiExceptionImpl { using TaichiExceptionImpl::TaichiExceptionImpl; }; -class TaichiSyntaxError : public TaichiExceptionImpl { +class TaichiWarning : public TaichiExceptionImpl { using TaichiExceptionImpl::TaichiExceptionImpl; + + protected: + static constexpr std::string_view name_; + + public: + void emit() { + taichi::Logger::get_instance().warn(std::string(name_) + "\n" + msg_); + } }; -class TaichiIndexError : public TaichiExceptionImpl { - using TaichiExceptionImpl::TaichiExceptionImpl; +class TaichiTypeError : public TaichiError { + using TaichiError::TaichiError; }; -class TaichiRuntimeError : public TaichiExceptionImpl { - using TaichiExceptionImpl::TaichiExceptionImpl; +class TaichiSyntaxError : public TaichiError { + using TaichiError::TaichiError; }; -class TaichiAssertionError : public TaichiExceptionImpl { - using TaichiExceptionImpl::TaichiExceptionImpl; +class TaichiIndexError : public TaichiError { + using TaichiError::TaichiError; +}; + +class TaichiRuntimeError : public TaichiError { + using TaichiError::TaichiError; +}; + +class TaichiAssertionError : public TaichiError { + using TaichiError::TaichiError; +}; + +class TaichiIrError : public TaichiError { + using TaichiError::TaichiError; +}; + +class TaichiCastWarning : public TaichiWarning { + using TaichiWarning::TaichiWarning; + static constexpr std::string_view name_ = "TaichiCastWarning"; +}; + +class TaichiTypeWarning : public TaichiWarning { + using TaichiWarning::TaichiWarning; + static constexpr std::string_view name_ = "TaichiTypeWarning"; +}; + +class TaichiIrWarning : public TaichiWarning { + using TaichiWarning::TaichiWarning; + static constexpr std::string_view name_ = "TaichiIrWarning"; +}; + +class TaichiIndexWarning : public TaichiWarning { + using TaichiWarning::TaichiWarning; + static constexpr std::string_view name_ = "TaichiIndexWarning"; +}; + +class TaichiRuntimeWarning : public TaichiWarning { + using TaichiWarning::TaichiWarning; + static constexpr std::string_view name_ = "TaichiRuntimeWarning"; +}; + +struct ErrorEmitter { + ErrorEmitter() = delete; + ErrorEmitter(ErrorEmitter &) = delete; + ErrorEmitter(ErrorEmitter &&) = delete; + + // Emit an error on stmt with error message + template >>, + // The expected type for T is `Stmt`, `Expression`, or `DebugInfo`. + // These types have a member function named get_tb() that returns + // trace back information as a `std::string`. + typename T, + typename = std::enable_if_t()->get_tb())>, + std::string>>> + ErrorEmitter(E &&error, T p_dbg_info, std::string &&error_msg) { + if constexpr ((std::is_same_v, DebugInfo *> || + std::is_same_v, const DebugInfo *>)&&std:: + is_base_of_v>) { + // Indicates a failed C++ API call from Python side, we should not print + // tb here + error.msg_ = error_msg; + } else { + error.msg_ = p_dbg_info->get_tb() + error_msg; + } + + if constexpr (std::is_base_of_v>) { + error.emit(); + } else if constexpr (std::is_base_of_v>) { + throw std::move(error); + } else { + TI_NOT_IMPLEMENTED; + } + } }; } // namespace taichi::lang diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 38ff682d0ef151..e81eaf776ffbe4 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -7,7 +7,11 @@ namespace taichi::lang { void Expr::set_tb(const std::string &tb) { - expr->tb = tb; + expr->set_tb(tb); +} + +const std::string &Expr::get_tb() const { + return expr->get_tb(); } DataType Expr::get_ret_type() const { diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index 31a65ecd602587..7348ca9bc263b0 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -95,6 +95,8 @@ class Expr { // traceback for type checking error message void set_tb(const std::string &tb); + const std::string &get_tb() const; + void set_adjoint(const Expr &o); void set_dual(const Expr &o); diff --git a/taichi/ir/expression.h b/taichi/ir/expression.h index 918cb6a9f90329..2eca1890b1501a 100644 --- a/taichi/ir/expression.h +++ b/taichi/ir/expression.h @@ -15,7 +15,7 @@ class Expression { Stmt *stmt; public: - std::string tb; + DebugInfo dbg_info; std::map attributes; DataType ret_type; @@ -59,6 +59,14 @@ class Expression { Stmt *get_flattened_stmt() const { return stmt; } + + std::string const &get_tb() const { + return dbg_info.tb; + } + + void set_tb(std::string const &tb) { + dbg_info.tb = tb; + } }; class ExprGroup { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 2e057d9e7275fa..99a4dd3fc88635 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -268,7 +268,7 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) { unary->cast_type = cast_type; } stmt = unary.get(); - stmt->tb = tb; + stmt->set_tb(get_tb()); stmt->ret_type = ret_type; ctx->push_back(std::move(unary)); } @@ -437,14 +437,14 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { if_stmt->set_false_statements(std::move(false_block)); auto ret = ctx->push_back(result); - ret->tb = tb; + ret->set_tb(get_tb()); stmt = ret; stmt->ret_type = ret_type; return; } auto rhs_stmt = flatten_rvalue(rhs, ctx); ctx->push_back(std::make_unique(type, lhs_stmt, rhs_stmt)); - ctx->stmts.back()->tb = tb; + ctx->stmts.back()->set_tb(get_tb()); stmt = ctx->back_stmt(); stmt->ret_type = ret_type; } @@ -587,7 +587,7 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { make_ifte(ctx, ret_type, op1, op2, op3); } stmt = ctx->back_stmt(); - stmt->tb = tb; + stmt->set_tb(get_tb()); stmt->ret_type = ret_type; } @@ -602,7 +602,7 @@ void InternalFuncCallExpression::type_check(const CompileConfig *) { void InternalFuncCallExpression::flatten(FlattenContext *ctx) { stmt = op->flatten(ctx, args, ret_type); - stmt->tb = tb; + stmt->set_tb(get_tb()); } void ExternalTensorExpression::flatten(FlattenContext *ctx) { @@ -614,7 +614,7 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) { Stmt::make(arg_id, type, /*is_ptr=*/true, /*create_load=*/false, /*arg_depth=*/arg_depth); - ptr->tb = tb; + ptr->set_tb(get_tb()); ctx->push_back(std::move(ptr)); stmt = ctx->back_stmt(); } @@ -767,7 +767,7 @@ IndexExpression::IndexExpression(const Expr &var, const ExprGroup &indices, std::string tb) : var(var), indices_group({indices}) { - this->tb = tb; + this->set_tb(tb); } IndexExpression::IndexExpression(const Expr &var, @@ -780,7 +780,7 @@ IndexExpression::IndexExpression(const Expr &var, // axis. For example, mat[0, 3:5] has indices_group={0, [3, 4]}, where [3, 4] // means "m"-axis will return a TensorType with size of 2. In this case, we // should not expand indices_group due to its special semantics. - this->tb = tb; + this->set_tb(tb); } bool IndexExpression::is_field() const { @@ -914,13 +914,13 @@ void IndexExpression::flatten(FlattenContext *ctx) { } else if (is_tensor()) { stmt = make_tensor_access( ctx, var, indices_group, ret_type, - var->ret_type.ptr_removed()->as()->get_shape(), tb); + var->ret_type.ptr_removed()->as()->get_shape(), get_tb()); } else { throw TaichiTypeError( "Invalid IndexExpression: the source is not among field, ndarray or " "local tensor"); } - stmt->tb = tb; + stmt->set_tb(get_tb()); } void RangeAssumptionExpression::type_check(const CompileConfig *) { @@ -1027,7 +1027,7 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) { auto dest_stmt = flatten_lvalue(dest, ctx); stmt = ctx->push_back(op_type, dest_stmt, val_stmt); stmt->ret_type = stmt->as()->dest->ret_type; - stmt->tb = tb; + stmt->set_tb(get_tb()); } SNodeOpExpression::SNodeOpExpression(SNode *snode, @@ -1061,7 +1061,7 @@ void SNodeOpExpression::type_check(const CompileConfig *config) { auto promoted = promoted_type(dst_type, value_type); if (dst_type != promoted) { TI_WARN("Append may lose precision: {} <- {}\n{}", - dst_type->to_string(), value_type->to_string(), tb); + dst_type->to_string(), value_type->to_string(), get_tb()); } values[i] = cast(values[i], dst_type); values[i]->type_check(config); @@ -1078,7 +1078,7 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { snode->type != SNodeType::dynamic; auto ptr = ctx->push_back(snode, indices_stmt, true, is_cell_access); - ptr->tb = tb; + ptr->set_tb(get_tb()); if (op_type == SNodeOpType::is_active) { TI_ERROR_IF(snode->type != SNodeType::pointer && snode->type != SNodeType::hash && @@ -1091,17 +1091,17 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { ctx->push_back(SNodeOpType::get_addr, snode, ptr, nullptr); } else if (op_type == SNodeOpType::append) { auto alloca = ctx->push_back(PrimitiveType::i32); - alloca->set_tb(tb); + alloca->set_tb(get_tb()); auto addr = ctx->push_back(SNodeOpType::allocate, snode, ptr, alloca); - addr->set_tb(tb); + addr->set_tb(get_tb()); for (int i = 0; i < values.size(); i++) { auto value_stmt = flatten_rvalue(values[i], ctx); auto ch_addr = ctx->push_back(addr, snode, i); - ch_addr->set_tb(tb); - ctx->push_back(ch_addr, value_stmt)->set_tb(tb); + ch_addr->set_tb(get_tb()); + ctx->push_back(ch_addr, value_stmt)->set_tb(get_tb()); } - ctx->push_back(alloca)->set_tb(tb); + ctx->push_back(alloca)->set_tb(get_tb()); TI_ERROR_IF(snode->type != SNodeType::dynamic, "ti.append only works on dynamic nodes."); } @@ -1352,7 +1352,7 @@ void ASTBuilder::insert_assignment(Expr &lhs, lhs.set(rhs); } else if (lhs.expr->is_lvalue()) { auto stmt = std::make_unique(lhs, rhs); - stmt->tb = tb; + stmt->set_tb(tb); this->insert(std::move(stmt)); } else { @@ -1716,13 +1716,13 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { if (expr.is()) { id_expr = expr; } else { - id_expr = make_var(expr, expr->tb); + id_expr = make_var(expr, expr.get_tb()); } auto shape = tensor_type->get_shape(); if (shape.size() == 1) { for (int i = 0; i < shape[0]; i++) { auto ind = Expr(std::make_shared( - id_expr, ExprGroup(Expr(i)), expr->tb)); + id_expr, ExprGroup(Expr(i)), expr.get_tb())); ind->type_check(nullptr); expanded_exprs.push_back(ind); } @@ -1731,7 +1731,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { for (int i = 0; i < shape[0]; i++) { for (int j = 0; j < shape[1]; j++) { auto ind = Expr(std::make_shared( - id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb)); + id_expr, ExprGroup(Expr(i), Expr(j)), expr.get_tb())); ind->type_check(nullptr); expanded_exprs.push_back(ind); } diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 6b76eda37839f7..969a5766e0a679 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -120,7 +120,7 @@ Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) { instance_id = instance_id_counter++; id = instance_id; erased = stmt.erased; - tb = stmt.tb; + dbg_info = stmt.dbg_info; ret_type = stmt.ret_type; } diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 8fba4176a7667c..9ac7ab785192c9 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -399,8 +399,8 @@ class Stmt : public IRNode { Block *parent; bool erased; bool fields_registered; - std::string tb; DataType ret_type; + DebugInfo dbg_info; Stmt(); Stmt(const Stmt &stmt); @@ -440,6 +440,14 @@ class Stmt : public IRNode { return *operands[i]; } + TI_FORCE_INLINE std::string const &get_tb() const { + return dbg_info.tb; + } + + TI_FORCE_INLINE void set_tb(const std::string &tb) { + dbg_info.tb = tb; + } + std::vector get_operands() const; void set_operand(int i, Stmt *stmt); @@ -484,10 +492,6 @@ class Stmt : public IRNode { return make_typed(std::forward(args)...); } - void set_tb(const std::string &tb) { - this->tb = tb; - } - std::string type(); virtual std::unique_ptr clone() const { diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 1a3af91dd85fe1..1fd6274132437f 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -101,7 +101,7 @@ MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input, const std::string &tb) { origin = origin_input; offset = offset_input; - this->tb = tb; + this->set_tb(tb); if (origin->is() || origin->is() || origin->is() || origin->is() || diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 739f9e6f9b7c08..9bb43f9f3913b5 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -314,7 +314,7 @@ Kernel &Program::get_snode_writer(SNode *snode) { std::vector{snode->num_active_indices}, snode->dt->get_compute_type()); argload_expr->type_check(&this->compile_config()); - builder.insert_assignment(expr, argload_expr, expr->tb); + builder.insert_assignment(expr, argload_expr, expr->get_tb()); }); ker.name = kernel_name; ker.is_accessor = true; diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 98715c57d37a43..b1e647f1ce5aec 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -382,7 +382,7 @@ class AlgSimp : public BasicStmtVisitor { Stmt::make(BinaryOpType::bit_shl, stmt->lhs, new_rhs); result->ret_type = stmt->ret_type; - result->set_tb(stmt->tb); + result->set_tb(stmt->get_tb()); stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(result)); modifier.erase(stmt); @@ -396,7 +396,7 @@ class AlgSimp : public BasicStmtVisitor { cast_to_result_type(a, stmt); auto sum = Stmt::make(BinaryOpType::add, a, a); sum->ret_type = a->ret_type; - sum->set_tb(stmt->tb); + sum->dbg_info = stmt->dbg_info; stmt->replace_usages_with(sum.get()); modifier.insert_before(stmt, std::move(sum)); modifier.erase(stmt); @@ -427,7 +427,7 @@ class AlgSimp : public BasicStmtVisitor { is_real(rhs->ret_type.get_element_type()) && stmt->op_type != BinaryOpType::floordiv) { if (alg_is_zero(rhs)) { - TI_WARN("Potential division by 0\n{}", stmt->tb); + TI_WARN("Potential division by 0\n{}", stmt->get_tb()); } else { // a / const -> a * (1 / const) Stmt *new_rhs = get_inverse(stmt); diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 878c88d317f925..b03e35b74b6935 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -1376,7 +1376,7 @@ class MakeAdjoint : public ADTransform { } else { TI_WARN("gradient of binary op {}\n{}", binary_op_type_name(bin->op_type), - bin->tb); + bin->get_tb()); TI_NOT_IMPLEMENTED; } } @@ -2012,7 +2012,7 @@ class MakeDual : public ADTransform { // do nothing } else { TI_WARN("gradient of binary op {}\n{}", binary_op_type_name(bin->op_type), - bin->tb); + bin->get_tb()); TI_NOT_IMPLEMENTED } } @@ -2455,7 +2455,7 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor { "(kernel={}) Breaks the global data access rule. Snode {} is " "overwritten unexpectedly.", kernel_name_, dest->snode->get_node_type_name()); - msg += "\n" + stmt->tb; + msg += "\n" + stmt->get_tb(); stmt->insert_before_me( Stmt::make(check_equal, msg, std::vector())); diff --git a/taichi/transforms/check_out_of_bound.cpp b/taichi/transforms/check_out_of_bound.cpp index 164100957b08ea..fe21ee2ed91fe8 100644 --- a/taichi/transforms/check_out_of_bound.cpp +++ b/taichi/transforms/check_out_of_bound.cpp @@ -90,7 +90,7 @@ class CheckOutOfBound : public BasicStmtVisitor { msg += ", "; msg += "%d"; } - msg += "]\n" + stmt->tb; + msg += "]\n" + stmt->get_tb(); new_stmts.push_back(result, msg, args); modifier.insert_before(stmt, std::move(new_stmts)); @@ -153,7 +153,7 @@ class CheckOutOfBound : public BasicStmtVisitor { msg += "%d"; } msg += ")"; - msg += "\n" + stmt->tb; + msg += "\n" + stmt->get_tb(); new_stmts.push_back(result, msg, args); modifier.insert_before(stmt, std::move(new_stmts)); @@ -198,7 +198,7 @@ class CheckOutOfBound : public BasicStmtVisitor { msg += ", "; msg += std::to_string(matrix_shape[i]); } - msg += "] matrix with index [%d]\n" + stmt->tb; + msg += "] matrix with index [%d]\n" + stmt->get_tb(); std::vector args = {index}; new_stmts.push_back(result, msg, args); @@ -219,7 +219,7 @@ class CheckOutOfBound : public BasicStmtVisitor { BinaryOpType::cmp_ge, stmt->rhs, compare_rhs.get()); compare->ret_type = PrimitiveType::i32; std::string msg = "Negative exponent in pow(int, int) is not allowed."; - msg += "\n" + stmt->tb; + msg += "\n" + stmt->get_tb(); auto assert_stmt = std::make_unique(compare.get(), msg, std::vector()); assert_stmt->accept(this); diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp index d92300ab0a1707..1520b3c3175df4 100644 --- a/taichi/transforms/demote_atomics.cpp +++ b/taichi/transforms/demote_atomics.cpp @@ -151,7 +151,7 @@ class DemoteAtomics : public BasicStmtVisitor { TI_WARN( "AtomicOp on QuantFloatType is not supported. " "Demoting to non-atomic RMW.\n{}", - stmt->tb); + stmt->get_tb()); demote = true; } } diff --git a/taichi/transforms/frontend_type_check.cpp b/taichi/transforms/frontend_type_check.cpp index 4a75b6683ac8d9..27eac4fc86c445 100644 --- a/taichi/transforms/frontend_type_check.cpp +++ b/taichi/transforms/frontend_type_check.cpp @@ -55,7 +55,7 @@ class FrontendTypeCheck : public IRVisitor { auto error = [&]() { throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'", - stmt->tb, rhs_type->to_string(), + stmt->get_tb(), rhs_type->to_string(), lhs_type->to_string())); }; @@ -68,6 +68,8 @@ class FrontendTypeCheck : public IRVisitor { void visit(FrontendIfStmt *stmt) override { // TODO: use PrimitiveType::u1 when it's supported + std::cerr << fmt::format("[debug] stmt->dbg_info.tb {}\n", + stmt->dbg_info.tb); check_cond_type(stmt->condition, "if"); if (stmt->true_statements) stmt->true_statements->accept(this); @@ -109,7 +111,7 @@ class FrontendTypeCheck : public IRVisitor { if (unsupported_group.find(conversion) != std::string::npos) { throw TaichiTypeError(fmt::format("{}conversion '{}' is not supported.", - stmt->tb, conversion)); + stmt->get_tb(), conversion)); } if ((real_group.find(conversion) != std::string::npos && @@ -119,7 +121,7 @@ class FrontendTypeCheck : public IRVisitor { (unsigned_group.find(conversion) != std::string::npos && !(is_integral(data_type) && is_unsigned(data_type)))) { throw TaichiTypeError(fmt::format("{} '{}' doesn't match '{}'.", - stmt->tb, format_spec, + stmt->get_tb(), format_spec, data_type->to_string())); } } diff --git a/taichi/transforms/inlining.cpp b/taichi/transforms/inlining.cpp index 3c30f5859a749c..7be805e3ec0741 100644 --- a/taichi/transforms/inlining.cpp +++ b/taichi/transforms/inlining.cpp @@ -44,7 +44,7 @@ class Inliner : public BasicStmtVisitor { TI_WARN( "Multiple returns in function \"{}\" may not be handled " "properly.\n{}", - func->get_name(), stmt->tb); + func->get_name(), stmt->get_tb()); } // Use a local variable to store the return value auto *return_address = inlined_ir->as()->insert( diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 5d615bb42f3bea..b2f03e45e9fd9f 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -440,7 +440,7 @@ class LowerAST : public IRVisitor { dest.cast()->is_ptr); fctx.push_back(dest_stmt, expr_stmt); } - fctx.stmts.back()->set_tb(assign->tb); + fctx.stmts.back()->dbg_info = assign->dbg_info; assign->parent->replace_with(assign, std::move(fctx.stmts)); } diff --git a/taichi/transforms/lower_matrix_ptr.cpp b/taichi/transforms/lower_matrix_ptr.cpp index a9bea82ed074a4..e4deb59192e840 100644 --- a/taichi/transforms/lower_matrix_ptr.cpp +++ b/taichi/transforms/lower_matrix_ptr.cpp @@ -532,7 +532,7 @@ class LowerMatrixPtr : public BasicStmtVisitor { TI_ASSERT_INFO( origin->dynamic_indexable, "Element of the MatrixField is not dynamic indexable.\n{}", - stmt->tb); + stmt->get_tb()); auto stride = std::make_unique( TypedConstant(origin->dynamic_index_stride)); auto offset = std::make_unique( diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 4f860974d1f46d..ada6e692689d87 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -240,7 +240,7 @@ class Offloader { TI_WARN( "Specified block dim {} is bigger than SNode element size {}. " "Clipping.\n{}", - for_stmt->block_dim, snode_num_elements, for_stmt->tb); + for_stmt->block_dim, snode_num_elements, for_stmt->get_tb()); offloaded_struct_for->block_dim = snode_num_elements; } else { offloaded_struct_for->block_dim = for_stmt->block_dim; diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 1f5ea1e43aec99..415193cc0a1225 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -184,7 +184,8 @@ class BasicBlockSimplify : public IRVisitor { auto check_sum = Stmt::make(BinaryOpType::cmp_ge, sum.get(), zero.get()); auto assert = Stmt::make( - check_sum.get(), "The indices provided are too big!\n" + stmt->tb, + check_sum.get(), + "The indices provided are too big!\n" + stmt->get_tb(), std::vector()); // Because Taichi's assertion is checked only after the execution of the // kernel, when the linear index overflows and goes negative, we have to diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 28d439ce2e141c..35af0a04018fac 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -34,7 +34,7 @@ class TypeCheck : public IRVisitor { if (dst_type != promoted) { TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(), stmt_name, dst_type->to_string(), val->ret_data_type_name(), - stmt->tb); + stmt->get_tb()); } val = insert_type_cast_before(stmt, val, dst_type); } @@ -145,7 +145,7 @@ class TypeCheck : public IRVisitor { TypeFactory::get_instance().get_pointer_type(stmt->snode->dt); } else TI_WARN("[{}] Type inference failed: snode is nullptr.\n{}", stmt->name(), - stmt->tb); + stmt->get_tb()); auto check_indices = [&](SNode *snode) { if (snode->num_active_indices != stmt->indices.size()) { TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(), @@ -159,7 +159,7 @@ class TypeCheck : public IRVisitor { TI_WARN( "[{}] Field index {} not int32, casting into int32 " "implicitly\n{}", - stmt->name(), i, stmt->tb); + stmt->name(), i, stmt->get_tb()); stmt->indices[i] = insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32); } @@ -275,7 +275,7 @@ class TypeCheck : public IRVisitor { std::string msg = "Detected overflow for bit_shift_op with rhs = %d, exceeding limit of " "%d."; - msg += "\n" + stmt->tb; + msg += "\n" + stmt->get_tb(); std::vector args = {rhs, const_stmt.get()}; auto assert_stmt = Stmt::make(cond_stmt.get(), msg, std::move(args)); @@ -302,9 +302,9 @@ class TypeCheck : public IRVisitor { if (comment == "") { TI_WARN("[{}] Type mismatch (left = {}, right = {}, stmt_id = {})\n{}", stmt->name(), stmt->lhs->ret_data_type_name(), - stmt->rhs->ret_data_type_name(), stmt->id, stmt->tb); + stmt->rhs->ret_data_type_name(), stmt->id, stmt->get_tb()); } else { - TI_WARN("[{}] {}\n{}", stmt->name(), comment, stmt->tb); + TI_WARN("[{}] {}\n{}", stmt->name(), comment, stmt->get_tb()); } TI_WARN("Compilation stopped due to type mismatch."); throw std::runtime_error("Binary operator type mismatch"); diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index f9f791e37dbe3f..4c7d7e75b297c7 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -32,7 +32,8 @@ TEST(FrontendTypeInference, Id) { auto kernel = std::make_unique(*prog, func, "fake_kernel"); auto const_i32 = value(-(1 << 20)); const_i32->type_check(nullptr); - auto id_i32 = kernel->context->builder().make_var(const_i32, const_i32->tb); + auto id_i32 = + kernel->context->builder().make_var(const_i32, const_i32->get_tb()); EXPECT_EQ(id_i32->ret_type, DataType(TypeFactory::get_instance().get_pointer_type( PrimitiveType::i32))); From 4fb52987875f03137f56553d2b80c7f3b8fef169 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AD=94=E6=B3=95=E5=B0=91=E5=A5=B3=E8=B5=B5=E5=BF=97?= =?UTF-8?q?=E8=BE=89?= Date: Tue, 18 Jul 2023 13:13:08 +0800 Subject: [PATCH 2/7] [Lang] Use ErrorEmitter in type check passes (#8285) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: # ### Brief Summary ### 🤖 Generated by Copilot at 79478ec This pull request introduces a new exception class `TaichiWarning` and its subclasses to handle warnings during compilation, and refactors the error and warning handling in various files using the `ErrorEmitter` class and `fmt::format`. It also adds source location and traceback information to statements and expressions for better error reporting, and improves the type checking logic for the frontend and backend. ### Walkthrough ### 🤖 Generated by Copilot at 79478ec * Add new classes and functions for error and warning reporting and handling in `exceptions.h` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-ccaf2900f7c75403d5aecff661ea03785341c847f0f7b1a5c75c6b93e5ede5d9L3-R20), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-ccaf2900f7c75403d5aecff661ea03785341c847f0f7b1a5c75c6b93e5ede5d9L18-R134)) * Add new structs for source location and debug information tracking in `ir.h` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-61484fa2a50e309478017fb2a436198aa4b0afdf72a4039bf574fc4f2aedbe4eR390-R407), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-61484fa2a50e309478017fb2a436198aa4b0afdf72a4039bf574fc4f2aedbe4eR422)) * Add new field for source location to the `Stmt` class in `ir.h` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-61484fa2a50e309478017fb2a436198aa4b0afdf72a4039bf574fc4f2aedbe4eR422)) * Replace `throw` statements with `ErrorEmitter` calls to emit error messages with expression and traceback information in `frontend_ir.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L199-R205), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L207-R215), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L221-R231), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L246-R258), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L285-R296), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L293-R307), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L824-R838), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L840-R855), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L864-R880), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L880-R903), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L896-R917), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L918-R937), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L932-R954), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L952-R972), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L975-R1000), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1212-R1254), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1224-R1268), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1241-R1285), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1343-R1388)) * Add type checking and warning logic for implicit casting in `frontend_ir.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35R1032-R1043), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1062-R1099), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1104-R1142)) * Add type checking and error logic for snode and texture operations in `frontend_ir.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1082-R1122), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1121-R1167), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1145-R1184), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1161-R1203), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1177-R1219), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1187-R1229)) * Move header files from `frontend_ir.cpp` to `exceptions.h` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L12-R16), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-ccaf2900f7c75403d5aecff661ea03785341c847f0f7b1a5c75c6b93e5ede5d9L3-R20)) * Add an example function for using `ErrorEmitter` in `frontend_ir.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L173-R177)) * Modify `FrontendTypeCheck` class to use `ErrorEmitter` for error and warning messages in `frontend_type_check.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-540161ef0e79b2ecf9eabafa3601b58dcd92fcd69583b475dbb80e7d29f84506L9-R18), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-540161ef0e79b2ecf9eabafa3601b58dcd92fcd69583b475dbb80e7d29f84506L45-R76), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-540161ef0e79b2ecf9eabafa3601b58dcd92fcd69583b475dbb80e7d29f84506L53-R106), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-540161ef0e79b2ecf9eabafa3601b58dcd92fcd69583b475dbb80e7d29f84506L111-R153), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-540161ef0e79b2ecf9eabafa3601b58dcd92fcd69583b475dbb80e7d29f84506L121-R164)) * Modify `TypeCheck` class to use `ErrorEmitter` for error and warning messages in `type_check.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL147-R143), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL159-R152), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL278), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL301-R295)) * Remove `stmt_name` parameter and `stmt->tb` from error and warning messages in `type_check.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL21-R21), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL33-L38), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL83-R74), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL108-R97), [link](https://github.com/taichi-dev/taichi/pull/8285/files?diff=unified&w=0#diff-dd572dab7be4dbb5edc1043d6d6339b931ef35198b8657761ebf45a83e76ac2bL176-R166)) --- taichi/ir/frontend_ir.cpp | 225 ++++++++++++++-------- taichi/transforms/frontend_type_check.cpp | 85 +++++--- taichi/transforms/type_check.cpp | 56 ++---- 3 files changed, 228 insertions(+), 138 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 99a4dd3fc88635..31f363b8f44cf5 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -9,10 +9,15 @@ namespace taichi::lang { -#define TI_ASSERT_TYPE_CHECKED(x) \ - TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, \ - "[{}] was not type-checked", \ - ExpressionHumanFriendlyPrinter::expr_to_string(x)) +#define TI_ASSERT_TYPE_CHECKED(x) \ + do { \ + if (x->ret_type == PrimitiveType::unknown) { \ + ErrorEmitter( \ + TaichiTypeError(), x.expr.get(), \ + fmt::format("[{}] was not type-checked", \ + ExpressionHumanFriendlyPrinter::expr_to_string(x))); \ + } \ + } while (false) static bool is_primitive_or_tensor_type(DataType &type) { return type->is() || type->is(); @@ -170,8 +175,11 @@ void TexturePtrExpression::flatten(FlattenContext *ctx) { } void RandExpression::type_check(const CompileConfig *) { - TI_ASSERT_INFO(dt->is() && dt != PrimitiveType::unknown, - "Invalid dt [{}] for RandExpression", dt->to_string()); + if (!(dt->is() && dt != PrimitiveType::unknown)) { + ErrorEmitter( + TaichiTypeError(), this, + fmt::format("Invalid dt [{}] for RandExpression", dt->to_string())); + } ret_type = dt; } @@ -196,17 +204,20 @@ void UnaryOpExpression::type_check(const CompileConfig *config) { auto ret_primitive_type = ret_type; if (!operand_primitive_type->is()) { - throw TaichiTypeError(fmt::format( - "unsupported operand type(s) for '{}': '{}'", unary_op_type_name(type), - operand_primitive_type->to_string())); + ErrorEmitter(TaichiTypeError(), this, + fmt::format("unsupported operand type(s) for '{}': '{}'", + unary_op_type_name(type), + operand_primitive_type->to_string())); } if ((type == UnaryOpType::round || type == UnaryOpType::floor || type == UnaryOpType::ceil || is_trigonometric(type)) && !is_real(operand_primitive_type)) - throw TaichiTypeError(fmt::format( - "'{}' takes real inputs only, however '{}' is provided", - unary_op_type_name(type), operand_primitive_type->to_string())); + ErrorEmitter( + TaichiTypeError(), this, + fmt::format("'{}' takes real inputs only, however '{}' is provided", + unary_op_type_name(type), + operand_primitive_type->to_string())); if ((type == UnaryOpType::sqrt || type == UnaryOpType::exp || type == UnaryOpType::log) && @@ -218,9 +229,11 @@ void UnaryOpExpression::type_check(const CompileConfig *config) { if ((type == UnaryOpType::bit_not || type == UnaryOpType::logic_not) && is_real(operand_primitive_type)) { - throw TaichiTypeError(fmt::format( - "'{}' takes integral inputs only, however '{}' is provided", - unary_op_type_name(type), operand_primitive_type->to_string())); + ErrorEmitter( + TaichiTypeError(), this, + fmt::format("'{}' takes integral inputs only, however '{}' is provided", + unary_op_type_name(type), + operand_primitive_type->to_string())); } if (type == UnaryOpType::logic_not) { @@ -243,9 +256,11 @@ void UnaryOpExpression::type_check(const CompileConfig *config) { } if (type == UnaryOpType::popcnt && is_real(operand_primitive_type)) { - throw TaichiTypeError(fmt::format( - "'{}' takes integral inputs only, however '{}' is provided", - unary_op_type_name(type), operand_primitive_type->to_string())); + ErrorEmitter( + TaichiTypeError(), this, + fmt::format("'{}' takes integral inputs only, however '{}' is provided", + unary_op_type_name(type), + operand_primitive_type->to_string())); } if (operand_type->is()) { @@ -282,7 +297,8 @@ Expr to_broadcast_tensor(const Expr &elt, const DataType &dt) { // Only tensor shape will be checked here, since the dtype will // be promoted later at irpass::type_check() if (elt_type.get_shape() != dt.get_shape()) { - TI_ERROR("Cannot broadcast tensor to tensor"); + ErrorEmitter(TaichiTypeError(), elt.expr.get(), + "Cannot broadcast tensor to tensor"); } else { return elt; } @@ -290,9 +306,12 @@ Expr to_broadcast_tensor(const Expr &elt, const DataType &dt) { auto tensor_type = dt->as(); auto tensor_elt_type = tensor_type->get_element_type(); - TI_ASSERT_INFO(tensor_elt_type->is(), - "Only primitive types are supported in Tensors, got {}", - tensor_elt_type->to_string()); + if (!tensor_elt_type->is()) { + ErrorEmitter( + TaichiTypeError(), elt.expr.get(), + fmt::format("Only primitive types are supported in Tensors, got {}", + tensor_elt_type->to_string())); + } std::vector broadcast_values(tensor_type->get_num_elements(), elt); auto matrix_expr = Expr::make( broadcast_values, tensor_type->get_shape(), elt_type); @@ -523,7 +542,8 @@ void TernaryOpExpression::type_check(const CompileConfig *config) { auto op3_type = op3.get_rvalue_type(); auto error = [&]() { - throw TaichiTypeError( + ErrorEmitter( + TaichiTypeError(), this, fmt::format("unsupported operand type(s) for '{}': '{}', '{}' and '{}'", ternary_type_name(type), op1_type->to_string(), op2_type->to_string(), op3_type->to_string())); @@ -822,7 +842,8 @@ static void field_validation(FieldExpression *field_expr, int index_dim) { int field_dim = field_expr->snode->num_active_indices; if (field_dim != index_dim) { - throw TaichiIndexError( + ErrorEmitter( + TaichiIndexError(), field_expr, fmt::format("Field with dim {} accessed with indices of dim {}", field_dim, index_dim)); } @@ -838,7 +859,10 @@ void IndexExpression::type_check(const CompileConfig *) { bool has_slice = !ret_shape.empty(); auto var_type = var.get_rvalue_type(); if (has_slice) { - TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices"); + if (!is_tensor()) { + ErrorEmitter(TaichiTypeError(), this, + "Slice or swizzle can only apply on matrices"); + } auto element_type = var_type->as()->get_element_type(); ret_type = TypeFactory::create_tensor_type(ret_shape, element_type); } else if (is_field()) { // field @@ -862,7 +886,8 @@ void IndexExpression::type_check(const CompileConfig *) { int element_dim = external_tensor_expr->dt.get_shape().size(); int total_dim = ndim + element_dim; if (total_dim != index_dim + element_dim) { - throw TaichiTypeError( + ErrorEmitter( + TaichiIndexError(), this, fmt::format("Array with dim {} accessed with indices of dim {}", total_dim - element_dim, index_dim)); } @@ -878,12 +903,14 @@ void IndexExpression::type_check(const CompileConfig *) { auto tensor_type = var_type->as(); auto shape = tensor_type->get_shape(); if (indices_group[0].size() != shape.size()) { - TI_ERROR("Expected {} indices, got {}.", shape.size(), - indices_group[0].size()); + ErrorEmitter(TaichiIndexError(), this, + fmt::format("Expected {} indices, got {}.", shape.size(), + indices_group[0].size())); } ret_type = tensor_type->get_element_type(); } else { - throw TaichiTypeError( + ErrorEmitter( + TaichiIndexError(), this, "Invalid IndexExpression: the source is not among field, ndarray or " "local tensor"); } @@ -894,10 +921,10 @@ void IndexExpression::type_check(const CompileConfig *) { TI_ASSERT_TYPE_CHECKED(expr); auto expr_type = expr.get_rvalue_type(); if (!is_integral(expr_type)) - throw TaichiTypeError( - fmt::format("indices must be integers, however '{}' is " - "provided as index {}", - expr_type->to_string(), i)); + ErrorEmitter(TaichiTypeError(), this, + fmt::format("indices must be integers, however '{}' is " + "provided as index {}", + expr_type->to_string(), i)); } } } @@ -916,7 +943,8 @@ void IndexExpression::flatten(FlattenContext *ctx) { ctx, var, indices_group, ret_type, var->ret_type.ptr_removed()->as()->get_shape(), get_tb()); } else { - throw TaichiTypeError( + ErrorEmitter( + TaichiIndexError(), this, "Invalid IndexExpression: the source is not among field, ndarray or " "local tensor"); } @@ -930,10 +958,10 @@ void RangeAssumptionExpression::type_check(const CompileConfig *) { auto base_type = base.get_rvalue_type(); if (!input_type->is() || !base_type->is() || input_type != base_type) - throw TaichiTypeError( - fmt::format("unsupported operand type(s) for " - "'range_assumption': '{}' and '{}'", - input_type->to_string(), base_type->to_string())); + ErrorEmitter(TaichiTypeError(), this, + fmt::format("unsupported operand type(s) for " + "'range_assumption': '{}' and '{}'", + input_type->to_string(), base_type->to_string())); ret_type = input_type; } @@ -950,7 +978,8 @@ void LoopUniqueExpression::type_check(const CompileConfig *) { auto input_type = input.get_rvalue_type(); if (!input_type->is()) - throw TaichiTypeError( + ErrorEmitter( + TaichiTypeError(), this, fmt::format("unsupported operand type(s) for 'loop_unique': '{}'", input_type->to_string())); ret_type = input_type; @@ -973,10 +1002,12 @@ void AtomicOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(dest); TI_ASSERT_TYPE_CHECKED(val); auto error = [&]() { - throw TaichiTypeError(fmt::format( - "unsupported operand type(s) for 'atomic_{}': '{}' and '{}'", - atomic_op_type_name(op_type), dest->ret_type->to_string(), - val->ret_type->to_string())); + ErrorEmitter( + TaichiTypeError(), this, + fmt::format( + "unsupported operand type(s) for 'atomic_{}': '{}' and '{}'", + atomic_op_type_name(op_type), dest->ret_type->to_string(), + val->ret_type->to_string())); }; // Broadcast val to dest if neccessary @@ -1008,6 +1039,18 @@ void AtomicOpExpression::type_check(const CompileConfig *config) { } else { error(); } + + auto const &ret_element_type = ret_type.get_element_type(); + if (ret_element_type != val_dtype) { + auto promoted = promoted_type(ret_element_type, val_dtype); + if (ret_element_type != promoted) { + ErrorEmitter( + TaichiCastWarning(), this, + fmt::format("Atomic {} may lose precision: {} <- {}", + atomic_op_type_name(op_type), + ret_element_type->to_string(), val_dtype->to_string())); + } + } } void AtomicOpExpression::flatten(FlattenContext *ctx) { @@ -1060,8 +1103,10 @@ void SNodeOpExpression::type_check(const CompileConfig *config) { auto value_type = values[i].get_rvalue_type(); auto promoted = promoted_type(dst_type, value_type); if (dst_type != promoted) { - TI_WARN("Append may lose precision: {} <- {}\n{}", - dst_type->to_string(), value_type->to_string(), get_tb()); + ErrorEmitter(TaichiCastWarning(), this, + fmt::format("Append may lose precision: {} <- {}\n{}", + dst_type->to_string(), value_type->to_string(), + get_tb())); } values[i] = cast(values[i], dst_type); values[i]->type_check(config); @@ -1080,10 +1125,12 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { ctx->push_back(snode, indices_stmt, true, is_cell_access); ptr->set_tb(get_tb()); if (op_type == SNodeOpType::is_active) { - TI_ERROR_IF(snode->type != SNodeType::pointer && - snode->type != SNodeType::hash && - snode->type != SNodeType::bitmasked, - "ti.is_active only works on pointer, hash or bitmasked nodes."); + if (!(snode->type == SNodeType::pointer || snode->type == SNodeType::hash || + snode->type == SNodeType::bitmasked)) { + ErrorEmitter( + TaichiTypeError(), this, + "ti.is_active only works on pointer, hash or bitmasked nodes."); + } ctx->push_back(SNodeOpType::is_active, snode, ptr, nullptr); } else if (op_type == SNodeOpType::length) { ctx->push_back(SNodeOpType::length, snode, ptr, nullptr); @@ -1102,8 +1149,10 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { ctx->push_back(ch_addr, value_stmt)->set_tb(get_tb()); } ctx->push_back(alloca)->set_tb(get_tb()); - TI_ERROR_IF(snode->type != SNodeType::dynamic, - "ti.append only works on dynamic nodes."); + if (snode->type != SNodeType::dynamic) { + ErrorEmitter(TaichiTypeError(), this, + "ti.append only works on dynamic nodes."); + } } stmt = ctx->back_stmt(); } @@ -1119,15 +1168,18 @@ void TextureOpExpression::type_check(const CompileConfig *config) { auto ptr = texture_ptr.cast(); if (op == TextureOpType::kSampleLod) { // UV, Lod - TI_ASSERT_INFO(args.size() == ptr->num_dims + 1, - "Invalid number of args for sample_lod Texture op with a " - "{}-dimension texture", - ptr->num_dims); + if (args.size() != ptr->num_dims + 1) { + ErrorEmitter(TaichiTypeError(), this, + fmt::format("Invalid number of args for sample_lod Texture " + "op with a {}-dimension texture", + ptr->num_dims)); + } for (int i = 0; i < ptr->num_dims; i++) { TI_ASSERT_TYPE_CHECKED(args[i]); auto arg_type = args[i].get_rvalue_type(); if (arg_type != PrimitiveType::f32) { - throw TaichiTypeError( + ErrorEmitter( + TaichiTypeError(), this, fmt::format("Invalid type for texture sample_lod: '{}', all " "arguments must be f32", arg_type->to_string())); @@ -1143,7 +1195,8 @@ void TextureOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(args[i]); auto arg_type = args[i].get_rvalue_type(); if (arg_type != PrimitiveType::i32) { - throw TaichiTypeError( + ErrorEmitter( + TaichiTypeError(), this, fmt::format("Invalid type for texture fetch_texel: '{}', all " "arguments must be i32", arg_type->to_string())); @@ -1159,10 +1212,10 @@ void TextureOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(args[i]); auto arg_type = args[i].get_rvalue_type(); if (arg_type != PrimitiveType::i32) { - throw TaichiTypeError( - fmt::format("Invalid type for texture load: '{}', all " - "arguments must be i32", - arg_type->to_string())); + ErrorEmitter(TaichiTypeError(), this, + fmt::format("Invalid type for texture load: '{}', all " + "arguments must be i32", + arg_type->to_string())); } } } else if (op == TextureOpType::kStore) { @@ -1175,20 +1228,20 @@ void TextureOpExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(args[i]); auto arg_type = args[i].get_rvalue_type(); if (arg_type != PrimitiveType::i32) { - throw TaichiTypeError( - fmt::format("Invalid type for texture load: '{}', index " - "arguments must be i32", - arg_type->to_string())); + ErrorEmitter(TaichiTypeError(), this, + fmt::format("Invalid type for texture load: '{}', index " + "arguments must be i32", + arg_type->to_string())); } } for (int i = ptr->num_dims; i < ptr->num_dims + 4; i++) { TI_ASSERT_TYPE_CHECKED(args[i]); auto arg_type = args[i].get_rvalue_type(); if (arg_type != PrimitiveType::f32) { - throw TaichiTypeError( - fmt::format("Invalid type for texture load: '{}', value " - "arguments must be f32", - arg_type->to_string())); + ErrorEmitter(TaichiTypeError(), this, + fmt::format("Invalid type for texture load: '{}', value " + "arguments must be f32", + arg_type->to_string())); } } } else { @@ -1210,9 +1263,11 @@ void TextureOpExpression::flatten(FlattenContext *ctx) { } void ConstExpression::type_check(const CompileConfig *) { - TI_ASSERT_INFO( - val.dt->is() && val.dt != PrimitiveType::unknown, - "Invalid dt [{}] for ConstExpression", val.dt->to_string()); + if (!(val.dt->is() && val.dt != PrimitiveType::unknown)) { + ErrorEmitter(TaichiTypeError(), this, + fmt::format("Invalid dt [{}] for ConstExpression", + val.dt->to_string())); + } ret_type = val.dt; } @@ -1222,10 +1277,13 @@ void ConstExpression::flatten(FlattenContext *ctx) { } void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) { - TI_ASSERT_INFO( - ptr.is() || ptr.is(), - "Invalid ptr [{}] for ExternalTensorShapeAlongAxisExpression", - ExpressionHumanFriendlyPrinter::expr_to_string(ptr)); + if (!(ptr.is() || ptr.is())) { + ErrorEmitter( + TaichiTypeError(), this, + fmt::format( + "Invalid ptr [{}] for ExternalTensorShapeAlongAxisExpression", + ExpressionHumanFriendlyPrinter::expr_to_string(ptr))); + } ret_type = PrimitiveType::i32; } @@ -1254,9 +1312,12 @@ void ExternalTensorBasePtrExpression::flatten(FlattenContext *ctx) { void GetElementExpression::type_check(const CompileConfig *config) { TI_ASSERT_TYPE_CHECKED(src); auto src_type = src->ret_type; - TI_ASSERT_INFO(src_type->is(), - "Invalid src [{}] for GetElementExpression", - ExpressionHumanFriendlyPrinter::expr_to_string(src)); + if (!src_type->is()) { + ErrorEmitter( + TaichiTypeError(), this, + fmt::format("Invalid src [{}] for GetElementExpression", + ExpressionHumanFriendlyPrinter::expr_to_string(src))); + } ret_type = src_type.ptr_removed()->as()->get_element_type(index); } @@ -1356,8 +1417,10 @@ void ASTBuilder::insert_assignment(Expr &lhs, this->insert(std::move(stmt)); } else { - TI_ERROR("Cannot assign to non-lvalue: {}", - ExpressionHumanFriendlyPrinter::expr_to_string(lhs)); + ErrorEmitter( + TaichiRuntimeError(), lhs.expr.get(), + fmt::format("Cannot assign to non-lvalue: {}", + ExpressionHumanFriendlyPrinter::expr_to_string(lhs))); } } diff --git a/taichi/transforms/frontend_type_check.cpp b/taichi/transforms/frontend_type_check.cpp index 27eac4fc86c445..6db23529efb7cf 100644 --- a/taichi/transforms/frontend_type_check.cpp +++ b/taichi/transforms/frontend_type_check.cpp @@ -6,13 +6,16 @@ namespace taichi::lang { class FrontendTypeCheck : public IRVisitor { - void check_cond_type(const Expr &cond, std::string stmt_name) { - auto cond_type = cond.get_rvalue_type(); - if (!cond_type->is() || !is_integral(cond_type)) - throw TaichiTypeError(fmt::format( - "`{0}` conditions must be an integer; found {1}. Consider using " - "`{0} x != 0` instead of `{0} x` for float values.", - stmt_name, cond_type->to_string())); + void check_cond_type(const Expr &cond, const std::string &stmt_name) { + DataType cond_type = cond.get_rvalue_type(); + if (!cond_type->is() || !is_integral(cond_type)) { + ErrorEmitter( + TaichiTypeError(), cond, + fmt::format("`{0}` conditions must be an integer; found {1}. " + "Consider using " + "`{0} x != 0` instead of `{0} x` for float values.", + stmt_name, cond_type->to_string())); + } } public: @@ -42,7 +45,35 @@ class FrontendTypeCheck : public IRVisitor { } void visit(FrontendSNodeOpStmt *stmt) override { - // Noop + if (!stmt->ret_type.ptr_removed().get_element_type()->is_primitive( + PrimitiveTypeID::unknown)) { + // pass + } else if (stmt->snode) { + stmt->ret_type = + TypeFactory::get_instance().get_pointer_type(stmt->snode->dt); + } else + ErrorEmitter(TaichiTypeWarning(), stmt, + "Type inference failed: snode is nullptr."); + auto check_indices = [&](SNode *snode) { + if (snode->num_active_indices != stmt->indices.size()) { + ErrorEmitter( + TaichiRuntimeError(), stmt, + fmt::format("{} has {} indices. Indexed with {}.", + snode->node_type_name, snode->num_active_indices, + stmt->indices.size())); + } + }; + auto is_cell_access = SNodeOpStmt::activation_related(stmt->op_type) && + stmt->snode->type != SNodeType::dynamic; + check_indices(is_cell_access ? stmt->snode : stmt->snode->parent); + for (int i = 0; i < stmt->indices.size(); i++) { + if (!stmt->indices[i]->ret_type->is_primitive(PrimitiveTypeID::i32)) { + ErrorEmitter( + TaichiCastWarning(), stmt, + fmt::format( + "Field index {} not int32, casting into int32 implicitly", i)); + } + } } void visit(FrontendAssertStmt *stmt) override { @@ -50,19 +81,28 @@ class FrontendTypeCheck : public IRVisitor { } void visit(FrontendAssignStmt *stmt) override { - auto lhs_type = stmt->lhs->ret_type.ptr_removed(); - auto rhs_type = stmt->rhs->ret_type.ptr_removed(); - - auto error = [&]() { - throw TaichiTypeError(fmt::format("{}cannot assign '{}' to '{}'", - stmt->get_tb(), rhs_type->to_string(), - lhs_type->to_string())); - }; + auto const &lhs_type = stmt->lhs->ret_type.ptr_removed(); + auto const &rhs_type = stmt->rhs->ret_type.ptr_removed(); // No implicit cast at frontend for now if (is_tensor(lhs_type) && is_tensor(rhs_type) && lhs_type.get_shape() != rhs_type.get_shape()) { - error(); + ErrorEmitter(TaichiTypeError(), stmt, + fmt::format("cannot assign '{}' to '{}'", + rhs_type->to_string(), lhs_type->to_string())); + } + + auto const &lhs_element_type = lhs_type.get_element_type(); + auto const &rhs_element_type = rhs_type.get_element_type(); + + if (lhs_element_type != rhs_element_type) { + auto promoted = promoted_type(lhs_element_type, rhs_element_type); + if (lhs_element_type != promoted) { + ErrorEmitter(TaichiCastWarning(), stmt, + fmt::format("Assign may lose precision: {} <- {}", + lhs_element_type->to_string(), + rhs_element_type->to_string())); + } } } @@ -110,8 +150,9 @@ class FrontendTypeCheck : public IRVisitor { constexpr std::string_view real_group = "fFeEaAgG"; if (unsupported_group.find(conversion) != std::string::npos) { - throw TaichiTypeError(fmt::format("{}conversion '{}' is not supported.", - stmt->get_tb(), conversion)); + ErrorEmitter( + TaichiTypeError(), stmt, + fmt::format("conversion '{}' is not supported.", conversion)); } if ((real_group.find(conversion) != std::string::npos && @@ -120,9 +161,9 @@ class FrontendTypeCheck : public IRVisitor { !(is_integral(data_type) && is_signed(data_type))) || (unsigned_group.find(conversion) != std::string::npos && !(is_integral(data_type) && is_unsigned(data_type)))) { - throw TaichiTypeError(fmt::format("{} '{}' doesn't match '{}'.", - stmt->get_tb(), format_spec, - data_type->to_string())); + ErrorEmitter(TaichiTypeError(), stmt, + fmt::format("'{}' doesn't match '{}'.", format_spec, + data_type->to_string())); } } } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 35af0a04018fac..1ccc1f091a9559 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -18,10 +18,7 @@ class TypeCheck : public IRVisitor { private: CompileConfig config_; - Type *type_check_store(Stmt *stmt, - Stmt *dst, - Stmt *&val, - const std::string &stmt_name) { + Type *type_check_store(Stmt *stmt, Stmt *dst, Stmt *&val) { auto dst_type = dst->ret_type.ptr_removed(); auto val_type = val->ret_type.ptr_removed(); if (is_quant(dst_type)) { @@ -30,12 +27,6 @@ class TypeCheck : public IRVisitor { dst_type = dst_type->get_compute_type(); } if (dst_type != val_type) { - auto promoted = promoted_type(dst_type, val_type); - if (dst_type != promoted) { - TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(), - stmt_name, dst_type->to_string(), val->ret_data_type_name(), - stmt->get_tb()); - } val = insert_type_cast_before(stmt, val, dst_type); } return dst_type; @@ -80,9 +71,7 @@ class TypeCheck : public IRVisitor { void visit(AtomicOpStmt *stmt) override { // TODO(type): test_ad_for fails if we assume dest is a pointer type. - stmt->ret_type = type_check_store( - stmt, stmt->dest, stmt->val, - fmt::format("Atomic {}", atomic_op_type_name(stmt->op_type))); + stmt->ret_type = type_check_store(stmt, stmt->dest, stmt->val); } void visit(LocalLoadStmt *stmt) override { @@ -105,8 +94,7 @@ class TypeCheck : public IRVisitor { // Infer data type for alloca stmt->dest->ret_type = stmt->val->ret_type; } - stmt->ret_type = - type_check_store(stmt, stmt->dest, stmt->val, "Local store"); + stmt->ret_type = type_check_store(stmt, stmt->dest, stmt->val); } void visit(GlobalLoadStmt *stmt) override { @@ -144,22 +132,24 @@ class TypeCheck : public IRVisitor { stmt->ret_type = TypeFactory::get_instance().get_pointer_type(stmt->snode->dt); } else - TI_WARN("[{}] Type inference failed: snode is nullptr.\n{}", stmt->name(), - stmt->get_tb()); + ErrorEmitter(TaichiTypeWarning(), stmt, + "Type inference failed: snode is nullptr."); auto check_indices = [&](SNode *snode) { if (snode->num_active_indices != stmt->indices.size()) { - TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(), - snode->node_type_name, snode->num_active_indices, - stmt->indices.size()); + ErrorEmitter( + TaichiRuntimeError(), stmt, + fmt::format("{} has {} indices. Indexed with {}.", + snode->node_type_name, snode->num_active_indices, + stmt->indices.size())); } }; check_indices(stmt->is_cell_access ? stmt->snode : stmt->snode->parent); for (int i = 0; i < stmt->indices.size(); i++) { if (!stmt->indices[i]->ret_type->is_primitive(PrimitiveTypeID::i32)) { - TI_WARN( - "[{}] Field index {} not int32, casting into int32 " - "implicitly\n{}", - stmt->name(), i, stmt->get_tb()); + ErrorEmitter( + TaichiCastWarning(), stmt, + fmt::format( + "Field index {} not int32, casting into int32 implicitly", i)); stmt->indices[i] = insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32); } @@ -173,7 +163,7 @@ class TypeCheck : public IRVisitor { } void visit(GlobalStoreStmt *stmt) override { - type_check_store(stmt, stmt->dest, stmt->val, "Global store"); + type_check_store(stmt, stmt->dest, stmt->val); } void visit(RangeForStmt *stmt) override { @@ -298,16 +288,12 @@ class TypeCheck : public IRVisitor { } void visit(BinaryOpStmt *stmt) override { - auto error = [&](std::string comment = "") { - if (comment == "") { - TI_WARN("[{}] Type mismatch (left = {}, right = {}, stmt_id = {})\n{}", - stmt->name(), stmt->lhs->ret_data_type_name(), - stmt->rhs->ret_data_type_name(), stmt->id, stmt->get_tb()); - } else { - TI_WARN("[{}] {}\n{}", stmt->name(), comment, stmt->get_tb()); - } - TI_WARN("Compilation stopped due to type mismatch."); - throw std::runtime_error("Binary operator type mismatch"); + auto error = [&]() { + ErrorEmitter( + TaichiTypeError(), stmt, + fmt::format("Type mismatch (left = {}, right = {}, stmt_id = {})", + stmt->lhs->ret_data_type_name(), + stmt->rhs->ret_data_type_name(), stmt->id)); }; if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::unknown) && stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::unknown)) From 632584f9f3c28a9a3f932850d010f5e50432f2b7 Mon Sep 17 00:00:00 2001 From: Qian Bao Date: Tue, 18 Jul 2023 21:57:52 +0800 Subject: [PATCH 3/7] [bug] Fix build.py Powershell part. (#8291) Issue: # ### Brief Summary "-Interactive" seems not to be a valid argument for the current Powershell. On my Windows 11 machine, invoke `python build.py --shell` gives the following error: ![image](https://github.com/taichi-dev/taichi/assets/2747993/fd58e452-30fd-4d0f-a1fd-145c1fa3f548) This PR removes the "-Interactive" flag, and add a "ExecutionPolicy" flag which is required for running the remaining part of the `build.py` script. --- .github/workflows/scripts/ti_build/alter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scripts/ti_build/alter.py b/.github/workflows/scripts/ti_build/alter.py index 69622be5b50191..b1abcd7dcd959b 100644 --- a/.github/workflows/scripts/ti_build/alter.py +++ b/.github/workflows/scripts/ti_build/alter.py @@ -117,7 +117,7 @@ def enter_shell(): if shell.name in ("pwsh.exe", "powershell.exe"): pwsh = Command(shell.exe) path = _write_ti_pwshrc() - pwsh("-Interactive", "-NoExit", "-File", str(path)) + pwsh("-ExecutionPolicy", "Bypass", "-NoExit", "-File", str(path)) elif shell.name == "cmd.exe": cmd = Command(shell.exe) cmd("/k", "set", "PROMPT=TaichiBuild $P$G") From 6fc217fb70dcde0c2241e30a95d5125a328d2cc4 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Wed, 19 Jul 2023 15:22:58 +0800 Subject: [PATCH 4/7] [Lang] Support TensorType for SharedArray (#8258) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: # ### Brief Summary ### 🤖 Generated by Copilot at 8854acc This pull request adds support for shared arrays of matrix types in the Taichi language and IR. It enables nested matrix pointer statements for accessing such arrays, and enhances the scalarize and offload passes to handle them. It also adds a test case and some temporary debugging code. ### Walkthrough ### 🤖 Generated by Copilot at 8854acc * Add support for shared arrays of matrix types ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-aee943d584058490d7717d34c02a3783d3487694dc091653d42b202e45b1e097R57-R58), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L729-R733), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5L108-R110), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L491-R491), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L529-R559), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-714976dfca6f0a5c18d59780a3f0be6f007ff8cf8cddafedd2bf07f755519c2bR144-R165)) - Check and convert the dtype argument of the SharedArray constructor in `python/taichi/lang/simt/block.py` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-aee943d584058490d7717d34c02a3783d3487694dc091653d42b202e45b1e097R57-R58)) - Avoid treating shared arrays as tensors in the code generation for matrix pointer statements in `taichi/ir/frontend_ir.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L729-R733)) - Allow nested matrix pointer statements to access shared arrays of matrix types in `taichi/ir/statements.cpp` and `taichi/ir/statements.h` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5L108-R110), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L491-R491)) - Merge with previous matrix pointer statements in the scalarize pass in `taichi/transforms/scalarize.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L529-R559), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528R65)) - Add a test case that creates and reads a shared array of vec4 type in `tests/python/test_shared_array.py` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-714976dfca6f0a5c18d59780a3f0be6f007ff8cf8cddafedd2bf07f755519c2bR144-R165), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-714976dfca6f0a5c18d59780a3f0be6f007ff8cf8cddafedd2bf07f755519c2bR5)) * Simplify the creation of constant statements for different types in `taichi/ir/statements.h` and `taichi/transforms/offload.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260R2021-R2050), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-d47b571f975c1002b8cb93634ac2a3d5f090f3fa9676ec3e0004c2ec4116ee21L541-R550)) - Add a template function that creates a vector of statements that represent a constant value of a given data type in `taichi/ir/statements.h` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260R2021-R2050)) - Use the template function to create a constant zero statement and a global store statement for a shared array allocation in `taichi/transforms/offload.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-d47b571f975c1002b8cb93634ac2a3d5f090f3fa9676ec3e0004c2ec4116ee21L541-R550)) * Add debugging code for matrix pointer statements in `taichi/ir/statements.cpp`, `taichi/transforms/compile_to_offloads.cpp`, and `taichi/transforms/lower_matrix_ptr.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5R6), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5R123), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bR297-R299), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-9b36b48490841b4018aca81632ae1beac3b2fdf1ee95a5c65eb42b676654b82eR242-R243)) - Include the signal.h header file and call raise(SIGSEGV) to trigger a segmentation fault when the origin of a matrix pointer statement is invalid in `taichi/ir/statements.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5R6), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5R123)) - Call the die pass and print the IR after the offload pass in `taichi/transforms/compile_to_offloads.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bR297-R299)) - Print whether a matrix pointer statement is null or not in `taichi/transforms/lower_matrix_ptr.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-9b36b48490841b4018aca81632ae1beac3b2fdf1ee95a5c65eb42b676654b82eR242-R243)) * Remove or add empty lines in `taichi/ir/type.h` and `taichi/transforms/scalarize.cpp` ([link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-87004ca8d67f31ff19e6bc1a62cd6e6d87c09b237197b93533ec6bf617f29149L743), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528R118), [link](https://github.com/taichi-dev/taichi/pull/8258/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L1317)) --- python/taichi/lang/simt/block.py | 2 ++ taichi/ir/frontend_ir.cpp | 6 ++++- taichi/ir/statements.cpp | 3 ++- taichi/ir/statements.h | 2 +- taichi/ir/type.h | 1 - taichi/transforms/compile_to_offloads.cpp | 3 +++ taichi/transforms/offload.cpp | 23 ++++++---------- taichi/transforms/scalarize.cpp | 33 ++++++++++++++++++++--- tests/python/test_shared_array.py | 26 ++++++++++++++++++ 9 files changed, 77 insertions(+), 22 deletions(-) diff --git a/python/taichi/lang/simt/block.py b/python/taichi/lang/simt/block.py index f680b5dca152e9..80296a14c1414f 100644 --- a/python/taichi/lang/simt/block.py +++ b/python/taichi/lang/simt/block.py @@ -54,6 +54,8 @@ def __init__(self, shape, dtype): raise ValueError( f"ti.simt.block.shared_array shape must be an integer or a tuple of integers, but got {shape}" ) + if isinstance(dtype, impl.MatrixType): + dtype = dtype.tensor_type self.dtype = dtype self.shared_array_proxy = impl.expr_init_shared_array(self.shape, dtype) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 31f363b8f44cf5..95a1df3127c045 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -747,7 +747,11 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, ctx->push_back(alloca_stmt, var_stmt); var_stmt = alloca_stmt; } - if (ret_type.ptr_removed()->is()) { + + bool is_shared_array = + (var_stmt->is() && var_stmt->as()->is_shared); + + if (ret_type.ptr_removed()->is() && !is_shared_array) { std::vector stmts; for (auto &indices : indices_group) { stmts.push_back( diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 1fd6274132437f..ddf9fa494b298f 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -105,7 +105,8 @@ MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input, if (origin->is() || origin->is() || origin->is() || origin->is() || - origin->is() || origin->is()) { + origin->is() || origin->is() || + origin->is()) { auto tensor_type = origin->ret_type.ptr_removed()->cast(); TI_ASSERT(tensor_type != nullptr); element_type() = tensor_type->get_element_type(); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index ec8c399094637f..c419e8e89c582c 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -499,7 +499,7 @@ class MatrixPtrStmt : public Stmt { */ bool offset_used_as_index() const { if (origin->is() || origin->is() || - origin->is()) { + origin->is() || origin->is()) { TI_ASSERT_INFO(origin->ret_type.ptr_removed()->is(), "MatrixPtrStmt can only be used for TensorType."); return true; diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 50881649e7e6cd..936138a0922fb6 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -772,7 +772,6 @@ Type::jsonserde_ptr_io(const T *&ptr, } } } - } // namespace taichi::lang namespace taichi::hashing { diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 23759f43ee65cc..4b770e2298b327 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -294,6 +294,9 @@ void offload_to_executable(IRNode *ir, !get_custom_cuda_library_path().empty()); if (config.real_matrix_scalarize) { if (irpass::scalarize(ir, half2_optimization_enabled)) { + irpass::die(ir); + print("DIE"); + // Remove redundant MatrixInitStmt inserted during scalarization irpass::full_simplify( ir, config, diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index ada6e692689d87..c402a56559965b 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -538,22 +538,15 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { auto offloaded = stmt_to_offloaded_[stmt]; stmt_to_offloaded_[ptr] = offloaded; - TypedConstant zero(alloca_type.get_element_type()); - auto const_zero_stmt = replacement.push_back(zero); - if (auto tensor_type = alloca_type->cast()) { - std::vector zero_values(tensor_type->get_num_elements(), - const_zero_stmt); - auto zero_matrix_init_stmt = - replacement.push_back(zero_values); - zero_matrix_init_stmt->ret_type = stmt->ret_type.ptr_removed(); - auto global_store_stmt = - replacement.push_back(ptr, zero_matrix_init_stmt); - stmt_to_offloaded_[global_store_stmt] = offloaded; - } else { - auto global_store_stmt = - replacement.push_back(ptr, const_zero_stmt); - stmt_to_offloaded_[global_store_stmt] = offloaded; + auto stmts = get_const_stmt_with_value(alloca_type, 0); + for (auto &stmt : stmts) { + replacement.push_back(std::move(stmt)); } + Stmt *const_stmt = replacement.back().get(); + + auto global_store_stmt = + replacement.push_back(ptr, const_stmt); + stmt_to_offloaded_[global_store_stmt] = offloaded; stmt->parent->replace_with(stmt, std::move(replacement), false); // To deal with the same offloaded visit_operand() diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 1dd36726c890b9..f840583892d92d 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -62,6 +62,7 @@ class Scalarize : public BasicStmtVisitor { auto const_stmt = std::make_unique( TypedConstant(get_data_type(), i)); + // Merge with previous MatrixPtrStmt auto matrix_ptr_stmt = std::make_unique(stmt->dest, const_stmt.get()); matrix_ptr_stmt->ret_type = primitive_type; @@ -114,6 +115,7 @@ class Scalarize : public BasicStmtVisitor { auto matrix_ptr_stmt = std::make_unique(stmt->src, const_stmt.get()); + matrix_ptr_stmt->ret_type = primitive_type; matrix_ptr_stmt->ret_type.set_is_pointer(true); @@ -526,8 +528,34 @@ class Scalarize : public BasicStmtVisitor { // scalarize to dest_i auto const_stmt = std::make_unique( TypedConstant(get_data_type(), i)); - auto matrix_ptr_stmt = - std::make_unique(stmt->dest, const_stmt.get()); + + // Merge with previous MatrixPtrStmt + std::unique_ptr matrix_ptr_stmt = nullptr; + if (stmt->dest->is()) { + /* + <*[Tensor (16) [Tensor (4) f32]]> $5 = alloca(shared) + <*[Tensor (4) f32]> $11 = shift ptr [$5 + $9] + <[Tensor (4) f32]> $12 : local store [$11 <- $10] + */ + auto matrix_ptr_stmt_ptr = stmt->dest->as(); + + auto base_offset = matrix_ptr_stmt_ptr->offset; + auto merged_offset = std::make_unique( + BinaryOpType::add, base_offset, const_stmt.get()); + merged_offset->ret_type = base_offset->ret_type; + matrix_ptr_stmt = std::make_unique( + matrix_ptr_stmt_ptr->origin, merged_offset.get()); + matrix_ptr_stmt->ret_type = primitive_type; + matrix_ptr_stmt->ret_type.set_is_pointer(true); + + delayed_modifier_.insert_before(stmt, std::move(merged_offset)); + } else { + matrix_ptr_stmt = + std::make_unique(stmt->dest, const_stmt.get()); + } + + matrix_ptr_stmt->ret_type = primitive_type; + matrix_ptr_stmt->ret_type.set_is_pointer(true); // scalarize to val_i auto val_stmt = val_values[i]; @@ -1315,7 +1343,6 @@ bool scalarize(IRNode *root, bool half2_optimization_enabled) { modified |= ScalarizePointers::run(root, scalarizable_allocas); modified |= ExtractLocalPointers::run(root); modified |= FuseMatrixPtr::run(root); - return modified; } diff --git a/tests/python/test_shared_array.py b/tests/python/test_shared_array.py index c18f87432957e2..64e499eef2dad2 100644 --- a/tests/python/test_shared_array.py +++ b/tests/python/test_shared_array.py @@ -2,6 +2,7 @@ import pytest import taichi as ti +from taichi.math import vec4 from tests import test_utils @@ -140,3 +141,28 @@ def atomic_test(out: ti.types.ndarray()): assert arr[32] == sum assert arr[128] == sum assert arr[224] == sum + + +@test_utils.test(arch=[ti.cuda]) +def test_shared_array_tensor_type(): + data_type = vec4 + block_dim = 16 + N = 64 + + y = ti.Vector.field(4, dtype=ti.f32, shape=(block_dim)) + + @ti.kernel + def test(): + ti.loop_config(block_dim=block_dim) + for i in range(N): + tid = i % block_dim + val = ti.Vector([1.0, 2.0, 3.0, 4.0]) + + shared_mem = ti.simt.block.SharedArray((block_dim), data_type) + shared_mem[tid] = val + ti.simt.block.sync() + + y[tid] += shared_mem[tid] + + test() + assert (y.to_numpy()[0] == [4.0, 8.0, 12.0, 16.0]).all() From 0ecd0adfa4383772f45f910702729d1ed8bdb007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AD=94=E6=B3=95=E5=B0=91=E5=A5=B3=E8=B5=B5=E5=BF=97?= =?UTF-8?q?=E8=BE=89?= Date: Wed, 19 Jul 2023 19:12:11 +0800 Subject: [PATCH 5/7] [Lang] Pass DebugInfo from Python to C++ for ndarray and field (#8286) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: # ### Brief Summary ### 🤖 Generated by Copilot at 00d994a This pull request improves the error reporting and type-checking functionality for `Ndarray` and `SNode` objects in Taichi. It introduces new classes and structs such as `DebugInfo`, `TaichiError`, and `ErrorEmitter` to pass and handle the debug information and exceptions. It also refactors and simplifies the existing code in various files to use these new classes and structs, and to provide more consistent and informative messages to the user. ### Walkthrough ### 🤖 Generated by Copilot at 00d994a * Add new exception and warning classes for different types of errors and warnings (`[link](https://github.com/taichi-dev/taichi/pull/8286/files?diff=unified&w=0#diff-ccaf2900f7c75403d5aecff661ea03785341c847f0f7b1a5c75c6b93e5ede5d9L18-R134)`) --- cpp_examples/aot_save.cpp | 2 +- cpp_examples/autograd.cpp | 5 +- cpp_examples/run_snode.cpp | 2 +- python/taichi/lang/_ndarray.py | 6 +- python/taichi/lang/expr.py | 6 +- python/taichi/lang/matrix.py | 13 +++- python/taichi/lang/snode.py | 12 ++-- taichi/ir/snode.cpp | 67 ++++++++++------- taichi/ir/snode.h | 71 +++++++++++-------- taichi/program/ndarray.cpp | 20 ++++-- taichi/program/ndarray.h | 11 ++- taichi/program/program.cpp | 5 +- taichi/program/program.h | 3 +- taichi/program/snode_expr_utils.cpp | 2 +- taichi/python/export_lang.cpp | 23 ++++-- tests/cpp/analysis/bls_analyzer_test.cpp | 2 +- tests/cpp/aot/dx12/aot_save_load_test.cpp | 2 +- tests/cpp/codegen/refine_coordinates_test.cpp | 4 +- tests/cpp/struct/snode_tree_test.cpp | 4 +- .../cpp/transforms/make_block_local_test.cpp | 6 +- .../scalar_pointer_lowerer_test.cpp | 10 +-- 21 files changed, 173 insertions(+), 103 deletions(-) diff --git a/cpp_examples/aot_save.cpp b/cpp_examples/aot_save.cpp index 7defd217a8dca6..4ff6b2064b921b 100644 --- a/cpp_examples/aot_save.cpp +++ b/cpp_examples/aot_save.cpp @@ -13,7 +13,7 @@ void aot_save(taichi::Arch arch) { // program.materialize_runtime(); auto *root = new SNode(0, SNodeType::root); - auto *pointer = &root->dense(Axis(0), n, ""); + auto *pointer = &root->dense(Axis(0), n); auto *place = &pointer->insert_children(SNodeType::place); place->dt = PrimitiveType::i32; program.add_snode_tree(std::unique_ptr(root), /*compile_only=*/true); diff --git a/cpp_examples/autograd.cpp b/cpp_examples/autograd.cpp index 135aea17b73944..6b503ddc4938a6 100644 --- a/cpp_examples/autograd.cpp +++ b/cpp_examples/autograd.cpp @@ -90,11 +90,10 @@ void autograd() { } }; - auto *snode = - &root->dense(Axis(0), n, "").insert_children(SNodeType::place); + auto *snode = &root->dense(Axis(0), n).insert_children(SNodeType::place); snode->dt = PrimitiveType::f32; snode->grad_info = std::make_unique( - &root->dense(Axis(0), n, "").insert_children(SNodeType::place)); + &root->dense(Axis(0), n).insert_children(SNodeType::place)); snode->get_adjoint()->dt = PrimitiveType::f32; snode->get_adjoint()->grad_info = std::make_unique(); return snode; diff --git a/cpp_examples/run_snode.cpp b/cpp_examples/run_snode.cpp index e0449a35ee5f6d..29d5224e991c5f 100644 --- a/cpp_examples/run_snode.cpp +++ b/cpp_examples/run_snode.cpp @@ -48,7 +48,7 @@ void run_snode() { int n = 10; program.materialize_runtime(); auto *root = new SNode(0, SNodeType::root); - auto *pointer = &root->pointer(Axis(0), n, ""); + auto *pointer = &root->pointer(Axis(0), n); auto *place = &pointer->insert_children(SNodeType::place); place->dt = PrimitiveType::i32; program.add_snode_tree(std::unique_ptr(root), /*compile_only=*/false); diff --git a/python/taichi/lang/_ndarray.py b/python/taichi/lang/_ndarray.py index cb1e37c726b2ba..796366adf345e3 100644 --- a/python/taichi/lang/_ndarray.py +++ b/python/taichi/lang/_ndarray.py @@ -3,7 +3,7 @@ from taichi.lang import impl from taichi.lang.enums import Layout from taichi.lang.exception import TaichiIndexError -from taichi.lang.util import cook_dtype, python_scope, to_numpy_type +from taichi.lang.util import cook_dtype, get_traceback, python_scope, to_numpy_type from taichi.types import primitive_types from taichi.types.ndarray_type import NdarrayTypeMetadata from taichi.types.utils import is_real, is_signed @@ -237,7 +237,9 @@ class ScalarNdarray(Ndarray): def __init__(self, dtype, arr_shape): super().__init__() self.dtype = cook_dtype(dtype) - self.arr = impl.get_runtime().prog.create_ndarray(self.dtype, arr_shape, layout=Layout.NULL, zero_fill=True) + self.arr = impl.get_runtime().prog.create_ndarray( + self.dtype, arr_shape, layout=Layout.NULL, zero_fill=True, dbg_info=_ti_core.DebugInfo(get_traceback()) + ) self.shape = tuple(self.arr.shape) self.element_type = dtype diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 2a04d77f7cc2be..68f4217bf15678 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -15,11 +15,13 @@ class Expr(TaichiOperations): def __init__(self, *args, tb=None, dtype=None): self.tb = tb + self.ptr_type_checked = False if len(args) == 1: if isinstance(args[0], _ti_core.Expr): self.ptr = args[0] elif isinstance(args[0], Expr): self.ptr = args[0].ptr + self.ptr_type_checked = args[0].ptr_type_checked self.tb = args[0].tb elif is_matrix_class(args[0]): self.ptr = make_matrix(args[0].to_list()).ptr @@ -39,7 +41,9 @@ def __init__(self, *args, tb=None, dtype=None): assert False if self.tb: self.ptr.set_tb(self.tb) - self.ptr.type_check(impl.get_runtime().prog.config()) + if not self.ptr_type_checked: + self.ptr.type_check(impl.get_runtime().prog.config()) + self.ptr_type_checked = True def is_tensor(self): return self.ptr.is_tensor() diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index be3fb5338e9b97..4e8abaf8d95aeb 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -20,6 +20,7 @@ from taichi.lang.field import Field, ScalarField, SNodeHostAccess from taichi.lang.util import ( cook_dtype, + get_traceback, in_python_scope, python_scope, taichi_scope, @@ -1655,7 +1656,11 @@ def __init__(self, n, m, dtype, shape): self.element_type = _type_factory.get_tensor_type((self.n, self.m), self.dtype) # TODO: we should pass in element_type, shape, layout instead. self.arr = impl.get_runtime().prog.create_ndarray( - cook_dtype(self.element_type), shape, Layout.AOS, zero_fill=True + cook_dtype(self.element_type), + shape, + Layout.AOS, + zero_fill=True, + dbg_info=ti_python_core.DebugInfo(get_traceback()), ) @property @@ -1765,7 +1770,11 @@ def __init__(self, n, dtype, shape): self.shape = tuple(shape) self.element_type = _type_factory.get_tensor_type((n,), self.dtype) self.arr = impl.get_runtime().prog.create_ndarray( - cook_dtype(self.element_type), shape, Layout.AOS, zero_fill=True + cook_dtype(self.element_type), + shape, + Layout.AOS, + zero_fill=True, + dbg_info=ti_python_core.DebugInfo(get_traceback()), ) @property diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index 7acddd9c42a3ec..e98e4f9d611850 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -35,7 +35,7 @@ def dense(self, axes, dimensions): """ if isinstance(dimensions, numbers.Number): dimensions = [dimensions] * len(axes) - return SNode(self.ptr.dense(axes, dimensions, get_traceback())) + return SNode(self.ptr.dense(axes, dimensions, _ti_core.DebugInfo(get_traceback()))) def pointer(self, axes, dimensions): """Adds a pointer SNode as a child component of `self`. @@ -51,7 +51,7 @@ def pointer(self, axes, dimensions): raise TaichiRuntimeError("Pointer SNode is not supported on this backend.") if isinstance(dimensions, numbers.Number): dimensions = [dimensions] * len(axes) - return SNode(self.ptr.pointer(axes, dimensions, get_traceback())) + return SNode(self.ptr.pointer(axes, dimensions, _ti_core.DebugInfo(get_traceback()))) @staticmethod def _hash(axes, dimensions): @@ -78,7 +78,7 @@ def dynamic(self, axis, dimension, chunk_size=None): assert len(axis) == 1 if chunk_size is None: chunk_size = dimension - return SNode(self.ptr.dynamic(axis[0], dimension, chunk_size, get_traceback())) + return SNode(self.ptr.dynamic(axis[0], dimension, chunk_size, _ti_core.DebugInfo(get_traceback()))) def bitmasked(self, axes, dimensions): """Adds a bitmasked SNode as a child component of `self`. @@ -94,7 +94,7 @@ def bitmasked(self, axes, dimensions): raise TaichiRuntimeError("Bitmasked SNode is not supported on this backend.") if isinstance(dimensions, numbers.Number): dimensions = [dimensions] * len(axes) - return SNode(self.ptr.bitmasked(axes, dimensions, get_traceback())) + return SNode(self.ptr.bitmasked(axes, dimensions, _ti_core.DebugInfo(get_traceback()))) def quant_array(self, axes, dimensions, max_num_bits): """Adds a quant_array SNode as a child component of `self`. @@ -109,7 +109,7 @@ def quant_array(self, axes, dimensions, max_num_bits): """ if isinstance(dimensions, numbers.Number): dimensions = [dimensions] * len(axes) - return SNode(self.ptr.quant_array(axes, dimensions, max_num_bits, get_traceback())) + return SNode(self.ptr.quant_array(axes, dimensions, max_num_bits, _ti_core.DebugInfo(get_traceback()))) def place(self, *args, offset=None): """Places a list of Taichi fields under the `self` container. @@ -129,7 +129,7 @@ def place(self, *args, offset=None): for arg in args: if isinstance(arg, BitpackedFields): bit_struct_type = arg.bit_struct_type_builder.build() - bit_struct_snode = self.ptr.bit_struct(bit_struct_type, get_traceback()) + bit_struct_snode = self.ptr.bit_struct(bit_struct_type, _ti_core.DebugInfo(get_traceback())) for field, id_in_bit_struct in arg.fields: bit_struct_snode.place(field, offset, id_in_bit_struct) elif isinstance(arg, Field): diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index 6627263a3962ac..d26c02cab35610 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -37,35 +37,47 @@ SNode &SNode::insert_children(SNodeType t) { SNode &SNode::create_node(std::vector axes, std::vector sizes, SNodeType type, - const std::string &tb) { - TI_ASSERT(axes.size() == sizes.size() || sizes.size() == 1); + const DebugInfo &dbg_info) { if (sizes.size() == 1) { sizes = std::vector(axes.size(), sizes[0]); } + if (axes.size() != sizes.size()) { + ErrorEmitter( + TaichiRuntimeError(), &dbg_info, + fmt::format( + "axes and sizes must have the same size, but got {} and {}.", + axes.size(), sizes.size())); + } - if (type == SNodeType::hash) - TI_ASSERT_INFO(depth == 0, - "hashed node must be child of root due to initialization " - "memset limitation."); + if (type == SNodeType::hash && depth != 0) { + ErrorEmitter(TaichiRuntimeError(), &dbg_info, + "hashed node must be child of root due to initialization " + "memset limitation."); + } auto &new_node = insert_children(type); for (int i = 0; i < (int)axes.size(); i++) { if (sizes[i] <= 0) { - throw TaichiRuntimeError( - "Every dimension of a Taichi field should be positive"); + ErrorEmitter(TaichiRuntimeError(), &dbg_info, + fmt::format("Every dimension of a Taichi field should be " + "positive, got {} in demension {}.", + sizes[i], i)); } + int ind = axes[i].value; auto end = new_node.physical_index_position + new_node.num_active_indices; bool is_first_division = std::find(new_node.physical_index_position, end, ind) == end; if (is_first_division) { new_node.physical_index_position[new_node.num_active_indices++] = ind; - } else { - TI_WARN_IF( - !bit::is_power_of_two(sizes[i]), - "Shape {} is detected on non-first division of axis {}:\n{} For " - "best performance, we recommend that you set it to a power of two.", - sizes[i], char('i' + ind), tb); + } else if (!bit::is_power_of_two(sizes[i])) { + ErrorEmitter( + TaichiRuntimeWarning(), &dbg_info, + fmt::format( + "Shape {} is detected on non-first division of axis {}. For " + "best performance, we recommend that you set it to a power of " + "two.", + sizes[i], char('i' + ind))); } new_node.extractors[ind].active = true; new_node.extractors[ind].num_elements_from_root *= sizes[i]; @@ -81,7 +93,8 @@ SNode &SNode::create_node(std::vector axes, acc_shape *= new_node.extractors[i].shape; } if (acc_shape > std::numeric_limits::max()) { - TI_WARN( + ErrorEmitter( + TaichiIndexWarning(), &dbg_info, "SNode index might be out of int32 boundary but int64 indexing is not " "supported yet. Struct fors might not work either."); } @@ -94,15 +107,19 @@ SNode &SNode::create_node(std::vector axes, active_extractor_counder += 1; SNode *p = new_node.parent; while (p) { - TI_ASSERT_INFO( - !p->extractors[i].active, - "Dynamic SNode must have a standalone dimensionality."); + if (p->extractors[i].active) { + ErrorEmitter( + TaichiRuntimeError(), &dbg_info, + "Dynamic SNode must have a standalone dimensionality."); + } p = p->parent; } } } - TI_ASSERT_INFO(active_extractor_counder == 1, + if (active_extractor_counder != 1) { + ErrorEmitter(TaichiRuntimeError(), &dbg_info, "Dynamic SNode can have only one index extractor."); + } } return new_node; } @@ -110,15 +127,15 @@ SNode &SNode::create_node(std::vector axes, SNode &SNode::dynamic(const Axis &expr, int n, int chunk_size, - const std::string &tb) { - auto &snode = create_node({expr}, {n}, SNodeType::dynamic, tb); + const DebugInfo &dbg_info) { + auto &snode = create_node({expr}, {n}, SNodeType::dynamic, dbg_info); snode.chunk_size = chunk_size; return snode; } SNode &SNode::bit_struct(BitStructType *bit_struct_type, - const std::string &tb) { - auto &snode = create_node({}, {}, SNodeType::bit_struct, tb); + const DebugInfo &dbg_info) { + auto &snode = create_node({}, {}, SNodeType::bit_struct, dbg_info); snode.dt = bit_struct_type; snode.physical_type = bit_struct_type->get_physical_type(); return snode; @@ -127,8 +144,8 @@ SNode &SNode::bit_struct(BitStructType *bit_struct_type, SNode &SNode::quant_array(const std::vector &axes, const std::vector &sizes, int bits, - const std::string &tb) { - auto &snode = create_node(axes, sizes, SNodeType::quant_array, tb); + const DebugInfo &dbg_info) { + auto &snode = create_node(axes, sizes, SNodeType::quant_array, dbg_info); snode.physical_type = TypeFactory::get_instance().get_primitive_int_type(bits, false); return snode; diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index fab5ca968541cf..7279b7d0f47fe6 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -138,81 +138,96 @@ class SNode { SNode &create_node(std::vector axes, std::vector sizes, SNodeType type, - const std::string &tb); + const DebugInfo &dbg_info = DebugInfo()); // SNodes maintains how flattened index bits are taken from indices SNode &dense(const std::vector &axes, const std::vector &sizes, - const std::string &tb) { - return create_node(axes, sizes, SNodeType::dense, tb); + const DebugInfo &dbg_info = DebugInfo()) { + return create_node(axes, sizes, SNodeType::dense, dbg_info); } SNode &dense(const std::vector &axes, int sizes, - const std::string &tb) { - return create_node(axes, std::vector{sizes}, SNodeType::dense, tb); + const DebugInfo &dbg_info = DebugInfo()) { + return create_node(axes, std::vector{sizes}, SNodeType::dense, + dbg_info); } - SNode &dense(const Axis &axis, int size, const std::string &tb) { - return SNode::dense(std::vector{axis}, size, tb); + SNode &dense(const Axis &axis, + int size, + const DebugInfo &dbg_info = DebugInfo()) { + return SNode::dense(std::vector{axis}, size, dbg_info); } SNode &pointer(const std::vector &axes, const std::vector &sizes, - const std::string &tb) { - return create_node(axes, sizes, SNodeType::pointer, tb); + const DebugInfo &dbg_info = DebugInfo()) { + return create_node(axes, sizes, SNodeType::pointer, dbg_info); } SNode &pointer(const std::vector &axes, int sizes, - const std::string &tb) { - return create_node(axes, std::vector{sizes}, SNodeType::pointer, tb); + const DebugInfo &dbg_info = DebugInfo()) { + return create_node(axes, std::vector{sizes}, SNodeType::pointer, + dbg_info); } - SNode &pointer(const Axis &axis, int size, const std::string &tb) { - return SNode::pointer(std::vector{axis}, size, tb); + SNode &pointer(const Axis &axis, + int size, + const DebugInfo &dbg_info = DebugInfo()) { + return SNode::pointer(std::vector{axis}, size, dbg_info); } SNode &bitmasked(const std::vector &axes, const std::vector &sizes, - const std::string &tb) { - return create_node(axes, sizes, SNodeType::bitmasked, tb); + const DebugInfo &dbg_info = DebugInfo()) { + return create_node(axes, sizes, SNodeType::bitmasked, dbg_info); } SNode &bitmasked(const std::vector &axes, int sizes, - const std::string &tb) { - return create_node(axes, std::vector{sizes}, SNodeType::bitmasked, tb); + const DebugInfo &dbg_info = DebugInfo()) { + return create_node(axes, std::vector{sizes}, SNodeType::bitmasked, + dbg_info); } - SNode &bitmasked(const Axis &axis, int size, const std::string &tb) { - return SNode::bitmasked(std::vector{axis}, size, tb); + SNode &bitmasked(const Axis &axis, + int size, + const DebugInfo &dbg_info = DebugInfo()) { + return SNode::bitmasked(std::vector{axis}, size, dbg_info); } SNode &hash(const std::vector &axes, const std::vector &sizes, - const std::string &tb) { - return create_node(axes, sizes, SNodeType::hash, tb); + const DebugInfo &dbg_info = DebugInfo()) { + return create_node(axes, sizes, SNodeType::hash, dbg_info); } - SNode &hash(const std::vector &axes, int sizes, const std::string &tb) { - return create_node(axes, std::vector{sizes}, SNodeType::hash, tb); + SNode &hash(const std::vector &axes, + int sizes, + const DebugInfo &dbg_info = DebugInfo()) { + return create_node(axes, std::vector{sizes}, SNodeType::hash, + dbg_info); } - SNode &hash(const Axis &axis, int size, const std::string &tb) { - return hash(std::vector{axis}, size, tb); + SNode &hash(const Axis &axis, + int size, + const DebugInfo &dbg_info = DebugInfo()) { + return hash(std::vector{axis}, size, dbg_info); } std::string type_name() { return snode_type_name(type); } - SNode &bit_struct(BitStructType *bit_struct_type, const std::string &tb); + SNode &bit_struct(BitStructType *bit_struct_type, + const DebugInfo &dbg_info = DebugInfo()); SNode &quant_array(const std::vector &axes, const std::vector &sizes, int bits, - const std::string &tb); + const DebugInfo &dbg_info = DebugInfo()); void print(); @@ -221,7 +236,7 @@ class SNode { SNode &dynamic(const Axis &expr, int n, int chunk_size, - const std::string &tb); + const DebugInfo &dbg_info = DebugInfo()); SNode &morton(bool val = true) { _morton = val; diff --git a/taichi/program/ndarray.cpp b/taichi/program/ndarray.cpp index a83d024d2b00bc..54702d54b2436b 100644 --- a/taichi/program/ndarray.cpp +++ b/taichi/program/ndarray.cpp @@ -26,10 +26,12 @@ size_t flatten_index(const std::vector &shapes, Ndarray::Ndarray(Program *prog, const DataType type, const std::vector &shape_, - ExternalArrayLayout layout_) + ExternalArrayLayout layout_, + const DebugInfo &dbg_info_) : dtype(type), shape(shape_), layout(layout_), + dbg_info(dbg_info_), nelement_(std::accumulate(std::begin(shape_), std::end(shape_), 1, @@ -51,7 +53,8 @@ Ndarray::Ndarray(Program *prog, std::accumulate(std::begin(total_shape_), std::end(total_shape_), 1LL, std::multiplies<>()); if (total_num_scalar > std::numeric_limits::max()) { - TI_WARN( + ErrorEmitter( + TaichiIndexWarning(), &dbg_info, "Ndarray index might be out of int32 boundary but int64 indexing is " "not supported yet."); } @@ -62,11 +65,13 @@ Ndarray::Ndarray(Program *prog, Ndarray::Ndarray(DeviceAllocation &devalloc, const DataType type, const std::vector &shape, - ExternalArrayLayout layout) + ExternalArrayLayout layout, + const DebugInfo &dbg_info) : ndarray_alloc_(devalloc), dtype(type), shape(shape), layout(layout), + dbg_info(dbg_info), nelement_(std::accumulate(std::begin(shape), std::end(shape), 1, @@ -91,7 +96,8 @@ Ndarray::Ndarray(DeviceAllocation &devalloc, std::accumulate(std::begin(total_shape_), std::end(total_shape_), 1LL, std::multiplies<>()); if (total_num_scalar > std::numeric_limits::max()) { - TI_WARN( + ErrorEmitter( + TaichiIndexWarning(), &dbg_info, "Ndarray index might be out of int32 boundary but int64 indexing is " "not supported yet."); } @@ -101,11 +107,13 @@ Ndarray::Ndarray(DeviceAllocation &devalloc, const DataType type, const std::vector &shape, const std::vector &element_shape, - ExternalArrayLayout layout) + ExternalArrayLayout layout, + const DebugInfo &dbg_info) : Ndarray(devalloc, TypeFactory::create_tensor_type(element_shape, type), shape, - layout) { + layout, + dbg_info) { TI_ASSERT(type->is()); } diff --git a/taichi/program/ndarray.h b/taichi/program/ndarray.h index 067be0f4f85c84..b5eeb5a918126b 100644 --- a/taichi/program/ndarray.h +++ b/taichi/program/ndarray.h @@ -4,6 +4,7 @@ #include #include "taichi/inc/constants.h" +#include "taichi/ir/ir.h" #include "taichi/ir/type_utils.h" #include "taichi/rhi/device.h" @@ -21,7 +22,8 @@ class TI_DLL_EXPORT Ndarray { explicit Ndarray(Program *prog, const DataType type, const std::vector &shape, - ExternalArrayLayout layout = ExternalArrayLayout::kNull); + ExternalArrayLayout layout = ExternalArrayLayout::kNull, + const DebugInfo &dbg_info = DebugInfo()); /* Constructs a Ndarray from an existing DeviceAllocation. * It doesn't handle the allocation and deallocation. @@ -31,7 +33,8 @@ class TI_DLL_EXPORT Ndarray { explicit Ndarray(DeviceAllocation &devalloc, const DataType type, const std::vector &shape, - ExternalArrayLayout layout = ExternalArrayLayout::kNull); + ExternalArrayLayout layout = ExternalArrayLayout::kNull, + const DebugInfo &dbg_info = DebugInfo()); /* Constructs a Ndarray from an existing DeviceAllocation. * This is an overloaded constructor for constructing Ndarray with TensorType @@ -41,7 +44,8 @@ class TI_DLL_EXPORT Ndarray { const DataType type, const std::vector &shape, const std::vector &element_shape, - ExternalArrayLayout layout = ExternalArrayLayout::kNull); + ExternalArrayLayout layout = ExternalArrayLayout::kNull, + const DebugInfo &dbg_info = DebugInfo()); DeviceAllocation ndarray_alloc_{kDeviceNullAllocation}; DataType dtype; @@ -50,6 +54,7 @@ class TI_DLL_EXPORT Ndarray { // num_active_indices = shape.size() std::vector shape; ExternalArrayLayout layout{ExternalArrayLayout::kNull}; + DebugInfo dbg_info; std::vector get_element_shape() const; DataType get_element_data_type() const; diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 9bb43f9f3913b5..153103ea1f1e91 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -375,8 +375,9 @@ std::size_t Program::get_snode_num_dynamically_allocated(SNode *snode) { Ndarray *Program::create_ndarray(const DataType type, const std::vector &shape, ExternalArrayLayout layout, - bool zero_fill) { - auto arr = std::make_unique(this, type, shape, layout); + bool zero_fill, + const DebugInfo &dbg_info) { + auto arr = std::make_unique(this, type, shape, layout, dbg_info); if (zero_fill) { Arch arch = compile_config().arch; if (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu) { diff --git a/taichi/program/program.h b/taichi/program/program.h index c4083af554f8ee..257363997d8620 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -254,7 +254,8 @@ class TI_DLL_EXPORT Program { const DataType type, const std::vector &shape, ExternalArrayLayout layout = ExternalArrayLayout::kNull, - bool zero_fill = false); + bool zero_fill = false, + const DebugInfo &dbg_info = DebugInfo()); ArgPack *create_argpack(const DataType dt); diff --git a/taichi/program/snode_expr_utils.cpp b/taichi/program/snode_expr_utils.cpp index acc5962696be1d..6b1696f7c6a131 100644 --- a/taichi/program/snode_expr_utils.cpp +++ b/taichi/program/snode_expr_utils.cpp @@ -56,7 +56,7 @@ void place_child(Expr *expr_arg, SNodeFieldMap *snode_to_exprs) { if (parent->type == SNodeType::root) { // never directly place to root - auto &ds = parent->dense(std::vector(), {}, ""); + auto &ds = parent->dense(std::vector(), {}); place_child(expr_arg, offset, id_in_bit_struct, &ds, snode_to_exprs); } else { TI_ASSERT(expr_arg->is()); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 66cdd20f673649..cdd46c7e36e887 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -147,6 +147,13 @@ void export_lang(py::module &m) { return dt; })); + py::class_(m, "DebugInfo") + .def(py::init<>()) + .def(py::init()) + .def(py::init<>()) + .def_readwrite("tb", &DebugInfo::tb) + .def_readwrite("src_loc", &DebugInfo::src_loc); + py::class_(m, "CompileConfig") .def(py::init<>()) .def_readwrite("arch", &CompileConfig::arch) @@ -430,12 +437,14 @@ void export_lang(py::module &m) { "create_ndarray", [&](Program *program, const DataType &dt, const std::vector &shape, ExternalArrayLayout layout, - bool zero_fill) -> Ndarray * { - return program->create_ndarray(dt, shape, layout, zero_fill); + bool zero_fill, DebugInfo dbg_info) -> Ndarray * { + return program->create_ndarray(dt, shape, layout, zero_fill, + dbg_info); }, py::arg("dt"), py::arg("shape"), py::arg("layout") = ExternalArrayLayout::kNull, - py::arg("zero_fill") = false, py::return_value_policy::reference) + py::arg("zero_fill") = false, py::arg("dbg_info") = DebugInfo(), + py::return_value_policy::reference) .def("delete_ndarray", &Program::delete_ndarray) .def( "create_argpack", @@ -492,23 +501,23 @@ void export_lang(py::module &m) { .def("dense", (SNode & (SNode::*)(const std::vector &, const std::vector &, - const std::string &))(&SNode::dense), + const DebugInfo &))(&SNode::dense), py::return_value_policy::reference) .def("pointer", (SNode & (SNode::*)(const std::vector &, const std::vector &, - const std::string &))(&SNode::pointer), + const DebugInfo &))(&SNode::pointer), py::return_value_policy::reference) .def("hash", (SNode & (SNode::*)(const std::vector &, const std::vector &, - const std::string &))(&SNode::hash), + const DebugInfo &))(&SNode::hash), py::return_value_policy::reference) .def("dynamic", &SNode::dynamic, py::return_value_policy::reference) .def("bitmasked", (SNode & (SNode::*)(const std::vector &, const std::vector &, - const std::string &))(&SNode::bitmasked), + const DebugInfo &))(&SNode::bitmasked), py::return_value_policy::reference) .def("bit_struct", &SNode::bit_struct, py::return_value_policy::reference) .def("quant_array", &SNode::quant_array, diff --git a/tests/cpp/analysis/bls_analyzer_test.cpp b/tests/cpp/analysis/bls_analyzer_test.cpp index c511ffc8dc51ef..66e71f7077c71d 100644 --- a/tests/cpp/analysis/bls_analyzer_test.cpp +++ b/tests/cpp/analysis/bls_analyzer_test.cpp @@ -20,7 +20,7 @@ class BLSAnalyzerTest : public ::testing::Test { void SetUp() override { const std::vector axes = {Axis{0}, Axis{1}}; root_snode_ = std::make_unique(/*depth=*/0, /*t=*/SNodeType::root); - parent_snode_ = &(root_snode_->dense(axes, /*sizes=*/kBlockSize, "")); + parent_snode_ = &(root_snode_->dense(axes, /*sizes=*/kBlockSize)); child_snode_ = &(parent_snode_->insert_children(SNodeType::place)); child_snode_->dt = PrimitiveType::i32; diff --git a/tests/cpp/aot/dx12/aot_save_load_test.cpp b/tests/cpp/aot/dx12/aot_save_load_test.cpp index 0a620b5f02264d..3b53436a7abef6 100644 --- a/tests/cpp/aot/dx12/aot_save_load_test.cpp +++ b/tests/cpp/aot/dx12/aot_save_load_test.cpp @@ -22,7 +22,7 @@ namespace fs = std::filesystem; int n = 10; auto *root = new SNode(0, SNodeType::root); - auto *pointer = &root->dense(Axis(0), n, ""); + auto *pointer = &root->dense(Axis(0), n); auto *place = &pointer->insert_children(SNodeType::place); place->dt = PrimitiveType::i32; program.add_snode_tree(std::unique_ptr(root), /*compile_only=*/true); diff --git a/tests/cpp/codegen/refine_coordinates_test.cpp b/tests/cpp/codegen/refine_coordinates_test.cpp index 7d97def0ff8c76..dccd0fe93232f8 100644 --- a/tests/cpp/codegen/refine_coordinates_test.cpp +++ b/tests/cpp/codegen/refine_coordinates_test.cpp @@ -120,8 +120,8 @@ class RefineCoordinatesTest : public ::testing::Test { root_snode_ = std::make_unique(/*depth=*/0, /*t=*/SNodeType::root); const std::vector axes = {Axis{0}}; - ptr_snode_ = &(root_snode_->pointer(axes, kPointerSize, "")); - dense_snode_ = &(ptr_snode_->dense(axes, kDenseSize, "")); + ptr_snode_ = &(root_snode_->pointer(axes, kPointerSize)); + dense_snode_ = &(ptr_snode_->dense(axes, kDenseSize)); // Must end with a `place` SNode. auto &leaf_snode = dense_snode_->insert_children(SNodeType::place); leaf_snode.dt = PrimitiveType::f32; diff --git a/tests/cpp/struct/snode_tree_test.cpp b/tests/cpp/struct/snode_tree_test.cpp index defc783e8079d1..8ef32d5634be0e 100644 --- a/tests/cpp/struct/snode_tree_test.cpp +++ b/tests/cpp/struct/snode_tree_test.cpp @@ -10,8 +10,8 @@ TEST(SNodeTree, GetSNodeToRootMapping) { const std::vector axes = {Axis{0}}; std::vector all_snode_ids; for (int i = 0; i < 3; ++i) { - auto &ptr_snode = root.pointer(axes, kSNodeSize, ""); - auto &dense_snode = ptr_snode.dense(axes, kSNodeSize, ""); + auto &ptr_snode = root.pointer(axes, kSNodeSize); + auto &dense_snode = ptr_snode.dense(axes, kSNodeSize); auto &leaf_snode = dense_snode.insert_children(SNodeType::place); all_snode_ids.push_back(ptr_snode.id); all_snode_ids.push_back(dense_snode.id); diff --git a/tests/cpp/transforms/make_block_local_test.cpp b/tests/cpp/transforms/make_block_local_test.cpp index 1633a216d60f14..cbc1345e2c060e 100644 --- a/tests/cpp/transforms/make_block_local_test.cpp +++ b/tests/cpp/transforms/make_block_local_test.cpp @@ -41,14 +41,14 @@ class MakeBlockLocalTest : public ::testing::Test { // want to see if the tests can handle the loop index scaling multiplier // (block_size) and infer the BLS size correctly. const std::vector axes = {Axis{0}, Axis{1}}; - pointer_snode_ = &(root_snode_->pointer(axes, pointer_size, "")); + pointer_snode_ = &(root_snode_->pointer(axes, pointer_size)); - bls_snode_ = &(pointer_snode_->dense(axes, /*sizes=*/block_size, "")); + bls_snode_ = &(pointer_snode_->dense(axes, /*sizes=*/block_size)); bls_place_snode_ = &(bls_snode_->insert_children(SNodeType::place)); bls_place_snode_->dt = PrimitiveType::f32; struct_for_snode_ = &(pointer_snode_->dynamic({Axis{2}}, /*n=*/1024, - /*chunk_size=*/128, "")); + /*chunk_size=*/128)); struct_for_place_snode_ = &(struct_for_snode_->insert_children(SNodeType::place)); struct_for_place_snode_->dt = PrimitiveType::i32; diff --git a/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp b/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp index 8c9db77ec11011..4b36ce79be78d5 100644 --- a/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp +++ b/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp @@ -35,8 +35,8 @@ class ScalarPointerLowererTest : public ::testing::Test { void SetUp() override { root_snode_ = std::make_unique(/*depth=*/0, /*t=*/SNodeType::root); const std::vector axes = {Axis{0}}; - ptr_snode_ = &(root_snode_->pointer(axes, kPointerSize, "")); - dense_snode_ = &(ptr_snode_->dense(axes, kDenseSize, "")); + ptr_snode_ = &(root_snode_->pointer(axes, kPointerSize)); + dense_snode_ = &(ptr_snode_->dense(axes, kDenseSize)); // Must end with a `place` SNode. leaf_snode_ = &(dense_snode_->insert_children(SNodeType::place)); leaf_snode_->dt = PrimitiveType::f32; @@ -106,9 +106,9 @@ TEST(ScalarPointerLowerer, EliminateModDiv) { VecStatement lowered; Stmt *index = builder.get_int32(2); auto root = std::make_unique(/*depth=*/0, SNodeType::root); - SNode *dense_1 = &(root->dense({Axis{2}, Axis{1}}, /*size=*/7, "")); - SNode *dense_2 = &(root->dense({Axis{1}}, /*size=*/3, "")); - SNode *dense_3 = &(dense_2->dense({Axis{0}, Axis{1}}, /*size=*/{5, 8}, "")); + SNode *dense_1 = &(root->dense({Axis{2}, Axis{1}}, /*sizes=*/7)); + SNode *dense_2 = &(root->dense({Axis{1}}, /*size=*/3)); + SNode *dense_3 = &(dense_2->dense({Axis{0}, Axis{1}}, /*sizes=*/{5, 8})); SNode *leaf_1 = &(dense_1->insert_children(SNodeType::place)); SNode *leaf_2 = &(dense_3->insert_children(SNodeType::place)); LowererImpl lowerer_1{leaf_1, From 772199ece38647226c8778962ee05c76e9df08ea Mon Sep 17 00:00:00 2001 From: Qian Bao Date: Thu, 20 Jul 2023 12:08:48 +0800 Subject: [PATCH 6/7] [bug] Fix mmwrite format. (#8292) Issue: # ### Brief Summary There is an extra white space in the current `mmwrite()` method, which causes reading errors: ![image](https://github.com/taichi-dev/taichi/assets/2747993/1c7423f8-638e-401b-9c7b-4723e3e8fef0) This PR removes this white space. --- taichi/program/sparse_matrix.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/program/sparse_matrix.cpp b/taichi/program/sparse_matrix.cpp index c86cf4bdcd6901..2432b56ba886ad 100644 --- a/taichi/program/sparse_matrix.cpp +++ b/taichi/program/sparse_matrix.cpp @@ -271,7 +271,7 @@ const std::string EigenSparseMatrix::to_string() const { template void EigenSparseMatrix::mmwrite(const std::string &filename) { std::ofstream file(filename); - file << "%%MatrixMarket matrix coordinate real general\n %" << std::endl; + file << "%%MatrixMarket matrix coordinate real general\n%" << std::endl; file << matrix_.rows() << " " << matrix_.cols() << " " << matrix_.nonZeros() << std::endl; for (int k = 0; k < matrix_.outerSize(); ++k) { From c41eec1b6efad13e74b54bd3a365c10647f1ec06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AD=94=E6=B3=95=E5=B0=91=E5=A5=B3=E8=B5=B5=E5=BF=97?= =?UTF-8?q?=E8=BE=89?= Date: Fri, 21 Jul 2023 15:35:54 +0800 Subject: [PATCH 7/7] [lang] Passing DebugInfo instead of std::string traceback (#8293) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: # ### Brief Summary ### 🤖 Generated by Copilot at 3304bee Refactor the IR code to use `DebugInfo` objects to store and pass source info. This simplifies the IR construction and debugging, and avoids duplication of the `tb` argument. Modify various files in `python/taichi/lang` and `taichi/ir` to use the new `DebugInfo` constructor and argument. ### Walkthrough ### 🤖 Generated by Copilot at 3304bee * Unify the way source info is stored and passed in the IR by using `DebugInfo` objects instead of `Traceback` objects or raw strings. ([link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-2e623ee0b0eec1b200fead36c0627a3c54738f6d83d79757398dc67decc01da8L91-R91), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-3e22417ffade4af0564893b98dc5101d714b8ba6fd4423ab5bc5129e360fee8fL991-R991), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-3e22417ffade4af0564893b98dc5101d714b8ba6fd4423ab5bc5129e360fee8fL1000-R1000), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-99744c5ae5f6a754d6f68408fdc64fb0d6097216518a7f3d1ef43ffe12599577L253-R262), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-c48bb572255ef55d0c9fd89c9febab88b9668e10dfcfc1fac88feb1be7bd94caL600-R600), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-c48bb572255ef55d0c9fd89c9febab88b9668e10dfcfc1fac88feb1be7bd94caL615-R615), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-059028cb0798284bed05638becbc32d256736846de19746e196fe5f5ee7fd061L1361-R1361), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-aee943d584058490d7717d34c02a3783d3487694dc091653d42b202e45b1e097L69-R69), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-9d4212518d49c780c1b10ace5a5d873aed525373a2751e7d888f164ea51edd7fR44-R47), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L41-R44), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L281-R287), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L458-R459), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L465-R466), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L604-R609), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L625-R624), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L633-R635), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L709-R707), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L735-R733), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L743-R741), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L757-R761), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L792-R791), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L800-R801), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L807), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L948-R947), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L955-R954), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1075-R1075), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1110-R1111), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1128-R1127), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1144-R1150), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1413-R1408), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1419-R1414), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1431-R1427), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1591-R1589), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1600-R1595), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1614-R1609), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1786-R1781), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1792-R1787), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-a6e92dd2dd707d705dc44ef91463ddc0423575188e6e8e0555de9e439db88c35L1801-R1796), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-0bb8e07f3f606e45f4e4284c02b32f986d420388fdd71b115a7475aacb50a42cL143-R145), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-0bb8e07f3f606e45f4e4284c02b32f986d420388fdd71b115a7475aacb50a42cL642-R649), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-0bb8e07f3f606e45f4e4284c02b32f986d420388fdd71b115a7475aacb50a42cL996-R999), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-0bb8e07f3f606e45f4e4284c02b32f986d420388fdd71b115a7475aacb50a42cL1027-R1029), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-0bb8e07f3f606e45f4e4284c02b32f986d420388fdd71b115a7475aacb50a42cL1034-R1038), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-50be2dc708119a4c9b53e977807d2f05e4ff6ce98c3f51fa91d1fa9e229962f1R127-R130), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-61484fa2a50e309478017fb2a436198aa4b0afdf72a4039bf574fc4f2aedbe4eR393), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5L9-R12), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5L65-R70), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5L101-R108), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5L139-R145), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5L317-R329), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5L328-R344), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L21-R22), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L159-R162), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L203-R209), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L260-R268), [link](https://github.com/taichi-dev/taichi/pull/8293/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L287-R299)) --- python/taichi/lang/any_array.py | 2 +- python/taichi/lang/ast/ast_transformer.py | 4 +- python/taichi/lang/expr.py | 2 +- python/taichi/lang/impl.py | 15 ++- python/taichi/lang/mesh.py | 4 +- python/taichi/lang/ops.py | 2 +- python/taichi/lang/simt/block.py | 2 +- taichi/ir/expr.cpp | 4 +- taichi/ir/expr.h | 4 +- taichi/ir/expression.h | 4 + taichi/ir/frontend_ir.cpp | 117 +++++++++--------- taichi/ir/frontend_ir.h | 18 +-- taichi/ir/ir.cpp | 4 + taichi/ir/ir.h | 1 + taichi/ir/statements.cpp | 39 ++++-- taichi/ir/statements.h | 60 ++++++--- taichi/math/svd.h | 5 +- taichi/program/program.cpp | 2 +- taichi/python/export_lang.cpp | 4 +- taichi/transforms/alg_simp.cpp | 2 +- tests/cpp/ir/frontend_type_inference_test.cpp | 2 +- 21 files changed, 175 insertions(+), 122 deletions(-) diff --git a/python/taichi/lang/any_array.py b/python/taichi/lang/any_array.py index fa552857a932cf..01bffab41b9cba 100644 --- a/python/taichi/lang/any_array.py +++ b/python/taichi/lang/any_array.py @@ -88,7 +88,7 @@ def subscript(self, i, j): ast_builder.expr_subscript( self.arr.ptr, make_expr_group(*indices), - impl.get_runtime().get_current_src_info(), + _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()), ) ) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index e9e93ca5cb3b7b..b763a9bcf2cc57 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -988,7 +988,7 @@ def build_Attribute(ctx, node): .expr_subscript( node.value.ptr.ptr, make_expr_group(keygroup.index(node.attr)), - impl.get_runtime().get_current_src_info(), + _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()), ) ) else: @@ -997,7 +997,7 @@ def build_Attribute(ctx, node): node.value.ptr.ptr, [make_expr_group(keygroup.index(ch)) for ch in node.attr], (attr_len,), - impl.get_runtime().get_current_src_info(), + _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()), ) ) else: diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 68f4217bf15678..258bd76f2d2941 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -40,7 +40,7 @@ def __init__(self, *args, tb=None, dtype=None): else: assert False if self.tb: - self.ptr.set_tb(self.tb) + self.ptr.set_dbg_info(_ti_core.DebugInfo(self.tb)) if not self.ptr_type_checked: self.ptr.type_check(impl.get_runtime().prog.config()) self.ptr_type_checked = True diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index 0f59bae6caba6d..1a6f795943e797 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -97,7 +97,9 @@ def expr_init(rhs): if hasattr(rhs, "_data_oriented"): return rhs return Expr( - get_runtime().compiling_callable.ast_builder().expr_var(Expr(rhs).ptr, get_runtime().get_current_src_info()) + get_runtime() + .compiling_callable.ast_builder() + .expr_var(Expr(rhs).ptr, _ti_core.DebugInfo(get_runtime().get_current_src_info())) ) @@ -175,6 +177,7 @@ def validate_subscript_index(value, index): @taichi_scope def subscript(ast_builder, value, *_indices, skip_reordered=False): + dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info()) ast_builder = get_runtime().compiling_callable.ast_builder() # Directly evaluate in Python for non-Taichi types if not isinstance( @@ -251,14 +254,14 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): ) if isinstance(value, MatrixField): - return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, get_runtime().get_current_src_info())) + return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info)) if isinstance(value, StructField): entries = {k: subscript(ast_builder, v, *indices) for k, v in value._items} entries["__struct_methods"] = value.struct_methods return _IntermediateStruct(entries) - return Expr(ast_builder.expr_subscript(_var, indices_expr_group, get_runtime().get_current_src_info())) + return Expr(ast_builder.expr_subscript(_var, indices_expr_group, dbg_info)) if isinstance(value, AnyArray): - return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, get_runtime().get_current_src_info())) + return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info)) assert isinstance(value, Expr) # Index into TensorType # value: IndexExpression with ret_type = TensorType @@ -291,10 +294,10 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False): value.ptr, multiple_indices, return_shape, - get_runtime().get_current_src_info(), + dbg_info, ) ) - return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, get_runtime().get_current_src_info())) + return Expr(ast_builder.expr_subscript(value.ptr, indices_expr_group, dbg_info)) class SrcInfoGuard: diff --git a/python/taichi/lang/mesh.py b/python/taichi/lang/mesh.py index b360b64b0a2e3f..8e7724fb5c463b 100644 --- a/python/taichi/lang/mesh.py +++ b/python/taichi/lang/mesh.py @@ -597,7 +597,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType, entry_expr ast_builder.expr_subscript( attr.ptr, global_entry_expr_group, - impl.get_runtime().get_current_src_info(), + _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()), ) ), ) @@ -612,7 +612,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType, entry_expr ast_builder.expr_subscript( var, global_entry_expr_group, - impl.get_runtime().get_current_src_info(), + _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()), ) ), ) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 650e42da65b319..49ed0a75046e72 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -1358,7 +1358,7 @@ def atomic_xor(x, y): @writeback_binary def assign(a, b): - impl.get_runtime().compiling_callable.ast_builder().expr_assign(a.ptr, b.ptr, stack_info()) + impl.get_runtime().compiling_callable.ast_builder().expr_assign(a.ptr, b.ptr, _ti_core.DebugInfo(stack_info())) return a diff --git a/python/taichi/lang/simt/block.py b/python/taichi/lang/simt/block.py index 80296a14c1414f..aef3544df31ead 100644 --- a/python/taichi/lang/simt/block.py +++ b/python/taichi/lang/simt/block.py @@ -66,6 +66,6 @@ def subscript(self, *indices): ast_builder.expr_subscript( self.shared_array_proxy, make_expr_group(*indices), - impl.get_runtime().get_current_src_info(), + _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()), ) ) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index e81eaf776ffbe4..0e4ad2708a4af3 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -6,8 +6,8 @@ namespace taichi::lang { -void Expr::set_tb(const std::string &tb) { - expr->set_tb(tb); +void Expr::set_dbg_info(const DebugInfo &dbg_info) { + expr->dbg_info = dbg_info; } const std::string &Expr::get_tb() const { diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index 7348ca9bc263b0..a9e63772962a43 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -92,8 +92,8 @@ class Expr { SNode *snode() const; - // traceback for type checking error message - void set_tb(const std::string &tb); + // debug info, contains traceback for type checking error message + void set_dbg_info(const DebugInfo &dbg_info); const std::string &get_tb() const; diff --git a/taichi/ir/expression.h b/taichi/ir/expression.h index 2eca1890b1501a..1a9c6642de345a 100644 --- a/taichi/ir/expression.h +++ b/taichi/ir/expression.h @@ -41,6 +41,10 @@ class Expression { stmt = nullptr; } + explicit Expression(const DebugInfo &dbg_info) : Expression() { + this->dbg_info = dbg_info; + } + virtual void type_check(const CompileConfig *config) = 0; virtual void accept(ExpressionVisitor *visitor) = 0; diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 95a1df3127c045..249cd57ae7b464 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -38,8 +38,10 @@ FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, FrontendReturnStmt::FrontendReturnStmt(const ExprGroup &group) : values(group) { } -FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) - : lhs(lhs), rhs(rhs) { +FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, + const Expr &rhs, + const DebugInfo &dbg_info) + : Stmt(dbg_info), lhs(lhs), rhs(rhs) { TI_ASSERT(lhs->is_lvalue()); if (lhs.is() && lhs->ret_type == PrimitiveType::unknown) { lhs.expr->ret_type = @@ -278,12 +280,11 @@ bool UnaryOpExpression::is_cast() const { void UnaryOpExpression::flatten(FlattenContext *ctx) { auto operand_stmt = flatten_rvalue(operand, ctx); - auto unary = std::make_unique(type, operand_stmt); + auto unary = std::make_unique(type, operand_stmt, dbg_info); if (is_cast()) { unary->cast_type = cast_type; } stmt = unary.get(); - stmt->set_tb(get_tb()); stmt->ret_type = ret_type; ctx->push_back(std::move(unary)); } @@ -455,15 +456,14 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { } if_stmt->set_false_statements(std::move(false_block)); - auto ret = ctx->push_back(result); - ret->set_tb(get_tb()); + auto ret = ctx->push_back(result, dbg_info); stmt = ret; stmt->ret_type = ret_type; return; } auto rhs_stmt = flatten_rvalue(rhs, ctx); - ctx->push_back(std::make_unique(type, lhs_stmt, rhs_stmt)); - ctx->stmts.back()->set_tb(get_tb()); + ctx->push_back( + std::make_unique(type, lhs_stmt, rhs_stmt, dbg_info)); stmt = ctx->back_stmt(); stmt->ret_type = ret_type; } @@ -601,13 +601,12 @@ void TernaryOpExpression::flatten(FlattenContext *ctx) { auto op1_stmt = flatten_rvalue(op1, ctx); auto op2_stmt = flatten_rvalue(op2, ctx); auto op3_stmt = flatten_rvalue(op3, ctx); - ctx->push_back( - std::make_unique(type, op1_stmt, op2_stmt, op3_stmt)); + ctx->push_back(std::make_unique(type, op1_stmt, op2_stmt, + op3_stmt, dbg_info)); } else if (type == TernaryOpType::ifte) { make_ifte(ctx, ret_type, op1, op2, op3); } stmt = ctx->back_stmt(); - stmt->set_tb(get_tb()); stmt->ret_type = ret_type; } @@ -622,7 +621,7 @@ void InternalFuncCallExpression::type_check(const CompileConfig *) { void InternalFuncCallExpression::flatten(FlattenContext *ctx) { stmt = op->flatten(ctx, args, ret_type); - stmt->set_tb(get_tb()); + stmt->dbg_info = dbg_info; } void ExternalTensorExpression::flatten(FlattenContext *ctx) { @@ -630,11 +629,10 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) { TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad); type = TypeFactory::get_instance().get_pointer_type((Type *)type); - auto ptr = - Stmt::make(arg_id, type, /*is_ptr=*/true, - /*create_load=*/false, /*arg_depth=*/arg_depth); + auto ptr = Stmt::make( + arg_id, type, /*is_ptr=*/true, + /*create_load=*/false, /*arg_depth=*/arg_depth, /*dbg_info=*/dbg_info); - ptr->set_tb(get_tb()); ctx->push_back(std::move(ptr)); stmt = ctx->back_stmt(); } @@ -706,7 +704,7 @@ Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx, Stmt *var_stmt, const ExprGroup &indices, const std::vector &shape, - const std::string &tb) { + const DebugInfo &dbg_info) { bool needs_dynamic_index = false; for (int i = 0; i < (int)indices.size(); ++i) { if (!indices[i].is()) { @@ -732,7 +730,7 @@ Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx, } offset_stmt = ctx->push_back(TypedConstant(offset)); } - return ctx->push_back(var_stmt, offset_stmt, tb); + return ctx->push_back(var_stmt, offset_stmt, dbg_info); } Stmt *make_tensor_access(Expression::FlattenContext *ctx, @@ -740,7 +738,7 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, const std::vector &indices_group, DataType ret_type, std::vector shape, - const std::string &tb) { + const DebugInfo &dbg_info) { auto var_stmt = flatten_lvalue(var, ctx); if (!var->is_lvalue()) { auto alloca_stmt = ctx->push_back(var.get_rvalue_type()); @@ -754,13 +752,13 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, if (ret_type.ptr_removed()->is() && !is_shared_array) { std::vector stmts; for (auto &indices : indices_group) { - stmts.push_back( - make_tensor_access_single_element(ctx, var_stmt, indices, shape, tb)); + stmts.push_back(make_tensor_access_single_element(ctx, var_stmt, indices, + shape, dbg_info)); } return ctx->push_back(stmts, ret_type); } return make_tensor_access_single_element(ctx, var_stmt, indices_group[0], - shape, tb); + shape, dbg_info); } void MatrixExpression::type_check(const CompileConfig *config) { @@ -789,22 +787,23 @@ void MatrixExpression::flatten(FlattenContext *ctx) { IndexExpression::IndexExpression(const Expr &var, const ExprGroup &indices, - std::string tb) - : var(var), indices_group({indices}) { - this->set_tb(tb); + const DebugInfo &dbg_info) + : Expression(dbg_info), var(var), indices_group({indices}) { } IndexExpression::IndexExpression(const Expr &var, const std::vector &indices_group, const std::vector &ret_shape, - std::string tb) - : var(var), indices_group(indices_group), ret_shape(ret_shape) { + const DebugInfo &dbg_info) + : Expression(dbg_info), + var(var), + indices_group(indices_group), + ret_shape(ret_shape) { // IndexExpression with ret_shape is used for matrix slicing, where each entry // of ExprGroup is interpreted as a group of indices to return within each // axis. For example, mat[0, 3:5] has indices_group={0, [3, 4]}, where [3, 4] // means "m"-axis will return a TensorType with size of 2. In this case, we // should not expand indices_group due to its special semantics. - this->set_tb(tb); } bool IndexExpression::is_field() const { @@ -945,14 +944,14 @@ void IndexExpression::flatten(FlattenContext *ctx) { } else if (is_tensor()) { stmt = make_tensor_access( ctx, var, indices_group, ret_type, - var->ret_type.ptr_removed()->as()->get_shape(), get_tb()); + var->ret_type.ptr_removed()->as()->get_shape(), dbg_info); } else { ErrorEmitter( TaichiIndexError(), this, "Invalid IndexExpression: the source is not among field, ndarray or " "local tensor"); } - stmt->set_tb(get_tb()); + stmt->dbg_info = dbg_info; } void RangeAssumptionExpression::type_check(const CompileConfig *) { @@ -1072,9 +1071,8 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) { // expand rhs auto val_stmt = flatten_rvalue(val, ctx); auto dest_stmt = flatten_lvalue(dest, ctx); - stmt = ctx->push_back(op_type, dest_stmt, val_stmt); + stmt = ctx->push_back(op_type, dest_stmt, val_stmt, dbg_info); stmt->ret_type = stmt->as()->dest->ret_type; - stmt->set_tb(get_tb()); } SNodeOpExpression::SNodeOpExpression(SNode *snode, @@ -1107,10 +1105,10 @@ void SNodeOpExpression::type_check(const CompileConfig *config) { auto value_type = values[i].get_rvalue_type(); auto promoted = promoted_type(dst_type, value_type); if (dst_type != promoted) { - ErrorEmitter(TaichiCastWarning(), this, - fmt::format("Append may lose precision: {} <- {}\n{}", - dst_type->to_string(), value_type->to_string(), - get_tb())); + ErrorEmitter( + TaichiCastWarning(), this, + fmt::format("Append may lose precision: {} <- {}", + dst_type->to_string(), value_type->to_string())); } values[i] = cast(values[i], dst_type); values[i]->type_check(config); @@ -1125,9 +1123,8 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { } auto is_cell_access = SNodeOpStmt::activation_related(op_type) && snode->type != SNodeType::dynamic; - auto ptr = - ctx->push_back(snode, indices_stmt, true, is_cell_access); - ptr->set_tb(get_tb()); + auto ptr = ctx->push_back(snode, indices_stmt, true, + is_cell_access, dbg_info); if (op_type == SNodeOpType::is_active) { if (!(snode->type == SNodeType::pointer || snode->type == SNodeType::hash || snode->type == SNodeType::bitmasked)) { @@ -1141,18 +1138,16 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { } else if (op_type == SNodeOpType::get_addr) { ctx->push_back(SNodeOpType::get_addr, snode, ptr, nullptr); } else if (op_type == SNodeOpType::append) { - auto alloca = ctx->push_back(PrimitiveType::i32); - alloca->set_tb(get_tb()); - auto addr = - ctx->push_back(SNodeOpType::allocate, snode, ptr, alloca); - addr->set_tb(get_tb()); + auto alloca = ctx->push_back(PrimitiveType::i32, dbg_info); + auto addr = ctx->push_back(SNodeOpType::allocate, snode, ptr, + alloca, dbg_info); for (int i = 0; i < values.size(); i++) { auto value_stmt = flatten_rvalue(values[i], ctx); - auto ch_addr = ctx->push_back(addr, snode, i); - ch_addr->set_tb(get_tb()); - ctx->push_back(ch_addr, value_stmt)->set_tb(get_tb()); + auto ch_addr = ctx->push_back( + addr, snode, i, /*is_bit_vectorized = */ false, dbg_info); + ctx->push_back(ch_addr, value_stmt, dbg_info); } - ctx->push_back(alloca)->set_tb(get_tb()); + ctx->push_back(alloca, dbg_info); if (snode->type != SNodeType::dynamic) { ErrorEmitter(TaichiTypeError(), this, "ti.append only works on dynamic nodes."); @@ -1410,14 +1405,13 @@ void ASTBuilder::stop_gradient(SNode *snode) { void ASTBuilder::insert_assignment(Expr &lhs, const Expr &rhs, - const std::string &tb) { + const DebugInfo &dbg_info) { // Inside a kernel or a function // Create an assignment in the IR if (lhs.expr == nullptr) { lhs.set(rhs); } else if (lhs.expr->is_lvalue()) { - auto stmt = std::make_unique(lhs, rhs); - stmt->set_tb(tb); + auto stmt = std::make_unique(lhs, rhs, dbg_info); this->insert(std::move(stmt)); } else { @@ -1428,9 +1422,9 @@ void ASTBuilder::insert_assignment(Expr &lhs, } } -Expr ASTBuilder::make_var(const Expr &x, std::string tb) { +Expr ASTBuilder::make_var(const Expr &x, const DebugInfo &dbg_info) { auto var = this->expr_alloca(); - this->insert_assignment(var, x, tb); + this->insert_assignment(var, x, dbg_info); return var; } @@ -1588,16 +1582,17 @@ Expr ASTBuilder::expr_alloca_shared_array(const std::vector &shape, return var; } -void ASTBuilder::expr_assign(const Expr &lhs, const Expr &rhs, std::string tb) { +void ASTBuilder::expr_assign(const Expr &lhs, + const Expr &rhs, + const DebugInfo &dbg_info) { TI_ASSERT(lhs->is_lvalue()); - auto stmt = std::make_unique(lhs, rhs); - stmt->set_tb(tb); + auto stmt = std::make_unique(lhs, rhs, dbg_info); this->insert(std::move(stmt)); } Expr ASTBuilder::expr_subscript(const Expr &expr, const ExprGroup &indices, - std::string tb) { + const DebugInfo &dbg_info) { TI_ASSERT(expr.is() || expr.is() || expr.is() || is_tensor(expr.expr->ret_type.ptr_removed())); @@ -1611,7 +1606,7 @@ Expr ASTBuilder::expr_subscript(const Expr &expr, auto expanded_expr_group = ExprGroup(); expanded_expr_group.exprs = expanded_indices; - return Expr::make(expr, expanded_expr_group, tb); + return Expr::make(expr, expanded_expr_group, dbg_info); } void ASTBuilder::create_assert_stmt(const Expr &cond, @@ -1783,13 +1778,13 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { if (expr.is()) { id_expr = expr; } else { - id_expr = make_var(expr, expr.get_tb()); + id_expr = make_var(expr, expr->dbg_info); } auto shape = tensor_type->get_shape(); if (shape.size() == 1) { for (int i = 0; i < shape[0]; i++) { auto ind = Expr(std::make_shared( - id_expr, ExprGroup(Expr(i)), expr.get_tb())); + id_expr, ExprGroup(Expr(i)), expr->dbg_info)); ind->type_check(nullptr); expanded_exprs.push_back(ind); } @@ -1798,7 +1793,7 @@ std::vector ASTBuilder::expand_exprs(const std::vector &exprs) { for (int i = 0; i < shape[0]; i++) { for (int j = 0; j < shape[1]; j++) { auto ind = Expr(std::make_shared( - id_expr, ExprGroup(Expr(i), Expr(j)), expr.get_tb())); + id_expr, ExprGroup(Expr(i), Expr(j)), expr->dbg_info)); ind->type_check(nullptr); expanded_exprs.push_back(ind); } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 1fc97612724a49..d1ed7a1689c807 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -140,7 +140,9 @@ class FrontendAssignStmt : public Stmt { public: Expr lhs, rhs; - FrontendAssignStmt(const Expr &lhs, const Expr &rhs); + FrontendAssignStmt(const Expr &lhs, + const Expr &rhs, + const DebugInfo &dbg_info = DebugInfo()); TI_DEFINE_ACCEPT TI_DEFINE_CLONE_FOR_FRONTEND_IR @@ -639,12 +641,12 @@ class IndexExpression : public Expression { IndexExpression(const Expr &var, const ExprGroup &indices, - std::string tb = ""); + const DebugInfo &dbg_info = DebugInfo()); IndexExpression(const Expr &var, const std::vector &indices_group, const std::vector &ret_shape, - std::string tb = ""); + const DebugInfo &dbg_info = DebugInfo()); void type_check(const CompileConfig *config) override; @@ -993,8 +995,8 @@ class ASTBuilder { void stop_gradient(SNode *); void insert_assignment(Expr &lhs, const Expr &rhs, - const std::string &tb = ""); - Expr make_var(const Expr &x, std::string tb); + const DebugInfo &dbg_info = DebugInfo()); + Expr make_var(const Expr &x, const DebugInfo &dbg_info = DebugInfo()); void insert_for(const Expr &s, const Expr &e, const std::function &func); @@ -1024,14 +1026,16 @@ class ASTBuilder { const DataType &element_type); Expr expr_subscript(const Expr &expr, const ExprGroup &indices, - std::string tb = ""); + const DebugInfo &dbg_info = DebugInfo()); Expr mesh_index_conversion(mesh::MeshPtr mesh_ptr, mesh::MeshElementType idx_type, const Expr &idx, mesh::ConvType &conv_type); - void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb); + void expr_assign(const Expr &lhs, + const Expr &rhs, + const DebugInfo &dbg_info = DebugInfo()); std::optional insert_func_call(Function *func, const ExprGroup &args); void create_assert_stmt(const Expr &cond, const std::string &msg, diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 969a5766e0a679..cd9924e827caa4 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -124,6 +124,10 @@ Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) { ret_type = stmt.ret_type; } +Stmt::Stmt(const DebugInfo &dbg_info) : Stmt() { + this->dbg_info = dbg_info; +} + Callable *Stmt::get_callable() const { Block *parent_block = parent; if (parent_block->parent_callable()) { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 9ac7ab785192c9..a62b00d8347627 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -390,6 +390,7 @@ class StmtFieldManager { class Stmt : public IRNode { protected: std::vector operands; + explicit Stmt(const DebugInfo &dbg_info); public: StmtFieldManager field_manager; diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index ddf9fa494b298f..97b5cf16dd66ae 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -6,8 +6,10 @@ namespace taichi::lang { -UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand) - : op_type(op_type), operand(operand) { +UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, + Stmt *operand, + const DebugInfo &dbg_info) + : Stmt(dbg_info), op_type(op_type), operand(operand) { TI_ASSERT(!operand->is()); cast_type = PrimitiveType::unknown; TI_STMT_REG_FIELDS; @@ -62,8 +64,10 @@ ExternalPtrStmt::ExternalPtrStmt(Stmt *base_ptr, GlobalPtrStmt::GlobalPtrStmt(SNode *snode, const std::vector &indices, bool activate, - bool is_cell_access) - : snode(snode), + bool is_cell_access, + const DebugInfo &dbg_info) + : Stmt(dbg_info), + snode(snode), indices(indices), activate(activate), is_cell_access(is_cell_access), @@ -98,10 +102,10 @@ MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector &stmts, MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input, Stmt *offset_input, - const std::string &tb) { + const DebugInfo &dbg_info) { origin = origin_input; offset = offset_input; - this->set_tb(tb); + this->dbg_info = dbg_info; if (origin->is() || origin->is() || origin->is() || origin->is() || @@ -136,8 +140,9 @@ bool MatrixPtrStmt::common_statement_eliminable() const { SNodeOpStmt::SNodeOpStmt(SNodeOpType op_type, SNode *snode, Stmt *ptr, - Stmt *val) - : op_type(op_type), snode(snode), ptr(ptr), val(val) { + Stmt *val, + const DebugInfo &dbg_info) + : Stmt(dbg_info), op_type(op_type), snode(snode), ptr(ptr), val(val) { element_type() = PrimitiveType::i32; TI_STMT_REG_FIELDS; } @@ -314,8 +319,14 @@ std::unique_ptr WhileStmt::clone() const { return new_stmt; } -GetChStmt::GetChStmt(Stmt *input_ptr, int chid, bool is_bit_vectorized) - : input_ptr(input_ptr), chid(chid), is_bit_vectorized(is_bit_vectorized) { +GetChStmt::GetChStmt(Stmt *input_ptr, + int chid, + bool is_bit_vectorized, + const DebugInfo &dbg_info) + : Stmt(dbg_info), + input_ptr(input_ptr), + chid(chid), + is_bit_vectorized(is_bit_vectorized) { TI_ASSERT(input_ptr->is()); input_snode = input_ptr->as()->snode; output_snode = input_snode->ch[chid].get(); @@ -325,8 +336,12 @@ GetChStmt::GetChStmt(Stmt *input_ptr, int chid, bool is_bit_vectorized) GetChStmt::GetChStmt(Stmt *input_ptr, SNode *snode, int chid, - bool is_bit_vectorized) - : input_ptr(input_ptr), chid(chid), is_bit_vectorized(is_bit_vectorized) { + bool is_bit_vectorized, + const DebugInfo &dbg_info) + : Stmt(dbg_info), + input_ptr(input_ptr), + chid(chid), + is_bit_vectorized(is_bit_vectorized) { input_snode = snode; output_snode = input_snode->ch[chid].get(); TI_STMT_REG_FIELDS; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index c419e8e89c582c..7414824a0f8956 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -18,7 +18,8 @@ class Function; */ class AllocaStmt : public Stmt, public ir_traits::Store { public: - explicit AllocaStmt(DataType type) : is_shared(false) { + explicit AllocaStmt(DataType type, const DebugInfo &dbg_info = DebugInfo()) + : Stmt(dbg_info), is_shared(false) { if (type->is_primitive(PrimitiveTypeID::unknown)) { ret_type = type; } else { @@ -156,7 +157,9 @@ class UnaryOpStmt : public Stmt { Stmt *operand; DataType cast_type; - UnaryOpStmt(UnaryOpType op_type, Stmt *operand); + UnaryOpStmt(UnaryOpType op_type, + Stmt *operand, + const DebugInfo &dbg_info = DebugInfo()); bool same_operation(UnaryOpStmt *o) const; bool is_cast() const; @@ -200,8 +203,10 @@ class ArgLoadStmt : public Stmt { const DataType &dt, bool is_ptr, bool create_load, - int arg_depth) - : arg_id(arg_id), + int arg_depth, + const DebugInfo &dbg_info = DebugInfo()) + : Stmt(dbg_info), + arg_id(arg_id), is_ptr(is_ptr), create_load(create_load), arg_depth(arg_depth) { @@ -257,8 +262,10 @@ class BinaryOpStmt : public Stmt { BinaryOpStmt(BinaryOpType op_type, Stmt *lhs, Stmt *rhs, + const DebugInfo &dbg_info = DebugInfo(), bool is_bit_vectorized = false) - : op_type(op_type), + : Stmt(dbg_info), + op_type(op_type), lhs(lhs), rhs(rhs), is_bit_vectorized(is_bit_vectorized) { @@ -284,8 +291,12 @@ class TernaryOpStmt : public Stmt { TernaryOpType op_type; Stmt *op1, *op2, *op3; - TernaryOpStmt(TernaryOpType op_type, Stmt *op1, Stmt *op2, Stmt *op3) - : op_type(op_type), op1(op1), op2(op2), op3(op3) { + TernaryOpStmt(TernaryOpType op_type, + Stmt *op1, + Stmt *op2, + Stmt *op3, + const DebugInfo &dbg_info = DebugInfo()) + : Stmt(dbg_info), op_type(op_type), op1(op1), op2(op2), op3(op3) { TI_ASSERT(!op1->is()); TI_ASSERT(!op2->is()); TI_ASSERT(!op3->is()); @@ -311,8 +322,15 @@ class AtomicOpStmt : public Stmt, Stmt *dest, *val; bool is_reduction; - AtomicOpStmt(AtomicOpType op_type, Stmt *dest, Stmt *val) - : op_type(op_type), dest(dest), val(val), is_reduction(false) { + AtomicOpStmt(AtomicOpType op_type, + Stmt *dest, + Stmt *val, + const DebugInfo &dbg_info = DebugInfo()) + : Stmt(dbg_info), + op_type(op_type), + dest(dest), + val(val), + is_reduction(false) { TI_STMT_REG_FIELDS; } @@ -405,7 +423,8 @@ class GlobalPtrStmt : public Stmt { GlobalPtrStmt(SNode *snode, const std::vector &indices, bool activate = true, - bool is_cell_access = false); + bool is_cell_access = false, + const DebugInfo &dbg_info = DebugInfo()); bool has_global_side_effect() const override { return activate; @@ -486,7 +505,7 @@ class MatrixPtrStmt : public Stmt { Stmt *origin{nullptr}; Stmt *offset{nullptr}; - MatrixPtrStmt(Stmt *, Stmt *, const std::string & = ""); + MatrixPtrStmt(Stmt *, Stmt *, const DebugInfo & = DebugInfo()); /* TODO(zhanlue/yi): Unify semantics of offset in MatrixPtrStmt @@ -544,7 +563,8 @@ class SNodeOpStmt : public Stmt, public ir_traits::Store { SNodeOpStmt(SNodeOpType op_type, SNode *snode, Stmt *ptr, - Stmt *val = nullptr); + Stmt *val = nullptr, + const DebugInfo &dbg_info = DebugInfo()); static bool activation_related(SNodeOpType op); @@ -775,7 +795,10 @@ class GlobalStoreStmt : public Stmt, public ir_traits::Store { Stmt *dest; Stmt *val; - GlobalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) { + GlobalStoreStmt(Stmt *dest, + Stmt *val, + const DebugInfo &dbg_info = DebugInfo()) + : Stmt(dbg_info), dest(dest), val(val) { TI_STMT_REG_FIELDS; } @@ -803,7 +826,8 @@ class LocalLoadStmt : public Stmt, public ir_traits::Load { public: Stmt *src; - explicit LocalLoadStmt(Stmt *src) : src(src) { + explicit LocalLoadStmt(Stmt *src, const DebugInfo &dbg_info = DebugInfo()) + : Stmt(dbg_info), src(src) { TI_STMT_REG_FIELDS; } @@ -1330,11 +1354,15 @@ class GetChStmt : public Stmt { // irpass::type_check() bool overrided_dtype = false; - GetChStmt(Stmt *input_ptr, int chid, bool is_bit_vectorized = false); + GetChStmt(Stmt *input_ptr, + int chid, + bool is_bit_vectorized = false, + const DebugInfo &dbg_info = DebugInfo()); GetChStmt(Stmt *input_ptr, SNode *snode, int chid, - bool is_bit_vectorized = false); + bool is_bit_vectorized = false, + const DebugInfo &dbg_info = DebugInfo()); bool has_global_side_effect() const override { return false; diff --git a/taichi/math/svd.h b/taichi/math/svd.h index 898d0306601922..7f87f514bff934 100644 --- a/taichi/math/svd.h +++ b/taichi/math/svd.h @@ -61,9 +61,8 @@ sifakis_svd_export(ASTBuilder *ast_builder, const Expr &mat, int num_iters) { constexpr Tf Sine_Pi_Over_Eight = 0.3826834323650897f; constexpr Tf Cosine_Pi_Over_Eight = 0.9238795325112867f; - std::string tb = ""; - auto Var = [ast_builder, tb](const taichi::lang::Expr &x) { - return ast_builder->make_var(x, tb); + auto Var = [ast_builder](const taichi::lang::Expr &x) { + return ast_builder->make_var(x); }; auto Sfour_gamma_squared = Var(Expr(Tf(0.0))); diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 153103ea1f1e91..52607b5b8ed04a 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -314,7 +314,7 @@ Kernel &Program::get_snode_writer(SNode *snode) { std::vector{snode->num_active_indices}, snode->dt->get_compute_type()); argload_expr->type_check(&this->compile_config()); - builder.insert_assignment(expr, argload_expr, expr->get_tb()); + builder.insert_assignment(expr, argload_expr, expr->dbg_info); }); ker.name = kernel_name; ker.is_accessor = true; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index cdd46c7e36e887..afa4824315d7ff 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -828,7 +828,7 @@ void export_lang(py::module &m) { SNodeGradType::kPrimal; }) .def("is_lvalue", [](Expr *expr) { return expr->expr->is_lvalue(); }) - .def("set_tb", &Expr::set_tb) + .def("set_dbg_info", &Expr::set_dbg_info) .def("set_name", [&](Expr *expr, std::string na) { expr->cast()->name = na; @@ -1083,7 +1083,7 @@ void export_lang(py::module &m) { m.def( "subscript_with_multiple_indices", Expr::make &, - const std::vector &, std::string>); + const std::vector &, const DebugInfo &>); m.def("get_external_tensor_element_dim", [](const Expr &expr) { TI_ASSERT(expr.is()); diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index b1e647f1ce5aec..2594767e44b3aa 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -382,7 +382,7 @@ class AlgSimp : public BasicStmtVisitor { Stmt::make(BinaryOpType::bit_shl, stmt->lhs, new_rhs); result->ret_type = stmt->ret_type; - result->set_tb(stmt->get_tb()); + result->dbg_info = stmt->dbg_info; stmt->replace_usages_with(result.get()); modifier.insert_before(stmt, std::move(result)); modifier.erase(stmt); diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index 4c7d7e75b297c7..e62129db347795 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -33,7 +33,7 @@ TEST(FrontendTypeInference, Id) { auto const_i32 = value(-(1 << 20)); const_i32->type_check(nullptr); auto id_i32 = - kernel->context->builder().make_var(const_i32, const_i32->get_tb()); + kernel->context->builder().make_var(const_i32, const_i32->dbg_info); EXPECT_EQ(id_i32->ret_type, DataType(TypeFactory::get_instance().get_pointer_type( PrimitiveType::i32)));