Skip to content

Commit

Permalink
[HLSL] Implement HLSL Aggregate splatting (llvm#118992)
Browse files Browse the repository at this point in the history
Implement HLSL Aggregate Splat casting that handles splatting for arrays
and structs, and vectors if splatting from a vec1.
Closes llvm#100609 and Closes llvm#100619 
Depends on llvm#118842
  • Loading branch information
spall authored and joaosaffran committed Feb 14, 2025
1 parent 6e9d06a commit 9236110
Show file tree
Hide file tree
Showing 18 changed files with 305 additions and 4 deletions.
3 changes: 3 additions & 0 deletions clang/include/clang/AST/OperationKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
// Aggregate by Value cast (HLSL only).
CAST_OPERATION(HLSLElementwiseCast)

// Splat cast for Aggregates (HLSL only).
CAST_OPERATION(HLSLAggregateSplatCast)

//===- Binary Operations -------------------------------------------------===//
// Operators listed in order of precedence.
// Note that additions to this should also update the StmtVisitor class,
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
bool ContainsBitField(QualType BaseTy);
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType);
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);

QualType getInoutParameterType(QualType Ty);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,7 @@ bool CastExpr::CastConsistency() const {
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
CheckNoBasePath:
assert(path_empty() && "Cast kind should not have a base path!");
break;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15025,6 +15025,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_FixedPointCast:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLAggregateSplatCast:
llvm_unreachable("invalid cast kind for integral value");

case CK_BitCast:
Expand Down Expand Up @@ -15903,6 +15904,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_LValueToRValue:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
return EmitUnsupportedLValue(E, "unexpected cast lvalue");

case CK_Dependent:
Expand Down
39 changes: 39 additions & 0 deletions clang/lib/CodeGen/CGExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,31 @@ static bool isTrivialFiller(Expr *E) {
return false;
}

static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);

assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
llvm::Value *Cast =
CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);

// store back
llvm::Value *Idx = StoreGEPList[I].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
}
}

// emit a flat cast where the RHS is a scalar, including vector
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
Expand Down Expand Up @@ -970,6 +995,19 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_HLSLArrayRValue:
Visit(E->getSubExpr());
break;
case CK_HLSLAggregateSplatCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
RValue RV = CGF.EmitAnyExpr(Src);
QualType DestTy = E->getType();
Address DestVal = Dest.getAddress();
SourceLocation Loc = E->getExprLoc();

assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
llvm::Value *SrcVal = RV.getScalarVal();
EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
break;
}
case CK_HLSLElementwiseCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
Expand Down Expand Up @@ -1560,6 +1598,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
return true;

case CK_BaseToDerivedMemberPointer:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_FloatingRealToComplex:
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,7 @@ class ConstExprEmitter
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
return nullptr;
}
llvm_unreachable("Invalid CastKind");
Expand Down
7 changes: 6 additions & 1 deletion clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2643,6 +2643,11 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
return EmitScalarConversion(Visit(E), E->getType(), DestTy,
CE->getExprLoc());
}
// CK_HLSLAggregateSplatCast only handles splatting to vectors from a vec1
// Casts were inserted in Sema to Cast the Src Expr to a Scalar and
// To perform any necessary Scalar Cast, so this Cast can be handled
// by the regular Vector Splat cast code.
case CK_HLSLAggregateSplatCast:
case CK_VectorSplat: {
llvm::Type *DstTy = ConvertType(DestTy);
Value *Elt = Visit(const_cast<Expr *>(E));
Expand Down Expand Up @@ -2800,7 +2805,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
SourceLocation Loc = CE->getExprLoc();
QualType SrcTy = E->getType();

assert(RV.isAggregate() && "Not a valid HLSL Flat Cast.");
assert(RV.isAggregate() && "Not a valid HLSL Elementwise Cast.");
// RHS is an aggregate
Address SrcVal = RV.getAggregateAddress();
return EmitHLSLElementwiseCast(CGF, SrcVal, SrcTy, DestTy, Loc);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Edit/RewriteObjCFoundationAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,

case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
llvm_unreachable("HLSL-specific cast in Objective-C?");
break;

Expand Down
1 change: 1 addition & 0 deletions clang/lib/Sema/Sema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
case CK_ToVoid:
case CK_NonAtomicToAtomic:
case CK_HLSLArrayRValue:
case CK_HLSLAggregateSplatCast:
break;
}
}
Expand Down
26 changes: 24 additions & 2 deletions clang/lib/Sema/SemaCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2776,9 +2776,9 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
CheckedConversionKind CCK = FunctionalStyle
? CheckedConversionKind::FunctionalCast
: CheckedConversionKind::CStyleCast;
// This case should not trigger on regular vector splat
// vector cast, vector truncation, or special hlsl splat cases

QualType SrcTy = SrcExpr.get()->getType();
// This case should not trigger on regular vector cast, vector truncation
if (Self.getLangOpts().HLSL &&
Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
if (SrcTy->isConstantArrayType())
Expand All @@ -2789,6 +2789,28 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
return;
}

// This case should not trigger on regular vector splat
// If the relative order of this and the HLSLElementWise cast checks
// are changed, it might change which cast handles what in a few cases
if (Self.getLangOpts().HLSL &&
Self.HLSL().CanPerformAggregateSplatCast(SrcExpr.get(), DestType)) {
const VectorType *VT = SrcTy->getAs<VectorType>();
// change splat from vec1 case to splat from scalar
if (VT && VT->getNumElements() == 1)
SrcExpr = Self.ImpCastExprToType(
SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
SrcExpr.get()->getValueKind(), nullptr, CCK);
// Inserting a scalar cast here allows for a simplified codegen in
// the case the destTy is a vector
if (const VectorType *DVT = DestType->getAs<VectorType>())
SrcExpr = Self.ImpCastExprToType(
SrcExpr.get(), DVT->getElementType(),
Self.PrepareScalarCast(SrcExpr, DVT->getElementType()),
SrcExpr.get()->getValueKind(), nullptr, CCK);
Kind = CK_HLSLAggregateSplatCast;
return;
}

if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
!isPlaceholder(BuiltinType::Overload)) {
SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());
Expand Down
41 changes: 40 additions & 1 deletion clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,9 @@ bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
// clarity of what types are supported
bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {

if (!SrcTy->isScalarType() || !DestTy->isScalarType())
return false;

if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
return true;

Expand Down Expand Up @@ -2778,7 +2781,7 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
}

// Detect if a type contains a bitfield. Will be removed when
// bitfield support is added to HLSLElementwiseCast
// bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast
bool SemaHLSL::ContainsBitField(QualType BaseTy) {
llvm::SmallVector<QualType, 16> WorkList;
WorkList.push_back(BaseTy);
Expand Down Expand Up @@ -2811,6 +2814,42 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
return false;
}

// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
// Src is a scalar or a vector of length 1
// Or if Dest is a vector and Src is a vector of length 1
bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {

QualType SrcTy = Src->getType();
// Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
// going to be a vector splat from a scalar.
if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
DestTy->isScalarType())
return false;

const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();

// Src isn't a scalar or a vector of length 1
if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
return false;

if (SrcVecTy)
SrcTy = SrcVecTy->getElementType();

if (ContainsBitField(DestTy))
return false;

llvm::SmallVector<QualType> DestTypes;
BuildFlattenedTypeList(DestTy, DestTypes);

for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
if (DestTypes[I]->isUnionType())
return false;
if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
return false;
}
return true;
}

// Can we perform an HLSL Elementwise cast?
// TODO: update this code when matrices are added; see issue #88060
bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
Expand Down
1 change: 1 addition & 0 deletions clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
case CK_MatrixCast:
case CK_VectorSplat:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
case CK_HLSLVectorTruncation: {
QualType resultType = CastE->getType();
if (CastE->isGLValue())
Expand Down
87 changes: 87 additions & 0 deletions clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s

// array splat
// CHECK-LABEL: define void {{.*}}call4
// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
export void call4() {
int B[2] = {1,2};
B = (int[2])3;
}

// splat from vector of length 1
// CHECK-LABEL: define void {{.*}}call8
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
export void call8() {
int1 A = {1};
int B[2] = {1,2};
B = (int[2])A;
}

// vector splat from vector of length 1
// CHECK-LABEL: define void {{.*}}call1
// CHECK: [[B:%.*]] = alloca <1 x float>, align 4
// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i32 0
// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
export void call1() {
float1 B = {1.0};
int4 A = (int4)B;
}

struct S {
int X;
float Y;
};

// struct splats
// CHECK-LABEL: define void {{.*}}call3
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
// CHECK: [[s:%.*]] = alloca %struct.S, align 4
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
export void call3() {
int1 A = {1};
S s = (S)A;
}

// struct splat from vector of length 1
// CHECK-LABEL: define void {{.*}}call5
// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
export void call5() {
int1 A = {1};
S s = (S)A;
}
Loading

0 comments on commit 9236110

Please sign in to comment.