Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: getElem lemmas for Vector operations #6324

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -785,11 +785,26 @@ theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat
else
simp [setIfInBounds, h]

theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (j : Nat)
(hj : j < (setIfInBounds a i v).size) :
(setIfInBounds a i v)[j]'hj = if i = j then v else a[j]'(by simpa using hj) := by
simp only [setIfInBounds]
split
· simp [getElem_set]
· simp only [size_setIfInBounds] at hj
rw [if_neg]
omega

@[simp] theorem getElem_setIfInBounds_eq (a : Array α) {i : Nat} (v : α) (h : _) :
(setIfInBounds a i v)[i]'h = v := by
simp at h
simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq]

@[simp] theorem getElem_setIfInBounds_ne (a : Array α) {i : Nat} (v : α) {j : Nat}
(hj : j < (setIfInBounds a i v).size) (h : i ≠ j) :
(setIfInBounds a i v)[j]'hj = a[j]'(by simpa using hj) := by
simp [getElem_setIfInBounds, h]

@[simp]
theorem getElem?_setIfInBounds_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) :
(a.setIfInBounds i v)[i]? = some v := by
Expand Down Expand Up @@ -991,11 +1006,6 @@ theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a
(h : i ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by
simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h]

theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (h : i < (setIfInBounds a i v).size) :
(setIfInBounds a i v)[i] = v := by
simp at h
simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq]

theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) :
(a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set]

Expand Down Expand Up @@ -1861,8 +1871,6 @@ instance [DecidableEq α] (a : α) (as : Array α) : Decidable (a ∈ as) :=

/-! ### swap -/

open Fin

@[simp] theorem getElem_swap_right (a : Array α) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by
simp [swap_def, getElem_set]
Expand All @@ -1881,7 +1889,7 @@ theorem getElem_swap' (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < a.s
· simp_all only [getElem_swap_left]
· split <;> simp_all

theorem getElem_swap (a : Array α) (i j : Nat) {hi hj}(k : Nat) (hk : k < (a.swap i j).size) :
theorem getElem_swap (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < (a.swap i j).size) :
(a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by
apply getElem_swap'

Expand Down Expand Up @@ -1944,6 +1952,13 @@ theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size)
(as.zip bs).size = min as.size bs.size :=
as.size_zipWith bs Prod.mk

@[simp] theorem getElem_zipWith (as : Array α) (bs : Array β) (f : α → β → γ) (i : Nat)
(hi : i < (as.zipWith bs f).size) :
(as.zipWith bs f)[i] = f (as[i]'(by simp at hi; omega)) (bs[i]'(by simp at hi; omega)) := by
cases as
cases bs
simp

/-! ### findSomeM?, findM?, findSome?, find? -/

@[simp] theorem findSomeM?_toList [Monad m] [LawfulMonad m] (p : α → m (Option β)) (as : Array α) :
Expand Down Expand Up @@ -2244,6 +2259,11 @@ theorem foldr_map' (g : α → β) (f : α → α → α) (f' : β → β → β
cases as
simp

@[simp] theorem getElem_reverse (as : Array α) (i : Nat) (hi : i < as.reverse.size) :
(as.reverse)[i] = as[as.size - 1 - i]'(by simp at hi; omega) := by
cases as
simp [Array.getElem_reverse]

/-! ### findSomeRevM?, findRevM?, findSomeRev?, findRev? -/

@[simp] theorem findSomeRevM?_eq_findSomeM?_reverse
Expand Down
132 changes: 132 additions & 0 deletions src/Init/Data/Vector/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a
(Vector.mk a h).eraseIdx! i = Vector.mk (a.eraseIdx i) (by simp [h, hi]) := by
simp [Vector.eraseIdx!, hi]

@[simp] theorem cast_mk (a : Array α) (h : a.size = n) (h' : n = m) :
(Vector.mk a h).cast h' = Vector.mk a (by simp [h, h']) := rfl

@[simp] theorem extract_mk (a : Array α) (h : a.size = n) (start stop) :
(Vector.mk a h).extract start stop = Vector.mk (a.extract start stop) (by simp [h]) := rfl

Expand Down Expand Up @@ -194,6 +197,9 @@ theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a
(a.eraseIdx! i).toArray = a.toArray.eraseIdx! i := by
cases a; simp_all [Array.eraseIdx!]

@[simp] theorem toArray_cast (a : Vector α n) (h : n = m) :
(a.cast h).toArray = a.toArray := rfl

@[simp] theorem toArray_extract (a : Vector α n) (start stop) :
(a.extract start stop).toArray = a.toArray.extract start stop := rfl

Expand Down Expand Up @@ -253,6 +259,132 @@ theorem toList_inj {a b : Vector α n} (h : a.toList = b.toList) : a = b := by
rcases b with ⟨⟨b⟩, hb⟩
simpa using h

/-! ### set -/

theorem getElem_set (a : Vector α n) (i : Nat) (x : α) (hi : i < n) (j : Nat) (hj : j < n) :
(a.set i x hi)[j] = if i = j then x else a[j] := by
cases a
split <;> simp_all [Array.getElem_set]

@[simp] theorem getElem_set_eq (a : Vector α n) (i : Nat) (x : α) (hi : i < n) :
(a.set i x hi)[i] = x := by simp [getElem_set]

@[simp] theorem getElem_set_ne (a : Vector α n) (i : Nat) (x : α) (hi : i < n) (j : Nat)
(hj : j < n) (h : i ≠ j) : (a.set i x hi)[j] = a[j] := by simp [getElem_set, h]

/-! ### setIfInBounds -/

theorem getElem_setIfInBounds (a : Vector α n) (i : Nat) (x : α) (j : Nat)
(hj : j < n) : (a.setIfInBounds i x)[j] = if i = j then x else a[j] := by
cases a
split <;> simp_all [Array.getElem_setIfInBounds]

@[simp] theorem getElem_setIfInBounds_eq (a : Vector α n) (i : Nat) (x : α) (hj : i < n) :
(a.setIfInBounds i x)[i] = x := by simp [getElem_setIfInBounds]

@[simp] theorem getElem_setIfInBounds_ne (a : Vector α n) (i : Nat) (x : α) (j : Nat)
(hj : j < n) (h : i ≠ j) : (a.setIfInBounds i x)[j] = a[j] := by simp [getElem_setIfInBounds, h]

/-! ### append -/

theorem getElem_append (a : Vector α n) (b : Vector α m) (i : Nat) (hi : i < n + m) :
(a ++ b)[i] = if h : i < n then a[i] else b[i - n] := by
rcases a with ⟨a, rfl⟩
rcases b with ⟨b, rfl⟩
simp [Array.getElem_append, hi]

theorem getElem_append_left {a : Vector α n} {b : Vector α m} {i : Nat} (hi : i < n) :
(a ++ b)[i] = a[i] := by simp [getElem_append, hi]

theorem getElem_append_right {a : Vector α n} {b : Vector α m} {i : Nat} (h : i < n + m) (hi : n ≤ i) :
(a ++ b)[i] = b[i - n] := by
rw [getElem_append, dif_neg (by omega)]

/-! ### cast -/

@[simp] theorem getElem_cast (a : Vector α n) (h : n = m) (i : Nat) (hi : i < m) :
(a.cast h)[i] = a[i] := by
cases a
simp

/-! ### extract -/

@[simp] theorem getElem_extract (a : Vector α n) (start stop) (i : Nat) (hi : i < min stop n - start) :
(a.extract start stop)[i] = a[start + i] := by
cases a
simp

/-! ### map -/

@[simp] theorem getElem_map (f : α → β) (a : Vector α n) (i : Nat) (hi : i < n) :
(a.map f)[i] = f a[i] := by
cases a
simp

/-! ### zipWith -/

@[simp] theorem getElem_zipWith (f : α → β → γ) (a : Vector α n) (b : Vector β n) (i : Nat)
(hi : i < n) : (zipWith a b f)[i] = f a[i] b[i] := by
cases a
cases b
simp

/-! ### swap -/

theorem getElem_swap (a : Vector α n) (i j : Nat) {hi hj} (k : Nat) (hk : k < n) :
(a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k] := by
cases a
simp_all [Array.getElem_swap]

@[simp] theorem getElem_swap_right (a : Vector α n) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by
simp +contextual [getElem_swap]

@[simp] theorem getElem_swap_left (a : Vector α n) {i j : Nat} {hi hj} :
(a.swap i j hi hj)[i]'(by simpa using hi) = a[j] := by
simp [getElem_swap]

@[simp] theorem getElem_swap_of_ne (a : Vector α n) {i j : Nat} {hi hj} (hp : p < n)
(hi' : p ≠ i) (hj' : p ≠ j) : (a.swap i j hi hj)[p] = a[p] := by
simp_all [getElem_swap]

@[simp] theorem swap_swap (a : Vector α n) {i j : Nat} {hi hj} :
(a.swap i j hi hj).swap i j hi hj = a := by
cases a
simp_all [Array.swap_swap]

theorem swap_comm (a : Vector α n) {i j : Nat} {hi hj} :
a.swap i j hi hj = a.swap j i hj hi := by
cases a
simp only [swap_mk, mk.injEq]
rw [Array.swap_comm]

/-! ### range -/

@[simp] theorem getElem_range (i : Nat) (hi : i < n) : (Vector.range n)[i] = i := by
simp [Vector.range]

/-! ### take -/

@[simp] theorem getElem_take (a : Vector α n) (m : Nat) (hi : i < min n m) :

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be i < min m n not i < min n m. The latter works, but because the size of Vector.take v m is min m n, using i < min n m makes it less useful in automation (because simp can't necessarily find the proof).

(a.take m)[i] = a[i] := by
cases a
simp

/-! ### drop -/

@[simp] theorem getElem_drop (a : Vector α n) (m : Nat) (hi : i < n - m) :
(a.drop m)[i] = a[m + i] := by
cases a
simp

/-! ### reverse -/

@[simp] theorem getElem_reverse (a : Vector α n) (i : Nat) (hi : i < n) :
(a.reverse)[i] = a[n - 1 - i] := by
rcases a with ⟨a, rfl⟩
simp

/-! ### Decidable quantifiers. -/

theorem forall_zero_iff {P : Vector α 0 → Prop} :
Expand Down
Loading