Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions kubeflow/trainer/backends/kubernetes/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def train(
name = None
trainer_overrides = {}
runtime_patches = None
active_deadline_seconds = None

if options:
for option in options:
Expand All @@ -302,6 +303,7 @@ def train(
spec_section = job_spec.get("spec", {})
trainer_overrides = spec_section.get("trainer", {})
runtime_patches = spec_section.get("runtimePatches")
active_deadline_seconds = spec_section.get("activeDeadlineSeconds")

# Generate unique name for the TrainJob if not provided
train_job_name = name or (
Expand All @@ -316,6 +318,7 @@ def train(
trainer=trainer,
trainer_overrides=trainer_overrides,
runtime_patches=runtime_patches,
active_deadline_seconds=active_deadline_seconds,
)

# Build the TrainJob.
Expand Down Expand Up @@ -756,6 +759,7 @@ def _get_trainjob_spec(
| None = None,
trainer_overrides: dict[str, Any] | None = None,
runtime_patches: list[dict[str, Any]] | None = None,
active_deadline_seconds: int | None = None,
) -> models.TrainerV1alpha1TrainJobSpec:
"""Get TrainJob spec from the given parameters."""

Expand Down Expand Up @@ -806,6 +810,7 @@ def _get_trainjob_spec(
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime.name),
trainer=trainer_cr if trainer_cr != models.TrainerV1alpha1Trainer() else None,
runtimePatches=runtime_patch_models,
activeDeadlineSeconds=active_deadline_seconds,
)

# Add initializer if users define it.
Expand Down
33 changes: 33 additions & 0 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import kubeflow.trainer.backends.kubernetes.utils as utils
from kubeflow.trainer.constants import constants
from kubeflow.trainer.options import (
ActiveDeadlineSeconds,
Annotations,
JobSetSpecPatch,
JobSetTemplatePatch,
Expand Down Expand Up @@ -319,6 +320,7 @@ def get_train_job(
labels: dict[str, str] | None = None,
annotations: dict[str, str] | None = None,
runtime_patches: list[models.TrainerV1alpha1RuntimePatch] | None = None,
active_deadline_seconds: int | None = None,
) -> models.TrainerV1alpha1TrainJob:
"""
Create a mock TrainJob object with optional trainer configurations.
Expand All @@ -335,6 +337,7 @@ def get_train_job(
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name),
trainer=train_job_trainer,
runtimePatches=runtime_patches,
activeDeadlineSeconds=active_deadline_seconds,
),
)

Expand Down Expand Up @@ -1305,6 +1308,36 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
],
),
),
TestCase(
name="train with active deadline seconds",
expected_status=SUCCESS,
config={
"options": [
ActiveDeadlineSeconds(seconds=3600),
],
},
expected_output=get_train_job(
runtime_name=TORCH_RUNTIME,
train_job_name=BASIC_TRAIN_JOB_NAME,
active_deadline_seconds=3600,
),
),
TestCase(
name="train with active deadline seconds and labels",
expected_status=SUCCESS,
config={
"options": [
ActiveDeadlineSeconds(seconds=600),
Labels({"team": "ml-platform"}),
],
},
expected_output=get_train_job(
runtime_name=TORCH_RUNTIME,
train_job_name=BASIC_TRAIN_JOB_NAME,
active_deadline_seconds=600,
labels={"team": "ml-platform"},
),
),
TestCase(
name="train with metadata labels and runtime patches",
expected_status=SUCCESS,
Expand Down
2 changes: 2 additions & 0 deletions kubeflow/trainer/options/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from kubeflow.trainer.options.common import Name
from kubeflow.trainer.options.kubernetes import (
ActiveDeadlineSeconds,
Annotations,
ContainerPatch,
JobSetSpecPatch,
Expand All @@ -43,6 +44,7 @@
# Common options (all backends)
"Name",
# Kubernetes options
"ActiveDeadlineSeconds",
"Annotations",
"ContainerPatch",
"JobSetSpecPatch",
Expand Down
53 changes: 53 additions & 0 deletions kubeflow/trainer/options/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,56 @@ def __call__(
spec = job_spec.setdefault("spec", {})
trainer_spec = spec.setdefault("trainer", {})
trainer_spec["args"] = self.args


@dataclass
class ActiveDeadlineSeconds:
"""Set the active deadline on the TrainJob (.spec.activeDeadlineSeconds).

Specifies the duration in seconds relative to the TrainJob start time
that the TrainJob may be active before the system tries to terminate it.
Once reached, all running Pods are terminated and the TrainJob status
becomes Failed with reason: DeadlineExceeded.

The deadline timer resets when the TrainJob is resumed from suspension.

Supported backends:
- Kubernetes

Args:
seconds: Duration in seconds. Must be a positive integer (minimum 1).
"""

seconds: int

def __post_init__(self):
"""Validate the active deadline seconds configuration."""
if type(self.seconds) is not int or self.seconds < 1:
raise ValueError("activeDeadlineSeconds must be a positive integer (minimum 1)")

def __call__(
self,
job_spec: dict[str, Any],
trainer: CustomTrainer | BuiltinTrainer | None,
backend: RuntimeBackend,
) -> None:
"""Apply active deadline seconds to the job specification.

Args:
job_spec: Job specification dictionary to modify.
trainer: Optional trainer instance for context.
backend: Backend instance for validation.

Raises:
ValueError: If backend does not support active deadline.
"""
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend

if not isinstance(backend, KubernetesBackend):
raise ValueError(
f"ActiveDeadlineSeconds option is not compatible with {type(backend).__name__}. "
f"Supported backends: KubernetesBackend"
)

spec = job_spec.setdefault("spec", {})
spec["activeDeadlineSeconds"] = self.seconds
30 changes: 30 additions & 0 deletions kubeflow/trainer/options/kubernetes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend
from kubeflow.trainer.options import (
ActiveDeadlineSeconds,
Annotations,
ContainerPatch,
JobSetSpecPatch,
Expand Down Expand Up @@ -69,6 +70,7 @@ class TestKubernetesOptionBackendValidation:
(Annotations, {"description": "test job"}),
(TrainerCommand, ["python", "train.py"]),
(TrainerArgs, ["--epochs", "10"]),
(ActiveDeadlineSeconds, 3600),
],
)
def test_kubernetes_options_reject_wrong_backend(
Expand All @@ -79,6 +81,8 @@ def test_kubernetes_options_reject_wrong_backend(
option = option_class(command=option_args)
elif option_class == TrainerArgs:
option = option_class(args=option_args)
elif option_class == ActiveDeadlineSeconds:
option = option_class(seconds=option_args)
else:
option = option_class(option_args)

Expand Down Expand Up @@ -129,6 +133,11 @@ class TestKubernetesOptionApplication:
["--epochs", "10"],
{"spec": {"trainer": {"args": ["--epochs", "10"]}}},
),
(
ActiveDeadlineSeconds,
3600,
{"spec": {"activeDeadlineSeconds": 3600}},
),
],
)
def test_option_application(
Expand All @@ -139,6 +148,8 @@ def test_option_application(
option = option_class(command=option_args)
elif option_class == TrainerArgs:
option = option_class(args=option_args)
elif option_class == ActiveDeadlineSeconds:
option = option_class(seconds=option_args)
else:
option = option_class(option_args)

Expand Down Expand Up @@ -415,3 +426,22 @@ def test_runtime_patch_with_jobset_metadata(self, mock_kubernetes_backend):
},
},
}


class TestActiveDeadlineSeconds:
"""Test ActiveDeadlineSeconds validation."""

@pytest.mark.parametrize(
"seconds,expected_error",
[
(0, "activeDeadlineSeconds must be a positive integer (minimum 1)"),
(-1, "activeDeadlineSeconds must be a positive integer (minimum 1)"),
(-100, "activeDeadlineSeconds must be a positive integer (minimum 1)"),
(True, "activeDeadlineSeconds must be a positive integer (minimum 1)"),
],
)
def test_active_deadline_seconds_rejects_invalid_values(self, seconds, expected_error):
"""Test ActiveDeadlineSeconds rejects non-positive values."""
with pytest.raises(ValueError) as exc_info:
ActiveDeadlineSeconds(seconds=seconds)
assert expected_error in str(exc_info.value)
Comment thread
XploY04 marked this conversation as resolved.
Loading