Skip to content

Commit c5f99e1

Browse files
authored
[SandboxVec][Legality] Fix legality of SelectInst (#125005)
SelectInsts need special treatment because they are not always straightforward to vectorize. This patch disables vectorization unless they are trivially vectorizable.
1 parent f10979f commit c5f99e1

File tree

3 files changed

+120
-2
lines changed

3 files changed

+120
-2
lines changed

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,15 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
116116
return std::nullopt;
117117
return ResultReason::DiffOpcodes;
118118
}
119-
case Instruction::Opcode::Select:
119+
case Instruction::Opcode::Select: {
120+
auto *Sel0 = cast<SelectInst>(Bndl[0]);
121+
auto *Cond0 = Sel0->getCondition();
122+
if (VecUtils::getNumLanes(Cond0) != VecUtils::getNumLanes(Sel0))
123+
// TODO: For now we don't vectorize if the lanes in the condition don't
124+
// match those of the select instruction.
125+
return ResultReason::Unimplemented;
126+
return std::nullopt;
127+
}
120128
case Instruction::Opcode::FNeg:
121129
case Instruction::Opcode::Add:
122130
case Instruction::Opcode::FAdd:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2 -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s
3+
4+
; This file includes tests for opcodes that need special checks.
5+
6+
; TODO: Selects with conditions of diff number of lanes than the instruction itself need special treatment.
7+
define void @selects_with_diff_cond_lanes(ptr %ptr, i1 %cond0, i1 %cond1, <2 x i8> %op0, <2 x i8> %op1) {
8+
; CHECK-LABEL: define void @selects_with_diff_cond_lanes(
9+
; CHECK-SAME: ptr [[PTR:%.*]], i1 [[COND0:%.*]], i1 [[COND1:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
10+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
11+
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 1
12+
; CHECK-NEXT: [[LD0:%.*]] = load <2 x i8>, ptr [[PTR0]], align 2
13+
; CHECK-NEXT: [[LD1:%.*]] = load <2 x i8>, ptr [[PTR1]], align 2
14+
; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[COND0]], <2 x i8> [[LD0]], <2 x i8> [[LD0]]
15+
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1]], <2 x i8> [[LD1]], <2 x i8> [[LD1]]
16+
; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i8> [[SEL0]], i32 0
17+
; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i8> poison, i8 [[VPACK]], i32 0
18+
; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i8> [[SEL0]], i32 1
19+
; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i8> [[VPACK1]], i8 [[VPACK2]], i32 1
20+
; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i8> [[SEL1]], i32 0
21+
; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i8> [[VPACK3]], i8 [[VPACK4]], i32 2
22+
; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i8> [[SEL1]], i32 1
23+
; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i8> [[VPACK5]], i8 [[VPACK6]], i32 3
24+
; CHECK-NEXT: store <4 x i8> [[VPACK7]], ptr [[PTR0]], align 2
25+
; CHECK-NEXT: ret void
26+
;
27+
%ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
28+
%ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
29+
%ld0 = load <2 x i8>, ptr %ptr0
30+
%ld1 = load <2 x i8>, ptr %ptr1
31+
%sel0 = select i1 %cond0, <2 x i8> %ld0, <2 x i8> %ld0
32+
%sel1 = select i1 %cond1, <2 x i8> %ld1, <2 x i8> %ld1
33+
store <2 x i8> %sel0, ptr %ptr0
34+
store <2 x i8> %sel1, ptr %ptr1
35+
ret void
36+
}
37+
38+
; TODO: Selects that share the same condition need special treatment.
39+
define void @selects_with_common_condition_but_diff_lanes(ptr %ptr, i1 %cond, <2 x i8> %op0, <2 x i8> %op1) {
40+
; CHECK-LABEL: define void @selects_with_common_condition_but_diff_lanes(
41+
; CHECK-SAME: ptr [[PTR:%.*]], i1 [[COND:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
42+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
43+
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 1
44+
; CHECK-NEXT: [[LD0:%.*]] = load <2 x i8>, ptr [[PTR0]], align 2
45+
; CHECK-NEXT: [[LD1:%.*]] = load <2 x i8>, ptr [[PTR1]], align 2
46+
; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[COND]], <2 x i8> [[LD0]], <2 x i8> [[LD0]]
47+
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND]], <2 x i8> [[LD1]], <2 x i8> [[LD1]]
48+
; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i8> [[SEL0]], i32 0
49+
; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i8> poison, i8 [[VPACK]], i32 0
50+
; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i8> [[SEL0]], i32 1
51+
; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i8> [[VPACK1]], i8 [[VPACK2]], i32 1
52+
; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i8> [[SEL1]], i32 0
53+
; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i8> [[VPACK3]], i8 [[VPACK4]], i32 2
54+
; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i8> [[SEL1]], i32 1
55+
; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i8> [[VPACK5]], i8 [[VPACK6]], i32 3
56+
; CHECK-NEXT: store <4 x i8> [[VPACK7]], ptr [[PTR0]], align 2
57+
; CHECK-NEXT: ret void
58+
;
59+
%ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
60+
%ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
61+
%ld0 = load <2 x i8>, ptr %ptr0
62+
%ld1 = load <2 x i8>, ptr %ptr1
63+
%sel0 = select i1 %cond, <2 x i8> %ld0, <2 x i8> %ld0
64+
%sel1 = select i1 %cond, <2 x i8> %ld1, <2 x i8> %ld1
65+
store <2 x i8> %sel0, ptr %ptr0
66+
store <2 x i8> %sel1, ptr %ptr1
67+
ret void
68+
}
69+
70+
; Selects with conditions of the same number of lanes as the instruction itself be vectorized as usual.
71+
define void @selects_same_cond_lanes(ptr %ptr, <2 x i1> %cond0, <2 x i1> %cond1, <2 x i8> %op0, <2 x i8> %op1) {
72+
; CHECK-LABEL: define void @selects_same_cond_lanes(
73+
; CHECK-SAME: ptr [[PTR:%.*]], <2 x i1> [[COND0:%.*]], <2 x i1> [[COND1:%.*]], <2 x i8> [[OP0:%.*]], <2 x i8> [[OP1:%.*]]) {
74+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x i8>, ptr [[PTR]], i32 0
75+
; CHECK-NEXT: [[VPACK:%.*]] = extractelement <2 x i1> [[COND0]], i32 0
76+
; CHECK-NEXT: [[VPACK1:%.*]] = insertelement <4 x i1> poison, i1 [[VPACK]], i32 0
77+
; CHECK-NEXT: [[VPACK2:%.*]] = extractelement <2 x i1> [[COND0]], i32 1
78+
; CHECK-NEXT: [[VPACK3:%.*]] = insertelement <4 x i1> [[VPACK1]], i1 [[VPACK2]], i32 1
79+
; CHECK-NEXT: [[VPACK4:%.*]] = extractelement <2 x i1> [[COND1]], i32 0
80+
; CHECK-NEXT: [[VPACK5:%.*]] = insertelement <4 x i1> [[VPACK3]], i1 [[VPACK4]], i32 2
81+
; CHECK-NEXT: [[VPACK6:%.*]] = extractelement <2 x i1> [[COND1]], i32 1
82+
; CHECK-NEXT: [[VPACK7:%.*]] = insertelement <4 x i1> [[VPACK5]], i1 [[VPACK6]], i32 3
83+
; CHECK-NEXT: [[VECL:%.*]] = load <4 x i8>, ptr [[PTR0]], align 2
84+
; CHECK-NEXT: [[VEC:%.*]] = select <4 x i1> [[VPACK7]], <4 x i8> [[VECL]], <4 x i8> [[VECL]]
85+
; CHECK-NEXT: store <4 x i8> [[VEC]], ptr [[PTR0]], align 2
86+
; CHECK-NEXT: ret void
87+
;
88+
%ptr0 = getelementptr <2 x i8>, ptr %ptr, i32 0
89+
%ptr1 = getelementptr <2 x i8>, ptr %ptr, i32 1
90+
%ld0 = load <2 x i8>, ptr %ptr0
91+
%ld1 = load <2 x i8>, ptr %ptr1
92+
%sel0 = select <2 x i1> %cond0, <2 x i8> %ld0, <2 x i8> %ld0
93+
%sel1 = select <2 x i1> %cond1, <2 x i8> %ld1, <2 x i8> %ld1
94+
store <2 x i8> %sel0, ptr %ptr0
95+
store <2 x i8> %sel1, ptr %ptr1
96+
ret void
97+
}

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp

+14-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ static sandboxir::BasicBlock *getBasicBlockByName(sandboxir::Function *F,
6767

6868
TEST_F(LegalityTest, LegalitySkipSchedule) {
6969
parseIR(C, R"IR(
70-
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
70+
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2, i1 %c0, i1 %c1) {
7171
entry:
7272
%gep0 = getelementptr float, ptr %ptr, i32 0
7373
%gep1 = getelementptr float, ptr %ptr, i32 1
@@ -93,6 +93,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
9393
%trunc32to8 = trunc i32 %v2 to i8
9494
%cmpSLT = icmp slt i64 %v0, %v1
9595
%cmpSGT = icmp sgt i64 %v0, %v1
96+
%sel0 = select i1 %c0, <2 x float> %vec2, <2 x float> %vec2
97+
%sel1 = select i1 %c1, <2 x float> %vec2, <2 x float> %vec2
9698
ret void
9799
}
98100
)IR");
@@ -128,6 +130,8 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
128130
auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
129131
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
130132
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
133+
auto *Sel0 = cast<sandboxir::SelectInst>(&*It++);
134+
auto *Sel1 = cast<sandboxir::SelectInst>(&*It++);
131135

132136
llvm::sandboxir::InstrMaps IMaps(Ctx);
133137
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
@@ -241,6 +245,15 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
241245
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
242246
sandboxir::ResultReason::RepeatedInstrs);
243247
}
248+
{
249+
// For now don't vectorize Selects when the number of elements of conditions
250+
// doesn't match the operands.
251+
const auto &Result =
252+
Legality.canVectorize({Sel0, Sel1}, /*SkipScheduling=*/true);
253+
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
254+
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
255+
sandboxir::ResultReason::Unimplemented);
256+
}
244257
}
245258

246259
TEST_F(LegalityTest, LegalitySchedule) {

0 commit comments

Comments
 (0)