diff --git a/contributing/LOCKING.md b/contributing/LOCKING.md index 54ee31991c..e23fb41f9e 100644 --- a/contributing/LOCKING.md +++ b/contributing/LOCKING.md @@ -108,3 +108,31 @@ Note that: * This pattern works assuming that Postgres is using default isolation level Read Committed. By the time a transaction acquires the advisory lock, all other transactions that can take the name have committed, so their changes can be seen and a unique name is taken. * SQLite needs a commit before selecting taken names due to Snapshot Isolation as noted above. + +**Use `AsyncExitStack`** + +In-memory locking typically requires taking lock for long (until commit). +Using lock context managers for in-memory locking is often hard because the lock is tied to a block: + +```python +if something: + # Can't do this because the lock will be released before commit. How to lock? + async with get_locker(get_db().dialect_name).lock_ctx(...): + # ... +# ... +await session.commit() +``` + +Use [`contextlib.AsyncExitStack`](https://docs.python.org/3/library/contextlib.html#contextlib.AsyncExitStack): + +```python +async with AsyncExitStack() as exit_stack: + if something: + # The lock will be released only on stack exit, so it's ok. + await exit_stack.enter_async_context( + get_locker(get_db().dialect_name).lock_ctx(...) + ) + # ... + # ... + await session.commit() +``` diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index a1c799990d..9f281a75b2 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -2,6 +2,7 @@ import itertools import math import uuid +from contextlib import AsyncExitStack from datetime import datetime, timedelta from typing import List, Optional, Union @@ -49,7 +50,13 @@ from dstack._internal.core.services.profiles import get_termination from dstack._internal.server import settings from dstack._internal.server.background.tasks.process_compute_groups import ComputeGroupStatus -from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.db import ( + get_db, + get_session_ctx, + is_db_postgres, + is_db_sqlite, + sqlite_commit, +) from dstack._internal.server.models import ( ComputeGroupModel, FleetModel, @@ -170,14 +177,21 @@ async def _process_next_submitted_job(): lockset.add(job_model.id) job_model_id = job_model.id try: - await _process_submitted_job(session=session, job_model=job_model) + async with AsyncExitStack() as exit_stack: + await _process_submitted_job( + exit_stack=exit_stack, + session=session, + job_model=job_model, + ) finally: lockset.difference_update([job_model_id]) global last_processed_at last_processed_at = common_utils.get_current_datetime() -async def _process_submitted_job(session: AsyncSession, job_model: JobModel): +async def _process_submitted_job( + exit_stack: AsyncExitStack, session: AsyncSession, job_model: JobModel +): # Refetch to load related attributes. res = await session.execute( select(JobModel) @@ -258,25 +272,16 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): # Then, the job runs on the assigned instance or a new instance is provisioned. # This is needed to avoid holding instances lock for a long time. if not job_model.instance_assigned: - # If another job freed the instance but is still trying to detach volumes, - # do not provision on it to prevent attaching volumes that are currently detaching. - detaching_instances_ids = await get_instances_ids_with_detaching_volumes(session) - - fleet_filters = [ - FleetModel.project_id == project.id, - FleetModel.deleted == False, - ] - if run_model.fleet is not None: - fleet_filters.append(FleetModel.id == run_model.fleet_id) - if run_spec.merged_profile.fleets is not None: - fleet_filters.append(FleetModel.name.in_(run_spec.merged_profile.fleets)) - - instance_filters = [ - InstanceModel.deleted == False, - InstanceModel.id.not_in(detaching_instances_ids), - ] - - fleet_models_with_instances, fleet_models_without_instances = await _select_fleet_models( + fleet_filters, instance_filters = await _get_candidate_fleet_models_filters( + session=session, + project=project, + run_model=run_model, + run_spec=run_spec, + ) + ( + fleet_models_with_instances, + fleet_models_without_instances, + ) = await _select_fleet_models_with_filters( session=session, fleet_filters=fleet_filters, instance_filters=instance_filters, @@ -286,66 +291,61 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): [i.id for i in f.instances] for f in fleet_models_with_instances ) ) - if get_db().dialect_name == "sqlite": - # Start new transaction to see committed changes after lock - await session.commit() - - async with get_locker(get_db().dialect_name).lock_ctx( - InstanceModel.__tablename__, instances_ids - ): - if get_db().dialect_name == "sqlite": - fleets_with_instances_ids = [f.id for f in fleet_models_with_instances] - fleet_models_with_instances = await _refetch_fleet_models_with_instances( - session=session, - fleets_ids=fleets_with_instances_ids, - instances_ids=instances_ids, - fleet_filters=fleet_filters, - instance_filters=instance_filters, - ) - fleet_models = fleet_models_with_instances + fleet_models_without_instances - fleet_model, fleet_instances_with_offers = await _find_optimal_fleet_with_offers( - project=project, - fleet_models=fleet_models, - run_model=run_model, - run_spec=run.run_spec, - job=job, - master_job_provisioning_data=master_job_provisioning_data, - volumes=volumes, - ) - if fleet_model is None: - if run_spec.merged_profile.fleets is not None: - # Run cannot create new fleets when fleets are specified - logger.debug("%s: failed to use specified fleets", fmt(job_model)) - job_model.status = JobStatus.TERMINATING - job_model.termination_reason = ( - JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY - ) - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() - return - if FeatureFlags.AUTOCREATED_FLEETS_DISABLED: - logger.debug("%s: no fleet found", fmt(job_model)) - job_model.status = JobStatus.TERMINATING - job_model.termination_reason = ( - JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY - ) - job_model.termination_reason_message = "Failed to find fleet" - job_model.last_processed_at = common_utils.get_current_datetime() - await session.commit() - return - instance = await _assign_job_to_fleet_instance( + await sqlite_commit(session) + await exit_stack.enter_async_context( + get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids) + ) + if is_db_sqlite(): + fleets_with_instances_ids = [f.id for f in fleet_models_with_instances] + fleet_models_with_instances = await _refetch_fleet_models_with_instances( session=session, - instances_with_offers=fleet_instances_with_offers, - job_model=job_model, - multinode=multinode, + fleets_ids=fleets_with_instances_ids, + instances_ids=instances_ids, + fleet_filters=fleet_filters, + instance_filters=instance_filters, ) - job_model.fleet = fleet_model - job_model.instance_assigned = True - job_model.last_processed_at = common_utils.get_current_datetime() - if len(instances_ids) > 0: + fleet_models = fleet_models_with_instances + fleet_models_without_instances + fleet_model, fleet_instances_with_offers = await _find_optimal_fleet_with_offers( + project=project, + fleet_models=fleet_models, + run_model=run_model, + run_spec=run.run_spec, + job=job, + master_job_provisioning_data=master_job_provisioning_data, + volumes=volumes, + ) + if fleet_model is None: + if run_spec.merged_profile.fleets is not None: + # Run cannot create new fleets when fleets are specified + logger.debug("%s: failed to use specified fleets", fmt(job_model)) + job_model.status = JobStatus.TERMINATING + job_model.termination_reason = ( + JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + ) + job_model.termination_reason_message = "Failed to use specified fleets" + job_model.last_processed_at = common_utils.get_current_datetime() await session.commit() return - # If no instances were locked, we can proceed in the same transaction. + if FeatureFlags.AUTOCREATED_FLEETS_DISABLED: + logger.debug("%s: no fleet found", fmt(job_model)) + job_model.status = JobStatus.TERMINATING + job_model.termination_reason = ( + JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + ) + job_model.termination_reason_message = "Failed to find fleet" + job_model.last_processed_at = common_utils.get_current_datetime() + await session.commit() + return + instance = await _assign_job_to_fleet_instance( + session=session, + fleet_model=fleet_model, + instances_with_offers=fleet_instances_with_offers, + job_model=job_model, + multinode=multinode, + ) + job_model.last_processed_at = common_utils.get_current_datetime() + await session.commit() + return # TODO: Volume attachment for compute groups is not yet supported since # currently supported compute groups (e.g. Runpod) don't need explicit volume attachment. @@ -380,6 +380,17 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): ): jobs_to_provision = replica_jobs + master_instance_provisioning_data = ( + await _fetch_fleet_with_master_instance_provisioning_data( + exit_stack=exit_stack, + session=session, + fleet_model=fleet_model, + job=job, + ) + ) + master_provisioning_data = ( + master_job_provisioning_data or master_instance_provisioning_data + ) run_job_result = await _run_jobs_on_new_instances( project=project, fleet_model=fleet_model, @@ -388,7 +399,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): jobs=jobs_to_provision, project_ssh_public_key=project.ssh_public_key, project_ssh_private_key=project.ssh_private_key, - master_job_provisioning_data=master_job_provisioning_data, + master_job_provisioning_data=master_provisioning_data, volumes=volumes, ) if run_job_result is None: @@ -401,6 +412,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): if fleet_model is None: fleet_model = await _create_fleet_model_for_job( + exit_stack=exit_stack, session=session, project=project, run=run, @@ -472,9 +484,9 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): volumes_ids = sorted([v.id for vs in volume_models for v in vs]) if need_volume_attachment: - # TODO: Lock instances for attaching volumes? # Take lock to prevent attaching volumes that are to be deleted. # If the volume was deleted before the lock, the volume will fail to attach and the job will fail. + # TODO: Lock instances for attaching volumes? await session.execute( select(VolumeModel) .where(VolumeModel.id.in_(volumes_ids)) @@ -482,22 +494,46 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): .order_by(VolumeModel.id) # take locks in order .with_for_update(key_share=True, of=VolumeModel) ) - async with get_locker(get_db().dialect_name).lock_ctx( - VolumeModel.__tablename__, volumes_ids - ): - if len(volume_models) > 0: - assert instance is not None - await _attach_volumes( - session=session, - project=project, - job_model=job_model, - instance=instance, - volume_models=volume_models, - ) - await session.commit() + await exit_stack.enter_async_context( + get_locker(get_db().dialect_name).lock_ctx(VolumeModel.__tablename__, volumes_ids) + ) + if len(volume_models) > 0: + assert instance is not None + await _attach_volumes( + session=session, + project=project, + job_model=job_model, + instance=instance, + volume_models=volume_models, + ) + await session.commit() + + +async def _get_candidate_fleet_models_filters( + session: AsyncSession, + project: ProjectModel, + run_model: RunModel, + run_spec: RunSpec, +) -> tuple[list, list]: + # If another job freed the instance but is still trying to detach volumes, + # do not provision on it to prevent attaching volumes that are currently detaching. + detaching_instances_ids = await get_instances_ids_with_detaching_volumes(session) + fleet_filters = [ + FleetModel.project_id == project.id, + FleetModel.deleted == False, + ] + if run_model.fleet is not None: + fleet_filters.append(FleetModel.id == run_model.fleet_id) + if run_spec.merged_profile.fleets is not None: + fleet_filters.append(FleetModel.name.in_(run_spec.merged_profile.fleets)) + instance_filters = [ + InstanceModel.deleted == False, + InstanceModel.id.not_in(detaching_instances_ids), + ] + return fleet_filters, instance_filters -async def _select_fleet_models( +async def _select_fleet_models_with_filters( session: AsyncSession, fleet_filters: list, instance_filters: list ) -> tuple[list[FleetModel], list[FleetModel]]: # Selecting fleets in two queries since Postgres does not allow @@ -511,6 +547,7 @@ async def _select_fleet_models( .options(contains_eager(FleetModel.instances)) .order_by(InstanceModel.id) # take locks in order .with_for_update(key_share=True, of=InstanceModel) + .execution_options(populate_existing=True) ) fleet_models_with_instances = list(res.unique().scalars().all()) fleet_models_with_instances_ids = [f.id for f in fleet_models_with_instances] @@ -527,9 +564,8 @@ async def _select_fleet_models( not_(and_(*instance_filters)), ) ) - # Load empty list of instances so that downstream code - # knows this fleet has no instances eligible for offers. .options(noload(FleetModel.instances)) + .execution_options(populate_existing=True) ) fleet_models_without_instances = list(res.unique().scalars().all()) return fleet_models_with_instances, fleet_models_without_instances @@ -554,7 +590,6 @@ async def _refetch_fleet_models_with_instances( *instance_filters, ) .options(contains_eager(FleetModel.instances)) - .execution_options(populate_existing=True) ) fleet_models = list(res.unique().scalars().all()) return fleet_models @@ -598,11 +633,20 @@ async def _find_optimal_fleet_with_offers( ] = [] for candidate_fleet_model in fleet_models: candidate_fleet = fleet_model_to_fleet(candidate_fleet_model) + if ( + job.job_spec.jobs_per_replica > 1 + and candidate_fleet.spec.configuration.placement != InstanceGroupPlacement.CLUSTER + ): + # Limit multinode runs to cluster fleets to guarantee best connectivity. + continue + fleet_instances_with_pool_offers = _get_fleet_instances_with_pool_offers( fleet_model=candidate_fleet_model, run_spec=run_spec, job=job, - master_job_provisioning_data=master_job_provisioning_data, + # No need to pass master_job_provisioning_data for master job + # as all pool offers are suitable. + master_job_provisioning_data=None, volumes=volumes, ) fleet_has_pool_capacity = nodes_required_num <= len(fleet_instances_with_pool_offers) @@ -620,6 +664,11 @@ async def _find_optimal_fleet_with_offers( except ValueError: fleet_backend_offers = [] else: + # Master job offers must be in the same cluster as existing instances. + master_instance_provisioning_data = _get_fleet_master_instance_provisioning_data( + fleet_model=candidate_fleet_model, + fleet_spec=candidate_fleet.spec, + ) # Handle multinode for old jobs that don't have requirements.multinode set. # TODO: Drop multinode param. multinode = requirements.multinode or job.job_spec.jobs_per_replica > 1 @@ -629,7 +678,7 @@ async def _find_optimal_fleet_with_offers( requirements=requirements, exclude_not_available=True, multinode=multinode, - master_job_provisioning_data=master_job_provisioning_data, + master_job_provisioning_data=master_instance_provisioning_data, volumes=volumes, privileged=job.job_spec.privileged, instance_mounts=check_run_spec_requires_instance_mounts(run_spec), @@ -685,6 +734,70 @@ def _get_nodes_required_num_for_run(run_spec: RunSpec) -> int: return nodes_required_num +def _get_fleet_master_instance_provisioning_data( + fleet_model: FleetModel, + fleet_spec: FleetSpec, +) -> Optional[JobProvisioningData]: + master_instance_provisioning_data = None + if fleet_spec.configuration.placement == InstanceGroupPlacement.CLUSTER: + # Offers for master jobs must be in the same cluster as existing instances. + fleet_instance_models = [im for im in fleet_model.instances if not im.deleted] + if len(fleet_instance_models) > 0: + master_instance_model = fleet_instance_models[0] + master_instance_provisioning_data = JobProvisioningData.__response__.parse_raw( + master_instance_model.job_provisioning_data + ) + return master_instance_provisioning_data + + +async def _fetch_fleet_with_master_instance_provisioning_data( + exit_stack: AsyncExitStack, + session: AsyncSession, + fleet_model: Optional[FleetModel], + job: Job, +) -> Optional[JobProvisioningData]: + master_instance_provisioning_data = None + if job.job_spec.job_num == 0 and fleet_model is not None: + fleet = fleet_model_to_fleet(fleet_model) + if fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER: + # To avoid violating fleet placement cluster during master provisioning, + # we must lock empty fleets and respect existing instances in non-empty fleets. + # On SQLite always take the lock during master provisioning for simplicity. + await exit_stack.enter_async_context( + get_locker(get_db().dialect_name).lock_ctx( + FleetModel.__tablename__, [fleet_model.id] + ) + ) + await sqlite_commit(session) + res = await session.execute( + select(FleetModel) + .outerjoin(FleetModel.instances) + .where( + FleetModel.id == fleet_model.id, + InstanceModel.id.is_(None), + ) + .with_for_update(key_share=True, of=FleetModel) + .execution_options(populate_existing=True) + .options(noload(FleetModel.instances)) + ) + empty_fleet_model = res.unique().scalar() + if empty_fleet_model is not None: + fleet_model = empty_fleet_model + else: + res = await session.execute( + select(FleetModel) + .where(FleetModel.id == fleet_model.id) + .options(joinedload(FleetModel.instances)) + .execution_options(populate_existing=True) + ) + fleet_model = res.unique().scalar_one() + master_instance_provisioning_data = _get_fleet_master_instance_provisioning_data( + fleet_model=fleet_model, + fleet_spec=fleet.spec, + ) + return master_instance_provisioning_data + + def _run_can_fit_into_fleet(run_spec: RunSpec, fleet: Fleet) -> bool: """ Returns `False` if the run cannot fit into fleet for sure. @@ -760,10 +873,13 @@ def _get_fleet_instances_with_pool_offers( async def _assign_job_to_fleet_instance( session: AsyncSession, - instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]], + fleet_model: Optional[FleetModel], job_model: JobModel, + instances_with_offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]], multinode: bool, ) -> Optional[InstanceModel]: + job_model.fleet = fleet_model + job_model.instance_assigned = True if len(instances_with_offers) == 0: return None @@ -949,6 +1065,7 @@ def _can_create_new_instance_in_fleet(fleet: Fleet) -> bool: async def _create_fleet_model_for_job( + exit_stack: AsyncExitStack, session: AsyncSession, project: ProjectModel, run: Run, @@ -957,16 +1074,18 @@ async def _create_fleet_model_for_job( if run.run_spec.configuration.type == "task" and run.run_spec.configuration.nodes > 1: placement = InstanceGroupPlacement.CLUSTER nodes = _get_nodes_required_num_for_run(run.run_spec) - lock_namespace = f"fleet_names_{project.name}" - # TODO: Lock fleet names on SQLite. - # Needs some refactoring so that the lock is released after commit. - if get_db().dialect_name == "postgresql": + if is_db_sqlite(): + # Start new transaction to see committed changes after lock + await session.commit() + elif is_db_postgres(): await session.execute( select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) ) + await exit_stack.enter_async_context( + get_locker(get_db().dialect_name).get_lockset(lock_namespace)[0] + ) fleet_name = await generate_fleet_name(session=session, project=project) - spec = FleetSpec( configuration=FleetConfiguration( name=fleet_name, diff --git a/src/dstack/_internal/server/db.py b/src/dstack/_internal/server/db.py index 084630add1..c9ed8d5280 100644 --- a/src/dstack/_internal/server/db.py +++ b/src/dstack/_internal/server/db.py @@ -103,6 +103,23 @@ async def new_func(*args, **kwargs): return new_func +def is_db_sqlite() -> bool: + return get_db().dialect_name == "sqlite" + + +def is_db_postgres() -> bool: + return get_db().dialect_name == "postgresql" + + +async def sqlite_commit(session: AsyncSession): + """ + Commit an sqlite transaction. + Should be used before taking locks in active sessions to see committed changes. + """ + if is_db_sqlite(): + await session.commit() + + def _run_alembic_upgrade(connection): alembic_cfg = config.Config() alembic_cfg.set_main_option("script_location", settings.ALEMBIC_MIGRATIONS_LOCATION) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 0e3aaf2d4b..7ee0bbfab2 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -45,7 +45,7 @@ from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models -from dstack._internal.server.db import get_db +from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( FleetModel, InstanceModel, @@ -675,14 +675,13 @@ async def _create_fleet( spec: FleetSpec, ) -> Fleet: lock_namespace = f"fleet_names_{project.name}" - if get_db().dialect_name == "sqlite": + if is_db_sqlite(): # Start new transaction to see committed changes after lock await session.commit() - elif get_db().dialect_name == "postgresql": + elif is_db_postgres(): await session.execute( select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) ) - lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) async with lock: if spec.configuration.name is not None: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index f47b192999..afad2831b7 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -38,7 +38,7 @@ ) from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.server import settings -from dstack._internal.server.db import get_db +from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( GatewayComputeModel, GatewayModel, @@ -148,14 +148,13 @@ async def create_gateway( ) lock_namespace = f"gateway_names_{project.name}" - if get_db().dialect_name == "sqlite": + if is_db_sqlite(): # Start new transaction to see committed changes after lock await session.commit() - elif get_db().dialect_name == "postgresql": + elif is_db_postgres(): await session.execute( select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) ) - lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) async with lock: if configuration.name is None: diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index ed64aa7219..870b378a8e 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -58,7 +58,7 @@ from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.core.services.diff import diff_models from dstack._internal.server import settings -from dstack._internal.server.db import get_db +from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( FleetModel, JobModel, @@ -510,14 +510,13 @@ async def submit_run( ) lock_namespace = f"run_names_{project.name}" - if get_db().dialect_name == "sqlite": + if is_db_sqlite(): # Start new transaction to see committed changes after lock await session.commit() - elif get_db().dialect_name == "postgresql": + elif is_db_postgres(): await session.execute( select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) ) - lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) async with lock: # FIXME: delete_runs commits, so Postgres lock is released too early. diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index f52f1f064f..fa3471192d 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -24,7 +24,7 @@ VolumeStatus, ) from dstack._internal.core.services import validate_dstack_resource_name -from dstack._internal.server.db import get_db +from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( InstanceModel, ProjectModel, @@ -215,14 +215,13 @@ async def create_volume( _validate_volume_configuration(configuration) lock_namespace = f"volume_names_{project.name}" - if get_db().dialect_name == "sqlite": + if is_db_sqlite(): # Start new transaction to see committed changes after lock await session.commit() - elif get_db().dialect_name == "postgresql": + elif is_db_postgres(): await session.execute( select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) ) - lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) async with lock: if configuration.name is not None: diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 545349e585..c92b8a2301 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -9,7 +9,7 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode from dstack._internal.core.models.configurations import TaskConfiguration -from dstack._internal.core.models.fleets import FleetNodesSpec +from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement from dstack._internal.core.models.health import HealthStatus from dstack._internal.core.models.instances import ( InstanceAvailability, @@ -546,6 +546,7 @@ async def test_assigns_multi_node_job_to_shared_instance(self, test_db, session: ) offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128) fleet_spec = get_fleet_spec() + fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER fleet_spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=None) fleet = await create_fleet(session=session, project=project, spec=fleet_spec) instance = await create_instance( @@ -1189,6 +1190,59 @@ async def test_provisions_compute_group(self, test_db, session: AsyncSession): res = await session.execute(select(ComputeGroupModel)) assert res.scalar() is not None + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_provisioning_master_job_respects_cluster_placement_in_non_empty_fleet( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=None) + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + backend=BackendType.AWS, + job_provisioning_data=get_job_provisioning_data(region="eu-west-1"), + ) + configuration = TaskConfiguration(image="debian", nodes=2) + run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration) + run = await create_run( + session=session, + run_name="run", + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + job = await create_job( + session=session, + run=run, + fleet=fleet, + instance_assigned=True, + ) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + backend_mock = Mock() + m.return_value = [backend_mock] + backend_mock.TYPE = BackendType.AWS + offer1 = get_instance_offer_with_availability(region="eu-west-2") + offer2 = get_instance_offer_with_availability(region="eu-west-1") + backend_mock.compute.return_value.get_offers.return_value = [offer1, offer2] + backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data() + await process_submitted_jobs() + m.assert_called_once() + backend_mock.compute.return_value.get_offers.assert_called_once() + backend_mock.compute.return_value.run_job.assert_called_once() + selected_offer = backend_mock.compute.return_value.run_job.call_args[0][2] + assert selected_offer.region == "eu-west-1" + await session.refresh(job) + assert job.status == JobStatus.PROVISIONING + @pytest.mark.parametrize( ["job_network_mode", "blocks", "multinode", "network_mode", "constraints_are_set"],