Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ docs/_build/

# Local E2E: cloned for building Spark Operator image (SPARK_OPERATOR_IMAGE_TAG=local)
spark-operator/

# Local issue tracking (development notes)
ISSUES_TO_SOLVE.md
12 changes: 12 additions & 0 deletions kubeflow/trainer/backends/container/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,18 @@ def wait_for_job_status(
) -> types.TrainJob:
import time

if polling_interval >= timeout:
raise ValueError(
f"polling_interval ({polling_interval}) must be strictly less than "
f"timeout ({timeout})"
)

if polling_interval <= 0 or timeout <= 0:
raise ValueError(
f"polling_interval ({polling_interval}) and timeout ({timeout}) "
f"must both be positive"
)

end = time.time() + timeout
Comment on lines +819 to 831
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider validating polling_interval > 0 (and timeout > 0) here as well: polling_interval=0 currently causes a tight loop (no sleep) until timeout, which can hammer the API and CPU; raising ValueError for non-positive values avoids this.

Copilot uses AI. Check for mistakes.
while time.time() < end:
tj = self.get_job(name)
Expand Down
58 changes: 35 additions & 23 deletions kubeflow/trainer/backends/container/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,49 +895,61 @@ def mock_create_with_status(*args, **kwargs):
config={"wait_status": constants.TRAINJOB_COMPLETE, "container_exit_code": 1},
expected_error=RuntimeError,
),
TestCase(
name="polling interval >= timeout error",
expected_status=FAILED,
config={"polling_interval": 5, "timeout": 5},
expected_error=ValueError,
),
],
)
def test_wait_for_job_status(container_backend, test_case):
"""Test waiting for job status."""
print("Executing test:", test_case.name)
try:
trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1)
runtime = container_backend.get_runtime(constants.DEFAULT_TRAINING_RUNTIME)
job_name = container_backend.train(runtime=runtime, trainer=trainer)
trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1)
runtime = container_backend.get_runtime(constants.DEFAULT_TRAINING_RUNTIME)
job_name = container_backend.train(runtime=runtime, trainer=trainer)

if test_case.name == "wait for complete":
container_id = container_backend._adapter.containers_created[0]["id"]
container_backend._adapter.set_container_status(
container_id, "exited", test_case.config["container_exit_code"]
)
if test_case.name == "wait for complete":
container_id = container_backend._adapter.containers_created[0]["id"]
container_backend._adapter.set_container_status(
container_id, "exited", test_case.config["container_exit_code"]
)

completed_job = container_backend.wait_for_job_status(
job_name, status={test_case.config["wait_status"]}, timeout=5, polling_interval=1
)
completed_job = container_backend.wait_for_job_status(
job_name, status={test_case.config["wait_status"]}, timeout=5, polling_interval=1
)

assert test_case.expected_status == SUCCESS
assert completed_job.status == constants.TRAINJOB_COMPLETE
assert test_case.expected_status == SUCCESS
assert completed_job.status == constants.TRAINJOB_COMPLETE

elif test_case.name == "wait timeout":
elif test_case.name == "wait timeout":
with pytest.raises(test_case.expected_error):
container_backend.wait_for_job_status(
job_name,
status={test_case.config["wait_status"]},
timeout=test_case.config["timeout"],
polling_interval=1,
)

elif test_case.name == "job fails":
container_id = container_backend._adapter.containers_created[0]["id"]
container_backend._adapter.set_container_status(
container_id, "exited", test_case.config["container_exit_code"]
)
elif test_case.name == "job fails":
container_id = container_backend._adapter.containers_created[0]["id"]
container_backend._adapter.set_container_status(
container_id, "exited", test_case.config["container_exit_code"]
)

container_backend.wait_for_job_status(
job_name, status={test_case.config["wait_status"]}, timeout=5, polling_interval=1
)

elif test_case.name == "polling interval >= timeout error":
with pytest.raises(test_case.expected_error):
container_backend.wait_for_job_status(
job_name, status={test_case.config["wait_status"]}, timeout=5, polling_interval=1
job_name,
timeout=test_case.config["timeout"],
polling_interval=test_case.config["polling_interval"],
)

except Exception as e:
assert type(e) is test_case.expected_error
print("test execution complete")


Expand Down
11 changes: 9 additions & 2 deletions kubeflow/trainer/backends/kubernetes/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,16 @@ def wait_for_job_status(
if not status.issubset(job_statuses):
raise ValueError(f"Expected status {status} must be a subset of {job_statuses}")

if polling_interval > timeout:
if polling_interval >= timeout:
raise ValueError(
f"Polling interval {polling_interval} must be less than timeout: {timeout}"
f"polling_interval ({polling_interval}) must be strictly less than "
f"timeout ({timeout})"
)

if polling_interval <= 0 or timeout <= 0:
raise ValueError(
f"polling_interval ({polling_interval}) and timeout ({timeout}) "
f"must both be positive"
)

for _ in range(round(timeout / polling_interval)):
Comment on lines +463 to 475
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation should also reject polling_interval <= 0 (and likely timeout <= 0) because the subsequent round(timeout / polling_interval) will raise ZeroDivisionError or behave incorrectly; add an explicit positive-value check before the division/sleep loop.

Copilot uses AI. Check for mistakes.
Expand Down
11 changes: 9 additions & 2 deletions kubeflow/trainer/backends/localprocess/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,16 @@ def wait_for_job_status(
if _job is None:
raise ValueError(f"No TrainJob with name {name}")

if polling_interval > timeout:
if polling_interval >= timeout:
raise ValueError(
f"Polling interval {polling_interval} must be less than timeout: {timeout}"
f"polling_interval ({polling_interval}) must be strictly less than "
f"timeout ({timeout})"
)

if polling_interval <= 0 or timeout <= 0:
raise ValueError(
f"polling_interval ({polling_interval}) and timeout ({timeout}) "
f"must both be positive"
)

for _ in range(round(timeout / polling_interval)):
Comment on lines +229 to 241
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new polling_interval >= timeout guard still allows polling_interval <= 0 (and/or timeout <= 0), which will lead to a ZeroDivisionError in round(timeout / polling_interval) or nonsensical waiting behavior; validate both values are positive before doing the division and sleeping.

Copilot uses AI. Check for mistakes.
Expand Down