diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index a51863a8a0..81b7d27acd 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -5,9 +5,9 @@ from datetime import datetime, timedelta from typing import List, Optional, Tuple -from sqlalchemy import and_, or_, select +from sqlalchemy import and_, not_, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import contains_eager, joinedload, load_only, selectinload +from sqlalchemy.orm import contains_eager, joinedload, load_only, noload, selectinload from dstack._internal.core.backends.base.backend import Backend from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport @@ -250,8 +250,8 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): ] if run_model.fleet is not None: fleet_filters.append(FleetModel.id == run_model.fleet_id) - if run_spec.configuration.fleets is not None: - fleet_filters.append(FleetModel.name.in_(run_spec.configuration.fleets)) + if run_spec.merged_profile.fleets is not None: + fleet_filters.append(FleetModel.name.in_(run_spec.merged_profile.fleets)) instance_filters = [ InstanceModel.deleted == False, @@ -269,9 +269,6 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): [i.id for i in f.instances] for f in fleet_models_with_instances ) ) - fleet_models = fleet_models_with_instances + fleet_models_without_instances - fleets_ids = [f.id for f in fleet_models] - if get_db().dialect_name == "sqlite": # Start new transaction to see committed changes after lock await session.commit() @@ -280,13 +277,15 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): InstanceModel.__tablename__, instances_ids ): if get_db().dialect_name == "sqlite": - fleet_models = await _refetch_fleet_models( + fleets_with_instances_ids = [f.id for f in fleet_models_with_instances] + fleet_models_with_instances = await _refetch_fleet_models_with_instances( session=session, - fleets_ids=fleets_ids, + fleets_ids=fleets_with_instances_ids, instances_ids=instances_ids, fleet_filters=fleet_filters, instance_filters=instance_filters, ) + fleet_models = fleet_models_with_instances + fleet_models_without_instances fleet_model, fleet_instances_with_offers = _find_optimal_fleet_with_offers( fleet_models=fleet_models, run_model=run_model, @@ -295,7 +294,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): master_job_provisioning_data=master_job_provisioning_data, volumes=volumes, ) - if fleet_model is None and run_spec.configuration.fleets is not None: + if fleet_model is None and run_spec.merged_profile.fleets is not None: # Run cannot create new fleets when fleets are specified logger.debug("%s: failed to use specified fleets", fmt(job_model)) job_model.status = JobStatus.TERMINATING @@ -443,14 +442,21 @@ async def _select_fleet_models( *fleet_filters, FleetModel.id.not_in(fleet_models_with_instances_ids), ) - .where(InstanceModel.id.is_(None)) - .options(contains_eager(FleetModel.instances)) # loading empty relation + .where( + or_( + InstanceModel.id.is_(None), + not_(and_(*instance_filters)), + ) + ) + # Load empty list of instances so that downstream code + # knows this fleet has no instances eligible for offers. + .options(noload(FleetModel.instances)) ) fleet_models_without_instances = list(res.unique().scalars().all()) return fleet_models_with_instances, fleet_models_without_instances -async def _refetch_fleet_models( +async def _refetch_fleet_models_with_instances( session: AsyncSession, fleets_ids: list[uuid.UUID], instances_ids: list[uuid.UUID], @@ -465,13 +471,8 @@ async def _refetch_fleet_models( *fleet_filters, ) .where( - or_( - InstanceModel.id.is_(None), - and_( - InstanceModel.id.in_(instances_ids), - *instance_filters, - ), - ) + InstanceModel.id.in_(instances_ids), + *instance_filters, ) .options(contains_eager(FleetModel.instances)) .execution_options(populate_existing=True) @@ -538,7 +539,7 @@ def _find_optimal_fleet_with_offers( fleet_priority, ) ) - if run_spec.configuration.fleets is None and all( + if run_spec.merged_profile.fleets is None and all( t[2] == 0 for t in candidate_fleets_with_offers ): # If fleets are not specified and no fleets have available offers, create a new fleet. diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 2bc226dded..b64cf2c56c 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -808,6 +808,46 @@ async def test_assigns_job_to_elastic_empty_fleet_if_fleets_specified( assert job.instance_id is None assert job.fleet_id == fleet.id + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_assigns_job_to_elastic_non_empty_busy_fleet_if_fleets_specified( + self, test_db, session: AsyncSession + ): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.nodes = Range(min=1, max=2) + fleet = await create_fleet(session=session, project=project, spec=fleet_spec, name="fleet") + await create_instance( + session=session, + project=project, + fleet=fleet, + instance_num=0, + status=InstanceStatus.BUSY, + total_blocks=1, + busy_blocks=1, + ) + run_spec = get_run_spec(repo_id=repo.name) + run_spec.configuration.fleets = [fleet.name] + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + job = await create_job( + session=session, + run=run, + instance_assigned=False, + ) + await process_submitted_jobs() + await session.refresh(job) + assert job.instance_assigned + assert job.instance_id is None + assert job.fleet_id == fleet.id + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_creates_new_instance_in_existing_empty_fleet(