Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/dstack/_internal/core/backends/oci/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from dstack._internal.core.backends.oci.region import OCIRegionClient
from dstack._internal.core.errors import BackendError
from dstack._internal.core.models.instances import InstanceOffer
from dstack._internal.utils.common import split_chunks
from dstack._internal.utils.common import batched
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -667,21 +667,21 @@ def add_security_group_rules(
security_group_id: str, rules: Iterable[SecurityRule], client: oci.core.VirtualNetworkClient
) -> None:
rules_details = map(SecurityRule.to_sdk_add_rule_details, rules)
for chunk in split_chunks(rules_details, ADD_SECURITY_RULES_MAX_CHUNK_SIZE):
for batch in batched(rules_details, ADD_SECURITY_RULES_MAX_CHUNK_SIZE):
client.add_network_security_group_security_rules(
security_group_id,
oci.core.models.AddNetworkSecurityGroupSecurityRulesDetails(security_rules=chunk),
oci.core.models.AddNetworkSecurityGroupSecurityRulesDetails(security_rules=batch),
)


def remove_security_group_rules(
security_group_id: str, rule_ids: Iterable[str], client: oci.core.VirtualNetworkClient
) -> None:
for chunk in split_chunks(rule_ids, REMOVE_SECURITY_RULES_MAX_CHUNK_SIZE):
for batch in batched(rule_ids, REMOVE_SECURITY_RULES_MAX_CHUNK_SIZE):
client.remove_network_security_group_security_rules(
security_group_id,
oci.core.models.RemoveNetworkSecurityGroupSecurityRulesDetails(
security_rule_ids=chunk
security_rule_ids=batch
),
)

Expand Down
31 changes: 10 additions & 21 deletions src/dstack/_internal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,27 +225,6 @@ def remove_prefix(text: str, prefix: str) -> str:
T = TypeVar("T")


def split_chunks(iterable: Iterable[T], chunk_size: int) -> Iterable[List[T]]:
"""
Splits an iterable into chunks of at most `chunk_size` items.

>>> list(split_chunks([1, 2, 3, 4, 5], 2))
[[1, 2], [3, 4], [5]]
"""

if chunk_size < 1:
raise ValueError(f"chunk_size should be a positive integer, not {chunk_size}")

chunk = []
for item in iterable:
chunk.append(item)
if len(chunk) == chunk_size:
yield chunk
chunk = []
if chunk:
yield chunk


MEMORY_UNITS = {
"B": 1,
"K": 2**10,
Expand Down Expand Up @@ -283,7 +262,17 @@ def get_or_error(v: Optional[T]) -> T:
return v


# TODO: drop after dropping Python 3.11
def batched(seq: Iterable[T], n: int) -> Iterable[List[T]]:
"""
Roughly equivalent to itertools.batched from Python 3.12.

>>> list(batched([1, 2, 3, 4, 5], 2))
[[1, 2], [3, 4], [5]]
"""

if n < 1:
raise ValueError(f"n should be a positive integer, not {n}")
it = iter(seq)
return iter(lambda: list(itertools.islice(it, n)), [])

Expand Down
20 changes: 10 additions & 10 deletions src/tests/_internal/utils/test_common.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from datetime import datetime, timedelta, timezone
from typing import Any, Iterable, List
from typing import Any, Iterable

import pytest
from freezegun import freeze_time

from dstack._internal.utils.common import (
batched,
concat_url_path,
format_duration_multiunit,
local_time,
make_proxy_url,
parse_memory,
pretty_date,
sizeof_fmt,
split_chunks,
)


Expand Down Expand Up @@ -138,9 +138,9 @@ def test_parses_memory(self, memory, as_units, expected):
assert parse_memory(memory, as_untis=as_units) == expected


class TestSplitChunks:
class TestBatched:
@pytest.mark.parametrize(
("iterable", "chunk_size", "expected_chunks"),
("iterable", "n", "expected_batches"),
[
([1, 2, 3, 4], 2, [[1, 2], [3, 4]]),
([1, 2, 3], 2, [[1, 2], [3]]),
Expand All @@ -151,15 +151,15 @@ class TestSplitChunks:
((x for x in range(5)), 3, [[0, 1, 2], [3, 4]]),
],
)
def test_split_chunks(
self, iterable: Iterable[Any], chunk_size: int, expected_chunks: List[List[Any]]
def test_batched(
self, iterable: Iterable[Any], n: int, expected_batches: list[list[Any]]
) -> None:
assert list(split_chunks(iterable, chunk_size)) == expected_chunks
assert list(batched(iterable, n)) == expected_batches

@pytest.mark.parametrize("chunk_size", [0, -1])
def test_raises_on_invalid_chunk_size(self, chunk_size: int) -> None:
@pytest.mark.parametrize("n", [0, -1])
def test_raises_on_invalid_n(self, n: int) -> None:
with pytest.raises(ValueError):
list(split_chunks([1, 2, 3], chunk_size))
list(batched([1, 2, 3], n))


@pytest.mark.parametrize(
Expand Down
Loading