Skip to content

Commit

Permalink
feat: structural recursion over nested datatypes (#4733)
Browse files Browse the repository at this point in the history
This now works:

```lean
inductive Tree where | node : List Tree → Tree

mutual
def Tree.size : Tree → Nat
  | node ts => list_size ts

def Tree.list_size : List Tree → Nat
  | [] => 0
  | t::ts => t.size + list_size ts
end
```

It is still out of scope to expect to be able to use nested recursion
(e.g. through `List.map` or `List.foldl`) here.

Depends on #4718.

---------

Co-authored-by: Tobias Grosser <[email protected]>
  • Loading branch information
nomeata and tobiasgrosser authored Jul 15, 2024
1 parent 3ab2c71 commit de96b6d
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/Lean/Elab/PreDefinition/Structural/BRecOn.lean
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def inferBRecOnFTypes (recArgInfos : Array RecArgInfo) (positions : Positions)
-- And return the types of of the next arguments
arrowDomainsN numTypeFormers brecOnType

let mut FTypes := Array.mkArray recArgInfos.size (Expr.sort 0)
let mut FTypes := Array.mkArray positions.numIndices (Expr.sort 0)
for packedFType in packedFTypes, poss in positions do
for pos in poss do
FTypes := FTypes.set! pos packedFType
Expand Down
48 changes: 40 additions & 8 deletions src/Lean/Elab/PreDefinition/Structural/FindRecArg.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def getRecArgInfo (fnName : Name) (numFixed : Nat) (xs : Array Expr) (i : Nat) :
throwError "its type {indInfo.name} does not have a recursor"
else if indInfo.isReflexive && !(← hasConst (mkBInductionOnName indInfo.name)) && !(← isInductivePredicate indInfo.name) then
throwError "its type {indInfo.name} is a reflexive inductive, but {mkBInductionOnName indInfo.name} does not exist and it is not an inductive predicate"
else if indInfo.isNested then
throwError "its type {indInfo.name} is a nested inductive, which is not yet supported"
else
let indArgs : Array Expr := xType.getAppArgs
let indParams : Array Expr := indArgs[0:indInfo.numParams]
Expand Down Expand Up @@ -174,8 +172,7 @@ def inductiveGroups (recArgInfos : Array RecArgInfo) : MetaM (Array IndGroupInst
Filters the `recArgInfos` by those that describe an argument that's part of the recursive inductive
group `group`.
Anticipating support for nested inductives this function has the ability to change the `recArgInfo`.
Because of nested inductives this function has the ability to change the `recArgInfo`.
Consider
```
inductive Tree where | node : List Tree → Tree
Expand All @@ -184,9 +181,44 @@ then when we look for arguments whose type is part of the group `Tree`, we want
the argument of type `List Tree`, even though that argument’s `RecArgInfo` refers to initially to
`List`.
-/
def argsInGroup (group : IndGroupInst) (_xs : Array Expr) (_value : Expr)
def argsInGroup (group : IndGroupInst) (xs : Array Expr) (value : Expr)
(recArgInfos : Array RecArgInfo) : MetaM (Array RecArgInfo) := do
recArgInfos.filterM (group.isDefEq ·.indGroupInst)

let nestedTypeFormers ← group.nestedTypeFormers

recArgInfos.filterMapM fun recArgInfo => do
-- Is this argument from the same mutual group of inductives?
if (← group.isDefEq recArgInfo.indGroupInst) then
return (.some recArgInfo)

-- Can this argument be understood as the auxillary type former of a nested inductive?
if nestedTypeFormers.isEmpty then return .none
lambdaTelescope value fun ys _ => do
let x := (xs++ys)[recArgInfo.recArgPos]!
for nestedTypeFormer in nestedTypeFormers, indIdx in [group.all.size : group.numMotives] do
let xType ← whnfD (← inferType x)
let (indIndices, _, type) ← forallMetaTelescope nestedTypeFormer
if (← isDefEqGuarded type xType) then
let indIndices ← indIndices.mapM instantiateMVars
if !indIndices.all Expr.isFVar then
-- throwError "indices are not variables{indentExpr xType}"
continue
if !indIndices.allDiff then
-- throwError "indices are not pairwise distinct{indentExpr xType}"
continue
-- TODO: Do we have to worry about the indices ending up in the fixed prefix here?
if let some (_index, _y) ← hasBadIndexDep? ys indIndices then
-- throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}"
continue
let indicesPos := indIndices.map fun index => match (xs++ys).indexOf? index with | some i => i.val | none => unreachable!
return .some
{ fnName := recArgInfo.fnName
numFixed := recArgInfo.numFixed
recArgPos := recArgInfo.recArgPos
indicesPos := indicesPos
indGroupInst := group
indIdx := indIdx }
return .none

def maxCombinationSize : Nat := 10

Expand Down Expand Up @@ -234,7 +266,7 @@ def tryAllArgs (fnNames : Array Name) (xs : Array Expr) (values : Array Expr)
-- TODO: Here we used to save and restore the state. But should the `try`-`catch`
-- not suffice?
let r ← k comb
trace[Elab.definition.structural] "tryTellArgs report:\n{report}"
trace[Elab.definition.structural] "tryAllArgs report:\n{report}"
return r
catch e =>
let m ← prettyParameterSet fnNames xs values comb
Expand All @@ -243,7 +275,7 @@ def tryAllArgs (fnNames : Array Name) (xs : Array Expr) (values : Array Expr)
report := report ++ m!"Too many possible combinations of parameters of type {group} (or " ++
m!"please indicate the recursive argument explicitly using `termination_by structural`).\n"
report := m!"failed to infer structural recursion:\n" ++ report
trace[Elab.definition.structural] "tryTellArgs:\n{report}"
trace[Elab.definition.structural] "tryAllArgs:\n{report}"
throwError report

end Lean.Elab.Structural
28 changes: 28 additions & 0 deletions src/Lean/Elab/PreDefinition/Structural/IndGroupInfo.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,32 @@ def IndGroupInst.isDefEq (igi1 igi2 : IndGroupInst) : MetaM Bool := do
unless (← (igi1.params.zip igi2.params).allM (fun (e₁, e₂) => Meta.isDefEqGuarded e₁ e₂)) do return false
return true

/--
Figures out the nested type formers of an inductive group, with parameters instantiated
and indices still forall-abstracted.
For example given a nested inductive
```
inductive Tree α where | node : α → Vector (Tree α) n → Tree α
```
(where `n` is an index of `Vector`) and the instantiation `Tree Int` it will return
```
#[(n : Nat) → Vector (Tree Int) n]
```
-/
def IndGroupInst.nestedTypeFormers (igi : IndGroupInst) : MetaM (Array Expr) := do
if igi.numNested = 0 then return #[]
-- We extract this information from the motives of the recursor
let recName := mkRecName igi.all[0]!
let recInfo ← getConstInfoRec recName
assert! recInfo.numMotives = igi.numMotives
let aux := mkAppN (.const recName (0 :: igi.levels)) igi.params
let motives ← inferArgumentTypesN recInfo.numMotives aux
let auxMotives : Array Expr := motives[igi.all.size:]
auxMotives.mapM fun motive =>
forallTelescopeReducing motive fun xs _ => do
assert! xs.size > 0
mkForallFVars xs.pop (← inferType xs.back)

end Lean.Elab.Structural
8 changes: 6 additions & 2 deletions src/Lean/Elab/PreDefinition/Structural/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def getMutualFixedPrefix (preDefs : Array PreDefinition) : M Nat :=
private def elimMutualRecursion (preDefs : Array PreDefinition) (xs : Array Expr)
(recArgInfos : Array RecArgInfo) : M (Array PreDefinition) := do
let values ← preDefs.mapM (instantiateLambda ·.value xs)
let indInfo ← getConstInfoInduct recArgInfos[0]!.indName!
let indInfo ← getConstInfoInduct recArgInfos[0]!.indGroupInst.all[0]!
if ← isInductivePredicate indInfo.name then
-- Here we branch off to the IndPred construction, but only for non-mutual functions
unless preDefs.size = 1 do
Expand All @@ -108,14 +108,18 @@ private def elimMutualRecursion (preDefs : Array PreDefinition) (xs : Array Expr
return #[{ preDef with value := valueNew }]

-- Sort the (indices of the) definitions by their position in indInfo.all
let positions : Positions := .groupAndSort (·.indName!) recArgInfos indInfo.all.toArray
let positions : Positions := .groupAndSort (·.indIdx) recArgInfos (Array.range indInfo.numTypeFormers)
trace[Elab.definition.structural] "positions: {positions}"

-- Construct the common `.brecOn` arguments
let motives ← (Array.zip recArgInfos values).mapM fun (r, v) => mkBRecOnMotive r v
trace[Elab.definition.structural] "motives: {motives}"
let brecOnConst ← mkBRecOnConst recArgInfos positions motives
let FTypes ← inferBRecOnFTypes recArgInfos positions brecOnConst
trace[Elab.definition.structural] "FTypes: {FTypes}"
let FArgs ← (recArgInfos.zip (values.zip FTypes)).mapM fun (r, (v, t)) =>
mkBRecOnF recArgInfos positions r v t
trace[Elab.definition.structural] "FArgs: {FArgs}"
-- Assemble the individual `.brecOn` applications
let valuesNew ← (Array.zip recArgInfos values).mapIdxM fun i (r, v) =>
mkBrecOnApp positions i brecOnConst FArgs r v
Expand Down
5 changes: 5 additions & 0 deletions tests/lean/run/funind_proof.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@ mutual
end

mutual
-- Since #4733 this function can be compiled using structural recursion,
-- but then the construction of the functional induction principle falls over
-- TODO: Fix funind, and then omit the `termination_by` here (or test both variants)
def replaceConst (a b : String) : Term → Term
| const c => if a == c then const b else const c
| app f cs => app f (replaceConstLst a b cs)
termination_by t => sizeOf t

def replaceConstLst (a b : String) : List Term → List Term
| [] => []
| c :: cs => replaceConst a b c :: replaceConstLst a b cs
termination_by ts => sizeOf ts
end


Expand Down
31 changes: 31 additions & 0 deletions tests/lean/run/nestedTypeFormers.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import Lean

open Lean Meta Elab
open Lean.Elab.Structural

/-!
Unit test for `IndGroupInst.nestedTypeFormers`
-/

inductive Tree (α : Type u) : Type u
| node : α → (Bool → Tree α) → List (Tree α) → Tree α


/-- info: [List (Tree Bool)] -/
#guard_msgs in
run_meta
let igi : IndGroupInst := {all := #[``Tree], levels := [0], params := #[.const ``Bool []], numNested := 1}
logInfo m!"{← igi.nestedTypeFormers}"

inductive Vec (α : Type u) : Nat → Type u where
| empty : Vec α 0
| succ : α → Vec α n → Vec α (n + 1)

inductive VTree (α : Type u) : Type u
| node : α → Vec (VTree α) 32 → VTree α

/-- info: [(a : Nat) → Vec (VTree Bool) a] -/
#guard_msgs in
run_meta
let igi : IndGroupInst := {all := #[``VTree], levels := [0], params := #[.const ``Bool []], numNested := 1}
logInfo m!"{← igi.nestedTypeFormers}"
25 changes: 2 additions & 23 deletions tests/lean/run/structuralMutual.lean
Original file line number Diff line number Diff line change
Expand Up @@ -318,27 +318,6 @@ info: MutualIndNonMutualFun.A.weird_size1.eq_1 (a : A) : a.self.weird_size1 = a.

end MutualIndNonMutualFun

namespace NestedWithTuple

inductive Tree where
| leaf
| node : (Tree × Tree) → Tree

-- Nested recursion does not work (yet)

/--
error: cannot use specified parameter for structural recursion:
its type NestedWithTuple.Tree is a nested inductive, which is not yet supported
-/
#guard_msgs in
def Tree.size : Tree → Nat
| leaf => 0
| node (t₁, t₂) => t₁.size + t₂.size
termination_by structural t => t

end NestedWithTuple


namespace DifferentTypes

-- Check error message when argument types are not mutually recursive
Expand Down Expand Up @@ -539,13 +518,13 @@ Too many possible combinations of parameters of type Nattish (or please indicate
Could not find a decreasing measure.
The arguments relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
Call from ManyCombinations.f to ManyCombinations.g at 571:15-29:
Call from ManyCombinations.f to ManyCombinations.g at 550:15-29:
#1 #2 #3 #4
#5 ? ? ? ?
#6 ? = ? ?
#7 ? ? = ?
#8 ? ? ? =
Call from ManyCombinations.g to ManyCombinations.f at 574:15-29:
Call from ManyCombinations.g to ManyCombinations.f at 553:15-29:
#5 #6 #7 #8
#1 _ _ _ _
#2 _ = _ _
Expand Down
104 changes: 104 additions & 0 deletions tests/lean/run/structuralOverNested.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
inductive Tree where | node : List Tree → Tree

mutual
def Tree.size : Tree → Nat
| node ts => list_size ts

def Tree.list_size : List Tree → Nat
| [] => 0
| t::ts => t.size + list_size ts
end

example : Tree.list_size (t :: ts) = t.size + Tree.list_size ts := rfl

-- If we only look at the nested type at a finite depth we don't need an auxillary definition:

def Tree.isList : Tree → Bool
| .node [] => true
| .node [t] => t.isList
| .node _ => false


-- A nested inductive type
-- the `Bool → RTree α` prevents well-founded recursion and
-- tests support for reflexive types

inductive RTree (α : Type u) : Type u
| node : α → (Bool → RTree α) → List (RTree α) → RTree α

-- only recurse on the non-nested component
def RTree.simple_size : RTree α → Nat
| .node _x t _ts => 1 + (t true).simple_size + (t false).simple_size

/--
info: theorem RTree.simple_size.eq_1.{u_1} : ∀ {α : Type u_1} (_x : α) (t : Bool → RTree α) (_ts : List (RTree α)),
(RTree.node _x t _ts).simple_size = 1 + (t true).simple_size + (t false).simple_size :=
fun {α} _x t _ts => Eq.refl (RTree.node _x t _ts).simple_size
-/
#guard_msgs in
#print RTree.simple_size.eq_1

-- set_option trace.Elab.definition.structural true

-- also recurse on the nested components
#guard_msgs in
mutual
def RTree.size : RTree α → Nat
| .node _ t ts => 1 + (t true).size + (t false).size + aux_size ts
def RTree.aux_size : List (RTree α) → Nat
| [] => 0
| t::ts => t.size + aux_size ts
end

/--
info: theorem RTree.aux_size.eq_2.{u_1} : ∀ {α : Type u_1} (t : RTree α) (ts : List (RTree α)),
RTree.aux_size (t :: ts) = t.size + RTree.aux_size ts :=
fun {α} t ts => Eq.refl (RTree.aux_size (t :: ts))
-/
#guard_msgs in
#print RTree.aux_size.eq_2

mutual
def RTree.map (f : α → β) : RTree α → RTree β
| .node x t ts => .node (f x) (fun b => (t b).map f) (map_aux f ts)
def RTree.map_aux (f : α → β) : List (RTree α) → List (RTree β)
| [] => []
| t::ts => t.map f :: map_aux f ts
end

/--
info: theorem RTree.map_aux.eq_2.{u_1, u_2} : ∀ {α : Type u_1} {β : Type u_2} (f : α → β) (t : RTree α) (ts : List (RTree α)),
RTree.map_aux f (t :: ts) = RTree.map f t :: RTree.map_aux f ts :=
fun {α} {β} f t ts => Eq.refl (RTree.map_aux f (t :: ts))
-/
#guard_msgs in
#print RTree.map_aux.eq_2


inductive Vec (α : Type u) : Nat → Bool → Type u where
| empty : Vec α 0 false
| succ : α → Vec α n b → Vec α (n + 1) true

-- Now an example with indices all over the place

inductive VTree (α : Type u) : Bool → Nat → Type u
| node (b : Bool) (n : Nat) : α → (List Bool → List Nat → Vec (VTree α true 5) n b) → VTree α (!b) (n+1)

mutual
def VTree.size : VTree α b n → Nat
| .node _ _ _ f => 1 + vec_size (f [] [])
-- We have to write `VTree α true 5` here, and cannot write `VTree α b' n'` here.
-- This seems to be reasonable, cf. the type of the motives of `VTree.rec`
def VTree.vec_size : Vec (VTree α true 5) n b → Nat
| .empty => 0
| .succ t ts => t.size + vec_size ts
end

/--
info: theorem VTree.size.eq_1.{u_1} : ∀ {α : Type u_1} (b_2 : Bool) (n_2 : Nat) (a : α)
(f : List Bool → List Nat → Vec (VTree α true 5) n_2 b_2),
(VTree.node b_2 n_2 a f).size = 1 + VTree.vec_size (f [] []) :=
fun {α} b_2 n_2 a f => Eq.refl (VTree.node b_2 n_2 a f).size
-/
#guard_msgs in
#print VTree.size.eq_1

0 comments on commit de96b6d

Please sign in to comment.