Skip to content

Commit

Permalink
[RISCV] Improve stack clash probe loop
Browse files Browse the repository at this point in the history
Limit the unrolled probe loop and emit a variable length probe loop
for bigger allocations.
We add a new pseudo instruction RISCV::PROBED_STACKALLOC that will
later be synthesized in a loop by `inlineStackProbe`.
  • Loading branch information
rzinsly committed Dec 5, 2024
1 parent f9c0f9c commit a1b32f8
Show file tree
Hide file tree
Showing 5 changed files with 896 additions and 33 deletions.
183 changes: 150 additions & 33 deletions llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,46 +608,97 @@ void RISCVFrameLowering::allocateStack(MachineBasicBlock &MBB,
return;
}

// Do an unrolled probe loop.
uint64_t CurrentOffset = 0;
bool IsRV64 = STI.is64Bit();
while (CurrentOffset + ProbeSize <= Offset) {
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg,
StackOffset::getFixed(-ProbeSize), MachineInstr::FrameSetup,
getStackAlign());
// s[d|w] zero, 0(sp)
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
.addReg(RISCV::X0)
.addReg(SPReg)
.addImm(0)
.setMIFlags(MachineInstr::FrameSetup);
// Unroll the probe loop depending on the number of iterations.
if (Offset < ProbeSize * 5) {
uint64_t CurrentOffset = 0;
bool IsRV64 = STI.is64Bit();
while (CurrentOffset + ProbeSize <= Offset) {
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg,
StackOffset::getFixed(-ProbeSize), MachineInstr::FrameSetup,
getStackAlign());
// s[d|w] zero, 0(sp)
BuildMI(MBB, MBBI, DL, TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
.addReg(RISCV::X0)
.addReg(SPReg)
.addImm(0)
.setMIFlags(MachineInstr::FrameSetup);

CurrentOffset += ProbeSize;
if (EmitCFI) {
// Emit ".cfi_def_cfa_offset CurrentOffset"
unsigned CFIIndex = MF.addFrameInst(
MCCFIInstruction::cfiDefCfaOffset(nullptr, CurrentOffset));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
.addCFIIndex(CFIIndex)
.setMIFlag(MachineInstr::FrameSetup);
}
}

CurrentOffset += ProbeSize;
if (EmitCFI) {
// Emit ".cfi_def_cfa_offset CurrentOffset"
unsigned CFIIndex = MF.addFrameInst(
MCCFIInstruction::cfiDefCfaOffset(nullptr, CurrentOffset));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
.addCFIIndex(CFIIndex)
.setMIFlag(MachineInstr::FrameSetup);
uint64_t Residual = Offset - CurrentOffset;
if (Residual) {
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg,
StackOffset::getFixed(-Residual), MachineInstr::FrameSetup,
getStackAlign());
if (EmitCFI) {
// Emit ".cfi_def_cfa_offset Offset"
unsigned CFIIndex =
MF.addFrameInst(MCCFIInstruction::cfiDefCfaOffset(nullptr, Offset));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
.addCFIIndex(CFIIndex)
.setMIFlag(MachineInstr::FrameSetup);
}
}

return;
}

// Emit a variable-length allocation probing loop.
uint64_t RoundedSize = alignDown(Offset, ProbeSize);
uint64_t Residual = Offset - RoundedSize;

Register TargetReg = RISCV::X6;
// SUB TargetReg, SP, RoundedSize
RI->adjustReg(MBB, MBBI, DL, TargetReg, SPReg,
StackOffset::getFixed(-RoundedSize), MachineInstr::FrameSetup,
getStackAlign());

if (EmitCFI) {
// Set the CFA register to TargetReg.
unsigned Reg = STI.getRegisterInfo()->getDwarfRegNum(TargetReg, true);
unsigned CFIIndex =
MF.addFrameInst(MCCFIInstruction::cfiDefCfa(nullptr, Reg, RoundedSize));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
.addCFIIndex(CFIIndex)
.setMIFlags(MachineInstr::FrameSetup);
}

// It will be expanded to a probe loop in `inlineStackProbe`.
BuildMI(MBB, MBBI, DL, TII->get(RISCV::PROBED_STACKALLOC))
.addReg(SPReg)
.addReg(TargetReg);

if (EmitCFI) {
// Set the CFA register back to SP.
unsigned Reg = STI.getRegisterInfo()->getDwarfRegNum(SPReg, true);
unsigned CFIIndex =
MF.addFrameInst(MCCFIInstruction::createDefCfaRegister(nullptr, Reg));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
.addCFIIndex(CFIIndex)
.setMIFlags(MachineInstr::FrameSetup);
}

uint64_t Residual = Offset - CurrentOffset;
if (Residual) {
if (Residual)
RI->adjustReg(MBB, MBBI, DL, SPReg, SPReg, StackOffset::getFixed(-Residual),
MachineInstr::FrameSetup, getStackAlign());
if (EmitCFI) {
// Emit ".cfi_def_cfa_offset Offset"
unsigned CFIIndex =
MF.addFrameInst(MCCFIInstruction::cfiDefCfaOffset(nullptr, Offset));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
.addCFIIndex(CFIIndex)
.setMIFlag(MachineInstr::FrameSetup);
}
}

return;
if (EmitCFI) {
// Emit ".cfi_def_cfa_offset Offset"
unsigned CFIIndex =
MF.addFrameInst(MCCFIInstruction::cfiDefCfaOffset(nullptr, Offset));
BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
.addCFIIndex(CFIIndex)
.setMIFlags(MachineInstr::FrameSetup);
}
}

void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
Expand Down Expand Up @@ -1962,3 +2013,69 @@ bool RISCVFrameLowering::isSupportedStackID(TargetStackID::Value ID) const {
TargetStackID::Value RISCVFrameLowering::getStackIDForScalableVectors() const {
return TargetStackID::ScalableVector;
}

// Synthesize the probe loop.
static void emitStackProbeInline(MachineFunction &MF, MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
DebugLoc DL) {

auto &Subtarget = MF.getSubtarget<RISCVSubtarget>();
const RISCVInstrInfo *TII = Subtarget.getInstrInfo();
bool IsRV64 = Subtarget.is64Bit();
Align StackAlign = Subtarget.getFrameLowering()->getStackAlign();
const RISCVTargetLowering *TLI = Subtarget.getTargetLowering();
uint64_t ProbeSize = TLI->getStackProbeSize(MF, StackAlign);

MachineFunction::iterator MBBInsertPoint = std::next(MBB.getIterator());
MachineBasicBlock *LoopTestMBB =
MF.CreateMachineBasicBlock(MBB.getBasicBlock());
MF.insert(MBBInsertPoint, LoopTestMBB);
MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(MBB.getBasicBlock());
MF.insert(MBBInsertPoint, ExitMBB);
MachineInstr::MIFlag Flags = MachineInstr::FrameSetup;
Register TargetReg = RISCV::X6;
Register ScratchReg = RISCV::X7;

// ScratchReg = ProbeSize
TII->movImm(MBB, MBBI, DL, ScratchReg, ProbeSize, Flags);

// LoopTest:
// SUB SP, SP, ProbeSize
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::SUB), SPReg)
.addReg(SPReg)
.addReg(ScratchReg)
.setMIFlags(Flags);

// s[d|w] zero, 0(sp)
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL,
TII->get(IsRV64 ? RISCV::SD : RISCV::SW))
.addReg(RISCV::X0)
.addReg(SPReg)
.addImm(0)
.setMIFlags(Flags);

// BNE SP, TargetReg, LoopTest
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(RISCV::BNE))
.addReg(SPReg)
.addReg(TargetReg)
.addMBB(LoopTestMBB)
.setMIFlags(Flags);

ExitMBB->splice(ExitMBB->end(), &MBB, std::next(MBBI), MBB.end());

LoopTestMBB->addSuccessor(ExitMBB);
LoopTestMBB->addSuccessor(LoopTestMBB);
MBB.addSuccessor(LoopTestMBB);
}

void RISCVFrameLowering::inlineStackProbe(MachineFunction &MF,
MachineBasicBlock &MBB) const {
auto Where = llvm::find_if(MBB, [](MachineInstr &MI) {
return MI.getOpcode() == RISCV::PROBED_STACKALLOC;
});
if (Where != MBB.end()) {
DebugLoc DL = MBB.findDebugLoc(Where);
emitStackProbeInline(MF, MBB, Where, DL);
Where->eraseFromParent();
}
}
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVFrameLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class RISCVFrameLowering : public TargetFrameLowering {

std::pair<int64_t, Align>
assignRVVStackObjectOffsets(MachineFunction &MF) const;
// Replace a StackProbe stub (if any) with the actual probe code inline
void inlineStackProbe(MachineFunction &MF,
MachineBasicBlock &PrologueMBB) const override;
};
} // namespace llvm
#endif
11 changes: 11 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,17 @@ def PseudoAddTPRel : Pseudo<(outs GPR:$rd),
def : Pat<(FrameAddrRegImm (iPTR GPR:$rs1), simm12:$imm12),
(ADDI GPR:$rs1, simm12:$imm12)>;

/// Stack probing

let hasSideEffects = 1, mayLoad = 1, mayStore = 1, isCodeGenOnly = 1 in {
// Probed stack allocation of a constant size, used in function prologues when
// stack-clash protection is enabled.
def PROBED_STACKALLOC : Pseudo<(outs GPR:$sp),
(ins GPR:$scratch),
[]>,
Sched<[]>;
}

/// HI and ADD_LO address nodes.

// Pseudo for a rematerializable LUI+ADDI sequence for loading an address.
Expand Down
Loading

0 comments on commit a1b32f8

Please sign in to comment.