-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR][USMP] Add a parallel to serial for loop converter pass #8469
Merged
areusch
merged 5 commits into
apache:main
from
manupak:usmp_parallel_for_to_serial_for_converter
Oct 12, 2021
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
bfcac8c
[TIR][USMP] Add a parallel to serial for loop converter pass
manupak 2b9f5d7
[TIR][USMP] Add a parallel to serial for loop converter pass
manupak 3dac799
[TIR][USMP] Add a parallel to serial for loop converter pass
manupak a424b9a
[TIR][USMP] Add a parallel to serial for loop converter pass
manupak 26571f8
[TIR][USMP] Add a parallel to serial for loop converter pass
manupak File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tir/transforms/convert_for_loops_serial.cc | ||
* \brief Convert all for loops to serial for lesser memory consumption | ||
*/ | ||
#include <tvm/arith/analyzer.h> | ||
#include <tvm/runtime/device_api.h> | ||
#include <tvm/tir/function.h> | ||
#include <tvm/tir/stmt_functor.h> | ||
#include <tvm/tir/transform.h> | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
class ForLoopSerialConverter : public StmtExprMutator { | ||
public: | ||
ForLoopSerialConverter() = default; | ||
Stmt operator()(const PrimFunc& func); | ||
|
||
private: | ||
Stmt VisitStmt_(const ForNode* op) override; | ||
}; | ||
|
||
Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) { | ||
if (op->kind == ForKind::kParallel) { | ||
return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, op->thread_binding, | ||
op->annotations, op->span); | ||
} | ||
return StmtExprMutator::VisitStmt_(op); | ||
} | ||
|
||
Stmt ForLoopSerialConverter::operator()(const PrimFunc& func) { | ||
return this->VisitStmt(func->body); | ||
} | ||
|
||
PrimFunc ConvertForLoopsToSerial(PrimFunc func) { | ||
PrimFuncNode* fptr = func.CopyOnWrite(); | ||
fptr->body = ForLoopSerialConverter()(func); | ||
return func; | ||
} | ||
|
||
namespace transform { | ||
|
||
Pass ConvertForLoopsToSerial() { | ||
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { | ||
return ConvertForLoopsToSerial(std::move(f)); | ||
}; | ||
return CreatePrimFuncPass(pass_func, 0, "tir.ConvertForLoopsToSerial", {}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tir.transform.ConvertForLoopsToSerial") | ||
.set_body_typed(ConvertForLoopsToSerial); | ||
|
||
} // namespace transform | ||
|
||
} // namespace tir | ||
} // namespace tvm |
62 changes: 62 additions & 0 deletions
62
tests/python/unittest/test_tir_transform_convert_for_loops_serial.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
import pytest | ||
|
||
import tvm | ||
from tvm import tir, script | ||
from tvm.script import ty | ||
from tvm.tir import stmt_functor | ||
|
||
# fmt: off | ||
@tvm.script.tir | ||
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None: | ||
# function attr dict | ||
tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) | ||
placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) | ||
placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) | ||
placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) | ||
T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) | ||
# body | ||
PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global") | ||
for i0_i1_fused_3 in tir.parallel(0, 28): | ||
for i2_3, i3_3 in tir.grid(28, 192): | ||
tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True) | ||
for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784): | ||
for ax3_2 in tir.serial(0, 16): | ||
Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global") | ||
tir.store(Conv2dOutput_3, 0, 0, True) | ||
for rc_3 in tir.serial(0, 192): | ||
tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True) | ||
tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) | ||
# fmt: on | ||
|
||
|
||
def test_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(): | ||
primfunc = fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2 | ||
mod = tvm.IRModule.from_expr(primfunc) | ||
mod = tvm.tir.transform.ConvertForLoopsToSerial()(mod) | ||
|
||
def verify_serial_loops(stmt): | ||
if isinstance(stmt, tvm.tir.For): | ||
assert stmt.kind == tvm.tir.ForKind.SERIAL | ||
|
||
for _, primfunc in mod.functions.items(): | ||
stmt_functor.post_order_visit(primfunc.body, verify_serial_loops) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
want to assert that you find at least one kParallel for loop in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to ? I mean its written in the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's consider this blocked on testing infrastructure. a common pattern in tests is for the data to be re-used in multiple tests and then lose the "why" behind the test. that's where my request is coming from.