diff --git a/requirements_test.txt b/requirements_test.txt index 80686f769a146..2dd9ca670e20f 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -7,7 +7,7 @@ numpy psutil autograd requests -matplotlib +matplotlib<=3.7.3 cffi scipy setproctitle diff --git a/taichi/analysis/gather_dynamically_indexed_pointers.cpp b/taichi/analysis/gather_dynamically_indexed_pointers.cpp new file mode 100644 index 0000000000000..88f92dbe7fba0 --- /dev/null +++ b/taichi/analysis/gather_dynamically_indexed_pointers.cpp @@ -0,0 +1,124 @@ +#include "taichi/analysis/gather_uniquely_accessed_pointers.h" +#include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/visitors.h" +#include + +namespace taichi::lang { + +bool is_leaf_nodes_on_same_branch(SNode *snode0, SNode *snode1) { + // Verify: place snode + if (!snode0->is_place() || !snode1->is_place()) { + return false; + } + + // Check parent snode + if (snode0->parent != snode1->parent) { + return false; + } + + return true; +} + +class DynamicIndexingAnalyzer : public BasicStmtVisitor { + void record_dynamic_indexed_ptr(ExternalPtrStmt *extern_ptr) { + dynamically_indexed_ptrs_.insert(extern_ptr); + // Find aliased ExternPtrStmt + for (auto *other_extern_ptr : extern_ptrs_) { + if (other_extern_ptr != extern_ptr && + other_extern_ptr->base_ptr == extern_ptr->base_ptr) { + // Aliased ExternalPtrStmt, with same base_ptr and outter index + dynamically_indexed_ptrs_.insert(other_extern_ptr); + } + } + } + + void record_dynamic_indexed_ptr(GlobalPtrStmt *global_ptr) { + dynamically_indexed_ptrs_.insert(global_ptr); + // Find aliased GlobalPtrStmt + for (auto *other_global_ptr : global_ptrs_) { + if (other_global_ptr != global_ptr && + is_leaf_nodes_on_same_branch(other_global_ptr->snode, + global_ptr->snode)) { + dynamically_indexed_ptrs_.insert(other_global_ptr); + } + } + } + + public: + explicit DynamicIndexingAnalyzer(IRNode *node) { + } + + void visit(GlobalPtrStmt *stmt) override { + for (auto *index_stmt : stmt->indices) { + if (!index_stmt->is() && !index_stmt->is()) { + record_dynamic_indexed_ptr(stmt); + } + } + + global_ptrs_.insert(stmt); + } + + void visit(ExternalPtrStmt *stmt) override { + for (auto *index_stmt : stmt->indices) { + if (!index_stmt->is() && !index_stmt->is()) { + record_dynamic_indexed_ptr(stmt); + } + } + + extern_ptrs_.insert(stmt); + } + + void visit(MatrixPtrStmt *stmt) override { + GlobalPtrStmt *global_ptr = nullptr; + ExternalPtrStmt *extern_ptr = nullptr; + + if (stmt->origin->is()) { + global_ptr = stmt->origin->as(); + } else if (stmt->origin->is()) { + extern_ptr = stmt->origin->as(); + } else { + return; + } + + // Is dynamic index + if (stmt->offset->is()) { + return; + } + + if (global_ptr) { + record_dynamic_indexed_ptr(global_ptr); + } + + if (extern_ptr) { + record_dynamic_indexed_ptr(extern_ptr); + } + } + + std::unordered_set get_dynamically_indexed_ptrs() { + return dynamically_indexed_ptrs_; + } + + private: + using BasicStmtVisitor::visit; + std::unordered_set dynamically_indexed_ptrs_; + std::unordered_set global_ptrs_; + std::unordered_set extern_ptrs_; +}; + +namespace irpass::analysis { + +std::unordered_set gather_dynamically_indexed_pointers(IRNode *root) { + DynamicIndexingAnalyzer pass(root); + + // This pass is intended to run twice + root->accept(&pass); + root->accept(&pass); + + auto dynamically_indexed_ptrs = pass.get_dynamically_indexed_ptrs(); + return dynamically_indexed_ptrs; +} + +} // namespace irpass::analysis +} // namespace taichi::lang diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index 7fd6c69d06ba7..df63a2d635d0f 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -111,6 +111,7 @@ std::tuple, hashing::Hasher>>, std::unordered_set> gather_uniquely_accessed_pointers(IRNode *root); +std::unordered_set gather_dynamically_indexed_pointers(IRNode *root); std::unique_ptr> gather_used_atomics( IRNode *root); diff --git a/taichi/transforms/cache_loop_invariant_global_vars.cpp b/taichi/transforms/cache_loop_invariant_global_vars.cpp index 5020dc620bced..6817747986c1f 100644 --- a/taichi/transforms/cache_loop_invariant_global_vars.cpp +++ b/taichi/transforms/cache_loop_invariant_global_vars.cpp @@ -26,6 +26,8 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { loop_unique_arr_ptr_; std::unordered_set loop_unique_matrix_ptr_; + std::unordered_set dynamic_indexed_ptrs_; + OffloadedStmt *current_offloaded; explicit CacheLoopInvariantGlobalVars(const CompileConfig &config) @@ -44,6 +46,9 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { std::move(std::get<2>(uniquely_accessed_pointers)); } current_offloaded = stmt; + dynamic_indexed_ptrs_ = + irpass::analysis::gather_dynamically_indexed_pointers(stmt); + // We don't need to visit TLS/BLS prologues/epilogues. if (stmt->body) { if (stmt->task_type == OffloadedStmt::TaskType::range_for || @@ -56,6 +61,28 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { current_offloaded = nullptr; } + bool is_dynamically_indexed(Stmt *stmt) { + // Handle GlobalPtrStmt + Stmt *ptr_stmt = nullptr; + if (stmt->is()) { + ptr_stmt = stmt->as(); + } else if (stmt->is() && + stmt->as()->origin->is()) { + ptr_stmt = stmt->as()->origin->as(); + } else if (stmt->is()) { + ptr_stmt = stmt->as(); + } else if (stmt->is() && + stmt->as()->origin->is()) { + ptr_stmt = stmt->as()->origin->as(); + } + + if (ptr_stmt && dynamic_indexed_ptrs_.count(ptr_stmt)) { + return true; + } + + return false; + } + bool is_offload_unique(Stmt *stmt) { if (current_offloaded->task_type == OffloadedTaskType::serial) { return true; @@ -174,6 +201,9 @@ class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { std::optional find_cache_depth_if_cacheable(Stmt *operand, Block *current_scope) { + if (is_dynamically_indexed(operand)) { + return std::nullopt; + } if (!is_offload_unique(operand)) { return std::nullopt; } diff --git a/tests/python/test_cache_loop_invariant.py b/tests/python/test_cache_loop_invariant.py new file mode 100644 index 0000000000000..69c9af4e3c7e2 --- /dev/null +++ b/tests/python/test_cache_loop_invariant.py @@ -0,0 +1,27 @@ +import pytest +from taichi.lang import impl + +import taichi as ti +from tests import test_utils + + +@test_utils.test(arch=[ti.cuda, ti.cpu]) +def test_local_matrix_non_constant_index_real_matrix(): + N = 1 + x = ti.Vector.field(3, float, shape=1) + + @ti.kernel + def test_invariant_cache(): + for i in range(1): + x[i][1] = x[i][1] + 1.0 + for j in range(1): + x[i][1] = x[i][1] - 5.0 + for z in range(1): + idx = 0 + if z == 0: + idx = 1 + x_print = x[i][idx] + + assert x_print == x[i][1] + + test_invariant_cache()