Skip to content

Commit

Permalink
Updated impacted unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Oct 20, 2023
1 parent 9701459 commit b3f0a94
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
9 changes: 3 additions & 6 deletions tests/python/relax/test_optimize_layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,9 @@ def main(
(lv1, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv2_1
R.output(gv)
return gv

Expand Down Expand Up @@ -256,10 +255,9 @@ def main(
(lv3, lv4),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv6
R.output(gv)
return gv

Expand Down Expand Up @@ -399,10 +397,9 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"
pad_value=None,
axis_separators=[],
)
lv_2 = R.call_tir(
gv = R.call_tir(
Expected.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32")
)
gv: R.Tensor((14,), dtype="float32") = lv_2
R.output(gv)
return gv

Expand Down
3 changes: 1 addition & 2 deletions tests/python/relax/test_remove_redundant_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001), dtype="float16"):
with R.dataflow():
lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
gv: R.Tensor((1, 1001), dtype="float16") = lv
gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
R.output(gv)
return gv

Expand Down

0 comments on commit b3f0a94

Please sign in to comment.