Skip to content

Commit 733b7bd

Browse files
spallsivan-shani
authored andcommitted
[HLSL] Implement HLSL Aggregate splatting (llvm#118992)
Implement HLSL Aggregate Splat casting that handles splatting for arrays and structs, and vectors if splatting from a vec1. Closes llvm#100609 and Closes llvm#100619 Depends on llvm#118842
1 parent 8921ea1 commit 733b7bd

18 files changed

+305
-4
lines changed

clang/include/clang/AST/OperationKinds.def

+3
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
370370
// Aggregate by Value cast (HLSL only).
371371
CAST_OPERATION(HLSLElementwiseCast)
372372

373+
// Splat cast for Aggregates (HLSL only).
374+
CAST_OPERATION(HLSLAggregateSplatCast)
375+
373376
//===- Binary Operations -------------------------------------------------===//
374377
// Operators listed in order of precedence.
375378
// Note that additions to this should also update the StmtVisitor class,

clang/include/clang/Sema/SemaHLSL.h

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
144144
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
145145
bool ContainsBitField(QualType BaseTy);
146146
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
147+
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType);
147148
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
148149

149150
QualType getInoutParameterType(QualType Ty);

clang/lib/AST/Expr.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1968,6 +1968,7 @@ bool CastExpr::CastConsistency() const {
19681968
case CK_HLSLArrayRValue:
19691969
case CK_HLSLVectorTruncation:
19701970
case CK_HLSLElementwiseCast:
1971+
case CK_HLSLAggregateSplatCast:
19711972
CheckNoBasePath:
19721973
assert(path_empty() && "Cast kind should not have a base path!");
19731974
break;

clang/lib/AST/ExprConstant.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -15025,6 +15025,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
1502515025
case CK_FixedPointCast:
1502615026
case CK_IntegralToFixedPoint:
1502715027
case CK_MatrixCast:
15028+
case CK_HLSLAggregateSplatCast:
1502815029
llvm_unreachable("invalid cast kind for integral value");
1502915030

1503015031
case CK_BitCast:
@@ -15903,6 +15904,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
1590315904
case CK_MatrixCast:
1590415905
case CK_HLSLVectorTruncation:
1590515906
case CK_HLSLElementwiseCast:
15907+
case CK_HLSLAggregateSplatCast:
1590615908
llvm_unreachable("invalid cast kind for complex value");
1590715909

1590815910
case CK_LValueToRValue:

clang/lib/CodeGen/CGExpr.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
53395339
case CK_HLSLVectorTruncation:
53405340
case CK_HLSLArrayRValue:
53415341
case CK_HLSLElementwiseCast:
5342+
case CK_HLSLAggregateSplatCast:
53425343
return EmitUnsupportedLValue(E, "unexpected cast lvalue");
53435344

53445345
case CK_Dependent:

clang/lib/CodeGen/CGExprAgg.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,31 @@ static bool isTrivialFiller(Expr *E) {
498498
return false;
499499
}
500500

501+
static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
502+
QualType DestTy, llvm::Value *SrcVal,
503+
QualType SrcTy, SourceLocation Loc) {
504+
// Flatten our destination
505+
SmallVector<QualType> DestTypes; // Flattened type
506+
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
507+
// ^^ Flattened accesses to DestVal we want to store into
508+
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
509+
510+
assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
511+
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
512+
llvm::Value *Cast =
513+
CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
514+
515+
// store back
516+
llvm::Value *Idx = StoreGEPList[I].second;
517+
if (Idx) {
518+
llvm::Value *V =
519+
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
520+
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
521+
}
522+
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
523+
}
524+
}
525+
501526
// emit a flat cast where the RHS is a scalar, including vector
502527
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
503528
QualType DestTy, llvm::Value *SrcVal,
@@ -970,6 +995,19 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
970995
case CK_HLSLArrayRValue:
971996
Visit(E->getSubExpr());
972997
break;
998+
case CK_HLSLAggregateSplatCast: {
999+
Expr *Src = E->getSubExpr();
1000+
QualType SrcTy = Src->getType();
1001+
RValue RV = CGF.EmitAnyExpr(Src);
1002+
QualType DestTy = E->getType();
1003+
Address DestVal = Dest.getAddress();
1004+
SourceLocation Loc = E->getExprLoc();
1005+
1006+
assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
1007+
llvm::Value *SrcVal = RV.getScalarVal();
1008+
EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
1009+
break;
1010+
}
9731011
case CK_HLSLElementwiseCast: {
9741012
Expr *Src = E->getSubExpr();
9751013
QualType SrcTy = Src->getType();
@@ -1560,6 +1598,7 @@ static bool castPreservesZero(const CastExpr *CE) {
15601598
case CK_AtomicToNonAtomic:
15611599
case CK_HLSLVectorTruncation:
15621600
case CK_HLSLElementwiseCast:
1601+
case CK_HLSLAggregateSplatCast:
15631602
return true;
15641603

15651604
case CK_BaseToDerivedMemberPointer:

clang/lib/CodeGen/CGExprComplex.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
611611
case CK_HLSLVectorTruncation:
612612
case CK_HLSLArrayRValue:
613613
case CK_HLSLElementwiseCast:
614+
case CK_HLSLAggregateSplatCast:
614615
llvm_unreachable("invalid cast kind for complex value");
615616

616617
case CK_FloatingRealToComplex:

clang/lib/CodeGen/CGExprConstant.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,7 @@ class ConstExprEmitter
13361336
case CK_HLSLVectorTruncation:
13371337
case CK_HLSLArrayRValue:
13381338
case CK_HLSLElementwiseCast:
1339+
case CK_HLSLAggregateSplatCast:
13391340
return nullptr;
13401341
}
13411342
llvm_unreachable("Invalid CastKind");

clang/lib/CodeGen/CGExprScalar.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -2643,6 +2643,11 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
26432643
return EmitScalarConversion(Visit(E), E->getType(), DestTy,
26442644
CE->getExprLoc());
26452645
}
2646+
// CK_HLSLAggregateSplatCast only handles splatting to vectors from a vec1
2647+
// Casts were inserted in Sema to Cast the Src Expr to a Scalar and
2648+
// To perform any necessary Scalar Cast, so this Cast can be handled
2649+
// by the regular Vector Splat cast code.
2650+
case CK_HLSLAggregateSplatCast:
26462651
case CK_VectorSplat: {
26472652
llvm::Type *DstTy = ConvertType(DestTy);
26482653
Value *Elt = Visit(const_cast<Expr *>(E));
@@ -2800,7 +2805,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
28002805
SourceLocation Loc = CE->getExprLoc();
28012806
QualType SrcTy = E->getType();
28022807

2803-
assert(RV.isAggregate() && "Not a valid HLSL Flat Cast.");
2808+
assert(RV.isAggregate() && "Not a valid HLSL Elementwise Cast.");
28042809
// RHS is an aggregate
28052810
Address SrcVal = RV.getAggregateAddress();
28062811
return EmitHLSLElementwiseCast(CGF, SrcVal, SrcTy, DestTy, Loc);

clang/lib/Edit/RewriteObjCFoundationAPI.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
10861086

10871087
case CK_HLSLVectorTruncation:
10881088
case CK_HLSLElementwiseCast:
1089+
case CK_HLSLAggregateSplatCast:
10891090
llvm_unreachable("HLSL-specific cast in Objective-C?");
10901091
break;
10911092

clang/lib/Sema/Sema.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
709709
case CK_ToVoid:
710710
case CK_NonAtomicToAtomic:
711711
case CK_HLSLArrayRValue:
712+
case CK_HLSLAggregateSplatCast:
712713
break;
713714
}
714715
}

clang/lib/Sema/SemaCast.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -2776,9 +2776,9 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
27762776
CheckedConversionKind CCK = FunctionalStyle
27772777
? CheckedConversionKind::FunctionalCast
27782778
: CheckedConversionKind::CStyleCast;
2779-
// This case should not trigger on regular vector splat
2780-
// vector cast, vector truncation, or special hlsl splat cases
2779+
27812780
QualType SrcTy = SrcExpr.get()->getType();
2781+
// This case should not trigger on regular vector cast, vector truncation
27822782
if (Self.getLangOpts().HLSL &&
27832783
Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
27842784
if (SrcTy->isConstantArrayType())
@@ -2789,6 +2789,28 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
27892789
return;
27902790
}
27912791

2792+
// This case should not trigger on regular vector splat
2793+
// If the relative order of this and the HLSLElementWise cast checks
2794+
// are changed, it might change which cast handles what in a few cases
2795+
if (Self.getLangOpts().HLSL &&
2796+
Self.HLSL().CanPerformAggregateSplatCast(SrcExpr.get(), DestType)) {
2797+
const VectorType *VT = SrcTy->getAs<VectorType>();
2798+
// change splat from vec1 case to splat from scalar
2799+
if (VT && VT->getNumElements() == 1)
2800+
SrcExpr = Self.ImpCastExprToType(
2801+
SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
2802+
SrcExpr.get()->getValueKind(), nullptr, CCK);
2803+
// Inserting a scalar cast here allows for a simplified codegen in
2804+
// the case the destTy is a vector
2805+
if (const VectorType *DVT = DestType->getAs<VectorType>())
2806+
SrcExpr = Self.ImpCastExprToType(
2807+
SrcExpr.get(), DVT->getElementType(),
2808+
Self.PrepareScalarCast(SrcExpr, DVT->getElementType()),
2809+
SrcExpr.get()->getValueKind(), nullptr, CCK);
2810+
Kind = CK_HLSLAggregateSplatCast;
2811+
return;
2812+
}
2813+
27922814
if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
27932815
!isPlaceholder(BuiltinType::Overload)) {
27942816
SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());

clang/lib/Sema/SemaHLSL.cpp

+40-1
Original file line numberDiff line numberDiff line change
@@ -2717,6 +2717,9 @@ bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
27172717
// clarity of what types are supported
27182718
bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
27192719

2720+
if (!SrcTy->isScalarType() || !DestTy->isScalarType())
2721+
return false;
2722+
27202723
if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
27212724
return true;
27222725

@@ -2778,7 +2781,7 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
27782781
}
27792782

27802783
// Detect if a type contains a bitfield. Will be removed when
2781-
// bitfield support is added to HLSLElementwiseCast
2784+
// bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast
27822785
bool SemaHLSL::ContainsBitField(QualType BaseTy) {
27832786
llvm::SmallVector<QualType, 16> WorkList;
27842787
WorkList.push_back(BaseTy);
@@ -2811,6 +2814,42 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
28112814
return false;
28122815
}
28132816

2817+
// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
2818+
// Src is a scalar or a vector of length 1
2819+
// Or if Dest is a vector and Src is a vector of length 1
2820+
bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
2821+
2822+
QualType SrcTy = Src->getType();
2823+
// Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
2824+
// going to be a vector splat from a scalar.
2825+
if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
2826+
DestTy->isScalarType())
2827+
return false;
2828+
2829+
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
2830+
2831+
// Src isn't a scalar or a vector of length 1
2832+
if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
2833+
return false;
2834+
2835+
if (SrcVecTy)
2836+
SrcTy = SrcVecTy->getElementType();
2837+
2838+
if (ContainsBitField(DestTy))
2839+
return false;
2840+
2841+
llvm::SmallVector<QualType> DestTypes;
2842+
BuildFlattenedTypeList(DestTy, DestTypes);
2843+
2844+
for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
2845+
if (DestTypes[I]->isUnionType())
2846+
return false;
2847+
if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
2848+
return false;
2849+
}
2850+
return true;
2851+
}
2852+
28142853
// Can we perform an HLSL Elementwise cast?
28152854
// TODO: update this code when matrices are added; see issue #88060
28162855
bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {

clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
523523
case CK_MatrixCast:
524524
case CK_VectorSplat:
525525
case CK_HLSLElementwiseCast:
526+
case CK_HLSLAggregateSplatCast:
526527
case CK_HLSLVectorTruncation: {
527528
QualType resultType = CastE->getType();
528529
if (CastE->isGLValue())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
2+
3+
// array splat
4+
// CHECK-LABEL: define void {{.*}}call4
5+
// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
6+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
7+
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
8+
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
9+
// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
10+
// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
11+
export void call4() {
12+
int B[2] = {1,2};
13+
B = (int[2])3;
14+
}
15+
16+
// splat from vector of length 1
17+
// CHECK-LABEL: define void {{.*}}call8
18+
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
19+
// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
20+
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
21+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
22+
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
23+
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
24+
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
25+
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
26+
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
27+
// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
28+
export void call8() {
29+
int1 A = {1};
30+
int B[2] = {1,2};
31+
B = (int[2])A;
32+
}
33+
34+
// vector splat from vector of length 1
35+
// CHECK-LABEL: define void {{.*}}call1
36+
// CHECK: [[B:%.*]] = alloca <1 x float>, align 4
37+
// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
38+
// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
39+
// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
40+
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i32 0
41+
// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
42+
// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
43+
// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
44+
// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
45+
export void call1() {
46+
float1 B = {1.0};
47+
int4 A = (int4)B;
48+
}
49+
50+
struct S {
51+
int X;
52+
float Y;
53+
};
54+
55+
// struct splats
56+
// CHECK-LABEL: define void {{.*}}call3
57+
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
58+
// CHECK: [[s:%.*]] = alloca %struct.S, align 4
59+
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
60+
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
61+
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
62+
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
63+
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
64+
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
65+
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
66+
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
67+
export void call3() {
68+
int1 A = {1};
69+
S s = (S)A;
70+
}
71+
72+
// struct splat from vector of length 1
73+
// CHECK-LABEL: define void {{.*}}call5
74+
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
75+
// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
76+
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
77+
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
78+
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
79+
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
80+
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
81+
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
82+
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
83+
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
84+
export void call5() {
85+
int1 A = {1};
86+
S s = (S)A;
87+
}

0 commit comments

Comments
 (0)