Skip to content

Commit

Permalink
Do not trigger write barrier when the child is a permanently rooted o…
Browse files Browse the repository at this point in the history
…bject

Or when the parent and the child are the same object
  • Loading branch information
yuyichao committed Oct 30, 2017
1 parent aa8adc8 commit 5b5cf3a
Showing 1 changed file with 123 additions and 26 deletions.
149 changes: 123 additions & 26 deletions src/llvm-late-gc-lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,13 @@ struct State {
// The result of the local analysis
std::map<BasicBlock *, BBState> BBStates;

// Refinement map. If all of the values are rooted (-1 means an externally rooted value),
// Refinement map. If all of the values are rooted
// (-1 means an externally rooted value and -2 means a globally/permanently rooted value),
// the key is already rooted (but not the other way around).
// A value that can be refined to -2 never need any rooting or write barrier.
// A value that can be refined to -1 don't need local root but still need write barrier.
// At the end of `LocalScan` this map has a few properties
// 1. Values are either -1 or dominates the key
// 1. Values are either < 0 or dominates the key
// 2. Therefore this is a DAG
std::map<int, SmallVector<int, 1>> Refinements;

Expand Down Expand Up @@ -430,7 +433,7 @@ struct LateLowerGCFrame: public FunctionPass {
bool doFinalization(Module &) override;
bool runOnFunction(Function &F) override;
Instruction *get_pgcstack(Instruction *ptlsStates);
bool CleanupIR(Function &F);
bool CleanupIR(Function &F, State *S=nullptr);
void NoteUseChain(State &S, BBState &BBS, User *TheUser);
SmallVector<int, 1> GetPHIRefinements(PHINode *phi, State &S);
void FixUpRefinements(ArrayRef<int> PHINumbers, State &S);
Expand Down Expand Up @@ -571,9 +574,12 @@ int LateLowerGCFrame::NumberBase(State &S, Value *V, Value *CurrentV)
if (it != S.AllPtrNumbering.end())
return it->second;
int Number;
if (isa<Constant>(CurrentV) || isa<Argument>(CurrentV) ||
((isa<AllocaInst>(CurrentV) || isa<AddrSpaceCastInst>(CurrentV)) &&
getValueAddrSpace(CurrentV) != AddressSpace::Tracked)) {
if (isa<Constant>(CurrentV)) {
// Perm rooted
Number = -2;
} else if (isa<Argument>(CurrentV) ||
((isa<AllocaInst>(CurrentV) || isa<AddrSpaceCastInst>(CurrentV)) &&
getValueAddrSpace(CurrentV) != AddressSpace::Tracked)) {
// We know this is rooted in the parent
Number = -1;
} else if (isa<SelectInst>(CurrentV) && getValueAddrSpace(CurrentV) != AddressSpace::Tracked) {
Expand Down Expand Up @@ -666,7 +672,7 @@ static bool HasBitSet(const BitVector &BV, unsigned Bit) {
}

static void NoteDef(State &S, BBState &BBS, int Num, const std::vector<int> &SafepointsSoFar) {
assert(Num != -1);
assert(Num >= 0);
MaybeResize(BBS, Num);
assert(BBS.Defs[Num] == 0 && "SSA Violation or misnumbering?");
BBS.Defs[Num] = 1;
Expand Down Expand Up @@ -739,7 +745,7 @@ void LateLowerGCFrame::NoteUse(State &S, BBState &BBS, Value *V, BitVector &Uses
}
else {
int Num = Number(S, V);
if (Num == -1)
if (Num < 0)
return;
MaybeResize(BBS, Num);
Uses[Num] = 1;
Expand Down Expand Up @@ -821,7 +827,8 @@ static bool isLoadFromImmut(LoadInst *LI)
return false;
while (TBAA->getNumOperands() > 1) {
TBAA = cast<MDNode>(TBAA->getOperand(1).get());
if (cast<MDString>(TBAA->getOperand(0))->getString() == "jtbaa_immut") {
auto str = cast<MDString>(TBAA->getOperand(0))->getString();
if (str == "jtbaa_immut" || str == "jtbaa_const") {
return true;
}
}
Expand Down Expand Up @@ -869,6 +876,27 @@ SmallVector<int, 1> LateLowerGCFrame::GetPHIRefinements(PHINode *Phi, State &S)
return RefinedPtr;
}

JL_USED_FUNC static void DumpRefinements(State *S)
{
for (auto &kv: S->Refinements) {
int Num = kv.first;
if (Num < 0)
continue;
jl_safe_printf("Refinements for %d -- ", Num);
auto V = S->ReversePtrNumbering[Num];
llvm_dump(V);
for (auto refine: kv.second) {
if (refine < 0) {
jl_safe_printf(" %d\n", refine);
continue;
}
jl_safe_printf(" %d: ", refine);
auto R = S->ReversePtrNumbering[refine];
llvm_dump(R);
}
}
}

void LateLowerGCFrame::FixUpRefinements(ArrayRef<int> PHINumbers, State &S)
{
// Now we have all the possible refinement information, we can remove ones for the invalid
Expand All @@ -883,12 +911,14 @@ void LateLowerGCFrame::FixUpRefinements(ArrayRef<int> PHINumbers, State &S)
// We do this by first assuming all values to be externally rooted and then removing
// values that are or can be derived from non-externally rooted values recursively.
BitVector extern_rooted(S.MaxPtrNumber + 1, true);
BitVector perm_rooted(S.MaxPtrNumber + 1, true);
// * First clear all values that are not derived from anything.
// This only needs to be done once.
for (int i = 0; i <= S.MaxPtrNumber; i++) {
auto it = S.Refinements.find(i);
if (it == S.Refinements.end() || it->second.empty()) {
extern_rooted[i] = false;
perm_rooted[i] = false;
}
}
// * Then remove values reachable from those values recursively
Expand All @@ -901,20 +931,42 @@ void LateLowerGCFrame::FixUpRefinements(ArrayRef<int> PHINumbers, State &S)
if (!HasBitSet(extern_rooted, Num))
continue;
for (auto refine: kv.second) {
if (refine == -1)
if (refine == -2) {
continue;
if (!HasBitSet(extern_rooted, refine)) {
}
else if (refine == -1) {
if (HasBitSet(perm_rooted, Num)) {
changed = true;
perm_rooted[Num] = false;
}
continue;
}
else if (!HasBitSet(extern_rooted, refine)) {
changed = true;
extern_rooted[Num] = false;
perm_rooted[Num] = false;
break;
}
else if (!HasBitSet(perm_rooted, refine)) {
if (HasBitSet(perm_rooted, Num)) {
changed = true;
perm_rooted[Num] = false;
}
}
}
}
} while (changed);
// * Now the `extern_rooted` map is accurate, normalize all externally rooted values.
// * Now the `extern_rooted` and `perm_rooted` map is accurate,
// normalize all externally rooted values.
for (auto &kv: S.Refinements) {
int Num = kv.first;
if (HasBitSet(extern_rooted, Num)) {
if (HasBitSet(perm_rooted, Num)) {
// For permanently rooted values, set their refinements simply to `{-2}`
kv.second.resize(1);
kv.second[0] = -2;
continue;
}
else if (HasBitSet(extern_rooted, Num)) {
// For externally rooted values, set their refinements simply to `{-1}`
kv.second.resize(1);
kv.second[0] = -1;
Expand All @@ -923,6 +975,7 @@ void LateLowerGCFrame::FixUpRefinements(ArrayRef<int> PHINumbers, State &S)
for (auto &refine: kv.second) {
// For other values,
// remove all externally rooted values from their refinements (replace with -1)
// No need to handle -2 specially since it won't make a difference.
if (HasBitSet(extern_rooted, refine)) {
refine = -1;
}
Expand All @@ -938,15 +991,15 @@ void LateLowerGCFrame::FixUpRefinements(ArrayRef<int> PHINumbers, State &S)
BitVector visited(S.MaxPtrNumber + 1, false);
for (auto Num: PHINumbers) {
// Not sure if `Num` can be `-1`
if (Num == -1 || HasBitSet(extern_rooted, Num))
if (Num < 0 || HasBitSet(extern_rooted, Num))
continue;
visited[Num] = true;
auto Phi = cast<PHINode>(S.ReversePtrNumbering[Num]);
auto &RefinedPtr = S.Refinements[Num];
unsigned j = 0; // new length
for (unsigned i = 0; i < RefinedPtr.size(); i++) {
auto refine = RefinedPtr[i];
if (refine == -1 || visited[refine])
if (refine < 0 || visited[refine])
continue;
visited[refine] = true;
if (i != j)
Expand Down Expand Up @@ -1011,7 +1064,7 @@ State LateLowerGCFrame::LocalScan(Function &F) {
}
auto callee = CI->getCalledFunction();
if (callee && callee == typeof_func) {
MaybeNoteDef(S, BBS, CI, BBS.Safepoints, SmallVector<int, 1>{-1});
MaybeNoteDef(S, BBS, CI, BBS.Safepoints, SmallVector<int, 1>{-2});
}
else {
MaybeNoteDef(S, BBS, CI, BBS.Safepoints);
Expand Down Expand Up @@ -1080,7 +1133,7 @@ State LateLowerGCFrame::LocalScan(Function &F) {
else if (isLoadFromConstGV(LI)) {
// If this is a const load from a global,
// we know that the object is a constant as well and doesn't need rooting.
RefinedPtr.push_back(-1);
RefinedPtr.push_back(-2);
}
MaybeNoteDef(S, BBS, LI, BBS.Safepoints, std::move(RefinedPtr));
NoteOperandUses(S, BBS, I, BBS.UpExposedUsesUnrooted);
Expand All @@ -1092,7 +1145,7 @@ State LateLowerGCFrame::LocalScan(Function &F) {
if (S.AllPtrNumbering.find(SI) != S.AllPtrNumbering.end())
continue;
auto Num = LiftSelect(S, SI);
if (Num == -1)
if (Num < 0)
continue;
auto SelectBase = cast<SelectInst>(S.ReversePtrNumbering[Num]);
SmallVector<int, 1> RefinedPtr{Number(S, SelectBase->getTrueValue()),
Expand Down Expand Up @@ -1133,7 +1186,7 @@ State LateLowerGCFrame::LocalScan(Function &F) {
auto origin = ASCI->getPointerOperand()->stripPointerCasts();
if (auto LI = dyn_cast<LoadInst>(origin)) {
if (isLoadFromConstGV(LI)) {
RefinedPtr.push_back(-1);
RefinedPtr.push_back(-2);
}
}
MaybeNoteDef(S, BBS, ASCI, BBS.Safepoints, std::move(RefinedPtr));
Expand Down Expand Up @@ -1242,11 +1295,11 @@ void LateLowerGCFrame::RefineLiveSet(BitVector &LS, State &S)
changed = false;
for (auto &kv: S.Refinements) {
int Num = kv.first;
if (Num == -1 || HasBitSet(FullLS, Num) || kv.second.empty())
if (Num < 0 || HasBitSet(FullLS, Num) || kv.second.empty())
continue;
bool live = true;
for (auto &refine: kv.second) {
if (refine == -1 || HasBitSet(FullLS, refine))
if (refine < 0 || HasBitSet(FullLS, refine))
continue;
live = false;
break;
Expand All @@ -1268,7 +1321,7 @@ void LateLowerGCFrame::RefineLiveSet(BitVector &LS, State &S)
continue;
bool rooted = true;
for (auto RefPtr: RefinedPtr) {
if (RefPtr == -1 || HasBitSet(FullLS, RefPtr))
if (RefPtr < 0 || HasBitSet(FullLS, RefPtr))
continue;
rooted = false;
break;
Expand Down Expand Up @@ -1532,7 +1585,41 @@ Value *LateLowerGCFrame::EmitLoadTag(IRBuilder<> &builder, Type *T, Value *V)
return load;
}

bool LateLowerGCFrame::CleanupIR(Function &F) {
static SmallVector<int, 1> *FindRefinements(Value *V, State *S)
{
if (!S)
return nullptr;
auto it = S->AllPtrNumbering.find(V);
if (it == S->AllPtrNumbering.end())
return nullptr;
auto rit = S->Refinements.find(it->second);
return rit != S->Refinements.end() && !rit->second.empty() ? &rit->second : nullptr;
}

static bool IsPermRooted(Value *V, State *S)
{
if (isa<Constant>(V))
return true;
if (auto *RefinePtr = FindRefinements(V, S))
return RefinePtr->size() == 1 && (*RefinePtr)[0] == -2;
return false;
}

static inline void UpdatePtrNumbering(Value *From, Value *To, State *S)
{
if (!S)
return;
auto it = S->AllPtrNumbering.find(From);
if (it == S->AllPtrNumbering.end())
return;
auto Num = it->second;
S->AllPtrNumbering.erase(it);
if (To) {
S->AllPtrNumbering[To] = Num;
}
}

bool LateLowerGCFrame::CleanupIR(Function &F, State *S) {
bool ChangesMade = false;
// We create one alloca for all the jlcall frames that haven't been processed
// yet. LLVM would merge them anyway later, so might as well save it a bit
Expand Down Expand Up @@ -1567,6 +1654,7 @@ bool LateLowerGCFrame::CleanupIR(Function &F) {
auto *ASCI = new AddrSpaceCastInst(obj, T_pjlvalue, "", CI);
ASCI->takeName(CI);
CI->replaceAllUsesWith(ASCI);
UpdatePtrNumbering(CI, ASCI, S);
} else if (alloc_obj_func && callee == alloc_obj_func) {
assert(CI->getNumArgOperands() == 3);
auto sz = (size_t)cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
Expand Down Expand Up @@ -1594,6 +1682,7 @@ bool LateLowerGCFrame::CleanupIR(Function &F) {
EmitTagPtr(builder, T_prjlvalue, newI));
store->setMetadata(LLVMContext::MD_tbaa, tbaa_tag);
CI->replaceAllUsesWith(newI);
UpdatePtrNumbering(CI, newI, S);
} else if (typeof_func && callee == typeof_func) {
assert(CI->getNumArgOperands() == 1);
IRBuilder<> builder(CI);
Expand All @@ -1604,6 +1693,7 @@ bool LateLowerGCFrame::CleanupIR(Function &F) {
T_prjlvalue);
typ->takeName(CI);
CI->replaceAllUsesWith(typ);
UpdatePtrNumbering(CI, typ, S);
} else if (write_barrier_func && callee == write_barrier_func) {
// The replacement for this requires creating new BasicBlocks
// which messes up the loop. Queue all of them to be replaced later.
Expand Down Expand Up @@ -1655,6 +1745,7 @@ bool LateLowerGCFrame::CleanupIR(Function &F) {
#endif
NewCall->setDebugLoc(CI->getDebugLoc());
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
} else if (CI->getNumArgOperands() == CI->getNumOperands()) {
/* No operand bundle to lower */
++it;
Expand All @@ -1663,19 +1754,25 @@ bool LateLowerGCFrame::CleanupIR(Function &F) {
CallInst *NewCall = CallInst::Create(CI, None, CI);
NewCall->takeName(CI);
CI->replaceAllUsesWith(NewCall);
UpdatePtrNumbering(CI, NewCall, S);
}
if (!CI->use_empty()) {
CI->replaceAllUsesWith(UndefValue::get(CI->getType()));
UpdatePtrNumbering(CI, nullptr, S);
}
it = CI->eraseFromParent();
ChangesMade = true;
}
}
for (auto CI: write_barriers) {
IRBuilder<> builder(CI);
builder.SetCurrentDebugLocation(CI->getDebugLoc());
auto parent = CI->getArgOperand(0);
auto child = CI->getArgOperand(1);
if (parent == child || IsPermRooted(child, S)) {
CI->eraseFromParent();
continue;
}
IRBuilder<> builder(CI);
builder.SetCurrentDebugLocation(CI->getDebugLoc());
auto parBits = builder.CreateAnd(EmitLoadTag(builder, T_size, parent), 3);
auto parOldMarked = builder.CreateICmpEQ(parBits, ConstantInt::get(T_size, 3));
auto mayTrigTerm = SplitBlockAndInsertIfThen(parOldMarked, CI, false);
Expand Down Expand Up @@ -1997,7 +2094,7 @@ bool LateLowerGCFrame::runOnFunction(Function &F) {
std::vector<int> Colors = ColorRoots(S);
std::map<Value *, std::pair<int, int>> CallFrames; // = OptimizeCallFrames(S, Ordering);
PlaceRootsAndUpdateCalls(Colors, S, CallFrames);
CleanupIR(F);
CleanupIR(F, &S);
return true;
}

Expand Down

0 comments on commit 5b5cf3a

Please sign in to comment.