From 3b01785fb447df883bcc891dcf1284b6f9faa889 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Fri, 14 Jun 2024 11:35:14 +0200 Subject: [PATCH] copy propagation --- .../Reg/Transformation/CopyPropagation.hs | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs diff --git a/src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs b/src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs new file mode 100644 index 0000000000..affbcd12ce --- /dev/null +++ b/src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs @@ -0,0 +1,51 @@ +module Juvix.Compiler.Reg.Transformation.CopyPropagation where + +import Data.HashMap.Strict qualified as HashMap +import Data.HashSet qualified as HashSet +import Juvix.Compiler.Reg.Extra +import Juvix.Compiler.Reg.Transformation.Base + +type VarMap = HashMap VarRef VarRef + +copyPropagateFunction :: Code -> Code +copyPropagateFunction = + snd + . runIdentity + . recurseF + ForwardRecursorSig + { _forwardFun = \i acc -> return (go i acc), + _forwardCombine = combine + } + mempty + where + go :: Instruction -> VarMap -> (VarMap, Instruction) + go instr mpv = case instr' of + Assign InstrAssign {..} + | VRef v <- _instrAssignValue -> + (HashMap.insert _instrAssignResult v mpv', instr') + _ -> + (mpv', instr') + where + instr' = overValueRefs (adjustVarRef mpv) instr + mpv' = maybe mpv (filterOutVars mpv) (getResultVar instr') + + filterOutVars :: VarMap -> VarRef -> VarMap + filterOutVars mpv v = HashMap.filter (/= v) mpv + + adjustVarRef :: VarMap -> VarRef -> VarRef + adjustVarRef mpv vref@VarRef {..} = case _varRefGroup of + VarGroupArgs -> vref + VarGroupLocal -> fromMaybe vref $ HashMap.lookup vref mpv + + combine :: Instruction -> NonEmpty VarMap -> (VarMap, Instruction) + combine instr mpvs = (mpv, instr) + where + mpv' :| mpvs' = fmap HashMap.toList mpvs + mpv = + HashMap.fromList + . HashSet.toList + . foldr (HashSet.intersection . HashSet.fromList) (HashSet.fromList mpv') + $ mpvs' + +copyPropagate :: InfoTable -> InfoTable +copyPropagate = mapT (const copyPropagateFunction)