Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from datetime import datetime, timedelta
from typing import List, Optional, Tuple

from sqlalchemy import and_, or_, select
from sqlalchemy import and_, not_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, joinedload, load_only, selectinload
from sqlalchemy.orm import contains_eager, joinedload, load_only, noload, selectinload

from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport
Expand Down Expand Up @@ -250,8 +250,8 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
]
if run_model.fleet is not None:
fleet_filters.append(FleetModel.id == run_model.fleet_id)
if run_spec.configuration.fleets is not None:
fleet_filters.append(FleetModel.name.in_(run_spec.configuration.fleets))
if run_spec.merged_profile.fleets is not None:
fleet_filters.append(FleetModel.name.in_(run_spec.merged_profile.fleets))

instance_filters = [
InstanceModel.deleted == False,
Expand All @@ -269,9 +269,6 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
[i.id for i in f.instances] for f in fleet_models_with_instances
)
)
fleet_models = fleet_models_with_instances + fleet_models_without_instances
fleets_ids = [f.id for f in fleet_models]

if get_db().dialect_name == "sqlite":
# Start new transaction to see committed changes after lock
await session.commit()
Expand All @@ -280,13 +277,15 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
InstanceModel.__tablename__, instances_ids
):
if get_db().dialect_name == "sqlite":
fleet_models = await _refetch_fleet_models(
fleets_with_instances_ids = [f.id for f in fleet_models_with_instances]
fleet_models_with_instances = await _refetch_fleet_models_with_instances(
session=session,
fleets_ids=fleets_ids,
fleets_ids=fleets_with_instances_ids,
instances_ids=instances_ids,
fleet_filters=fleet_filters,
instance_filters=instance_filters,
)
fleet_models = fleet_models_with_instances + fleet_models_without_instances
fleet_model, fleet_instances_with_offers = _find_optimal_fleet_with_offers(
fleet_models=fleet_models,
run_model=run_model,
Expand All @@ -295,7 +294,7 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
master_job_provisioning_data=master_job_provisioning_data,
volumes=volumes,
)
if fleet_model is None and run_spec.configuration.fleets is not None:
if fleet_model is None and run_spec.merged_profile.fleets is not None:
# Run cannot create new fleets when fleets are specified
logger.debug("%s: failed to use specified fleets", fmt(job_model))
job_model.status = JobStatus.TERMINATING
Expand Down Expand Up @@ -443,14 +442,21 @@ async def _select_fleet_models(
*fleet_filters,
FleetModel.id.not_in(fleet_models_with_instances_ids),
)
.where(InstanceModel.id.is_(None))
.options(contains_eager(FleetModel.instances)) # loading empty relation
.where(
or_(
InstanceModel.id.is_(None),
not_(and_(*instance_filters)),
)
)
# Load empty list of instances so that downstream code
# knows this fleet has no instances eligible for offers.
.options(noload(FleetModel.instances))
)
fleet_models_without_instances = list(res.unique().scalars().all())
return fleet_models_with_instances, fleet_models_without_instances


async def _refetch_fleet_models(
async def _refetch_fleet_models_with_instances(
session: AsyncSession,
fleets_ids: list[uuid.UUID],
instances_ids: list[uuid.UUID],
Expand All @@ -465,13 +471,8 @@ async def _refetch_fleet_models(
*fleet_filters,
)
.where(
or_(
InstanceModel.id.is_(None),
and_(
InstanceModel.id.in_(instances_ids),
*instance_filters,
),
)
InstanceModel.id.in_(instances_ids),
*instance_filters,
)
.options(contains_eager(FleetModel.instances))
.execution_options(populate_existing=True)
Expand Down Expand Up @@ -538,7 +539,7 @@ def _find_optimal_fleet_with_offers(
fleet_priority,
)
)
if run_spec.configuration.fleets is None and all(
if run_spec.merged_profile.fleets is None and all(
t[2] == 0 for t in candidate_fleets_with_offers
):
# If fleets are not specified and no fleets have available offers, create a new fleet.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,46 @@ async def test_assigns_job_to_elastic_empty_fleet_if_fleets_specified(
assert job.instance_id is None
assert job.fleet_id == fleet.id

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_assigns_job_to_elastic_non_empty_busy_fleet_if_fleets_specified(
self, test_db, session: AsyncSession
):
project = await create_project(session)
user = await create_user(session)
repo = await create_repo(session=session, project_id=project.id)
fleet_spec = get_fleet_spec()
fleet_spec.configuration.nodes = Range(min=1, max=2)
fleet = await create_fleet(session=session, project=project, spec=fleet_spec, name="fleet")
await create_instance(
session=session,
project=project,
fleet=fleet,
instance_num=0,
status=InstanceStatus.BUSY,
total_blocks=1,
busy_blocks=1,
)
run_spec = get_run_spec(repo_id=repo.name)
run_spec.configuration.fleets = [fleet.name]
run = await create_run(
session=session,
project=project,
repo=repo,
user=user,
run_spec=run_spec,
)
job = await create_job(
session=session,
run=run,
instance_assigned=False,
)
await process_submitted_jobs()
await session.refresh(job)
assert job.instance_assigned
assert job.instance_id is None
assert job.fleet_id == fleet.id

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_creates_new_instance_in_existing_empty_fleet(
Expand Down
Loading