From 418cc96d638f7ffdb322ab43033cfe57b50d6005 Mon Sep 17 00:00:00 2001 From: peterschmidt85 Date: Wed, 25 Jun 2025 20:34:10 +0200 Subject: [PATCH] [Bug]: Use a unique token for log pagination instead of a timestamp #2833 --- frontend/src/hooks/useInfiniteScroll.ts | 2 +- .../src/pages/Runs/Details/Logs/index.tsx | 95 +- .../Runs/Details/Logs/styles.module.scss | 30 +- frontend/src/services/project.ts | 15 +- frontend/src/types/log.d.ts | 30 +- .../_internal/core/compatibility/logs.py | 15 + src/dstack/_internal/core/models/logs.py | 3 +- src/dstack/_internal/server/schemas/logs.py | 11 +- .../_internal/server/services/logs/aws.py | 76 +- .../_internal/server/services/logs/filelog.py | 62 +- .../_internal/server/services/logs/gcp.py | 33 +- src/dstack/api/_public/runs.py | 11 +- src/dstack/api/server/_logs.py | 6 +- .../_internal/server/routers/test_logs.py | 6 +- .../_internal/server/services/test_logs.py | 1114 ++++++++++++++--- 15 files changed, 1196 insertions(+), 313 deletions(-) create mode 100644 src/dstack/_internal/core/compatibility/logs.py diff --git a/frontend/src/hooks/useInfiniteScroll.ts b/frontend/src/hooks/useInfiniteScroll.ts index f6db2fb8f3..3a3813ff92 100644 --- a/frontend/src/hooks/useInfiniteScroll.ts +++ b/frontend/src/hooks/useInfiniteScroll.ts @@ -26,7 +26,7 @@ export const useInfiniteScroll = ({ const [data, setData] = useState>([]); const scrollElement = useRef(document.documentElement); const isLoadingRef = useRef(false); - const lastRequestParams = useRef(undefined); + const lastRequestParams = useRef(undefined); const [disabledMore, setDisabledMore] = useState(false); const { limit, ...argsProp } = args; const lastArgsProps = useRef>(null); diff --git a/frontend/src/pages/Runs/Details/Logs/index.tsx b/frontend/src/pages/Runs/Details/Logs/index.tsx index 721b0527de..61cce49eb3 100644 --- a/frontend/src/pages/Runs/Details/Logs/index.tsx +++ b/frontend/src/pages/Runs/Details/Logs/index.tsx @@ -8,7 +8,7 @@ import { Terminal } from '@xterm/xterm'; import { Container, Header, ListEmptyMessage, Loader, TextContent } from 'components'; import { useAppSelector } from 'hooks'; -import { useGetProjectLogsQuery } from 'services/project'; +import { useLazyGetProjectLogsQuery } from 'services/project'; import { selectSystemMode } from 'App/slice'; @@ -22,10 +22,50 @@ export const Logs: React.FC = ({ className, projectName, runName, jobSub const { t } = useTranslation(); const appliedTheme = useAppSelector(selectSystemMode); - const terminalInstance = useRef(new Terminal()); - + const terminalInstance = useRef(new Terminal({scrollback: 10000000})); const fitAddonInstance = useRef(new FitAddon()); const [logsData, setLogsData] = useState([]); + const [isLoading, setIsLoading] = useState(false); + + const [getProjectLogs] = useLazyGetProjectLogsQuery(); + + const writeDataToTerminal = (logs: ILogItem[]) => { + logs.forEach((logItem) => { + terminalInstance.current.write(logItem.message); + }); + + fitAddonInstance.current.fit(); + }; + + const getNextLogItems = (nextToken?: string) => { + setIsLoading(true); + + if (!jobSubmissionId) { + return; + } + + getProjectLogs({ + project_name: projectName, + run_name: runName, + descending: false, + job_submission_id: jobSubmissionId ?? '', + next_token: nextToken, + limit: LIMIT_LOG_ROWS, + }) + .unwrap() + .then((response) => { + setLogsData((old) => [...old, ...response.logs]); + + writeDataToTerminal(response.logs); + + if (response.next_token) { + getNextLogItems(response.next_token); + } else { + setIsLoading(false); + } + }) + .catch(() => setIsLoading(false)); + }; useEffect(() => { if (appliedTheme === Mode.Light) { @@ -45,6 +85,8 @@ export const Logs: React.FC = ({ className, projectName, runName, jobSub useEffect(() => { terminalInstance.current.loadAddon(fitAddonInstance.current); + getNextLogItems(); + const onResize = () => { fitAddonInstance.current.fit(); }; @@ -56,50 +98,27 @@ export const Logs: React.FC = ({ className, projectName, runName, jobSub }; }, []); - const { - data: fetchData, - isLoading, - isFetching: isFetchingLogs, - } = useGetProjectLogsQuery( - { - project_name: projectName, - run_name: runName, - descending: true, - job_submission_id: jobSubmissionId ?? '', - limit: LIMIT_LOG_ROWS, - }, - { - skip: !jobSubmissionId, - }, - ); - - useEffect(() => { - if (fetchData) { - const reversed = [...fetchData].reverse(); - setLogsData((old) => [...reversed, ...old]); - } - }, [fetchData]); - useEffect(() => { const element = document.getElementById('terminal'); - if (logsData.length && terminalInstance.current && element) { + if (terminalInstance.current && element) { terminalInstance.current.open(element); - - logsData.forEach((logItem) => { - terminalInstance.current.write(logItem.message); - }); - - fitAddonInstance.current.fit(); } - }, [logsData]); + }, []); return (
- {t('projects.run.log')}}> + +
+ {t('projects.run.log')} + +
+ + } + > - - {!isLoading && !logsData.length && ( ['Projects'], }), - getProjectLogs: builder.query({ + getProjectLogs: builder.query({ query: ({ project_name, ...body }) => { return { url: API.PROJECTS.LOGS(project_name), @@ -84,11 +84,17 @@ export const projectApi = createApi({ keepUnusedDataFor: 0, providesTags: () => ['ProjectLogs'], - transformResponse: (response: { logs: ILogItem[] }) => - response.logs.map((logItem) => ({ + transformResponse: (response: { logs: ILogItem[]; next_token: string }) => { + const logs = response.logs.map((logItem) => ({ ...logItem, message: base64ToArrayBuffer(logItem.message as string), - })), + })); + + return { + ...response, + logs, + }; + }, }), getProjectRepos: builder.query({ @@ -111,5 +117,6 @@ export const { useUpdateProjectMembersMutation, useDeleteProjectsMutation, useGetProjectLogsQuery, + useLazyGetProjectLogsQuery, useGetProjectReposQuery, } = projectApi; diff --git a/frontend/src/types/log.d.ts b/frontend/src/types/log.d.ts index ec19243738..eec182c1ec 100644 --- a/frontend/src/types/log.d.ts +++ b/frontend/src/types/log.d.ts @@ -1,16 +1,22 @@ declare interface ILogItem { - log_source: 'stdout' | 'stderr' - timestamp: string, - message: string | Uint8Array, + log_source: 'stdout' | 'stderr'; + timestamp: string; + message: string | Uint8Array; } declare type TRequestLogsParams = { - project_name: IProject['project_name'], - run_name: IRun['run_name'], - job_submission_id: string - start_time?: DateTime, - end_time?: DateTime, - descending?: boolean, - limit?: number - diagnose?: boolean -} + project_name: IProject['project_name']; + run_name: IRun['run_name']; + job_submission_id: string; + start_time?: DateTime; + end_time?: DateTime; + descending?: boolean; + limit?: number; + diagnose?: boolean; + next_token?: string; +}; + +declare type TResponseLogsParams = { + logs: ILogItem[]; + next_token?: string; +}; diff --git a/src/dstack/_internal/core/compatibility/logs.py b/src/dstack/_internal/core/compatibility/logs.py new file mode 100644 index 0000000000..d6c2d3b3bb --- /dev/null +++ b/src/dstack/_internal/core/compatibility/logs.py @@ -0,0 +1,15 @@ +from typing import Dict, Optional + +from dstack._internal.server.schemas.logs import PollLogsRequest + + +def get_poll_logs_excludes(request: PollLogsRequest) -> Optional[Dict]: + """ + Returns exclude mapping to exclude certain fields from the request. + Use this method to exclude new fields when they are not set to keep + clients backward-compatibility with older servers. + """ + excludes = {} + if request.next_token is None: + excludes["next_token"] = True + return excludes if excludes else None diff --git a/src/dstack/_internal/core/models/logs.py b/src/dstack/_internal/core/models/logs.py index d93dd74a53..7dd5cb87de 100644 --- a/src/dstack/_internal/core/models/logs.py +++ b/src/dstack/_internal/core/models/logs.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import List +from typing import List, Optional from dstack._internal.core.models.common import CoreModel @@ -23,3 +23,4 @@ class LogEvent(CoreModel): class JobSubmissionLogs(CoreModel): logs: List[LogEvent] + next_token: Optional[str] diff --git a/src/dstack/_internal/server/schemas/logs.py b/src/dstack/_internal/server/schemas/logs.py index 0d6c0a02b0..267f5612fa 100644 --- a/src/dstack/_internal/server/schemas/logs.py +++ b/src/dstack/_internal/server/schemas/logs.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Optional -from pydantic import UUID4, Field +from pydantic import UUID4, Field, validator from dstack._internal.core.models.common import CoreModel @@ -12,5 +12,14 @@ class PollLogsRequest(CoreModel): start_time: Optional[datetime] end_time: Optional[datetime] descending: bool = False + next_token: Optional[str] = None limit: int = Field(100, ge=0, le=1000) diagnose: bool = False + + @validator("descending") + @classmethod + def validate_descending(cls, v): + # Descending is not supported until we migrate from base64-encoded logs to plain text logs. + if v is True: + raise ValueError("descending: true is not supported") + return v diff --git a/src/dstack/_internal/server/services/logs/aws.py b/src/dstack/_internal/server/services/logs/aws.py index 054bbc5883..92155763c3 100644 --- a/src/dstack/_internal/server/services/logs/aws.py +++ b/src/dstack/_internal/server/services/logs/aws.py @@ -78,14 +78,22 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi project.name, request.run_name, request.job_submission_id, log_producer ) cw_events: List[_CloudWatchLogEvent] + next_token: Optional[str] = None with self._wrap_boto_errors(): try: - cw_events = self._get_log_events(stream, request) + cw_events, next_token = self._get_log_events(stream, request) except botocore.exceptions.ClientError as e: if not self._is_resource_not_found_exception(e): raise - logger.debug("Stream %s not found, returning dummy response", stream) - cw_events = [] + # Check if the group exists to distinguish between group not found vs stream not found + try: + self._check_group_exists(self._group) + # Group exists, so the error must be due to missing stream + logger.debug("Stream %s not found, returning dummy response", stream) + cw_events = [] + except LogStorageError: + # Group doesn't exist, re-raise the LogStorageError + raise logs = [ LogEvent( timestamp=unix_time_ms_to_datetime(cw_event["timestamp"]), @@ -94,51 +102,43 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi ) for cw_event in cw_events ] - return JobSubmissionLogs(logs=logs) + return JobSubmissionLogs(logs=logs, next_token=next_token if len(logs) > 0 else None) - def _get_log_events(self, stream: str, request: PollLogsRequest) -> List[_CloudWatchLogEvent]: - limit = request.limit + def _get_log_events( + self, stream: str, request: PollLogsRequest + ) -> Tuple[List[_CloudWatchLogEvent], Optional[str]]: + start_from_head = not request.descending parameters = { "logGroupName": self._group, "logStreamName": stream, - "limit": limit, + "limit": request.limit, + "startFromHead": start_from_head, } - start_from_head = not request.descending - parameters["startFromHead"] = start_from_head + if request.start_time: - # XXX: Since callers use start_time/end_time for pagination, one millisecond is added - # to avoid an infinite loop because startTime boundary is inclusive. parameters["startTime"] = datetime_to_unix_time_ms(request.start_time) + 1 + if request.end_time: - # No need to substract one millisecond in this case, though, seems that endTime is - # exclusive, that is, time interval boundaries are [startTime, entTime) parameters["endTime"] = datetime_to_unix_time_ms(request.end_time) - # "Partially full or empty pages don't necessarily mean that pagination is finished. - # As long as the nextBackwardToken or nextForwardToken returned is NOT equal to the - # nextToken that you passed into the API call, there might be more log events available." - events: List[_CloudWatchLogEvent] = [] - next_token: Optional[str] = None + elif start_from_head: + # When startFromHead=true and no endTime is provided, set endTime to "now" + # to prevent infinite pagination as new logs arrive faster than we can read them + parameters["endTime"] = datetime_to_unix_time_ms(datetime.now(timezone.utc)) + + if request.next_token: + parameters["nextToken"] = request.next_token + + response = self._client.get_log_events(**parameters) + + events = response.get("events", []) next_token_key = "nextForwardToken" if start_from_head else "nextBackwardToken" - # Limit max tries to avoid a possible infinite loop if the API is misbehaving - tries_left = 10 - while tries_left: - if next_token is not None: - parameters["nextToken"] = next_token - response = self._client.get_log_events(**parameters) - if start_from_head: - events.extend(response["events"]) - else: - # Regardless of the startFromHead value log events are arranged in - # chronological order, from earliest to latest. - events.extend(reversed(response["events"])) - if len(events) >= limit: - return events[:limit] - if response[next_token_key] == next_token: - return events - next_token = response[next_token_key] - tries_left -= 1 - logger.warning("too many requests to stream %s, returning partial response", stream) - return events + next_token = response.get(next_token_key) + + # TODO: The code below is not going to be used until we migrate from base64-encoded logs to plain text logs. + if request.descending: + events = list(reversed(events)) + + return events, next_token def write_logs( self, diff --git a/src/dstack/_internal/server/services/logs/filelog.py b/src/dstack/_internal/server/services/logs/filelog.py index 6c6bdcac3f..905ee3527b 100644 --- a/src/dstack/_internal/server/services/logs/filelog.py +++ b/src/dstack/_internal/server/services/logs/filelog.py @@ -14,6 +14,7 @@ from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent from dstack._internal.server.services.logs.base import ( LogStorage, + LogStorageError, b64encode_raw_message, unix_time_ms_to_datetime, ) @@ -29,7 +30,9 @@ def __init__(self, root: Union[Path, str, None] = None) -> None: self.root = Path(root) def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: - # TODO Respect request.limit to support pagination + if request.descending: + raise LogStorageError("descending: true is not supported") + log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB log_file_path = self._get_log_file_path( project_name=project.name, @@ -37,22 +40,53 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi job_submission_id=request.job_submission_id, producer=log_producer, ) + + start_line = 0 + if request.next_token: + try: + start_line = int(request.next_token) + if start_line < 0: + raise LogStorageError( + f"Invalid next_token: {request.next_token}. Must be a non-negative integer." + ) + except ValueError: + raise LogStorageError( + f"Invalid next_token: {request.next_token}. Must be a valid integer." + ) + logs = [] + next_token = None + current_line = 0 + try: with open(log_file_path) as f: - for line in f: - log_event = LogEvent.__response__.parse_raw(line) - if request.start_time and log_event.timestamp <= request.start_time: - continue - if request.end_time is None or log_event.timestamp < request.end_time: - logs.append(log_event) - else: - break - except IOError: - pass - if request.descending: - logs = list(reversed(logs)) - return JobSubmissionLogs(logs=logs) + lines = f.readlines() + + for i, line in enumerate(lines): + if current_line < start_line: + current_line += 1 + continue + + log_event = LogEvent.__response__.parse_raw(line) + current_line += 1 + + if request.start_time and log_event.timestamp <= request.start_time: + continue + if request.end_time is not None and log_event.timestamp >= request.end_time: + break + + logs.append(log_event) + + if len(logs) >= request.limit: + # Only set next_token if there are more lines to read + if current_line < len(lines): + next_token = str(current_line) + break + + except IOError as e: + raise LogStorageError(f"Failed to read log file {log_file_path}: {e}") + + return JobSubmissionLogs(logs=logs, next_token=next_token) def write_logs( self, diff --git a/src/dstack/_internal/server/services/logs/gcp.py b/src/dstack/_internal/server/services/logs/gcp.py index 26a3a39930..ac228e19e1 100644 --- a/src/dstack/_internal/server/services/logs/gcp.py +++ b/src/dstack/_internal/server/services/logs/gcp.py @@ -1,5 +1,4 @@ -import time -from typing import Iterable, List +from typing import List from uuid import UUID from dstack._internal.core.errors import ServerClientError @@ -25,7 +24,8 @@ try: import google.api_core.exceptions import google.auth.exceptions - from google.cloud import logging + from google.cloud import logging_v2 + from google.cloud.logging_v2.types import ListLogEntriesRequest except ImportError: GCP_LOGGING_AVAILABLE = False @@ -50,7 +50,7 @@ class GCPLogStorage(LogStorage): def __init__(self, project_id: str): try: - self.client = logging.Client(project=project_id) + self.client = logging_v2.Client(project=project_id) self.logger = self.client.logger(name=self.LOG_NAME) self.logger.list_entries(max_results=1) # Python client doesn't seem to support dry_run, @@ -64,6 +64,7 @@ def __init__(self, project_id: str): raise LogStorageError("Insufficient permissions") def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs: + # TODO: GCP may return logs in random order when events have the same timestamp. producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB stream_name = self._get_stream_name( project_name=project.name, @@ -78,23 +79,27 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi log_filters.append(f'timestamp < "{request.end_time.isoformat()}"') log_filter = " AND ".join(log_filters) - order_by = logging.DESCENDING if request.descending else logging.ASCENDING + order_by = logging_v2.DESCENDING if request.descending else logging_v2.ASCENDING try: - entries: Iterable[logging.LogEntry] = self.logger.list_entries( - filter_=log_filter, + # Use low-level API to get access to next_page_token + request_obj = ListLogEntriesRequest( + resource_names=[f"projects/{self.client.project}"], + filter=log_filter, order_by=order_by, - max_results=request.limit, - # Specify max possible page_size (<=1000) to reduce number of API calls. page_size=request.limit, + page_token=request.next_token, ) + response = self.client._logging_api._gapic_api.list_log_entries(request=request_obj) + logs = [ LogEvent( timestamp=entry.timestamp, - message=entry.payload["message"], + message=entry.json_payload.get("message"), log_source=LogEventSource.STDOUT, ) - for entry in entries + for entry in response.entries ] + next_token = response.next_page_token or None except google.api_core.exceptions.ResourceExhausted as e: logger.warning("GCP Logging exception: %s", repr(e)) # GCP Logging has severely low quota of 60 reads/min for entries.list @@ -102,11 +107,7 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi "GCP Logging read request limit exceeded." " It's recommended to increase default entries.list request quota from 60 per minute." ) - # We intentionally make reading logs slow to prevent hitting GCP quota. - # This doesn't help with many concurrent clients but - # should help with one client reading all logs sequentially. - time.sleep(1) - return JobSubmissionLogs(logs=logs) + return JobSubmissionLogs(logs=logs, next_token=next_token if len(logs) > 0 else None) def write_logs( self, diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 9069154705..da16fc3c6a 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -204,25 +204,26 @@ def logs( job = self._find_job(replica_num=replica_num, job_num=job_num) if job is None: return [] - next_start_time = start_time + next_token = None while True: resp = self._api_client.logs.poll( project_name=self._project, body=PollLogsRequest( run_name=self.name, job_submission_id=job.job_submissions[-1].id, - start_time=next_start_time, + start_time=start_time, end_time=None, descending=False, limit=1000, diagnose=diagnose, + next_token=next_token, ), ) - if len(resp.logs) == 0: - return [] for log in resp.logs: yield base64.b64decode(log.message) - next_start_time = resp.logs[-1].timestamp + next_token = resp.next_token + if next_token is None: + break def refresh(self): """ diff --git a/src/dstack/api/server/_logs.py b/src/dstack/api/server/_logs.py index b82d7017d7..7cdfc246f7 100644 --- a/src/dstack/api/server/_logs.py +++ b/src/dstack/api/server/_logs.py @@ -1,5 +1,6 @@ from pydantic import parse_obj_as +from dstack._internal.core.compatibility.logs import get_poll_logs_excludes from dstack._internal.core.models.logs import JobSubmissionLogs from dstack._internal.server.schemas.logs import PollLogsRequest from dstack.api.server._group import APIClientGroup @@ -7,5 +8,8 @@ class LogsAPIClient(APIClientGroup): def poll(self, project_name: str, body: PollLogsRequest) -> JobSubmissionLogs: - resp = self._request(f"/api/project/{project_name}/logs/poll", body=body.json()) + resp = self._request( + f"/api/project/{project_name}/logs/poll", + body=body.json(exclude=get_poll_logs_excludes(body)), + ) return parse_obj_as(JobSubmissionLogs.__response__, resp.json()) diff --git a/src/tests/_internal/server/routers/test_logs.py b/src/tests/_internal/server/routers/test_logs.py index 688087b414..11f0da8daf 100644 --- a/src/tests/_internal/server/routers/test_logs.py +++ b/src/tests/_internal/server/routers/test_logs.py @@ -74,7 +74,8 @@ async def test_returns_logs( "log_source": "stdout", "message": "!", }, - ] + ], + "next_token": None, } response = await client.post( f"/api/project/{project.name}/logs/poll", @@ -94,5 +95,6 @@ async def test_returns_logs( "log_source": "stdout", "message": "!", }, - ] + ], + "next_token": None, } diff --git a/src/tests/_internal/server/services/test_logs.py b/src/tests/_internal/server/services/test_logs.py index 8250cc18a5..19769a3602 100644 --- a/src/tests/_internal/server/services/test_logs.py +++ b/src/tests/_internal/server/services/test_logs.py @@ -1,5 +1,4 @@ import base64 -import itertools import logging from datetime import datetime, timedelta, timezone from pathlib import Path @@ -11,6 +10,7 @@ import pytest import pytest_asyncio from freezegun import freeze_time +from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.logs import LogEvent, LogEventSource @@ -55,6 +55,581 @@ async def test_writes_logs(self, test_db, session: AsyncSession, tmp_path: Path) '{"timestamp": "2023-10-06T10:01:53.235000+00:00", "log_source": "stdout", "message": "V29ybGQ="}\n' ) + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_basic(self, test_db, session: AsyncSession, tmp_path: Path): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write test logs + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513234, message=b"Log1"), + RunnerLogEvent(timestamp=1696586513235, message=b"Log2"), + RunnerLogEvent(timestamp=1696586513236, message=b"Log3"), + ], + job_logs=[], + ) + + # Test basic polling without pagination + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=10, + diagnose=True, + ) + job_submission_logs = log_storage.poll_logs(project, poll_request) + + assert len(job_submission_logs.logs) == 3 + assert job_submission_logs.next_token is None # No more logs, so no next_token + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_with_next_token_pagination( + self, test_db, session: AsyncSession, tmp_path: Path + ): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write test logs + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513234, message=b"Log1"), + RunnerLogEvent(timestamp=1696586513235, message=b"Log2"), + RunnerLogEvent(timestamp=1696586513236, message=b"Log3"), + RunnerLogEvent(timestamp=1696586513237, message=b"Log4"), + RunnerLogEvent(timestamp=1696586513238, message=b"Log5"), + ], + job_logs=[], + ) + + # First page: get 2 logs + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=2, + diagnose=True, + ) + job_submission_logs = log_storage.poll_logs(project, poll_request) + + assert len(job_submission_logs.logs) == 2 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log1".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.logs[1].message == base64.b64encode( + "Log2".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.next_token == "2" # Next line to read + + # Second page: use next_token + poll_request.next_token = job_submission_logs.next_token + job_submission_logs = log_storage.poll_logs(project, poll_request) + + assert len(job_submission_logs.logs) == 2 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log3".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.logs[1].message == base64.b64encode( + "Log4".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.next_token == "4" # Next line to read + + # Third page: get remaining log + poll_request.next_token = job_submission_logs.next_token + job_submission_logs = log_storage.poll_logs(project, poll_request) + + assert len(job_submission_logs.logs) == 1 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log5".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.next_token is None # No more logs + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_with_start_from_specific_line( + self, test_db, session: AsyncSession, tmp_path: Path + ): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write test logs + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513234, message=b"Log1"), + RunnerLogEvent(timestamp=1696586513235, message=b"Log2"), + RunnerLogEvent(timestamp=1696586513236, message=b"Log3"), + ], + job_logs=[], + ) + + # Start from line 1 (second log) + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + next_token="1", + limit=10, + diagnose=True, + ) + job_submission_logs = log_storage.poll_logs(project, poll_request) + + assert len(job_submission_logs.logs) == 2 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log2".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.logs[1].message == base64.b64encode( + "Log3".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.next_token is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_invalid_next_token_raises_error( + self, test_db, session: AsyncSession, tmp_path: Path + ): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Test with non-integer next_token + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + next_token="invalid", + limit=10, + diagnose=True, + ) + with pytest.raises( + LogStorageError, match="Invalid next_token: invalid. Must be a valid integer." + ): + log_storage.poll_logs(project, poll_request) + + # Test with negative next_token + poll_request.next_token = "-1" + with pytest.raises( + LogStorageError, match="Invalid next_token: -1. Must be a non-negative integer." + ): + log_storage.poll_logs(project, poll_request) + + # Test with float next_token + poll_request.next_token = "1.5" + with pytest.raises( + LogStorageError, match="Invalid next_token: 1.5. Must be a valid integer." + ): + log_storage.poll_logs(project, poll_request) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_descending_raises_error( + self, test_db, session: AsyncSession, tmp_path: Path + ): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Test that descending=True raises LogStorageError + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=10, + diagnose=True, + # Note: This bypasses schema validation for testing the implementation + ) + poll_request.descending = True # Set directly to bypass validation + + with pytest.raises(LogStorageError, match="descending: true is not supported"): + log_storage.poll_logs(project, poll_request) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_file_not_found_raises_error( + self, test_db, session: AsyncSession, tmp_path: Path + ): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Test with non-existent log file + poll_request = PollLogsRequest( + run_name="nonexistent_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=10, + diagnose=True, + ) + + with pytest.raises( + LogStorageError, match="Failed to read log file .* No such file or directory" + ): + log_storage.poll_logs(project, poll_request) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_with_time_filtering_and_pagination( + self, test_db, session: AsyncSession, tmp_path: Path + ): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write test logs with different timestamps + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent( + timestamp=1696586513234, message=b"Log1" + ), # 2023-10-06T10:01:53.234 + RunnerLogEvent( + timestamp=1696586513235, message=b"Log2" + ), # 2023-10-06T10:01:53.235 + RunnerLogEvent( + timestamp=1696586513236, message=b"Log3" + ), # 2023-10-06T10:01:53.236 + RunnerLogEvent( + timestamp=1696586513237, message=b"Log4" + ), # 2023-10-06T10:01:53.237 + ], + job_logs=[], + ) + + # Filter logs after 2023-10-06T10:01:53.235 with pagination + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + start_time=datetime(2023, 10, 6, 10, 1, 53, 235000, timezone.utc), + limit=1, + diagnose=True, + ) + job_submission_logs = log_storage.poll_logs(project, poll_request) + + # Should get Log3 first (timestamp > 235) + assert len(job_submission_logs.logs) == 1 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log3".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.next_token == "3" + + # Get next page + poll_request.next_token = job_submission_logs.next_token + job_submission_logs = log_storage.poll_logs(project, poll_request) + + # Should get Log4 + assert len(job_submission_logs.logs) == 1 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log4".encode("utf-8") + ).decode("utf-8") + # Should not have next_token since we reached end of file + assert job_submission_logs.next_token is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_empty_file_returns_empty_list( + self, test_db, session: AsyncSession, tmp_path: Path + ): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Create empty log file + log_file_path = ( + tmp_path + / "projects" + / project.name + / "logs" + / "test_run" + / "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e" + / "runner.log" + ) + log_file_path.parent.mkdir(parents=True, exist_ok=True) + log_file_path.write_text("") + + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=10, + diagnose=True, + ) + job_submission_logs = log_storage.poll_logs(project, poll_request) + + assert len(job_submission_logs.logs) == 0 + assert job_submission_logs.next_token is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_next_token_pagination_complete_workflow( + self, test_db, session: AsyncSession, tmp_path: Path + ): + """Test complete pagination workflow using next_token""" + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write 10 logs + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513000 + i, message=f"Log{i + 1}".encode()) + for i in range(10) + ], + job_logs=[], + ) + + # First page: get 3 logs + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=3, + diagnose=True, + ) + page1 = log_storage.poll_logs(project, poll_request) + + assert len(page1.logs) == 3 + assert page1.logs[0].message == base64.b64encode("Log1".encode()).decode() + assert page1.logs[1].message == base64.b64encode("Log2".encode()).decode() + assert page1.logs[2].message == base64.b64encode("Log3".encode()).decode() + assert page1.next_token == "3" # Next line to read + + # Second page: use next_token + poll_request.next_token = page1.next_token + page2 = log_storage.poll_logs(project, poll_request) + + assert len(page2.logs) == 3 + assert page2.logs[0].message == base64.b64encode("Log4".encode()).decode() + assert page2.logs[1].message == base64.b64encode("Log5".encode()).decode() + assert page2.logs[2].message == base64.b64encode("Log6".encode()).decode() + assert page2.next_token == "6" + + # Third page: get more logs + poll_request.next_token = page2.next_token + page3 = log_storage.poll_logs(project, poll_request) + + assert len(page3.logs) == 3 + assert page3.logs[0].message == base64.b64encode("Log7".encode()).decode() + assert page3.logs[1].message == base64.b64encode("Log8".encode()).decode() + assert page3.logs[2].message == base64.b64encode("Log9".encode()).decode() + assert page3.next_token == "9" + + # Fourth page: get last log + poll_request.next_token = page3.next_token + page4 = log_storage.poll_logs(project, poll_request) + + assert len(page4.logs) == 1 + assert page4.logs[0].message == base64.b64encode("Log10".encode()).decode() + assert page4.next_token is None # No more logs + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_next_token_with_time_filtering( + self, test_db, session: AsyncSession, tmp_path: Path + ): + """Test next_token behavior with time filtering""" + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write logs with different timestamps + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513000, message=b"Log1"), # Before filter + RunnerLogEvent(timestamp=1696586513100, message=b"Log2"), # Before filter + RunnerLogEvent(timestamp=1696586513200, message=b"Log3"), # After filter + RunnerLogEvent(timestamp=1696586513300, message=b"Log4"), # After filter + RunnerLogEvent(timestamp=1696586513400, message=b"Log5"), # After filter + ], + job_logs=[], + ) + + # Filter logs after timestamp 150 with pagination + start_time = datetime.fromtimestamp(1696586513.150, tz=timezone.utc) + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + start_time=start_time, + limit=2, + diagnose=True, + ) + + page1 = log_storage.poll_logs(project, poll_request) + assert len(page1.logs) == 2 + assert page1.logs[0].message == base64.b64encode("Log3".encode()).decode() + assert page1.logs[1].message == base64.b64encode("Log4".encode()).decode() + assert page1.next_token == "4" + + # Get next page + poll_request.next_token = page1.next_token + page2 = log_storage.poll_logs(project, poll_request) + assert len(page2.logs) == 1 + assert page2.logs[0].message == base64.b64encode("Log5".encode()).decode() + assert page2.next_token is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_next_token_edge_cases(self, test_db, session: AsyncSession, tmp_path: Path): + """Test edge cases for next_token behavior""" + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write exactly one log + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513000, message=b"OnlyLog"), + ], + job_logs=[], + ) + + # Request with limit higher than available logs + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=10, + diagnose=True, + ) + result = log_storage.poll_logs(project, poll_request) + + assert len(result.logs) == 1 + assert result.logs[0].message == base64.b64encode("OnlyLog".encode()).decode() + assert result.next_token is None # No more logs available + + # Request with limit equal to available logs + poll_request.limit = 1 + result = log_storage.poll_logs(project, poll_request) + + assert len(result.logs) == 1 + assert result.logs[0].message == base64.b64encode("OnlyLog".encode()).decode() + assert result.next_token is None # No more logs available + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_next_token_beyond_file_end( + self, test_db, session: AsyncSession, tmp_path: Path + ): + """Test next_token that points beyond the end of file""" + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write 3 logs + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513000, message=b"Log1"), + RunnerLogEvent(timestamp=1696586513100, message=b"Log2"), + RunnerLogEvent(timestamp=1696586513200, message=b"Log3"), + ], + job_logs=[], + ) + + # Use next_token that points beyond the file + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + next_token="10", # Points beyond the 3 logs in file + limit=5, + diagnose=True, + ) + result = log_storage.poll_logs(project, poll_request) + + assert len(result.logs) == 0 + assert result.next_token is None + + +class TestPollLogsRequestValidation: + def test_descending_true_not_supported(self): + """Test that descending: true raises a validation error.""" + with pytest.raises(ValidationError, match="descending: true is not supported"): + PollLogsRequest( + run_name="test-run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + descending=True, + ) + + def test_descending_false_is_supported(self): + """Test that descending: false works correctly.""" + request = PollLogsRequest( + run_name="test-run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + descending=False, + ) + assert request.descending is False + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_path: Path): + project = await create_project(session=session) + log_storage = FileLogStorage(tmp_path) + + # Write more logs than the limit + log_storage.write_logs( + project=project, + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + runner_logs=[ + RunnerLogEvent(timestamp=1696586513234, message=b"Log1"), + RunnerLogEvent(timestamp=1696586513235, message=b"Log2"), + RunnerLogEvent(timestamp=1696586513236, message=b"Log3"), + RunnerLogEvent(timestamp=1696586513237, message=b"Log4"), + RunnerLogEvent(timestamp=1696586513238, message=b"Log5"), + ], + job_logs=[], + ) + logs = log_storage.poll_logs( + project, + PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + start_time=None, + end_time=None, + limit=1000, + diagnose=True, + ), + ).logs + assert len(logs) == 5 + + # Test with limit smaller than total logs + poll_request = PollLogsRequest( + run_name="test_run", + job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"), + limit=3, + diagnose=True, + ) + job_submission_logs = log_storage.poll_logs(project, poll_request) + + # Should return only the first 3 logs and provide next_token + assert len(job_submission_logs.logs) == 3 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log1".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.logs[1].message == base64.b64encode( + "Log2".encode("utf-8") + ).decode("utf-8") + assert job_submission_logs.logs[2].message == base64.b64encode( + "Log3".encode("utf-8") + ).decode("utf-8") + # Should have next_token pointing to line 3 (fourth log) + assert job_submission_logs.next_token == "3" + + # Test with limit of 1 and time filtering + poll_request.limit = 1 + poll_request.start_time = logs[3].timestamp + job_submission_logs = log_storage.poll_logs(project, poll_request) + assert len(job_submission_logs.logs) == 1 + assert job_submission_logs.logs[0].message == base64.b64encode( + "Log5".encode("utf-8") + ).decode("utf-8") + # Should not have next_token since we reached end of file + assert job_submission_logs.next_token is None + class TestCloudWatchLogStorage: FAKE_NOW = datetime(2023, 10, 6, 10, 1, 54, tzinfo=timezone.utc) @@ -171,34 +746,6 @@ def test_ensure_stream_exists_cached_forced( logGroupName="test-group", logStreamName="test-stream" ) - @pytest.mark.asyncio - async def test_poll_logs_non_empty_response( - self, - project: ProjectModel, - log_storage: CloudWatchLogStorage, - mock_client: Mock, - poll_logs_request: PollLogsRequest, - ): - mock_client.get_log_events.return_value["events"] = [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, - ] - poll_logs_request.limit = 2 - job_submission_logs = log_storage.poll_logs(project, poll_logs_request) - - assert job_submission_logs.logs == [ - LogEvent( - timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc), - log_source=LogEventSource.STDOUT, - message="SGVsbG8=", - ), - LogEvent( - timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc), - log_source=LogEventSource.STDOUT, - message="V29ybGQ=", - ), - ] - @pytest.mark.asyncio @pytest.mark.parametrize("descending", [False, True]) async def test_poll_logs_empty_response( @@ -207,90 +754,32 @@ async def test_poll_logs_empty_response( log_storage: CloudWatchLogStorage, mock_client: Mock, poll_logs_request: PollLogsRequest, - descending: bool, - ): - mock_client.get_log_events.return_value["events"] = [] - poll_logs_request.descending = descending - job_submission_logs = log_storage.poll_logs(project, poll_logs_request) - - assert job_submission_logs.logs == [] - assert mock_client.get_log_events.call_count == 2 - - @pytest.mark.asyncio - async def test_poll_logs_descending_non_empty_response_on_first_call( - self, - project: ProjectModel, - log_storage: CloudWatchLogStorage, - mock_client: Mock, - poll_logs_request: PollLogsRequest, - ): - mock_client.get_log_events.return_value["events"] = [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, - ] - poll_logs_request.descending = True - poll_logs_request.limit = 2 - job_submission_logs = log_storage.poll_logs(project, poll_logs_request) - - assert job_submission_logs.logs == [ - LogEvent( - timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc), - log_source=LogEventSource.STDOUT, - message="V29ybGQ=", - ), - LogEvent( - timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc), - log_source=LogEventSource.STDOUT, - message="SGVsbG8=", - ), - ] - - @pytest.mark.asyncio - async def test_poll_logs_descending_some_responses_are_empty( - self, - project: ProjectModel, - log_storage: CloudWatchLogStorage, - mock_client: Mock, - poll_logs_request: PollLogsRequest, - ): - # The first two calls return empty event lists, though the token is not the same, meaning - # there are more events, see: https://github.com/dstackai/dstack/issues/1647 - # As the third call returns less events than requested (2 < 3), we continue to poll until - # accumulate enough events (2 + 2) and return exactly the requested number of events (3), - # see: https://github.com/dstackai/dstack/issues/2500 - mock_client.get_log_events.side_effect = [ - { - "events": [], - "nextBackwardToken": "bwd1", - "nextForwardToken": "fwd", - }, - { - "events": [], - "nextBackwardToken": "bwd2", - "nextForwardToken": "fwd", - }, - { - "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - {"timestamp": 1696586513235, "message": "V29ybGQ="}, - ], - "nextBackwardToken": "bwd3", - "nextForwardToken": "fwd", - }, - { - "events": [], - "nextBackwardToken": "bwd4", - "nextForwardToken": "fwd", - }, - { - "events": [ - {"timestamp": 1696586513232, "message": "aW5pdCAx"}, - {"timestamp": 1696586513233, "message": "aW5pdCAy"}, - ], - "nextBackwardToken": "bwd5", - "nextForwardToken": "fwd", - }, - ] + descending: bool, + ): + mock_client.get_log_events.return_value["events"] = [] + poll_logs_request.descending = descending + job_submission_logs = log_storage.poll_logs(project, poll_logs_request) + + assert job_submission_logs.logs == [] + assert mock_client.get_log_events.call_count == 1 + + @pytest.mark.asyncio + async def test_poll_logs_descending_some_responses_are_empty( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + # Test that the current implementation returns the events from a single API call + mock_client.get_log_events.return_value = { + "events": [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513235, "message": "V29ybGQ="}, + ], + "nextBackwardToken": "bwd3", + "nextForwardToken": "fwd", + } poll_logs_request.descending = True poll_logs_request.limit = 3 job_submission_logs = log_storage.poll_logs(project, poll_logs_request) @@ -306,13 +795,8 @@ async def test_poll_logs_descending_some_responses_are_empty( log_source=LogEventSource.STDOUT, message="SGVsbG8=", ), - LogEvent( - timestamp=datetime(2023, 10, 6, 10, 1, 53, 233000, tzinfo=timezone.utc), - log_source=LogEventSource.STDOUT, - message="aW5pdCAy", - ), ] - assert mock_client.get_log_events.call_count == 5 + assert mock_client.get_log_events.call_count == 1 @pytest.mark.asyncio async def test_poll_logs_descending_empty_response_with_same_token( @@ -322,34 +806,17 @@ async def test_poll_logs_descending_empty_response_with_same_token( mock_client: Mock, poll_logs_request: PollLogsRequest, ): - # The first two calls return empty event lists with the same token, meaning we reached - # the end. - # https://github.com/dstackai/dstack/issues/1647 - mock_client.get_log_events.side_effect = [ - { - "events": [], - "nextBackwardToken": "bwd", - "nextForwardToken": "fwd", - }, - { - "events": [], - "nextBackwardToken": "bwd", - "nextForwardToken": "fwd", - }, - # We should not reach this response - { - "events": [ - {"timestamp": 1696586513234, "message": "SGVsbG8="}, - ], - "nextBackwardToken": "bwd2", - "nextForwardToken": "fwd", - }, - ] + # Test empty response from a single API call + mock_client.get_log_events.return_value = { + "events": [], + "nextBackwardToken": "bwd", + "nextForwardToken": "fwd", + } poll_logs_request.descending = True job_submission_logs = log_storage.poll_logs(project, poll_logs_request) assert job_submission_logs.logs == [] - assert mock_client.get_log_events.call_count == 2 + assert mock_client.get_log_events.call_count == 1 @pytest.mark.asyncio async def test_poll_logs_descending_empty_response_max_tries( @@ -359,24 +826,17 @@ async def test_poll_logs_descending_empty_response_max_tries( mock_client: Mock, poll_logs_request: PollLogsRequest, ): - # Test for a circuit breaker when the API returns empty results on each call, but the - # token is different on each call. - # https://github.com/dstackai/dstack/issues/1647 - counter = itertools.count() - - def _response_producer(*args, **kwargs): - return { - "events": [], - "nextBackwardToken": f"bwd{next(counter)}", - "nextForwardToken": "fwd", - } - - mock_client.get_log_events.side_effect = _response_producer + # Test empty response from a single API call + mock_client.get_log_events.return_value = { + "events": [], + "nextBackwardToken": "bwd1", + "nextForwardToken": "fwd", + } poll_logs_request.descending = True job_submission_logs = log_storage.poll_logs(project, poll_logs_request) assert job_submission_logs.logs == [] - assert mock_client.get_log_events.call_count == 10 + assert mock_client.get_log_events.call_count == 1 @pytest.mark.asyncio async def test_poll_logs_request_params_asc_no_diag_no_dates( @@ -390,13 +850,15 @@ async def test_poll_logs_request_params_asc_no_diag_no_dates( poll_logs_request.limit = 5 poll_logs_request.diagnose = False log_storage.poll_logs(project, poll_logs_request) - assert mock_client.get_log_events.call_count == 2 + assert mock_client.get_log_events.call_count == 1 mock_client.get_log_events.assert_called_with( logGroupName="test-group", logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/job", limit=5, startFromHead=True, - nextToken="fwd", + endTime=mock_client.get_log_events.call_args.kwargs[ + "endTime" + ], # endTime is set to "now" ) @pytest.mark.asyncio @@ -407,8 +869,7 @@ async def test_poll_logs_request_params_desc_diag_with_dates( mock_client: Mock, poll_logs_request: PollLogsRequest, ): - # Ensure the first response has events to avoid triggering a workaround for - # https://github.com/dstackai/dstack/issues/1647 + # Ensure the response has events mock_client.get_log_events.return_value["events"] = [ {"timestamp": 1696586513234, "message": "SGVsbG8="} ] @@ -420,15 +881,14 @@ async def test_poll_logs_request_params_desc_diag_with_dates( poll_logs_request.limit = 10 poll_logs_request.diagnose = True log_storage.poll_logs(project, poll_logs_request) - assert mock_client.get_log_events.call_count == 2 + assert mock_client.get_log_events.call_count == 1 mock_client.get_log_events.assert_called_with( logGroupName="test-group", logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/runner", limit=10, startFromHead=False, - startTime=1696586513235, + startTime=1696586513235, # start_time + 1ms endTime=1696672913234, - nextToken="bwd", ) @pytest.mark.asyncio @@ -792,3 +1252,329 @@ def _delta_ms(**kwargs: int) -> int: for c in mock_client.put_log_events.call_args_list ] assert actual == expected + + @pytest.mark.asyncio + async def test_poll_logs_non_empty_response( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + mock_client.get_log_events.return_value["events"] = [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513235, "message": "V29ybGQ="}, + ] + poll_logs_request.limit = 2 + job_submission_logs = log_storage.poll_logs(project, poll_logs_request) + + assert job_submission_logs.logs == [ + LogEvent( + timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc), + log_source=LogEventSource.STDOUT, + message="SGVsbG8=", + ), + LogEvent( + timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc), + log_source=LogEventSource.STDOUT, + message="V29ybGQ=", + ), + ] + + @pytest.mark.asyncio + async def test_poll_logs_descending_non_empty_response_on_first_call( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + mock_client.get_log_events.return_value["events"] = [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513235, "message": "V29ybGQ="}, + ] + poll_logs_request.descending = True + poll_logs_request.limit = 2 + job_submission_logs = log_storage.poll_logs(project, poll_logs_request) + + assert job_submission_logs.logs == [ + LogEvent( + timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc), + log_source=LogEventSource.STDOUT, + message="V29ybGQ=", + ), + LogEvent( + timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc), + log_source=LogEventSource.STDOUT, + message="SGVsbG8=", + ), + ] + + @pytest.mark.asyncio + async def test_next_token_ascending_pagination( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + """Test next_token behavior for ascending pagination""" + # Setup response with nextForwardToken + mock_client.get_log_events.return_value = { + "events": [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513235, "message": "V29ybGQ="}, + ], + "nextBackwardToken": "bwd", + "nextForwardToken": "fwd123", + } + + poll_logs_request.descending = False + poll_logs_request.limit = 2 + result = log_storage.poll_logs(project, poll_logs_request) + + assert len(result.logs) == 2 + assert result.next_token == "fwd123" # Should return nextForwardToken + + # Verify API was called with correct parameters + mock_client.get_log_events.assert_called_once_with( + logGroupName="test-group", + logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/job", + limit=2, + startFromHead=True, + endTime=mock_client.get_log_events.call_args.kwargs["endTime"], # endTime is auto-set + ) + + @pytest.mark.asyncio + async def test_next_token_descending_pagination( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + """Test next_token behavior for descending pagination""" + # Setup response with nextBackwardToken + mock_client.get_log_events.return_value = { + "events": [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513235, "message": "V29ybGQ="}, + ], + "nextBackwardToken": "bwd456", + "nextForwardToken": "fwd", + } + + poll_logs_request.descending = True + poll_logs_request.limit = 2 + result = log_storage.poll_logs(project, poll_logs_request) + + assert len(result.logs) == 2 + # Events should be reversed for descending order + assert result.logs[0].message == "V29ybGQ=" + assert result.logs[1].message == "SGVsbG8=" + assert result.next_token == "bwd456" # Should return nextBackwardToken + + # Verify API was called with correct parameters + mock_client.get_log_events.assert_called_once_with( + logGroupName="test-group", + logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/job", + limit=2, + startFromHead=False, + ) + + @pytest.mark.asyncio + async def test_next_token_provided_in_request( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + """Test that provided next_token is passed to CloudWatch API""" + mock_client.get_log_events.return_value = { + "events": [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + ], + "nextBackwardToken": "bwd", + "nextForwardToken": "new_fwd", + } + + poll_logs_request.next_token = "existing_token_123" + poll_logs_request.descending = False + poll_logs_request.limit = 1 + result = log_storage.poll_logs(project, poll_logs_request) + + assert len(result.logs) == 1 + assert result.next_token == "new_fwd" + + # Verify API was called with the provided next_token + mock_client.get_log_events.assert_called_once_with( + logGroupName="test-group", + logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/job", + limit=1, + startFromHead=True, + nextToken="existing_token_123", + endTime=mock_client.get_log_events.call_args.kwargs["endTime"], + ) + + @pytest.mark.asyncio + async def test_next_token_none_when_no_logs( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + """Test that next_token is None when no logs are returned""" + mock_client.get_log_events.return_value = { + "events": [], + "nextBackwardToken": "bwd", + "nextForwardToken": "fwd", + } + + poll_logs_request.limit = 10 + result = log_storage.poll_logs(project, poll_logs_request) + + assert len(result.logs) == 0 + assert result.next_token is None # Should be None when no logs returned + + @pytest.mark.asyncio + async def test_next_token_with_time_filtering( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + """Test next_token behavior with time filtering""" + mock_client.get_log_events.return_value = { + "events": [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + ], + "nextBackwardToken": "bwd_with_time", + "nextForwardToken": "fwd_with_time", + } + + poll_logs_request.start_time = datetime(2023, 10, 6, 10, 1, 53, 234000, timezone.utc) + poll_logs_request.end_time = datetime(2023, 10, 7, 10, 1, 53, 234000, timezone.utc) + poll_logs_request.next_token = "time_token" + poll_logs_request.descending = True + poll_logs_request.diagnose = True + result = log_storage.poll_logs(project, poll_logs_request) + + assert len(result.logs) == 1 + assert result.next_token == "bwd_with_time" + + # Verify API was called with time filters and next_token + mock_client.get_log_events.assert_called_once_with( + logGroupName="test-group", + logStreamName="test-proj/test-run/1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e/runner", + limit=100, + startFromHead=False, + startTime=1696586513235, # start_time + 1ms + endTime=1696672913234, + nextToken="time_token", + ) + + @pytest.mark.asyncio + async def test_next_token_missing_in_cloudwatch_response( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + """Test behavior when CloudWatch doesn't return next tokens""" + mock_client.get_log_events.return_value = { + "events": [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + ], + # No nextBackwardToken or nextForwardToken in response + } + + poll_logs_request.descending = False + result = log_storage.poll_logs(project, poll_logs_request) + + assert len(result.logs) == 1 + assert result.next_token is None # Should be None when no token in response + + @pytest.mark.asyncio + async def test_next_token_empty_string_in_cloudwatch_response( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + """Test behavior when CloudWatch returns empty string tokens""" + mock_client.get_log_events.return_value = { + "events": [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + ], + "nextBackwardToken": "", + "nextForwardToken": "", + } + + poll_logs_request.descending = False + result = log_storage.poll_logs(project, poll_logs_request) + + assert len(result.logs) == 1 + assert result.next_token == "" # Should return empty string if that's what AWS returns + + @pytest.mark.asyncio + async def test_next_token_pagination_workflow( + self, + project: ProjectModel, + log_storage: CloudWatchLogStorage, + mock_client: Mock, + poll_logs_request: PollLogsRequest, + ): + """Test complete pagination workflow with next_token""" + # First call - returns some logs with next_token + mock_client.get_log_events.side_effect = [ + { + "events": [ + {"timestamp": 1696586513234, "message": "SGVsbG8="}, + {"timestamp": 1696586513235, "message": "V29ybGQ="}, + ], + "nextBackwardToken": "bwd", + "nextForwardToken": "token_page2", + }, + # Second call - returns final logs without next_token + { + "events": [ + {"timestamp": 1696586513236, "message": "IQ=="}, + ], + "nextBackwardToken": "final_bwd", + "nextForwardToken": "final_fwd", + }, + ] + + # First page + poll_logs_request.limit = 2 + poll_logs_request.descending = False + page1 = log_storage.poll_logs(project, poll_logs_request) + + assert len(page1.logs) == 2 + assert page1.logs[0].message == "SGVsbG8=" + assert page1.logs[1].message == "V29ybGQ=" + assert page1.next_token == "token_page2" + + # Second page using next_token + poll_logs_request.next_token = page1.next_token + page2 = log_storage.poll_logs(project, poll_logs_request) + + assert len(page2.logs) == 1 + assert page2.logs[0].message == "IQ==" + assert page2.next_token == "final_fwd" + + # Verify both API calls + assert mock_client.get_log_events.call_count == 2 + + # First call should not have nextToken + first_call = mock_client.get_log_events.call_args_list[0] + assert "nextToken" not in first_call.kwargs + + # Second call should have nextToken + second_call = mock_client.get_log_events.call_args_list[1] + assert second_call.kwargs["nextToken"] == "token_page2"