diff --git a/source/runtime/backend/reg_alloc.cpp b/source/runtime/backend/reg_alloc.cpp index a4fe38f..9dddd96 100644 --- a/source/runtime/backend/reg_alloc.cpp +++ b/source/runtime/backend/reg_alloc.cpp @@ -6,12 +6,29 @@ namespace swift::runtime::backend { -ir::HostFPR RegAlloc::ValueFPR(const ir::Value& value) { - return {}; -} +RegAlloc::RegAlloc(u32 instr_size, const GPRSMask& gprs, const FPRSMask& fprs) + : alloc_result{instr_size}, gprs(gprs), fprs(fprs) {} -ir::HostGPR RegAlloc::ValueGPR(const ir::Value& value) { - return {}; +void RegAlloc::MapRegister(u32 id, ir::HostFPR fpr) { + alloc_result[id].type = FPR; + alloc_result[id].slot = fpr.id; } +void RegAlloc::MapRegister(u32 id, ir::HostGPR gpr) { + alloc_result[id].type = GPR; + alloc_result[id].slot = gpr.id; } + +ir::HostFPR RegAlloc::ValueFPR(const ir::Value& value) { return {}; } + +ir::HostGPR RegAlloc::ValueGPR(const ir::Value& value) { return {}; } + +ir::HostGPR RegAlloc::ValueGPR(u32 id) { return ir::HostGPR{alloc_result[id].slot}; } + +ir::HostFPR RegAlloc::ValueFPR(u32 id) { return ir::HostFPR{alloc_result[id].slot}; } + +const GPRSMask& RegAlloc::GetGprs() const { return gprs; } + +const FPRSMask& RegAlloc::GetFprs() const { return fprs; } + +} // namespace swift::runtime::backend diff --git a/source/runtime/backend/reg_alloc.h b/source/runtime/backend/reg_alloc.h index b60a1e2..a6d8b80 100644 --- a/source/runtime/backend/reg_alloc.h +++ b/source/runtime/backend/reg_alloc.h @@ -4,6 +4,7 @@ #pragma once +#include #include "base/common_funcs.h" #include "runtime/common/types.h" #include "runtime/ir/block.h" @@ -11,10 +12,59 @@ namespace swift::runtime::backend { +template +class RegisterMask { +public: + + explicit RegisterMask() : mask() {} + + explicit RegisterMask(T mask) : mask(mask) {} + + auto GetFirstMarked() { + return std::countr_zero(mask); + } + + auto GetFirstClear() { + return std::countr_one(mask); + } + + auto GetMarkedCount() { + return std::popcount(mask); + } + + auto GetClearCount() { + return GetAllCount() - std::popcount(mask); + } + + auto GetAllCount() { + return sizeof(T) * 8; + } + + bool Get(u32 bit) { + return mask & (T(1) << bit); + } + + void Mark(u32 bit) { + mask |= (T(1) << bit); + } + + void Clear(u32 bit) { + mask &= ~(T(1) << bit); + } + +private: + T mask; +}; + +using GPRSMask = RegisterMask; +using FPRSMask = RegisterMask; + class RegAlloc : DeleteCopyAndMove { public: - enum Type : u8 { + explicit RegAlloc(u32 instr_size, const GPRSMask& gprs, const FPRSMask& fprs); + + enum Type : u16 { NONE, GPR, FPR, @@ -22,21 +72,27 @@ class RegAlloc : DeleteCopyAndMove { }; struct Map { - union { - ir::HostGPR gpr; - ir::HostFPR fpr; - ir::SpillSlot spill; - }; Type type{NONE}; - u32 dirty_gprs{}; - u32 dirty_fprs{}; + u16 slot{}; + GPRSMask dirty_gprs{0}; + FPRSMask dirty_fprs{0}; }; + const GPRSMask& GetGprs() const; + const FPRSMask& GetFprs() const; + + void MapRegister(u32 id, ir::HostGPR gpr); + void MapRegister(u32 id, ir::HostFPR fpr); + ir::HostGPR ValueGPR(const ir::Value &value); ir::HostFPR ValueFPR(const ir::Value &value); + ir::HostGPR ValueGPR(u32 id); + ir::HostFPR ValueFPR(u32 id); private: Vector alloc_result; + const GPRSMask gprs; + const FPRSMask fprs; }; } diff --git a/source/runtime/ir/args.h b/source/runtime/ir/args.h index efe1b60..da6cdb5 100644 --- a/source/runtime/ir/args.h +++ b/source/runtime/ir/args.h @@ -81,7 +81,7 @@ class Value { void Use() const; void UnUse() const; - u16 Id() const; + [[nodiscard]] u16 Id() const; private: Inst* inst{}; diff --git a/source/runtime/ir/hir_builder.cpp b/source/runtime/ir/hir_builder.cpp index f9fd236..d85a603 100644 --- a/source/runtime/ir/hir_builder.cpp +++ b/source/runtime/ir/hir_builder.cpp @@ -176,7 +176,6 @@ void HIRFunction::MergeAdjacentBlocks(HIRBlock* left, HIRBlock* right) {} bool HIRFunction::SplitBlock(HIRBlock* new_block, HIRBlock* old_block) { return false; } void HIRFunction::IdByRPO() { - std::destroy(values.begin(), values.end()); u32 cur_inst_id{0}; // Re id inst StackVector function_values{}; diff --git a/source/runtime/ir/hir_builder.h b/source/runtime/ir/hir_builder.h index 9647c8f..848b773 100644 --- a/source/runtime/ir/hir_builder.h +++ b/source/runtime/ir/hir_builder.h @@ -116,7 +116,7 @@ struct HIRValue { void Use(Inst* inst, u8 idx); void UnUse(Inst* inst, u8 idx); - u16 GetOrderId() const; + [[nodiscard]] u16 GetOrderId() const; // for rbtree compare static NOINLINE int Compare(const HIRValue& lhs, const HIRValue& rhs) { @@ -229,6 +229,7 @@ class HIRFunction : public DataContext { ASSERT(current_block); auto inst = new Inst(op); inst->SetArgs(args...); + inst->SetId(inst_order_id++); current_block->block->AppendInst(inst); AppendValue(current_block, inst); return inst; diff --git a/source/runtime/ir/opts/register_alloc_pass.cpp b/source/runtime/ir/opts/register_alloc_pass.cpp index 2224497..2720a61 100644 --- a/source/runtime/ir/opts/register_alloc_pass.cpp +++ b/source/runtime/ir/opts/register_alloc_pass.cpp @@ -7,6 +7,7 @@ namespace swift::runtime::ir { struct LiveInterval { + Inst* inst{}; u32 start{}; u32 end{}; @@ -22,10 +23,26 @@ struct LiveInterval { class LinearScanAllocator { public: explicit LinearScanAllocator(HIRFunction* function, backend::RegAlloc* alloc) - : function(function), block(), reg_alloc(alloc), live_interval{function->MaxInstrCount()}, active_lives() {} + : function(function) + , block() + , reg_alloc(alloc) + , live_interval() + , active_lives() { + active_gprs = alloc->GetGprs(); + active_fprs = alloc->GetFprs(); + live_interval.reserve(function->MaxInstrCount()); + } explicit LinearScanAllocator(Block* block, backend::RegAlloc* alloc) - : function(), block(block), reg_alloc(alloc), live_interval{block->GetInstList().size()}, active_lives() {} + : function() + , block(block) + , reg_alloc(alloc) + , live_interval() + , active_lives() { + active_gprs = alloc->GetGprs(); + active_fprs = alloc->GetFprs(); + live_interval.reserve(block->GetInstList().size()); + } void AllocateRegisters() { // Step 1: Collect live intervals @@ -42,51 +59,93 @@ class LinearScanAllocator { for (auto& interval : live_interval) { ExpireOldIntervals(interval); -// if (active_lives.size() == static_cast(numAvailableRegisters)) { -// SpillAtInterval(interval); -// } else { -// int reg = AllocateRegister(); -// interval.registerAssigned = reg; -// active_lives.push_back(interval); -// AssignRegisterToVariable(interval.variable, reg); -// } + if (!IsFloatValue(interval.inst)) { + if (auto alloc = AllocGPR(); alloc >= 0) { + active_lives.push_back(interval); + reg_alloc->MapRegister(interval.inst->Id(), HostGPR{(u16)alloc}); + } else { + SpillAtInterval(interval); + } + } else { + if (auto alloc = AllocFPR(); alloc >= 0) { + active_lives.push_back(interval); + reg_alloc->MapRegister(interval.inst->Id(), HostFPR{(u16)alloc}); + } else { + SpillAtInterval(interval); + } + } } } private: - void CollectLiveIntervals(HIRFunction* hir_function) { for (auto& hir_value : hir_function->GetHIRValues()) { u32 end{hir_value.GetOrderId()}; std::for_each(hir_value.uses.begin(), hir_value.uses.end(), [&end](auto& use) { - end = std::max(end, (u32) use.inst->Id()); + end = std::max(end, (u32)use.inst->Id()); }); - live_interval[hir_value.GetOrderId()] = {hir_value.GetOrderId(), end}; + live_interval.push_back({ + hir_value.value.Def(), hir_value.GetOrderId(), end}); } } - void CollectLiveIntervals(Block *lir_block) { + void CollectLiveIntervals(Block* lir_block) {} - } - - void ExpireOldIntervals(LiveInterval ¤t) { + void ExpireOldIntervals(LiveInterval& current) { for (auto it = active_lives.begin(); it != active_lives.end();) { if (it->end < current.start) { -// FreeRegister(it->registerAssigned); - it = active_lives.erase(it); // Remove expired intervals + if (!IsFloatValue(it->inst)) { + FreeGPR(reg_alloc->ValueGPR(it->inst->Id()).id); + } else { + FreeFPR(reg_alloc->ValueFPR(it->inst->Id()).id); + } + it = active_lives.erase(it); // Remove expired intervals } else { ++it; } } } + void SpillAtInterval(LiveInterval& interval) {} + + bool IsFloatValue(Inst* inst) { + auto value_type = inst->ReturnType(); + return value_type >= ValueType::V8 && value_type <= ValueType::V256; + } + + int AllocGPR() { + if (auto alloc = active_gprs.GetFirstClear(); alloc >= 0) { + active_gprs.Mark(alloc); + return alloc; + } + return -1; + } + + int AllocFPR() { + if (auto alloc = active_fprs.GetFirstClear(); alloc >= 0) { + active_fprs.Mark(alloc); + return alloc; + } + return -1; + } + + void FreeGPR(u32 id) { + ASSERT(active_gprs.Get(id)); + active_gprs.Clear(id); + } + + void FreeFPR(u32 id) { + ASSERT(active_fprs.Get(id)); + active_fprs.Clear(id); + } + HIRFunction* function; - Block *block; + Block* block; backend::RegAlloc* reg_alloc; Vector live_interval; List active_lives; - u16 active_general_value{}; - u16 active_float_value{}; + backend::GPRSMask active_gprs; + backend::FPRSMask active_fprs; }; void RegisterAllocPass::Run(HIRBuilder* hir_builder, backend::RegAlloc* reg_alloc) { diff --git a/source/tests/main_case.cpp b/source/tests/main_case.cpp index 4b5d699..2b5c09b 100644 --- a/source/tests/main_case.cpp +++ b/source/tests/main_case.cpp @@ -2,6 +2,8 @@ #include "runtime/ir/hir_builder.h" #include "runtime/ir/opts/cfg_analysis_pass.h" #include "runtime/ir/opts/local_elimination_pass.h" +#include "runtime/ir/opts/reid_instr_pass.h" +#include "runtime/ir/opts/register_alloc_pass.h" #include "runtime/backend/mem_map.h" #include "compiler/slang/slang.h" @@ -57,14 +59,14 @@ TEST_CASE("Test runtime-ir") { hir_builder.Return(); CFGAnalysisPass::Run(&hir_builder); LocalEliminationPass::Run(&hir_builder); - - assert(local2.Defined()); + ReIdInstrPass::Run(&hir_builder); + RegAlloc reg_alloc{function->MaxInstrCount(), GPRSMask{0}, FPRSMask{0}}; + RegisterAllocPass::Run(&hir_builder, ®_alloc); MemMap mem_arena{0x100000, true}; auto res = mem_arena.Map(0x100000, 0, MemMap::ReadExe, false); ASSERT(res); - } TEST_CASE("Test runtime-ir-cfg") { @@ -110,6 +112,7 @@ TEST_CASE("Test runtime-ir-cfg") { hir_builder.Return(); CFGAnalysisPass::Run(&hir_builder); LocalEliminationPass::Run(&hir_builder); + ReIdInstrPass::Run(&hir_builder); assert(local2.Defined());