diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 30ff2df7c38fc..b4a5461f4405d 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { /// Fold an insert or extract operation into an poison value when a poison index /// is found at any dimension of the static position. -static ub::PoisonAttr -foldPoisonIndexInsertExtractOp(MLIRContext *context, - ArrayRef staticPos, int64_t poisonVal) { +static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, + ArrayRef staticPos, + int64_t poisonVal) { if (!llvm::is_contained(staticPos, poisonVal)) - return ub::PoisonAttr(); + return {}; return ub::PoisonAttr::get(context); } +/// Fold a vector extract from is a poison source. +static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) { + if (llvm::isa_and_nonnull(srcAttr)) + return srcAttr; + + return {}; +} + OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v. // Note: Do not fold "vector.extract %v[] : f32 from vector" (type @@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { if (auto res = foldPoisonIndexInsertExtractOp( getContext(), adaptor.getStaticPosition(), kPoisonIndex)) return res; + if (auto res = foldPoisonSrcExtractOp(adaptor.getVector())) + return res; if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 3c92b222e6bc8..6135a1290d559 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m, ImplicitLocOpBuilder b(source.getLoc(), builder); SmallVector vs; for (int64_t i = 0; i < m; ++i) - vs.push_back(b.create(source, i)); + vs.push_back(b.createOrFold(source, i)); // Interleave 32-bit lanes using // 8x _mm512_unpacklo_epi32 @@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern { SmallVector insertIdxs(extractIdxs); applyPermutationToVector(insertIdxs, prunedTransp); Value extractOp = - rewriter.create(loc, input, extractIdxs); - result = - rewriter.create(loc, extractOp, result, insertIdxs); + rewriter.createOrFold(loc, input, extractIdxs); + result = rewriter.createOrFold(loc, extractOp, result, + insertIdxs); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 800c1d9fb1dbf..c1e3850f05c5e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -172,7 +172,7 @@ struct UnrollTransferReadPattern readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); - result = rewriter.create( + result = rewriter.createOrFold( loc, slicedRead, result, elementOffsets, strides); } rewriter.replaceOp(readOp, result); @@ -213,7 +213,7 @@ struct UnrollTransferWritePattern Value resultTensor; for (SmallVector elementOffsets : StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { - Value slicedVector = rewriter.create( + Value slicedVector = rewriter.createOrFold( loc, writeOp.getVector(), elementOffsets, *targetShape, strides); SmallVector indices = sliceTransferIndices(elementOffsets, originalIndices, @@ -289,8 +289,9 @@ struct UnrollContractionPattern SmallVector operandShape = applyPermutationMap( permutationMap, ArrayRef(*targetShape)); SmallVector operandStrides(operandOffets.size(), 1); - slicesOperands[index] = rewriter.create( - loc, operand, operandOffets, operandShape, operandStrides); + slicesOperands[index] = + rewriter.createOrFold( + loc, operand, operandOffets, operandShape, operandStrides); }; // Extract the new lhs operand. @@ -333,7 +334,7 @@ struct UnrollContractionPattern loc, dstVecType, rewriter.getZeroAttr(dstVecType)); for (const auto &it : accCache) { SmallVector dstStrides(it.first.size(), 1); - result = rewriter.create( + result = rewriter.createOrFold( loc, it.second, result, it.first, dstStrides); } rewriter.replaceOp(contractOp, result); @@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector operands; SmallVector operandStrides(offsets.size(), 1); - Value slicedOperand = rewriter.create( - loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); + Value slicedOperand = + rewriter.createOrFold( + loc, reductionOp.getSource(), offsets, *targetShape, + operandStrides); operands.push_back(slicedOperand); SmallVector dstShape; SmallVector destOffset; @@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern if (accIt != accCache.end()) acc = accIt->second; else - acc = rewriter.create( + acc = rewriter.createOrFold( loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); operands.push_back(acc); auto targetType = VectorType::get( @@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern rewriter.getZeroAttr(reductionOp.getDestType())); for (const auto &it : accCache) { SmallVector dstStrides(it.first.size(), 1); - result = rewriter.create( + result = rewriter.createOrFold( loc, it.second, result, it.first, dstStrides); } rewriter.replaceOp(reductionOp, result); @@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern { continue; } extractOperands.push_back( - rewriter.create( + rewriter.createOrFold( loc, operand.get(), offsets, *targetShape, strides)); } Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, op, extractOperands, newVecType); - result = rewriter.create( + result = rewriter.createOrFold( loc, newOp->getResult(0), result, offsets, strides); } rewriter.replaceOp(op, result); @@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern { for (SmallVector offsets : StaticTileOffsetRange(originalSize, *targetShape)) { SmallVector strides(offsets.size(), 1); - Value slicedOperand = rewriter.create( - loc, reductionOp.getVector(), offsets, *targetShape, strides); + Value slicedOperand = + rewriter.createOrFold( + loc, reductionOp.getVector(), offsets, *targetShape, strides); Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); Value result = newOp->getResult(0); @@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern { permutedOffsets[indices.value()] = elementOffsets[indices.index()]; permutedShape[indices.value()] = (*targetShape)[indices.index()]; } - Value slicedOperand = rewriter.create( - loc, transposeOp.getVector(), permutedOffsets, permutedShape, - strides); - Value transposedSlice = - rewriter.create(loc, slicedOperand, permutation); - result = rewriter.create( + Value slicedOperand = + rewriter.createOrFold( + loc, transposeOp.getVector(), permutedOffsets, permutedShape, + strides); + Value transposedSlice = rewriter.createOrFold( + loc, slicedOperand, permutation); + result = rewriter.createOrFold( loc, transposedSlice, result, elementOffsets, strides); } rewriter.replaceOp(transposeOp, result); @@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern { // To get the unrolled gather, extract the same slice based on the // decomposed shape from each of the index, mask, and pass-through // vectors. - Value indexSubVec = rewriter.create( + Value indexSubVec = rewriter.createOrFold( loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); - Value maskSubVec = rewriter.create( + Value maskSubVec = rewriter.createOrFold( loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); - Value passThruSubVec = rewriter.create( - loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); + Value passThruSubVec = + rewriter.createOrFold( + loc, gatherOp.getPassThru(), elementOffsets, *targetShape, + strides); auto slicedGather = rewriter.create( loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), indexSubVec, maskSubVec, passThruSubVec); - result = rewriter.create( + result = rewriter.createOrFold( loc, slicedGather, result, elementOffsets, strides); } rewriter.replaceOp(gatherOp, result); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 61e858f5f226a..a74e562ad2f68 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -132,10 +132,35 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index // ----- +// CHECK-LABEL: @extract_scalar_poison +func.func @extract_scalar_poison() -> f32 { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32 + // CHECK-NOT: vector.extract + // CHECK-NEXT: return %[[UB]] : f32 + %0 = ub.poison : vector<4x8xf32> + %1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32> + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: @extract_vector_poison +func.func @extract_vector_poison() -> vector<8xf32> { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32> + // CHECK-NOT: vector.extract + // CHECK-NEXT: return %[[UB]] : vector<8xf32> + %0 = ub.poison : vector<4x8xf32> + %1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32> + return %1 : vector<8xf32> +} + +// ----- + // CHECK-LABEL: @extract_scalar_poison_idx func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32 // CHECK-NOT: vector.extract - // CHECK-NEXT: ub.poison : f32 + // CHECK-NEXT: return %[[UB]] : f32 %0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32> return %0 : f32 } @@ -144,8 +169,9 @@ func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 { // CHECK-LABEL: @extract_vector_poison_idx func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<5xf32> // CHECK-NOT: vector.extract - // CHECK-NEXT: ub.poison : vector<5xf32> + // CHECK-NEXT: return %[[UB]] : vector<5xf32> %0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32> return %0 : vector<5xf32> } @@ -155,8 +181,9 @@ func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> { // CHECK-LABEL: @extract_multiple_poison_idx func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>) -> vector<8xf32> { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32> // CHECK-NOT: vector.extract - // CHECK-NEXT: ub.poison : vector<8xf32> + // CHECK-NEXT: return %[[UB]] : vector<8xf32> %0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32> return %0 : vector<8xf32> } @@ -2886,13 +2913,47 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> { return %1 : vector<4xi8> } +// ----- + +// Insert a poison value shouldn't be folded as the resulting vector is not +// fully poison. + +// CHECK-LABEL: @insert_scalar_poison +func.func @insert_scalar_poison(%a: vector<4x8xf32>) + -> vector<4x8xf32> { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32 + // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[UB]] + // CHECK-NEXT: return %[[RES]] : vector<4x8xf32> + %0 = ub.poison : f32 + %1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32> + return %1 : vector<4x8xf32> +} + +// ----- + +// Insert a poison value shouldn't be folded as the resulting vector is not +// fully poison. + +// CHECK-LABEL: @insert_vector_poison +func.func @insert_vector_poison(%a: vector<4x8xf32>) + -> vector<4x8xf32> { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32> + // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[UB]] + // CHECK-NEXT: return %[[RES]] : vector<4x8xf32> + %0 = ub.poison : vector<8xf32> + %1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32> + return %1 : vector<4x8xf32> +} + + // ----- // CHECK-LABEL: @insert_scalar_poison_idx func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32) -> vector<4x5xf32> { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32> // CHECK-NOT: vector.insert - // CHECK-NEXT: ub.poison : vector<4x5xf32> + // CHECK-NEXT: return %[[UB]] : vector<4x5xf32> %0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32> return %0 : vector<4x5xf32> } @@ -2902,8 +2963,9 @@ func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32) // CHECK-LABEL: @insert_vector_poison_idx func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>) -> vector<4x5xf32> { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32> // CHECK-NOT: vector.insert - // CHECK-NEXT: ub.poison : vector<4x5xf32> + // CHECK-NEXT: return %[[UB]] : vector<4x5xf32> %0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32> return %0 : vector<4x5xf32> } @@ -2913,8 +2975,9 @@ func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>) // CHECK-LABEL: @insert_multiple_poison_idx func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>) -> vector<4x5x8xf32> { + // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5x8xf32> // CHECK-NOT: vector.insert - // CHECK-NEXT: ub.poison : vector<4x5x8xf32> + // CHECK-NEXT: return %[[UB]] : vector<4x5x8xf32> %0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32> return %0 : vector<4x5x8xf32> }