Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions contributing/LOCKING.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,31 @@ Note that:

* This pattern works assuming that Postgres is using default isolation level Read Committed. By the time a transaction acquires the advisory lock, all other transactions that can take the name have committed, so their changes can be seen and a unique name is taken.
* SQLite needs a commit before selecting taken names due to Snapshot Isolation as noted above.

**Use `AsyncExitStack`**

In-memory locking typically requires taking lock for long (until commit).
Using lock context managers for in-memory locking is often hard because the lock is tied to a block:

```python
if something:
# Can't do this because the lock will be released before commit. How to lock?
async with get_locker(get_db().dialect_name).lock_ctx(...):
# ...
# ...
await session.commit()
```

Use [`contextlib.AsyncExitStack`](https://docs.python.org/3/library/contextlib.html#contextlib.AsyncExitStack):

```python
async with AsyncExitStack() as exit_stack:
if something:
# The lock will be released only on stack exit, so it's ok.
await exit_stack.enter_async_context(
get_locker(get_db().dialect_name).lock_ctx(...)
)
# ...
# ...
await session.commit()
```
329 changes: 224 additions & 105 deletions src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions src/dstack/_internal/server/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ async def new_func(*args, **kwargs):
return new_func


def is_db_sqlite() -> bool:
return get_db().dialect_name == "sqlite"


def is_db_postgres() -> bool:
return get_db().dialect_name == "postgresql"


async def sqlite_commit(session: AsyncSession):
"""
Commit an sqlite transaction.
Should be used before taking locks in active sessions to see committed changes.
"""
if is_db_sqlite():
await session.commit()


def _run_alembic_upgrade(connection):
alembic_cfg = config.Config()
alembic_cfg.set_main_option("script_location", settings.ALEMBIC_MIGRATIONS_LOCATION)
Expand Down
7 changes: 3 additions & 4 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from dstack._internal.core.models.users import GlobalRole
from dstack._internal.core.services import validate_dstack_resource_name
from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models
from dstack._internal.server.db import get_db
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
from dstack._internal.server.models import (
FleetModel,
InstanceModel,
Expand Down Expand Up @@ -675,14 +675,13 @@ async def _create_fleet(
spec: FleetSpec,
) -> Fleet:
lock_namespace = f"fleet_names_{project.name}"
if get_db().dialect_name == "sqlite":
if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
elif get_db().dialect_name == "postgresql":
elif is_db_postgres():
await session.execute(
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
)

lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
async with lock:
if spec.configuration.name is not None:
Expand Down
7 changes: 3 additions & 4 deletions src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from dstack._internal.core.services import validate_dstack_resource_name
from dstack._internal.server import settings
from dstack._internal.server.db import get_db
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
from dstack._internal.server.models import (
GatewayComputeModel,
GatewayModel,
Expand Down Expand Up @@ -148,14 +148,13 @@ async def create_gateway(
)

lock_namespace = f"gateway_names_{project.name}"
if get_db().dialect_name == "sqlite":
if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
elif get_db().dialect_name == "postgresql":
elif is_db_postgres():
await session.execute(
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
)

lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
async with lock:
if configuration.name is None:
Expand Down
7 changes: 3 additions & 4 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from dstack._internal.core.services import validate_dstack_resource_name
from dstack._internal.core.services.diff import diff_models
from dstack._internal.server import settings
from dstack._internal.server.db import get_db
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
from dstack._internal.server.models import (
FleetModel,
JobModel,
Expand Down Expand Up @@ -510,14 +510,13 @@ async def submit_run(
)

lock_namespace = f"run_names_{project.name}"
if get_db().dialect_name == "sqlite":
if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
elif get_db().dialect_name == "postgresql":
elif is_db_postgres():
await session.execute(
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
)

lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
async with lock:
# FIXME: delete_runs commits, so Postgres lock is released too early.
Expand Down
7 changes: 3 additions & 4 deletions src/dstack/_internal/server/services/volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
VolumeStatus,
)
from dstack._internal.core.services import validate_dstack_resource_name
from dstack._internal.server.db import get_db
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
from dstack._internal.server.models import (
InstanceModel,
ProjectModel,
Expand Down Expand Up @@ -215,14 +215,13 @@ async def create_volume(
_validate_volume_configuration(configuration)

lock_namespace = f"volume_names_{project.name}"
if get_db().dialect_name == "sqlite":
if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
elif get_db().dialect_name == "postgresql":
elif is_db_postgres():
await session.execute(
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
)

lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
async with lock:
if configuration.name is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import NetworkMode
from dstack._internal.core.models.configurations import TaskConfiguration
from dstack._internal.core.models.fleets import FleetNodesSpec
from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement
from dstack._internal.core.models.health import HealthStatus
from dstack._internal.core.models.instances import (
InstanceAvailability,
Expand Down Expand Up @@ -546,6 +546,7 @@ async def test_assigns_multi_node_job_to_shared_instance(self, test_db, session:
)
offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
fleet_spec = get_fleet_spec()
fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER
fleet_spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=None)
fleet = await create_fleet(session=session, project=project, spec=fleet_spec)
instance = await create_instance(
Expand Down Expand Up @@ -1189,6 +1190,59 @@ async def test_provisions_compute_group(self, test_db, session: AsyncSession):
res = await session.execute(select(ComputeGroupModel))
assert res.scalar() is not None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_provisioning_master_job_respects_cluster_placement_in_non_empty_fleet(
self, test_db, session: AsyncSession
):
project = await create_project(session)
user = await create_user(session)
repo = await create_repo(session=session, project_id=project.id)
fleet_spec = get_fleet_spec()
fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER
fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=None)
fleet = await create_fleet(session=session, project=project, spec=fleet_spec)
await create_instance(
session=session,
project=project,
fleet=fleet,
status=InstanceStatus.BUSY,
backend=BackendType.AWS,
job_provisioning_data=get_job_provisioning_data(region="eu-west-1"),
)
configuration = TaskConfiguration(image="debian", nodes=2)
run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration)
run = await create_run(
session=session,
run_name="run",
project=project,
repo=repo,
user=user,
run_spec=run_spec,
)
job = await create_job(
session=session,
run=run,
fleet=fleet,
instance_assigned=True,
)
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.AWS
offer1 = get_instance_offer_with_availability(region="eu-west-2")
offer2 = get_instance_offer_with_availability(region="eu-west-1")
backend_mock.compute.return_value.get_offers.return_value = [offer1, offer2]
backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data()
await process_submitted_jobs()
m.assert_called_once()
backend_mock.compute.return_value.get_offers.assert_called_once()
backend_mock.compute.return_value.run_job.assert_called_once()
selected_offer = backend_mock.compute.return_value.run_job.call_args[0][2]
assert selected_offer.region == "eu-west-1"
await session.refresh(job)
assert job.status == JobStatus.PROVISIONING


@pytest.mark.parametrize(
["job_network_mode", "blocks", "multinode", "network_mode", "constraints_are_set"],
Expand Down