Skip to content

Commit 69c4ea1

Browse files
spallIcohedron
authored andcommitted
[HLSL] Implement HLSL Elementwise casting (excluding splat cases); Re-land llvm#118842 (llvm#126258)
Implement HLSLElementwiseCast excluding support for splat cases Do not support casting types that contain bitfields. Partly closes llvm#100609 and partly closes llvm#100619 Re-land llvm#118842 after fixing warning as an error, found by a buildbot.
1 parent a47fb98 commit 69c4ea1

20 files changed

+807
-6
lines changed

clang/include/clang/AST/OperationKinds.def

+3
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ CAST_OPERATION(HLSLVectorTruncation)
367367
// Non-decaying array RValue cast (HLSL only).
368368
CAST_OPERATION(HLSLArrayRValue)
369369

370+
// Aggregate by Value cast (HLSL only).
371+
CAST_OPERATION(HLSLElementwiseCast)
372+
370373
//===- Binary Operations -------------------------------------------------===//
371374
// Operators listed in order of precedence.
372375
// Note that additions to this should also update the StmtVisitor class,

clang/include/clang/Sema/SemaHLSL.h

+3
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ class SemaHLSL : public SemaBase {
141141
// Diagnose whether the input ID is uint/unit2/uint3 type.
142142
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);
143143

144+
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
145+
bool ContainsBitField(QualType BaseTy);
146+
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
144147
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
145148

146149
QualType getInoutParameterType(QualType Ty);

clang/lib/AST/Expr.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,7 @@ bool CastExpr::CastConsistency() const {
19561956
case CK_FixedPointToBoolean:
19571957
case CK_HLSLArrayRValue:
19581958
case CK_HLSLVectorTruncation:
1959+
case CK_HLSLElementwiseCast:
19591960
CheckNoBasePath:
19601961
assert(path_empty() && "Cast kind should not have a base path!");
19611962
break;

clang/lib/AST/ExprConstant.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -15047,6 +15047,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
1504715047
case CK_NoOp:
1504815048
case CK_LValueToRValueBitCast:
1504915049
case CK_HLSLArrayRValue:
15050+
case CK_HLSLElementwiseCast:
1505015051
return ExprEvaluatorBaseTy::VisitCastExpr(E);
1505115052

1505215053
case CK_MemberPointerToBoolean:
@@ -15905,6 +15906,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
1590515906
case CK_IntegralToFixedPoint:
1590615907
case CK_MatrixCast:
1590715908
case CK_HLSLVectorTruncation:
15909+
case CK_HLSLElementwiseCast:
1590815910
llvm_unreachable("invalid cast kind for complex value");
1590915911

1591015912
case CK_LValueToRValue:

clang/lib/CodeGen/CGExpr.cpp

+73
Original file line numberDiff line numberDiff line change
@@ -5338,6 +5338,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
53385338
case CK_MatrixCast:
53395339
case CK_HLSLVectorTruncation:
53405340
case CK_HLSLArrayRValue:
5341+
case CK_HLSLElementwiseCast:
53415342
return EmitUnsupportedLValue(E, "unexpected cast lvalue");
53425343

53435344
case CK_Dependent:
@@ -6376,3 +6377,75 @@ RValue CodeGenFunction::EmitPseudoObjectRValue(const PseudoObjectExpr *E,
63766377
LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
63776378
return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
63786379
}
6380+
6381+
void CodeGenFunction::FlattenAccessAndType(
6382+
Address Addr, QualType AddrType,
6383+
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
6384+
SmallVectorImpl<QualType> &FlatTypes) {
6385+
// WorkList is list of type we are processing + the Index List to access
6386+
// the field of that type in Addr for use in a GEP
6387+
llvm::SmallVector<std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>>,
6388+
16>
6389+
WorkList;
6390+
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
6391+
// Addr should be a pointer so we need to 'dereference' it
6392+
WorkList.push_back({AddrType, {llvm::ConstantInt::get(IdxTy, 0)}});
6393+
6394+
while (!WorkList.empty()) {
6395+
auto [T, IdxList] = WorkList.pop_back_val();
6396+
T = T.getCanonicalType().getUnqualifiedType();
6397+
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
6398+
if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
6399+
uint64_t Size = CAT->getZExtSize();
6400+
for (int64_t I = Size - 1; I > -1; I--) {
6401+
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6402+
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
6403+
WorkList.emplace_back(CAT->getElementType(), IdxListCopy);
6404+
}
6405+
} else if (const auto *RT = dyn_cast<RecordType>(T)) {
6406+
const RecordDecl *Record = RT->getDecl();
6407+
assert(!Record->isUnion() && "Union types not supported in flat cast.");
6408+
6409+
const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);
6410+
6411+
llvm::SmallVector<QualType, 16> FieldTypes;
6412+
if (CXXD && CXXD->isStandardLayout())
6413+
Record = CXXD->getStandardLayoutBaseWithFields();
6414+
6415+
// deal with potential base classes
6416+
if (CXXD && !CXXD->isStandardLayout()) {
6417+
for (auto &Base : CXXD->bases())
6418+
FieldTypes.push_back(Base.getType());
6419+
}
6420+
6421+
for (auto *FD : Record->fields())
6422+
FieldTypes.push_back(FD->getType());
6423+
6424+
for (int64_t I = FieldTypes.size() - 1; I > -1; I--) {
6425+
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
6426+
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
6427+
WorkList.insert(WorkList.end(), {FieldTypes[I], IdxListCopy});
6428+
}
6429+
} else if (const auto *VT = dyn_cast<VectorType>(T)) {
6430+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6431+
CharUnits Align = getContext().getTypeAlignInChars(T);
6432+
Address GEP =
6433+
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "vector.gep");
6434+
for (unsigned I = 0, E = VT->getNumElements(); I < E; I++) {
6435+
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, I);
6436+
// gep on vector fields is not recommended so combine gep with
6437+
// extract/insert
6438+
AccessList.emplace_back(GEP, Idx);
6439+
FlatTypes.push_back(VT->getElementType());
6440+
}
6441+
} else {
6442+
// a scalar/builtin type
6443+
llvm::Type *LLVMT = ConvertTypeForMem(T);
6444+
CharUnits Align = getContext().getTypeAlignInChars(T);
6445+
Address GEP =
6446+
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "gep");
6447+
AccessList.emplace_back(GEP, nullptr);
6448+
FlatTypes.push_back(T);
6449+
}
6450+
}
6451+
}

clang/lib/CodeGen/CGExprAgg.cpp

+93-1
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,79 @@ static bool isTrivialFiller(Expr *E) {
491491
return false;
492492
}
493493

494+
// emit a flat cast where the RHS is a scalar, including vector
495+
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
496+
QualType DestTy, llvm::Value *SrcVal,
497+
QualType SrcTy, SourceLocation Loc) {
498+
// Flatten our destination
499+
SmallVector<QualType, 16> DestTypes; // Flattened type
500+
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
501+
// ^^ Flattened accesses to DestVal we want to store into
502+
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
503+
504+
assert(SrcTy->isVectorType() && "HLSL Flat cast doesn't handle splatting.");
505+
const VectorType *VT = SrcTy->getAs<VectorType>();
506+
SrcTy = VT->getElementType();
507+
assert(StoreGEPList.size() <= VT->getNumElements() &&
508+
"Cannot perform HLSL flat cast when vector source \
509+
object has less elements than flattened destination \
510+
object.");
511+
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
512+
llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load");
513+
llvm::Value *Cast =
514+
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[I], Loc);
515+
516+
// store back
517+
llvm::Value *Idx = StoreGEPList[I].second;
518+
if (Idx) {
519+
llvm::Value *V =
520+
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
521+
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
522+
}
523+
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
524+
}
525+
return;
526+
}
527+
528+
// emit a flat cast where the RHS is an aggregate
529+
static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address DestVal,
530+
QualType DestTy, Address SrcVal,
531+
QualType SrcTy, SourceLocation Loc) {
532+
// Flatten our destination
533+
SmallVector<QualType, 16> DestTypes; // Flattened type
534+
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
535+
// ^^ Flattened accesses to DestVal we want to store into
536+
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
537+
// Flatten our src
538+
SmallVector<QualType, 16> SrcTypes; // Flattened type
539+
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
540+
// ^^ Flattened accesses to SrcVal we want to load from
541+
CGF.FlattenAccessAndType(SrcVal, SrcTy, LoadGEPList, SrcTypes);
542+
543+
assert(StoreGEPList.size() <= LoadGEPList.size() &&
544+
"Cannot perform HLSL flat cast when flattened source object \
545+
has less elements than flattened destination object.");
546+
// apply casts to what we load from LoadGEPList
547+
// and store result in Dest
548+
for (unsigned I = 0, E = StoreGEPList.size(); I < E; I++) {
549+
llvm::Value *Idx = LoadGEPList[I].second;
550+
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
551+
Load =
552+
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
553+
llvm::Value *Cast =
554+
CGF.EmitScalarConversion(Load, SrcTypes[I], DestTypes[I], Loc);
555+
556+
// store back
557+
Idx = StoreGEPList[I].second;
558+
if (Idx) {
559+
llvm::Value *V =
560+
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
561+
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
562+
}
563+
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
564+
}
565+
}
566+
494567
/// Emit initialization of an array from an initializer list. ExprToVisit must
495568
/// be either an InitListEpxr a CXXParenInitListExpr.
496569
void AggExprEmitter::EmitArrayInit(Address DestPtr, llvm::ArrayType *AType,
@@ -890,7 +963,25 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
890963
case CK_HLSLArrayRValue:
891964
Visit(E->getSubExpr());
892965
break;
893-
966+
case CK_HLSLElementwiseCast: {
967+
Expr *Src = E->getSubExpr();
968+
QualType SrcTy = Src->getType();
969+
RValue RV = CGF.EmitAnyExpr(Src);
970+
QualType DestTy = E->getType();
971+
Address DestVal = Dest.getAddress();
972+
SourceLocation Loc = E->getExprLoc();
973+
974+
if (RV.isScalar()) {
975+
llvm::Value *SrcVal = RV.getScalarVal();
976+
EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
977+
} else {
978+
assert(RV.isAggregate() &&
979+
"Can't perform HLSL Aggregate cast on a complex type.");
980+
Address SrcVal = RV.getAggregateAddress();
981+
EmitHLSLElementwiseCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
982+
}
983+
break;
984+
}
894985
case CK_NoOp:
895986
case CK_UserDefinedConversion:
896987
case CK_ConstructorConversion:
@@ -1461,6 +1552,7 @@ static bool castPreservesZero(const CastExpr *CE) {
14611552
case CK_NonAtomicToAtomic:
14621553
case CK_AtomicToNonAtomic:
14631554
case CK_HLSLVectorTruncation:
1555+
case CK_HLSLElementwiseCast:
14641556
return true;
14651557

14661558
case CK_BaseToDerivedMemberPointer:

clang/lib/CodeGen/CGExprComplex.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
610610
case CK_MatrixCast:
611611
case CK_HLSLVectorTruncation:
612612
case CK_HLSLArrayRValue:
613+
case CK_HLSLElementwiseCast:
613614
llvm_unreachable("invalid cast kind for complex value");
614615

615616
case CK_FloatingRealToComplex:

clang/lib/CodeGen/CGExprConstant.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,7 @@ class ConstExprEmitter
13351335
case CK_MatrixCast:
13361336
case CK_HLSLVectorTruncation:
13371337
case CK_HLSLArrayRValue:
1338+
case CK_HLSLElementwiseCast:
13381339
return nullptr;
13391340
}
13401341
llvm_unreachable("Invalid CastKind");

clang/lib/CodeGen/CGExprScalar.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -2269,6 +2269,42 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
22692269
return true;
22702270
}
22712271

2272+
// RHS is an aggregate type
2273+
static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address RHSVal,
2274+
QualType RHSTy, QualType LHSTy,
2275+
SourceLocation Loc) {
2276+
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
2277+
SmallVector<QualType, 16> SrcTypes; // Flattened type
2278+
CGF.FlattenAccessAndType(RHSVal, RHSTy, LoadGEPList, SrcTypes);
2279+
// LHS is either a vector or a builtin?
2280+
// if its a vector create a temp alloca to store into and return that
2281+
if (auto *VecTy = LHSTy->getAs<VectorType>()) {
2282+
assert(SrcTypes.size() >= VecTy->getNumElements() &&
2283+
"Flattened type on RHS must have more elements than vector on LHS.");
2284+
llvm::Value *V =
2285+
CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
2286+
// write to V.
2287+
for (unsigned I = 0, E = VecTy->getNumElements(); I < E; I++) {
2288+
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
2289+
llvm::Value *Idx = LoadGEPList[I].second;
2290+
Load = Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract")
2291+
: Load;
2292+
llvm::Value *Cast = CGF.EmitScalarConversion(
2293+
Load, SrcTypes[I], VecTy->getElementType(), Loc);
2294+
V = CGF.Builder.CreateInsertElement(V, Cast, I);
2295+
}
2296+
return V;
2297+
}
2298+
// i its a builtin just do an extract element or load.
2299+
assert(LHSTy->isBuiltinType() &&
2300+
"Destination type must be a vector or builtin type.");
2301+
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[0].first, "load");
2302+
llvm::Value *Idx = LoadGEPList[0].second;
2303+
Load =
2304+
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
2305+
return CGF.EmitScalarConversion(Load, LHSTy, SrcTypes[0], Loc);
2306+
}
2307+
22722308
// VisitCastExpr - Emit code for an explicit or implicit cast. Implicit casts
22732309
// have to handle a more broad range of conversions than explicit casts, as they
22742310
// handle things like function to ptr-to-function decay etc.
@@ -2759,7 +2795,16 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
27592795
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
27602796
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
27612797
}
2798+
case CK_HLSLElementwiseCast: {
2799+
RValue RV = CGF.EmitAnyExpr(E);
2800+
SourceLocation Loc = CE->getExprLoc();
2801+
QualType SrcTy = E->getType();
27622802

2803+
assert(RV.isAggregate() && "Not a valid HLSL Flat Cast.");
2804+
// RHS is an aggregate
2805+
Address SrcVal = RV.getAggregateAddress();
2806+
return EmitHLSLElementwiseCast(CGF, SrcVal, SrcTy, DestTy, Loc);
2807+
}
27632808
} // end of switch
27642809

27652810
llvm_unreachable("unknown scalar cast");

clang/lib/CodeGen/CodeGenFunction.h

+5
Original file line numberDiff line numberDiff line change
@@ -4439,6 +4439,11 @@ class CodeGenFunction : public CodeGenTypeCache {
44394439
AggValueSlot slot = AggValueSlot::ignored());
44404440
LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);
44414441

4442+
void FlattenAccessAndType(
4443+
Address Addr, QualType AddrTy,
4444+
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
4445+
SmallVectorImpl<QualType> &FlatTypes);
4446+
44424447
llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
44434448
const ObjCIvarDecl *Ivar);
44444449
llvm::Value *EmitIvarOffsetAsPointerDiff(const ObjCInterfaceDecl *Interface,

clang/lib/Edit/RewriteObjCFoundationAPI.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
10851085
llvm_unreachable("OpenCL-specific cast in Objective-C?");
10861086

10871087
case CK_HLSLVectorTruncation:
1088+
case CK_HLSLElementwiseCast:
10881089
llvm_unreachable("HLSL-specific cast in Objective-C?");
10891090
break;
10901091

clang/lib/Sema/SemaCast.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "clang/Basic/TargetInfo.h"
2424
#include "clang/Lex/Preprocessor.h"
2525
#include "clang/Sema/Initialization.h"
26+
#include "clang/Sema/SemaHLSL.h"
2627
#include "clang/Sema/SemaObjC.h"
2728
#include "clang/Sema/SemaRISCV.h"
2829
#include "llvm/ADT/SmallVector.h"
@@ -2772,6 +2773,22 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
27722773
return;
27732774
}
27742775

2776+
CheckedConversionKind CCK = FunctionalStyle
2777+
? CheckedConversionKind::FunctionalCast
2778+
: CheckedConversionKind::CStyleCast;
2779+
// This case should not trigger on regular vector splat
2780+
// vector cast, vector truncation, or special hlsl splat cases
2781+
QualType SrcTy = SrcExpr.get()->getType();
2782+
if (Self.getLangOpts().HLSL &&
2783+
Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
2784+
if (SrcTy->isConstantArrayType())
2785+
SrcExpr = Self.ImpCastExprToType(
2786+
SrcExpr.get(), Self.Context.getArrayParameterType(SrcTy),
2787+
CK_HLSLArrayRValue, VK_PRValue, nullptr, CCK);
2788+
Kind = CK_HLSLElementwiseCast;
2789+
return;
2790+
}
2791+
27752792
if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
27762793
!isPlaceholder(BuiltinType::Overload)) {
27772794
SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());
@@ -2824,9 +2841,6 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
28242841
if (isValidCast(tcr))
28252842
Kind = CK_NoOp;
28262843

2827-
CheckedConversionKind CCK = FunctionalStyle
2828-
? CheckedConversionKind::FunctionalCast
2829-
: CheckedConversionKind::CStyleCast;
28302844
if (tcr == TC_NotApplicable) {
28312845
tcr = TryAddressSpaceCast(Self, SrcExpr, DestType, /*CStyle*/ true, msg,
28322846
Kind);

0 commit comments

Comments
 (0)