Skip to content

Commit

Permalink
fix: simp usedSimps (#3821)
Browse files Browse the repository at this point in the history
When `discharge?` failed, the `usedSimps` was being restored, but the
cache wasn't. This bug was exposed by issue #3710.

This PR makes the following changes:
- We restore the `cache` at `discharge?`. We use `SMap` to ensure the
operation is efficient.
- We don't need the field `dischargeDepth` anymore at `Simp.Result`.
- `UsedSimps` should use `PHashMap` since it is not used linearly.

closes #3710

---------

Co-authored-by: Mario Carneiro <[email protected]>
  • Loading branch information
leodemoura and digama0 authored Apr 2, 2024
1 parent 0684c95 commit f35fc18
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 21 deletions.
3 changes: 3 additions & 0 deletions src/Lean/Data/PersistentHashMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ def map {α : Type u} {β : Type v} {σ : Type u} {_ : BEq α} {_ : Hashable α}
def toList {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) : List (α × β) :=
m.foldl (init := []) fun ps k v => (k, v) :: ps

def toArray {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) : Array (α × β) :=
m.foldl (init := #[]) fun ps k v => ps.push (k, v)

structure Stats where
numNodes : Nat := 0
numNull : Nat := 0
Expand Down
3 changes: 3 additions & 0 deletions src/Lean/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Data.Hashable
import Lean.Data.KVMap
import Lean.Data.SMap
import Lean.Level

namespace Lean
Expand Down Expand Up @@ -1389,6 +1390,8 @@ def mkDecIsFalse (pred proof : Expr) :=

abbrev ExprMap (α : Type) := HashMap Expr α
abbrev PersistentExprMap (α : Type) := PHashMap Expr α
abbrev SExprMap (α : Type) := SMap Expr α

abbrev ExprSet := HashSet Expr
abbrev PersistentExprSet := PHashSet Expr
abbrev PExprSet := PersistentExprSet
Expand Down
11 changes: 2 additions & 9 deletions src/Lean/Meta/Tactic/Simp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,7 @@ def simpStep (e : Expr) : SimpM Result := do

def cacheResult (e : Expr) (cfg : Config) (r : Result) : SimpM Result := do
if cfg.memoize && r.cache then
let ctx ← readThe Simp.Context
let dischargeDepth := ctx.dischargeDepth
modify fun s => { s with cache := s.cache.insert e { r with dischargeDepth } }
modify fun s => { s with cache := s.cache.insert e r }
return r

partial def simpLoop (e : Expr) : SimpM Result := withIncRecDepth do
Expand Down Expand Up @@ -634,12 +632,7 @@ where
if cfg.memoize then
let cache := (← get).cache
if let some result := cache.find? e then
/-
If the result was cached at a dischargeDepth > the current one, it may not be valid.
See issue #1234
-/
if result.dischargeDepth ≤ (← readThe Simp.Context).dischargeDepth then
return result
return result
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e

Expand Down
5 changes: 3 additions & 2 deletions src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ inductive DischargeResult where
deriving DecidableEq

/--
Wrapper for invoking `discharge?`. It checks for maximum discharge depth, create trace nodes, and ensure
Wrapper for invoking `discharge?` method. It checks for maximum discharge depth, create trace nodes, and ensure
the generated proof was successfully assigned to `x`.
-/
def discharge?' (thmId : Origin) (x : Expr) (type : Expr) : SimpM Bool := do
Expand All @@ -44,8 +44,9 @@ def discharge?' (thmId : Origin) (x : Expr) (type : Expr) : SimpM Bool := do
else withTheReader Context (fun ctx => { ctx with dischargeDepth := ctx.dischargeDepth + 1 }) do
-- We save the state, so that `UsedTheorems` does not accumulate
-- `simp` lemmas used during unsuccessful discharging.
-- We use `withPreservedCache` to ensure the cache is restored after `discharge?`
let usedTheorems := (← get).usedTheorems
match (← discharge? type) with
match (← withPreservedCache <| (← getMethods).discharge? type) with
| some proof =>
unless (← isDefEq x proof) do
modify fun s => { s with usedTheorems }
Expand Down
27 changes: 17 additions & 10 deletions src/Lean/Meta/Tactic/Simp/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ structure Result where
/-- A proof that `$e = $expr`, where the simplified expression is on the RHS.
If `none`, the proof is assumed to be `refl`. -/
proof? : Option Expr := none
/-- Save the field `dischargeDepth` at `Simp.Context` because it impacts the simplifier result. -/
dischargeDepth : UInt32 := 0
/-- If `cache := true` the result is cached. -/
cache : Bool := true
deriving Inhabited
Expand All @@ -44,7 +42,8 @@ def Result.mkEqSymm (e : Expr) (r : Simp.Result) : MetaM Simp.Result :=
| none => return { r with expr := e }
| some p => return { r with expr := e, proof? := some (← Meta.mkEqSymm p) }

abbrev Cache := ExprMap Result
-- We use `SExprMap` because we want to discard cached results after a `discharge?`
abbrev Cache := SExprMap Result

abbrev CongrCache := ExprMap (Option CongrTheorem)

Expand Down Expand Up @@ -92,7 +91,8 @@ structure Context where
def Context.isDeclToUnfold (ctx : Context) (declName : Name) : Bool :=
ctx.simpTheorems.isDeclToUnfold declName

abbrev UsedSimps := HashMap Origin Nat
-- We should use `PHashMap` because we backtrack the contents of `UsedSimps`
abbrev UsedSimps := PHashMap Origin Nat

structure State where
cache : Cache := {}
Expand Down Expand Up @@ -254,9 +254,6 @@ def pre (e : Expr) : SimpM Step := do
def post (e : Expr) : SimpM Step := do
(← getMethods).post e

def discharge? (e : Expr) : SimpM (Option Expr) := do
(← getMethods).discharge? e

@[inline] def getContext : SimpM Context :=
readThe Context

Expand All @@ -272,16 +269,26 @@ def getSimpTheorems : SimpM SimpTheoremsArray :=
def getSimpCongrTheorems : SimpM SimpCongrTheorems :=
return (← readThe Context).congrTheorems

@[inline] def savingCache (x : SimpM α) : SimpM α := do
@[inline] def withPreservedCache (x : SimpM α) : SimpM α := do
-- Recall that `cache.map₁` should be used linearly but `cache.map₂` is great for copies.
let savedMap₂ := (← get).cache.map₂
let savedStage₁ := (← get).cache.stage₁
modify fun s => { s with cache := s.cache.switch }
try x finally modify fun s => { s with cache.map₂ := savedMap₂, cache.stage₁ := savedStage₁ }

/--
Save current cache, reset it, execute `x`, and then restore original cache.
-/
@[inline] def withFreshCache (x : SimpM α) : SimpM α := do
let cacheSaved := (← get).cache
modify fun s => { s with cache := {} }
try x finally modify fun s => { s with cache := cacheSaved }

@[inline] def withSimpTheorems (s : SimpTheoremsArray) (x : SimpM α) : SimpM α := do
savingCache <| withTheReader Context (fun ctx => { ctx with simpTheorems := s }) x
withFreshCache <| withTheReader Context (fun ctx => { ctx with simpTheorems := s }) x

@[inline] def withDischarger (discharge? : Expr → SimpM (Option Expr)) (x : SimpM α) : SimpM α :=
savingCache <| withReader (fun r => { MethodsRef.toMethods r with discharge? }.toMethodsRef) x
withFreshCache <| withReader (fun r => { MethodsRef.toMethods r with discharge? }.toMethodsRef) x

def recordSimpTheorem (thmId : Origin) : SimpM Unit := do
/-
Expand Down
33 changes: 33 additions & 0 deletions tests/lean/run/3710.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
def Set := Nat → Prop

namespace Set

def singleton (a : Nat) : Set := fun b ↦ b = a

def compl (s : Set) : Set := fun x ↦ ¬ s x

@[simp]
theorem compl_iff (s : Set) (x : Nat) : s.compl x ↔ ¬ s x := Iff.rfl

@[simp]
theorem singleton_iff {a b : Nat} : singleton b a ↔ a = b := Iff.rfl

open Classical

noncomputable def indicator (s : Set) (x : Nat) : Nat := if s x then 1 else 0

@[simp] -- remove `simp` attribute --> works (and the trace changes)
theorem indicator_of {s : Set} {a : Nat} (h : s a) : indicator s a = 1 := if_pos h

@[simp]
theorem indicator_of_not {s : Set} {a : Nat} (h : ¬ s a) : indicator s a = 0 := if_neg h

/--
info: Try this: simp only [compl_iff, singleton_iff, not_true_eq_false, not_false_eq_true, indicator_of_not]
-/
#guard_msgs in
theorem test : indicator (compl <| singleton 0) 0 = 0 := by
simp? -- should not leave out `singleton_iff`

theorem test' : indicator (compl <| singleton 0) 0 = 0 := by
simp only [compl_iff, singleton_iff, not_true_eq_false, not_false_eq_true, indicator_of_not]

0 comments on commit f35fc18

Please sign in to comment.