diff --git a/src/dstack/_internal/server/schemas/logs.py b/src/dstack/_internal/server/schemas/logs.py index 0d6c0a02b0..8704473d41 100644 --- a/src/dstack/_internal/server/schemas/logs.py +++ b/src/dstack/_internal/server/schemas/logs.py @@ -12,5 +12,5 @@ class PollLogsRequest(CoreModel): start_time: Optional[datetime] end_time: Optional[datetime] descending: bool = False - limit: int = Field(100, ge=0, le=1000) + limit: int = Field(100, ge=1, le=1000) diagnose: bool = False diff --git a/src/dstack/_internal/server/services/logs/filelog.py b/src/dstack/_internal/server/services/logs/filelog.py index 6c6bdcac3f..efd6135d21 100644 --- a/src/dstack/_internal/server/services/logs/filelog.py +++ b/src/dstack/_internal/server/services/logs/filelog.py @@ -14,6 +14,7 @@ from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent from dstack._internal.server.services.logs.base import ( LogStorage, + LogStorageError, b64encode_raw_message, unix_time_ms_to_datetime, ) @@ -29,7 +30,8 @@ def __init__(self, root: Union[Path, str, None] = None) -> None: self.root = Path(root) def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: - # TODO Respect request.limit to support pagination + if request.descending is True: + raise LogStorageError("FileLogStorage doesn't support descending order") log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB log_file_path = self._get_log_file_path( project_name=project.name, @@ -46,12 +48,12 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi continue if request.end_time is None or log_event.timestamp < request.end_time: logs.append(log_event) + if len(logs) >= request.limit: + break else: break except IOError: pass - if request.descending: - logs = list(reversed(logs)) return JobSubmissionLogs(logs=logs) def write_logs( diff --git a/src/tests/_internal/server/services/test_logs.py b/src/tests/_internal/server/services/test_logs.py index 8250cc18a5..06b09f45e6 100644 --- a/src/tests/_internal/server/services/test_logs.py +++ b/src/tests/_internal/server/services/test_logs.py @@ -55,6 +55,69 @@ async def test_writes_logs(self, test_db, session: AsyncSession, tmp_path: Path) '{"timestamp": "2023-10-06T10:01:53.235000+00:00", "log_source": "stdout", "message": "V29ybGQ="}\n' ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_path: Path): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write more logs than the limit + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513234, message=b"Log1"), + RunnerLogEvent(timestamp=1696586513235, message=b"Log2"), + RunnerLogEvent(timestamp=1696586513236, message=b"Log3"), + RunnerLogEvent(timestamp=1696586513237, message=b"Log4"), + RunnerLogEvent(timestamp=1696586513238, message=b"Log5"), + ], + job_logs=[], + ) + logs = log_storage.poll_logs( + project, + PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + start_time=None, + end_time=None, + limit=1000, + diagnose=True, + ), + ).logs + assert len(logs) == 5 + + # Test with limit smaller than total logs (ascending) + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=3, + diagnose=True, + ) + job_submission_logs = log_storage.poll_logs(project, poll_request) + + # Should return only the first 3 logs in ascending order + assert len(job_submission_logs.logs) == 3 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log1".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.logs[1].message == base64.b64encode( + "Log2".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.logs[2].message == base64.b64encode( + "Log3".encode("utf-8") + ).decode("utf-8") + + # Test with limit of 1 + poll_request.limit = 1 + poll_request.start_time = logs[3].timestamp + job_submission_logs = log_storage.poll_logs(project, poll_request) + assert len(job_submission_logs.logs) == 1 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log5".encode("utf-8") + ).decode("utf-8") + class TestCloudWatchLogStorage: FAKE_NOW = datetime(2023, 10, 6, 10, 1, 54, tzinfo=timezone.utc)