Skip to content

Commit

Permalink
try something else..
Browse files Browse the repository at this point in the history
  • Loading branch information
ubfx committed Dec 8, 2023
1 parent a872c60 commit 123f638
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,15 +848,14 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
indices);
};

// expand F,C,H,W -> F/G,G,C,H,W
// expand F,C,H,W -> G,F/G,C,H,W
auto expandWeight = [&](Value tensor) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = makeShapeTorchCompatible(inType.getShape());

SmallVector<int64_t> outShape{(inShape[0] == kUnknownSize
? kUnknownSize
: inShape[0] / groupSize),
groupSize};
SmallVector<int64_t> outShape{
groupSize, (inShape[0] == kUnknownSize ? kUnknownSize
: inShape[0] / groupSize)};
outShape.append(inShape.begin() + 1, inShape.end());

SmallVector<ReassociationIndices> indices{{0, 1}};
Expand All @@ -870,21 +869,19 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

Value paddedInputExpanded = expandGroups(paddedInput, 1);
Value weightExpanded = expandWeight(weight);
Value outputTensorExpanded = expandGroups(outputTensor, 1);
auto expandOutputTensor = expandGroups(outputTensor, 1);

// TODO: add 1D and 3D case
conv = rewriter
.create<linalg::Conv2DNgchwFgchwOp>(
loc, outputTensorExpanded.getType(),
.create<linalg::Conv2DNgchwGfchwOp>(
loc, expandOutputTensor.getResultType(),
ValueRange{paddedInputExpanded, weightExpanded},
outputTensorExpanded, stridesAttr, dilationAttr)
expandOutputTensor, stridesAttr, dilationAttr)
.getResult(0);

SmallVector<ReassociationIndices> indices{{0}, {1, 2}};
for (auto dim = 3; dim <= (int64_t)inRank; dim++)
indices.push_back({dim});
conv = rewriter.create<tensor::CollapseShapeOp>(
loc, outputTensor.getType(), conv, indices);
loc, outputTensor.getType(), conv,
expandOutputTensor.getReassociationIndices());
}

Type newResultType = getTypeConverter()->convertType(op.getType());
Expand Down

0 comments on commit 123f638

Please sign in to comment.