Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve specialization optimization #2944

Merged
merged 12 commits into from
Aug 14, 2024
Merged
10 changes: 5 additions & 5 deletions src/Juvix/Compiler/Casm/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
(blts, binstrs) <- addStdlibBuiltins (length pinstrs)
let cinstrs = concatMap (mkFunCall . fst) $ sortOn snd $ HashMap.toList (info ^. Reg.extraInfoFUIDs)
(addr, instrs) <- second (concat . reverse) <$> foldM (goFun blts endLab) (length pinstrs + length binstrs + length cinstrs, []) (tab ^. Reg.infoFunctions)
eassert (addr == length instrs + length cinstrs + length binstrs + length pinstrs)
massert (addr == length instrs + length cinstrs + length binstrs + length pinstrs)
registerLabelName endSym endName
registerLabelAddress endSym addr
return
Expand Down Expand Up @@ -181,15 +181,15 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
mapM_ goInstr _blockBody
case _blockNext of
Just block' -> do
eassert (isJust _blockFinal)
massert (isJust _blockFinal)
goFinalInstr (block' ^. Reg.blockLiveVars) (fromJust _blockFinal)
goBlock blts failLab liveVars0 mout block'
Nothing -> case _blockFinal of
Just instr ->
goFinalInstr liveVars0 instr
Nothing -> do
eassert (isJust mout)
eassert (HashSet.member (fromJust mout) liveVars0)
massert (isJust mout)
massert (HashSet.member (fromJust mout) liveVars0)
goCallBlock False Nothing liveVars0
where
output'' :: Instruction -> Sem r ()
Expand Down Expand Up @@ -634,7 +634,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI

goCase :: HashSet Reg.VarRef -> Reg.InstrCase -> Sem r ()
goCase liveVars Reg.InstrCase {..} = do
eassert (not (Reg.isInductiveRecord tab _instrCaseInductive))
massert (not (Reg.isInductiveRecord tab _instrCaseInductive))
syms <- replicateM (length tags) freshSymbol
symEnd <- freshSymbol
let symMap = HashMap.fromList $ zip tags syms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ checkImportNoPublic ::
Import 'Parsed ->
Sem r (Import 'Scoped)
checkImportNoPublic import_@Import {..} = do
eassert (_importPublic == NoPublic)
massert (_importPublic == NoPublic)
smodule <- readScopeModule import_
let sname :: S.TopModulePath = smodule ^. scopedModulePath
sname' :: S.Name = set S.nameConcrete (topModulePathToName _importModulePath) sname
Expand Down Expand Up @@ -1460,7 +1460,7 @@ checkSections sec = topBindings helper
-- section and start a new one
def@DefinitionFunctionDef {} : defs
| not (null ms) -> do
eassert (not (null acc))
massert (not (null acc))
sec' <- goDefsSection (nonEmpty' (reverse acc))
ms' <- goInductiveModules (nonEmpty' (reverse ms))
next' <- goDefs [] [] (def : defs)
Expand All @@ -1481,7 +1481,7 @@ checkSections sec = topBindings helper
let ms' = maybeToList m ++ ms
goDefs (def : acc) ms' defs
[] -> do
eassert (not (null acc))
massert (not (null acc))
sec' <- goDefsSection (nonEmpty' (reverse acc))
next' <- case nonEmpty (reverse ms) of
Nothing -> mapM goNonDefinitions _definitionsNext
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Concrete/Translation/FromSource.hs
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,8 @@ pnamedArgumentFunctionDef ::
(Members '[ParserResultBuilder, PragmasStash, JudocStash] r) =>
ParsecS r (NamedArgumentFunctionDef 'Parsed)
pnamedArgumentFunctionDef = do
optional_ stashJudoc
optional_ stashPragmas
fun <- functionDefinition True False Nothing
return
NamedArgumentFunctionDef
Expand Down
5 changes: 4 additions & 1 deletion src/Juvix/Compiler/Core/Data/BinderList.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import GHC.Show qualified as S
import Juvix.Compiler.Core.Language hiding (cons, drop, lookup, uncons)
import Juvix.Prelude qualified as Prelude

-- | if we have \x\y. b, the binderlist in b is [y, x]
-- | if we have `\\x\\y. b`, the binderlist in b is `[y, x]`
data BinderList a = BinderList
{ _blLength :: Int,
_blMap :: [a]
Expand All @@ -22,6 +22,9 @@ drop k (BinderList n l) = BinderList (n - k) (dropExact k l)
tail :: BinderList a -> BinderList a
tail = snd . fromJust . uncons

elem :: (Eq a) => BinderList a -> a -> Bool
elem bl a = a `Prelude.elem` (bl ^. blMap)

uncons :: BinderList a -> Maybe (a, BinderList a)
uncons l = second helper <$> Prelude.uncons (l ^. blMap)
where
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ unfoldLambdas' = first length . unfoldLambdas
lambdaTypes :: Node -> [Type]
lambdaTypes = map (\LambdaLhs {..} -> _lambdaLhsBinder ^. binderType) . fst . unfoldLambdas

lambdaBinders :: Node -> [Binder]
lambdaBinders = map (^. lambdaLhsBinder) . fst . unfoldLambdas

isConstructorApp :: Node -> Bool
isConstructorApp node = case node of
NCtr {} -> True
Expand Down
10 changes: 10 additions & 0 deletions src/Juvix/Compiler/Core/Info/PragmaInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ getInfoPragmas i =
setInfoPragmas :: [Pragmas] -> Info -> Info
setInfoPragmas = Info.insert . PragmasInfo

overInfoPragmas :: (Pragmas -> Pragmas) -> Info -> Info
overInfoPragmas f i = case Info.lookup kPragmasInfo i of
Just PragmasInfo {..} -> setInfoPragmas (map f _infoPragmas) i
Nothing -> i

getInfoPragma :: Info -> Pragmas
getInfoPragma i =
case Info.lookup kPragmaInfo i of
Expand All @@ -44,6 +49,11 @@ getInfoPragma i =
setInfoPragma :: Pragmas -> Info -> Info
setInfoPragma = Info.insert . PragmaInfo

overInfoPragma :: (Pragmas -> Pragmas) -> Info -> Info
overInfoPragma f i = case Info.lookup kPragmaInfo i of
Just PragmaInfo {..} -> setInfoPragma (f _infoPragma) i
Nothing -> i

getNodePragmas :: Node -> Pragmas
getNodePragmas = getInfoPragma . getInfo

Expand Down
26 changes: 24 additions & 2 deletions src/Juvix/Compiler/Core/Transformation/DisambiguateNames.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import Data.List.NonEmpty qualified as NonEmpty
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.NameInfo (setInfoName)
import Juvix.Compiler.Core.Info.PragmaInfo
import Juvix.Compiler.Core.Transformation.Base

disambiguateNodeNames' :: (BinderList Binder -> Text -> Text) -> Module -> Node -> Node
disambiguateNodeNames' disambiguate md = dmapL go
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
go bl node = case node' of
NVar Var {..} ->
mkVar (setInfoName (BL.lookup _varIndex bl ^. binderName) _varInfo) _varIndex
NIdt Ident {..} ->
Expand Down Expand Up @@ -56,7 +57,28 @@ disambiguateNodeNames' disambiguate md = dmapL go
NPi pi
| varOccurs 0 (pi ^. piBody) ->
NPi (over piBinder (over binderName (disambiguate bl)) pi)
_ -> node
_ -> node'
where
node' = modifyInfo (overInfoPragma disambiguatePragmas . overInfoPragmas disambiguatePragmas) node

disambiguatePragmas :: Pragmas -> Pragmas
disambiguatePragmas =
over
pragmasSpecialiseArgs
(fmap $ over pragmaSpecialiseArgs (map disambiguateArg))
. over
pragmasSpecialiseBy
(fmap $ over pragmaSpecialiseBy (map (disambiguate' bl)))

disambiguateArg :: PragmaSpecialiseArg -> PragmaSpecialiseArg
disambiguateArg = \case
SpecialiseArgNum i -> SpecialiseArgNum i
SpecialiseArgNamed n -> SpecialiseArgNamed (disambiguate' bl n)

disambiguate' :: BinderList Binder -> Text -> Text
disambiguate' bl n
| elem n (map (^. binderName) (toList bl)) = n
| otherwise = disambiguate mempty n

disambiguateBinders :: BinderList Binder -> [Binder] -> [Binder]
disambiguateBinders bl = \case
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ optimize' :: CoreOptions -> Module -> Module
optimize' opts@CoreOptions {..} md =
filterUnreachable
. compose
(4 * _optOptimizationLevel)
(6 * _optOptimizationLevel)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this changed from 4 to 6?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because with maps and folds inside record definitions 4 iterations are not enough to inline & specialize everything in a typical case (without further nestings).

( doConstantFolding
. doSimplification 2
. doInlining
Expand Down
38 changes: 20 additions & 18 deletions src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ isSpecializable md node =
_ -> True
NLam {} -> True
NCst {} -> True
NCtr Constr {..} -> all (isSpecializable md) _constrArgs
NCtr Constr {..} ->
case lookupConstructorInfo md _constrTag ^. constructorPragmas . pragmasSpecialise of
Just (PragmaSpecialise False) -> False
_ -> True
NApp {} ->
let (h, _) = unfoldApps' node
in isSpecializable md h
Expand Down Expand Up @@ -102,7 +105,11 @@ convertNode = dmapLRM go
(tyargs, tgt) = unfoldPi' (ii ^. identifierType)
def = lookupIdentifierNode md _identSymbol
(lams, body) = unfoldLambdas def
argnames = map (^. lambdaLhsBinder . binderName) lams
argnames =
zipWith
(\mn lhs -> fromMaybe (lhs ^. lambdaLhsBinder . binderName) mn)
(ii ^. identifierArgNames ++ repeat Nothing)
lams

-- arguments marked for specialisation with `specialize: true`
psargs0 =
Expand All @@ -118,23 +125,21 @@ convertNode = dmapLRM go
| (isJust pspec || isJust pspecby || not (null psargs0)) && length args == argsNum -> do
let psargs1 = mapMaybe getArgIndex $ maybe [] (^. pragmaSpecialiseArgs) pspec
psargs2 = maybe [] (map (+ 1) . mapMaybe (`elemIndex` argnames) . (^. pragmaSpecialiseBy)) pspecby
-- psargs are the arguments explicitly marked for specialization
psargs = nubSort (psargs0 ++ psargs1 ++ psargs2)
-- assumption: all type variables are at the front
let specargs0 =
let -- specargs0 are the arguments actually selected for specialization
specargs0 =
filter
( \argNum ->
argNum <= argsNum
&& isSpecializable md (args' !! (argNum - 1))
&& isArgSpecializable md _identSymbol argNum
)
psargs
tyargsNum = length (takeWhile (isTypeConstr md) tyargs)
tyargnums = map fst $ filter (isTypeConstr md . snd) $ zip [1 .. argsNum] tyargs
-- in addition to the arguments explicitly marked for
-- specialisation, also specialise all type arguments
specargs =
nub $
[1 .. tyargsNum]
++ specargs0
specargs = nub $ tyargnums ++ specargs0
-- the arguments marked for specialisation which we don't
-- specialise now
remainingSpecargs =
Expand Down Expand Up @@ -165,12 +170,9 @@ convertNode = dmapLRM go
| null specargs0 ->
return $ End (mkApps' (NIdt idt) args')
| otherwise -> do
eassert (tyargsNum < argsNum)
eassert (length lams == argsNum)
eassert (length args' == argsNum)
eassert (argsNum <= length tyargs)
-- assumption: all type variables are at the front
eassert (not $ any (isTypeConstr md) (drop tyargsNum tyargs))
massert (length lams == argsNum)
massert (length args' == argsNum)
massert (argsNum <= length tyargs)
-- the specialisation signature: the values we specialise the arguments by
let specSigArgs = selectSpecargs specargs args'
specSig = (specSigArgs, specargs)
Expand Down Expand Up @@ -237,8 +239,8 @@ convertNode = dmapLRM go
| otherwise ->
return $ End $ mkApps' (NIdt idt) args'

-- assumption: all type arguments are substituted, so no binders in the type
-- list refer to other elements in the list
-- Because all type arguments are substituted (specialized), in the end no
-- binders in the resulting type list refer to other elements in the list
removeSpecTypeArgs :: [Int] -> [Node] -> [Type] -> [Type]
removeSpecTypeArgs = goRemove 1
where
Expand Down Expand Up @@ -269,7 +271,7 @@ convertNode = dmapLRM go

shiftSpecargs :: [Int] -> [Int] -> [Int]
shiftSpecargs specargs =
map (\argNum -> argNum - length (filter (argNum <) specargs))
map (\argNum -> argNum - length (filter (argNum >) specargs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the change from < -> > fix a bug or is it related to another change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was incorrect, but we didn't see it because:

  1. This adjusts the indices of arguments marked for specialization which are not specialized because the provided argument has the wrong form, e.g., it is a variable (makes no sense to specialize by an unknown). They can be specialized later in another iteration if, e.g., inlining substitutes a concrete value for the variable, but this is not very common.
  2. If specialization fails nothing really bad happens -- the function application is just not specialized and thus less efficient.


-- Replace the calls to the function being specialised with the specialised
-- version (omitting the specialised arguments). We need to first replace
Expand Down
31 changes: 18 additions & 13 deletions src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ makeLenses ''PreMutual
mkIdentIndex :: Name -> Text
mkIdentIndex = show . (^. Internal.nameId)

computeImplicitArgs :: Internal.Expression -> [Bool]
computeImplicitArgs = \case
Internal.ExpressionFunction Internal.Function {..}
| _functionLeft ^. Internal.paramImplicit == Implicit ->
True : computeImplicitArgs _functionRight
| otherwise ->
False : computeImplicitArgs _functionRight
_ -> []

fromInternal :: (Members '[NameIdGen, Reader Store.ModuleTable, Error JuvixError] k) => Internal.InternalTypedResult -> Sem k CoreResult
fromInternal i = mapError (JuvixError . ErrBadScope) $ do
importTab <- asks Store.getInternalModuleTable
Expand Down Expand Up @@ -277,7 +286,7 @@ preFunctionDef f = do
sym <- freshSymbol
funTy <- fromTopIndex (goType (f ^. Internal.funDefType))
let _identifierName = f ^. Internal.funDefName . nameText
implParamsNum = implicitParametersNum (f ^. Internal.funDefType)
implArgs = computeImplicitArgs (f ^. Internal.funDefType)
info =
IdentifierInfo
{ _identifierName = normalizeBuiltinName (f ^. Internal.funDefBuiltin) (f ^. Internal.funDefName . nameText),
Expand All @@ -290,7 +299,7 @@ preFunctionDef f = do
_identifierIsExported = False,
_identifierBuiltin = f ^. Internal.funDefBuiltin,
_identifierPragmas =
adjustPragmas implParamsNum (f ^. Internal.funDefPragmas),
adjustPragmas' implArgs (f ^. Internal.funDefPragmas),
_identifierArgNames = argnames
}
case f ^. Internal.funDefBuiltin of
Expand Down Expand Up @@ -323,13 +332,6 @@ preFunctionDef f = do
">=" -> Str.natGe
_ -> name

implicitParametersNum :: Internal.Expression -> Int
implicitParametersNum = \case
Internal.ExpressionFunction Internal.Function {..}
| _functionLeft ^. Internal.paramImplicit == Implicit ->
implicitParametersNum _functionRight + 1
_ -> 0

getPatternName :: Internal.PatternArg -> Maybe Text
getPatternName pat = case pat ^. Internal.patternArgName of
Just n -> Just (n ^. nameText)
Expand Down Expand Up @@ -554,17 +556,20 @@ goLet l = goClauses (toList (l ^. Internal.letClauses))
funTy <- goType (f ^. Internal.funDefType)
funBody <- mkFunBody funTy f
rest <- localAddName (f ^. Internal.funDefName) (goClauses cs)
let name = f ^. Internal.funDefName . nameText
let implArgs = computeImplicitArgs (f ^. Internal.funDefType)
name = f ^. Internal.funDefName . nameText
loc = f ^. Internal.funDefName . nameLoc
info = setInfoPragma (f ^. Internal.funDefPragmas) mempty
body = modifyInfo (setInfoPragma (f ^. Internal.funDefPragmas)) funBody
pragmas = adjustPragmas' implArgs (f ^. Internal.funDefPragmas)
info = setInfoPragma pragmas mempty
body = modifyInfo (setInfoPragma pragmas) funBody
return $ mkLet info (Binder name (Just loc) funTy) body rest
goMutual :: Internal.MutualBlockLet -> Sem r Node
goMutual (Internal.MutualBlockLet funs) = do
let lfuns = toList funs
names = map (^. Internal.funDefName) lfuns
tys = map (^. Internal.funDefType) lfuns
pragmas = map (^. Internal.funDefPragmas) lfuns
implArgs = map (computeImplicitArgs . (^. Internal.funDefType)) lfuns
pragmas = zipWith adjustPragmas' implArgs (map (^. Internal.funDefPragmas) lfuns)
tys' <- mapM goType tys
localAddNames names $ do
vals' <- sequence [mkFunBody (shift (length names) ty) f | (ty, f) <- zipExact tys' lfuns]
Expand Down
6 changes: 4 additions & 2 deletions src/Juvix/Compiler/Core/Translation/Stripped/FromCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,14 @@ translateFunctionInfo tab IdentifierInfo {..} =
_functionBody =
translateFunction
_identifierArgsNum
(fromJust $ HashMap.lookup _identifierSymbol (tab ^. identContext)),
body,
_functionType = translateType _identifierType,
_functionArgsNum = _identifierArgsNum,
_functionArgsInfo = map translateArgInfo (typeArgsBinders _identifierType),
_functionArgsInfo = map translateArgInfo (lambdaBinders body),
_functionIsExported = _identifierIsExported
}
where
body = fromJust $ HashMap.lookup _identifierSymbol (tab ^. identContext)

translateArgInfo :: Binder -> Stripped.ArgumentInfo
translateArgInfo Binder {..} =
Expand Down
9 changes: 5 additions & 4 deletions src/Juvix/Compiler/Store/Core/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Juvix.Compiler.Store.Core.Extra where

import Juvix.Compiler.Core.Data.InfoTable qualified as Core
import Juvix.Compiler.Core.Extra qualified as Core
import Juvix.Compiler.Core.Info.PragmaInfo
import Juvix.Compiler.Store.Core.Data.InfoTable
import Juvix.Compiler.Store.Core.Language

Expand Down Expand Up @@ -72,9 +73,9 @@ toCoreNode = \case
NApp App {..} -> Core.mkApp' (toCoreNode _appLeft) (toCoreNode _appRight)
NBlt BuiltinApp {..} -> Core.mkBuiltinApp' _builtinAppOp (map toCoreNode _builtinAppArgs)
NCtr Constr {..} -> Core.mkConstr' _constrTag (map toCoreNode _constrArgs)
NLam Lambda {..} -> Core.mkLambda mempty (goBinder _lambdaBinder) (toCoreNode _lambdaBody)
NLam Lambda {..} -> Core.mkLambda (setInfoPragma (_lambdaInfo ^. lambdaInfoPragma) mempty) (goBinder _lambdaBinder) (toCoreNode _lambdaBody)
NLet Let {..} -> Core.NLet $ Core.Let mempty (goLetItem _letItem) (toCoreNode _letBody)
NRec LetRec {..} -> Core.NRec $ Core.LetRec mempty (fmap goLetItem _letRecValues) (toCoreNode _letRecBody)
NRec LetRec {..} -> Core.NRec $ Core.LetRec (setInfoPragmas (_letRecInfo ^. letRecInfoPragmas) mempty) (fmap goLetItem _letRecValues) (toCoreNode _letRecBody)
NCase Case {..} -> Core.mkCase' _caseInductive (toCoreNode _caseValue) (map goCaseBranch _caseBranches) (fmap toCoreNode _caseDefault)
NPi Pi {..} -> Core.mkPi mempty (goBinder _piBinder) (toCoreNode _piBody)
NUniv Univ {..} -> Core.mkUniv' _univLevel
Expand Down Expand Up @@ -159,9 +160,9 @@ fromCoreNode = \case
Core.NApp Core.App {..} -> NApp $ App () (fromCoreNode _appLeft) (fromCoreNode _appRight)
Core.NBlt Core.BuiltinApp {..} -> NBlt $ BuiltinApp () _builtinAppOp (map fromCoreNode _builtinAppArgs)
Core.NCtr Core.Constr {..} -> NCtr $ Constr () _constrTag (map fromCoreNode _constrArgs)
Core.NLam Core.Lambda {..} -> NLam $ Lambda () (goBinder _lambdaBinder) (fromCoreNode _lambdaBody)
Core.NLam Core.Lambda {..} -> NLam $ Lambda (LambdaInfo (getInfoPragma _lambdaInfo)) (goBinder _lambdaBinder) (fromCoreNode _lambdaBody)
Core.NLet Core.Let {..} -> NLet $ Let () (goLetItem _letItem) (fromCoreNode _letBody)
Core.NRec Core.LetRec {..} -> NRec $ LetRec () (fmap goLetItem _letRecValues) (fromCoreNode _letRecBody)
Core.NRec Core.LetRec {..} -> NRec $ LetRec (LetRecInfo (getInfoPragmas _letRecInfo)) (fmap goLetItem _letRecValues) (fromCoreNode _letRecBody)
Core.NCase Core.Case {..} -> NCase $ Case () _caseInductive (fromCoreNode _caseValue) (map goCaseBranch _caseBranches) (fmap fromCoreNode _caseDefault)
Core.NPi Core.Pi {..} -> NPi $ Pi () (goBinder _piBinder) (fromCoreNode _piBody)
Core.NUniv Core.Univ {..} -> NUniv $ Univ () _univLevel
Expand Down
Loading
Loading