From 12860c6549ddc881cbf93471bc229c1b7d492025 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Fri, 6 Dec 2024 12:22:17 +1100 Subject: [PATCH] feat: getElem lemmas for Vector operations --- src/Init/Data/Array/Lemmas.lean | 36 +++++++-- src/Init/Data/Vector/Lemmas.lean | 132 +++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 8 deletions(-) diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index 7aebc66b4b58..ea8a4f931c9a 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -785,11 +785,26 @@ theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat else simp [setIfInBounds, h] +theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (j : Nat) + (hj : j < (setIfInBounds a i v).size) : + (setIfInBounds a i v)[j]'hj = if i = j then v else a[j]'(by simpa using hj) := by + simp only [setIfInBounds] + split + · simp [getElem_set] + · simp only [size_setIfInBounds] at hj + rw [if_neg] + omega + @[simp] theorem getElem_setIfInBounds_eq (a : Array α) {i : Nat} (v : α) (h : _) : (setIfInBounds a i v)[i]'h = v := by simp at h simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq] +@[simp] theorem getElem_setIfInBounds_ne (a : Array α) {i : Nat} (v : α) {j : Nat} + (hj : j < (setIfInBounds a i v).size) (h : i ≠ j) : + (setIfInBounds a i v)[j]'hj = a[j]'(by simpa using hj) := by + simp [getElem_setIfInBounds, h] + @[simp] theorem getElem?_setIfInBounds_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) : (a.setIfInBounds i v)[i]? = some v := by @@ -991,11 +1006,6 @@ theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a (h : i ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h] -theorem getElem_setIfInBounds (a : Array α) (i : Nat) (v : α) (h : i < (setIfInBounds a i v).size) : - (setIfInBounds a i v)[i] = v := by - simp at h - simp only [setIfInBounds, h, ↓reduceDIte, getElem_set_eq] - theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) : (a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set] @@ -1861,8 +1871,6 @@ instance [DecidableEq α] (a : α) (as : Array α) : Decidable (a ∈ as) := /-! ### swap -/ -open Fin - @[simp] theorem getElem_swap_right (a : Array α) {i j : Nat} {hi hj} : (a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by simp [swap_def, getElem_set] @@ -1881,7 +1889,7 @@ theorem getElem_swap' (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < a.s · simp_all only [getElem_swap_left] · split <;> simp_all -theorem getElem_swap (a : Array α) (i j : Nat) {hi hj}(k : Nat) (hk : k < (a.swap i j).size) : +theorem getElem_swap (a : Array α) (i j : Nat) {hi hj} (k : Nat) (hk : k < (a.swap i j).size) : (a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k]'(by simp_all) := by apply getElem_swap' @@ -1944,6 +1952,13 @@ theorem eraseIdx_eq_eraseIdxIfInBounds {a : Array α} {i : Nat} (h : i < a.size) (as.zip bs).size = min as.size bs.size := as.size_zipWith bs Prod.mk +@[simp] theorem getElem_zipWith (as : Array α) (bs : Array β) (f : α → β → γ) (i : Nat) + (hi : i < (as.zipWith bs f).size) : + (as.zipWith bs f)[i] = f (as[i]'(by simp at hi; omega)) (bs[i]'(by simp at hi; omega)) := by + cases as + cases bs + simp + /-! ### findSomeM?, findM?, findSome?, find? -/ @[simp] theorem findSomeM?_toList [Monad m] [LawfulMonad m] (p : α → m (Option β)) (as : Array α) : @@ -2244,6 +2259,11 @@ theorem foldr_map' (g : α → β) (f : α → α → α) (f' : β → β → β cases as simp +@[simp] theorem getElem_reverse (as : Array α) (i : Nat) (hi : i < as.reverse.size) : + (as.reverse)[i] = as[as.size - 1 - i]'(by simp at hi; omega) := by + cases as + simp [Array.getElem_reverse] + /-! ### findSomeRevM?, findRevM?, findSomeRev?, findRev? -/ @[simp] theorem findSomeRevM?_eq_findSomeM?_reverse diff --git a/src/Init/Data/Vector/Lemmas.lean b/src/Init/Data/Vector/Lemmas.lean index a3d71282a8f5..4f9bc37593ff 100644 --- a/src/Init/Data/Vector/Lemmas.lean +++ b/src/Init/Data/Vector/Lemmas.lean @@ -124,6 +124,9 @@ theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a (Vector.mk a h).eraseIdx! i = Vector.mk (a.eraseIdx i) (by simp [h, hi]) := by simp [Vector.eraseIdx!, hi] +@[simp] theorem cast_mk (a : Array α) (h : a.size = n) (h' : n = m) : + (Vector.mk a h).cast h' = Vector.mk a (by simp [h, h']) := rfl + @[simp] theorem extract_mk (a : Array α) (h : a.size = n) (start stop) : (Vector.mk a h).extract start stop = Vector.mk (a.extract start stop) (by simp [h]) := rfl @@ -194,6 +197,9 @@ theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a (a.eraseIdx! i).toArray = a.toArray.eraseIdx! i := by cases a; simp_all [Array.eraseIdx!] +@[simp] theorem toArray_cast (a : Vector α n) (h : n = m) : + (a.cast h).toArray = a.toArray := rfl + @[simp] theorem toArray_extract (a : Vector α n) (start stop) : (a.extract start stop).toArray = a.toArray.extract start stop := rfl @@ -253,6 +259,132 @@ theorem toList_inj {a b : Vector α n} (h : a.toList = b.toList) : a = b := by rcases b with ⟨⟨b⟩, hb⟩ simpa using h +/-! ### set -/ + +theorem getElem_set (a : Vector α n) (i : Nat) (x : α) (hi : i < n) (j : Nat) (hj : j < n) : + (a.set i x hi)[j] = if i = j then x else a[j] := by + cases a + split <;> simp_all [Array.getElem_set] + +@[simp] theorem getElem_set_eq (a : Vector α n) (i : Nat) (x : α) (hi : i < n) : + (a.set i x hi)[i] = x := by simp [getElem_set] + +@[simp] theorem getElem_set_ne (a : Vector α n) (i : Nat) (x : α) (hi : i < n) (j : Nat) + (hj : j < n) (h : i ≠ j) : (a.set i x hi)[j] = a[j] := by simp [getElem_set, h] + +/-! ### setIfInBounds -/ + +theorem getElem_setIfInBounds (a : Vector α n) (i : Nat) (x : α) (j : Nat) + (hj : j < n) : (a.setIfInBounds i x)[j] = if i = j then x else a[j] := by + cases a + split <;> simp_all [Array.getElem_setIfInBounds] + +@[simp] theorem getElem_setIfInBounds_eq (a : Vector α n) (i : Nat) (x : α) (hj : i < n) : + (a.setIfInBounds i x)[i] = x := by simp [getElem_setIfInBounds] + +@[simp] theorem getElem_setIfInBounds_ne (a : Vector α n) (i : Nat) (x : α) (j : Nat) + (hj : j < n) (h : i ≠ j) : (a.setIfInBounds i x)[j] = a[j] := by simp [getElem_setIfInBounds, h] + +/-! ### append -/ + +theorem getElem_append (a : Vector α n) (b : Vector α m) (i : Nat) (hi : i < n + m) : + (a ++ b)[i] = if h : i < n then a[i] else b[i - n] := by + rcases a with ⟨a, rfl⟩ + rcases b with ⟨b, rfl⟩ + simp [Array.getElem_append, hi] + +theorem getElem_append_left {a : Vector α n} {b : Vector α m} {i : Nat} (hi : i < n) : + (a ++ b)[i] = a[i] := by simp [getElem_append, hi] + +theorem getElem_append_right {a : Vector α n} {b : Vector α m} {i : Nat} (h : i < n + m) (hi : n ≤ i) : + (a ++ b)[i] = b[i - n] := by + rw [getElem_append, dif_neg (by omega)] + +/-! ### cast -/ + +@[simp] theorem getElem_cast (a : Vector α n) (h : n = m) (i : Nat) (hi : i < m) : + (a.cast h)[i] = a[i] := by + cases a + simp + +/-! ### extract -/ + +@[simp] theorem getElem_extract (a : Vector α n) (start stop) (i : Nat) (hi : i < min stop n - start) : + (a.extract start stop)[i] = a[start + i] := by + cases a + simp + +/-! ### map -/ + +@[simp] theorem getElem_map (f : α → β) (a : Vector α n) (i : Nat) (hi : i < n) : + (a.map f)[i] = f a[i] := by + cases a + simp + +/-! ### zipWith -/ + +@[simp] theorem getElem_zipWith (f : α → β → γ) (a : Vector α n) (b : Vector β n) (i : Nat) + (hi : i < n) : (zipWith a b f)[i] = f a[i] b[i] := by + cases a + cases b + simp + +/-! ### swap -/ + +theorem getElem_swap (a : Vector α n) (i j : Nat) {hi hj} (k : Nat) (hk : k < n) : + (a.swap i j hi hj)[k] = if k = i then a[j] else if k = j then a[i] else a[k] := by + cases a + simp_all [Array.getElem_swap] + +@[simp] theorem getElem_swap_right (a : Vector α n) {i j : Nat} {hi hj} : + (a.swap i j hi hj)[j]'(by simpa using hj) = a[i] := by + simp +contextual [getElem_swap] + +@[simp] theorem getElem_swap_left (a : Vector α n) {i j : Nat} {hi hj} : + (a.swap i j hi hj)[i]'(by simpa using hi) = a[j] := by + simp [getElem_swap] + +@[simp] theorem getElem_swap_of_ne (a : Vector α n) {i j : Nat} {hi hj} (hp : p < n) + (hi' : p ≠ i) (hj' : p ≠ j) : (a.swap i j hi hj)[p] = a[p] := by + simp_all [getElem_swap] + +@[simp] theorem swap_swap (a : Vector α n) {i j : Nat} {hi hj} : + (a.swap i j hi hj).swap i j hi hj = a := by + cases a + simp_all [Array.swap_swap] + +theorem swap_comm (a : Vector α n) {i j : Nat} {hi hj} : + a.swap i j hi hj = a.swap j i hj hi := by + cases a + simp only [swap_mk, mk.injEq] + rw [Array.swap_comm] + +/-! ### range -/ + +@[simp] theorem getElem_range (i : Nat) (hi : i < n) : (Vector.range n)[i] = i := by + simp [Vector.range] + +/-! ### take -/ + +@[simp] theorem getElem_take (a : Vector α n) (m : Nat) (hi : i < min n m) : + (a.take m)[i] = a[i] := by + cases a + simp + +/-! ### drop -/ + +@[simp] theorem getElem_drop (a : Vector α n) (m : Nat) (hi : i < n - m) : + (a.drop m)[i] = a[m + i] := by + cases a + simp + +/-! ### reverse -/ + +@[simp] theorem getElem_reverse (a : Vector α n) (i : Nat) (hi : i < n) : + (a.reverse)[i] = a[n - 1 - i] := by + rcases a with ⟨a, rfl⟩ + simp + /-! ### Decidable quantifiers. -/ theorem forall_zero_iff {P : Vector α 0 → Prop} :