diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 459222234bc37..cb58abec61c11 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2347,6 +2347,11 @@ class VPWidenPHIRecipe : public VPSingleDefRecipe { /// Returns the \p I th incoming VPBasicBlock. VPBasicBlock *getIncomingBlock(unsigned I) { return IncomingBlocks[I]; } + /// Set the \p I th incoming VPBasicBlock to \p IncomingBlock. + void setIncomingBlock(unsigned I, VPBasicBlock *IncomingBlock) { + IncomingBlocks[I] = IncomingBlock; + } + /// Returns the \p I th incoming VPValue. VPValue *getIncomingValue(unsigned I) { return getOperand(I); } }; diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.h b/llvm/lib/Transforms/Vectorize/VPlanUtils.h index 6ddb88308955f..ac5e1978fcfbe 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.h +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.h @@ -169,8 +169,16 @@ class VPBlockUtils { static void reassociateBlocks(VPBlockBase *Old, VPBlockBase *New) { for (auto *Pred : to_vector(Old->getPredecessors())) Pred->replaceSuccessor(Old, New); - for (auto *Succ : to_vector(Old->getSuccessors())) + for (auto *Succ : to_vector(Old->getSuccessors())) { Succ->replacePredecessor(Old, New); + + // Replace any references to Old in widened phi incoming blocks. + for (auto &R : Succ->getEntryBasicBlock()->phis()) + if (auto *WidenPhiR = dyn_cast(&R)) + for (unsigned I = 0; I < WidenPhiR->getNumOperands(); I++) + if (WidenPhiR->getIncomingBlock(I) == Old) + WidenPhiR->setIncomingBlock(I, cast(New)); + } New->setPredecessors(Old->getPredecessors()); New->setSuccessors(Old->getSuccessors()); Old->clearPredecessors(); diff --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp index e7987a95f1ca2..0b57f8084e5f4 100644 --- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp @@ -659,6 +659,49 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) { } } +TEST_F(VPBasicBlockTest, reassociateBlocks) { + { + // Ensure that when we reassociate a basic block, we make sure to update any + // references to it in VPWidenPHIRecipes' incoming blocks. + VPlan &Plan = getPlan(); + VPBasicBlock *VPBB1 = Plan.createVPBasicBlock("VPBB1"); + VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("VPBB2"); + VPBlockUtils::connectBlocks(VPBB1, VPBB2); + + auto *WidenPhi = new VPWidenPHIRecipe(nullptr); + IntegerType *Int32 = IntegerType::get(C, 32); + VPValue *Val = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1)); + WidenPhi->addIncoming(Val, VPBB1); + VPBB2->appendRecipe(WidenPhi); + + VPBasicBlock *VPBBNew = Plan.createVPBasicBlock("VPBBNew"); + VPBlockUtils::reassociateBlocks(VPBB1, VPBBNew); + EXPECT_EQ(VPBB2->getSinglePredecessor(), VPBBNew); + EXPECT_EQ(WidenPhi->getIncomingBlock(0), VPBBNew); + } + + { + // Ensure that we update VPWidenPHIRecipes that are nested inside a + // VPRegionBlock. + VPlan &Plan = getPlan(); + VPBasicBlock *VPBB1 = Plan.createVPBasicBlock("VPBB1"); + VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("VPBB2"); + VPRegionBlock *R1 = Plan.createVPRegionBlock(VPBB2, VPBB2, "R1"); + VPBlockUtils::connectBlocks(VPBB1, R1); + + auto *WidenPhi = new VPWidenPHIRecipe(nullptr); + IntegerType *Int32 = IntegerType::get(C, 32); + VPValue *Val = Plan.getOrAddLiveIn(ConstantInt::get(Int32, 1)); + WidenPhi->addIncoming(Val, VPBB1); + VPBB2->appendRecipe(WidenPhi); + + VPBasicBlock *VPBBNew = Plan.createVPBasicBlock("VPBBNew"); + VPBlockUtils::reassociateBlocks(VPBB1, VPBBNew); + EXPECT_EQ(R1->getSinglePredecessor(), VPBBNew); + EXPECT_EQ(WidenPhi->getIncomingBlock(0), VPBBNew); + } +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) TEST_F(VPBasicBlockTest, print) { VPInstruction *TC = new VPInstruction(Instruction::Add, {});