Skip to content

Commit

Permalink
perf: globally cache intermediate type class results
Browse files Browse the repository at this point in the history
  • Loading branch information
JovanGerb committed May 13, 2024
1 parent fad881f commit eaae586
Showing 1 changed file with 89 additions and 46 deletions.
135 changes: 89 additions & 46 deletions src/Lean/Meta/SynthInstance.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ structure Instance where

structure GeneratorNode where
mvar : Expr
mvarType : Expr
key : Expr
mctx : MetavarContext
instances : Array Instance
Expand Down Expand Up @@ -187,6 +188,7 @@ structure State where
generatorStack : Array GeneratorNode := #[]
resumeStack : Array (ConsumerNode × Answer) := #[]
tableEntries : HashMap Expr TableEntry := {}
cacheEntries : Array ((LocalInstances × Expr) × Expr) := #[]

abbrev SynthM := ReaderT Context $ StateRefT State MetaM

Expand Down Expand Up @@ -247,7 +249,7 @@ def mkGeneratorNode? (key mvar : Expr) : MetaM (Option GeneratorNode) := do
else
let mctx ← getMCtx
return some {
mvar, key, mctx, instances
mvar, mvarType, key, mctx, instances
typeHasMVars := mvarType.hasMVar
currInstanceIdx := instances.size
}
Expand Down Expand Up @@ -490,6 +492,22 @@ private def removeUnusedArguments? (mctx : MetavarContext) (mvar : Expr) : MetaM
trace[Meta.synthInstance.unusedArgs] "{mvarType}\nhas unused arguments, reduced type{indentExpr mvarType'}\nTransformer{indentExpr transformer}"
return some (mvarType', transformer)

def checkGlobalCache (mvar : Expr) (mctx : MetavarContext) : MetaM (Option (Option Answer)) :=
withMCtx mctx do
let mvarType ← inferType mvar
let mvarType ← instantiateMVars mvarType
if mvarType.hasMVar then
return none
match (← get).cache.synthInstance.find? (← getLocalInstances, mvarType) with
| none => return none
| some none => return some none
| some (some inst) => return some $ some {
result := ← abstractMVars inst
resultType := mvarType
size := 1 }



/-- Process the next subgoal in the given consumer node. -/
def consume (cNode : ConsumerNode) : SynthM Unit := do
/- Filter out subgoals that have already been assigned when solving typing constraints.
Expand All @@ -510,35 +528,43 @@ def consume (cNode : ConsumerNode) : SynthM Unit := do
match cNode.subgoals with
| [] => addAnswer cNode
| mvar::_ =>
let waiter := Waiter.consumerNode cNode
let key ← mkTableKeyFor cNode.mctx mvar
let entry? ← findEntry? key
match entry? with
| none =>
-- Remove unused arguments and try again, see comment at `removeUnusedArguments?`
match (← removeUnusedArguments? cNode.mctx mvar) with
| none => newSubgoal cNode.mctx key mvar waiter
| some (mvarType', transformer) =>
let key' ← withMCtx cNode.mctx <| mkTableKey mvarType'
match (← findEntry? key') with
| none =>
let (mctx', mvar') ← withMCtx cNode.mctx do
let mvar' ← mkFreshExprMVar mvarType'
return (← getMCtx, mvar')
newSubgoal mctx' key' mvar' (Waiter.consumerNode { cNode with mctx := mctx', subgoals := mvar'::cNode.subgoals })
| some entry' =>
let answers' ← entry'.answers.mapM fun a => withMCtx cNode.mctx do
let trAnswr := Expr.betaRev transformer #[← instantiateMVars a.result.expr]
let trAnswrType ← inferType trAnswr
pure { a with result.expr := trAnswr, resultType := trAnswrType }
modify fun s =>
{ s with
resumeStack := answers'.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
tableEntries := s.tableEntries.insert key' { entry' with waiters := entry'.waiters.push waiter } }
| some entry => modify fun s =>
{ s with
resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
tableEntries := s.tableEntries.insert key { entry with waiters := entry.waiters.push waiter } }
let waiter := Waiter.consumerNode cNode

match ← checkGlobalCache mvar cNode.mctx with
| some result =>
if let some answer := result then
modify fun s =>
{ s with
resumeStack := s.resumeStack.push (cNode, answer) }
| none =>
let key ← mkTableKeyFor cNode.mctx mvar
let entry? ← findEntry? key
match entry? with
| none =>
-- Remove unused arguments and try again, see comment at `removeUnusedArguments?`
match (← removeUnusedArguments? cNode.mctx mvar) with
| none => newSubgoal cNode.mctx key mvar waiter
| some (mvarType', transformer) =>
let key' ← withMCtx cNode.mctx <| mkTableKey mvarType'
match (← findEntry? key') with
| none =>
let (mctx', mvar') ← withMCtx cNode.mctx do
let mvar' ← mkFreshExprMVar mvarType'
return (← getMCtx, mvar')
newSubgoal mctx' key' mvar' (Waiter.consumerNode { cNode with mctx := mctx', subgoals := mvar'::cNode.subgoals })
| some entry' =>
let answers' ← entry'.answers.mapM fun a => withMCtx cNode.mctx do
let trAnswr := Expr.betaRev transformer #[← instantiateMVars a.result.expr]
let trAnswrType ← inferType trAnswr
pure { a with result.expr := trAnswr, resultType := trAnswrType }
modify fun s =>
{ s with
resumeStack := answers'.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
tableEntries := s.tableEntries.insert key' { entry' with waiters := entry'.waiters.push waiter } }
| some entry => modify fun s =>
{ s with
resumeStack := entry.answers.foldl (fun s answer => s.push (cNode, answer)) s.resumeStack,
tableEntries := s.tableEntries.insert key { entry with waiters := entry.waiters.push waiter } }

def getTop : SynthM GeneratorNode :=
return (← get).generatorStack.back
Expand All @@ -551,6 +577,14 @@ def generate : SynthM Unit := do
let gNode ← getTop
if gNode.currInstanceIdx == 0 then
modify fun s => { s with generatorStack := s.generatorStack.pop }
unless gNode.typeHasMVars do
if let some entry := (← get).tableEntries.find? gNode.key then
if h : entry.answers.size > 0 then
let answer := entry.answers[0].result
if answer.numMVars == 0 then
let inst := answer.expr
let cacheKey := (← getLocalInstances, gNode.mvarType)
modify fun s => { s with cacheEntries := s.cacheEntries.push (cacheKey, inst)}
else
let key := gNode.key
let idx := gNode.currInstanceIdx - 1
Expand All @@ -561,13 +595,18 @@ def generate : SynthM Unit := do
if backward.synthInstance.canonInstances.get (← getOptions) then
unless gNode.typeHasMVars do
if let some entry := (← get).tableEntries.find? key then
unless entry.answers.isEmpty do
if h : entry.answers.size > 0 then
/-
We already have an answer for this node, and since its type does not have metavariables,
we can skip other solutions because we assume instances are "morally canonical".
We have added this optimization to address issue #3996.
-/
modify fun s => { s with generatorStack := s.generatorStack.pop }
let answer := entry.answers[0].result
if answer.numMVars == 0 then
let inst := answer.expr
let cacheKey := (← getLocalInstances, gNode.mvarType)
modify fun s => { s with cacheEntries := s.cacheEntries.push (cacheKey, inst)}
return
discard do withMCtx mctx do
withTraceNode `Meta.synthInstance
Expand Down Expand Up @@ -628,18 +667,22 @@ partial def synth : SynthM (Option AbstractMVarsResult) := do

def main (type : Expr) (maxResultSize : Nat) : MetaM (Option AbstractMVarsResult) :=
withCurrHeartbeats do
let mvar ← mkFreshExprMVar type
let key ← mkTableKey type
let action : SynthM (Option AbstractMVarsResult) := do
newSubgoal (← getMCtx) key mvar Waiter.root
synth
tryCatchRuntimeEx
(action.run { maxResultSize := maxResultSize, maxHeartbeats := getMaxHeartbeats (← getOptions) } |>.run' {})
fun ex =>
if ex.isRuntime then
throwError "failed to synthesize{indentExpr type}\n{ex.toMessageData}"
else
throw ex
let mvar ← mkFreshExprMVar type
let key ← mkTableKey type
let action : SynthM (Option AbstractMVarsResult) := do
newSubgoal (← getMCtx) key mvar Waiter.root
synth
let (result, { cacheEntries, ..}) ← tryCatchRuntimeEx
(action.run { maxResultSize := maxResultSize, maxHeartbeats := getMaxHeartbeats (← getOptions) } |>.run {})
fun ex =>
if ex.isRuntime then
throwError "failed to synthesize{indentExpr type}\n{ex.toMessageData}"
else
throw ex
let cache := (← get).cache.synthInstance
let cache ← cacheEntries.foldlM (fun c (k, e) => return c.insert k e) cache
modify fun s => { s with cache.synthInstance := cache}
return result

end SynthInstance

Expand Down Expand Up @@ -699,7 +742,8 @@ private def preprocessOutParam (type : Expr) : MetaM Expr :=
Remark: we use a different option for controlling the maximum result size for coercions.
-/

def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (Option Expr) := do profileitM Exception "typeclass inference" (← getOptions) (decl := type.getAppFn.constName?.getD .anonymous) do
def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (Option Expr) := do
profileitM Exception "typeclass inference" (← getOptions) (decl := type.getAppFn.constName?.getD .anonymous) do
let opts ← getOptions
let maxResultSize := maxResultSize?.getD (synthInstance.maxSize.get opts)
withTraceNode `Meta.synthInstance
Expand All @@ -710,7 +754,6 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (
let localInsts ← getLocalInstances
let type ← instantiateMVars type
let type ← preprocess type
let s ← get
let rec assignOutParams (result : Expr) : MetaM Bool := do
let resultType ← inferType result
/- Output parameters of local instances may be marked as `syntheticOpaque` by the application-elaborator.
Expand All @@ -720,7 +763,7 @@ def synthInstance? (type : Expr) (maxResultSize? : Option Nat := none) : MetaM (
unless defEq do
trace[Meta.synthInstance] "{crossEmoji} result type{indentExpr resultType}\nis not definitionally equal to{indentExpr type}"
return defEq
match s.cache.synthInstance.find? (localInsts, type) with
match (← get).cache.synthInstance.find? (localInsts, type) with
| some result =>
trace[Meta.synthInstance] "result {result} (cached)"
if let some inst := result then
Expand Down

0 comments on commit eaae586

Please sign in to comment.