Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copy propagation in JuvixReg #2828

Merged
merged 9 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ jobs:
- name: Install RISC0 VM
shell: bash
run: |
cargo install cargo-binstall --force
cargo install cargo-binstall@1.6.9 --force
cargo binstall [email protected] --no-confirm --force
cargo risczero install

Expand Down
10 changes: 7 additions & 3 deletions src/Juvix/Compiler/Reg/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ data TransformationId
| Cleanup
| SSA
| InitBranchVars
| CopyPropagation
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -19,14 +20,16 @@ data PipelineId

type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

-- Note: this works only because for now we mark all variables as live. Liveness
-- information needs to be re-computed after copy propagation.
toCTransformations :: [TransformationId]
toCTransformations = [Cleanup]
toCTransformations = [Cleanup, CopyPropagation]

toRustTransformations :: [TransformationId]
toRustTransformations = [Cleanup]
toRustTransformations = [Cleanup, CopyPropagation]

toCasmTransformations :: [TransformationId]
toCasmTransformations = [Cleanup, SSA]
toCasmTransformations = [Cleanup, CopyPropagation, SSA]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
Expand All @@ -35,6 +38,7 @@ instance TransformationId' TransformationId where
Cleanup -> strCleanup
SSA -> strSSA
InitBranchVars -> strInitBranchVars
CopyPropagation -> strCopyPropagation

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ strSSA = "ssa"

strInitBranchVars :: Text
strInitBranchVars = "init-branch-vars"

strCopyPropagation :: Text
strCopyPropagation = "copy-propagation"
10 changes: 10 additions & 0 deletions src/Juvix/Compiler/Reg/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,13 @@ overValueRefs f = \case

goBlock :: InstrBlock -> InstrBlock
goBlock x = x

updateLiveVars' :: (VarRef -> Maybe VarRef) -> Instruction -> Instruction
updateLiveVars' f = \case
Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe f) x
Call x -> Call $ over instrCallLiveVars (mapMaybe f) x
CallClosures x -> CallClosures $ over instrCallClosuresLiveVars (mapMaybe f) x
instr -> instr

updateLiveVars :: (VarRef -> VarRef) -> Instruction -> Instruction
updateLiveVars f = updateLiveVars' (Just . f)
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Reg/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ where
import Juvix.Compiler.Reg.Data.TransformationId
import Juvix.Compiler.Reg.Transformation.Base
import Juvix.Compiler.Reg.Transformation.Cleanup
import Juvix.Compiler.Reg.Transformation.CopyPropagation
import Juvix.Compiler.Reg.Transformation.IdentityTrans
import Juvix.Compiler.Reg.Transformation.InitBranchVars
import Juvix.Compiler.Reg.Transformation.SSA
Expand All @@ -21,3 +22,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
Cleanup -> return . cleanup
SSA -> return . computeSSA
InitBranchVars -> return . initBranchVars
CopyPropagation -> return . copyPropagate
56 changes: 56 additions & 0 deletions src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module Juvix.Compiler.Reg.Transformation.CopyPropagation where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Reg.Extra
import Juvix.Compiler.Reg.Transformation.Base

type VarMap = HashMap VarRef VarRef

copyPropagateFunction :: Code -> Code
copyPropagateFunction =
snd
. runIdentity
. recurseF
ForwardRecursorSig
{ _forwardFun = \i acc -> return (go i acc),
_forwardCombine = combine
}
mempty
where
go :: Instruction -> VarMap -> (VarMap, Instruction)
go instr mpv = case instr' of
Assign InstrAssign {..}
| VRef v <- _instrAssignValue ->
(HashMap.insert _instrAssignResult v mpv', instr')
_ ->
(mpv', instr')
where
instr' = overValueRefs (adjustVarRef mpv) instr
mpv' = maybe mpv (filterOutVars mpv) (getResultVar instr)

filterOutVars :: VarMap -> VarRef -> VarMap
filterOutVars mpv v = HashMap.delete v $ HashMap.filter (/= v) mpv

adjustVarRef :: VarMap -> VarRef -> VarRef
adjustVarRef mpv vref@VarRef {..} = case _varRefGroup of
VarGroupArgs -> vref
VarGroupLocal -> fromMaybe vref $ HashMap.lookup vref mpv

combine :: Instruction -> NonEmpty VarMap -> (VarMap, Instruction)
combine instr mpvs = (mpv, instr')
where
mpv' :| mpvs' = fmap HashMap.toList mpvs
mpv =
HashMap.fromList
. HashSet.toList
. foldr (HashSet.intersection . HashSet.fromList) (HashSet.fromList mpv')
$ mpvs'

instr' = case instr of
Branch x -> Branch $ over instrBranchOutVar (fmap (adjustVarRef mpv)) x
Case x -> Case $ over instrCaseOutVar (fmap (adjustVarRef mpv)) x
_ -> impossible

copyPropagate :: InfoTable -> InfoTable
copyPropagate = mapT (const copyPropagateFunction)
11 changes: 2 additions & 9 deletions src/Juvix/Compiler/Reg/Transformation/SSA.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,13 @@ computeFunctionSSA =
where
go :: Instruction -> IndexMap VarRef -> (IndexMap VarRef, Instruction)
go instr mp = case getResultVar instr' of
Just vref -> (mp', updateLiveVars mp' (setResultVar instr' (mkVarRef VarGroupLocal idx)))
Just vref -> (mp', updateLiveVars' (adjustVarRef' mp') (setResultVar instr' (mkVarRef VarGroupLocal idx)))
where
(idx, mp') = IndexMap.assign mp vref
Nothing -> (mp, updateLiveVars mp instr')
Nothing -> (mp, updateLiveVars' (adjustVarRef' mp) instr')
where
instr' = overValueRefs (adjustVarRef mp) instr

updateLiveVars :: IndexMap VarRef -> Instruction -> Instruction
updateLiveVars mp = \case
Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe (adjustVarRef' mp)) x
Call x -> Call $ over instrCallLiveVars (mapMaybe (adjustVarRef' mp)) x
CallClosures x -> CallClosures $ over instrCallClosuresLiveVars (mapMaybe (adjustVarRef' mp)) x
instr -> instr

-- For branches, when necessary we insert assignments unifying the renamed
-- output variables into a single output variable for both branches.
combine :: Instruction -> NonEmpty (IndexMap VarRef) -> (IndexMap VarRef, Instruction)
Expand Down
15 changes: 14 additions & 1 deletion test/Reg/Parse/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ testDescr PosTest {..} =
filterTests :: [String] -> [PosTest] -> [PosTest]
filterTests incl = filter (\PosTest {..} -> _name `elem` incl)

filterOutTests :: [String] -> [PosTest] -> [PosTest]
filterOutTests excl = filter (\PosTest {..} -> _name `notElem` excl)

allTests :: TestTree
allTests =
testGroup
Expand Down Expand Up @@ -223,5 +226,15 @@ tests =
"Test038: Apply & argsnum"
$(mkRelDir ".")
$(mkRelFile "test038.jvr")
$(mkRelFile "out/test038.out")
$(mkRelFile "out/test038.out"),
PosTest
"Test039: Copy & constant propagation"
$(mkRelDir ".")
$(mkRelFile "test039.jvr")
$(mkRelFile "out/test039.out"),
PosTest
"Test040: Copy & constant propagation with branches"
$(mkRelDir ".")
$(mkRelFile "test040.jvr")
$(mkRelFile "out/test040.out")
]
4 changes: 3 additions & 1 deletion test/Reg/Transformation.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Reg.Transformation where

import Base
import Reg.Transformation.CopyPropagation qualified as CopyPropagation
import Reg.Transformation.IdentityTrans qualified as IdentityTrans
import Reg.Transformation.InitBranchVars qualified as InitBranchVars
import Reg.Transformation.SSA qualified as SSA
Expand All @@ -11,5 +12,6 @@ allTests =
"JuvixReg transformations"
[ IdentityTrans.allTests,
SSA.allTests,
InitBranchVars.allTests
InitBranchVars.allTests,
CopyPropagation.allTests
]
21 changes: 21 additions & 0 deletions test/Reg/Transformation/CopyPropagation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module Reg.Transformation.CopyPropagation where

import Base
import Juvix.Compiler.Reg.Transformation
import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "Copy Propagation" (map liftTest Parse.tests)

pipe :: [TransformationId]
pipe = [CopyPropagation]

liftTest :: Parse.PosTest -> TestTree
liftTest _testRun =
fromTest
Test
{ _testTransformations = pipe,
_testAssertion = const (return ()),
_testRun
}
2 changes: 1 addition & 1 deletion test/Reg/Transformation/InitBranchVars.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "InitBranchVars" (map liftTest Parse.tests)
allTests = testGroup "InitBranchVars" (map liftTest $ Parse.filterOutTests ["Test039: Copy & constant propagation"] Parse.tests)

pipe :: [TransformationId]
pipe = [SSA, InitBranchVars]
Expand Down
2 changes: 1 addition & 1 deletion test/Reg/Transformation/SSA.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "SSA" (map liftTest Parse.tests)
allTests = testGroup "SSA" (map liftTest $ Parse.filterOutTests ["Test039: Copy & constant propagation"] Parse.tests)

pipe :: [TransformationId]
pipe = [SSA]
Expand Down
1 change: 1 addition & 0 deletions tests/Reg/positive/out/test039.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
82
1 change: 1 addition & 0 deletions tests/Reg/positive/out/test040.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
79
84 changes: 84 additions & 0 deletions tests/Reg/positive/test039.jvr
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
-- Copy & constant propagation

type either {
left : integer -> either;
right : bool -> either;
}

function main() : * {
tmp[0] = 7;
tmp[1] = tmp[0];
tmp[0] = tmp[1];
tmp[2] = tmp[0];
-- tmp[2] = 7

tmp[1] = tmp[0];
tmp[0] = add tmp[1] 1;
tmp[2] = add tmp[2] tmp[1];
-- tmp[2] = 14

tmp[0] = 19;
tmp[1] = tmp[0];
tmp[0] = add tmp[1] 1;
tmp[3] = add tmp[0] tmp[1];
tmp[4] = tmp[3];
tmp[2] = add tmp[4] tmp[2];
-- tmp[2] = 53

tmp[1] = eq tmp[2] 54;
tmp[0] = 4;
tmp[3] = 3;
tmp[4] = tmp[0];
tmp[5] = 4;
tmp[6] = tmp[5];
br tmp[1] {
true: {
tmp[4] = 7;
};
false: {
tmp[3] = tmp[6];
};
};
tmp[2] = add tmp[2] tmp[4];
tmp[2] = add tmp[2] tmp[3];
-- tmp[2] = 61

tmp[0] = alloc left (3);
tmp[1] = 17;
tmp[3] = tmp[1];
case[either] tmp[0] {
left: {
tmp[4] = tmp[0].left[0];
tmp[1] = tmp[4];
tmp[3] = tmp[1];
};
right: {
nop;
};
};
tmp[2] = add tmp[2] tmp[3];
-- tmp[2] = 64

tmp[0] = alloc right (true);
tmp[1] = 17;
tmp[3] = tmp[1];
case[either] tmp[0] {
left: {
tmp[1] = tmp[0].left[0];
};
right: {
br tmp[0].right[0] {
true: {
tmp[1] = add tmp[3] 1;
};
false: {
nop;
};
};
};
};
tmp[2] = add tmp[2] tmp[1];
-- tmp[2] = 82

ret tmp[2];
}
Loading
Loading