Skip to content

Commit

Permalink
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimens…
Browse files Browse the repository at this point in the history
…ion ordering (#2623)

The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where

1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W

Now this has been fixed in
llvm/llvm-project#73855 which broke the
torch-mlir lowering to that Op.

This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.

Fix #2622
  • Loading branch information
ubfx authored Dec 8, 2023
1 parent 8252656 commit fb21a85
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 23 deletions.
15 changes: 7 additions & 8 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
indices);
};

// expand F,C,H,W -> G,F/G,C,H,W
auto expandWeight = [&](Value tensor) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = makeShapeTorchCompatible(inType.getShape());
Expand All @@ -868,21 +869,19 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

Value paddedInputExpanded = expandGroups(paddedInput, 1);
Value weightExpanded = expandWeight(weight);
Value outputTensorExpanded = expandGroups(outputTensor, 1);
auto expandOutputTensor = expandGroups(outputTensor, 1);

// TODO: add 1D and 3D case
conv = rewriter
.create<linalg::Conv2DNgchwFgchwOp>(
loc, outputTensorExpanded.getType(),
.create<linalg::Conv2DNgchwGfchwOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded},
outputTensorExpanded, stridesAttr, dilationAttr)
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);

SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
indices.push_back({dim});
conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv, indices);
loc, outputTensor.getType(), conv,
expandOutputTensor.getReassociationIndices());
}

Type newResultType = getTypeConverter()->convertType(op.getType());
Expand Down
15 changes: 0 additions & 15 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@
"IscloseStaticModuleTrue_basic"
}

if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
LINALG_XFAIL_SET |= {
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"ConvolutionModule2DGroups_basic",
}


TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors

Expand Down Expand Up @@ -316,13 +308,6 @@
"ArangeStartOutViewModule_basic",
}

if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
TORCHDYNAMO_XFAIL_SET |= {
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"ConvolutionModule2DGroups_basic",
}

TORCHDYNAMO_CRASHING_SET = {
# No upstream decompositions.
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
Expand Down

0 comments on commit fb21a85

Please sign in to comment.