diff --git a/Cslib.lean b/Cslib.lean index 8905be5f9..be342da0d 100644 --- a/Cslib.lean +++ b/Cslib.lean @@ -1,7 +1,14 @@ module -- shake: keep-all -public import Cslib.Algorithms.Lean.MergeSort.MergeSort -public import Cslib.Algorithms.Lean.TimeM +public import Cslib.AlgorithmsTheory.Algorithms.ListInsertionSort +public import Cslib.AlgorithmsTheory.Algorithms.ListLinearSearch +public import Cslib.AlgorithmsTheory.Algorithms.ListOrderedInsert +public import Cslib.AlgorithmsTheory.Algorithms.MergeSort +public import Cslib.AlgorithmsTheory.Lean.MergeSort.MergeSort +public import Cslib.AlgorithmsTheory.Lean.TimeM +public import Cslib.AlgorithmsTheory.Models.ListComparisonSearch +public import Cslib.AlgorithmsTheory.Models.ListComparisonSort +public import Cslib.AlgorithmsTheory.QueryModel public import Cslib.Computability.Automata.Acceptors.Acceptor public import Cslib.Computability.Automata.Acceptors.OmegaAcceptor public import Cslib.Computability.Automata.DA.Basic diff --git a/Cslib/AlgorithmsTheory/Algorithms/ListInsertionSort.lean b/Cslib/AlgorithmsTheory/Algorithms/ListInsertionSort.lean new file mode 100644 index 000000000..86c85c3ee --- /dev/null +++ b/Cslib/AlgorithmsTheory/Algorithms/ListInsertionSort.lean @@ -0,0 +1,100 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric Wieser +-/ +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Cslib.AlgorithmsTheory.Algorithms.ListOrderedInsert +public import Mathlib + +@[expose] public section + +/-! +# Insertion sort in a list + +In this file we state and prove the correctness and complexity of insertion sort in lists under +the `SortOps` model. This insertionSort evaluates identically to the upstream version of +`List.insertionSort` +-- + +## Main Definitions + +- `insertionSort` : Insertion sort algorithm in the `SortOps` query model + +## Main results + +- `insertionSort_eval`: `insertionSort` evaluates identically to `List.insertionSort`. +- `insertionSort_permutation` : `insertionSort` outputs a permutation of the input list. +- `insertionSort_sorted` : `insertionSort` outputs a sorted list. +- `insertionSort_complexity` : `insertionSort` takes at most n * (n + 1) comparisons and + (n + 1) * (n + 2) list head-insertions. +-/ + +namespace Cslib + +namespace Algorithms + +open Prog + +/-- The insertionSort algorithms on lists with the `SortOps` query. -/ +def insertionSort (l : List α) : Prog (SortOps α) (List α) := + match l with + | [] => return [] + | x :: xs => do + let rest ← insertionSort xs + insertOrd x rest + +@[simp] +theorem insertionSort_eval (l : List α) (le : α → α → Prop) [DecidableRel le] : + (insertionSort l).eval (sortModel le) = l.insertionSort le := by + induction l with simp_all [insertionSort] + +theorem insertionSort_permutation (l : List α) (le : α → α → Prop) [DecidableRel le] : + ((insertionSort l).eval (sortModel le)).Perm l := by + simp [insertionSort_eval, List.perm_insertionSort] + +theorem insertionSort_sorted + (l : List α) (le : α → α → Prop) [DecidableRel le] [Std.Total le] [IsTrans α le] : + ((insertionSort l).eval (sortModel le)).Pairwise le := by + simpa using List.pairwise_insertionSort _ _ + +lemma insertionSort_length (l : List α) (le : α → α → Prop) [DecidableRel le] : + ((insertionSort l).eval (sortModel le)).length = l.length := by + simp + +lemma insertionSort_time_compares (head : α) (tail : List α) (le : α → α → Prop) [DecidableRel le] : + ((insertionSort (head :: tail)).time (sortModel le)).compares = + ((insertionSort tail).time (sortModel le)).compares + + ((insertOrd head (tail.insertionSort le)).time (sortModel le)).compares := by + simp [insertionSort] + +lemma insertionSort_time_inserts (head : α) (tail : List α) (le : α → α → Prop) [DecidableRel le] : + ((insertionSort (head :: tail)).time (sortModel le)).inserts = + ((insertionSort tail).time (sortModel le)).inserts + + ((insertOrd head (tail.insertionSort le)).time (sortModel le)).inserts := by + simp [insertionSort] + +theorem insertionSort_complexity (l : List α) (le : α → α → Prop) [DecidableRel le] : + ((insertionSort l).time (sortModel le)) + ≤ ⟨l.length * (l.length + 1), (l.length + 1) * (l.length + 2)⟩ := by + induction l with + | nil => + simp [insertionSort] + | cons head tail ih => + have h := insertOrd_complexity_upper_bound (tail.insertionSort le) head le + simp_all only [List.length_cons, List.length_insertionSort] + obtain ⟨ih₁,ih₂⟩ := ih + obtain ⟨h₁,h₂⟩ := h + refine ⟨?_, ?_⟩ + · clear h₂ + rw [insertionSort_time_compares] + nlinarith [ih₁, h₁] + · clear h₁ + rw [insertionSort_time_inserts] + nlinarith [ih₂, h₂] + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Algorithms/ListLinearSearch.lean b/Cslib/AlgorithmsTheory/Algorithms/ListLinearSearch.lean new file mode 100644 index 000000000..0a1f5c3a9 --- /dev/null +++ b/Cslib/AlgorithmsTheory/Algorithms/ListLinearSearch.lean @@ -0,0 +1,83 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric Wieser +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Cslib.AlgorithmsTheory.Models.ListComparisonSearch +public import Mathlib + +@[expose] public section + +/-! +# Linear search in a list + +In this file we state and prove the correctness and complexity of linear search in lists under +the `ListSearch` model. +-- + +## Main Definitions + +- `listLinearSearch` : Linear search algorithm in the `ListSearch` query model + +## Main results + +- `listLinearSearch_eval`: `insertOrd` evaluates identically to `List.contains`. +- `listLinearSearchM_time_complexity_upper_bound` : `linearSearch` takes at most `n` + comparison operations +- `listLinearSearchM_time_complexity_lower_bound` : There exist lists on which `linearSearch` needs + `n` comparisons +-/ +namespace Cslib + +namespace Algorithms + +open Prog + +open ListSearch in +/-- Linear Search in Lists on top of the `ListSearch` query model. -/ +def listLinearSearch (l : List α) (x : α) : Prog (ListSearch α) Bool := do + match l with + | [] => return false + | l :: ls => + let cmp : Bool ← compare (l :: ls) x + if cmp then + return true + else + listLinearSearch ls x + +@[simp, grind =] +lemma listLinearSearch_eval [BEq α] (l : List α) (x : α) : + (listLinearSearch l x).eval ListSearch.natCost = l.contains x := by + fun_induction l.elem x with simp_all [listLinearSearch] + +lemma listLinearSearchM_correct_true [BEq α] [LawfulBEq α] (l : List α) + {x : α} (x_mem_l : x ∈ l) : (listLinearSearch l x).eval ListSearch.natCost = true := by + simp [x_mem_l] + +lemma listLinearSearchM_correct_false [BEq α] [LawfulBEq α] (l : List α) + {x : α} (x_mem_l : x ∉ l) : (listLinearSearch l x).eval ListSearch.natCost = false := by + simp [x_mem_l] + +lemma listLinearSearchM_time_complexity_upper_bound [BEq α] (l : List α) (x : α) : + (listLinearSearch l x).time ListSearch.natCost ≤ l.length := by + fun_induction l.elem x with + | case1 => simp [listLinearSearch] + | case2 => simp_all [listLinearSearch] + | case3 => + simp_all [listLinearSearch] + grind + +-- This statement is wrong +lemma listLinearSearchM_time_complexity_lower_bound [DecidableEq α] [Nonempty α] : + ∃ l : List α, ∃ x : α, (listLinearSearch l x).time ListSearch.natCost = l.length := by + inhabit α + refine ⟨[], default, ?_⟩ + simp_all [ListSearch.natCost, listLinearSearch] + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Algorithms/ListOrderedInsert.lean b/Cslib/AlgorithmsTheory/Algorithms/ListOrderedInsert.lean new file mode 100644 index 000000000..4a0ebfb93 --- /dev/null +++ b/Cslib/AlgorithmsTheory/Algorithms/ListOrderedInsert.lean @@ -0,0 +1,96 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric Wieser +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Cslib.AlgorithmsTheory.Models.ListComparisonSort +public import Mathlib + +@[expose] public section + +/-! +# Ordered insertion in a list + +In this file we state and prove the correctness and complexity of ordered insertions in lists under +the `SortOps` model. This ordered insert is later used in `insertionSort` mirroring the structure +in upstream libraries for the pure lean code versions of these declarations. + +-- + +## Main Definitions + +- `insertOrd` : ordered insert algorithm in the `SortOps` query model + +## Main results + +- `insertOrd_eval`: `insertOrd` evaluates identically to `List.orderedInsert`. +- `insertOrd_complexity_upper_bound` : Shows that `insertOrd` takes at most `n` comparisons, + and `n + 1` list head-insertion operations. +- `insertOrd_sorted` : Applying `insertOrd` to a sorted list yields a sorted list. +-/ + +namespace Cslib +namespace Algorithms + +open Prog + +open SortOps + +/-- +Performs ordered insertion of `x` into a list `l` in the `SortOps` query model. +If `l` is sorted, then `x` is inserted into `l` such that the resultant list is also sorted. +-/ +def insertOrd (x : α) (l : List α) : Prog (SortOps α) (List α) := do + match l with + | [] => insertHead x l + | a :: as => + if (← cmpLE x a : Bool) then + insertHead x (a :: as) + else + let res ← insertOrd x as + insertHead a res + +@[simp] +lemma insertOrd_eval (x : α) (l : List α) (le : α → α → Prop) [DecidableRel le] : + (insertOrd x l).eval (sortModel le) = l.orderedInsert le x := by + induction l with + | nil => + simp [insertOrd, sortModel] + | cons head tail ih => + by_cases h_head : le x head + · simp [insertOrd, h_head] + · simp [insertOrd, h_head, ih] + +-- to upstream +@[simp] +lemma _root_.List.length_orderedInsert (x : α) (l : List α) [DecidableRel r] : + (l.orderedInsert r x).length = l.length + 1 := by + induction l <;> grind + +theorem insertOrd_complexity_upper_bound + (l : List α) (x : α) (le : α → α → Prop) [DecidableRel le] : + (insertOrd x l).time (sortModel le) ≤ ⟨l.length, l.length + 1⟩ := by + induction l with + | nil => + simp [insertOrd, sortModel] + | cons head tail ih => + obtain ⟨ih_compares, ih_inserts⟩ := ih + rw [insertOrd] + by_cases h_head : le x head + · simp [h_head] + · simp [h_head] + grind + +lemma insertOrd_sorted + (l : List α) (x : α) (le : α → α → Prop) [DecidableRel le] [Std.Total le] [IsTrans _ le] : + l.Pairwise le → ((insertOrd x l).eval (sortModel le)).Pairwise le := by + rw [insertOrd_eval] + exact List.Pairwise.orderedInsert _ _ + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Algorithms/MergeSort.lean b/Cslib/AlgorithmsTheory/Algorithms/MergeSort.lean new file mode 100644 index 000000000..a2a235984 --- /dev/null +++ b/Cslib/AlgorithmsTheory/Algorithms/MergeSort.lean @@ -0,0 +1,207 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric Wieser +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Cslib.AlgorithmsTheory.Models.ListComparisonSort +public import Cslib.AlgorithmsTheory.Lean.MergeSort.MergeSort +import all Cslib.AlgorithmsTheory.Lean.MergeSort.MergeSort +import all Init.Data.List.Sort.Basic +@[expose] public section + +/-! +# Merge sort in a list + +In this file we state and prove the correctness and complexity of merge sort in lists under +the `SortOps` model. +-- + +## Main Definitions +- `merge` : Merge algorithm for merging two sorted lists in the `SortOps` query model +- `mergeSort` : Merge sort algorithm in the `SortOps` query model + +## Main results + +- `mergeSort_eval`: `mergeSort` evaluates identically to the priva. +- `mergeSort_sorted` : `mergeSort` outputs a sorted list. +- `mergeSort_perm` : The output of `mergeSort` is a permutation of the input list +- `mergeSort_complexity` : `mergeSort` takes at most n * ⌈log n⌉ comparisons. +-/ +namespace Cslib.Algorithms + +open SortOpsCmp + +/-- Merge two sorted lists using comparisons in the query monad. -/ +@[simp] +def merge (x y : List α) : Prog (SortOpsCmp α) (List α) := do + match x,y with + | [], ys => return ys + | xs, [] => return xs + | x :: xs', y :: ys' => do + let cmp : Bool ← cmpLE x y + if cmp then + let rest ← merge xs' (y :: ys') + return (x :: rest) + else + let rest ← merge (x :: xs') ys' + return (y :: rest) + +lemma merge_timeComplexity (x y : List α) (le : α → α → Prop) [DecidableRel le] : + (merge x y).time (sortModelNat le) ≤ x.length + y.length := by + fun_induction List.merge x y (le · ·) with + | case1 => simp + | case2 => simp + | case3 x xs y ys hxy ihx => + suffices 1 + (merge xs (y :: ys)).time (sortModelNat le) ≤ xs.length + 1 + (ys.length + 1) by + simpa [hxy] + grind + | case4 x xs y ys hxy ihy => + suffices 1 + (merge (x :: xs) ys).time (sortModelNat le) ≤ xs.length + 1 + (ys.length + 1) by + simpa [hxy] + grind + +@[simp] +lemma merge_eval (x y : List α) (le : α → α → Prop) [DecidableRel le] : + (merge x y).eval (sortModelNat le) = List.merge x y (le · ·) := by + fun_induction List.merge with + | case1 => simp + | case2 => simp + | case3 x xs y ys ihx ihy => simp_all [merge] + | case4 x xs y ys hxy ihx => + rw [decide_eq_true_iff] at hxy + simp_all [merge, -not_le] + +lemma merge_length (x y : List α) (le : α → α → Prop) [DecidableRel le] : + ((merge x y).eval (sortModelNat le)).length = x.length + y.length := by + rw [merge_eval] + apply List.length_merge + +/-- +The `mergeSort` algorithm in the `SortOps` query model. It sorts the input list +according to the mergeSort algorithm. +-/ +def mergeSort (xs : List α) : Prog (SortOpsCmp α) (List α) := do + if xs.length < 2 then return xs + else + let half := xs.length / 2 + let left := xs.take half + let right := xs.drop half + let sortedLeft ← mergeSort left + let sortedRight ← mergeSort right + merge sortedLeft sortedRight + +/-- +The vanilla-lean version of `mergeSortNaive` that is extensionally equal to `mergeSort` +-/ +private def mergeSortNaive (xs : List α) (le : α → α → Prop) [DecidableRel le] : List α := + if xs.length < 2 then xs + else + let sortedLeft := mergeSortNaive (xs.take (xs.length/2)) le + let sortedRight := mergeSortNaive (xs.drop (xs.length/2)) le + List.merge sortedLeft sortedRight (le · ·) + +private proof_wanted mergeSortNaive_eq_mergeSort + [LinearOrder α] (xs : List α) (le : α → α → Prop) [DecidableRel le] : + mergeSortNaive xs le = xs.mergeSort + +private lemma mergeSortNaive_Perm (xs : List α) (le : α → α → Prop) [DecidableRel le] : + (mergeSortNaive xs le).Perm xs := by + fun_induction mergeSortNaive + · simp + · expose_names + rw [←(List.take_append_drop (x.length / 2) x)] + grw [List.merge_perm_append, ← ih1, ← ih2] + +@[simp] +private lemma mergeSort_eval (xs : List α) (le : α → α → Prop) [DecidableRel le] : + (mergeSort xs).eval (sortModelNat le) = mergeSortNaive xs le := by + fun_induction mergeSort with + | case1 xs h => + simp [h, mergeSortNaive, Prog.eval] + | case2 xs h n left right ihl ihr => + rw [mergeSortNaive, if_neg h] + have im := merge_eval left right + simp [ihl, ihr, merge_eval] + rfl + +private lemma mergeSortNaive_length (xs : List α) (le : α → α → Prop) [DecidableRel le] : + (mergeSortNaive xs le).length = xs.length := by + fun_induction mergeSortNaive with + | case1 xs h => + simp + | case2 xs h left right ihl ihr => + rw [List.length_merge] + convert congr($ihl + $ihr) + rw [← List.length_append] + simp + +lemma mergeSort_length (xs : List α) (le : α → α → Prop) [DecidableRel le] : + ((mergeSort xs).eval (sortModelNat le)).length = xs.length := by + rw [mergeSort_eval] + apply mergeSortNaive_length + +lemma merge_sorted_sorted + (xs ys : List α) (le : α → α → Prop) [DecidableRel le] [Std.Total le] [IsTrans _ le] + (hxs_mono : xs.Pairwise le) (hys_mono : ys.Pairwise le) : + ((merge xs ys).eval (sortModelNat le)).Pairwise le := by + rw [merge_eval] + grind [hxs_mono.merge hys_mono] + +private lemma mergeSortNaive_sorted + (xs : List α) (le : α → α → Prop) [DecidableRel le] [Std.Total le] [IsTrans _ le] : + (mergeSortNaive xs le).Pairwise le := by + fun_induction mergeSortNaive with + | case1 xs h => + match xs with | [] | [x] => simp + | case2 xs h left right ihl ihr => + simpa using ihl.merge ihr + +theorem mergeSort_sorted + (xs : List α) (le : α → α → Prop) [DecidableRel le] [Std.Total le] [IsTrans _ le] : + ((mergeSort xs).eval (sortModelNat le)).Pairwise le := by + rw [mergeSort_eval] + apply mergeSortNaive_sorted + +theorem mergeSort_perm (xs : List α) (le : α → α → Prop) [DecidableRel le] : + ((mergeSort xs).eval (sortModelNat le)).Perm xs := by + rw [mergeSort_eval] + apply mergeSortNaive_Perm + +section TimeComplexity + +open Cslib.Algorithms.Lean.TimeM + +-- TODO: reuse the work in `mergeSort_time_le`? +theorem mergeSort_complexity (xs : List α) (le : α → α → Prop) [DecidableRel le] : + (mergeSort xs).time (sortModelNat le) ≤ T (xs.length) := by + fun_induction mergeSort + · simp [T] + · expose_names + simp only [FreeM.bind_eq_bind, Prog.time_bind, mergeSort_eval] + grw [merge_timeComplexity, ih1, ih2, mergeSortNaive_length, mergeSortNaive_length] + set n := x.length + have hleft_len : left.length ≤ n / 2 := by + grind + have hright_len : right.length ≤ (n + 1) / 2 := by + have hright_eq : right.length = n - n / 2 := by + simp [right, n, half, List.length_drop] + rw [hright_eq] + grind + have htleft_len : T left.length ≤ T (n / 2) := T_monotone hleft_len + have htright_len : T right.length ≤ T ((n + 1) / 2) := T_monotone hright_len + grw [htleft_len, htright_len, hleft_len, hright_len] + have hs := some_algebra (n - 2) + have hsub1 : (n - 2) / 2 + 1 = n / 2 := by grind + have hsub2 : 1 + (1 + (n - 2)) / 2 = (n + 1) / 2 := by grind + have hsub3 : (n - 2) + 2 = n := by grind + have hsplit : n / 2 + (n + 1) / 2 = n := by grind + simpa [T, hsub1, hsub2, hsub3, hsplit, Nat.add_assoc, Nat.add_left_comm, Nat.add_comm] + using hs + +end TimeComplexity + +end Cslib.Algorithms diff --git a/Cslib/Algorithms/Lean/MergeSort/MergeSort.lean b/Cslib/AlgorithmsTheory/Lean/MergeSort/MergeSort.lean similarity index 97% rename from Cslib/Algorithms/Lean/MergeSort/MergeSort.lean rename to Cslib/AlgorithmsTheory/Lean/MergeSort/MergeSort.lean index 8ba55d461..081dbf1b7 100644 --- a/Cslib/Algorithms/Lean/MergeSort/MergeSort.lean +++ b/Cslib/AlgorithmsTheory/Lean/MergeSort/MergeSort.lean @@ -6,7 +6,7 @@ Authors: Sorrachai Yingchareonthawornhcai module -public import Cslib.Algorithms.Lean.TimeM +public import Cslib.AlgorithmsTheory.Lean.TimeM public import Mathlib.Data.Nat.Cast.Order.Ring public import Mathlib.Data.Nat.Lattice public import Mathlib.Data.Nat.Log @@ -158,6 +158,10 @@ private lemma some_algebra (n : ℕ) : /-- Upper bound function for merge sort time complexity: `T(n) = n * ⌈log₂ n⌉` -/ abbrev T (n : ℕ) : ℕ := n * clog 2 n +lemma T_monotone : Monotone T := by + intro i j h_ij + exact Nat.mul_le_mul h_ij (Nat.clog_monotone 2 h_ij) + /-- Solve the recurrence -/ theorem timeMergeSortRec_le (n : ℕ) : timeMergeSortRec n ≤ T n := by fun_induction timeMergeSortRec with diff --git a/Cslib/Algorithms/Lean/TimeM.lean b/Cslib/AlgorithmsTheory/Lean/TimeM.lean similarity index 100% rename from Cslib/Algorithms/Lean/TimeM.lean rename to Cslib/AlgorithmsTheory/Lean/TimeM.lean diff --git a/Cslib/AlgorithmsTheory/Models/ListComparisonSearch.lean b/Cslib/AlgorithmsTheory/Models/ListComparisonSearch.lean new file mode 100644 index 000000000..888f223bc --- /dev/null +++ b/Cslib/AlgorithmsTheory/Models/ListComparisonSearch.lean @@ -0,0 +1,54 @@ +/- +Copyright (c) 2025 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel +public import Mathlib + +@[expose] public section + +/-! +# Query Type for Comparison Search in Lists + +In this file we define a query type `ListSearch` for comparison based searching in Lists, +whose sole query `compare` compares the head of the list with a given argument. It +further defines a model `ListSearch.natCost` for this query. + +-- +## Definitions + +- `ListSearch`: A query type for comparison based search in lists. +- `ListSearch.natCost`: A model for this query with costs in `ℕ`. + +-/ + +namespace Cslib + +namespace Algorithms + +open Prog + +/-- +A query type for searching elements in list. It supports exactly one query +`compare l val` which returns `true` if the head of the list `l` is equal to `val` +and returns `false` otherwise. +-/ +inductive ListSearch (α : Type*) : Type → Type _ where + | compare (a : List α) (val : α) : ListSearch α Bool + + +/-- A model of the `ListSearch` query type that assigns the cost as the number of queries. -/ +@[simps] +def ListSearch.natCost [BEq α] : Model (ListSearch α) ℕ where + evalQuery + | .compare l x => some x == l.head? + cost + | .compare _ _ => 1 + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/Models/ListComparisonSort.lean b/Cslib/AlgorithmsTheory/Models/ListComparisonSort.lean new file mode 100644 index 000000000..5aad6b123 --- /dev/null +++ b/Cslib/AlgorithmsTheory/Models/ListComparisonSort.lean @@ -0,0 +1,146 @@ +/- +Copyright (c) 2026 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas, Eric WIeser +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel + +@[expose] public section + +/-! +# Query Type for Comparison Search in Lists + +In this file we define two query types `SortOps` which is suitable for insertion sort, and +`SortOps`for comparison based searching in Lists. We define a model `sortModel` for `SortOps` +which uses a custom cost structure `SortOpsCost`. We define a model `sortModelCmp` for `SortOpsCmp` +which defines a `ℕ` based cost structure. +-- +## Definitions + +- `SortOps`: A query type for comparison based sorting in lists which includes queries for + comparison and head-insertion into Lists. This is a suitable query for ordered insertion + and insertion sort. +- `SortOpsCmp`: A query type for comparison based sorting that only includes a comparison query. + This is more suitable for comparison based sorts for which it is only desirable to count + comparisons + +-/ +namespace Cslib + +namespace Algorithms + +open Prog + +/-- +A model for comparison sorting on lists. +-/ +inductive SortOps (α : Type) : Type → Type where + /-- `cmpLE x y` is intended to return `true` if `x ≤ y` and `false` otherwise. + The specific order relation depends on the model provided for this typ. e-/ + | cmpLE (x : α) (y : α) : SortOps α Bool + /-- `insertHead l x` is intended to return `x :: l`. -/ + | insertHead (x : α) (l : List α) : SortOps α (List α) + +open SortOps + +section SortOpsCostModel + +/-- +A cost type for counting the operations of `SortOps` with separate fields for +counting calls to `cmpLT` and `insertHead` +-/ +@[ext, grind] +structure SortOpsCost where + /-- `compares` counts the number of calls to `cmpLT` -/ + compares : ℕ + /-- `inserts` counts the number of calls to `insertHead` -/ + inserts : ℕ + +/-- Equivalence between SortOpsCost and a product type. -/ +def SortOpsCost.equivProd : SortOpsCost ≃ (ℕ × ℕ) where + toFun sortOps := (sortOps.compares, sortOps.inserts) + invFun pair := ⟨pair.1, pair.2⟩ + left_inv _ := rfl + right_inv _ := rfl + +namespace SortOpsCost + +@[simps, grind] +instance : Zero SortOpsCost := ⟨0, 0⟩ + +@[simps] +instance : LE SortOpsCost where + le soc₁ soc₂ := soc₁.compares ≤ soc₂.compares ∧ soc₁.inserts ≤ soc₂.inserts + +instance : LT SortOpsCost where + lt soc₁ soc₂ := soc₁ ≤ soc₂ ∧ ¬soc₂ ≤ soc₁ + +@[grind] +instance : PartialOrder SortOpsCost := + fast_instance% SortOpsCost.equivProd.injective.partialOrder _ .rfl .rfl + +@[simps] +instance : Add SortOpsCost where + add soc₁ soc₂ := ⟨soc₁.compares + soc₂.compares, soc₁.inserts + soc₂.inserts⟩ + +@[simps] +instance : SMul ℕ SortOpsCost where + smul n soc := ⟨n • soc.compares, n • soc.inserts⟩ + +instance : AddCommMonoid SortOpsCost := + fast_instance% + SortOpsCost.equivProd.injective.addCommMonoid _ rfl (fun _ _ => rfl) (fun _ _ => rfl) + +end SortOpsCost + +/-- +A model of `SortOps` that uses `SortOpsCost` as the cost type for operations. + +While this accepts any decidable relation `le`, most sorting algorithms are only well-behaved in the +presence of `[Std.Total le] [IsTrans _ le]`. +-/ +@[simps, grind] +def sortModel {α : Type} (le : α → α → Prop) [DecidableRel le] : Model (SortOps α) SortOpsCost where + evalQuery + | .cmpLE x y => decide (le x y) + | .insertHead x l => x :: l + cost + | .cmpLE _ _ => ⟨1,0⟩ + | .insertHead _ _ => ⟨0,1⟩ + +end SortOpsCostModel + +section NatModel + +/-- +A model for comparison sorting on lists with only the comparison operation. This +is used in mergeSort. +-/ +inductive SortOpsCmp.{u} (α : Type u) : Type → Type _ where + /-- `cmpLE x y` is intended to return `true` if `x ≤ y` and `false` otherwise. + The specific order relation depends on the model provided for this type. -/ + | cmpLE (x : α) (y : α) : SortOpsCmp α Bool + +/-- +A model of `SortOps` that uses `ℕ` as the type for the cost of operations. In this model, +both comparisons and insertions are counted in a single `ℕ` parameter. + +While this accepts any decidable relation `le`, most sorting algorithms are only well-behaved in the +presence of `[Std.Total le] [IsTrans _ le]`. +-/ +@[simps] +def sortModelNat {α : Type*} + (le : α → α → Prop) [DecidableRel le] : Model (SortOpsCmp α) ℕ where + evalQuery + | .cmpLE x y => decide (le x y) + cost + | .cmpLE _ _ => 1 + +end NatModel + +end Algorithms + +end Cslib diff --git a/Cslib/AlgorithmsTheory/QueryModel.lean b/Cslib/AlgorithmsTheory/QueryModel.lean new file mode 100644 index 000000000..319807c05 --- /dev/null +++ b/Cslib/AlgorithmsTheory/QueryModel.lean @@ -0,0 +1,150 @@ +/- +Copyright (c) 2025 Tanner Duve. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Tanner Duve, Shreyas Srinivas, Eric Wieser +-/ + +module + +public import Mathlib +public import Cslib.Foundations.Control.Monad.Free.Fold +public import Cslib.AlgorithmsTheory.Lean.TimeM + +@[expose] public section + +/- +# Query model + +This file defines a simple query language modeled as a free monad over a +parametric type of query operations. + +## Main definitions + +- `Model Q c`: A model type for a query type `Q : Type u → Type u` and cost type `c` +- `Prog Q α`: The type of programs of query type `Q` and return type `α`. + This is a free monad under the hood +- `Prog.eval`, `Prog.time`: concrete execution semantics of a `Prog Q α` for a given model of `Q` + +## How to set up an algorithm + +This model is a lightweight framework for specifying and verifying both the correctness +and complexity of algorithms in lean. To specify an algorithm, one must: +1. Define an inductive type of queries. This type must at least one index parameter + which determines the output type of the query. Additionally, it helps to have a parameter `α` + on which the index type depends. This way, any instance parameters of `α` can be used easily + for the output types. The signatures of `Model.evalQuery` and `Model.cost` are fixed. + So you can't supply instances for the index type there. +2. Define a record of the `Model Q C` structure that specifies the evaluation and time (cost) of + each query +3. Write your algorithm as a monadic program in `Prog Q α`. With sufficient type anotations + each query `q : Q` is automatically lifted into `Prog Q α`. + +## Tags +query model, free monad, time complexity, Prog +-/ + +namespace Cslib + +namespace Algorithms + +/-- +A model type for a query type `QType` and cost type `Cost`. It consists of +two fields, which respectively define the evaluation and cost of a query. +-/ +structure Model (QType : Type u → Type v) (Cost : Type w) where + /-- Evaluates a query `q : Q ι` to return a result of type `ι`. -/ + evalQuery : QType ι → ι + /-- Counts the operational cost of a query `q : Q ι` to return a result of type `Cost`. + The cost could represent any desired complexity measure, + including but not limited to time complexity. -/ + cost : QType ι → Cost + + +open Cslib.Algorithms.Lean in +/-- lift `Model.cost` to `TimeM Cost ι` -/ +abbrev Model.timeQuery + (M : Model Q Cost) (x : Q ι) : TimeM Cost ι := + TimeM.mk (M.evalQuery x) (M.cost x) + +/-- +A program is defined as a Free Monad over a Query type `Q` which operates on a base type `α` +which can determine the input and output types of a query. +-/ +abbrev Prog Q α := FreeM Q α + +/-- +The evaluation function of a program `P : Prog Q α` given a model `M : Model Q α` of `Q` +-/ +def Prog.eval + (P : Prog Q α) (M : Model Q Cost) : α := + Id.run <| P.liftM fun x => pure (M.evalQuery x) + +@[simp, grind =] +theorem Prog.eval_pure (a : α) (M : Model Q Cost) : + Prog.eval (FreeM.pure a) M = a := + rfl + +@[simp, grind =] +theorem Prog.eval_bind + (x : Prog Q α) (f : α → Prog Q β) (M : Model Q Cost) : + Prog.eval (FreeM.bind x f) M = Prog.eval (f (x.eval M)) M := by + simp [Prog.eval] + +@[simp, grind =] +theorem Prog.eval_liftBind + (x : Q α) (f : α → Prog Q β) (M : Model Q Cost) : + Prog.eval (FreeM.liftBind x f) M = Prog.eval (f <| M.evalQuery x) M := by + simp [Prog.eval] + +/-- +The cost function of a program `P : Prog Q α` given a model `M : Model Q α` of `Q`. +The most common use case of this function is to compute time-complexity, hence the name. + +In practice this is only well-behaved in the presence of `AddCommMonoid Cost`. +-/ +def Prog.time [AddZero Cost] + (P : Prog Q α) (M : Model Q Cost) : Cost := + (P.liftM M.timeQuery).time + +@[simp, grind =] +lemma Prog.time_pure [AddZero Cost] (a : α) (M : Model Q Cost) : + Prog.time (FreeM.pure a) M = 0 := by + simp [time] + +@[simp, grind =] +theorem Prog.time_liftBind [AddZero Cost] + (x : Q α) (f : α → Prog Q β) (M : Model Q Cost) : + Prog.time (FreeM.liftBind x f) M = M.cost x + Prog.time (f <| M.evalQuery x) M := by + simp [Prog.time] + +@[simp, grind =] +lemma Prog.time_bind [AddCommMonoid Cost] (M : Model Q Cost) + (op : Prog Q ι) (cont : ι → Prog Q α) : + Prog.time (op.bind cont) M = + Prog.time op M + Prog.time (cont (Prog.eval op M)) M := by + simp only [eval, time] + induction op with + | pure a => + simp + | liftBind op cont' ih => + specialize ih (M.evalQuery op) + simp_all [add_assoc] + +section Reduction + +/-- A reduction structure from query type `Q₁` to query type `Q₂`. -/ +structure Reduction (Q₁ Q₂ : Type u → Type u) where + /-- `reduce (q : Q₁ α)` is a program `P : Prog Q₂ α` that is intended to + implement `q` in the query type `Q₂` -/ + reduce : Q₁ α → Prog Q₂ α + +/-- +`Prog.reduceProg` takes a reduction structure from a query `Q₁` to `Q₂` and extends its +`reduce` function to programs on the query type `Q₁`. +-/ +abbrev Prog.reduceProg (P : Prog Q₁ α) (red : Reduction Q₁ Q₂) : Prog Q₂ α := + P.liftM red.reduce + +end Reduction + +end Cslib.Algorithms diff --git a/Cslib/Foundations/Control/Monad/Free.lean b/Cslib/Foundations/Control/Monad/Free.lean index 9cf40c322..d27476d40 100644 --- a/Cslib/Foundations/Control/Monad/Free.lean +++ b/Cslib/Foundations/Control/Monad/Free.lean @@ -96,7 +96,7 @@ variable {F : Type u → Type v} {ι : Type u} {α : Type w} {β : Type w'} {γ instance : Pure (FreeM F) where pure := .pure -@[simp] +@[simp, grind =] theorem pure_eq_pure : (pure : α → FreeM F α) = FreeM.pure := rfl /-- Bind operation for the `FreeM` monad. -/ @@ -115,7 +115,7 @@ protected theorem bind_assoc (x : FreeM F α) (f : α → FreeM F β) (g : β instance : Bind (FreeM F) where bind := .bind -@[simp] +@[simp, grind =] theorem bind_eq_bind {α β : Type w} : Bind.bind = (FreeM.bind : FreeM F α → _ → FreeM F β) := rfl /-- Map a function over a `FreeM` monad. -/ @@ -154,14 +154,21 @@ lemma map_lift (f : ι → α) (op : F ι) : map f (lift op : FreeM F ι) = liftBind op (fun z => (.pure (f z) : FreeM F α)) := rfl /-- `.pure a` followed by `bind` collapses immediately. -/ -@[simp] +@[simp, grind =] lemma pure_bind (a : α) (f : α → FreeM F β) : (.pure a : FreeM F α).bind f = f a := rfl -@[simp] +@[simp, grind =] +lemma pure_bind' {α β} (a : α) (f : α → FreeM F β) : (.pure a : FreeM F α) >>= f = f a := + pure_bind a f + +@[simp, grind =] lemma bind_pure : ∀ x : FreeM F α, x.bind (.pure) = x | .pure a => rfl | liftBind op k => by simp [FreeM.bind, bind_pure] +@[simp, grind =] +lemma bind_pure' : ∀ x : FreeM F α, x >>= .pure = x := bind_pure + @[simp] lemma bind_pure_comp (f : α → β) : ∀ x : FreeM F α, x.bind (.pure ∘ f) = map f x | .pure a => rfl @@ -223,6 +230,9 @@ lemma liftM_bind [LawfulMonad m] rw [FreeM.bind, liftM_liftBind, liftM_liftBind, bind_assoc] simp_rw [ih] +instance {Q α} : CoeOut (Q α) (FreeM Q α) where + coe := FreeM.lift + /-- A predicate stating that `interp : FreeM F α → m α` is an interpreter for the effect handler `handler : ∀ {α}, F α → m α`. diff --git a/CslibTests.lean b/CslibTests.lean index 73292aef3..c1c44021a 100644 --- a/CslibTests.lean +++ b/CslibTests.lean @@ -11,4 +11,6 @@ public import CslibTests.HasFresh public import CslibTests.ImportWithMathlib public import CslibTests.LTS public import CslibTests.LambdaCalculus +public import CslibTests.QueryModel.ProgExamples +public import CslibTests.QueryModel.QueryExamples public import CslibTests.Reduction diff --git a/CslibTests/QueryModel/ProgExamples.lean b/CslibTests/QueryModel/ProgExamples.lean new file mode 100644 index 000000000..18c70e985 --- /dev/null +++ b/CslibTests/QueryModel/ProgExamples.lean @@ -0,0 +1,122 @@ +/- +Copyright (c) 2025 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel + +@[expose] public section + +namespace Cslib + +namespace Algorithms + +namespace Prog + +section ProgExamples + +inductive Arith (α : Type u) : Type u → Type _ where + | add (x y : α) : Arith α α + | mul (x y : α) : Arith α α + | neg (x : α) : Arith α α + | zero : Arith α α + | one : Arith α α + +def Arith.natCost [Ring α] : Model (Arith α) ℕ where + evalQuery + | .add x y => x + y + | .mul x y => x * y + | .neg x => -x + | .zero => 0 + | .one => 1 + cost _ := 1 + +open Arith in +def ex1 : Prog (Arith α) α := do + let mut x : α ← @zero α + let mut y ← @one α + let z ← (add x y) + let w ← @neg α (← add z y) + add w z + +/-- The array version of the sort operations. -/ +inductive VecSortOps.{u} (α : Type u) : Type u → Type _ where + | swap (a : Vector α n) (i j : Fin n) : VecSortOps α (Vector α n) + -- Note that we have to ULift the result to fit this in the same universe as the other types. + -- We can avoid this only by forcing everything to be in `Type 0`. + | cmp (a : Vector α n) (i j : Fin n) : VecSortOps α (ULift Bool) + | write (a : Vector α n) (i : Fin n) (x : α) : VecSortOps α (Vector α n) + | read (a : Vector α n) (i : Fin n) : VecSortOps α α + | push (a : Vector α n) (elem : α) : VecSortOps α (Vector α (n + 1)) + +/-- The typical means of evaluating a `VecSortOps`. -/ +@[simp] +def VecSortOps.eval [BEq α] : VecSortOps α β → β + | .write v i x => v.set i x + | .cmp l i j => .up <| l[i] == l[j] + | .read l i => l[i] + | .swap l i j => l.swap i j + | .push a elem => a.push elem + +@[simps] +def VecSortOps.worstCase [DecidableEq α] : Model (VecSortOps α) ℕ where + evalQuery := VecSortOps.eval + cost + | .write _ _ _ => 1 + | .read _ _ => 1 + | .cmp _ _ _ => 1 + | .swap _ _ _ => 1 + | .push _ _ => 2 -- amortized over array insertion and resizing by doubling + +@[simps] +def VecSortOps.cmpSwap [DecidableEq α] : Model (VecSortOps α) ℕ where + evalQuery := VecSortOps.eval + cost + | .cmp _ _ _ => 1 + | .swap _ _ _ => 1 + | _ => 0 + +open VecSortOps in +def simpleExample (v : Vector ℤ n) (i k : Fin n) : + Prog (VecSortOps ℤ) (Vector ℤ (n + 1)) := do + let b : Vector ℤ n ← write v i 10 + let mut c : Vector ℤ n ← swap b i k + let elem ← read c i + push c elem + +inductive VecSearch (α : Type u) : Type → Type _ where + | compare (a : Vector α n) (i : ℕ) (val : α) : VecSearch α Bool + +@[simps] +def VecSearch.nat [DecidableEq α] : Model (VecSearch α) ℕ where + evalQuery + | .compare l i x => l[i]? == some x + cost + | .compare _ _ _ => 1 + +open VecSearch in +def linearSearchAux (v : Vector α n) + (x : α) (acc : Bool) (index : ℕ) : Prog (VecSearch α) Bool := do + if h : index ≥ n then + return acc + else + let cmp_res : Bool ← compare v index x + if cmp_res then + return true + else + linearSearchAux v x false (index + 1) + +open VecSearch in +def linearSearch (v : Vector α n) (x : α) : Prog (VecSearch α) Bool:= + linearSearchAux v x false 0 + +end ProgExamples + +end Prog + +end Algorithms + +end Cslib diff --git a/CslibTests/QueryModel/QueryExamples.lean b/CslibTests/QueryModel/QueryExamples.lean new file mode 100644 index 000000000..6d9b11c41 --- /dev/null +++ b/CslibTests/QueryModel/QueryExamples.lean @@ -0,0 +1,77 @@ +/- +Copyright (c) 2025 Shreyas Srinivas. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Shreyas Srinivas +-/ + +module + +public import Cslib.AlgorithmsTheory.QueryModel + + +@[expose] public section + +namespace Cslib + +namespace Algorithms + +section Examples + +/-- +ListOps provides an example of list query type equipped with a `find` query. +The complexity of this query depends on the search algorithm used. This means +we can define two separate models for modelling situations where linear search +or binary search is used. +-/ +inductive ListOps (α : Type u) : Type u → Type _ where + | get (l : List α) (i : Fin l.length) : ListOps α α + | find (l : List α) (elem : α) : ListOps α (ULift ℕ) + | write (l : List α) (i : Fin l.length) (x : α) : ListOps α (List α) + +/-- The typical means of evaluating a `ListOps`. -/ +@[simp] +def ListOps.eval [BEq α] : ListOps α ι → ι + | .write l i x => l.set i x + | .find l elem => l.findIdx (· == elem) + | .get l i => l[i] + +@[simps] +def ListOps.linSearchWorstCase [DecidableEq α] : Model (ListOps α) ℕ where + evalQuery := ListOps.eval + cost + | .write l _ _ => l.length + | .find l _ => l.length + | .get l _ => l.length + +def ListOps.binSearchWorstCase [BEq α] : Model (ListOps α) ℕ where + evalQuery := ListOps.eval + cost + | .find l _ => 1 + Nat.log 2 (l.length) + | .write l _ _ => l.length + | .get l _ => l.length + +inductive ArrayOps (α : Type u) : Type u → Type _ where + | get (l : Array α) (i : Fin l.size) : ArrayOps α α + | find (l : Array α) (x : α) : ArrayOps α (ULift ℕ) + | write (l : Array α) (i : Fin l.size) (x : α) : ArrayOps α (Array α) + +/-- The typical means of evaluating a `ListOps`. -/ +@[simp] +def ArrayOps.eval [BEq α] : ArrayOps α ι → ι + | .write l i x => l.set i x + | .find l elem => l.findIdx (· == elem) + | .get l i => l[i] + +@[simps] +def ArrayOps.binSearchWorstCase [BEq α] : Model (ArrayOps α) ℕ where + evalQuery := ArrayOps.eval + cost + | .find l _ => 1 + Nat.log 2 (l.size) + | .write _ _ _ => 1 + | .get _ _ => 1 + +end Examples + +end Algorithms + +end Cslib