diff --git a/LeanColls/Classes/Ops.lean b/LeanColls/Classes/Ops.lean index 0b5e8a728..529648be8 100644 --- a/LeanColls/Classes/Ops.lean +++ b/LeanColls/Classes/Ops.lean @@ -118,7 +118,9 @@ export Fold (fold foldM) namespace Fold -instance [Fold C τ] : ForIn m C τ where +variable [Fold C τ] + +instance : ForIn m C τ where forIn := fun {β} _ c acc f => do let res ← Fold.foldM (m := ExceptT β m) c (fun x acc => @@ -129,7 +131,7 @@ instance [Fold C τ] : ForIn m C τ where | .ok a => pure a | .error a => pure a -def find (f : τ → Bool) [Fold C τ] (cont : C) : Option τ := +def find (f : τ → Bool) (cont : C) : Option τ := match Fold.foldM cont (fun () x => if f x then .error x else .ok () @@ -138,7 +140,7 @@ def find (f : τ → Bool) [Fold C τ] (cont : C) : Option τ := | Except.ok () => none | Except.error x => some x -def any (f : τ → Bool) [Fold C τ] (cont : C) : Bool := +def any (f : τ → Bool) (cont : C) : Bool := match Fold.foldM cont (fun () x => if f x then .error () else .ok () @@ -147,7 +149,7 @@ def any (f : τ → Bool) [Fold C τ] (cont : C) : Bool := | Except.ok () => false | Except.error () => true -def all (f : τ → Bool) [Fold C τ] (cont : C) : Bool := +def all (f : τ → Bool) (cont : C) : Bool := match Fold.foldM cont (fun () x => if f x then .ok () else .error () @@ -159,6 +161,73 @@ def all (f : τ → Bool) [Fold C τ] (cont : C) : Bool := instance (priority := low) [Fold C τ] [BEq τ] : Membership τ C where mem x c := any (· == x) c +/-- Correctness of `Fold` with respect to `ToList` -/ +class ToList (C τ) [Fold C τ] [ToList C τ] : Prop where + fold_eq_fold_toList : ∀ (c : C) (f) (init : β), ∃ L, + List.Perm L (toList c) ∧ fold c f init = List.foldl f init L + foldM_eq_foldM_toList : [Monad m] → ∀ (c : C) (f) (init : β), ∃ L, + List.Perm L (toList c) ∧ foldM (m := m) c f init = List.foldlM f init L + +theorem any_eq_any_toList [LeanColls.ToList C τ] [ToList C τ] + (f : τ → Bool) (c : C) + : any f c = List.any (toList c) f := by + unfold any + generalize hf' : (fun _ _ => _) = f' + suffices foldM c f' () = Except.error () ↔ List.any (toList c) f by + rw [eq_comm]; split + · rw [Bool.eq_false_iff]; aesop + · aesop + have ⟨L,perm,h⟩ := ToList.foldM_eq_foldM_toList c f' () + rw [h]; clear h + simp_rw [List.any_eq_true, ← perm.mem_iff]; clear perm c + subst hf' + induction L with + | nil => simp_all [pure, Except.pure] + | cons hd tl ih => + simp [bind, Except.bind] + by_cases f hd = true <;> simp_all + +theorem all_eq_all_toList [LeanColls.ToList C τ] [ToList C τ] + (f : τ → Bool) (c : C) + : all f c = List.all (toList c) f := by + unfold all + generalize hf' : (fun _ _ => _) = f' + suffices foldM c f' () = Except.ok () ↔ List.all (toList c) f by + rw [eq_comm]; split + · aesop + · rw [Bool.eq_false_iff]; aesop + have ⟨L,perm,h⟩ := ToList.foldM_eq_foldM_toList c f' () + rw [h]; clear h + simp_rw [List.all_eq_true, ← perm.mem_iff]; clear perm c + subst hf' + induction L with + | nil => simp_all [pure, Except.pure] + | cons hd tl ih => + simp [bind, Except.bind] + by_cases f hd = true <;> simp_all + +@[simp] +theorem any_iff_exists [Membership τ C] [LeanColls.ToList C τ] [ToList C τ] [Mem.ToList C τ] + (f : τ → Bool) (c : C) + : any f c ↔ ∃ x ∈ c, f x := by + rw [any_eq_any_toList] + simp [Mem.ToList.mem_iff_mem_toList] + +@[simp] +theorem all_iff_exists [Membership τ C] [LeanColls.ToList C τ] [ToList C τ] [Mem.ToList C τ] + (f : τ → Bool) (c : C) + : all f c ↔ ∀ x ∈ c, f x := by + rw [all_eq_all_toList] + simp [Mem.ToList.mem_iff_mem_toList] + +instance [Fold C τ] [BEq τ] [LeanColls.ToList C τ] + [ToList C τ] [LawfulBEq τ] : Mem.ToList C τ where + mem_iff_mem_toList := by + intro x c + conv => lhs; simp [Membership.mem] + rw [any_eq_any_toList] + simp only [List.any_eq_true, beq_iff_eq, exists_eq_right] + end Fold diff --git a/LeanColls/Data/Array.lean b/LeanColls/Data/Array.lean index a1edc6929..52a21dfd0 100644 --- a/LeanColls/Data/Array.lean +++ b/LeanColls/Data/Array.lean @@ -6,6 +6,7 @@ Authors: James Gallicchio import Std.Data.Array.Lemmas import Std.Data.List.Lemmas import Mathlib.Data.Array.Lemmas +import Mathlib.Tactic.Ring import LeanColls.Classes.Seq import LeanColls.Data.List @@ -134,17 +135,179 @@ abbrev NArray (α : Type u) (n : Nat) := FixSize (Array α) n namespace ByteArray +instance : ToList ByteArray UInt8 where + toList := toList + instance : Fold ByteArray UInt8 where fold arr := arr.foldl foldM arr := arr.foldlM +instance : Fold.ToList ByteArray UInt8 where + -- JG: silly me, thinking someone had proven even a single theorem about ByteArray.foldl + fold_eq_fold_toList := by sorry + foldM_eq_foldM_toList := by sorry + instance : Seq ByteArray UInt8 where size := size get := get set := set empty := empty insert := push - toList := toList ofFn := ofFn + snoc := push + -- TODO(JG): implement bytearray append directly + +-- JG: again, silly me, thinking anyone has proven anything about ByteArray at all +/- +instance : LawfulSeq ByteArray UInt8 where + toList_append := sorry + toMultiset_empty := sorry + toMultiset_insert := sorry + toMultiset_singleton := by + simp [LeanColls.singleton, ToMultiset.toMultiset, LeanColls.toList] + sorry -- ... yikes.. + size_def := by + intros; simp [LeanColls.size, LeanColls.toList] + sorry -- ... +-/ end ByteArray + +namespace FloatArray + +instance : ToList FloatArray Float where + toList := toList + +instance : Fold FloatArray Float where + fold arr := arr.foldl + foldM arr := arr.foldlM + +def append (A1 A2 : FloatArray) : FloatArray := + aux A1 0 +where aux (acc : FloatArray) (i : Nat) : FloatArray := + if _h : i < A2.size then + aux (acc.push A2[i]) (i+1) + else + acc +termination_by aux _ i => A2.size - i + +instance : Append FloatArray where + append := append + +def ofFn (f : Fin n → Float) : FloatArray := + aux (FloatArray.mkEmpty n) 0 +where aux (acc : FloatArray) (i : Nat) : FloatArray := + if h : i < n then + aux (acc.push (f ⟨i,h⟩)) (i+1) + else + acc +termination_by aux i _ => n - i + +instance : Seq FloatArray Float where + size := size + get := get + set := set + empty := empty + insert := push + ofFn := ofFn + snoc := push + append := append + +theorem ext (A1 A2 : FloatArray) : A1.data = A2.data → A1 = A2 := by + cases A1; cases A2; simp + +theorem ext_iff (A1 A2 : FloatArray) : A1 = A2 ↔ A1.data = A2.data := by + constructor + · rintro rfl; rfl + · apply ext + +theorem append.aux_eq_acc_append (A2 : FloatArray) (acc : FloatArray) (i : Nat) + : (aux A2 acc i).data = acc.data ++ (aux A2 FloatArray.empty i).data := by + if i < A2.size then + let j := size A2 - i + have : i = size A2 - j := by + simp; rw [Nat.sub_sub_self]; apply Nat.lt_succ.mp (Nat.le.step _); assumption + rw [this]; clear this + have : j ≤ size A2 := by simp + generalize j = j' at *; clear j; clear! i + induction j' generalizing acc with + | zero => unfold aux; simp [empty, mkEmpty] + | succ j ih => + unfold aux + have : size A2 - j.succ < size A2 := by omega + have : size A2 - j.succ + 1 = size A2 - j := by omega + simp [*] + have : j ≤ size A2 := Nat.le_of_lt ‹_› + conv => lhs; rw [ih _ ‹_›] + conv => rhs; rw [ih _ ‹_›] + rw [←Array.append_assoc] + congr 1 + else + unfold aux + simp [*, empty, mkEmpty] + +theorem append.aux_fromEmpty (A : FloatArray) + : (aux A FloatArray.empty 0).data = A.data := by + rcases A with ⟨A⟩ + suffices ∀ i, i ≤ A.size → (aux ⟨A⟩ empty (A.size - i)).data = ⟨A.data.drop (A.size - i)⟩ by + have := this A.size (Nat.le_refl _) + simpa using this + intro i hi + induction i with + | zero => unfold aux; simp [size, empty, mkEmpty]; rfl + | succ i ih => + specialize ih (Nat.le_of_lt hi) + unfold aux + have : A.size - i.succ < A.size := by omega + have : A.size - i.succ + 1 = A.size - i := by omega + simp [size, *] + rw [append.aux_eq_acc_append, ih, Array.ext'_iff] + simp [push, empty, mkEmpty] + have : A.size - i.succ < A.data.length := by omega + conv => rhs; rw [List.drop_eq_get_cons ‹_›] + congr 2 + omega + +@[simp] theorem data_append (A1 A2 : FloatArray) : (A1 ++ A2).data = A1.data ++ A2.data := by + rcases A1 with ⟨A1⟩ + rcases A2 with ⟨A2⟩ + simp [instAppendFloatArray, instHAppend, append] + rw [append.aux_eq_acc_append, append.aux_fromEmpty] + rfl + + +theorem ofFn.aux_spec (f : Fin n → Float) (acc : FloatArray) (i : Nat) (hi : i ≤ n) + : aux f acc i = acc ++ ⟨Array.ofFn (fun (j : Fin (n-i)) => f (j.addNat i |>.cast (by omega)))⟩ := by + have hi' := hi + revert hi acc + apply Nat.decreasingInduction' (n := n) (P := fun i => ∀ acc (hi : i ≤ n), aux f acc i = _) + · intro j jlt _ilej ih acc hi + unfold aux + simp only [jlt, dite_true] + rw [ih _ jlt, ext_iff] + rw [Array.ext'_iff] + simp [push] + rw [←List.ofFn_def, ←List.ofFn_def] + have := List.ofFn_succ (fun (x : Fin (n - j.succ + 1)) => f (x.addNat j |>.cast (by omega))) + convert this.symm using 1 <;> clear this + · simp; constructor + · congr; simp + · funext x + congr 1; simp [Fin.eq_iff_veq]; ring + · apply List.ext_get + · simp; omega + · intro x h1 h2; simp + · assumption + · intro acc _ + unfold aux + simp; cases acc; rw [ext_iff]; simp + suffices Array.ofFn _ = #[] by + rw [this]; simp + rw [Array.ext'_iff] + simp + apply List.eq_nil_of_length_eq_zero + simp + +-- Not even going to try to write out the lawfulseq instance... + +end FloatArray