Skip to content

Commit

Permalink
Add a decomposition for torch.aten.argmin (#2613)
Browse files Browse the repository at this point in the history
Adds a lowering for the torch.aten.argmin operator to linalg via decomposition into torch.aten.min.dim.

---------

Co-authored-by: Franz Haniel <[email protected]>
  • Loading branch information
frafranz and FranzHaniel authored Dec 6, 2023
1 parent 6244f30 commit c011570
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 79 deletions.
8 changes: 8 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6846,6 +6846,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: getting num_classes from tensor contents is not supported\"\n"
Expand Down Expand Up @@ -10659,6 +10663,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.argmin\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %int0 = torch.constant.int 0\n"
Expand Down
27 changes: 15 additions & 12 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,12 +840,13 @@ class DecomposeAten_LogSoftmaxBackwardDataOp
};
} // namespace

// Decompose `AtenArgMaxOp` into `AtenMaxDimOp`.
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp`
namespace {
class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
template <typename OpTy, typename DecompOpTy>
class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenArgmaxOp op,
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
Expand All @@ -870,7 +871,7 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
.cast<BaseTensorType>();

// If the dim type is `NoneType` i.e. reduce along all the dimensions.
// `AtenMaxDimOp` doesn't support dim as `NoneType` so first the input
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input
// tensor is flattened to 1d tensor and then the reduction happens on the
// 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) {
Expand All @@ -885,13 +886,14 @@ class DecomposeAtenArgMaxOp : public OpRewritePattern<AtenArgmaxOp> {
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
dim, end);
}
Value maxResult =
rewriter
.create<AtenMaxDimOp>(loc, valueTensorType, indicesTensorType,
input, dim, keepDim)
.getIndices();

rewriter.replaceOp(op, maxResult);
Value resultArg =
rewriter
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType,
input, dim, keepDim)
.getIndices();

rewriter.replaceOp(op, resultArg);
return success();
}
};
Expand Down Expand Up @@ -5774,7 +5776,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenArangeOp>();
target.addIllegalOp<AtenArangeStartOp>();
target.addIllegalOp<AtenArgmaxOp>();
target.addIllegalOp<AtenArgminOp>();
target.addIllegalOp<AtenSquareOp>();
target.addIllegalOp<AtenVarOp>();
target.addIllegalOp<AtenStdOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] =
def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.argmax(self, dim, keepdim)

def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]:
# There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here.
return upstream_shape_functions.argmax(self, dim, keepdim)

# TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor,
# making it impossible to add support for it using the current design of the shape library.
def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
Expand Down Expand Up @@ -3254,7 +3258,10 @@ def aten〇mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Li

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇argmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return torch.int64

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇argmin〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int:
return torch.int64

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def register_all_tests():
from . import type_conversion
from . import backprop
from . import reduction
from . import argmax
from . import matmul
from . import reshape_like
from . import scalar
Expand Down
65 changes: 0 additions & 65 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/argmax.py

This file was deleted.

165 changes: 165 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,171 @@ def ReduceMinUnsignedIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=100))

# ==============================================================================

class ArgminModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])

def forward(self, a):
return torch.ops.aten.argmin(a)


@register_test_case(module_factory=lambda: ArgminModule())
def ArgminModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

# ==============================================================================

class ArgminIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])

def forward(self, a):
return torch.ops.aten.argmin(a)


@register_test_case(module_factory=lambda: ArgminIntModule())
def ArgminIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100))

@register_test_case(module_factory=lambda: ArgminIntModule())
def ArgminIntModule_multiple_mins(module, tu: TestUtils):
# To cover the special case that the minimal value occurs more than once.
# The pytorch convention is here to consider the first occurence as the argmin.
module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64))

# ==============================================================================

class ArgminWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmin(a, dim=1)

@register_test_case(module_factory=lambda: ArgminWithDimModule())
def ArgminModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ArgminKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmin(a, 0, True)

@register_test_case(module_factory=lambda: ArgminKeepDimsModule())
def ArgminModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))

# ==============================================================================

class ArgmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])

def forward(self, a):
return torch.ops.aten.argmax(a)


@register_test_case(module_factory=lambda: ArgmaxModule())
def ArgmaxModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

# ==============================================================================

class ArgmaxIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])

def forward(self, a):
return torch.ops.aten.argmax(a)


@register_test_case(module_factory=lambda: ArgmaxIntModule())
def ArgmaxIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100))

@register_test_case(module_factory=lambda: ArgmaxIntModule())
def ArgmaxIntModule_multiple_maxs(module, tu: TestUtils):
# To cover the special case that the maximal value occurs more than once.
# The pytorch convention is here to consider the first occurence as the argmax.
module.forward(torch.full((3,4), tu.randint(1).item(), dtype=torch.int64))

# ==============================================================================

class ArgmaxWithDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmax(a, dim=1)

@register_test_case(module_factory=lambda: ArgmaxWithDimModule())
def ArgmaxModule_with_dim(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ArgmaxKeepDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.argmax(a, 0, True)

@register_test_case(module_factory=lambda: ArgmaxKeepDimsModule())
def ArgmaxModule_keepDim(module, tu: TestUtils):
module.forward(tu.rand(4, 6))

# ==============================================================================

class ReduceL1NormModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit c011570

Please sign in to comment.