Skip to content

Commit

Permalink
[Relay][Convert Layout] Enable layout transformation for image.resize…
Browse files Browse the repository at this point in the history
… op (apache#8205)

* Enable layout transformation for image.resize op

* Change str map function to str and index retrieval

* Fix for pytorch frontend segmentation models test
  • Loading branch information
jtuyls authored and ylc committed Jan 13, 2022
1 parent 61fab5d commit cf9777c
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
31 changes: 31 additions & 0 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .. import op as reg
from .. import strategy
from ..op import OpPattern
from .image import resize


# resize
Expand Down Expand Up @@ -58,6 +59,36 @@ def compute_resize(attrs, inputs, out_type):
reg.register_injective_schedule("image.resize")


@reg.register_convert_op_layout("image.resize")
def convert_image_resize(attrs, inputs, tinfos, desired_layouts):
"""Convert Layout pass registration for image resize op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current resize op
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of layout strings
List of layouts defining our desired
layout for the data input.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""

new_attrs = dict(attrs)
assert len(desired_layouts) == 1, "Only one desired layout is expected"
desired_layout = str(desired_layouts[0])
assert desired_layout != "default", "Layout cannot be default"
new_attrs["layout"] = desired_layout
return resize(*inputs, **new_attrs)


@script
def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis):
out = output_tensor((4,), "int64")
Expand Down
26 changes: 26 additions & 0 deletions src/relay/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,31 @@ namespace relay {

TVM_REGISTER_NODE_TYPE(ResizeAttrs);

template <typename T>
Array<Array<Layout> > ResizeInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());

if (new_in_layouts.defined()) {
ICHECK_EQ(new_in_layouts.size(), 1);

Layout raw_layout(params->layout);
Layout new_layout = new_in_layouts[0];
Layout old_layout = old_in_layouts[0];
if (!new_layout.Equals(old_layout) && raw_layout.Equals(old_layout) &&
new_layout->axes.size() == old_layout->axes.size()) {
// Follow input layout
params->layout = new_layout.name();
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}

bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 2);
Expand Down Expand Up @@ -102,6 +127,7 @@ RELAY_REGISTER_OP("image.resize")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5)
.add_type_rel("Resize", ResizeRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ResizeInferCorrectLayout<ResizeAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);

TVM_REGISTER_NODE_TYPE(Resize3dAttrs);
Expand Down
86 changes: 86 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,90 @@ def expected():
_test_conv_reduce_convert_layout2()


def test_image_resize_convert_layout():
def _test_image_resize_convert_layout_nchw_to_nhwc():
def before():
x = relay.var("x", shape=(1, 2, 4, 4))
y = relay.image.resize(x, (8, 8))
y = relay.Function([x], y)
return y

def expected():
x = relay.var("x", shape=(1, 2, 4, 4))
x = relay.layout_transform(x, "NCHW", "NHWC")
y = relay.image.resize(x, (8, 8), layout="NHWC")
y = relay.layout_transform(y, "NHWC", "NCHW")
y = relay.Function(relay.analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def _test_image_resize_convert_layout_nhwc_to_nchw():
def before():
x = relay.var("x", shape=(1, 4, 4, 2))
y = relay.image.resize(x, (8, 8), layout="NHWC")
y = relay.Function([x], y)
return y

def expected():
x = relay.var("x", shape=(1, 4, 4, 2))
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.image.resize(x, (8, 8), layout="NCHW")
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(relay.analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

_test_image_resize_convert_layout_nchw_to_nhwc()
_test_image_resize_convert_layout_nhwc_to_nchw()


def test_conv_image_resize_convert_layout():
"""Check that layout transforms are propagated through image resize."""

def before():
x = relay.var("x", shape=(1, 56, 56, 64))
weight = relay.var("weight", shape=(3, 3, 64, 64))
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.image.resize(y, (112, 112), layout="NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y

def expected():
x = relay.var("x", shape=(1, 56, 56, 64))
w = relay.var("weight", shape=(3, 3, 64, 64))
x = relay.layout_transform(x, "NHWC", "NCHW")
w = relay.layout_transform(w, "HWIO", "OIHW")
y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
y = relay.image.resize(y, (112, 112), layout="NCHW")
y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


if __name__ == "__main__":
test_qnn_binary_no_convert_layout()
test_no_convert_layout()
Expand Down Expand Up @@ -1828,3 +1912,5 @@ def expected():
test_conv_squeeze_convert_layout()
test_conv_reduce_convert_layout()
test_conv_strided_slice_axes_convert_layout()
test_image_resize_convert_layout()
test_conv_image_resize_convert_layout()

0 comments on commit cf9777c

Please sign in to comment.