Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions nemo/utils/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,84 @@
PACKING_ALGOS = ["first_fit_decreasing", "first_fit_shuffle"]


class _SegmentTree:
"""Segment tree over bin remaining-capacities for O(log n) first-fit queries.

Each leaf stores the remaining capacity of a bin. Internal nodes store the
maximum value among their children, which allows us to find the *leftmost*
leaf with remaining capacity >= s in O(log n) time.
"""

__slots__ = ("_n", "_tree", "_num_bins")

def __init__(self, capacity: int) -> None:
self._n = capacity
self._tree = [0] * (4 * capacity)
self._num_bins = 0

# -- internal helpers -----------------------------------------------------

def _push_up(self, node: int) -> None:
self._tree[node] = max(self._tree[2 * node], self._tree[2 * node + 1])

def _update(self, node: int, lo: int, hi: int, idx: int, val: int) -> None:
if lo == hi:
self._tree[node] = val
return
mid = (lo + hi) // 2
if idx <= mid:
self._update(2 * node, lo, mid, idx, val)
else:
self._update(2 * node + 1, mid + 1, hi, idx, val)
self._push_up(node)

def _query(self, node: int, lo: int, hi: int, s: int) -> int:
if lo == hi:
return lo
mid = (lo + hi) // 2
if self._tree[2 * node] >= s:
return self._query(2 * node, lo, mid, s)
return self._query(2 * node + 1, mid + 1, hi, s)

def _get_leaf(self, idx: int) -> int:
node, lo, hi = 1, 0, self._n - 1
while lo < hi:
mid = (lo + hi) // 2
if idx <= mid:
node, hi = 2 * node, mid
else:
node, lo = 2 * node + 1, mid + 1
return self._tree[node]

# -- public API -----------------------------------------------------------

def open_bin(self, remaining: int) -> int:
"""Open a new bin with *remaining* free capacity. Returns its 0-based index."""
idx = self._num_bins
self._num_bins += 1
self._update(1, 0, self._n - 1, idx, remaining)
return idx

def query(self, s: int) -> int:
"""Return the index of the leftmost bin with remaining capacity >= *s*, or -1."""
if self._tree[1] < s:
return -1
return self._query(1, 0, self._n - 1, s)

def update(self, idx: int, amount: int) -> None:
"""Decrease the remaining capacity of bin *idx* by *amount*."""
new_val = self._get_leaf(idx) - amount
self._update(1, 0, self._n - 1, idx, new_val)


def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> int:
"""
Finds the first bin in a list of bins that has enough space to fit a sequence of size 's'.

.. deprecated::
This O(n) linear-scan implementation is kept for reference. ``first_fit``
now uses :class:`_SegmentTree` internally for O(log n) queries.

Args:
bins: A list of lists, where each inner list represents a bin and contains the current elements in that bin.
s: The size of the sequence to be placed in a bin.
Expand All @@ -41,25 +115,57 @@ def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> in
return -1


def first_fit(seqlens: List[int], pack_size: int) -> List[List[int]]:
def first_fit(seqlens: List[int], pack_size: int, backend: str = "segment_tree") -> List[List[int]]:
"""
Packs sequences of varying lengths into bins using the First-Fit algorithm.

Args:
seqlens: A list of integers, representing the lengths of the sequences to be packed.
pack_size: The maximum capacity of each bin.
backend: The search backend to use for finding the first fitting bin.
``"segment_tree"`` (default) uses a segment tree for O(log n) queries.
``"naive"`` uses the original O(n) linear scan.

Returns:
A list of lists, where each inner list represents a bin and contains the indices
of the sequences assigned to that bin.
"""
res = []
if backend == "segment_tree":
return _first_fit_segment_tree(seqlens, pack_size)
elif backend == "naive":
return _first_fit_naive(seqlens, pack_size)
else:
raise ValueError(f"Unknown backend {backend!r}, expected 'segment_tree' or 'naive'")


def _first_fit_naive(seqlens: List[int], pack_size: int) -> List[List[int]]:
"""First-Fit packing with O(n) linear scan per sequence."""
res: List[List[int]] = []
for s in seqlens:
first_bin = find_first_bin_that_fits(res, s, pack_size)
if first_bin == -1:
res.append([s])
else:
res[first_bin].append(s)
return res


def _first_fit_segment_tree(seqlens: List[int], pack_size: int) -> List[List[int]]:
"""First-Fit packing with O(log n) segment tree queries."""
if not seqlens:
return []

tree = _SegmentTree(len(seqlens))
res: List[List[int]] = []

for s in seqlens:
first_bin = tree.query(s)
if first_bin == -1: # open a new bin
tree.open_bin(pack_size - s)
res.append([s])
else:
res[first_bin].append(s)
tree.update(first_bin, s)
return res


Expand Down
100 changes: 100 additions & 0 deletions tests/utils/test_first_fit_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import time

import pytest

from nemo.utils.sequence_packing_utils import first_fit


class TestFirstFitBackendConsistency:
"""Verify that the 'naive' and 'segment_tree' backends produce identical results."""

@pytest.mark.unit
@pytest.mark.parametrize(
"seqlens, pack_size",
[
([], 10),
([5], 10),
([10], 10),
([3, 7], 10),
([6, 6], 10),
([10, 10, 10], 10),
([5, 3, 7, 2, 4], 10),
([1, 1, 1, 1, 1], 3),
([3, 3, 3, 3], 5),
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 10),
([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], 10),
([1] * 100, 10),
([3, 7, 2, 8, 1, 9, 4, 6, 5, 10], 15),
],
ids=[
"empty",
"single",
"exact_pack_size",
"two_fit_one_bin",
"overflow_new_bin",
"one_per_bin",
"mixed_small",
"all_ones",
"uniform_3",
"ascending",
"descending",
"100_ones",
"mixed_large_pack",
],
)
def test_backends_match(self, seqlens, pack_size):
naive = first_fit(seqlens, pack_size, backend="naive")
segment_tree = first_fit(seqlens, pack_size, backend="segment_tree")
assert naive == segment_tree

@pytest.mark.unit
def test_backends_match_random_large(self):
"""Compare backends on 5000 random sequences."""
rng = random.Random(12345)
seqlens = [rng.randint(1, 500) for _ in range(5000)]
pack_size = 1024
naive = first_fit(seqlens, pack_size, backend="naive")
segment_tree = first_fit(seqlens, pack_size, backend="segment_tree")
assert naive == segment_tree

@pytest.mark.unit
def test_invalid_backend_raises(self):
with pytest.raises(ValueError, match="Unknown backend"):
first_fit([1, 2, 3], 10, backend="invalid")


class TestFirstFitBackendPerformance:
"""Benchmark naive vs segment_tree to confirm the speedup."""

@pytest.mark.unit
def test_segment_tree_faster_than_naive(self):
rng = random.Random(42)
seqlens = [rng.randint(1, 500) for _ in range(10000)]
pack_size = 1024

t0 = time.perf_counter()
first_fit(seqlens, pack_size, backend="naive")
naive_time = time.perf_counter() - t0

t0 = time.perf_counter()
first_fit(seqlens, pack_size, backend="segment_tree")
st_time = time.perf_counter() - t0

speedup = naive_time / st_time
print(f"\nnaive: {naive_time:.3f}s | segment_tree: {st_time:.3f}s | speedup: {speedup:.1f}x")
assert speedup > 2, f"Expected significant speedup, got only {speedup:.1f}x"
Loading