From 869971128697adecb4eeea604458ad5fae8af87b Mon Sep 17 00:00:00 2001 From: Hongbo Zhang Date: Mon, 23 Mar 2026 14:43:21 +0800 Subject: [PATCH 1/2] perf(sorted_set): rewrite set algebra ops with split/join tree recursion Rewrite intersection, difference, and symmetric_difference to use the split/join approach instead of naive contains+add loops. Before: intersection/difference were O(n * log(m)) where n is the size of one set and m is the other, using contains() + add() per element which each do O(log m) work including tree rebalancing. After: All three operations use split_member + join/join2 tree recursion, achieving O(n * log(m/n + 1)) which is optimal for set operations on balanced trees. Also fix union to track size through recursion instead of doing a full O(n) tree traversal to recount after construction. Helper functions added: - split_member: like split but also reports if pivot was found - join2: join two trees where all left < all right (no pivot) - split_min: extract minimum element from a tree - tree_count: count nodes in a tree Co-Authored-By: Claude Opus 4.6 (1M context) --- sorted_set/set.mbt | 151 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 129 insertions(+), 22 deletions(-) diff --git a/sorted_set/set.mbt b/sorted_set/set.mbt index 12c72a0fd..89d708c86 100644 --- a/sorted_set/set.mbt +++ b/sorted_set/set.mbt @@ -68,6 +68,15 @@ fn[V] copy_tree(node : Node[V]?) -> Node[V]? { } } +///| +/// Counts the number of nodes in a tree. +fn[V] tree_count(node : Node[V]?) -> Int { + match node { + None => 0 + Some(n) => 1 + tree_count(n.left) + tree_count(n.right) + } +} + ///| fn[V] new_node( value : V, @@ -144,15 +153,17 @@ pub fn[V : Compare] SortedSet::union( self : SortedSet[V], src : SortedSet[V], ) -> SortedSet[V] { - fn aux(a : Node[V]?, b : Node[V]?) -> Node[V]? { + fn aux(a : Node[V]?, b : Node[V]?) -> (Node[V]?, Int) { match (a, b) { - (Some(_), None) => a - (None, Some(_)) => b + (Some(_), None) => (a, tree_count(a)) + (None, Some(_)) => (b, tree_count(b)) (Some({ value: va, left: la, right: ra, .. }), Some(_)) => { let (l, r) = split(b, va) - Some(join(aux(la, l), va, aux(ra, r))) + let (left, lsize) = aux(la, l) + let (right, rsize) = aux(ra, r) + (Some(join(left, va, right)), lsize + 1 + rsize) } - (None, None) => None + (None, None) => (None, 0) } } @@ -160,13 +171,8 @@ pub fn[V : Compare] SortedSet::union( (Some(_), Some(_)) => { let t1 = copy_tree(self.root) let t2 = copy_tree(src.root) - let t = aux(t1, t2) - let mut ct = 0 - let ret = { root: t, size: 0 } - // TODO: optimize this. Avoid counting the size of the set. - ret.each(_x => ct = ct + 1) - ret.size = ct - ret + let (t, ct) = aux(t1, t2) + { root: t, size: ct } } (Some(_), None) => { root: copy_tree(self.root), size: self.size } (None, Some(_)) => { root: copy_tree(src.root), size: src.size } @@ -193,6 +199,54 @@ fn[V : Compare] split(root : Node[V]?, value : V) -> (Node[V]?, Node[V]?) { } } +///| +/// Like `split`, but also returns whether the pivot value was found in the tree. +fn[V : Compare] split_member( + root : Node[V]?, + value : V, +) -> (Node[V]?, Bool, Node[V]?) { + match root { + None => (None, false, None) + Some(node) => { + let comp = value.compare(node.value) + if comp == 0 { + (node.left, true, node.right) + } else if comp < 0 { + let (l, found, r) = split_member(node.left, value) + (l, found, Some(join(r, node.value, node.right))) + } else { + let (l, found, r) = split_member(node.right, value) + (Some(join(node.left, node.value, l)), found, r) + } + } + } +} + +///| +/// Joins two trees where all elements in `left` are less than all elements in `right`. +fn[V] join2(left : Node[V]?, right : Node[V]?) -> Node[V]? { + match (left, right) { + (None, _) => right + (_, None) => left + _ => { + let (min, right2) = split_min(right.unwrap()) + Some(join(left, min, right2)) + } + } +} + +///| +/// Removes and returns the minimum element from a non-empty tree. +fn[V] split_min(node : Node[V]) -> (V, Node[V]?) { + match node.left { + None => (node.value, node.right) + Some(left) => { + let (min, new_left) = split_min(left) + (min, Some(join(new_left, node.value, node.right))) + } + } +} + ///| fn[V] join(left : Node[V]?, value : V, right : Node[V]?) -> Node[V] { let (hl, hr) = (height(left), height(right)) @@ -262,9 +316,27 @@ pub fn[V : Compare] SortedSet::difference( self : SortedSet[V], src : SortedSet[V], ) -> SortedSet[V] { - let ret = new() - self.each(x => if !src.contains(x) { ret.add(x) }) - ret + fn aux(a : Node[V]?, b : Node[V]?) -> (Node[V]?, Int) { + match (a, b) { + (None, _) => (None, 0) + (Some(_), None) => (a, tree_count(a)) + (Some({ value: va, left: la, right: ra, .. }), Some(_)) => { + let (lb, found, rb) = split_member(b, va) + let (left, lsize) = aux(la, lb) + let (right, rsize) = aux(ra, rb) + if found { + (join2(left, right), lsize + rsize) + } else { + (Some(join(left, va, right)), lsize + 1 + rsize) + } + } + } + } + + let t1 = copy_tree(self.root) + let t2 = copy_tree(src.root) + let (t, ct) = aux(t1, t2) + { root: t, size: ct } } ///| @@ -294,10 +366,28 @@ pub fn[V : Compare] SortedSet::symmetric_difference( self : SortedSet[V], other : SortedSet[V], ) -> SortedSet[V] { - // TODO: Optimize this function to avoid creating two intermediate sets. - let set1 = self.difference(other) - let set2 = other.difference(self) - set1.union(set2) + fn aux(a : Node[V]?, b : Node[V]?) -> (Node[V]?, Int) { + match (a, b) { + (None, Some(_)) => (b, tree_count(b)) + (Some(_), None) => (a, tree_count(a)) + (Some({ value: va, left: la, right: ra, .. }), Some(_)) => { + let (lb, found, rb) = split_member(b, va) + let (left, lsize) = aux(la, lb) + let (right, rsize) = aux(ra, rb) + if found { + (join2(left, right), lsize + rsize) + } else { + (Some(join(left, va, right)), lsize + 1 + rsize) + } + } + (None, None) => (None, 0) + } + } + + let t1 = copy_tree(self.root) + let t2 = copy_tree(other.root) + let (t, ct) = aux(t1, t2) + { root: t, size: ct } } ///| @@ -307,9 +397,26 @@ pub fn[V : Compare] SortedSet::intersection( self : SortedSet[V], src : SortedSet[V], ) -> SortedSet[V] { - let ret = new() - self.each(x => if src.contains(x) { ret.add(x) }) - ret + fn aux(a : Node[V]?, b : Node[V]?) -> (Node[V]?, Int) { + match (a, b) { + (None, _) | (_, None) => (None, 0) + (Some({ value: va, left: la, right: ra, .. }), Some(_)) => { + let (lb, found, rb) = split_member(b, va) + let (left, lsize) = aux(la, lb) + let (right, rsize) = aux(ra, rb) + if found { + (Some(join(left, va, right)), lsize + 1 + rsize) + } else { + (join2(left, right), lsize + rsize) + } + } + } + } + + let t1 = copy_tree(self.root) + let t2 = copy_tree(src.root) + let (t, ct) = aux(t1, t2) + { root: t, size: ct } } ///| From 21fbe71980f40736849fe109adbcd5f1599f8a65 Mon Sep 17 00:00:00 2001 From: Hongbo Zhang Date: Mon, 23 Mar 2026 15:08:33 +0800 Subject: [PATCH 2/2] test(sorted_set): add size verification tests for set algebra operations Add tests that verify the .length() of results from intersection, difference, union, and symmetric_difference are correctly tracked through the split/join recursion. Also add tests for empty set edge cases in intersection and difference. Co-Authored-By: Claude Opus 4.6 (1M context) --- sorted_set/set_test.mbt | 52 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/sorted_set/set_test.mbt b/sorted_set/set_test.mbt index a69ec7bc5..472ea48c2 100644 --- a/sorted_set/set_test.mbt +++ b/sorted_set/set_test.mbt @@ -343,3 +343,55 @@ test "@sorted_set.symmetric_difference/identical" { let result = set1.symmetric_difference(set2) inspect(result, content="@sorted_set.from_array([])") } + +///| +test "intersection with empty set" { + let set1 = @sorted_set.from_array([1, 2, 3]) + let empty : @sorted_set.SortedSet[Int] = @sorted_set.new() + inspect(set1.intersection(empty), content="@sorted_set.from_array([])") + inspect(empty.intersection(set1), content="@sorted_set.from_array([])") +} + +///| +test "intersection size is correct" { + let set1 = @sorted_set.from_array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + let set2 = @sorted_set.from_array([5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) + let result = set1.intersection(set2) + inspect(result, content="@sorted_set.from_array([5, 6, 7, 8, 9, 10])") + assert_eq(result.length(), 6) +} + +///| +test "difference with empty set" { + let set1 = @sorted_set.from_array([1, 2, 3]) + let empty : @sorted_set.SortedSet[Int] = @sorted_set.new() + inspect(set1.difference(empty), content="@sorted_set.from_array([1, 2, 3])") + inspect(empty.difference(set1), content="@sorted_set.from_array([])") +} + +///| +test "difference size is correct" { + let set1 = @sorted_set.from_array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + let set2 = @sorted_set.from_array([5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) + let result = set1.difference(set2) + inspect(result, content="@sorted_set.from_array([1, 2, 3, 4])") + assert_eq(result.length(), 4) +} + +///| +test "union size is correct" { + let set1 = @sorted_set.from_array([1, 2, 3, 4, 5]) + let set2 = @sorted_set.from_array([3, 4, 5, 6, 7]) + let result = set1.union(set2) + inspect(result, content="@sorted_set.from_array([1, 2, 3, 4, 5, 6, 7])") + assert_eq(result.length(), 7) +} + +///| +test "symmetric_difference size is correct" { + let set1 = @sorted_set.from_array([1, 2, 3, 4, 5]) + let set2 = @sorted_set.from_array([3, 4, 5, 6, 7]) + let result = set1.symmetric_difference(set2) + inspect(result, content="@sorted_set.from_array([1, 2, 6, 7])") + assert_eq(result.length(), 4) +}