From bf9ebb9a5d75390362e6118c4b511de3e0262a90 Mon Sep 17 00:00:00 2001 From: Jan Mas Rovira Date: Fri, 7 Jun 2024 18:09:41 +0200 Subject: [PATCH] add NormalizedExpression --- .../Compiler/Core/Translation/FromInternal.hs | 2 +- src/Juvix/Compiler/Internal/Language.hs | 10 +++ .../Internal/Translation/FromInternal.hs | 2 +- .../Analysis/Positivity/Checker.hs | 2 +- .../Analysis/TypeChecking/CheckerNew.hs | 66 ++++++++++--------- .../Analysis/TypeChecking/Data/Inference.hs | 49 +++++++++++--- .../Analysis/TypeChecking/Error/Types.hs | 60 +++++++++++------ .../Analysis/TypeChecking/Traits/Resolver.hs | 6 +- 8 files changed, 132 insertions(+), 65 deletions(-) diff --git a/src/Juvix/Compiler/Core/Translation/FromInternal.hs b/src/Juvix/Compiler/Core/Translation/FromInternal.hs index bba87ca07e..d0c19ac596 100644 --- a/src/Juvix/Compiler/Core/Translation/FromInternal.hs +++ b/src/Juvix/Compiler/Core/Translation/FromInternal.hs @@ -362,7 +362,7 @@ goType :: Sem r Type goType ty = do normTy <- InternalTyped.strongNormalize'' ty - squashApps <$> goExpression normTy + squashApps <$> goExpression (normTy ^. Internal.normalizedExpression) mkFunBody :: forall r. diff --git a/src/Juvix/Compiler/Internal/Language.hs b/src/Juvix/Compiler/Internal/Language.hs index b6d4f7dc1d..c6e042fbfc 100644 --- a/src/Juvix/Compiler/Internal/Language.hs +++ b/src/Juvix/Compiler/Internal/Language.hs @@ -428,6 +428,12 @@ newtype ModuleIndex = ModuleIndex } deriving stock (Data) +-- | An expression that maybe has been normalized +data NormalizedExpression = NormalizedExpression + { _normalizedExpression :: Expression, + _normalizedExpressionOriginal :: Expression + } + makeLenses ''ModuleIndex makeLenses ''ArgInfo makeLenses ''WildcardConstructor @@ -454,6 +460,10 @@ makeLenses ''FunctionParameter makeLenses ''InductiveParameter makeLenses ''ConstructorDef makeLenses ''ConstructorApp +makeLenses ''NormalizedExpression + +instance HasLoc NormalizedExpression where + getLoc = getLoc . (^. normalizedExpressionOriginal) instance Eq ModuleIndex where (==) = (==) `on` (^. moduleIxModule . moduleName) diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal.hs index 6f75820379..00c484fd2c 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal.hs @@ -38,7 +38,7 @@ typeCheckExpressionType exp = do . mapError (JuvixError @TypeCheckerError) . runInferenceDef $ inferExpression Nothing exp - >>= traverseOf typedType strongNormalize + >>= traverseOf typedType strongNormalize_ typeCheckExpression :: (Members '[Error JuvixError, State Artifacts, Termination] r) => diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Positivity/Checker.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Positivity/Checker.hs index 807fb42ed5..e1af4e12b4 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Positivity/Checker.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/Positivity/Checker.hs @@ -64,7 +64,7 @@ checkStrictlyPositiveOccurrences :: CheckPositivityArgs -> Sem r () checkStrictlyPositiveOccurrences p = do - typeOfConstr <- strongNormalize (p ^. checkPositivityArgsTypeOfConstructor) + typeOfConstr <- strongNormalize_ (p ^. checkPositivityArgsTypeOfConstructor) go False typeOfConstr where indInfo = p ^. checkPositivityArgsInductive diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/CheckerNew.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/CheckerNew.hs index e5001b7456..18c7708567 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/CheckerNew.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/CheckerNew.hs @@ -369,12 +369,11 @@ checkInstanceType :: checkInstanceType FunctionDef {..} = do ty <- strongNormalize _funDefType let mi = - instanceFromTypedExpression - ( TypedExpression - { _typedType = ty, - _typedExpression = ExpressionIden (IdenFunction _funDefName) - } - ) + instanceFromTypedExpression $ + TypedExpression + { _typedType = ty ^. normalizedExpression, + _typedExpression = ExpressionIden (IdenFunction _funDefName) + } case mi of Just ii@InstanceInfo {..} -> do tab <- ask @@ -401,10 +400,16 @@ checkInstanceType FunctionDef {..} = do _ -> throw (ErrNotATrait (NotATrait _paramType)) -checkInstanceParam :: (Member (Error TypeCheckerError) r) => InfoTable -> Expression -> Sem r () -checkInstanceParam tab ty = case traitFromExpression mempty ty of +checkInstanceParam :: + (Member (Error TypeCheckerError) r) => + InfoTable -> + NormalizedExpression -> + Sem r () +checkInstanceParam tab ty' = case traitFromExpression mempty ty of Just InstanceApp {..} | isTrait tab _instanceAppHead -> return () _ -> throw (ErrNotATrait (NotATrait ty)) + where + ty = ty' ^. normalizedExpression checkCoercionType :: forall r. @@ -416,7 +421,7 @@ checkCoercionType FunctionDef {..} = do let mi = coercionFromTypedExpression ( TypedExpression - { _typedType = ty, + { _typedType = ty ^. normalizedExpression, _typedExpression = ExpressionIden (IdenFunction _funDefName) } ) @@ -456,11 +461,16 @@ checkExpression expectedTy e = do e' <- strongNormalize e inferred' <- strongNormalize (inferred ^. typedType) expected' <- strongNormalize expectedTy + let thing = + WrongTypeThingExpression + MkWrongTypeThingExpression + { _wrongTypeNormalizedExpression = e', + _wrongTypeInferredExpression = inferred ^. typedExpression + } throw . ErrWrongType $ WrongType - { _wrongTypeThing = Left e', - _wrongTypeThingWithHoles = Just (Left (inferred ^. typedExpression)), + { _wrongTypeThing = thing, _wrongTypeActual = inferred', _wrongTypeExpected = expected' } @@ -497,7 +507,7 @@ checkFunctionParameter FunctionParameter {..} = do checkInstanceParam tab ty' return FunctionParameter - { _paramType = ty', + { _paramType = ty' ^. normalizedExpression, _paramName, _paramImplicit } @@ -728,15 +738,13 @@ checkPattern = go constrName = a ^. constrAppConstructor err :: MatchError -> Sem r () err m = - throw - ( ErrWrongType - WrongType - { _wrongTypeThing = Right pat, - _wrongTypeThingWithHoles = Nothing, - _wrongTypeExpected = m ^. matchErrorRight, - _wrongTypeActual = m ^. matchErrorLeft - } - ) + throw $ + ErrWrongType + WrongType + { _wrongTypeThing = WrongTypeThingPattern pat, + _wrongTypeExpected = m ^. matchErrorRight, + _wrongTypeActual = m ^. matchErrorLeft + } case s of Left hole -> do let indParams = info ^. constructorInfoInductiveParameters @@ -754,15 +762,13 @@ checkPattern = go Right (ind, tyArgs) -> do when (ind /= constrIndName) - ( throw - ( ErrWrongConstructorType - WrongConstructorType - { _wrongCtorTypeName = constrName, - _wrongCtorTypeExpected = ind, - _wrongCtorTypeActual = constrIndName - } - ) - ) + $ throw + $ ErrWrongConstructorType + WrongConstructorType + { _wrongCtorTypeName = constrName, + _wrongCtorTypeExpected = ind, + _wrongCtorTypeActual = constrIndName + } PatternConstructorApp <$> goConstr (IdenInductive ind) a tyArgs goConstr :: Iden -> ConstructorApp -> [(InductiveParameter, Expression)] -> Sem r ConstructorApp diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Data/Inference.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Data/Inference.hs index ae5c585bbd..7b249d242d 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Data/Inference.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Data/Inference.hs @@ -8,6 +8,7 @@ module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.TypeChecking.Da queryMetavar, registerIdenType, strongNormalize'', + strongNormalize_, iniState, strongNormalize, weakNormalize, @@ -35,8 +36,8 @@ data MetavarState Refined Expression data MatchError = MatchError - { _matchErrorLeft :: Expression, - _matchErrorRight :: Expression + { _matchErrorLeft :: NormalizedExpression, + _matchErrorRight :: NormalizedExpression } makeLenses ''MatchError @@ -46,7 +47,7 @@ data Inference :: Effect where QueryMetavar :: Hole -> Inference m (Maybe Expression) RegisterIdenType :: Name -> Expression -> Inference m () RememberFunctionDef :: FunctionDef -> Inference m () - StrongNormalize :: Expression -> Inference m Expression + StrongNormalize :: Expression -> Inference m NormalizedExpression WeakNormalize :: Expression -> Inference m Expression makeSem ''Inference @@ -131,9 +132,18 @@ queryMetavarFinal h = do Just (ExpressionHole h') -> queryMetavarFinal h' _ -> return m +strongNormalize_ :: (Members '[Inference] r) => Expression -> Sem r Expression +strongNormalize_ = fmap (^. normalizedExpression) . strongNormalize + -- FIXME the returned expression should have the same location as the original -strongNormalize' :: forall r. (Members '[ResultBuilder, State InferenceState, NameIdGen] r) => Expression -> Sem r Expression -strongNormalize' = go +strongNormalize' :: forall r. (Members '[ResultBuilder, State InferenceState, NameIdGen] r) => Expression -> Sem r NormalizedExpression +strongNormalize' original = do + normalized <- go original + return + NormalizedExpression + { _normalizedExpression = normalized, + _normalizedExpressionOriginal = original + } where go :: Expression -> Sem r Expression go e = case e of @@ -362,14 +372,30 @@ runInferenceState inis = reinterpret (runState inis) $ \case where ok :: Sem r (Maybe MatchError) ok = return Nothing + check :: Bool -> Sem r (Maybe MatchError) check b | b = ok | otherwise = err + bicheck :: Sem r (Maybe MatchError) -> Sem r (Maybe MatchError) -> Sem r (Maybe MatchError) bicheck = liftA2 (<|>) + + normalizedB = + NormalizedExpression + { _normalizedExpression = normB, + _normalizedExpressionOriginal = inputB + } + + normalizedA = + NormalizedExpression + { _normalizedExpression = normA, + _normalizedExpressionOriginal = inputA + } + err :: Sem r (Maybe MatchError) - err = return (Just (MatchError normA normB)) + err = return (Just (MatchError normalizedA normalizedB)) + goHole :: Hole -> Expression -> Sem r (Maybe MatchError) goHole h t = do r <- queryMetavar' h @@ -382,7 +408,7 @@ runInferenceState inis = reinterpret (runState inis) $ \case | ExpressionHole h' <- holTy, h' == hol = return () | otherwise = do - holTy' <- strongNormalize' holTy + holTy' <- (^. normalizedExpression) <$> strongNormalize' holTy let er = ErrUnsolvedMeta UnsolvedMeta @@ -392,7 +418,7 @@ runInferenceState inis = reinterpret (runState inis) $ \case when (LeafExpressionHole hol `elem` holTy' ^.. leafExpressions) (throw er) s <- gets (fromJust . (^. inferenceMap . at hol)) case s of - Fresh -> modify (over inferenceMap (HashMap.insert hol (Refined holTy'))) + Fresh -> modify (set (inferenceMap . at hol) (Just (Refined holTy'))) Refined {} -> impossible goIden :: Iden -> Iden -> Sem r (Maybe MatchError) @@ -521,7 +547,7 @@ functionDefEval f = do return r where strongNorm :: (Members '[ResultBuilder, NameIdGen] r) => Expression -> Sem r Expression - strongNorm = evalState iniState . strongNormalize' + strongNorm = evalState iniState . fmap (^. normalizedExpression) . strongNormalize' isUniverse :: Expression -> Bool isUniverse = \case @@ -583,7 +609,10 @@ registerFunctionDef :: (Members '[ResultBuilder, Error TypeCheckerError, NameIdG registerFunctionDef f = whenJustM (functionDefEval f) $ \e -> addFunctionDef (f ^. funDefName) e -strongNormalize'' :: (Members '[Reader FunctionsTable, NameIdGen] r) => Expression -> Sem r Expression +strongNormalize'' :: + (Members '[Reader FunctionsTable, NameIdGen] r) => + Expression -> + Sem r NormalizedExpression strongNormalize'' ty = do ftab <- ask let importCtx = diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Error/Types.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Error/Types.hs index 2518b40448..86462bd544 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Error/Types.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Error/Types.hs @@ -159,18 +159,36 @@ instance ToGenericError WrongConstructorAppArgs where pat :: Int -> Doc ann pat n = pretty n <+> plural "pattern" "patterns" n +data WrongTypeThing + = WrongTypeThingPattern Pattern + | WrongTypeThingExpression WrongTypeThingExpression + +data WrongTypeThingExpression = MkWrongTypeThingExpression + { _wrongTypeNormalizedExpression :: NormalizedExpression, + _wrongTypeInferredExpression :: Expression + } + -- | the type of an expression does not match the inferred type data WrongType = WrongType - { _wrongTypeThing :: Either Expression Pattern, - _wrongTypeThingWithHoles :: Maybe (Either Expression Pattern), - _wrongTypeExpected :: Expression, - _wrongTypeActual :: Expression + { _wrongTypeThing :: WrongTypeThing, + _wrongTypeExpected :: NormalizedExpression, + _wrongTypeActual :: NormalizedExpression } makeLenses ''WrongType +makeLenses ''WrongTypeThingExpression + +instance HasLoc WrongTypeThing where + getLoc = \case + WrongTypeThingPattern p -> getLoc p + WrongTypeThingExpression e -> getLoc e +instance HasLoc WrongTypeThingExpression where + getLoc = getLoc . (^. wrongTypeNormalizedExpression) + +-- TODO we should show both the normalized and original version of the expression when relevant. instance ToGenericError WrongType where - genericError e = ask >>= generr + genericError err = ask >>= generr where generr opts = return @@ -181,24 +199,28 @@ instance ToGenericError WrongType where } where opts' = fromGenericOptions opts - i = either getLoc getLoc (e ^. wrongTypeThing) + i = getLoc (err ^. wrongTypeThing) msg = "The" + <+> thingName <+> thing - <+> either (ppCode opts') (ppCode opts') subjectThing <+> "has type:" <> line - <> indent' (ppCode opts' (e ^. wrongTypeActual)) + <> indent' (ppCode opts' (err ^. wrongTypeActual ^. normalizedExpression)) <> line <> "but is expected to have type:" <> line - <> indent' (ppCode opts' (e ^. wrongTypeExpected)) - thing :: Doc a - thing = case subjectThing of - Left {} -> "expression" - Right {} -> "pattern" - subjectThing :: Either Expression Pattern - subjectThing = fromMaybe (e ^. wrongTypeThing) (e ^. wrongTypeThingWithHoles) + <> indent' (ppCode opts' (err ^. wrongTypeExpected . normalizedExpression)) + + thingName :: Doc a + thingName = case err ^. wrongTypeThing of + WrongTypeThingExpression {} -> "expression" + WrongTypeThingPattern {} -> "pattern" + + thing :: Doc CodeAnn + thing = case err ^. wrongTypeThing of + WrongTypeThingExpression e -> ppCode opts' (e ^. wrongTypeInferredExpression) + WrongTypeThingPattern p -> ppCode opts' p -- | The left hand expression of a function application is not -- a function type. @@ -463,7 +485,7 @@ instance ToGenericError CoercionCycles where <> indent' (hsep (ppCode opts' <$> take 10 (toList (e ^. coercionCycles)))) data NoInstance = NoInstance - { _noInstanceType :: Expression, + { _noInstanceType :: NormalizedExpression, _noInstanceLoc :: Interval } @@ -484,10 +506,10 @@ instance ToGenericError NoInstance where i = e ^. noInstanceLoc msg = "No trait instance found for:" - <+> ppCode opts' (e ^. noInstanceType) + <+> ppCode opts' (e ^. noInstanceType . normalizedExpression) data AmbiguousInstances = AmbiguousInstances - { _ambiguousInstancesType :: Expression, + { _ambiguousInstancesType :: NormalizedExpression, _ambiguousInstancesInfos :: [InstanceInfo], _ambiguousInstancesLoc :: Interval } @@ -510,7 +532,7 @@ instance ToGenericError AmbiguousInstances where locs = itemize $ map (pretty . getLoc . (^. instanceInfoResult)) (e ^. ambiguousInstancesInfos) msg = "Multiple trait instances found for" - <+> ppCode opts' (e ^. ambiguousInstancesType) + <+> ppCode opts' (e ^. ambiguousInstancesType . normalizedExpression) <> line <> "Matching instances found at:" <> line diff --git a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Traits/Resolver.hs b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Traits/Resolver.hs index 872e9c8da2..2e4b2c2964 100644 --- a/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Traits/Resolver.hs +++ b/src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/Traits/Resolver.hs @@ -32,13 +32,13 @@ resolveTraitInstance :: TypedHole -> Sem r Expression resolveTraitInstance TypedHole {..} = do - vars <- overM localTypes (mapM strongNormalize) _typedHoleLocalVars + vars <- overM localTypes (mapM strongNormalize_) _typedHoleLocalVars infoTab <- ask tab0 <- getCombinedInstanceTable let tab = foldr (flip updateInstanceTable) tab0 (varsToInstances infoTab vars) ty <- strongNormalize _typedHoleType ctab <- getCombinedCoercionTable - is <- lookupInstance ctab tab ty + is <- lookupInstance ctab tab (ty ^. normalizedExpression) case is of [(cs, ii, subs)] -> expandArity loc (subsIToE subs) (ii ^. instanceInfoArgs) (ii ^. instanceInfoResult) @@ -257,7 +257,7 @@ lookupInstance :: InstanceTable -> Expression -> Sem r [(CoercionChain, InstanceInfo, SubsI)] -lookupInstance ctab tab ty = do +lookupInstance ctab tab ty = case traitFromExpression mempty ty of Just InstanceApp {..} -> lookupInstance' [] False ctab tab _instanceAppHead _instanceAppArgs