Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/schemas/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ class PollLogsRequest(CoreModel):
start_time: Optional[datetime]
end_time: Optional[datetime]
descending: bool = False
limit: int = Field(100, ge=0, le=1000)
limit: int = Field(100, ge=1, le=1000)
diagnose: bool = False
8 changes: 5 additions & 3 deletions src/dstack/_internal/server/services/logs/filelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dstack._internal.server.schemas.runner import LogEvent as RunnerLogEvent
from dstack._internal.server.services.logs.base import (
LogStorage,
LogStorageError,
b64encode_raw_message,
unix_time_ms_to_datetime,
)
Expand All @@ -29,7 +30,8 @@ def __init__(self, root: Union[Path, str, None] = None) -> None:
self.root = Path(root)

def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmissionLogs:
# TODO Respect request.limit to support pagination
if request.descending is True:
raise LogStorageError("FileLogStorage doesn't support descending order")
log_producer = LogProducer.RUNNER if request.diagnose else LogProducer.JOB
log_file_path = self._get_log_file_path(
project_name=project.name,
Expand All @@ -46,12 +48,12 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi
continue
if request.end_time is None or log_event.timestamp < request.end_time:
logs.append(log_event)
if len(logs) >= request.limit:
break
else:
break
except IOError:
pass
if request.descending:
logs = list(reversed(logs))
return JobSubmissionLogs(logs=logs)

def write_logs(
Expand Down
63 changes: 63 additions & 0 deletions src/tests/_internal/server/services/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,69 @@ async def test_writes_logs(self, test_db, session: AsyncSession, tmp_path: Path)
'{"timestamp": "2023-10-06T10:01:53.235000+00:00", "log_source": "stdout", "message": "V29ybGQ="}\n'
)

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_poll_logs_with_limit(self, test_db, session: AsyncSession, tmp_path: Path):
project = await create_project(session=session)
log_storage = FileLogStorage(tmp_path)

# Write more logs than the limit
log_storage.write_logs(
project=project,
run_name="test_run",
job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"),
runner_logs=[
RunnerLogEvent(timestamp=1696586513234, message=b"Log1"),
RunnerLogEvent(timestamp=1696586513235, message=b"Log2"),
RunnerLogEvent(timestamp=1696586513236, message=b"Log3"),
RunnerLogEvent(timestamp=1696586513237, message=b"Log4"),
RunnerLogEvent(timestamp=1696586513238, message=b"Log5"),
],
job_logs=[],
)
logs = log_storage.poll_logs(
project,
PollLogsRequest(
run_name="test_run",
job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"),
start_time=None,
end_time=None,
limit=1000,
diagnose=True,
),
).logs
assert len(logs) == 5

# Test with limit smaller than total logs (ascending)
poll_request = PollLogsRequest(
run_name="test_run",
job_submission_id=UUID("1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e"),
limit=3,
diagnose=True,
)
job_submission_logs = log_storage.poll_logs(project, poll_request)

# Should return only the first 3 logs in ascending order
assert len(job_submission_logs.logs) == 3
assert job_submission_logs.logs[0].message == base64.b64encode(
"Log1".encode("utf-8")
).decode("utf-8")
assert job_submission_logs.logs[1].message == base64.b64encode(
"Log2".encode("utf-8")
).decode("utf-8")
assert job_submission_logs.logs[2].message == base64.b64encode(
"Log3".encode("utf-8")
).decode("utf-8")

# Test with limit of 1
poll_request.limit = 1
poll_request.start_time = logs[3].timestamp
job_submission_logs = log_storage.poll_logs(project, poll_request)
assert len(job_submission_logs.logs) == 1
assert job_submission_logs.logs[0].message == base64.b64encode(
"Log5".encode("utf-8")
).decode("utf-8")


class TestCloudWatchLogStorage:
FAKE_NOW = datetime(2023, 10, 6, 10, 1, 54, tzinfo=timezone.utc)
Expand Down
Loading