@@ -491,6 +491,79 @@ static bool isTrivialFiller(Expr *E) {
491
491
return false ;
492
492
}
493
493
494
+ // emit a flat cast where the RHS is a scalar, including vector
495
+ static void EmitHLSLScalarFlatCast (CodeGenFunction &CGF, Address DestVal,
496
+ QualType DestTy, llvm::Value *SrcVal,
497
+ QualType SrcTy, SourceLocation Loc) {
498
+ // Flatten our destination
499
+ SmallVector<QualType, 16 > DestTypes; // Flattened type
500
+ SmallVector<std::pair<Address, llvm::Value *>, 16 > StoreGEPList;
501
+ // ^^ Flattened accesses to DestVal we want to store into
502
+ CGF.FlattenAccessAndType (DestVal, DestTy, StoreGEPList, DestTypes);
503
+
504
+ assert (SrcTy->isVectorType () && " HLSL Flat cast doesn't handle splatting." );
505
+ const VectorType *VT = SrcTy->getAs <VectorType>();
506
+ SrcTy = VT->getElementType ();
507
+ assert (StoreGEPList.size () <= VT->getNumElements () &&
508
+ " Cannot perform HLSL flat cast when vector source \
509
+ object has less elements than flattened destination \
510
+ object." );
511
+ for (unsigned I = 0 , Size = StoreGEPList.size (); I < Size ; I++) {
512
+ llvm::Value *Load = CGF.Builder .CreateExtractElement (SrcVal, I, " vec.load" );
513
+ llvm::Value *Cast =
514
+ CGF.EmitScalarConversion (Load, SrcTy, DestTypes[I], Loc);
515
+
516
+ // store back
517
+ llvm::Value *Idx = StoreGEPList[I].second ;
518
+ if (Idx) {
519
+ llvm::Value *V =
520
+ CGF.Builder .CreateLoad (StoreGEPList[I].first , " load.for.insert" );
521
+ Cast = CGF.Builder .CreateInsertElement (V, Cast, Idx);
522
+ }
523
+ CGF.Builder .CreateStore (Cast, StoreGEPList[I].first );
524
+ }
525
+ return ;
526
+ }
527
+
528
+ // emit a flat cast where the RHS is an aggregate
529
+ static void EmitHLSLElementwiseCast (CodeGenFunction &CGF, Address DestVal,
530
+ QualType DestTy, Address SrcVal,
531
+ QualType SrcTy, SourceLocation Loc) {
532
+ // Flatten our destination
533
+ SmallVector<QualType, 16 > DestTypes; // Flattened type
534
+ SmallVector<std::pair<Address, llvm::Value *>, 16 > StoreGEPList;
535
+ // ^^ Flattened accesses to DestVal we want to store into
536
+ CGF.FlattenAccessAndType (DestVal, DestTy, StoreGEPList, DestTypes);
537
+ // Flatten our src
538
+ SmallVector<QualType, 16 > SrcTypes; // Flattened type
539
+ SmallVector<std::pair<Address, llvm::Value *>, 16 > LoadGEPList;
540
+ // ^^ Flattened accesses to SrcVal we want to load from
541
+ CGF.FlattenAccessAndType (SrcVal, SrcTy, LoadGEPList, SrcTypes);
542
+
543
+ assert (StoreGEPList.size () <= LoadGEPList.size () &&
544
+ " Cannot perform HLSL flat cast when flattened source object \
545
+ has less elements than flattened destination object." );
546
+ // apply casts to what we load from LoadGEPList
547
+ // and store result in Dest
548
+ for (unsigned I = 0 , E = StoreGEPList.size (); I < E; I++) {
549
+ llvm::Value *Idx = LoadGEPList[I].second ;
550
+ llvm::Value *Load = CGF.Builder .CreateLoad (LoadGEPList[I].first , " load" );
551
+ Load =
552
+ Idx ? CGF.Builder .CreateExtractElement (Load, Idx, " vec.extract" ) : Load;
553
+ llvm::Value *Cast =
554
+ CGF.EmitScalarConversion (Load, SrcTypes[I], DestTypes[I], Loc);
555
+
556
+ // store back
557
+ Idx = StoreGEPList[I].second ;
558
+ if (Idx) {
559
+ llvm::Value *V =
560
+ CGF.Builder .CreateLoad (StoreGEPList[I].first , " load.for.insert" );
561
+ Cast = CGF.Builder .CreateInsertElement (V, Cast, Idx);
562
+ }
563
+ CGF.Builder .CreateStore (Cast, StoreGEPList[I].first );
564
+ }
565
+ }
566
+
494
567
// / Emit initialization of an array from an initializer list. ExprToVisit must
495
568
// / be either an InitListEpxr a CXXParenInitListExpr.
496
569
void AggExprEmitter::EmitArrayInit (Address DestPtr, llvm::ArrayType *AType,
@@ -890,7 +963,25 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
890
963
case CK_HLSLArrayRValue:
891
964
Visit (E->getSubExpr ());
892
965
break ;
893
-
966
+ case CK_HLSLElementwiseCast: {
967
+ Expr *Src = E->getSubExpr ();
968
+ QualType SrcTy = Src->getType ();
969
+ RValue RV = CGF.EmitAnyExpr (Src);
970
+ QualType DestTy = E->getType ();
971
+ Address DestVal = Dest.getAddress ();
972
+ SourceLocation Loc = E->getExprLoc ();
973
+
974
+ if (RV.isScalar ()) {
975
+ llvm::Value *SrcVal = RV.getScalarVal ();
976
+ EmitHLSLScalarFlatCast (CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
977
+ } else {
978
+ assert (RV.isAggregate () &&
979
+ " Can't perform HLSL Aggregate cast on a complex type." );
980
+ Address SrcVal = RV.getAggregateAddress ();
981
+ EmitHLSLElementwiseCast (CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
982
+ }
983
+ break ;
984
+ }
894
985
case CK_NoOp:
895
986
case CK_UserDefinedConversion:
896
987
case CK_ConstructorConversion:
@@ -1461,6 +1552,7 @@ static bool castPreservesZero(const CastExpr *CE) {
1461
1552
case CK_NonAtomicToAtomic:
1462
1553
case CK_AtomicToNonAtomic:
1463
1554
case CK_HLSLVectorTruncation:
1555
+ case CK_HLSLElementwiseCast:
1464
1556
return true ;
1465
1557
1466
1558
case CK_BaseToDerivedMemberPointer:
0 commit comments