diff --git a/dhall/src/Dhall/Eval.hs b/dhall/src/Dhall/Eval.hs index 73ba65b7a..c0ff76cc6 100644 --- a/dhall/src/Dhall/Eval.hs +++ b/dhall/src/Dhall/Eval.hs @@ -318,6 +318,12 @@ vCombine mk t u = t' (VRecordLit m, VRecordLit m') -> VRecordLit (Map.unionWith (vCombine Nothing) m m') + (VRecord m, u') | null m -> + u' + (t', VRecord m) | null m -> + t' + (VRecord m, VRecord m') -> + VRecord (Map.unionWith (vCombine Nothing) m m') (t', u') -> VCombine mk t' u' diff --git a/dhall/src/Dhall/Normalize.hs b/dhall/src/Dhall/Normalize.hs index c51377dfd..e17d4859e 100644 --- a/dhall/src/Dhall/Normalize.hs +++ b/dhall/src/Dhall/Normalize.hs @@ -548,15 +548,20 @@ normalizeWithM ctx e0 = loop (Syntax.denote e0) kts' = traverse (traverse loop) kts Combine cs mk x y -> decide <$> loop x <*> loop y where + mergeFields (RecordField _ expr _ _) (RecordField _ expr' _ _) = + Syntax.makeRecordField $ decide expr expr' decide (RecordLit m) r | Data.Foldable.null m = r decide l (RecordLit n) | Data.Foldable.null n = l decide (RecordLit m) (RecordLit n) = - RecordLit (Dhall.Map.unionWith f m n) - where - f (RecordField _ expr _ _) (RecordField _ expr' _ _) = - Syntax.makeRecordField $ decide expr expr' + RecordLit (Dhall.Map.unionWith mergeFields m n) + decide (Record m) r | Data.Foldable.null m = + r + decide l (Record n) | Data.Foldable.null n = + l + decide (Record m) (Record n) = + Record (Dhall.Map.unionWith mergeFields m n) decide l r = Combine cs mk l r CombineTypes cs x y -> decide <$> loop x <*> loop y @@ -949,6 +954,9 @@ isNormalized e0 = loop (Syntax.denote e0) decide (RecordLit m) _ | Data.Foldable.null m = False decide _ (RecordLit n) | Data.Foldable.null n = False decide (RecordLit _) (RecordLit _) = False + decide (Record m) _ | Data.Foldable.null m = False + decide _ (Record n) | Data.Foldable.null n = False + decide (Record _) (Record _) = False decide _ _ = True CombineTypes _ x y -> loop x && loop y && decide x y where diff --git a/dhall/src/Dhall/TypeCheck.hs b/dhall/src/Dhall/TypeCheck.hs index 9c2a30355..792d56f53 100644 --- a/dhall/src/Dhall/TypeCheck.hs +++ b/dhall/src/Dhall/TypeCheck.hs @@ -801,34 +801,22 @@ infer typer = loop Combine _ mk l r -> do _L' <- loop ctx l - let l'' = quote names (eval values l) - - _R' <- loop ctx r - - let r'' = quote names (eval values r) + let _L'' = quote names _L' - xLs' <- case _L' of - VRecord xLs' -> - return xLs' + let l' = eval values l - _ -> do - let _L'' = quote names _L' + let l'' = quote names l' - case mk of - Nothing -> die (MustCombineARecord '∧' l'' _L'') - Just t -> die (InvalidDuplicateField t l _L'') + _R' <- loop ctx r - xRs' <- case _R' of - VRecord xRs' -> - return xRs' + let _R'' = quote names _R' - _ -> do - let _R'' = quote names _R' + let r' = eval values r - case mk of - Nothing -> die (MustCombineARecord '∧' r'' _R'') - Just t -> die (InvalidDuplicateField t r _R'') + let r'' = quote names r' + -- The `Combine` operator should now work on record terms and also on record types. + -- We will use combineTypes or combineTypesCheck below as needed for each case. let combineTypes xs xLs₀' xRs₀' = do let combine x (VRecord xLs₁') (VRecord xRs₁') = combineTypes (x : xs) xLs₁' xRs₁' @@ -845,7 +833,45 @@ infer typer = loop return (VRecord xTs) - combineTypes [] xLs' xRs' + let combineTypesCheck xs xLs₀' xRs₀' = do + let combine x (VRecord xLs₁') (VRecord xRs₁') = + combineTypesCheck (x : xs) xLs₁' xRs₁' + + combine x _ _ = + die (FieldTypeCollision (NonEmpty.reverse (x :| xs))) + + let mL = Dhall.Map.toMap xLs₀' + let mR = Dhall.Map.toMap xRs₀' + + Foldable.sequence_ (Data.Map.intersectionWithKey combine mL mR) + + -- If both sides of `Combine` are record terms, we use combineTypes to figure out the resulting type. + -- If both sides are record types, we use combineTypesCheck and then return the upper bound of two types. + -- Otherwise there is a type error. + case (_L', l', _R', r') of + (VRecord xLs', _, VRecord xRs', _) -> do + combineTypes [] xLs' xRs' + + (VConst cL, VRecord xLs', VConst cR, VRecord xRs') -> do + let c = max cL cR + combineTypesCheck [] xLs' xRs' + return (VConst c) + + (_, _, VRecord _, _) -> do + case mk of + Nothing -> die (MustCombineARecord '∧' l'' _L'') + Just t -> die (InvalidDuplicateField t l _L'') + + (_, _, VConst _, _) -> do + case mk of + Nothing -> die (MustCombineARecord '∧' l'' _L'') + Just t -> die (InvalidDuplicateField t l _L'') + + _ -> do + case mk of + Nothing -> die (MustCombineARecord '∧' r'' _R'') + Just t -> die (InvalidDuplicateField t r _R'') + CombineTypes _ l r -> do _L' <- loop ctx l