Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve support for equations in cutsat #7203

Merged
merged 10 commits into from
Feb 24, 2025
126 changes: 36 additions & 90 deletions src/Init/Data/Int/Linear.lean
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,12 @@ theorem le_unsat (ctx : Context) (p : Poly) : p.isUnsatLe → p.denote' ctx ≤
have := Int.lt_of_le_of_lt h₂ h₁
simp at this

theorem eq_norm (ctx : Context) (p₁ p₂ : Poly) (h : p₁.norm == p₂) : p₁.denote' ctx = 0 → p₂.denote' ctx = 0 := by
simp at h
replace h := congrArg (Poly.denote ctx) h
simp at h
simp [*]

def Poly.coeff (p : Poly) (x : Var) : Int :=
match p with
| .add a y p => bif x == y then a else coeff p x
Expand All @@ -864,17 +870,28 @@ private theorem dvd_of_eq' {a x p : Int} : a*x + p = 0 → a ∣ p := by
rw [Int.mul_comm, ← Int.neg_mul, Eq.comm, Int.mul_comm] at h
exact ⟨-x, h⟩

private def abs (x : Int) : Int :=
Int.ofNat x.natAbs

private theorem abs_dvd {a p : Int} (h : a ∣ p) : abs a ∣ p := by
cases a <;> simp [abs]
· simp at h; assumption
· simp [Int.negSucc_eq] at h; assumption

def dvd_of_eq_cert (x : Var) (p₁ : Poly) (d₂ : Int) (p₂ : Poly) : Bool :=
d₂ == p₁.coeff x && p₂ == p₁.insert (-d₂) x
let a := p₁.coeff x
d₂ == abs a && p₂ == p₁.insert (-a) x

theorem dvd_of_eq (ctx : Context) (x : Var) (p₁ : Poly) (d₂ : Int) (p₂ : Poly)
: dvd_of_eq_cert x p₁ d₂ p₂ → p₁.denote' ctx = 0 → d₂ ∣ p₂.denote' ctx := by
simp [dvd_of_eq_cert]
intro h₁ h₂
have h := eq_add_coeff_insert ctx p₁ x
rw [← h₁, ← h₂] at h
rw [h]
apply dvd_of_eq'
rw [← h₂] at h
rw [h, h₁]
intro h₃
apply abs_dvd
apply dvd_of_eq' h₃

private theorem eq_dvd_subst' {a x p d b q : Int} : a*x + p = 0 → d ∣ b*x + q → a*d ∣ a*q - b*p := by
intro h₁ ⟨z, h₂⟩
Expand All @@ -892,7 +909,7 @@ def eq_dvd_subst_cert (x : Var) (p₁ : Poly) (d₂ : Int) (p₂ : Poly) (d₃ :
let b := p₂.coeff x
let p := p₁.insert (-a) x
let q := p₂.insert (-b) x
d₃ == a * d₂ &&
d₃ == abs (a * d₂) &&
p₃ == (q.mul a |>.combine (p.mul (-b)))

theorem eq_dvd_subst (ctx : Context) (x : Var) (p₁ : Poly) (d₂ : Int) (p₂ : Poly) (d₃ : Int) (p₃ : Poly)
Expand All @@ -913,124 +930,53 @@ theorem eq_dvd_subst (ctx : Context) (x : Var) (p₁ : Poly) (d₂ : Int) (p₂
rw [Int.add_comm] at h₁ h₂
have := eq_dvd_subst' h₁ h₂
rw [Int.sub_eq_add_neg, Int.add_comm] at this
apply abs_dvd
simp [this]

private theorem eq_eq_subst' {a x p b q : Int} : a*x + p = 0 → b*x + q = 0 → b*p - a*q = 0 := by
intro h₁ h₂
replace h₁ := congrArg (b*·) h₁; simp at h₁
replace h₂ := congrArg ((-a)*.) h₂; simp at h₂
rw [Int.add_comm] at h₁
replace h₁ := Int.neg_eq_of_add_eq_zero h₁
rw [← h₁]; clear h₁
replace h₂ := Int.neg_eq_of_add_eq_zero h₂; simp at h₂
rw [h₂]; clear h₂
rw [Int.mul_left_comm]
simp

def eq_eq_subst_cert (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : Bool :=
let a := p₁.coeff x
let b := p₂.coeff x
let p := p₁.insert (-a) x
let q := p₂.insert (-b) x
p₃ == (p.mul b |>.combine (q.mul (-a)))
p₃ == (p₁.mul b |>.combine (p₂.mul (-a)))

theorem eq_eq_subst (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly)
: eq_eq_subst_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 → p₃.denote' ctx = 0 := by
simp [eq_eq_subst_cert]
have eq₁ := eq_add_coeff_insert ctx p₁ x
have eq₂ := eq_add_coeff_insert ctx p₂ x
revert eq₁ eq₂
generalize p₁.coeff x = a
generalize p₂.coeff x = b
generalize p₁.insert (-a) x = p
generalize p₂.insert (-b) x = q
intro eq₁; simp [eq₁]; clear eq₁
intro eq₂; simp [eq₂]; clear eq₂
intro; subst p₃
intro h₁ h₂
rw [Int.add_comm] at h₁ h₂
have := eq_eq_subst' h₁ h₂
rw [Int.sub_eq_add_neg] at this
simp [this]

private theorem eq_le_subst_nonneg' {a x p b q : Int} : a ≥ 0 → a*x + p = 0 → b*x + q ≤ 0 → a*q - b*p ≤ 0 := by
intro h h₁ h₂
replace h₁ := congrArg ((-b)*·) h₁; simp at h₁
rw [Int.add_comm, Int.mul_left_comm] at h₁
replace h₁ := Int.neg_eq_of_add_eq_zero h₁; simp at h₁
replace h₂ := Int.mul_le_mul_of_nonneg_left h₂ h
rw [Int.mul_add, h₁] at h₂; clear h₁
simp at h₂
rw [Int.sub_eq_add_neg]
assumption
simp [*]

def eq_le_subst_nonneg_cert (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : Bool :=
let a := p₁.coeff x
let b := p₂.coeff x
let p := p₁.insert (-a) x
let q := p₂.insert (-b) x
a ≥ 0 && p₃ == (q.mul a |>.combine (p.mul (-b)))
a ≥ 0 && p₃ == (p₂.mul a |>.combine (p₁.mul (-b)))

theorem eq_le_subst_nonneg (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly)
: eq_le_subst_nonneg_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by
simp [eq_le_subst_nonneg_cert]
have eq₁ := eq_add_coeff_insert ctx p₁ x
have eq₂ := eq_add_coeff_insert ctx p₂ x
revert eq₁ eq₂
generalize p₁.coeff x = a
generalize p₂.coeff x = b
generalize p₁.insert (-a) x = p
generalize p₂.insert (-b) x = q
intro eq₁; simp [eq₁]; clear eq₁
intro eq₂; simp [eq₂]; clear eq₂
intro h
intro; subst p₃
intro h₁ h₂
rw [Int.add_comm] at h₁ h₂
have := eq_le_subst_nonneg' h h₁ h₂
rw [Int.sub_eq_add_neg, Int.add_comm] at this
simp [this]

private theorem eq_le_subst_nonpos' {a x p b q : Int} : a ≤ 0 → a*x + p = 0 → b*x + q ≤ 0 → b*p - a*q ≤ 0 := by
intro h h₁ h₂
replace h₁ := congrArg (b*·) h₁; simp at h₁
rw [Int.add_comm, Int.mul_left_comm] at h₁
replace h₁ := Int.neg_eq_of_add_eq_zero h₁; simp at h₁
replace h : (-a) ≥ 0 := by
have := Int.neg_le_neg h
simp at this
exact this
replace h₂ := Int.mul_le_mul_of_nonneg_left h₂ h; simp at h₂; clear h
rw [h₁] at h₂
rw [Int.add_comm, ←Int.sub_eq_add_neg] at h₂
assumption
replace h₂ := Int.mul_le_mul_of_nonneg_left h₂ h
simp at h₂
simp [*]

def eq_le_subst_nonpos_cert (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly) : Bool :=
let a := p₁.coeff x
let b := p₂.coeff x
let p := p₁.insert (-a) x
let q := p₂.insert (-b) x
a ≤ 0 && p₃ == (p.mul b |>.combine (q.mul (-a)))
a ≤ 0 && p₃ == (p₁.mul b |>.combine (p₂.mul (-a)))

theorem eq_le_subst_nonpos (ctx : Context) (x : Var) (p₁ : Poly) (p₂ : Poly) (p₃ : Poly)
: eq_le_subst_nonpos_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by
simp [eq_le_subst_nonpos_cert]
have eq₁ := eq_add_coeff_insert ctx p₁ x
have eq₂ := eq_add_coeff_insert ctx p₂ x
revert eq₁ eq₂
generalize p₁.coeff x = a
generalize p₂.coeff x = b
generalize p₁.insert (-a) x = p
generalize p₂.insert (-b) x = q
intro eq₁; simp [eq₁]; clear eq₁
intro eq₂; simp [eq₂]; clear eq₂
intro h
intro; subst p₃
intro h₁ h₂
rw [Int.add_comm] at h₁ h₂
have := eq_le_subst_nonpos' h h₁ h₂
rw [Int.sub_eq_add_neg] at this
simp [this]
simp [*]
replace h₂ := Int.mul_le_mul_of_nonpos_left h₂ h; simp at h₂; clear h
rw [← Int.neg_zero]
apply Int.neg_le_neg
rw [Int.mul_comm]
assumption

end Int.Linear

Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind/Arith/Cutsat.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.EqCnstr
namespace Lean

builtin_initialize registerTraceClass `grind.cutsat
builtin_initialize registerTraceClass `grind.cutsat.subst
builtin_initialize registerTraceClass `grind.cutsat.eq
builtin_initialize registerTraceClass `grind.cutsat.assert
builtin_initialize registerTraceClass `grind.cutsat.assert.dvd
Expand Down
81 changes: 77 additions & 4 deletions src/Lean/Meta/Tactic/Grind/Arith/Cutsat/EqCnstr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,95 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Var
import Lean.Meta.Tactic.Grind.Arith.Cutsat.DvdCnstr

namespace Lean.Meta.Grind.Arith.Cutsat

def mkEqCnstr (p : Poly) (h : EqCnstrProof) : GoalM EqCnstr := do
return { p, h, id := (← mkCnstrId) }

def EqCnstr.norm (c : EqCnstr) : GoalM EqCnstr := do
let c ← if c.p.isSorted then
pure c
else
mkEqCnstr c.p.norm (.norm c)

/--
Selects the variable in the given linear polynomial whose coefficient has the smallest absolute value.
-/
def _root_.Int.Linear.Poly.pickVarToElim? (p : Poly) : Option (Int × Var) :=
match p with
| .num _ => none
| .add k x p => go k x p
where
go (k : Int) (x : Var) (p : Poly) : Int × Var :=
if k == 1 || k == -1 then
(k, x)
else match p with
| .num _ => (k, x)
| .add k' x' p =>
if k'.natAbs < k.natAbs then
go k' x' p
else
go k x p

/--
Given a polynomial `p`, returns `some (x, k, c)` if `p` contains the monomial `k*x`,
and `x` has been eliminated using the equality `c`.
-/
def _root_.Int.Linear.Poly.findVarToSubst (p : Poly) : GoalM (Option (Int × Var × EqCnstr)) := do
match p with
| .num _ => return none
| .add k x p =>
if let some c := (← get').elimEqs[x]! then
return some (k, x, c)
else
findVarToSubst p

partial def applySubsts (c : EqCnstr) : GoalM EqCnstr := do
let some (a, x, c₁) ← c.p.findVarToSubst | return c
trace[grind.cutsat.subst] "{← getVar x}, {← c.pp}, {← c₁.pp}"
let b := c₁.p.coeff x
let p := c.p.mul (-b) |>.combine (c₁.p.mul a)
let c ← mkEqCnstr p (.subst x c₁ c)
applySubsts c

def EqCnstr.assert (c : EqCnstr) : GoalM Unit := do
if (← isInconsistent) then return ()
trace[grind.cutsat.assert] "{← c.pp}"
let c ← c.norm
let c ← applySubsts c
-- TODO: check coeffsr
trace[grind.cutsat.eq] "{← c.pp}"
let some (k, x) := c.p.pickVarToElim? | c.throwUnexpected
-- TODO: eliminate `x` from lowers, uppers, and dvdCnstrs
-- TODO: reset `x`s occurrences
-- assert a divisibility constraint IF `|k| != 1`
if k.natAbs != 1 then
let p := c.p.insert (-k) x
let d := Int.ofNat k.natAbs
let c ← mkDvdCnstr d p (.ofEq x c)
c.assert
modify' fun s => { s with
elimEqs := s.elimEqs.set x (some c)
elimStack := x :: s.elimStack
}

@[export lean_process_cutsat_eq]
def processNewEqImpl (a b : Expr) : GoalM Unit := do
trace[grind.cutsat.eq] "{mkIntEq a b}"
-- TODO
return ()

@[export lean_process_new_cutsat_lit]
def processNewEqLitImpl (a k : Expr) : GoalM Unit := do
trace[grind.cutsat.eq] "{mkIntEq a k}"
-- TODO
return ()
def processNewEqLitImpl (a ke : Expr) : GoalM Unit := do
let some k ← getIntValue? ke | return ()
let some p := (← get').terms.find? { expr := a } | return ()
if k == 0 then
(← mkEqCnstr p (.expr (← mkEqProof a ke))).assert
else
-- TODO
return ()

/-- Different kinds of terms internalized by this module. -/
private inductive SupportedTermKind where
Expand Down
20 changes: 17 additions & 3 deletions src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Proof.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ partial def DvdCnstr.toExprProof (c' : DvdCnstr) : ProofM Expr := c'.caching do
return mkApp10 (mkConst ``Int.Linear.dvd_solve_elim)
(← getContext) (toExpr c₁.d) (toExpr c₁.p) (toExpr c₂.d) (toExpr c₂.p) (toExpr c'.d) (toExpr c'.p)
reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof)
| .subst _c₁ _c₂ => throwError "NIY"
| .ofEq _c => throwError "NIY"
| .subst _x _c₁ _c₂ => throwError "NIY"
| .ofEq x c =>
return mkApp7 (mkConst ``Int.Linear.dvd_of_eq)
(← getContext) (toExpr x) (toExpr c.p) (toExpr c'.d) (toExpr c'.p)
reflBoolTrue (← c.toExprProof)

partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := c'.caching do
match c'.h with
Expand All @@ -56,7 +59,18 @@ partial def LeCnstr.toExprProof (c' : LeCnstr) : ProofM Expr := c'.caching do
(← getContext) (toExpr c₁.p) (toExpr c₂.p) (toExpr c'.p)
reflBoolTrue
(← c₁.toExprProof) (← c₂.toExprProof)
| .subst _c₁ _c₂ => throwError "NIY"
| .subst _x _c₁ _c₂ => throwError "NIY"

partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := c'.caching do
match c'.h with
| .expr h =>
return h
| .norm c =>
return mkApp5 (mkConst ``Int.Linear.eq_norm) (← getContext) (toExpr c.p) (toExpr c'.p) reflBoolTrue (← c.toExprProof)
| .subst x c₁ c₂ =>
return mkApp8 (mkConst ``Int.Linear.eq_eq_subst)
(← getContext) (toExpr x) (toExpr c₁.p) (toExpr c₂.p) (toExpr c'.p)
reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof)

end
end Lean.Meta.Grind.Arith.Cutsat
8 changes: 4 additions & 4 deletions src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ inductive DvdCnstrProof where
| solveCombine (c₁ c₂ : DvdCnstr)
| solveElim (c₁ c₂ : DvdCnstr)
| elim (c : DvdCnstr)
| ofEq (c : EqCnstr)
| subst (c₁ : EqCnstr) (c₂ : DvdCnstr)
| ofEq (x : Var) (c : EqCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : DvdCnstr)

structure LeCnstr where
p : Poly
Expand All @@ -48,7 +48,7 @@ inductive LeCnstrProof where
| norm (c : LeCnstr)
| divCoeffs (c : LeCnstr)
| combine (c₁ c₂ : LeCnstr)
| subst (c₁ : EqCnstr) (c₂ : LeCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : LeCnstr)
-- TODO: missing constructors

structure EqCnstr where
Expand All @@ -59,7 +59,7 @@ structure EqCnstr where
inductive EqCnstrProof where
| expr (h : Expr)
| norm (c : EqCnstr)
| subst (c₁ : EqCnstr) (c₂ : EqCnstr)
| subst (x : Var) (c₁ : EqCnstr) (c₂ : EqCnstr)
end

/-- State of the cutsat procedure. -/
Expand Down
10 changes: 4 additions & 6 deletions src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def EqCnstr.pp (c : EqCnstr) : GoalM MessageData := do
def EqCnstr.denoteExpr (c : EqCnstr) : GoalM Expr := do
return mkIntEq (← c.p.denoteExpr') (mkIntLit 0)

def EqCnstr.throwUnexpected (c : LeCnstr) : GoalM α := do
def EqCnstr.throwUnexpected (c : EqCnstr) : GoalM α := do
throwError "`grind` internal error, unexpected{indentD (← c.pp)}"

/-- Returns occurrences of `x`. -/
Expand Down Expand Up @@ -176,11 +176,9 @@ abbrev caching (id : Nat) (k : ProofM Expr) : ProofM Expr := do
modify fun s => { s with cache := s.cache.insert id h }
return h

abbrev DvdCnstr.caching (c : DvdCnstr) (k : ProofM Expr) : ProofM Expr :=
Cutsat.caching c.id k

abbrev LeCnstr.caching (c : LeCnstr) (k : ProofM Expr) : ProofM Expr :=
Cutsat.caching c.id k
abbrev DvdCnstr.caching (c : DvdCnstr) (k : ProofM Expr) : ProofM Expr := Cutsat.caching c.id k
abbrev LeCnstr.caching (c : LeCnstr) (k : ProofM Expr) : ProofM Expr := Cutsat.caching c.id k
abbrev EqCnstr.caching (c : EqCnstr) (k : ProofM Expr) : ProofM Expr := Cutsat.caching c.id k

abbrev withProofContext (x : ProofM Expr) : GoalM Expr := do
withLetDecl `ctx (mkApp (mkConst ``RArray) (mkConst ``Int)) (← toContextExpr) fun ctx => do
Expand Down
Loading
Loading