-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR][USMP] Add a parallel to serial for loop converter pass (#8469)
* [TIR][USMP] Add a parallel to serial for loop converter pass This is an optional pass to convert all parallel for loops in TIR to serial ones for different reasons such as executor does not support parallel launch of for loops (e.g., AoT) or allocating space for parallel for loops might not be desired. * Additionally adding FFI scaffolding for USMP Change-Id: Id5e8ccb90140d2d3ae113b20a3ca152a54497c45 * [TIR][USMP] Add a parallel to serial for loop converter pass * remove unused import Change-Id: I29d5fdec92120418596f9dba1d6630f65620a603 * [TIR][USMP] Add a parallel to serial for loop converter pass *moved the pass to tir namespace Change-Id: I74720ca2f566066b3a4f22f504d8f0f684c99dc2 * [TIR][USMP] Add a parallel to serial for loop converter pass * fixed docstring Change-Id: I73bb9867fe2ed6a86f65666493c5c6e3edf87b49 * [TIR][USMP] Add a parallel to serial for loop converter pass * fixed mypy lint error Change-Id: I226ef27d5536674fbe4b2d2c6ff47b8cb3b41431
- Loading branch information
Showing
4 changed files
with
157 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tir/transforms/convert_for_loops_serial.cc | ||
* \brief Convert all for loops to serial for lesser memory consumption | ||
*/ | ||
#include <tvm/arith/analyzer.h> | ||
#include <tvm/runtime/device_api.h> | ||
#include <tvm/tir/function.h> | ||
#include <tvm/tir/stmt_functor.h> | ||
#include <tvm/tir/transform.h> | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
class ForLoopSerialConverter : public StmtExprMutator { | ||
public: | ||
ForLoopSerialConverter() = default; | ||
Stmt operator()(const PrimFunc& func); | ||
|
||
private: | ||
Stmt VisitStmt_(const ForNode* op) override; | ||
}; | ||
|
||
Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { | ||
if (op->kind == ForKind::kParallel) { | ||
return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, | ||
op->annotations, op->span); | ||
} | ||
return StmtExprMutator::VisitStmt_(op); | ||
} | ||
|
||
Stmt ForLoopSerialConverter::operator()(const PrimFunc& func) { | ||
return this->VisitStmt(func->body); | ||
} | ||
|
||
PrimFunc ConvertForLoopsToSerial(PrimFunc func) { | ||
PrimFuncNode* fptr = func.CopyOnWrite(); | ||
fptr->body = ForLoopSerialConverter()(func); | ||
return func; | ||
} | ||
|
||
namespace transform { | ||
|
||
Pass ConvertForLoopsToSerial() { | ||
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { | ||
return ConvertForLoopsToSerial(std::move(f)); | ||
}; | ||
return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") | ||
.set_body_typed(ConvertForLoopsToSerial); | ||
|
||
} // namespace transform | ||
|
||
} // namespace tir | ||
} // namespace tvm |
62 changes: 62 additions & 0 deletions
62
tests/python/unittest/test_tir_transform_convert_for_loops_serial.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
import pytest | ||
|
||
import tvm | ||
from tvm import tir, script | ||
from tvm.script import ty | ||
from tvm.tir import stmt_functor | ||
|
||
# fmt: off | ||
@tvm.script.tir | ||
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None: | ||
# function attr dict | ||
tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) | ||
placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) | ||
placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) | ||
placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) | ||
T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) | ||
# body | ||
PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global") | ||
for i0_i1_fused_3 in tir.parallel(0, 28): | ||
for i2_3, i3_3 in tir.grid(28, 192): | ||
tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) | ||
for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784): | ||
for ax3_2 in tir.serial(0, 16): | ||
Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global") | ||
tir.store(Conv2dOutput_3, 0, 0, True) | ||
for rc_3 in tir.serial(0, 192): | ||
tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) | ||
tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) | ||
# fmt: on | ||
|
||
|
||
def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): | ||
primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 | ||
mod = tvm.IRModule.from_expr(primfunc) | ||
mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod) | ||
|
||
def verify_serial_loops(stmt): | ||
if isinstance(stmt, tvm.tir.For): | ||
assert stmt.kind == tvm.tir.ForKind.SERIAL | ||
|
||
for _, primfunc in mod.functions.items(): | ||
stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |