Skip to content

Commit

Permalink
fix: use maxType when building expression in expression tree elaborat…
Browse files Browse the repository at this point in the history
…or (#4215)

The expression tree elaborator computes a "maxType" that every leaf term
can be coerced to, but the elaborator was not ensuring that the entire
expression tree would have maxType as its type. This led to unexpected
errors in examples such as
```lean
example (a : Nat) (b : Int) :
  a = id (a * b^2) := sorry
```
where it would say it could not synthesize an `HMul Int Int Nat`
instance (the `Nat` would propagate from the `a` on the LHS of the
equality). The issue in this case is that `HPow` uses default instances,
so while the expression tree elaborator decides that `a * b^2` should be
referring to an `Int`, the actual elaborated type is temporarily a
metavariable. Then, when the binrel elaborator is looking at both sides
of the equality, it decides that `Nat` will work and coercions don't
need to be inserted.

The fix is to unify the type of the resulting elaborated expression with
the computed maxType. One wrinkle is that `hasUncomparable` being false
is a valid test only if there are no leaf terms with unknown types (if
they become known, it could change `hasUncomparable` to true), so this
unification is only performed if the leaf terms all have known types.

Fixes issue described by Floris van Doorn on
[Zulip](https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/elaboration.20issue.20involving.20powers.20and.20sums/near/439243587).
  • Loading branch information
kmill authored May 18, 2024
1 parent 02b6fb3 commit b639d10
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 22 deletions.
24 changes: 17 additions & 7 deletions src/Lean/Elab/Extra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,10 @@ private def hasCoe (fromType toType : Expr) : TermElabM Bool := do

private structure AnalyzeResult where
max? : Option Expr := none
hasUncomparable : Bool := false -- `true` if there are two types `α` and `β` where we don't have coercions in any direction.
/-- `true` if there are two types `α` and `β` where we don't have coercions in any direction. -/
hasUncomparable : Bool := false
/-- `true` if there are any leaf terms with an unknown type (according to `isUnknown`). -/
hasUnknown : Bool := false

private def isUnknown : Expr → Bool
| .mvar .. => true
Expand All @@ -255,7 +258,7 @@ private def analyze (t : Tree) (expectedType? : Option Expr) : TermElabM Analyze
match expectedType? with
| none => pure none
| some expectedType =>
let expectedType ← instantiateMVars expectedType
let expectedType := (← instantiateMVars expectedType).cleanupAnnotations
if isUnknown expectedType then pure none else pure (some expectedType)
(go t *> get).run' { max? }
where
Expand All @@ -268,8 +271,10 @@ where
| .binop _ _ _ lhs rhs => go lhs; go rhs
| .unop _ _ arg => go arg
| .term _ _ val =>
let type ← instantiateMVars (← inferType val)
unless isUnknown type do
let type := (← instantiateMVars (← inferType val)).cleanupAnnotations
if isUnknown type then
modify fun s => { s with hasUnknown := true }
else
match (← get).max? with
| none => modify fun s => { s with max? := type }
| some max =>
Expand Down Expand Up @@ -430,7 +435,7 @@ mutual
| .unop ref f arg =>
return .unop ref f (← go arg none false false)
| .term ref trees e =>
let type ← instantiateMVars (← inferType e)
let type := (← instantiateMVars (← inferType e)).cleanupAnnotations
trace[Elab.binop] "visiting {e} : {type} =?= {maxType}"
if isUnknown type then
if let some f := f? then
Expand All @@ -448,12 +453,17 @@ mutual

private partial def toExpr (tree : Tree) (expectedType? : Option Expr) : TermElabM Expr := do
let r ← analyze tree expectedType?
trace[Elab.binop] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}"
trace[Elab.binop] "hasUncomparable: {r.hasUncomparable}, hasUnknown: {r.hasUnknown}, maxType: {r.max?}"
if r.hasUncomparable || r.max?.isNone then
let result ← toExprCore tree
ensureHasType expectedType? result
else
let result ← toExprCore (← applyCoe tree r.max?.get! (isPred := false))
unless r.hasUnknown do
-- Record the resulting maxType calculation.
-- We can do this when all the types are known, since in this case `hasUncomparable` is valid.
-- If they're not known, recording maxType like this can lead to heterogeneous operations failing to elaborate.
discard <| isDefEqGuarded (← inferType result) r.max?.get!
trace[Elab.binop] "result: {result}"
ensureHasType expectedType? result

Expand Down Expand Up @@ -519,7 +529,7 @@ def elabBinRelCore (noProp : Bool) (stx : Syntax) (expectedType? : Option Expr)
let rhs ← withRef rhsStx <| toTree rhsStx
let tree := .binop stx .regular f lhs rhs
let r ← analyze tree none
trace[Elab.binrel] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}"
trace[Elab.binrel] "hasUncomparable: {r.hasUncomparable}, hasUnknown: {r.hasUnknown}, maxType: {r.max?}"
if r.hasUncomparable || r.max?.isNone then
-- Use default elaboration strategy + `toBoolIfNecessary`
let lhs ← toExprCore lhs
Expand Down
11 changes: 0 additions & 11 deletions tests/lean/binopIssues.lean

This file was deleted.

4 changes: 0 additions & 4 deletions tests/lean/binopIssues.lean.expected.out

This file was deleted.

41 changes: 41 additions & 0 deletions tests/lean/run/binop.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/-!
# Tests for the expression tree elaborator (`binop%`, etc.)
-/

/-!
Some basic Int/Nat examples
-/

example (n : Nat) (i : Int) : n + i = i + n := by
rw [Int.add_comm]

def f1 (a : Int) (b c : Nat) : Int :=
a + (b - c)

def f2 (a : Int) (b c : Nat) : Int :=
(b - c) + a

/--
info: def f1 : Int → Nat → Nat → Int :=
fun a b c => a + (↑b - ↑c)
-/
#guard_msgs in
#print f1

/--
info: def f2 : Int → Nat → Nat → Int :=
fun a b c => ↑b - ↑c + a
-/
#guard_msgs in
#print f2


/-!
Interaction with default instances for pow. This used to fail with not being able
to synthesize an `HMul Int Int Nat` instance because the type of
the result of `*` wasn't being set to `Int`.
-/

/-- info: ∀ (a : Nat) (b : Int), ↑a = id (↑a * b ^ 2) : Prop -/
#guard_msgs in
#check ∀ (a : Nat) (b : Int), a = id (a * b^2)

0 comments on commit b639d10

Please sign in to comment.