diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 59920b5a4dd20..fd6b5303a2570 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -125,6 +125,7 @@ class VectorCombine { bool foldShuffleFromReductions(Instruction &I); bool foldCastFromReductions(Instruction &I); bool foldSelectShuffle(Instruction &I, bool FromReduction = false); + bool foldInterleaveIntrinsics(Instruction &I); bool shrinkType(Instruction &I); void replaceValue(Value &Old, Value &New) { @@ -3145,6 +3146,47 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) { return true; } +/// If we're interleaving 2 constant splats, for instance ` +/// ` and ` `, we can create a +/// larger splat ` ` first +/// before casting it back into ``. +bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) { + const APInt *SplatVal0, *SplatVal1; + if (!match(&I, m_Intrinsic( + m_APInt(SplatVal0), m_APInt(SplatVal1)))) + return false; + + LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I + << "\n"); + + auto *VTy = + cast(cast(I).getArgOperand(0)->getType()); + auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy); + unsigned Width = VTy->getElementType()->getIntegerBitWidth(); + + // Just in case the cost of interleave2 intrinsic and bitcast are both + // invalid, in which case we want to bail out, we use <= rather + // than < here. Even they both have valid and equal costs, it's probably + // not a good idea to emit a high-cost constant splat. + if (TTI.getInstructionCost(&I, CostKind) <= + TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy, + TTI::CastContextHint::None, CostKind)) { + LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to " + << *I.getType() << " is too high.\n"); + return false; + } + + APInt NewSplatVal = SplatVal1->zext(Width * 2); + NewSplatVal <<= Width; + NewSplatVal |= SplatVal0->zext(Width * 2); + auto *NewSplat = ConstantVector::getSplat( + ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal)); + + IRBuilder<> Builder(&I); + replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType())); + return true; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -3189,6 +3231,7 @@ bool VectorCombine::run() { MadeChange |= scalarizeBinopOrCmp(I); MadeChange |= scalarizeLoadExtract(I); MadeChange |= scalarizeVPIntrinsic(I); + MadeChange |= foldInterleaveIntrinsics(I); } if (Opcode == Instruction::Store) diff --git a/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat-e64.ll b/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat-e64.ll new file mode 100644 index 0000000000000..26a5d2e7f849b --- /dev/null +++ b/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat-e64.ll @@ -0,0 +1,17 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -mtriple=riscv64 -mattr=+v %s -passes=vector-combine | FileCheck %s +; RUN: opt -S -mtriple=riscv32 -mattr=+v %s -passes=vector-combine | FileCheck %s + +; We should not form a i128 vector. + +define void @interleave2_const_splat_nxv8i64(ptr %dst) { +; CHECK-LABEL: define void @interleave2_const_splat_nxv8i64( +; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: [[INTERLEAVE2:%.*]] = call @llvm.vector.interleave2.nxv8i64( splat (i64 666), splat (i64 777)) +; CHECK-NEXT: call void @llvm.vp.store.nxv8i64.p0( [[INTERLEAVE2]], ptr [[DST]], splat (i1 true), i32 88) +; CHECK-NEXT: ret void +; + %interleave2 = call @llvm.vector.interleave2.nxv8i64( splat (i64 666), splat (i64 777)) + call void @llvm.vp.store.nxv8i64.p0( %interleave2, ptr %dst, splat (i1 true), i32 88) + ret void +} diff --git a/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll b/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll new file mode 100644 index 0000000000000..477a593ec51e9 --- /dev/null +++ b/llvm/test/Transforms/VectorCombine/RISCV/vector-interleave2-splat.ll @@ -0,0 +1,21 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -mtriple=riscv64 -mattr=+v %s -passes=vector-combine | FileCheck %s +; RUN: opt -S -mtriple=riscv32 -mattr=+v %s -passes=vector-combine | FileCheck %s +; RUN: opt -S -mtriple=riscv64 -mattr=+zve32x %s -passes=vector-combine | FileCheck %s --check-prefix=ZVE32X + +define void @interleave2_const_splat_nxv16i32(ptr %dst) { +; CHECK-LABEL: define void @interleave2_const_splat_nxv16i32( +; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: call void @llvm.vp.store.nxv16i32.p0( bitcast ( splat (i64 3337189589658) to ), ptr [[DST]], splat (i1 true), i32 88) +; CHECK-NEXT: ret void +; +; ZVE32X-LABEL: define void @interleave2_const_splat_nxv16i32( +; ZVE32X-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] { +; ZVE32X-NEXT: [[INTERLEAVE2:%.*]] = call @llvm.vector.interleave2.nxv16i32( splat (i32 666), splat (i32 777)) +; ZVE32X-NEXT: call void @llvm.vp.store.nxv16i32.p0( [[INTERLEAVE2]], ptr [[DST]], splat (i1 true), i32 88) +; ZVE32X-NEXT: ret void +; + %interleave2 = call @llvm.vector.interleave2.nxv16i32( splat (i32 666), splat (i32 777)) + call void @llvm.vp.store.nxv16i32.p0( %interleave2, ptr %dst, splat (i1 true), i32 88) + ret void +}