Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions LeanColls/Classes/Ops.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ export Fold (fold foldM)

namespace Fold

instance [Fold C τ] : ForIn m C τ where
variable [Fold C τ]

instance : ForIn m C τ where
forIn := fun {β} _ c acc f => do
let res ← Fold.foldM (m := ExceptT β m)
c (fun x acc =>
Expand All @@ -129,7 +131,7 @@ instance [Fold C τ] : ForIn m C τ where
| .ok a => pure a
| .error a => pure a

def find (f : τ → Bool) [Fold C τ] (cont : C) : Option τ :=
def find (f : τ → Bool) (cont : C) : Option τ :=
match
Fold.foldM cont (fun () x =>
if f x then .error x else .ok ()
Expand All @@ -138,7 +140,7 @@ def find (f : τ → Bool) [Fold C τ] (cont : C) : Option τ :=
| Except.ok () => none
| Except.error x => some x

def any (f : τ → Bool) [Fold C τ] (cont : C) : Bool :=
def any (f : τ → Bool) (cont : C) : Bool :=
match
Fold.foldM cont (fun () x =>
if f x then .error () else .ok ()
Expand All @@ -147,7 +149,7 @@ def any (f : τ → Bool) [Fold C τ] (cont : C) : Bool :=
| Except.ok () => false
| Except.error () => true

def all (f : τ → Bool) [Fold C τ] (cont : C) : Bool :=
def all (f : τ → Bool) (cont : C) : Bool :=
match
Fold.foldM cont (fun () x =>
if f x then .ok () else .error ()
Expand All @@ -159,6 +161,73 @@ def all (f : τ → Bool) [Fold C τ] (cont : C) : Bool :=
instance (priority := low) [Fold C τ] [BEq τ] : Membership τ C where
mem x c := any (· == x) c

/-- Correctness of `Fold` with respect to `ToList` -/
class ToList (C τ) [Fold C τ] [ToList C τ] : Prop where
fold_eq_fold_toList : ∀ (c : C) (f) (init : β), ∃ L,
List.Perm L (toList c) ∧ fold c f init = List.foldl f init L
foldM_eq_foldM_toList : [Monad m] → ∀ (c : C) (f) (init : β), ∃ L,
List.Perm L (toList c) ∧ foldM (m := m) c f init = List.foldlM f init L

theorem any_eq_any_toList [LeanColls.ToList C τ] [ToList C τ]
(f : τ → Bool) (c : C)
: any f c = List.any (toList c) f := by
unfold any
generalize hf' : (fun _ _ => _) = f'
suffices foldM c f' () = Except.error () ↔ List.any (toList c) f by
rw [eq_comm]; split
· rw [Bool.eq_false_iff]; aesop
· aesop
have ⟨L,perm,h⟩ := ToList.foldM_eq_foldM_toList c f' ()
rw [h]; clear h
simp_rw [List.any_eq_true, ← perm.mem_iff]; clear perm c
subst hf'
induction L with
| nil => simp_all [pure, Except.pure]
| cons hd tl ih =>
simp [bind, Except.bind]
by_cases f hd = true <;> simp_all

theorem all_eq_all_toList [LeanColls.ToList C τ] [ToList C τ]
(f : τ → Bool) (c : C)
: all f c = List.all (toList c) f := by
unfold all
generalize hf' : (fun _ _ => _) = f'
suffices foldM c f' () = Except.ok () ↔ List.all (toList c) f by
rw [eq_comm]; split
· aesop
· rw [Bool.eq_false_iff]; aesop
have ⟨L,perm,h⟩ := ToList.foldM_eq_foldM_toList c f' ()
rw [h]; clear h
simp_rw [List.all_eq_true, ← perm.mem_iff]; clear perm c
subst hf'
induction L with
| nil => simp_all [pure, Except.pure]
| cons hd tl ih =>
simp [bind, Except.bind]
by_cases f hd = true <;> simp_all

@[simp]
theorem any_iff_exists [Membership τ C] [LeanColls.ToList C τ] [ToList C τ] [Mem.ToList C τ]
(f : τ → Bool) (c : C)
: any f c ↔ ∃ x ∈ c, f x := by
rw [any_eq_any_toList]
simp [Mem.ToList.mem_iff_mem_toList]

@[simp]
theorem all_iff_exists [Membership τ C] [LeanColls.ToList C τ] [ToList C τ] [Mem.ToList C τ]
(f : τ → Bool) (c : C)
: all f c ↔ ∀ x ∈ c, f x := by
rw [all_eq_all_toList]
simp [Mem.ToList.mem_iff_mem_toList]

instance [Fold C τ] [BEq τ] [LeanColls.ToList C τ]
[ToList C τ] [LawfulBEq τ] : Mem.ToList C τ where
mem_iff_mem_toList := by
intro x c
conv => lhs; simp [Membership.mem]
rw [any_eq_any_toList]
simp only [List.any_eq_true, beq_iff_eq, exists_eq_right]

end Fold


Expand Down
165 changes: 164 additions & 1 deletion LeanColls/Data/Array.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: James Gallicchio
import Std.Data.Array.Lemmas
import Std.Data.List.Lemmas
import Mathlib.Data.Array.Lemmas
import Mathlib.Tactic.Ring

import LeanColls.Classes.Seq
import LeanColls.Data.List
Expand Down Expand Up @@ -134,17 +135,179 @@ abbrev NArray (α : Type u) (n : Nat) := FixSize (Array α) n

namespace ByteArray

instance : ToList ByteArray UInt8 where
toList := toList

instance : Fold ByteArray UInt8 where
fold arr := arr.foldl
foldM arr := arr.foldlM

instance : Fold.ToList ByteArray UInt8 where
-- JG: silly me, thinking someone had proven even a single theorem about ByteArray.foldl
fold_eq_fold_toList := by sorry
foldM_eq_foldM_toList := by sorry

instance : Seq ByteArray UInt8 where
size := size
get := get
set := set
empty := empty
insert := push
toList := toList
ofFn := ofFn
snoc := push
-- TODO(JG): implement bytearray append directly

-- JG: again, silly me, thinking anyone has proven anything about ByteArray at all
/-
instance : LawfulSeq ByteArray UInt8 where
toList_append := sorry
toMultiset_empty := sorry
toMultiset_insert := sorry
toMultiset_singleton := by
simp [LeanColls.singleton, ToMultiset.toMultiset, LeanColls.toList]
sorry -- ... yikes..
size_def := by
intros; simp [LeanColls.size, LeanColls.toList]
sorry -- ...
-/

end ByteArray

namespace FloatArray

instance : ToList FloatArray Float where
toList := toList

instance : Fold FloatArray Float where
fold arr := arr.foldl
foldM arr := arr.foldlM

def append (A1 A2 : FloatArray) : FloatArray :=
aux A1 0
where aux (acc : FloatArray) (i : Nat) : FloatArray :=
if _h : i < A2.size then
aux (acc.push A2[i]) (i+1)
else
acc
termination_by aux _ i => A2.size - i

instance : Append FloatArray where
append := append

def ofFn (f : Fin n → Float) : FloatArray :=
aux (FloatArray.mkEmpty n) 0
where aux (acc : FloatArray) (i : Nat) : FloatArray :=
if h : i < n then
aux (acc.push (f ⟨i,h⟩)) (i+1)
else
acc
termination_by aux i _ => n - i

instance : Seq FloatArray Float where
size := size
get := get
set := set
empty := empty
insert := push
ofFn := ofFn
snoc := push
append := append

theorem ext (A1 A2 : FloatArray) : A1.data = A2.data → A1 = A2 := by
cases A1; cases A2; simp

theorem ext_iff (A1 A2 : FloatArray) : A1 = A2 ↔ A1.data = A2.data := by
constructor
· rintro rfl; rfl
· apply ext

theorem append.aux_eq_acc_append (A2 : FloatArray) (acc : FloatArray) (i : Nat)
: (aux A2 acc i).data = acc.data ++ (aux A2 FloatArray.empty i).data := by
if i < A2.size then
let j := size A2 - i
have : i = size A2 - j := by
simp; rw [Nat.sub_sub_self]; apply Nat.lt_succ.mp (Nat.le.step _); assumption
rw [this]; clear this
have : j ≤ size A2 := by simp
generalize j = j' at *; clear j; clear! i
induction j' generalizing acc with
| zero => unfold aux; simp [empty, mkEmpty]
| succ j ih =>
unfold aux
have : size A2 - j.succ < size A2 := by omega
have : size A2 - j.succ + 1 = size A2 - j := by omega
simp [*]
have : j ≤ size A2 := Nat.le_of_lt ‹_›
conv => lhs; rw [ih _ ‹_›]
conv => rhs; rw [ih _ ‹_›]
rw [←Array.append_assoc]
congr 1
else
unfold aux
simp [*, empty, mkEmpty]

theorem append.aux_fromEmpty (A : FloatArray)
: (aux A FloatArray.empty 0).data = A.data := by
rcases A with ⟨A⟩
suffices ∀ i, i ≤ A.size → (aux ⟨A⟩ empty (A.size - i)).data = ⟨A.data.drop (A.size - i)⟩ by
have := this A.size (Nat.le_refl _)
simpa using this
intro i hi
induction i with
| zero => unfold aux; simp [size, empty, mkEmpty]; rfl
| succ i ih =>
specialize ih (Nat.le_of_lt hi)
unfold aux
have : A.size - i.succ < A.size := by omega
have : A.size - i.succ + 1 = A.size - i := by omega
simp [size, *]
rw [append.aux_eq_acc_append, ih, Array.ext'_iff]
simp [push, empty, mkEmpty]
have : A.size - i.succ < A.data.length := by omega
conv => rhs; rw [List.drop_eq_get_cons ‹_›]
congr 2
omega

@[simp] theorem data_append (A1 A2 : FloatArray) : (A1 ++ A2).data = A1.data ++ A2.data := by
rcases A1 with ⟨A1⟩
rcases A2 with ⟨A2⟩
simp [instAppendFloatArray, instHAppend, append]
rw [append.aux_eq_acc_append, append.aux_fromEmpty]
rfl


theorem ofFn.aux_spec (f : Fin n → Float) (acc : FloatArray) (i : Nat) (hi : i ≤ n)
: aux f acc i = acc ++ ⟨Array.ofFn (fun (j : Fin (n-i)) => f (j.addNat i |>.cast (by omega)))⟩ := by
have hi' := hi
revert hi acc
apply Nat.decreasingInduction' (n := n) (P := fun i => ∀ acc (hi : i ≤ n), aux f acc i = _)
· intro j jlt _ilej ih acc hi
unfold aux
simp only [jlt, dite_true]
rw [ih _ jlt, ext_iff]
rw [Array.ext'_iff]
simp [push]
rw [←List.ofFn_def, ←List.ofFn_def]
have := List.ofFn_succ (fun (x : Fin (n - j.succ + 1)) => f (x.addNat j |>.cast (by omega)))
convert this.symm using 1 <;> clear this
· simp; constructor
· congr; simp
· funext x
congr 1; simp [Fin.eq_iff_veq]; ring
· apply List.ext_get
· simp; omega
· intro x h1 h2; simp
· assumption
· intro acc _
unfold aux
simp; cases acc; rw [ext_iff]; simp
suffices Array.ofFn _ = #[] by
rw [this]; simp
rw [Array.ext'_iff]
simp
apply List.eq_nil_of_length_eq_zero
simp

-- Not even going to try to write out the lawfulseq instance...

end FloatArray