|
32 | 32 | JobSpec, |
33 | 33 | JobStatus, |
34 | 34 | JobTerminationReason, |
| 35 | + ProbeSpec, |
35 | 36 | Run, |
36 | 37 | RunSpec, |
37 | 38 | RunStatus, |
|
70 | 71 | from dstack._internal.server.services.runner import client |
71 | 72 | from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel |
72 | 73 | from dstack._internal.server.services.runs import ( |
| 74 | + is_job_ready, |
73 | 75 | run_model_to_run, |
74 | 76 | ) |
75 | 77 | from dstack._internal.server.services.secrets import get_project_secrets_mapping |
@@ -140,6 +142,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): |
140 | 142 | select(JobModel) |
141 | 143 | .where(JobModel.id == job_model.id) |
142 | 144 | .options(joinedload(JobModel.instance).joinedload(InstanceModel.project)) |
| 145 | + .options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak)) |
143 | 146 | .execution_options(populate_existing=True) |
144 | 147 | ) |
145 | 148 | job_model = res.unique().scalar_one() |
@@ -382,52 +385,21 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel): |
382 | 385 | job_submission.age, |
383 | 386 | ) |
384 | 387 |
|
385 | | - if ( |
386 | | - initial_status != job_model.status |
387 | | - and job_model.status == JobStatus.RUNNING |
388 | | - and job_model.job_num == 0 # gateway connects only to the first node |
389 | | - and run.run_spec.configuration.type == "service" |
390 | | - ): |
391 | | - ssh_head_proxy: Optional[SSHConnectionParams] = None |
392 | | - ssh_head_proxy_private_key: Optional[str] = None |
393 | | - instance = common_utils.get_or_error(job_model.instance) |
394 | | - if instance.remote_connection_info is not None: |
395 | | - rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info) |
396 | | - if rci.ssh_proxy is not None: |
397 | | - ssh_head_proxy = rci.ssh_proxy |
398 | | - ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys) |
399 | | - ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private |
400 | | - try: |
401 | | - await services.register_replica( |
402 | | - session, |
403 | | - run_model.gateway_id, |
404 | | - run, |
405 | | - job_model, |
406 | | - ssh_head_proxy, |
407 | | - ssh_head_proxy_private_key, |
408 | | - ) |
409 | | - except GatewayError as e: |
410 | | - logger.warning( |
411 | | - "%s: failed to register service replica: %s, age=%s", |
412 | | - fmt(job_model), |
413 | | - e, |
414 | | - job_submission.age, |
415 | | - ) |
416 | | - job_model.status = JobStatus.TERMINATING |
417 | | - job_model.termination_reason = JobTerminationReason.GATEWAY_ERROR |
418 | | - else: |
419 | | - for probe_num in range(len(job.job_spec.probes)): |
420 | | - session.add( |
421 | | - ProbeModel( |
422 | | - name=f"{job_model.job_name}-{probe_num}", |
423 | | - job=job_model, |
424 | | - probe_num=probe_num, |
425 | | - due=common_utils.get_current_datetime(), |
426 | | - success_streak=0, |
427 | | - active=True, |
428 | | - ) |
| 388 | + if initial_status != job_model.status and job_model.status == JobStatus.RUNNING: |
| 389 | + job_model.probes = [] |
| 390 | + for probe_num in range(len(job.job_spec.probes)): |
| 391 | + job_model.probes.append( |
| 392 | + ProbeModel( |
| 393 | + name=f"{job_model.job_name}-{probe_num}", |
| 394 | + probe_num=probe_num, |
| 395 | + due=common_utils.get_current_datetime(), |
| 396 | + success_streak=0, |
| 397 | + active=True, |
429 | 398 | ) |
| 399 | + ) |
430 | 400 |
|
| 401 | + if job_model.status == JobStatus.RUNNING: |
| 402 | + await _maybe_register_replica(session, run_model, run, job_model, job.job_spec.probes) |
431 | 403 | if job_model.status == JobStatus.RUNNING: |
432 | 404 | await _check_gpu_utilization(session, job_model, job) |
433 | 405 |
|
@@ -822,6 +794,55 @@ def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool: |
822 | 794 | ) |
823 | 795 |
|
824 | 796 |
|
| 797 | +async def _maybe_register_replica( |
| 798 | + session: AsyncSession, |
| 799 | + run_model: RunModel, |
| 800 | + run: Run, |
| 801 | + job_model: JobModel, |
| 802 | + probe_specs: Iterable[ProbeSpec], |
| 803 | +) -> None: |
| 804 | + """ |
| 805 | + Register the replica represented by this job to receive service requests if it is ready. |
| 806 | + """ |
| 807 | + |
| 808 | + if ( |
| 809 | + run.run_spec.configuration.type != "service" |
| 810 | + or job_model.registered |
| 811 | + or job_model.job_num != 0 # only the first job in the replica receives service requests |
| 812 | + or not is_job_ready(job_model.probes, probe_specs) |
| 813 | + ): |
| 814 | + return |
| 815 | + |
| 816 | + ssh_head_proxy: Optional[SSHConnectionParams] = None |
| 817 | + ssh_head_proxy_private_key: Optional[str] = None |
| 818 | + instance = common_utils.get_or_error(job_model.instance) |
| 819 | + if instance.remote_connection_info is not None: |
| 820 | + rci: RemoteConnectionInfo = RemoteConnectionInfo.__response__.parse_raw( |
| 821 | + instance.remote_connection_info |
| 822 | + ) |
| 823 | + if rci.ssh_proxy is not None: |
| 824 | + ssh_head_proxy = rci.ssh_proxy |
| 825 | + ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys) |
| 826 | + ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private |
| 827 | + try: |
| 828 | + await services.register_replica( |
| 829 | + session, |
| 830 | + run_model.gateway_id, |
| 831 | + run, |
| 832 | + job_model, |
| 833 | + ssh_head_proxy, |
| 834 | + ssh_head_proxy_private_key, |
| 835 | + ) |
| 836 | + except GatewayError as e: |
| 837 | + logger.warning( |
| 838 | + "%s: failed to register service replica: %s", |
| 839 | + fmt(job_model), |
| 840 | + e, |
| 841 | + ) |
| 842 | + job_model.status = JobStatus.TERMINATING |
| 843 | + job_model.termination_reason = JobTerminationReason.GATEWAY_ERROR |
| 844 | + |
| 845 | + |
825 | 846 | async def _check_gpu_utilization(session: AsyncSession, job_model: JobModel, job: Job) -> None: |
826 | 847 | policy = job.job_spec.utilization_policy |
827 | 848 | if policy is None: |
|
0 commit comments