diff --git a/frontend/src/pages/Runs/List/index.tsx b/frontend/src/pages/Runs/List/index.tsx index 6ac11f990d..4f90e5430d 100644 --- a/frontend/src/pages/Runs/List/index.tsx +++ b/frontend/src/pages/Runs/List/index.tsx @@ -48,7 +48,7 @@ export const RunList: React.FC = () => { const { data, isLoading, refreshList, isLoadingMore } = useInfiniteScroll({ useLazyQuery: useLazyGetRunsQuery, - args: { ...filteringRequestParams, limit: DEFAULT_TABLE_PAGE_SIZE }, + args: { ...filteringRequestParams, limit: DEFAULT_TABLE_PAGE_SIZE, job_submissions_limit: 1 }, getPaginationParams: (lastRun) => ({ prev_submitted_at: lastRun.submitted_at }), }); diff --git a/frontend/src/types/run.d.ts b/frontend/src/types/run.d.ts index eae9ebacc4..2e613defb4 100644 --- a/frontend/src/types/run.d.ts +++ b/frontend/src/types/run.d.ts @@ -7,6 +7,7 @@ declare type TRunsRequestParams = { prev_run_id?: string; limit?: number; ascending?: boolean; + job_submissions_limit?: number; }; declare type TDeleteRunsRequestParams = { diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 49691eb504..84a66f6b70 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -556,6 +556,10 @@ def _get_error(termination_reason: Optional[RunTerminationReason]) -> Optional[s @root_validator def _status_message(cls, values) -> Dict: + # FIXME: status_message should not require all job submissions for status calculation + # since it's very expensive and is not required for anything else. + # May return a different status if not all job submissions requested. + # TODO: Calculate status_message by looking at job models directly instead job submissions. try: status = values["status"] jobs: List[Job] = values["jobs"] diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index c6a4b60f80..99bce78073 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -54,6 +54,8 @@ async def list_runs( repo_id=body.repo_id, username=body.username, only_active=body.only_active, + include_jobs=body.include_jobs, + job_submissions_limit=body.job_submissions_limit, prev_submitted_at=body.prev_submitted_at, prev_run_id=body.prev_run_id, limit=body.limit, diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index 8ae875df09..8447243715 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -9,12 +9,24 @@ class ListRunsRequest(CoreModel): - project_name: Optional[str] - repo_id: Optional[str] - username: Optional[str] + project_name: Optional[str] = None + repo_id: Optional[str] = None + username: Optional[str] = None only_active: bool = False - prev_submitted_at: Optional[datetime] - prev_run_id: Optional[UUID] + include_jobs: bool = Field( + True, + description=("Whether to include `jobs` in the response"), + ) + job_submissions_limit: Optional[int] = Field( + None, + ge=0, + description=( + "Limit number of job submissions returned per job to avoid large responses." + "Drops older job submissions. No effect with `include_jobs: false`" + ), + ) + prev_submitted_at: Optional[datetime] = None + prev_run_id: Optional[UUID] = None limit: int = Field(100, ge=0, le=100) ascending: bool = False diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 79d9e0a209..5462d6b018 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -105,6 +105,8 @@ async def list_user_runs( repo_id: Optional[str], username: Optional[str], only_active: bool, + include_jobs: bool, + job_submissions_limit: Optional[int], prev_submitted_at: Optional[datetime], prev_run_id: Optional[uuid.UUID], limit: int, @@ -148,7 +150,14 @@ async def list_user_runs( runs = [] for r in run_models: try: - runs.append(run_model_to_run(r, return_in_api=True)) + runs.append( + run_model_to_run( + r, + return_in_api=True, + include_jobs=include_jobs, + job_submissions_limit=job_submissions_limit, + ) + ) except pydantic.ValidationError: pass if len(run_models) > len(runs): @@ -652,46 +661,26 @@ async def delete_runs( def run_model_to_run( run_model: RunModel, - include_job_submissions: bool = True, + include_jobs: bool = True, + job_submissions_limit: Optional[int] = None, return_in_api: bool = False, include_sensitive: bool = False, ) -> Run: jobs: List[Job] = [] - run_jobs = sorted(run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num)) - for replica_num, replica_submissions in itertools.groupby( - run_jobs, key=lambda j: j.replica_num - ): - for job_num, job_submissions in itertools.groupby( - replica_submissions, key=lambda j: j.job_num - ): - submissions = [] - job_model = None - for job_model in job_submissions: - if include_job_submissions: - job_submission = job_model_to_job_submission(job_model) - if return_in_api: - # Set default non-None values for 0.18 backward-compatibility - # Remove in 0.19 - if job_submission.job_provisioning_data is not None: - if job_submission.job_provisioning_data.hostname is None: - job_submission.job_provisioning_data.hostname = "" - if job_submission.job_provisioning_data.ssh_port is None: - job_submission.job_provisioning_data.ssh_port = 22 - submissions.append(job_submission) - if job_model is not None: - # Use the spec from the latest submission. Submissions can have different specs - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) - if not include_sensitive: - _remove_job_spec_sensitive_info(job_spec) - jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) + if include_jobs: + jobs = _get_run_jobs_with_submissions( + run_model=run_model, + job_submissions_limit=job_submissions_limit, + return_in_api=return_in_api, + include_sensitive=include_sensitive, + ) run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) latest_job_submission = None - if include_job_submissions: + if len(jobs) > 0 and len(jobs[0].job_submissions) > 0: # TODO(egor-s): does it make sense with replicas and multi-node? - if jobs: - latest_job_submission = jobs[0].job_submissions[-1] + latest_job_submission = jobs[0].job_submissions[-1] service_spec = None if run_model.service_spec is not None: @@ -716,6 +705,47 @@ def run_model_to_run( return run +def _get_run_jobs_with_submissions( + run_model: RunModel, + job_submissions_limit: Optional[int], + return_in_api: bool = False, + include_sensitive: bool = False, +) -> List[Job]: + jobs: List[Job] = [] + run_jobs = sorted(run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num)) + for replica_num, replica_submissions in itertools.groupby( + run_jobs, key=lambda j: j.replica_num + ): + for job_num, job_models in itertools.groupby(replica_submissions, key=lambda j: j.job_num): + submissions = [] + job_model = None + if job_submissions_limit is not None: + if job_submissions_limit == 0: + # Take latest job submission to return its job_spec + job_models = list(job_models)[-1:] + else: + job_models = list(job_models)[-job_submissions_limit:] + for job_model in job_models: + if job_submissions_limit != 0: + job_submission = job_model_to_job_submission(job_model) + if return_in_api: + # Set default non-None values for 0.18 backward-compatibility + # Remove in 0.19 + if job_submission.job_provisioning_data is not None: + if job_submission.job_provisioning_data.hostname is None: + job_submission.job_provisioning_data.hostname = "" + if job_submission.job_provisioning_data.ssh_port is None: + job_submission.job_provisioning_data.ssh_port = 22 + submissions.append(job_submission) + if job_model is not None: + # Use the spec from the latest submission. Submissions can have different specs + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + if not include_sensitive: + _remove_job_spec_sensitive_info(job_spec) + jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) + return jobs + + async def _get_pool_offers( session: AsyncSession, project: ProjectModel, diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 1a4e0e1e2e..feeacbe63f 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -748,6 +748,7 @@ def list(self, all: bool = False, limit: Optional[int] = None) -> List[Run]: repo_id=None, only_active=only_active, limit=limit or 100, + job_submissions_limit=1, # no need to return more than 1 submission per job ) if only_active and len(runs) == 0: runs = self._api_client.runs.list( diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 2c85792eb2..4e80321ca4 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -33,12 +33,16 @@ def list( prev_run_id: Optional[UUID] = None, limit: int = 100, ascending: bool = False, + include_jobs: bool = True, + job_submissions_limit: Optional[int] = None, ) -> List[Run]: body = ListRunsRequest( project_name=project_name, repo_id=repo_id, username=username, only_active=only_active, + include_jobs=include_jobs, + job_submissions_limit=job_submissions_limit, prev_submitted_at=prev_submitted_at, prev_run_id=prev_run_id, limit=limit, diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 96adb6499d..eafd956e5a 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -707,6 +707,108 @@ async def test_lists_runs_pagination( assert len(response2_json) == 1 assert response2_json[0]["id"] == str(run2.id) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_limits_job_submissions( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run_submitted_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + submitted_at=run_submitted_at, + ) + run_spec = RunSpec.parse_raw(run.run_spec) + await create_job( + session=session, + run=run, + submitted_at=run_submitted_at, + last_processed_at=run_submitted_at, + ) + job2 = await create_job( + session=session, + run=run, + submitted_at=run_submitted_at, + last_processed_at=run_submitted_at, + ) + job2_spec = JobSpec.parse_raw(job2.job_spec_data) + response = await client.post( + "/api/runs/list", + headers=get_auth_headers(user.token), + json={"job_submissions_limit": 1}, + ) + assert response.status_code == 200, response.json() + assert response.json() == [ + { + "id": str(run.id), + "project_name": project.name, + "user": user.name, + "submitted_at": run_submitted_at.isoformat(), + "last_processed_at": run_submitted_at.isoformat(), + "status": "submitted", + "status_message": "submitted", + "run_spec": run_spec.dict(), + "jobs": [ + { + "job_spec": job2_spec.dict(), + "job_submissions": [ + { + "id": str(job2.id), + "submission_num": 0, + "deployment_num": 0, + "submitted_at": run_submitted_at.isoformat(), + "last_processed_at": run_submitted_at.isoformat(), + "finished_at": None, + "inactivity_secs": None, + "status": "submitted", + "status_message": "submitted", + "termination_reason": None, + "termination_reason_message": None, + "error": None, + "exit_status": None, + "job_provisioning_data": None, + "job_runtime_data": None, + } + ], + } + ], + "latest_job_submission": { + "id": str(job2.id), + "submission_num": 0, + "deployment_num": 0, + "submitted_at": run_submitted_at.isoformat(), + "last_processed_at": run_submitted_at.isoformat(), + "finished_at": None, + "inactivity_secs": None, + "status": "submitted", + "status_message": "submitted", + "termination_reason_message": None, + "termination_reason": None, + "error": None, + "exit_status": None, + "job_provisioning_data": None, + "job_runtime_data": None, + }, + "cost": 0, + "service": None, + "deployment_num": 0, + "termination_reason": None, + "error": None, + "deleted": False, + }, + ] + class TestGetRun: @pytest.mark.asyncio