From 03504deaf753e66ef37b33ff47f47fb5a511ab15 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Thu, 17 Jul 2025 14:25:01 +0200 Subject: [PATCH] [chore]: Drop duplicate utility `split_chunks` --- .../_internal/core/backends/oci/resources.py | 10 +++--- src/dstack/_internal/utils/common.py | 31 ++++++------------- src/tests/_internal/utils/test_common.py | 20 ++++++------ 3 files changed, 25 insertions(+), 36 deletions(-) diff --git a/src/dstack/_internal/core/backends/oci/resources.py b/src/dstack/_internal/core/backends/oci/resources.py index 9494db7f25..eda17109d8 100644 --- a/src/dstack/_internal/core/backends/oci/resources.py +++ b/src/dstack/_internal/core/backends/oci/resources.py @@ -26,7 +26,7 @@ from dstack._internal.core.backends.oci.region import OCIRegionClient from dstack._internal.core.errors import BackendError from dstack._internal.core.models.instances import InstanceOffer -from dstack._internal.utils.common import split_chunks +from dstack._internal.utils.common import batched from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -667,21 +667,21 @@ def add_security_group_rules( security_group_id: str, rules: Iterable[SecurityRule], client: oci.core.VirtualNetworkClient ) -> None: rules_details = map(SecurityRule.to_sdk_add_rule_details, rules) - for chunk in split_chunks(rules_details, ADD_SECURITY_RULES_MAX_CHUNK_SIZE): + for batch in batched(rules_details, ADD_SECURITY_RULES_MAX_CHUNK_SIZE): client.add_network_security_group_security_rules( security_group_id, - oci.core.models.AddNetworkSecurityGroupSecurityRulesDetails(security_rules=chunk), + oci.core.models.AddNetworkSecurityGroupSecurityRulesDetails(security_rules=batch), ) def remove_security_group_rules( security_group_id: str, rule_ids: Iterable[str], client: oci.core.VirtualNetworkClient ) -> None: - for chunk in split_chunks(rule_ids, REMOVE_SECURITY_RULES_MAX_CHUNK_SIZE): + for batch in batched(rule_ids, REMOVE_SECURITY_RULES_MAX_CHUNK_SIZE): client.remove_network_security_group_security_rules( security_group_id, oci.core.models.RemoveNetworkSecurityGroupSecurityRulesDetails( - security_rule_ids=chunk + security_rule_ids=batch ), ) diff --git a/src/dstack/_internal/utils/common.py b/src/dstack/_internal/utils/common.py index 6832078029..56382abd38 100644 --- a/src/dstack/_internal/utils/common.py +++ b/src/dstack/_internal/utils/common.py @@ -225,27 +225,6 @@ def remove_prefix(text: str, prefix: str) -> str: T = TypeVar("T") -def split_chunks(iterable: Iterable[T], chunk_size: int) -> Iterable[List[T]]: - """ - Splits an iterable into chunks of at most `chunk_size` items. - - >>> list(split_chunks([1, 2, 3, 4, 5], 2)) - [[1, 2], [3, 4], [5]] - """ - - if chunk_size < 1: - raise ValueError(f"chunk_size should be a positive integer, not {chunk_size}") - - chunk = [] - for item in iterable: - chunk.append(item) - if len(chunk) == chunk_size: - yield chunk - chunk = [] - if chunk: - yield chunk - - MEMORY_UNITS = { "B": 1, "K": 2**10, @@ -283,7 +262,17 @@ def get_or_error(v: Optional[T]) -> T: return v +# TODO: drop after dropping Python 3.11 def batched(seq: Iterable[T], n: int) -> Iterable[List[T]]: + """ + Roughly equivalent to itertools.batched from Python 3.12. + + >>> list(batched([1, 2, 3, 4, 5], 2)) + [[1, 2], [3, 4], [5]] + """ + + if n < 1: + raise ValueError(f"n should be a positive integer, not {n}") it = iter(seq) return iter(lambda: list(itertools.islice(it, n)), []) diff --git a/src/tests/_internal/utils/test_common.py b/src/tests/_internal/utils/test_common.py index ed9ef99168..1764fedcff 100644 --- a/src/tests/_internal/utils/test_common.py +++ b/src/tests/_internal/utils/test_common.py @@ -1,10 +1,11 @@ from datetime import datetime, timedelta, timezone -from typing import Any, Iterable, List +from typing import Any, Iterable import pytest from freezegun import freeze_time from dstack._internal.utils.common import ( + batched, concat_url_path, format_duration_multiunit, local_time, @@ -12,7 +13,6 @@ parse_memory, pretty_date, sizeof_fmt, - split_chunks, ) @@ -138,9 +138,9 @@ def test_parses_memory(self, memory, as_units, expected): assert parse_memory(memory, as_untis=as_units) == expected -class TestSplitChunks: +class TestBatched: @pytest.mark.parametrize( - ("iterable", "chunk_size", "expected_chunks"), + ("iterable", "n", "expected_batches"), [ ([1, 2, 3, 4], 2, [[1, 2], [3, 4]]), ([1, 2, 3], 2, [[1, 2], [3]]), @@ -151,15 +151,15 @@ class TestSplitChunks: ((x for x in range(5)), 3, [[0, 1, 2], [3, 4]]), ], ) - def test_split_chunks( - self, iterable: Iterable[Any], chunk_size: int, expected_chunks: List[List[Any]] + def test_batched( + self, iterable: Iterable[Any], n: int, expected_batches: list[list[Any]] ) -> None: - assert list(split_chunks(iterable, chunk_size)) == expected_chunks + assert list(batched(iterable, n)) == expected_batches - @pytest.mark.parametrize("chunk_size", [0, -1]) - def test_raises_on_invalid_chunk_size(self, chunk_size: int) -> None: + @pytest.mark.parametrize("n", [0, -1]) + def test_raises_on_invalid_n(self, n: int) -> None: with pytest.raises(ValueError): - list(split_chunks([1, 2, 3], chunk_size)) + list(batched([1, 2, 3], n)) @pytest.mark.parametrize(