Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[HLSL] Implement HLSL Flat casting (excluding splat cases)" #126149

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions clang/include/clang/AST/OperationKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,6 @@ CAST_OPERATION(HLSLVectorTruncation)
// Non-decaying array RValue cast (HLSL only).
CAST_OPERATION(HLSLArrayRValue)

// Aggregate by Value cast (HLSL only).
CAST_OPERATION(HLSLElementwiseCast)

//===- Binary Operations -------------------------------------------------===//
// Operators listed in order of precedence.
// Note that additions to this should also update the StmtVisitor class,
Expand Down
3 changes: 0 additions & 3 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ class SemaHLSL : public SemaBase {
// Diagnose whether the input ID is uint/unit2/uint3 type.
bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);

bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
bool ContainsBitField(QualType BaseTy);
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);

QualType getInoutParameterType(QualType Ty);
Expand Down
1 change: 0 additions & 1 deletion clang/lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,6 @@ bool CastExpr::CastConsistency() const {
case CK_FixedPointToBoolean:
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
CheckNoBasePath:
assert(path_empty() && "Cast kind should not have a base path!");
break;
Expand Down
2 changes: 0 additions & 2 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15047,7 +15047,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_NoOp:
case CK_LValueToRValueBitCast:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
return ExprEvaluatorBaseTy::VisitCastExpr(E);

case CK_MemberPointerToBoolean:
Expand Down Expand Up @@ -15906,7 +15905,6 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_LValueToRValue:
Expand Down
73 changes: 0 additions & 73 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5338,7 +5338,6 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
return EmitUnsupportedLValue(E, "unexpected cast lvalue");

case CK_Dependent:
Expand Down Expand Up @@ -6377,75 +6376,3 @@ RValue CodeGenFunction::EmitPseudoObjectRValue(const PseudoObjectExpr *E,
LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
}

void CodeGenFunction::FlattenAccessAndType(
Address Addr, QualType AddrType,
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
SmallVectorImpl<QualType> &FlatTypes) {
// WorkList is list of type we are processing + the Index List to access
// the field of that type in Addr for use in a GEP
llvm::SmallVector<std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>>,
16>
WorkList;
llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
// Addr should be a pointer so we need to 'dereference' it
WorkList.push_back({AddrType, {llvm::ConstantInt::get(IdxTy, 0)}});

while (!WorkList.empty()) {
auto [T, IdxList] = WorkList.pop_back_val();
T = T.getCanonicalType().getUnqualifiedType();
assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
uint64_t Size = CAT->getZExtSize();
for (int64_t I = Size - 1; I > -1; I--) {
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
WorkList.emplace_back(CAT->getElementType(), IdxListCopy);
}
} else if (const auto *RT = dyn_cast<RecordType>(T)) {
const RecordDecl *Record = RT->getDecl();
assert(!Record->isUnion() && "Union types not supported in flat cast.");

const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);

llvm::SmallVector<QualType, 16> FieldTypes;
if (CXXD && CXXD->isStandardLayout())
Record = CXXD->getStandardLayoutBaseWithFields();

// deal with potential base classes
if (CXXD && !CXXD->isStandardLayout()) {
for (auto &Base : CXXD->bases())
FieldTypes.push_back(Base.getType());
}

for (auto *FD : Record->fields())
FieldTypes.push_back(FD->getType());

for (int64_t I = FieldTypes.size() - 1; I > -1; I--) {
llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
WorkList.insert(WorkList.end(), {FieldTypes[I], IdxListCopy});
}
} else if (const auto *VT = dyn_cast<VectorType>(T)) {
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
Address GEP =
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "vector.gep");
for (unsigned I = 0, E = VT->getNumElements(); I < E; I++) {
llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, I);
// gep on vector fields is not recommended so combine gep with
// extract/insert
AccessList.emplace_back(GEP, Idx);
FlatTypes.push_back(VT->getElementType());
}
} else {
// a scalar/builtin type
llvm::Type *LLVMT = ConvertTypeForMem(T);
CharUnits Align = getContext().getTypeAlignInChars(T);
Address GEP =
Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "gep");
AccessList.emplace_back(GEP, nullptr);
FlatTypes.push_back(T);
}
}
}
94 changes: 1 addition & 93 deletions clang/lib/CodeGen/CGExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,79 +491,6 @@ static bool isTrivialFiller(Expr *E) {
return false;
}

// emit a flat cast where the RHS is a scalar, including vector
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType, 16> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);

assert(SrcTy->isVectorType() && "HLSL Flat cast doesn't handle splatting.");
const VectorType *VT = SrcTy->getAs<VectorType>();
SrcTy = VT->getElementType();
assert(StoreGEPList.size() <= VT->getNumElements() &&
"Cannot perform HLSL flat cast when vector source \
object has less elements than flattened destination \
object.");
for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load");
llvm::Value *Cast =
CGF.EmitScalarConversion(Load, SrcTy, DestTypes[I], Loc);

// store back
llvm::Value *Idx = StoreGEPList[I].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
}
return;
}

// emit a flat cast where the RHS is an aggregate
static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, Address SrcVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType, 16> DestTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
// Flatten our src
SmallVector<QualType, 16> SrcTypes; // Flattened type
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
// ^^ Flattened accesses to SrcVal we want to load from
CGF.FlattenAccessAndType(SrcVal, SrcTy, LoadGEPList, SrcTypes);

assert(StoreGEPList.size() <= LoadGEPList.size() &&
"Cannot perform HLSL flat cast when flattened source object \
has less elements than flattened destination object.");
// apply casts to what we load from LoadGEPList
// and store result in Dest
for (unsigned I = 0, E = StoreGEPList.size(); I < E; I++) {
llvm::Value *Idx = LoadGEPList[I].second;
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
Load =
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
llvm::Value *Cast =
CGF.EmitScalarConversion(Load, SrcTypes[I], DestTypes[I], Loc);

// store back
Idx = StoreGEPList[I].second;
if (Idx) {
llvm::Value *V =
CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
}
CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
}
}

/// Emit initialization of an array from an initializer list. ExprToVisit must
/// be either an InitListEpxr a CXXParenInitListExpr.
void AggExprEmitter::EmitArrayInit(Address DestPtr, llvm::ArrayType *AType,
Expand Down Expand Up @@ -963,25 +890,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_HLSLArrayRValue:
Visit(E->getSubExpr());
break;
case CK_HLSLElementwiseCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
RValue RV = CGF.EmitAnyExpr(Src);
QualType DestTy = E->getType();
Address DestVal = Dest.getAddress();
SourceLocation Loc = E->getExprLoc();

if (RV.isScalar()) {
llvm::Value *SrcVal = RV.getScalarVal();
EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
} else {
assert(RV.isAggregate() &&
"Can't perform HLSL Aggregate cast on a complex type.");
Address SrcVal = RV.getAggregateAddress();
EmitHLSLElementwiseCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
}
break;
}

case CK_NoOp:
case CK_UserDefinedConversion:
case CK_ConstructorConversion:
Expand Down Expand Up @@ -1552,7 +1461,6 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_NonAtomicToAtomic:
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
return true;

case CK_BaseToDerivedMemberPointer:
Expand Down
1 change: 0 additions & 1 deletion clang/lib/CodeGen/CGExprComplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,6 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
llvm_unreachable("invalid cast kind for complex value");

case CK_FloatingRealToComplex:
Expand Down
1 change: 0 additions & 1 deletion clang/lib/CodeGen/CGExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,6 @@ class ConstExprEmitter
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
return nullptr;
}
llvm_unreachable("Invalid CastKind");
Expand Down
45 changes: 0 additions & 45 deletions clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2269,42 +2269,6 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
return true;
}

// RHS is an aggregate type
static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address RHSVal,
QualType RHSTy, QualType LHSTy,
SourceLocation Loc) {
SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
SmallVector<QualType, 16> SrcTypes; // Flattened type
CGF.FlattenAccessAndType(RHSVal, RHSTy, LoadGEPList, SrcTypes);
// LHS is either a vector or a builtin?
// if its a vector create a temp alloca to store into and return that
if (auto *VecTy = LHSTy->getAs<VectorType>()) {
assert(SrcTypes.size() >= VecTy->getNumElements() &&
"Flattened type on RHS must have more elements than vector on LHS.");
llvm::Value *V =
CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
// write to V.
for (unsigned I = 0, E = VecTy->getNumElements(); I < E; I++) {
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
llvm::Value *Idx = LoadGEPList[I].second;
Load = Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract")
: Load;
llvm::Value *Cast = CGF.EmitScalarConversion(
Load, SrcTypes[I], VecTy->getElementType(), Loc);
V = CGF.Builder.CreateInsertElement(V, Cast, I);
}
return V;
}
// i its a builtin just do an extract element or load.
assert(LHSTy->isBuiltinType() &&
"Destination type must be a vector or builtin type.");
llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[0].first, "load");
llvm::Value *Idx = LoadGEPList[0].second;
Load =
Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
return CGF.EmitScalarConversion(Load, LHSTy, SrcTypes[0], Loc);
}

// VisitCastExpr - Emit code for an explicit or implicit cast. Implicit casts
// have to handle a more broad range of conversions than explicit casts, as they
// handle things like function to ptr-to-function decay etc.
Expand Down Expand Up @@ -2795,16 +2759,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
}
case CK_HLSLElementwiseCast: {
RValue RV = CGF.EmitAnyExpr(E);
SourceLocation Loc = CE->getExprLoc();
QualType SrcTy = E->getType();

assert(RV.isAggregate() && "Not a valid HLSL Flat Cast.");
// RHS is an aggregate
Address SrcVal = RV.getAggregateAddress();
return EmitHLSLElementwiseCast(CGF, SrcVal, SrcTy, DestTy, Loc);
}
} // end of switch

llvm_unreachable("unknown scalar cast");
Expand Down
5 changes: 0 additions & 5 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4439,11 +4439,6 @@ class CodeGenFunction : public CodeGenTypeCache {
AggValueSlot slot = AggValueSlot::ignored());
LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);

void FlattenAccessAndType(
Address Addr, QualType AddrTy,
SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
SmallVectorImpl<QualType> &FlatTypes);

llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
const ObjCIvarDecl *Ivar);
llvm::Value *EmitIvarOffsetAsPointerDiff(const ObjCInterfaceDecl *Interface,
Expand Down
1 change: 0 additions & 1 deletion clang/lib/Edit/RewriteObjCFoundationAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,6 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
llvm_unreachable("OpenCL-specific cast in Objective-C?");

case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
llvm_unreachable("HLSL-specific cast in Objective-C?");
break;

Expand Down
20 changes: 3 additions & 17 deletions clang/lib/Sema/SemaCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "clang/Basic/TargetInfo.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Sema/Initialization.h"
#include "clang/Sema/SemaHLSL.h"
#include "clang/Sema/SemaObjC.h"
#include "clang/Sema/SemaRISCV.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -2773,22 +2772,6 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
return;
}

CheckedConversionKind CCK = FunctionalStyle
? CheckedConversionKind::FunctionalCast
: CheckedConversionKind::CStyleCast;
// This case should not trigger on regular vector splat
// vector cast, vector truncation, or special hlsl splat cases
QualType SrcTy = SrcExpr.get()->getType();
if (Self.getLangOpts().HLSL &&
Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
if (SrcTy->isConstantArrayType())
SrcExpr = Self.ImpCastExprToType(
SrcExpr.get(), Self.Context.getArrayParameterType(SrcTy),
CK_HLSLArrayRValue, VK_PRValue, nullptr, CCK);
Kind = CK_HLSLElementwiseCast;
return;
}

if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
!isPlaceholder(BuiltinType::Overload)) {
SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());
Expand Down Expand Up @@ -2841,6 +2824,9 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
if (isValidCast(tcr))
Kind = CK_NoOp;

CheckedConversionKind CCK = FunctionalStyle
? CheckedConversionKind::FunctionalCast
: CheckedConversionKind::CStyleCast;
if (tcr == TC_NotApplicable) {
tcr = TryAddressSpaceCast(Self, SrcExpr, DestType, /*CStyle*/ true, msg,
Kind);
Expand Down
Loading