Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions kubeflow/trainer/backends/localprocess/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
Unit tests for the LocalProcessBackend class in the Kubeflow Trainer SDK.
"""

from datetime import datetime
from unittest.mock import Mock, patch

import pytest

from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend
from kubeflow.trainer.backends.localprocess.constants import LOCAL_RUNTIME_IMAGE
from kubeflow.trainer.backends.localprocess.job import LocalJob
from kubeflow.trainer.backends.localprocess.types import (
LocalBackendJobs,
LocalBackendStep,
LocalProcessBackendConfig,
LocalRuntimeTrainer,
)
Expand Down Expand Up @@ -552,3 +556,87 @@ def test_get_job_status(local_backend, test_case):

status = local_backend._LocalProcessBackend__get_job_status(job)
assert status == test_case.expected_output


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="follow=False yields individual lines from stdout snapshot",
expected_status=SUCCESS,
config={"stdout": "line one\nline two\nline three\n"},
expected_output=["line one", "line two", "line three"],
),
TestCase(
name="follow=False with empty stdout yields nothing",
expected_status=SUCCESS,
config={"stdout": ""},
expected_output=[],
),
],
)
def test_get_job_logs_follow_false(local_backend, test_case):
"""Test that get_job_logs(follow=False) yields individual log lines without side effects."""
job = LocalJob(name="test-logs-follow-false-train", command=["echo", "hi"])
job._stdout = test_case.config["stdout"]
job._status = constants.TRAINJOB_COMPLETE
step = LocalBackendStep(step_name="train", job=job)
backend_job = LocalBackendJobs(
name="test-logs-follow-false", runtime=None, created=datetime.now()
)
backend_job.steps.append(step)
local_backend._LocalProcessBackend__local_jobs.append(backend_job)
lines = list(local_backend.get_job_logs("test-logs-follow-false", step="train"))
assert lines == test_case.expected_output


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="follow=True streams lines to caller without printing to stdout",
expected_status=SUCCESS,
config={"chunks": ["stream one\nstream two\n", "stream three\n"]},
expected_output=["stream one", "stream two", "stream three"],
),
TestCase(
name="follow=True with empty stream yields nothing",
expected_status=SUCCESS,
config={"chunks": []},
expected_output=[],
),
],
)
def test_get_job_logs_follow_false_is_eager_snapshot(local_backend):
"""Test that get_job_logs(follow=False) captures stdout at call time, not at iteration time."""
job = LocalJob(name="test-logs-eager-snapshot-train", command=["echo", "hi"])
job._stdout = "line one\nline two\n"
job._status = constants.TRAINJOB_COMPLETE
step = LocalBackendStep(step_name="train", job=job)
backend_job = LocalBackendJobs(
name="test-logs-eager-snapshot", runtime=None, created=datetime.now()
)
backend_job.steps.append(step)
local_backend._LocalProcessBackend__local_jobs.append(backend_job)

# Obtain the iterator before mutating stdout.
iterator = local_backend.get_job_logs("test-logs-eager-snapshot", step="train")
# Mutate stdout after the call — snapshot must NOT include this new line.
job._stdout += "line three\n"
assert list(iterator) == ["line one", "line two"]
"""Test that get_job_logs(follow=True) streams lines to caller without printing to stdout."""
job = LocalJob(name="test-logs-follow-true-train", command=["echo", "hi"])
chunks = test_case.config["chunks"]
with patch.object(job, "stream_logs", return_value=iter(chunks)):
step = LocalBackendStep(step_name="train", job=job)
backend_job = LocalBackendJobs(
name="test-logs-follow-true", runtime=None, created=datetime.now()
)
backend_job.steps.append(step)
local_backend._LocalProcessBackend__local_jobs.append(backend_job)
lines = list(
local_backend.get_job_logs("test-logs-follow-true", follow=True, step="train")
)
captured = capsys.readouterr()
assert captured.out == "", "logs(follow=True) must not print to stdout"
assert lines == test_case.expected_output
29 changes: 21 additions & 8 deletions kubeflow/trainer/backends/localprocess/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from datetime import datetime
import logging
import os
Expand Down Expand Up @@ -149,17 +150,29 @@ def cancel(self):
def returncode(self):
return self._returncode

def logs(self, follow=False) -> list[str]:
def logs(self, follow: bool = False) -> Iterator[str]:
"""Return log lines from the job's stdout.

Args:
follow: If True, stream lines in real-time as the job runs.
If False, return an eager snapshot of current stdout lines.

Returns:
An iterator of individual log lines without trailing newline characters.
"""
if not follow:
return self._stdout.splitlines()
# Take an eager snapshot of the current stdout contents.
with self._lock:
snapshot = self._stdout.splitlines()
return iter(snapshot)

try:
for chunk in self.stream_logs():
print(chunk, end="", flush=True) # stream to console live
except StopIteration:
pass
# For streaming behavior, delegate to a separate generator helper.
return self._follow_logs()

return self._stdout.splitlines()
def _follow_logs(self) -> Iterator[str]:
"""Generator that yields new output lines as they come in, line by line."""
for chunk in self.stream_logs():
yield from chunk.splitlines()

def stream_logs(self):
"""Generator that yields new output lines as they come in."""
Expand Down