Skip to content

Commit

Permalink
feat: register allocator pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ganyao114 committed Dec 8, 2023
1 parent 6346021 commit 065e459
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 41 deletions.
27 changes: 22 additions & 5 deletions source/runtime/backend/reg_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,29 @@

namespace swift::runtime::backend {

ir::HostFPR RegAlloc::ValueFPR(const ir::Value& value) {
return {};
}
RegAlloc::RegAlloc(u32 instr_size, const GPRSMask& gprs, const FPRSMask& fprs)
: alloc_result{instr_size}, gprs(gprs), fprs(fprs) {}

ir::HostGPR RegAlloc::ValueGPR(const ir::Value& value) {
return {};
void RegAlloc::MapRegister(u32 id, ir::HostFPR fpr) {
alloc_result[id].type = FPR;
alloc_result[id].slot = fpr.id;
}

void RegAlloc::MapRegister(u32 id, ir::HostGPR gpr) {
alloc_result[id].type = GPR;
alloc_result[id].slot = gpr.id;
}

ir::HostFPR RegAlloc::ValueFPR(const ir::Value& value) { return {}; }

ir::HostGPR RegAlloc::ValueGPR(const ir::Value& value) { return {}; }

ir::HostGPR RegAlloc::ValueGPR(u32 id) { return ir::HostGPR{alloc_result[id].slot}; }

ir::HostFPR RegAlloc::ValueFPR(u32 id) { return ir::HostFPR{alloc_result[id].slot}; }

const GPRSMask& RegAlloc::GetGprs() const { return gprs; }

const FPRSMask& RegAlloc::GetFprs() const { return fprs; }

} // namespace swift::runtime::backend
72 changes: 64 additions & 8 deletions source/runtime/backend/reg_alloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,95 @@

#pragma once

#include <bit>
#include "base/common_funcs.h"
#include "runtime/common/types.h"
#include "runtime/ir/block.h"
#include "runtime/ir/host_reg.h"

namespace swift::runtime::backend {

template<typename T = u32>
class RegisterMask {
public:

explicit RegisterMask() : mask() {}

explicit RegisterMask(T mask) : mask(mask) {}

auto GetFirstMarked() {
return std::countr_zero(mask);
}

auto GetFirstClear() {
return std::countr_one(mask);
}

auto GetMarkedCount() {
return std::popcount(mask);
}

auto GetClearCount() {
return GetAllCount() - std::popcount(mask);
}

auto GetAllCount() {
return sizeof(T) * 8;
}

bool Get(u32 bit) {
return mask & (T(1) << bit);
}

void Mark(u32 bit) {
mask |= (T(1) << bit);
}

void Clear(u32 bit) {
mask &= ~(T(1) << bit);
}

private:
T mask;
};

using GPRSMask = RegisterMask<u32>;
using FPRSMask = RegisterMask<u32>;

class RegAlloc : DeleteCopyAndMove {
public:

enum Type : u8 {
explicit RegAlloc(u32 instr_size, const GPRSMask& gprs, const FPRSMask& fprs);

enum Type : u16 {
NONE,
GPR,
FPR,
MEM
};

struct Map {
union {
ir::HostGPR gpr;
ir::HostFPR fpr;
ir::SpillSlot spill;
};
Type type{NONE};
u32 dirty_gprs{};
u32 dirty_fprs{};
u16 slot{};
GPRSMask dirty_gprs{0};
FPRSMask dirty_fprs{0};
};

const GPRSMask& GetGprs() const;
const FPRSMask& GetFprs() const;

void MapRegister(u32 id, ir::HostGPR gpr);
void MapRegister(u32 id, ir::HostFPR fpr);

ir::HostGPR ValueGPR(const ir::Value &value);
ir::HostFPR ValueFPR(const ir::Value &value);
ir::HostGPR ValueGPR(u32 id);
ir::HostFPR ValueFPR(u32 id);

private:
Vector<Map> alloc_result;
const GPRSMask gprs;
const FPRSMask fprs;
};

}
2 changes: 1 addition & 1 deletion source/runtime/ir/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Value {

void Use() const;
void UnUse() const;
u16 Id() const;
[[nodiscard]] u16 Id() const;

private:
Inst* inst{};
Expand Down
1 change: 0 additions & 1 deletion source/runtime/ir/hir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ void HIRFunction::MergeAdjacentBlocks(HIRBlock* left, HIRBlock* right) {}
bool HIRFunction::SplitBlock(HIRBlock* new_block, HIRBlock* old_block) { return false; }

void HIRFunction::IdByRPO() {
std::destroy(values.begin(), values.end());
u32 cur_inst_id{0};
// Re id inst
StackVector<HIRValue*, 32> function_values{};
Expand Down
3 changes: 2 additions & 1 deletion source/runtime/ir/hir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct HIRValue {
void Use(Inst* inst, u8 idx);
void UnUse(Inst* inst, u8 idx);

u16 GetOrderId() const;
[[nodiscard]] u16 GetOrderId() const;

// for rbtree compare
static NOINLINE int Compare(const HIRValue& lhs, const HIRValue& rhs) {
Expand Down Expand Up @@ -229,6 +229,7 @@ class HIRFunction : public DataContext {
ASSERT(current_block);
auto inst = new Inst(op);
inst->SetArgs(args...);
inst->SetId(inst_order_id++);
current_block->block->AppendInst(inst);
AppendValue(current_block, inst);
return inst;
Expand Down
103 changes: 81 additions & 22 deletions source/runtime/ir/opts/register_alloc_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
namespace swift::runtime::ir {

struct LiveInterval {
Inst* inst{};
u32 start{};
u32 end{};

Expand All @@ -22,10 +23,26 @@ struct LiveInterval {
class LinearScanAllocator {
public:
explicit LinearScanAllocator(HIRFunction* function, backend::RegAlloc* alloc)
: function(function), block(), reg_alloc(alloc), live_interval{function->MaxInstrCount()}, active_lives() {}
: function(function)
, block()
, reg_alloc(alloc)
, live_interval()
, active_lives() {
active_gprs = alloc->GetGprs();
active_fprs = alloc->GetFprs();
live_interval.reserve(function->MaxInstrCount());
}

explicit LinearScanAllocator(Block* block, backend::RegAlloc* alloc)
: function(), block(block), reg_alloc(alloc), live_interval{block->GetInstList().size()}, active_lives() {}
: function()
, block(block)
, reg_alloc(alloc)
, live_interval()
, active_lives() {
active_gprs = alloc->GetGprs();
active_fprs = alloc->GetFprs();
live_interval.reserve(block->GetInstList().size());
}

void AllocateRegisters() {
// Step 1: Collect live intervals
Expand All @@ -42,51 +59,93 @@ class LinearScanAllocator {
for (auto& interval : live_interval) {
ExpireOldIntervals(interval);

// if (active_lives.size() == static_cast<size_t>(numAvailableRegisters)) {
// SpillAtInterval(interval);
// } else {
// int reg = AllocateRegister();
// interval.registerAssigned = reg;
// active_lives.push_back(interval);
// AssignRegisterToVariable(interval.variable, reg);
// }
if (!IsFloatValue(interval.inst)) {
if (auto alloc = AllocGPR(); alloc >= 0) {
active_lives.push_back(interval);
reg_alloc->MapRegister(interval.inst->Id(), HostGPR{(u16)alloc});
} else {
SpillAtInterval(interval);
}
} else {
if (auto alloc = AllocFPR(); alloc >= 0) {
active_lives.push_back(interval);
reg_alloc->MapRegister(interval.inst->Id(), HostFPR{(u16)alloc});
} else {
SpillAtInterval(interval);
}
}
}
}

private:

void CollectLiveIntervals(HIRFunction* hir_function) {
for (auto& hir_value : hir_function->GetHIRValues()) {
u32 end{hir_value.GetOrderId()};
std::for_each(hir_value.uses.begin(), hir_value.uses.end(), [&end](auto& use) {
end = std::max(end, (u32) use.inst->Id());
end = std::max(end, (u32)use.inst->Id());
});
live_interval[hir_value.GetOrderId()] = {hir_value.GetOrderId(), end};
live_interval.push_back({
hir_value.value.Def(), hir_value.GetOrderId(), end});
}
}

void CollectLiveIntervals(Block *lir_block) {
void CollectLiveIntervals(Block* lir_block) {}

}

void ExpireOldIntervals(LiveInterval &current) {
void ExpireOldIntervals(LiveInterval& current) {
for (auto it = active_lives.begin(); it != active_lives.end();) {
if (it->end < current.start) {
// FreeRegister(it->registerAssigned);
it = active_lives.erase(it); // Remove expired intervals
if (!IsFloatValue(it->inst)) {
FreeGPR(reg_alloc->ValueGPR(it->inst->Id()).id);
} else {
FreeFPR(reg_alloc->ValueFPR(it->inst->Id()).id);
}
it = active_lives.erase(it); // Remove expired intervals
} else {
++it;
}
}
}

void SpillAtInterval(LiveInterval& interval) {}

bool IsFloatValue(Inst* inst) {
auto value_type = inst->ReturnType();
return value_type >= ValueType::V8 && value_type <= ValueType::V256;
}

int AllocGPR() {
if (auto alloc = active_gprs.GetFirstClear(); alloc >= 0) {
active_gprs.Mark(alloc);
return alloc;
}
return -1;
}

int AllocFPR() {
if (auto alloc = active_fprs.GetFirstClear(); alloc >= 0) {
active_fprs.Mark(alloc);
return alloc;
}
return -1;
}

void FreeGPR(u32 id) {
ASSERT(active_gprs.Get(id));
active_gprs.Clear(id);
}

void FreeFPR(u32 id) {
ASSERT(active_fprs.Get(id));
active_fprs.Clear(id);
}

HIRFunction* function;
Block *block;
Block* block;
backend::RegAlloc* reg_alloc;
Vector<LiveInterval> live_interval;
List<LiveInterval> active_lives;
u16 active_general_value{};
u16 active_float_value{};
backend::GPRSMask active_gprs;
backend::FPRSMask active_fprs;
};

void RegisterAllocPass::Run(HIRBuilder* hir_builder, backend::RegAlloc* reg_alloc) {
Expand Down
9 changes: 6 additions & 3 deletions source/tests/main_case.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include "runtime/ir/hir_builder.h"
#include "runtime/ir/opts/cfg_analysis_pass.h"
#include "runtime/ir/opts/local_elimination_pass.h"
#include "runtime/ir/opts/reid_instr_pass.h"
#include "runtime/ir/opts/register_alloc_pass.h"
#include "runtime/backend/mem_map.h"
#include "compiler/slang/slang.h"

Expand Down Expand Up @@ -57,14 +59,14 @@ TEST_CASE("Test runtime-ir") {
hir_builder.Return();
CFGAnalysisPass::Run(&hir_builder);
LocalEliminationPass::Run(&hir_builder);

assert(local2.Defined());
ReIdInstrPass::Run(&hir_builder);
RegAlloc reg_alloc{function->MaxInstrCount(), GPRSMask{0}, FPRSMask{0}};
RegisterAllocPass::Run(&hir_builder, &reg_alloc);

MemMap mem_arena{0x100000, true};

auto res = mem_arena.Map(0x100000, 0, MemMap::ReadExe, false);
ASSERT(res);

}

TEST_CASE("Test runtime-ir-cfg") {
Expand Down Expand Up @@ -110,6 +112,7 @@ TEST_CASE("Test runtime-ir-cfg") {
hir_builder.Return();
CFGAnalysisPass::Run(&hir_builder);
LocalEliminationPass::Run(&hir_builder);
ReIdInstrPass::Run(&hir_builder);

assert(local2.Defined());

Expand Down

0 comments on commit 065e459

Please sign in to comment.