Skip to content

Commit 31bdec8

Browse files
committed
Use new helpers get_{run,job}_spec
1 parent ea55b8e commit 31bdec8

10 files changed

Lines changed: 36 additions & 29 deletions

File tree

src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
emit_job_status_change_event,
6060
get_job_provisioning_data,
6161
get_job_runtime_data,
62+
get_job_spec,
6263
)
6364
from dstack._internal.server.services.locking import get_locker
6465
from dstack._internal.server.services.logging import fmt
@@ -797,7 +798,7 @@ async def _detach_volumes_from_job_instance(
797798
jpd: JobProvisioningData,
798799
run_termination_reason: Optional[RunTerminationReason],
799800
) -> _VolumeDetachResult:
800-
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
801+
job_spec = get_job_spec(job_model)
801802
backend = await backends_services.get_project_backend_by_type(
802803
project=instance_model.project,
803804
backend_type=jpd.backend,

src/dstack/_internal/server/background/scheduled_tasks/probes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlalchemy.orm import joinedload
1313

1414
from dstack._internal.core.errors import SSHError
15-
from dstack._internal.core.models.runs import JobSpec, JobStatus, ProbeSpec
15+
from dstack._internal.core.models.runs import JobStatus, ProbeSpec
1616
from dstack._internal.core.services.ssh.tunnel import (
1717
SSH_DEFAULT_OPTIONS,
1818
IPSocket,
@@ -21,6 +21,7 @@
2121
)
2222
from dstack._internal.server.db import get_db, get_session_ctx
2323
from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel
24+
from dstack._internal.server.services.jobs import get_job_spec
2425
from dstack._internal.server.services.locking import get_locker
2526
from dstack._internal.server.services.logging import fmt
2627
from dstack._internal.server.services.ssh import container_ssh_tunnel
@@ -71,7 +72,7 @@ async def process_probes():
7172
if probe.job.status != JobStatus.RUNNING:
7273
probe.active = False
7374
else:
74-
job_spec: JobSpec = JobSpec.__response__.parse_raw(probe.job.job_spec_data)
75+
job_spec = get_job_spec(probe.job)
7576
probe_spec = job_spec.probes[probe.probe_num]
7677
if probe_spec.until_ready and probe.success_streak >= probe_spec.ready_after:
7778
probe.active = False
@@ -148,7 +149,7 @@ async def _get_service_replica_client(job: JobModel) -> AsyncGenerator[AsyncClie
148149
**SSH_DEFAULT_OPTIONS,
149150
"ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
150151
}
151-
job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
152+
job_spec = get_job_spec(job)
152153
with TemporaryDirectory() as temp_dir:
153154
app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
154155
async with container_ssh_tunnel(

src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
JobTerminationReason,
3535
ProbeSpec,
3636
Run,
37-
RunSpec,
3837
RunStatus,
3938
)
4039
from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
@@ -67,6 +66,7 @@
6766
find_job,
6867
get_job_attached_volumes,
6968
get_job_runtime_data,
69+
get_job_spec,
7070
is_master_job,
7171
job_model_to_job_submission,
7272
switch_job_status,
@@ -82,6 +82,7 @@
8282
from dstack._internal.server.services.runner import client
8383
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
8484
from dstack._internal.server.services.runs import (
85+
get_run_spec,
8586
is_job_ready,
8687
run_model_to_run,
8788
)
@@ -732,7 +733,7 @@ def _process_provisioning_with_shim(
732733
Returns:
733734
is successful
734735
"""
735-
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
736+
job_spec = get_job_spec(job_model)
736737

737738
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
738739

@@ -980,7 +981,7 @@ def _terminate_if_inactivity_duration_exceeded(
980981
job_model: JobModel,
981982
no_connections_secs: Optional[int],
982983
) -> None:
983-
conf = RunSpec.__response__.parse_raw(run_model.run_spec).configuration
984+
conf = get_run_spec(run_model).configuration
984985
if not isinstance(conf, DevEnvironmentConfiguration) or not isinstance(
985986
conf.inactivity_duration, int
986987
):

src/dstack/_internal/server/background/scheduled_tasks/runs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
1414
from dstack._internal.core.models.runs import (
1515
Job,
16-
JobSpec,
1716
JobStatus,
1817
JobTerminationReason,
1918
Run,
@@ -33,6 +32,7 @@
3332
from dstack._internal.server.services import events
3433
from dstack._internal.server.services.jobs import (
3534
find_job,
35+
get_job_spec,
3636
get_job_specs_from_run_spec,
3737
group_jobs_by_replica_latest,
3838
is_master_job,
@@ -527,7 +527,7 @@ async def _handle_run_replicas(
527527
if job.status.is_finished():
528528
continue
529529
try:
530-
job_spec = JobSpec.__response__.parse_raw(job.job_spec_data)
530+
job_spec = get_job_spec(job)
531531
existing_group_names.add(job_spec.replica_group)
532532
except Exception:
533533
continue
@@ -643,7 +643,7 @@ async def _update_jobs_to_new_deployment_in_place(
643643
replica_group_name = None
644644

645645
if replicas:
646-
job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data)
646+
job_spec = get_job_spec(job_models[0])
647647
replica_group_name = job_spec.replica_group
648648

649649
# FIXME: Handle getting image configuration errors or skip it.
@@ -658,7 +658,7 @@ async def _update_jobs_to_new_deployment_in_place(
658658
)
659659
can_update_all_jobs = True
660660
for old_job_model, new_job_spec in zip(job_models, new_job_specs):
661-
old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data)
661+
old_job_spec = get_job_spec(old_job_model)
662662
if new_job_spec != old_job_spec:
663663
can_update_all_jobs = False
664664
break

src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from dstack._internal.server.services.jobs import (
4040
get_job_provisioning_data,
4141
get_job_runtime_data,
42+
get_job_spec,
4243
switch_job_status,
4344
)
4445
from dstack._internal.server.services.locking import get_locker
@@ -356,7 +357,7 @@ async def _detach_volumes_from_job_instance(
356357
instance_model: InstanceModel,
357358
volume_models: list[VolumeModel],
358359
) -> bool:
359-
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
360+
job_spec = get_job_spec(job_model)
360361
backend = await backends_services.get_project_backend_by_type(
361362
project=project,
362363
backend_type=jpd.backend,

src/dstack/_internal/server/services/prometheus/custom_metrics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sqlalchemy.orm import aliased, joinedload
1414

1515
from dstack._internal.core.models.instances import InstanceStatus
16-
from dstack._internal.core.models.runs import JobStatus, RunSpec, RunStatus
16+
from dstack._internal.core.models.runs import JobStatus, RunStatus
1717
from dstack._internal.server.models import (
1818
InstanceModel,
1919
JobMetricsPoint,
@@ -25,6 +25,7 @@
2525
)
2626
from dstack._internal.server.services.instances import get_instance_offer
2727
from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
28+
from dstack._internal.server.services.runs import get_run_spec
2829
from dstack._internal.utils.common import get_current_datetime
2930

3031

@@ -152,7 +153,7 @@ async def get_job_metrics(session: AsyncSession) -> Iterable[Metric]:
152153
price = jrd.offer.price
153154
gpus = resources.gpus
154155
cpus = resources.cpus
155-
run_spec = RunSpec.__response__.parse_raw(job.run.run_spec)
156+
run_spec = get_run_spec(job.run)
156157
labels = {
157158
"dstack_project_name": job.project.name,
158159
"dstack_user_name": job.run.user.name,

src/dstack/_internal/server/services/proxy/repo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from dstack._internal.core.models.instances import SSHConnectionParams
1313
from dstack._internal.core.models.runs import (
1414
JobProvisioningData,
15-
JobSpec,
1615
JobStatus,
17-
RunSpec,
1816
RunStatus,
1917
ServiceSpec,
2018
get_service_port,
@@ -32,6 +30,8 @@
3230
from dstack._internal.proxy.lib.repo import BaseProxyRepo
3331
from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel
3432
from dstack._internal.server.services.instances import get_instance_remote_connection_info
33+
from dstack._internal.server.services.jobs import get_job_spec
34+
from dstack._internal.server.services.runs import get_run_spec
3535
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
3636
from dstack._internal.utils.common import get_or_error
3737

@@ -68,7 +68,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
6868
if not len(jobs):
6969
return None
7070
run = jobs[0].run
71-
run_spec = RunSpec.__response__.parse_raw(run.run_spec)
71+
run_spec = get_run_spec(run)
7272
if not isinstance(run_spec.configuration, ServiceConfiguration):
7373
return None
7474
replicas = []
@@ -108,7 +108,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
108108
if rci is not None and rci.ssh_proxy is not None:
109109
ssh_head_proxy = rci.ssh_proxy
110110
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
111-
job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
111+
job_spec = get_job_spec(job)
112112
replica = Replica(
113113
id=job.id.hex,
114114
app_port=get_service_port(job_spec, run_spec.configuration),

src/dstack/_internal/server/services/runs/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from dstack._internal.core.models.runs import (
2525
ApplyRunPlanInput,
2626
Job,
27-
JobSpec,
2827
JobStatus,
2928
JobSubmission,
3029
JobTerminationReason,
@@ -54,6 +53,7 @@
5453
check_can_attach_job_volumes,
5554
delay_job_instance_termination,
5655
get_job_configured_volumes,
56+
get_job_spec,
5757
get_jobs_from_run_spec,
5858
job_model_to_job_submission,
5959
remove_job_spec_sensitive_info,
@@ -835,7 +835,7 @@ def _get_run_jobs_with_submissions(
835835
submissions.append(job_submission)
836836
if job_model is not None:
837837
# Use the spec from the latest submission. Submissions can have different specs
838-
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
838+
job_spec = get_job_spec(job_model)
839839
if not include_sensitive:
840840
remove_job_spec_sensitive_info(job_spec)
841841
jobs.append(Job(job_spec=job_spec, job_submissions=submissions))
@@ -861,7 +861,7 @@ def _get_run_status_message(run_model: RunModel) -> str:
861861
if run_model.status in [RunStatus.SUBMITTED, RunStatus.PENDING]:
862862
# Show `retrying` if any job caused the run to retry
863863
for job_models in job_models_grouped_by_job:
864-
last_job_spec = JobSpec.__response__.parse_raw(job_models[-1].job_spec_data)
864+
last_job_spec = get_job_spec(job_models[-1])
865865
retry_on_events = last_job_spec.retry.on_events if last_job_spec.retry else []
866866
last_job_termination_reason = _get_last_job_termination_reason(job_models)
867867
if (

src/dstack/_internal/server/services/runs/replicas.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
from sqlalchemy.ext.asyncio import AsyncSession
55

66
from dstack._internal.core.models.configurations import ReplicaGroup
7-
from dstack._internal.core.models.runs import JobSpec, JobStatus, JobTerminationReason, RunSpec
7+
from dstack._internal.core.models.runs import JobStatus, JobTerminationReason, RunSpec
88
from dstack._internal.server.models import JobModel, RunModel
99
from dstack._internal.server.services import events
1010
from dstack._internal.server.services.jobs import (
11+
get_job_spec,
1112
get_jobs_from_run_spec,
1213
group_jobs_by_replica_latest,
1314
switch_job_status,
1415
)
1516
from dstack._internal.server.services.logging import fmt
1617
from dstack._internal.server.services.runs import (
1718
create_job_model_for_new_submission,
19+
get_run_spec,
1820
logger,
1921
)
2022
from dstack._internal.server.services.secrets import get_project_secrets_mapping
@@ -30,8 +32,8 @@ async def retry_run_replica_jobs(
3032
)
3133

3234
# Determine replica group from existing job
33-
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
34-
job_spec = JobSpec.__response__.parse_raw(latest_jobs[0].job_spec_data)
35+
run_spec = get_run_spec(run_model)
36+
job_spec = get_job_spec(latest_jobs[0])
3537
replica_group_name = job_spec.replica_group
3638

3739
new_jobs = await get_jobs_from_run_spec(
@@ -86,7 +88,7 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
8688
)
8789

8890
active_replicas, inactive_replicas = build_replica_lists(run_model)
89-
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
91+
run_spec = get_run_spec(run_model)
9092

9193
if replicas_diff < 0:
9294
scale_down_replicas(session, active_replicas, abs(replicas_diff))
@@ -259,7 +261,7 @@ async def scale_run_replicas_per_group(
259261
run_model=run_model,
260262
group=group,
261263
replicas_diff=group_diff,
262-
run_spec=RunSpec.__response__.parse_raw(run_model.run_spec),
264+
run_spec=get_run_spec(run_model),
263265
active_replicas=active_replicas,
264266
inactive_replicas=inactive_replicas,
265267
)
@@ -300,7 +302,7 @@ async def scale_run_replicas_for_group(
300302

301303

302304
def job_belongs_to_group(job: JobModel, group_name: str) -> bool:
303-
job_spec = JobSpec.__response__.parse_raw(job.job_spec_data)
305+
job_spec = get_job_spec(job)
304306
return job_spec.replica_group == group_name
305307

306308

src/dstack/_internal/server/services/services/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
RouterType,
3333
SGLangServiceRouterConfig,
3434
)
35-
from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec
35+
from dstack._internal.core.models.runs import Run, RunSpec, ServiceModelSpec, ServiceSpec
3636
from dstack._internal.core.models.services import OpenAIChatModel
3737
from dstack._internal.proxy.gateway.const import SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE
3838
from dstack._internal.server import settings
@@ -322,7 +322,7 @@ async def register_replica(
322322
async with conn.client() as client:
323323
await client.register_replica(
324324
run=run,
325-
job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data),
325+
job_spec=jobs_services.get_job_spec(job_model),
326326
job_submission=job_submission,
327327
instance_project_ssh_private_key=instance_project_ssh_private_key,
328328
ssh_head_proxy=ssh_head_proxy,

0 commit comments

Comments
 (0)