Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped conv2d: Use MLIR Op which matches memory layout of weight dimensions #2623

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
indices);
};

// expand F,C,H,W -> G,F/G,C,H,W
auto expandWeight = [&](Value tensor) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = makeShapeTorchCompatible(inType.getShape());
Expand All @@ -868,21 +869,19 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

Value paddedInputExpanded = expandGroups(paddedInput, 1);
Value weightExpanded = expandWeight(weight);
Value outputTensorExpanded = expandGroups(outputTensor, 1);
auto expandOutputTensor = expandGroups(outputTensor, 1);

// TODO: add 1D and 3D case
conv = rewriter
.create<linalg::Conv2DNgchwFgchwOp>(
loc, outputTensorExpanded.getType(),
.create<linalg::Conv2DNgchwGfchwOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded},
outputTensorExpanded, stridesAttr, dilationAttr)
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
.getResult(0);

SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
indices.push_back({dim});
conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv, indices);
loc, outputTensor.getType(), conv,
expandOutputTensor.getReassociationIndices());
}

Type newResultType = getTypeConverter()->convertType(op.getType());
Expand Down
15 changes: 0 additions & 15 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@
"IscloseStaticModuleTrue_basic"
}

if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
LINALG_XFAIL_SET |= {
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"ConvolutionModule2DGroups_basic",
}


TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors

Expand Down Expand Up @@ -316,13 +308,6 @@
"ArangeStartOutViewModule_basic",
}

if torch_version_for_comparison() >= version.parse("2.2.0.dev20231204"):
TORCHDYNAMO_XFAIL_SET |= {
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"ConvolutionModule2DGroups_basic",
}

TORCHDYNAMO_CRASHING_SET = {
# No upstream decompositions.
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
Expand Down