From 4442b927dec87fa73e059678004a4c3ece860449 Mon Sep 17 00:00:00 2001 From: flupe Date: Thu, 14 Dec 2023 14:59:03 +0100 Subject: [PATCH] add inlining pragma --- src/Agda2Hs/Compile.hs | 5 ++- src/Agda2Hs/Compile/Function.hs | 27 +++++++++++++- src/Agda2Hs/Compile/Term.hs | 66 +++++++++++++++++++++++++++------ src/Agda2Hs/Compile/Utils.hs | 11 ++++-- src/Agda2Hs/HsUtils.hs | 2 +- src/Agda2Hs/Pragma.hs | 4 +- test/AllTests.agda | 2 + test/Inlining.agda | 35 +++++++++++++++++ test/golden/AllTests.hs | 1 + test/golden/Inlining.hs | 14 +++++++ 10 files changed, 147 insertions(+), 20 deletions(-) create mode 100644 test/Inlining.agda create mode 100644 test/golden/Inlining.hs diff --git a/src/Agda2Hs/Compile.hs b/src/Agda2Hs/Compile.hs index 682d24fc..e90d592c 100644 --- a/src/Agda2Hs/Compile.hs +++ b/src/Agda2Hs/Compile.hs @@ -9,6 +9,7 @@ import qualified Data.Map as M import Agda.Compiler.Backend import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName ) import Agda.TypeChecking.Pretty +import Agda.TypeChecking.Monad.Signature ( isInlineFun ) import Agda.Utils.Null import Agda.Utils.Monad ( whenM ) @@ -16,7 +17,7 @@ import qualified Language.Haskell.Exts.Extension as Hs import Agda2Hs.Compile.ClassInstance ( compileInstance ) import Agda2Hs.Compile.Data ( compileData ) -import Agda2Hs.Compile.Function ( compileFun, checkTransparentPragma ) +import Agda2Hs.Compile.Function ( compileFun, checkTransparentPragma, checkInlinePragma ) import Agda2Hs.Compile.Postulate ( compilePostulate ) import Agda2Hs.Compile.Record ( compileRecord, checkUnboxPragma ) import Agda2Hs.Compile.Types @@ -91,6 +92,8 @@ compile opts tlm _ def = withCurrentModule (qnameModule $ defName def) $ runC tl tag <$> compileFun True def (DefaultPragma ds, _, Record{}) -> tag . single <$> compileRecord (ToRecord ds) def + (InlinePragma, _, Function{}) -> do + checkInlinePragma def >> return [] _ -> genericDocError =<< do text "Don't know how to compile" <+> prettyTCM (defName def) diff --git a/src/Agda2Hs/Compile/Function.hs b/src/Agda2Hs/Compile/Function.hs index 2124d9ef..558ffc71 100644 --- a/src/Agda2Hs/Compile/Function.hs +++ b/src/Agda2Hs/Compile/Function.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE OverloadedStrings, ViewPatterns, NamedFieldPuns #-} module Agda2Hs.Compile.Function where import Control.Monad ( (>=>), filterM, forM_ ) @@ -26,7 +26,7 @@ import Agda.TypeChecking.Substitute import Agda.TypeChecking.Telescope ( telView ) import Agda.TypeChecking.Sort ( ifIsSort ) -import Agda.Utils.Functor ( (<&>) ) +import Agda.Utils.Functor ( (<&>), dget) import Agda.Utils.Impossible ( __IMPOSSIBLE__ ) import Agda.Utils.List import Agda.Utils.Maybe @@ -294,3 +294,26 @@ checkTransparentPragma def = compileFun False def >>= \case errNotTransparent = genericDocError =<< "Cannot make function" <+> prettyTCM (defName def) <+> "transparent." <+> "A transparent function must have exactly one non-erased argument and return it unchanged." + +checkInlinePragma :: Definition -> C () +checkInlinePragma def@Defn{defName = f} = do + let Function{funClauses = cs} = theDef def + case filter (isJust . clauseBody) cs of + [c] -> do + let Clause{clauseTel,namedClausePats = naps} = c + unlessM (allM (dget . dget <$> naps) allowedPat) $ genericDocError =<< + "Cannot make function" <+> prettyTCM (defName def) <+> "inlinable." <+> + "Inline functions can only use variable patterns, dot patterns, or transparent record constructor patterns." + _ -> + genericDocError =<< + "Cannot make function" <+> prettyTCM f <+> "inlinable." <+> + "An inline function must have exactly one clause." + where allowedPat :: DeBruijnPattern -> C Bool + allowedPat VarP{} = pure True + allowedPat DotP{} = pure True + -- only allow matching on (unboxed) record constructors + allowedPat (ConP ch ci cargs) = + isUnboxConstructor (conName ch) >>= \case + Just _ -> allM cargs (allowedPat . dget . dget) + Nothing -> pure False + allowedPat _ = pure False diff --git a/src/Agda2Hs/Compile/Term.hs b/src/Agda2Hs/Compile/Term.hs index e13e615b..fda6bd42 100644 --- a/src/Agda2Hs/Compile/Term.hs +++ b/src/Agda2Hs/Compile/Term.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ViewPatterns, NamedFieldPuns #-} module Agda2Hs.Compile.Term where import Control.Arrow ( (>>>), (&&&) ) @@ -7,6 +8,7 @@ import Control.Monad.Reader import Data.List ( isPrefixOf ) import Data.Maybe ( fromMaybe, isJust ) import qualified Data.Text as Text ( unpack ) +import qualified Data.Set as Set (singleton) import qualified Language.Haskell.Exts as Hs @@ -18,8 +20,8 @@ import Agda.Syntax.Internal import Agda.TypeChecking.Monad import Agda.TypeChecking.Pretty -import Agda.TypeChecking.Reduce ( instantiate ) -import Agda.TypeChecking.Substitute ( Apply(applyE) ) +import Agda.TypeChecking.Reduce ( instantiate, unfoldDefinitionStep ) +import Agda.TypeChecking.Substitute ( Apply(applyE), raise, mkAbs ) import Agda.Utils.Lens @@ -228,11 +230,9 @@ compileTerm v = do | Just semantics <- isSpecialTerm f -> do reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of special function" semantics f es - | otherwise -> isClassFunction f >>= \case - True -> compileClassFunApp f es - False -> (isJust <$> isUnboxProjection f) `or2M` isTransparentFunction f >>= \case - True -> compileErasedApp es - False -> do + | otherwise -> ifM (isClassFunction f) (compileClassFunApp f es) $ do + ifM ((isJust <$> isUnboxProjection f) `or2M` isTransparentFunction f) (compileErasedApp es) $ do + ifM (isInlinedFunction f) (compileInlineFunctionApp f es) $ do reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of regular function" -- Drop module parameters of local `where` functions moduleArgs <- getDefFreeVars f @@ -281,14 +281,56 @@ compileTerm v = do Just _ -> compileErasedApp es Nothing -> (`app` es) . Hs.Con () =<< compileQName (conName h) --- `compileErasedApp` compiles an application of an erased constructor --- or projection. +-- `compileErasedApp` compiles an application of an unboxed constructor +-- or unboxed projection or transparent function. +-- Precondition is that at most one elim is preserved. compileErasedApp :: Elims -> C (Hs.Exp ()) compileErasedApp es = do - reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of erased function" + reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of transparent function or erased unboxed constructor" compileElims es >>= \case - [] -> return $ hsVar "id" - (v:vs) -> return $ v `eApp` vs + [] -> return $ hsVar "id" + [v] -> return v + _ -> __IMPOSSIBLE__ + +-- | Compile the application of a function definition marked as inlinable. +-- The provided arguments will get substituted in the function body, and the missing arguments +-- will get quantified with lambdas. +compileInlineFunctionApp :: QName -> Elims -> C (Hs.Exp ()) +compileInlineFunctionApp f es = do + reportSDoc "agda2hs.compile.term" 12 $ text "Compiling application of inline function" + Function { funClauses = cs } <- theDef <$> getConstInfo f + let [ Clause { namedClausePats = pats + , clauseBody = Just body + , clauseTel + } ] = filter (isJust . clauseBody) cs + etaExpand (drop (length es) pats) es >>= compileTerm + where + -- inline functions can only have transparent constructor patterns and variable patterns + extractPatName :: DeBruijnPattern -> ArgName + extractPatName (VarP _ v) = dbPatVarName v + extractPatName (ConP _ _ args) = + let arg = namedThing $ unArg $ head $ filter (usableModality `and2M` visible) args + in extractPatName arg + extractPatName _ = __IMPOSSIBLE__ + + extractName :: NamedArg DeBruijnPattern -> ArgName + extractName (unArg -> np) + | Just n <- nameOf np = rangedThing (woThing n) + | otherwise = extractPatName (namedThing np) + + etaExpand :: NAPs -> Elims -> C Term + etaExpand [] es = do + r <- liftReduce + $ locallyReduceDefs (OnlyReduceDefs $ Set.singleton f) + $ unfoldDefinitionStep (Def f es) f es + case r of + YesReduction _ t -> pure t + _ -> genericDocError =<< text "Could not reduce inline function" <+> prettyTCM f + + etaExpand (p:ps) es = + let ai = argInfo p in + Lam ai . mkAbs (extractName p) + <$> etaExpand ps (raise 1 es ++ [ Apply $ Arg ai $ var 0 ]) -- `compileClassFunApp` is used when we have a record projection and we want to -- drop the first visible arg (the record) diff --git a/src/Agda2Hs/Compile/Utils.hs b/src/Agda2Hs/Compile/Utils.hs index 349ea771..ccf89be1 100644 --- a/src/Agda2Hs/Compile/Utils.hs +++ b/src/Agda2Hs/Compile/Utils.hs @@ -191,9 +191,14 @@ isTransparentFunction :: QName -> C Bool isTransparentFunction q = do getConstInfo q >>= \case Defn{defName = r, theDef = Function{}} -> - processPragma r <&> \case - TransparentPragma -> True - _ -> False + (TransparentPragma ==) <$> processPragma r + _ -> return False + +isInlinedFunction :: QName -> C Bool +isInlinedFunction q = do + getConstInfo q >>= \case + Defn{defName = r, theDef = Function{}} -> + (InlinePragma ==) <$> processPragma r _ -> return False checkInstance :: Term -> C () diff --git a/src/Agda2Hs/HsUtils.hs b/src/Agda2Hs/HsUtils.hs index e7426c6d..328b327a 100644 --- a/src/Agda2Hs/HsUtils.hs +++ b/src/Agda2Hs/HsUtils.hs @@ -267,4 +267,4 @@ patToExp = \case _ -> Nothing data Strictness = Lazy | Strict - deriving Show + deriving (Eq, Show) diff --git a/src/Agda2Hs/Pragma.hs b/src/Agda2Hs/Pragma.hs index 0f32fac4..6321b221 100644 --- a/src/Agda2Hs/Pragma.hs +++ b/src/Agda2Hs/Pragma.hs @@ -48,6 +48,7 @@ getForeignPragmas exts = do data ParsedPragma = NoPragma + | InlinePragma | DefaultPragma [Hs.Deriving ()] | ClassPragma [String] | ExistingClassPragma @@ -55,7 +56,7 @@ data ParsedPragma | TransparentPragma | NewTypePragma [Hs.Deriving ()] | DerivePragma (Maybe (Hs.DerivStrategy ())) - deriving Show + deriving (Eq, Show) derivePragma :: String derivePragma = "derive" @@ -85,6 +86,7 @@ processPragma qn = liftTCM (getUniqueCompilerPragma pragmaName qn) >>= \case Nothing -> return NoPragma Just (CompilerPragma _ s) | "class" `isPrefixOf` s -> return $ ClassPragma (words $ drop 5 s) + | s == "inline" -> return InlinePragma | s == "existing-class" -> return ExistingClassPragma | s == "unboxed" -> return $ UnboxPragma Lazy | s == "unboxed-strict" -> return $ UnboxPragma Strict diff --git a/test/AllTests.agda b/test/AllTests.agda index 633901ff..5b327cf4 100644 --- a/test/AllTests.agda +++ b/test/AllTests.agda @@ -63,6 +63,7 @@ import Issue210 import ModuleParameters import ModuleParametersImports import Coerce +import Inlining {-# FOREIGN AGDA2HS import Issue14 @@ -126,4 +127,5 @@ import Issue210 import ModuleParameters import ModuleParametersImports import Coerce +import Inlining #-} diff --git a/test/Inlining.agda b/test/Inlining.agda new file mode 100644 index 00000000..a0d6aab4 --- /dev/null +++ b/test/Inlining.agda @@ -0,0 +1,35 @@ +module Inlining where + +open import Haskell.Prelude + +record Wrap (a : Set) : Set where + constructor Wrapped + field + unwrap : a +open Wrap public +{-# COMPILE AGDA2HS Wrap unboxed #-} + +mapWrap : (f : a → b) → Wrap a → Wrap b +mapWrap f (Wrapped x) = Wrapped (f x) +{-# COMPILE AGDA2HS mapWrap inline #-} + +mapWrap2 : (f : a → b → c) → Wrap a → Wrap b → Wrap c +mapWrap2 f (Wrapped x) (Wrapped y) = Wrapped (f x y) +{-# COMPILE AGDA2HS mapWrap2 inline #-} + +test1 : Wrap Int → Wrap Int +test1 x = mapWrap (1 +_) x +{-# COMPILE AGDA2HS test1 #-} + +test2 : Wrap Int → Wrap Int → Wrap Int +test2 x y = mapWrap2 _+_ x y +{-# COMPILE AGDA2HS test2 #-} + +-- partial application of inline function +test3 : Wrap Int → Wrap Int → Wrap Int +test3 x = mapWrap2 _+_ x +{-# COMPILE AGDA2HS test3 #-} + +test4 : Wrap Int → Wrap Int → Wrap Int +test4 = mapWrap2 _+_ +{-# COMPILE AGDA2HS test4 #-} diff --git a/test/golden/AllTests.hs b/test/golden/AllTests.hs index d8478991..4098b4f0 100644 --- a/test/golden/AllTests.hs +++ b/test/golden/AllTests.hs @@ -61,4 +61,5 @@ import Issue210 import ModuleParameters import ModuleParametersImports import Coerce +import Inlining diff --git a/test/golden/Inlining.hs b/test/golden/Inlining.hs new file mode 100644 index 00000000..47c763cd --- /dev/null +++ b/test/golden/Inlining.hs @@ -0,0 +1,14 @@ +module Inlining where + +test1 :: Int -> Int +test1 x = 1 + x + +test2 :: Int -> Int -> Int +test2 x y = x + y + +test3 :: Int -> Int -> Int +test3 x = \ y -> x + y + +test4 :: Int -> Int -> Int +test4 = \ x y -> x + y +