Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VectorCombine] Fold vector.interleave2 with two constant splats #125144

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class VectorCombine {
bool foldShuffleFromReductions(Instruction &I);
bool foldCastFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
bool foldInterleaveIntrinsics(Instruction &I);
bool shrinkType(Instruction &I);

void replaceValue(Value &Old, Value &New) {
Expand Down Expand Up @@ -3145,6 +3146,47 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
return true;
}

/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
/// before casting it back into `<vscale x 16 x i32>`.
bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
const APInt *SplatVal0, *SplatVal1;
if (!match(&I, m_Intrinsic<Intrinsic::vector_interleave2>(
m_APInt(SplatVal0), m_APInt(SplatVal1))))
return false;

LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
<< "\n");

auto *VTy =
cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
unsigned Width = VTy->getElementType()->getIntegerBitWidth();

// Just in case the cost of interleave2 intrinsic and bitcast are both
// invalid, in which case we want to bail out, we use <= rather
// than < here. Even they both have valid and equal costs, it's probably
// not a good idea to emit a high-cost constant splat.
if (TTI.getInstructionCost(&I, CostKind) <=
TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're really just worrying about the legalization cost here, should ExtVTy be an illegal type.

TTI::CastContextHint::None, CostKind)) {
LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
<< *I.getType() << " is too high.\n");
return false;
}

APInt NewSplatVal = SplatVal1->zext(Width * 2);
NewSplatVal <<= Width;
NewSplatVal |= SplatVal0->zext(Width * 2);
auto *NewSplat = ConstantVector::getSplat(
ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));

IRBuilder<> Builder(&I);
replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
return true;
}

/// This is the entry point for all transforms. Pass manager differences are
/// handled in the callers of this function.
bool VectorCombine::run() {
Expand Down Expand Up @@ -3189,6 +3231,7 @@ bool VectorCombine::run() {
MadeChange |= scalarizeBinopOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
MadeChange |= foldInterleaveIntrinsics(I);
}

if (Opcode == Instruction::Store)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -mtriple=riscv64 -mattr=+v %s -passes=vector-combine | FileCheck %s
; RUN: opt -S -mtriple=riscv32 -mattr=+v %s -passes=vector-combine | FileCheck %s

; We should not form a i128 vector.

define void @interleave2_const_splat_nxv8i64(ptr %dst) {
Copy link
Collaborator

@topperc topperc Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have some bad interaction with zve32x that required a separate test file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have some bad interaction with zve32x that required a separate test file?

Correct, from what I'd tried zve32x doesn't like SEW=64 in general.

; CHECK-LABEL: define void @interleave2_const_splat_nxv8i64(
; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[INTERLEAVE2:%.*]] = call <vscale x 8 x i64> @llvm.vector.interleave2.nxv8i64(<vscale x 4 x i64> splat (i64 666), <vscale x 4 x i64> splat (i64 777))
; CHECK-NEXT: call void @llvm.vp.store.nxv8i64.p0(<vscale x 8 x i64> [[INTERLEAVE2]], ptr [[DST]], <vscale x 8 x i1> splat (i1 true), i32 88)
; CHECK-NEXT: ret void
;
%interleave2 = call <vscale x 8 x i64> @llvm.vector.interleave2.nxv8i64(<vscale x 4 x i64> splat (i64 666), <vscale x 4 x i64> splat (i64 777))
call void @llvm.vp.store.nxv8i64.p0(<vscale x 8 x i64> %interleave2, ptr %dst, <vscale x 8 x i1> splat (i1 true), i32 88)
ret void
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -mtriple=riscv64 -mattr=+v %s -passes=vector-combine | FileCheck %s
; RUN: opt -S -mtriple=riscv32 -mattr=+v %s -passes=vector-combine | FileCheck %s
; RUN: opt -S -mtriple=riscv64 -mattr=+zve32x %s -passes=vector-combine | FileCheck %s --check-prefix=ZVE32X

define void @interleave2_const_splat_nxv16i32(ptr %dst) {
; CHECK-LABEL: define void @interleave2_const_splat_nxv16i32(
; CHECK-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> bitcast (<vscale x 8 x i64> splat (i64 3337189589658) to <vscale x 16 x i32>), ptr [[DST]], <vscale x 16 x i1> splat (i1 true), i32 88)
; CHECK-NEXT: ret void
;
; ZVE32X-LABEL: define void @interleave2_const_splat_nxv16i32(
; ZVE32X-SAME: ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
; ZVE32X-NEXT: [[INTERLEAVE2:%.*]] = call <vscale x 16 x i32> @llvm.vector.interleave2.nxv16i32(<vscale x 8 x i32> splat (i32 666), <vscale x 8 x i32> splat (i32 777))
; ZVE32X-NEXT: call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> [[INTERLEAVE2]], ptr [[DST]], <vscale x 16 x i1> splat (i1 true), i32 88)
; ZVE32X-NEXT: ret void
;
%interleave2 = call <vscale x 16 x i32> @llvm.vector.interleave2.nxv16i32(<vscale x 8 x i32> splat (i32 666), <vscale x 8 x i32> splat (i32 777))
call void @llvm.vp.store.nxv16i32.p0(<vscale x 16 x i32> %interleave2, ptr %dst, <vscale x 16 x i1> splat (i1 true), i32 88)
ret void
}