diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e94b966bc0fc..017078bd7bf7 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -463,6 +463,15 @@ TVM_DLL Pass UnifyThreadBinding(); */ TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); +/*! + * \brief This pass is post-scheduling pass to convert all + * Parallel For loops to Serial ones. This is run + * to attain lesser memory and/or executor/backend + * does not support parallel launch of For loops. + * \return The pass. + */ +TVM_DLL Pass ConvertForLoopsToSerial(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f072f6b38a43..1abba77a801f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -715,3 +715,14 @@ def MergeDynamicSharedMemoryAllocations(): The result pass """ return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore + + +def ConvertForLoopsToSerial(): + """Convert Parallel For Loops to Serial For Loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ConvertForLoopsToSerial() # type: ignore diff --git a/src/tir/transforms/convert_for_loops_serial.cc b/src/tir/transforms/convert_for_loops_serial.cc new file mode 100644 index 000000000000..d01ae8a45113 --- /dev/null +++ b/src/tir/transforms/convert_for_loops_serial.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/transforms/convert_for_loops_serial.cc + * \brief Convert all for loops to serial for lesser memory consumption + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class ForLoopSerialConverter : public StmtExprMutator { + public: + ForLoopSerialConverter() = default; + Stmt operator()(const PrimFunc& func); + + private: + Stmt VisitStmt_(const ForNode* op) override; +}; + +Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { + if (op->kind == ForKind::kParallel) { + return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, + op->annotations, op->span); + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt ForLoopSerialConverter::operator()(const PrimFunc& func) { + return this->VisitStmt(func->body); +} + +PrimFunc ConvertForLoopsToSerial(PrimFunc func) { + PrimFuncNode* fptr = func.CopyOnWrite(); + fptr->body = ForLoopSerialConverter()(func); + return func; +} + +namespace transform { + +Pass ConvertForLoopsToSerial() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ConvertForLoopsToSerial(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") + .set_body_typed(ConvertForLoopsToSerial); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py new file mode 100644 index 000000000000..272e0d45410f --- /dev/null +++ b/tests/python/unittest/test_tir_transform_convert_for_loops_serial.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir, script +from tvm.script import ty +from tvm.tir import stmt_functor + +# fmt: off +@tvm.script.tir +def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) + placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global") + for i0_i1_fused_3 in tir.parallel(0, 28): + for i2_3, i3_3 in tir.grid(28, 192): + tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) + for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784): + for ax3_2 in tir.serial(0, 16): + Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global") + tir.store(Conv2dOutput_3, 0, 0, True) + for rc_3 in tir.serial(0, 192): + tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) + tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) +# fmt: on + + +def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): + primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 + mod = tvm.IRModule.from_expr(primfunc) + mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod) + + def verify_serial_loops(stmt): + if isinstance(stmt, tvm.tir.For): + assert stmt.kind == tvm.tir.ForKind.SERIAL + + for _, primfunc in mod.functions.items(): + stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) + + +if __name__ == "__main__": + pytest.main([__file__])