diff --git a/src/Juvix/Compiler/Reg/Data/TransformationId.hs b/src/Juvix/Compiler/Reg/Data/TransformationId.hs index 8003a9599a..b06807306b 100644 --- a/src/Juvix/Compiler/Reg/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Reg/Data/TransformationId.hs @@ -6,8 +6,9 @@ import Juvix.Prelude data TransformationId = Identity - | SSA | Cleanup + | SSA + | InitBranchVars deriving stock (Data, Bounded, Enum, Show) data PipelineId @@ -21,14 +22,15 @@ toCTransformations :: [TransformationId] toCTransformations = [Cleanup] toCairoTransformations :: [TransformationId] -toCairoTransformations = [Cleanup, SSA] +toCairoTransformations = [Cleanup, SSA, InitBranchVars] instance TransformationId' TransformationId where transformationText :: TransformationId -> Text transformationText = \case Identity -> strIdentity - SSA -> strSSA Cleanup -> strCleanup + SSA -> strSSA + InitBranchVars -> strInitBranchVars instance PipelineId' TransformationId PipelineId where pipelineText :: PipelineId -> Text diff --git a/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs index b05e08b320..fe54ffb955 100644 --- a/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs @@ -11,8 +11,11 @@ strCairoPipeline = "pipeline-cairo" strIdentity :: Text strIdentity = "identity" +strCleanup :: Text +strCleanup = "cleanup" + strSSA :: Text strSSA = "ssa" -strCleanup :: Text -strCleanup = "cleanup" +strInitBranchVars :: Text +strInitBranchVars = "init-branch-vars" diff --git a/src/Juvix/Compiler/Reg/Extra/Recursors.hs b/src/Juvix/Compiler/Reg/Extra/Recursors.hs index 44f356c592..0b3d6557de 100644 --- a/src/Juvix/Compiler/Reg/Extra/Recursors.hs +++ b/src/Juvix/Compiler/Reg/Extra/Recursors.hs @@ -9,7 +9,17 @@ data ForwardRecursorSig m c = ForwardRecursorSig } data BackwardRecursorSig m a = BackwardRecursorSig - { _backwardFun :: Code -> a -> [a] -> m (a, Code), + { -- | In `_backwardFun is a as`: `is = i : is'` is the instruction list + -- currently being processed (the head `i` is the processed instruction, the + -- tail `is'` contains the instructions after it); `a` is the accumulator + -- for `is'`; `as` contains the accumulator values for the branches (for + -- `Branch` and `Case` instructions, otherwise empty). For the `Case` + -- instruction, the accumulator for the default branch (if present) is the + -- last element of `as`. + _backwardFun :: Code -> a -> [a] -> m (a, Code), + -- | `backwardAdjust a` adjusts the accumulator value when going backwards + -- into a branch. See also `FoldSig` in `Asm.Extra.Recursors` for more + -- explanations. _backwardAdjust :: a -> a } @@ -125,3 +135,25 @@ ifoldFM f a0 is0 = ifoldF :: (Monoid a) => (a -> Instruction -> a) -> a -> Code -> a ifoldF f a is = runIdentity (ifoldFM (\a' -> return . f a') a is) + +ifoldBM :: forall a m. (Monad m) => (a -> [a] -> Instruction -> m a) -> a -> Code -> m a +ifoldBM f a0 is0 = + fst + <$> recurseB + BackwardRecursorSig + { _backwardFun = go, + _backwardAdjust = id + } + a0 + is0 + where + go :: Code -> a -> [a] -> m (a, Code) + go is a as = case is of + i : _ -> do + a' <- f a as i + return (a', is) + [] -> + return (a, is) + +ifoldB :: (a -> [a] -> Instruction -> a) -> a -> Code -> a +ifoldB f a is = runIdentity (ifoldBM (\a' as' -> return . f a' as') a is) diff --git a/src/Juvix/Compiler/Reg/Transformation.hs b/src/Juvix/Compiler/Reg/Transformation.hs index b8b6e6c16c..d464969927 100644 --- a/src/Juvix/Compiler/Reg/Transformation.hs +++ b/src/Juvix/Compiler/Reg/Transformation.hs @@ -9,6 +9,7 @@ import Juvix.Compiler.Reg.Data.TransformationId import Juvix.Compiler.Reg.Transformation.Base import Juvix.Compiler.Reg.Transformation.Cleanup import Juvix.Compiler.Reg.Transformation.Identity +import Juvix.Compiler.Reg.Transformation.InitBranchVars import Juvix.Compiler.Reg.Transformation.SSA applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable @@ -17,5 +18,6 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts appTrans :: TransformationId -> InfoTable -> Sem r InfoTable appTrans = \case Identity -> return . identity - SSA -> return . computeSSA Cleanup -> return . cleanup + SSA -> return . computeSSA + InitBranchVars -> return . initBranchVars diff --git a/src/Juvix/Compiler/Reg/Transformation/InitBranchVars.hs b/src/Juvix/Compiler/Reg/Transformation/InitBranchVars.hs new file mode 100644 index 0000000000..5617cf122c --- /dev/null +++ b/src/Juvix/Compiler/Reg/Transformation/InitBranchVars.hs @@ -0,0 +1,91 @@ +module Juvix.Compiler.Reg.Transformation.InitBranchVars where + +import Data.Functor.Identity +import Data.HashSet qualified as HashSet +import Data.List qualified as List +import Juvix.Compiler.Reg.Extra +import Juvix.Compiler.Reg.Transformation.Base + +-- | Inserts assignments to initialize variables assigned in other branches. +-- Assumes the input is in SSA form (which is preserved). +initBranchVars :: InfoTable -> InfoTable +initBranchVars = mapT (const goFun) + where + goFun :: Code -> Code + goFun = + snd + . runIdentity + . recurseB + BackwardRecursorSig + { _backwardFun = \is a as -> return (go is a as), + _backwardAdjust = const mempty + } + mempty + + go :: Code -> HashSet VarRef -> [HashSet VarRef] -> (HashSet VarRef, Code) + go is a as = case is of + Branch InstrBranch {..} : is' -> case as of + [a1, a2] -> (a <> a', i' : is') + where + a' = a1 <> a2 + a1' = HashSet.difference a' a1 + a2' = HashSet.difference a' a2 + i' = + Branch + InstrBranch + { _instrBranchTrue = addInits a1' _instrBranchTrue, + _instrBranchFalse = addInits a2' _instrBranchFalse, + .. + } + _ -> impossible + Case InstrCase {..} : is' -> + (a <> a', i' : is') + where + a' = mconcat as + as' = map (HashSet.difference a') as + n = length _instrCaseBranches + brs' = zipWithExact goBranch (take n as') _instrCaseBranches + def' = maybe Nothing (Just . addInits (List.last as')) _instrCaseDefault + i' = + Case + InstrCase + { _instrCaseBranches = brs', + _instrCaseDefault = def', + .. + } + + goBranch :: HashSet VarRef -> CaseBranch -> CaseBranch + goBranch vars = over caseBranchCode (addInits vars) + i : _ -> + case getResultVar i of + Just v -> + (HashSet.insert v a <> mconcat as, is) + Nothing -> + (a <> mconcat as, is) + [] -> + (a <> mconcat as, is) + + addInits :: HashSet VarRef -> Code -> Code + addInits vars is = map mk (toList vars) ++ is + where + mk :: VarRef -> Instruction + mk vref = + Assign + InstrAssign + { _instrAssignResult = vref, + _instrAssignValue = Const ConstVoid + } + +checkInitialized :: InfoTable -> Bool +checkInitialized tab = all (goFun . (^. functionCode)) (tab ^. infoFunctions) + where + goFun :: Code -> Bool + goFun = snd . ifoldB go (mempty, True) + where + go :: (HashSet VarRef, Bool) -> [(HashSet VarRef, Bool)] -> Instruction -> (HashSet VarRef, Bool) + go (v, b) ls i = case getResultVar i of + Just vref -> (HashSet.insert vref v', b') + Nothing -> (v', b') + where + v' = v <> mconcat (map fst ls) + b' = b && allSame (map fst ls) && and (map snd ls) diff --git a/test/Reg/Transformation.hs b/test/Reg/Transformation.hs index 7e2a1033a7..7f0b9033d1 100644 --- a/test/Reg/Transformation.hs +++ b/test/Reg/Transformation.hs @@ -2,6 +2,7 @@ module Reg.Transformation where import Base import Reg.Transformation.Identity qualified as Identity +import Reg.Transformation.InitBranchVars qualified as InitBranchVars import Reg.Transformation.SSA qualified as SSA allTests :: TestTree @@ -9,5 +10,6 @@ allTests = testGroup "JuvixReg transformations" [ Identity.allTests, - SSA.allTests + SSA.allTests, + InitBranchVars.allTests ] diff --git a/test/Reg/Transformation/InitBranchVars.hs b/test/Reg/Transformation/InitBranchVars.hs new file mode 100644 index 0000000000..5a56347c15 --- /dev/null +++ b/test/Reg/Transformation/InitBranchVars.hs @@ -0,0 +1,25 @@ +module Reg.Transformation.InitBranchVars where + +import Base +import Juvix.Compiler.Reg.Transformation +import Juvix.Compiler.Reg.Transformation.InitBranchVars +import Juvix.Compiler.Reg.Transformation.SSA +import Reg.Parse.Positive qualified as Parse +import Reg.Transformation.Base + +allTests :: TestTree +allTests = testGroup "InitBranchVars" (map liftTest Parse.tests) + +pipe :: [TransformationId] +pipe = [SSA, InitBranchVars] + +liftTest :: Parse.PosTest -> TestTree +liftTest _testRun = + fromTest + Test + { _testTransformations = pipe, + _testAssertion = \tab -> do + unless (checkSSA tab) $ error "check ssa" + unless (checkInitialized tab) $ error "check initialized", + _testRun + }