Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from dstack._internal.utils.path import is_absolute_posix_path
from dstack.api._public.repos import get_ssh_keypair
from dstack.api._public.runs import Run
from dstack.api.server import APIClient
from dstack.api.utils import load_profile

_KNOWN_AMD_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_AMD_GPUS}
Expand Down Expand Up @@ -222,6 +223,9 @@ def apply_configuration(
format_date=local_time,
)
)

_warn_fleet_autocreated(self.api.client, run)

console.print(
f"\n[code]{run.name}[/] provisioning completed [secondary]({run.status.value})[/]"
)
Expand Down Expand Up @@ -865,3 +869,16 @@ def render_run_spec_diff(old_spec: RunSpec, new_spec: RunSpec) -> Optional[str]:
item = NestedListItem(spec_field.replace("_", " ").capitalize())
nested_list.children.append(item)
return nested_list.render()


def _warn_fleet_autocreated(api: APIClient, run: Run):
if run._run.fleet is None:
return
fleet = api.fleets.get(project_name=run._project, name=run._run.fleet.name)
if not fleet.spec.autocreated:
return
warn(
f"\nNo existing fleet matched, so the run created a new fleet [code]{fleet.name}[/code].\n"
"Future dstack versions won't create fleets automatically.\n"
"Create a fleet explicitly: https://dstack.ai/docs/concepts/fleets/"
)
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,16 @@ def is_finished(self):
return self in self.finished_statuses()


class RunFleet(CoreModel):
id: UUID4
name: str


class Run(CoreModel):
id: UUID4
project_name: str
user: str
fleet: Optional[RunFleet] = None
submitted_at: datetime
last_processed_at: datetime
status: RunStatus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from dstack._internal.server.background.tasks.common import get_provisioning_timeout
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import (
FleetModel,
InstanceModel,
JobModel,
ProbeModel,
Expand Down Expand Up @@ -151,6 +152,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
.options(joinedload(RunModel.project))
.options(joinedload(RunModel.user))
.options(joinedload(RunModel.repo))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(joinedload(RunModel.jobs))
)
run_model = res.unique().scalar_one()
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/background/tasks/process_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import (
FleetModel,
InstanceModel,
JobModel,
ProjectModel,
Expand Down Expand Up @@ -145,6 +146,7 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
.execution_options(populate_existing=True)
.options(joinedload(RunModel.project).load_only(ProjectModel.id, ProjectModel.name))
.options(joinedload(RunModel.user).load_only(UserModel.name))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(
selectinload(RunModel.jobs)
.joinedload(JobModel.instance)
Expand Down
16 changes: 16 additions & 0 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
JobTerminationReason,
ProbeSpec,
Run,
RunFleet,
RunPlan,
RunSpec,
RunStatus,
Expand All @@ -58,6 +59,7 @@
from dstack._internal.server import settings
from dstack._internal.server.db import get_db
from dstack._internal.server.models import (
FleetModel,
JobModel,
ProbeModel,
ProjectModel,
Expand Down Expand Up @@ -227,6 +229,7 @@ async def list_projects_run_models(
select(RunModel)
.where(*filters)
.options(joinedload(RunModel.user).load_only(UserModel.name))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(selectinload(RunModel.jobs).joinedload(JobModel.probes))
.order_by(*order_by)
.limit(limit)
Expand Down Expand Up @@ -269,6 +272,7 @@ async def get_run_by_name(
RunModel.deleted == False,
)
.options(joinedload(RunModel.user))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(selectinload(RunModel.jobs).joinedload(JobModel.probes))
)
run_model = res.scalar()
Expand All @@ -289,6 +293,7 @@ async def get_run_by_id(
RunModel.id == run_id,
)
.options(joinedload(RunModel.user))
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
.options(selectinload(RunModel.jobs).joinedload(JobModel.probes))
)
run_model = res.scalar()
Expand Down Expand Up @@ -709,10 +714,12 @@ def run_model_to_run(

status_message = _get_run_status_message(run_model)
error = _get_run_error(run_model)
fleet = _get_run_fleet(run_model)
run = Run(
id=run_model.id,
project_name=run_model.project.name,
user=run_model.user.name,
fleet=fleet,
submitted_at=run_model.submitted_at,
last_processed_at=run_model.last_processed_at,
status=run_model.status,
Expand Down Expand Up @@ -821,6 +828,15 @@ def _get_run_error(run_model: RunModel) -> Optional[str]:
return run_model.termination_reason.to_error()


def _get_run_fleet(run_model: RunModel) -> Optional[RunFleet]:
if run_model.fleet is None:
return None
return RunFleet(
id=run_model.fleet.id,
name=run_model.fleet.name,
)


async def _get_pool_offers(
session: AsyncSession,
project: ProjectModel,
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ async def create_run(
project: ProjectModel,
repo: RepoModel,
user: UserModel,
fleet: Optional[FleetModel] = None,
run_name: str = "test-run",
status: RunStatus = RunStatus.SUBMITTED,
termination_reason: Optional[RunTerminationReason] = None,
Expand All @@ -310,6 +311,7 @@ async def create_run(
project_id=project.id,
repo_id=repo.id,
user_id=user.id,
fleet_id=fleet.id if fleet else None,
submitted_at=submitted_at,
run_name=run_name,
status=status,
Expand Down
14 changes: 14 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from dstack._internal.server.services.runs import run_model_to_run
from dstack._internal.server.testing.common import (
create_backend,
create_fleet,
create_gateway,
create_gateway_compute,
create_instance,
Expand Down Expand Up @@ -337,6 +338,7 @@ def get_dev_env_run_dict(
"id": run_id,
"project_name": project_name,
"user": username,
"fleet": None,
"submitted_at": submitted_at,
"last_processed_at": last_processed_at,
"status": "submitted",
Expand Down Expand Up @@ -558,6 +560,7 @@ async def test_returns_40x_if_not_authenticated(
async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncClient):
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
fleet = await create_fleet(session=session, project=project)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
Expand All @@ -571,6 +574,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli
project=project,
repo=repo,
user=user,
fleet=fleet,
submitted_at=run1_submitted_at,
)
run1_spec = RunSpec.parse_raw(run1.run_spec)
Expand All @@ -587,6 +591,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli
project=project,
repo=repo,
user=user,
fleet=fleet,
submitted_at=run2_submitted_at,
)
run2_spec = RunSpec.parse_raw(run2.run_spec)
Expand All @@ -601,6 +606,10 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli
"id": str(run1.id),
"project_name": project.name,
"user": user.name,
"fleet": {
"id": str(fleet.id),
"name": fleet.name,
},
"submitted_at": run1_submitted_at.isoformat(),
"last_processed_at": run1_submitted_at.isoformat(),
"status": "submitted",
Expand Down Expand Up @@ -660,6 +669,10 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli
"id": str(run2.id),
"project_name": project.name,
"user": user.name,
"fleet": {
"id": str(fleet.id),
"name": fleet.name,
},
"submitted_at": run2_submitted_at.isoformat(),
"last_processed_at": run2_submitted_at.isoformat(),
"status": "submitted",
Expand Down Expand Up @@ -784,6 +797,7 @@ async def test_limits_job_submissions(
"id": str(run.id),
"project_name": project.name,
"user": user.name,
"fleet": None,
"submitted_at": run_submitted_at.isoformat(),
"last_processed_at": run_submitted_at.isoformat(),
"status": "submitted",
Expand Down
Loading