From 611adef18c398d396a43f050549cbd65c6dcca28 Mon Sep 17 00:00:00 2001 From: Sabari Narayana Date: Fri, 27 Feb 2026 15:31:14 +0530 Subject: [PATCH 01/11] docs: add KEP-2839 Dynamic LLM Trainer Framework proposal --- .../2839-dynamic-llm-trainer/README.md | 620 ++++++++++++++++++ 1 file changed, 620 insertions(+) create mode 100644 docs/proposals/2839-dynamic-llm-trainer/README.md diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md new file mode 100644 index 0000000000..39fd1caf98 --- /dev/null +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -0,0 +1,620 @@ +# KEP-2839: Dynamic LLM Trainer Framework + +**Authors**: NarayanaSabari + +**Status**: Provisional + +**Creation date**: 2026-02-27 + +**Tracking issue**: [kubeflow/trainer#2839](https://github.com/kubeflow/trainer/issues/2839) + +**Upstream KEP**: [KEP-2401: Kubeflow LLM Trainer V2](../2401-llm-trainer-v2/README.md) + +## Table of Contents + + +- [KEP-2839: Dynamic LLM Trainer Framework](#kep-2839-dynamic-llm-trainer-framework) + - [Table of Contents](#table-of-contents) + - [Summary](#summary) + - [Goals](#goals) + - [Non-Goals](#non-goals) + - [Design Details](#design-details) + - [SDK: LLMBackend Interface](#sdk-llmbackend-interface) + - [SDK: Backend Registry](#sdk-backend-registry) + - [SDK: BuiltinTrainer Change](#sdk-builtintrainer-change) + - [SDK: TorchTune Backend (Refactored)](#sdk-torchtune-backend-refactored) + - [SDK: TRL Backend](#sdk-trl-backend) + - [Go Control Plane: LLMBackendStrategy](#go-control-plane-llmbackendstrategy) + - [Go Control Plane: Strategy Dispatch](#go-control-plane-strategy-dispatch) + - [Go Control Plane: Constants](#go-control-plane-constants) + - [Container Images](#container-images) + - [ClusterTrainingRuntimes](#clustertrainingruntimes) + - [Helm Chart Changes](#helm-chart-changes) + - [Test Plan](#test-plan) + - [Risks and Mitigations](#risks-and-mitigations) + + +## Summary + +Decouple the `BuiltinTrainer` from TorchTune by introducing a pluggable `LLMBackend` +interface in the SDK and a corresponding `LLMBackendStrategy` in the Go control plane. +TorchTune becomes the first backend implementation (preserving backward compatibility), +and TRL is added as the first new backend with SFT/DPO support. + +This builds on [KEP-2401](../2401-llm-trainer-v2/README.md) and the community consensus +on "Plan 3" in [#2752](https://github.com/kubeflow/trainer/issues/2752). +TorchTune stopped adding features in July 2025 +([pytorch/torchtune#2883](https://github.com/pytorch/torchtune/issues/2883)). + +## Goals + +1. Define an `LLMBackend` abstract interface in the Python SDK. +2. Implement a backend registry with `@register_backend` decorator. +3. Refactor `TorchTuneConfig` to implement `LLMBackend` with zero breaking changes. +4. Implement `TRLConfig` backend supporting SFT and DPO. +5. Create TRL container image and `ClusterTrainingRuntime` manifests. +6. Generalize the Go Torch plugin to dispatch via `LLMBackendStrategy` instead of + hardcoded TorchTune command-sniffing. +7. Support external (out-of-tree) backend registration. + +## Non-Goals + +1. Unsloth or LlamaFactory backends (future work). +2. CRD schema changes -- operates within existing `.spec.trainer.command`/`.spec.trainer.args`. +3. New Kubernetes resource topologies (e.g., launcher/worker patterns). +4. Go-side distributed training plugins per backend (backends use existing torchrun infra). + +## Design Details + +### SDK: LLMBackend Interface + +New file: `kubeflow/trainer/types/backends/__init__.py` + +```python +import abc +from dataclasses import dataclass + + +@dataclass +class LLMBackend(abc.ABC): + + @abc.abstractmethod + def to_command(self) -> tuple[str, ...]: + """Container entrypoint command.""" + ... + + @abc.abstractmethod + def to_args(self) -> list[str]: + """CLI arguments for .spec.trainer.args.""" + ... + + @abc.abstractmethod + def framework(self) -> str: + """Framework identifier matching trainer.kubeflow.org/framework label.""" + ... + + def validate(self) -> None: + """Optional config validation. Raise ValueError on invalid config.""" + pass + + @property + def num_nodes(self) -> int | None: + return None + + @property + def resources_per_node(self) -> dict | None: + return None +``` + +### SDK: Backend Registry + +New file: `kubeflow/trainer/types/backends/registry.py` + +```python +from collections.abc import Callable + +_BACKEND_REGISTRY: dict[str, type] = {} + + +def register_backend(name: str) -> Callable: + def decorator(cls): + if name in _BACKEND_REGISTRY: + raise ValueError( + f"Backend '{name}' already registered by {_BACKEND_REGISTRY[name].__name__}." + ) + _BACKEND_REGISTRY[name] = cls + return cls + return decorator + + +def get_registered_backends() -> dict[str, type]: + return dict(_BACKEND_REGISTRY) + + +def get_backend(name: str) -> type | None: + return _BACKEND_REGISTRY.get(name) +``` + +### SDK: BuiltinTrainer Change + +In `kubeflow/trainer/types/types.py`: + +```python +# BEFORE +@dataclass +class BuiltinTrainer: + config: TorchTuneConfig + +# AFTER +@dataclass +class BuiltinTrainer: + config: LLMBackend +``` + +`TorchTuneConfig` implements `LLMBackend`, so existing +`BuiltinTrainer(config=TorchTuneConfig(...))` code is unchanged. + +The SDK's `KubernetesBackend` dispatch becomes generic: + +```python +def _get_trainer_spec(self, trainer: BuiltinTrainer) -> dict: + backend = trainer.config + backend.validate() + return { + "command": list(backend.to_command()), + "args": backend.to_args(), + } +``` + +This replaces the current `get_args_using_torchtune_config()`. + +### SDK: TorchTune Backend (Refactored) + +New file: `kubeflow/trainer/types/backends/torchtune.py` + +```python +@register_backend("torchtune") +@dataclass +class TorchTuneConfig(LLMBackend): + dtype: DataType | None = None + batch_size: int | None = None + epochs: int | None = None + loss: Loss | None = None + _num_nodes: int | None = field(default=None, repr=True) + peft_config: LoraConfig | None = None + dataset_preprocess_config: TorchTuneInstructDataset | None = None + _resources_per_node: dict | None = field(default=None, repr=True) + + def to_command(self) -> tuple[str, ...]: + return ("tune", "run") + + def to_args(self) -> list[str]: + args = [] + if self.dtype is not None: + args.append(f"dtype={self.dtype.value}") + if self.batch_size is not None: + args.append(f"batch_size={self.batch_size}") + if self.epochs is not None: + args.append(f"epochs={self.epochs}") + if self.loss is not None: + args.append(f"loss={self.loss.value}") + if self.peft_config is not None: + args.extend(_get_args_from_peft_config(self.peft_config)) + if self.dataset_preprocess_config is not None: + args.extend( + _get_args_from_dataset_preprocess_config(self.dataset_preprocess_config) + ) + return args + + def framework(self) -> str: + return "torchtune" + + @property + def num_nodes(self) -> int | None: + return self._num_nodes + + @property + def resources_per_node(self) -> dict | None: + return self._resources_per_node +``` + +Helper functions `_get_args_from_peft_config()` and +`_get_args_from_dataset_preprocess_config()` are extracted from the current +`get_args_using_torchtune_config()` in `backends/kubernetes/utils.py` unchanged. + +### SDK: TRL Backend + +New file: `kubeflow/trainer/types/backends/trl.py` + +```python +class TRLTrainerType(Enum): + SFT = "sft" + DPO = "dpo" + PPO = "ppo" + ORPO = "orpo" + KTO = "kto" + + +@dataclass +class TRLPeftConfig: + r: int = 16 + lora_alpha: int = 32 + lora_dropout: float = 0.05 + target_modules: list[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) + use_rslora: bool = False + use_dora: bool = False + + +@dataclass +class TRLSFTConfig: + max_seq_length: int = 2048 + packing: bool = False + dataset_text_field: str | None = None + + +@dataclass +class TRLDPOConfig: + beta: float = 0.1 + max_length: int = 1024 + max_prompt_length: int = 512 + loss_type: str = "sigmoid" + + +@register_backend("trl") +@dataclass +class TRLConfig(LLMBackend): + trainer_type: TRLTrainerType = TRLTrainerType.SFT + model_name_or_path: str = "/workspace/model" + learning_rate: float = 2e-5 + num_train_epochs: int = 3 + per_device_train_batch_size: int = 4 + gradient_accumulation_steps: int = 1 + bf16: bool = True + fp16: bool = False + peft_config: TRLPeftConfig | None = None + sft_config: TRLSFTConfig | None = None + dpo_config: TRLDPOConfig | None = None + _num_nodes: int | None = None + _resources_per_node: dict | None = None + output_dir: str = "/workspace/output" + + def to_command(self) -> tuple[str, ...]: + return ("python", "-m", "trl") + + def to_args(self) -> list[str]: + args = [self.trainer_type.value] + args.extend(["--model_name_or_path", self.model_name_or_path]) + args.extend(["--output_dir", self.output_dir]) + args.extend(["--learning_rate", str(self.learning_rate)]) + args.extend(["--num_train_epochs", str(self.num_train_epochs)]) + args.extend(["--per_device_train_batch_size", str(self.per_device_train_batch_size)]) + args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)]) + if self.bf16: + args.append("--bf16") + if self.fp16: + args.append("--fp16") + if self.peft_config is not None: + args.append("--use_peft") + args.extend(["--lora_r", str(self.peft_config.r)]) + args.extend(["--lora_alpha", str(self.peft_config.lora_alpha)]) + args.extend(["--lora_dropout", str(self.peft_config.lora_dropout)]) + if self.peft_config.target_modules: + args.extend(["--lora_target_modules", *self.peft_config.target_modules]) + if self.trainer_type == TRLTrainerType.SFT and self.sft_config: + args.extend(["--max_seq_length", str(self.sft_config.max_seq_length)]) + if self.sft_config.packing: + args.append("--packing") + if self.sft_config.dataset_text_field: + args.extend(["--dataset_text_field", self.sft_config.dataset_text_field]) + if self.trainer_type == TRLTrainerType.DPO and self.dpo_config: + args.extend(["--beta", str(self.dpo_config.beta)]) + args.extend(["--max_length", str(self.dpo_config.max_length)]) + args.extend(["--max_prompt_length", str(self.dpo_config.max_prompt_length)]) + args.extend(["--loss_type", self.dpo_config.loss_type]) + return args + + def framework(self) -> str: + return "trl" + + def validate(self) -> None: + if self.bf16 and self.fp16: + raise ValueError("Cannot enable both bf16 and fp16.") + if self.trainer_type == TRLTrainerType.DPO and self.sft_config is not None: + raise ValueError("sft_config should not be set when trainer_type is DPO.") + if self.trainer_type == TRLTrainerType.SFT and self.dpo_config is not None: + raise ValueError("dpo_config should not be set when trainer_type is SFT.") + + @property + def num_nodes(self) -> int | None: + return self._num_nodes + + @property + def resources_per_node(self) -> dict | None: + return self._resources_per_node +``` + +### Go Control Plane: LLMBackendStrategy + +New file: `pkg/runtime/framework/plugins/torch/backend.go` + +Replace the current command-sniffing in `torch.go` and `torchtune.go` with a strategy +interface: + +```go +type LLMBackendStrategy interface { + Name() string + EnforceCommand(info *runtime.Info, trainJob *trainer.TrainJob) error + Validate(info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) +} +``` + +`TorchTuneStrategy` wraps the existing `torchtune.go` logic (getRecipeAndConfig, +extractOverridesFromRuntime, validateTorchTune) with no behavioral changes: + +```go +type TorchTuneStrategy struct{} + +func (s *TorchTuneStrategy) Name() string { return "torchtune" } + +func (s *TorchTuneStrategy) EnforceCommand(info *runtime.Info, trainJob *trainer.TrainJob) error { + // Existing torchtune.go logic: rdzv_endpoint, recipe, config, overrides + return nil +} + +func (s *TorchTuneStrategy) Validate(info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + return validateTorchTune(info, newObj) +} +``` + +`TRLStrategy` is minimal -- TRL config is fully constructed by the SDK: + +```go +type TRLStrategy struct{} + +func (s *TRLStrategy) Name() string { return "trl" } + +func (s *TRLStrategy) EnforceCommand(info *runtime.Info, trainJob *trainer.TrainJob) error { + // Inject rendezvous endpoint for multi-node. SDK provides full args. + return nil +} + +func (s *TRLStrategy) Validate(info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + return nil, nil +} +``` + +### Go Control Plane: Strategy Dispatch + +The Torch plugin struct holds registered strategies and dispatches via the runtime's +framework label instead of command-sniffing: + +```go +type Torch struct { + backends map[string]LLMBackendStrategy +} + +func New(ctx context.Context, client client.Client, indexer client.FieldIndexer) (framework.Plugin, error) { + return &Torch{ + backends: map[string]LLMBackendStrategy{ + "torchtune": &TorchTuneStrategy{}, + "trl": &TRLStrategy{}, + }, + }, nil +} + +func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { + // Common torch distributed setup (PET_NUM_NODES, PET_NPROC_PER_NODE, etc.) + // ... + + framework := info.Labels[constants.RuntimeFrameworkLabel] + if strategy, ok := t.backends[framework]; ok { + return strategy.EnforceCommand(info, trainJob) + } + + // Default: standard torchrun path + // ... +} + +func (t *Torch) Validate(ctx context.Context, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + // Common torch validation ... + + framework := info.Labels[constants.RuntimeFrameworkLabel] + if strategy, ok := t.backends[framework]; ok { + warnings, errs := strategy.Validate(info, oldObj, newObj) + // append ... + } + // ... +} +``` + +### Go Control Plane: Constants + +```go +// pkg/constants/constants.go + +TRLEntrypoint = []string{"python", "-m", "trl"} +TRLFrameworkLabel = "trl" +RuntimeFrameworkLabel = "trainer.kubeflow.org/framework" +``` + +### Container Images + +```dockerfile +# cmd/trainers/trl/Dockerfile +FROM pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime +WORKDIR /workspace +RUN apt update && apt-get install -y --no-install-recommends build-essential \ + && rm -rf /var/lib/apt/lists/* +COPY cmd/trainers/trl/requirements.txt . +RUN pip install -r requirements.txt +``` + +``` +# cmd/trainers/trl/requirements.txt +trl>=0.15.0 +transformers>=4.48.0 +datasets>=3.0.0 +accelerate>=1.2.0 +peft>=0.14.0 +bitsandbytes>=0.41.1 +``` + +Published as `ghcr.io/kubeflow/trainer/trl-trainer`. + +### ClusterTrainingRuntimes + +Example: `manifests/base/runtimes/trl/llama3_2/llama3_2_1B.yaml` + +```yaml +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: trl-llama3.2-1b + labels: + trainer.kubeflow.org/framework: trl +spec: + mlPolicy: + numNodes: 1 + torch: + numProcPerNode: auto + template: + spec: + replicatedJobs: + - name: dataset-initializer + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: dataset-initializer + spec: + template: + spec: + containers: + - name: dataset-initializer + image: ghcr.io/kubeflow/trainer/dataset-initializer + env: + - name: STORAGE_URI + value: hf://tatsu-lab/alpaca + volumeMounts: + - mountPath: /workspace/dataset + name: workspace + volumes: + - name: workspace + persistentVolumeClaim: + claimName: workspace + - name: model-initializer + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: model-initializer + spec: + template: + spec: + containers: + - name: model-initializer + image: ghcr.io/kubeflow/trainer/model-initializer + env: + - name: STORAGE_URI + value: hf://meta-llama/Llama-3.2-1B-Instruct + volumeMounts: + - mountPath: /workspace/model + name: workspace + volumes: + - name: workspace + persistentVolumeClaim: + claimName: workspace + - name: node + dependsOn: + - dataset-initializer + - model-initializer + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: trainer + spec: + template: + spec: + containers: + - name: node + image: ghcr.io/kubeflow/trainer/trl-trainer + command: + - python + - -m + - trl + args: + - sft + - --model_name_or_path + - /workspace/model + - --output_dir + - /workspace/output + - --dataset_name + - /workspace/dataset + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - mountPath: /workspace + name: workspace + volumes: + - name: workspace + persistentVolumeClaim: + claimName: workspace +``` + +Directory structure: + +``` +manifests/base/runtimes/ +├── torchtune/ # Existing (unchanged) +└── trl/ # NEW + ├── kustomization.yaml + ├── llama3_2/ + │ ├── llama3_2_1B.yaml + │ └── llama3_2_3B.yaml + └── qwen2_5/ + └── qwen2_5_1.5B.yaml +``` + +### Helm Chart Changes + +```yaml +# charts/kubeflow-trainer/values.yaml (additions) +runtimes: + trlDistributed: + image: + registry: ghcr.io + repository: kubeflow/trainer/trl-trainer + tag: "" + llama3_2_1B: + enabled: false + llama3_2_3B: + enabled: false + qwen2_5_1_5B: + enabled: false +``` + +## Test Plan + +**Unit tests**: +- SDK backend registry: registration, duplicate detection, lookup. +- `TorchTuneConfig` backward compat: `to_args()` identical to current `get_args_using_torchtune_config()`. +- `TRLConfig`: `to_args()`, `to_command()`, `validate()` for SFT, DPO, error cases. +- Go Torch plugin: strategy dispatch, `TorchTuneStrategy` (existing cases), `TRLStrategy`. + +**Integration tests**: +- SDK creates valid TRL TrainJob CR that controller reconciles into JobSet. +- `client.list_runtimes()` returns both TorchTune and TRL runtimes. +- Existing TorchTune examples execute unchanged. + +**E2E tests**: +- TRL SFT with Llama 3.2 1B on Alpaca (GPU). +- TRL DPO with preference dataset (GPU). +- TorchTune regression. + +## Risks and Mitigations + +| Risk | Mitigation | +|------|------------| +| TRL CLI changes across versions | Pin version in requirements.txt; version compat tests | +| TRL uses accelerate, not torchrun | TRL supports torchrun-compatible launch; validate in E2E | +| SDK type widening affects static analysis | TorchTuneConfig is a subtype of LLMBackend; passes type checks | +| Scope creep from adding backends | Scoped to TorchTune + TRL only | \ No newline at end of file From 7622915395f43a416bbcfb87576725a2e4407441 Mon Sep 17 00:00:00 2001 From: Sabari Narayana Date: Mon, 2 Mar 2026 13:08:08 +0530 Subject: [PATCH 02/11] docs: rewrite KEP-2839 as HLD-only proposal, remove LLD sections - Strip all Low-Level Design content (code interfaces, strategies, Dockerfile, runtime YAML, Helm chart details) - Fix 10 technical inaccuracies found during audit: - TRL CLI entry point (trl sft, not python -m trl) - Multi-node env vars (standard + PET variants) - Correct enforceTorchTunePolicy inline location - dependsOn YAML format, volume handling pattern - TRLTrainerType enum values (SFT/DPO/KTO/GRPO) - Container name 'node' not 'trainer' - PET env var naming conventions - KEP now covers: Summary, Goals, Non-Goals, Current State Analysis, High-Level Design, Test Plan, Risks, Phases --- .../2839-dynamic-llm-trainer/README.md | 703 ++++-------------- 1 file changed, 155 insertions(+), 548 deletions(-) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index 39fd1caf98..cc14c3df9a 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -18,22 +18,19 @@ - [Summary](#summary) - [Goals](#goals) - [Non-Goals](#non-goals) - - [Design Details](#design-details) - - [SDK: LLMBackend Interface](#sdk-llmbackend-interface) - - [SDK: Backend Registry](#sdk-backend-registry) - - [SDK: BuiltinTrainer Change](#sdk-builtintrainer-change) - - [SDK: TorchTune Backend (Refactored)](#sdk-torchtune-backend-refactored) - - [SDK: TRL Backend](#sdk-trl-backend) - - [Go Control Plane: LLMBackendStrategy](#go-control-plane-llmbackendstrategy) - - [Go Control Plane: Strategy Dispatch](#go-control-plane-strategy-dispatch) - - [Go Control Plane: Constants](#go-control-plane-constants) - - [Container Images](#container-images) - - [ClusterTrainingRuntimes](#clustertrainingruntimes) - - [Helm Chart Changes](#helm-chart-changes) - - [Test Plan](#test-plan) + - [Current State Analysis](#current-state-analysis) + - [How TorchTune Is Wired Today](#how-torchtune-is-wired-today) + - [SDK Coupling](#sdk-coupling) + - [Why This Must Change](#why-this-must-change) + - [High-Level Design](#high-level-design) + - [Architecture Overview](#architecture-overview) + - [Component Interaction Flow](#component-interaction-flow) + - [What Changes vs What Stays](#what-changes-vs-what-stays) - [Risks and Mitigations](#risks-and-mitigations) +--- + ## Summary Decouple the `BuiltinTrainer` from TorchTune by introducing a pluggable `LLMBackend` @@ -60,561 +57,171 @@ TorchTune stopped adding features in July 2025 ## Non-Goals 1. Unsloth or LlamaFactory backends (future work). -2. CRD schema changes -- operates within existing `.spec.trainer.command`/`.spec.trainer.args`. +2. CRD schema changes — operates within existing `.spec.trainer.command`/`.spec.trainer.args`. 3. New Kubernetes resource topologies (e.g., launcher/worker patterns). 4. Go-side distributed training plugins per backend (backends use existing torchrun infra). -## Design Details - -### SDK: LLMBackend Interface +--- -New file: `kubeflow/trainer/types/backends/__init__.py` +## Current State Analysis -```python -import abc -from dataclasses import dataclass +### How TorchTune Is Wired Today +The Torch plugin (`pkg/runtime/framework/plugins/torch/torch.go`) is the only ML policy +plugin that handles LLM fine-tuning. It hardcodes TorchTune support via **command-sniffing**: -@dataclass -class LLMBackend(abc.ABC): - - @abc.abstractmethod - def to_command(self) -> tuple[str, ...]: - """Container entrypoint command.""" - ... - - @abc.abstractmethod - def to_args(self) -> list[str]: - """CLI arguments for .spec.trainer.args.""" - ... - - @abc.abstractmethod - def framework(self) -> str: - """Framework identifier matching trainer.kubeflow.org/framework label.""" - ... - - def validate(self) -> None: - """Optional config validation. Raise ValueError on invalid config.""" - pass - - @property - def num_nodes(self) -> int | None: - return None - - @property - def resources_per_node(self) -> dict | None: - return None +```go +// torch.go:149 — the branching point +if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { + // Standard torchrun path: inject PET_MASTER_ADDR, PET_MASTER_PORT +} else { + // TorchTune path: mutate command with recipe, config, rdzv_endpoint +} ``` -### SDK: Backend Registry - -New file: `kubeflow/trainer/types/backends/registry.py` - -```python -from collections.abc import Callable - -_BACKEND_REGISTRY: dict[str, type] = {} - - -def register_backend(name: str) -> Callable: - def decorator(cls): - if name in _BACKEND_REGISTRY: - raise ValueError( - f"Backend '{name}' already registered by {_BACKEND_REGISTRY[name].__name__}." - ) - _BACKEND_REGISTRY[name] = cls - return cls - return decorator - - -def get_registered_backends() -> dict[str, type]: - return dict(_BACKEND_REGISTRY) +`constants.TorchTuneEntrypoint` is `[]string{"tune", "run"}`. When the trainer command +matches this, the plugin enters the TorchTune branch (torch.go:159-183) which: +1. Builds the rendezvous endpoint: `--rdzv_endpoint={name}-node-0-0.{name}:29500` +2. Calls `getRecipeAndConfig()` (torchtune.go:78) to select a recipe/config pair + from a matrix of `numNodes × numGPUs × LoRA/QLoRA` combinations. +3. Calls `extractOverridesFromRuntime()` (torchtune.go:131) to pull immutable config + overrides from the ClusterTrainingRuntime's node container command. +4. Appends all of this to `trainJob.Spec.Trainer.Command`. -def get_backend(name: str) -> type | None: - return _BACKEND_REGISTRY.get(name) -``` +The validation path (torch.go:88) also sniffs the same entrypoint to decide whether +to run `validateTorchTune()`. -### SDK: BuiltinTrainer Change +### SDK Coupling -In `kubeflow/trainer/types/types.py`: +In the Python SDK (`kubeflow/sdk` repo), `BuiltinTrainer` has a single field: ```python -# BEFORE -@dataclass -class BuiltinTrainer: - config: TorchTuneConfig - -# AFTER @dataclass class BuiltinTrainer: - config: LLMBackend -``` - -`TorchTuneConfig` implements `LLMBackend`, so existing -`BuiltinTrainer(config=TorchTuneConfig(...))` code is unchanged. - -The SDK's `KubernetesBackend` dispatch becomes generic: - -```python -def _get_trainer_spec(self, trainer: BuiltinTrainer) -> dict: - backend = trainer.config - backend.validate() - return { - "command": list(backend.to_command()), - "args": backend.to_args(), - } -``` - -This replaces the current `get_args_using_torchtune_config()`. - -### SDK: TorchTune Backend (Refactored) - -New file: `kubeflow/trainer/types/backends/torchtune.py` - -```python -@register_backend("torchtune") -@dataclass -class TorchTuneConfig(LLMBackend): - dtype: DataType | None = None - batch_size: int | None = None - epochs: int | None = None - loss: Loss | None = None - _num_nodes: int | None = field(default=None, repr=True) - peft_config: LoraConfig | None = None - dataset_preprocess_config: TorchTuneInstructDataset | None = None - _resources_per_node: dict | None = field(default=None, repr=True) - - def to_command(self) -> tuple[str, ...]: - return ("tune", "run") - - def to_args(self) -> list[str]: - args = [] - if self.dtype is not None: - args.append(f"dtype={self.dtype.value}") - if self.batch_size is not None: - args.append(f"batch_size={self.batch_size}") - if self.epochs is not None: - args.append(f"epochs={self.epochs}") - if self.loss is not None: - args.append(f"loss={self.loss.value}") - if self.peft_config is not None: - args.extend(_get_args_from_peft_config(self.peft_config)) - if self.dataset_preprocess_config is not None: - args.extend( - _get_args_from_dataset_preprocess_config(self.dataset_preprocess_config) - ) - return args - - def framework(self) -> str: - return "torchtune" - - @property - def num_nodes(self) -> int | None: - return self._num_nodes - - @property - def resources_per_node(self) -> dict | None: - return self._resources_per_node -``` - -Helper functions `_get_args_from_peft_config()` and -`_get_args_from_dataset_preprocess_config()` are extracted from the current -`get_args_using_torchtune_config()` in `backends/kubernetes/utils.py` unchanged. - -### SDK: TRL Backend - -New file: `kubeflow/trainer/types/backends/trl.py` - -```python -class TRLTrainerType(Enum): - SFT = "sft" - DPO = "dpo" - PPO = "ppo" - ORPO = "orpo" - KTO = "kto" - - -@dataclass -class TRLPeftConfig: - r: int = 16 - lora_alpha: int = 32 - lora_dropout: float = 0.05 - target_modules: list[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) - use_rslora: bool = False - use_dora: bool = False - - -@dataclass -class TRLSFTConfig: - max_seq_length: int = 2048 - packing: bool = False - dataset_text_field: str | None = None - - -@dataclass -class TRLDPOConfig: - beta: float = 0.1 - max_length: int = 1024 - max_prompt_length: int = 512 - loss_type: str = "sigmoid" - - -@register_backend("trl") -@dataclass -class TRLConfig(LLMBackend): - trainer_type: TRLTrainerType = TRLTrainerType.SFT - model_name_or_path: str = "/workspace/model" - learning_rate: float = 2e-5 - num_train_epochs: int = 3 - per_device_train_batch_size: int = 4 - gradient_accumulation_steps: int = 1 - bf16: bool = True - fp16: bool = False - peft_config: TRLPeftConfig | None = None - sft_config: TRLSFTConfig | None = None - dpo_config: TRLDPOConfig | None = None - _num_nodes: int | None = None - _resources_per_node: dict | None = None - output_dir: str = "/workspace/output" - - def to_command(self) -> tuple[str, ...]: - return ("python", "-m", "trl") - - def to_args(self) -> list[str]: - args = [self.trainer_type.value] - args.extend(["--model_name_or_path", self.model_name_or_path]) - args.extend(["--output_dir", self.output_dir]) - args.extend(["--learning_rate", str(self.learning_rate)]) - args.extend(["--num_train_epochs", str(self.num_train_epochs)]) - args.extend(["--per_device_train_batch_size", str(self.per_device_train_batch_size)]) - args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)]) - if self.bf16: - args.append("--bf16") - if self.fp16: - args.append("--fp16") - if self.peft_config is not None: - args.append("--use_peft") - args.extend(["--lora_r", str(self.peft_config.r)]) - args.extend(["--lora_alpha", str(self.peft_config.lora_alpha)]) - args.extend(["--lora_dropout", str(self.peft_config.lora_dropout)]) - if self.peft_config.target_modules: - args.extend(["--lora_target_modules", *self.peft_config.target_modules]) - if self.trainer_type == TRLTrainerType.SFT and self.sft_config: - args.extend(["--max_seq_length", str(self.sft_config.max_seq_length)]) - if self.sft_config.packing: - args.append("--packing") - if self.sft_config.dataset_text_field: - args.extend(["--dataset_text_field", self.sft_config.dataset_text_field]) - if self.trainer_type == TRLTrainerType.DPO and self.dpo_config: - args.extend(["--beta", str(self.dpo_config.beta)]) - args.extend(["--max_length", str(self.dpo_config.max_length)]) - args.extend(["--max_prompt_length", str(self.dpo_config.max_prompt_length)]) - args.extend(["--loss_type", self.dpo_config.loss_type]) - return args - - def framework(self) -> str: - return "trl" - - def validate(self) -> None: - if self.bf16 and self.fp16: - raise ValueError("Cannot enable both bf16 and fp16.") - if self.trainer_type == TRLTrainerType.DPO and self.sft_config is not None: - raise ValueError("sft_config should not be set when trainer_type is DPO.") - if self.trainer_type == TRLTrainerType.SFT and self.dpo_config is not None: - raise ValueError("dpo_config should not be set when trainer_type is SFT.") - - @property - def num_nodes(self) -> int | None: - return self._num_nodes - - @property - def resources_per_node(self) -> dict | None: - return self._resources_per_node -``` - -### Go Control Plane: LLMBackendStrategy - -New file: `pkg/runtime/framework/plugins/torch/backend.go` - -Replace the current command-sniffing in `torch.go` and `torchtune.go` with a strategy -interface: - -```go -type LLMBackendStrategy interface { - Name() string - EnforceCommand(info *runtime.Info, trainJob *trainer.TrainJob) error - Validate(info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) -} -``` - -`TorchTuneStrategy` wraps the existing `torchtune.go` logic (getRecipeAndConfig, -extractOverridesFromRuntime, validateTorchTune) with no behavioral changes: - -```go -type TorchTuneStrategy struct{} - -func (s *TorchTuneStrategy) Name() string { return "torchtune" } - -func (s *TorchTuneStrategy) EnforceCommand(info *runtime.Info, trainJob *trainer.TrainJob) error { - // Existing torchtune.go logic: rdzv_endpoint, recipe, config, overrides - return nil -} - -func (s *TorchTuneStrategy) Validate(info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { - return validateTorchTune(info, newObj) -} -``` - -`TRLStrategy` is minimal -- TRL config is fully constructed by the SDK: - -```go -type TRLStrategy struct{} - -func (s *TRLStrategy) Name() string { return "trl" } - -func (s *TRLStrategy) EnforceCommand(info *runtime.Info, trainJob *trainer.TrainJob) error { - // Inject rendezvous endpoint for multi-node. SDK provides full args. - return nil -} - -func (s *TRLStrategy) Validate(info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { - return nil, nil -} -``` - -### Go Control Plane: Strategy Dispatch - -The Torch plugin struct holds registered strategies and dispatches via the runtime's -framework label instead of command-sniffing: - -```go -type Torch struct { - backends map[string]LLMBackendStrategy -} - -func New(ctx context.Context, client client.Client, indexer client.FieldIndexer) (framework.Plugin, error) { - return &Torch{ - backends: map[string]LLMBackendStrategy{ - "torchtune": &TorchTuneStrategy{}, - "trl": &TRLStrategy{}, - }, - }, nil -} - -func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { - // Common torch distributed setup (PET_NUM_NODES, PET_NPROC_PER_NODE, etc.) - // ... - - framework := info.Labels[constants.RuntimeFrameworkLabel] - if strategy, ok := t.backends[framework]; ok { - return strategy.EnforceCommand(info, trainJob) - } - - // Default: standard torchrun path - // ... -} - -func (t *Torch) Validate(ctx context.Context, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { - // Common torch validation ... - - framework := info.Labels[constants.RuntimeFrameworkLabel] - if strategy, ok := t.backends[framework]; ok { - warnings, errs := strategy.Validate(info, oldObj, newObj) - // append ... - } - // ... -} -``` - -### Go Control Plane: Constants - -```go -// pkg/constants/constants.go - -TRLEntrypoint = []string{"python", "-m", "trl"} -TRLFrameworkLabel = "trl" -RuntimeFrameworkLabel = "trainer.kubeflow.org/framework" -``` - -### Container Images - -```dockerfile -# cmd/trainers/trl/Dockerfile -FROM pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime -WORKDIR /workspace -RUN apt update && apt-get install -y --no-install-recommends build-essential \ - && rm -rf /var/lib/apt/lists/* -COPY cmd/trainers/trl/requirements.txt . -RUN pip install -r requirements.txt -``` - -``` -# cmd/trainers/trl/requirements.txt -trl>=0.15.0 -transformers>=4.48.0 -datasets>=3.0.0 -accelerate>=1.2.0 -peft>=0.14.0 -bitsandbytes>=0.41.1 -``` - -Published as `ghcr.io/kubeflow/trainer/trl-trainer`. - -### ClusterTrainingRuntimes - -Example: `manifests/base/runtimes/trl/llama3_2/llama3_2_1B.yaml` - -```yaml -apiVersion: trainer.kubeflow.org/v1alpha1 -kind: ClusterTrainingRuntime -metadata: - name: trl-llama3.2-1b - labels: - trainer.kubeflow.org/framework: trl -spec: - mlPolicy: - numNodes: 1 - torch: - numProcPerNode: auto - template: - spec: - replicatedJobs: - - name: dataset-initializer - template: - metadata: - labels: - trainer.kubeflow.org/trainjob-ancestor-step: dataset-initializer - spec: - template: - spec: - containers: - - name: dataset-initializer - image: ghcr.io/kubeflow/trainer/dataset-initializer - env: - - name: STORAGE_URI - value: hf://tatsu-lab/alpaca - volumeMounts: - - mountPath: /workspace/dataset - name: workspace - volumes: - - name: workspace - persistentVolumeClaim: - claimName: workspace - - name: model-initializer - template: - metadata: - labels: - trainer.kubeflow.org/trainjob-ancestor-step: model-initializer - spec: - template: - spec: - containers: - - name: model-initializer - image: ghcr.io/kubeflow/trainer/model-initializer - env: - - name: STORAGE_URI - value: hf://meta-llama/Llama-3.2-1B-Instruct - volumeMounts: - - mountPath: /workspace/model - name: workspace - volumes: - - name: workspace - persistentVolumeClaim: - claimName: workspace - - name: node - dependsOn: - - dataset-initializer - - model-initializer - template: - metadata: - labels: - trainer.kubeflow.org/trainjob-ancestor-step: trainer - spec: - template: - spec: - containers: - - name: node - image: ghcr.io/kubeflow/trainer/trl-trainer - command: - - python - - -m - - trl - args: - - sft - - --model_name_or_path - - /workspace/model - - --output_dir - - /workspace/output - - --dataset_name - - /workspace/dataset - resources: - limits: - nvidia.com/gpu: 2 - volumeMounts: - - mountPath: /workspace - name: workspace - volumes: - - name: workspace - persistentVolumeClaim: - claimName: workspace -``` - -Directory structure: - -``` -manifests/base/runtimes/ -├── torchtune/ # Existing (unchanged) -└── trl/ # NEW - ├── kustomization.yaml - ├── llama3_2/ - │ ├── llama3_2_1B.yaml - │ └── llama3_2_3B.yaml - └── qwen2_5/ - └── qwen2_5_1.5B.yaml -``` - -### Helm Chart Changes - -```yaml -# charts/kubeflow-trainer/values.yaml (additions) -runtimes: - trlDistributed: - image: - registry: ghcr.io - repository: kubeflow/trainer/trl-trainer - tag: "" - llama3_2_1B: - enabled: false - llama3_2_3B: - enabled: false - qwen2_5_1_5B: - enabled: false -``` - -## Test Plan - -**Unit tests**: -- SDK backend registry: registration, duplicate detection, lookup. -- `TorchTuneConfig` backward compat: `to_args()` identical to current `get_args_using_torchtune_config()`. -- `TRLConfig`: `to_args()`, `to_command()`, `validate()` for SFT, DPO, error cases. -- Go Torch plugin: strategy dispatch, `TorchTuneStrategy` (existing cases), `TRLStrategy`. - -**Integration tests**: -- SDK creates valid TRL TrainJob CR that controller reconciles into JobSet. -- `client.list_runtimes()` returns both TorchTune and TRL runtimes. -- Existing TorchTune examples execute unchanged. - -**E2E tests**: -- TRL SFT with Llama 3.2 1B on Alpaca (GPU). -- TRL DPO with preference dataset (GPU). -- TorchTune regression. + config: TorchTuneConfig # No other option +``` + +The `KubernetesBackend.train()` method calls `get_args_using_torchtune_config()` in +`backends/kubernetes/utils.py` to translate the config into CLI args. There is no +abstraction — adding a new backend means modifying this function and the type annotation. + +### Why This Must Change + +- **TorchTune stopped adding features** in July 2025. The project is in maintenance mode. +- **The command-sniffing pattern doesn't scale.** Each new backend would require another + `slices.Equal` check, another branch in `EnforceMLPolicy`, and another branch in `Validate`. +- **Community consensus on Plan 3** (pluggable framework) from #2752 was unanimous. +- **TRL is actively maintained** by HuggingFace with native CLI support (`trl sft`, `trl dpo`, etc.) + and built-in accelerate integration for multi-GPU/multi-node. + +--- + +## High-Level Design + +### Architecture Overview + +The change is a **localized refactor** of two coupling points. No new CRDs, no new +controllers, no changes to the plugin framework itself. + +``` + BEFORE AFTER + ┌──────────────┐ ┌──────────────┐ + SDK │BuiltinTrainer│ │BuiltinTrainer│ + │ config: │ │ config: │ + │ TorchTune │ │ LLMBackend │ + │ Config │ │ (abstract) │ + └──────┬───────┘ └──────┬───────┘ + │ │ + │ to_args() │ to_command() / to_args() + ▼ ▼ + get_args_using_ backend.to_command() + torchtune_config() backend.to_args() + │ │ + │ creates TrainJob CR │ creates TrainJob CR + ▼ ▼ + ┌────────────────────────────────────────────────────────────────────────┐ + │ Kubernetes API │ + └────────────────────────────────┬───────────────────────────────────────┘ + │ + Go ▼ + Torch ┌─────────────────────────────────┐ + Plugin │ EnforceMLPolicy() │ + │ │ + BEFORE: │ if cmd == ["tune","run"]: │ + │ → TorchTune branch │ + │ else: │ + │ → torchrun branch │ + │ │ + AFTER: │ // common: PET env vars │ + │ label = info.Labels[framework] │ + │ if strategy = backends[label]: │ + │ → strategy.EnforceCommand() │ + │ else: │ + │ → default torchrun branch │ + └─────────────────────────────────┘ +``` + +### Component Interaction Flow + +End-to-end for a TRL SFT job: + +``` +1. User: TrainerClient.train(builtin_trainer=BuiltinTrainer(config=TRLConfig( + trainer_type=TRLTrainerType.SFT, ...))) + +2. SDK: TRLConfig.validate() → ok + TRLConfig.to_command() → ("trl",) + TRLConfig.to_args() → ["sft", "--model_name_or_path", "/workspace/model", ...] + Build TrainJob CR with: + runtimeRef: { name: "trl-llama3.2-1b" } + trainer: { command: ["trl"], args: ["sft", ...] } + +3. K8s: Webhook validates TrainJob + Torch plugin Validate() → label=trl → TRLStrategy.Validate() → ok + +4. Go: TrainJob controller reconciles: + Torch EnforceMLPolicy(): + a) Common: set PET_NNODES, PET_NPROC_PER_NODE, PET_NODE_RANK + b) Label "trl" → TRLStrategy.EnforceCommand(): + inject PET_MASTER_ADDR, PET_MASTER_PORT + inject MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK (accelerate-compatible) + c) Add container port + +5. K8s: Controller SSA → JobSet → ReplicatedJobs → Pods + Init: dataset-initializer downloads dataset + Init: model-initializer downloads model + Main: trl sft --model_name_or_path /workspace/model ... +``` + +### What Changes vs What Stays + +| Component | Changes? | Details | +|-----------|----------|---------| +| CRD schemas | **No** | No new fields, no new types | +| Plugin framework interfaces | **No** | Same 7 interfaces | +| Controller reconciliation | **No** | Same SSA flow | +| Webhooks | **No** | Same validation hooks (Torch plugin gains strategy dispatch) | +| Torch plugin (common path) | **No** | PET env var injection stays | +| Torch plugin (TorchTune path) | **Refactor** | Extract inline code → `TorchTuneStrategy` | +| Torch plugin (dispatch) | **New** | Label-based strategy lookup replaces command-sniffing | +| TRL strategy | **New** | `TRLStrategy` for TRL-specific env vars | +| SDK `BuiltinTrainer` | **Widen** | `TorchTuneConfig` → `LLMBackend` | +| SDK `TorchTuneConfig` | **Implement** | Implements `LLMBackend` (backward compatible) | +| SDK `TRLConfig` | **New** | New backend class | +| SDK registry | **New** | `@register_backend` decorator | +| Container images | **New** | `trl-trainer` image | +| ClusterTrainingRuntimes | **New** | TRL-specific runtime manifests | + +--- ## Risks and Mitigations | Risk | Mitigation | |------|------------| -| TRL CLI changes across versions | Pin version in requirements.txt; version compat tests | -| TRL uses accelerate, not torchrun | TRL supports torchrun-compatible launch; validate in E2E | +| TRL CLI changes across versions | Pin version range in requirements.txt; version compat tests | +| TRL uses accelerate, not torchrun, for distributed | TRLStrategy injects both `PET_*` and standard env vars; accelerate reads `MASTER_ADDR`, `MASTER_PORT`, `WORLD_SIZE`, `RANK`; validated in E2E | +| Multi-node TRL untested at scale | Phase 1 scoped to single-node multi-GPU; multi-node added in Phase 2 with dedicated E2E | | SDK type widening affects static analysis | TorchTuneConfig is a subtype of LLMBackend; passes type checks | -| Scope creep from adding backends | Scoped to TorchTune + TRL only | \ No newline at end of file +| Scope creep from adding backends | Scoped to TorchTune + TRL only | +| `trainer.kubeflow.org/framework` label not a Go constant | KEP adds `RuntimeFrameworkLabel` constant; existing manifests already use the label | \ No newline at end of file From 37f64fe37d732fcd53a767c182e84aff2db486e5 Mon Sep 17 00:00:00 2001 From: Sabari Narayana Date: Mon, 2 Mar 2026 14:05:23 +0530 Subject: [PATCH 03/11] updated KEP for TRL Signed-off-by: Sabari Narayana --- .../2839-dynamic-llm-trainer/README.md | 623 ++++++++++++++++++ 1 file changed, 623 insertions(+) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index cc14c3df9a..13e676cc2b 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -26,6 +26,19 @@ - [Architecture Overview](#architecture-overview) - [Component Interaction Flow](#component-interaction-flow) - [What Changes vs What Stays](#what-changes-vs-what-stays) + - [Design Details](#design-details) + - [Python SDK: `LLMBackend` Interface](#python-sdk-llmbackend-interface) + - [Python SDK: Backend Registry](#python-sdk-backend-registry) + - [Python SDK: `TRLConfig`](#python-sdk-trlconfig) + - [Python SDK: Integration into `KubernetesBackend`](#python-sdk-integration-into-kubernetesbackend) + - [Go Control Plane: `LLMBackendStrategy` Interface](#go-control-plane-llmbackendstrategy-interface) + - [Go Control Plane: `TorchTuneStrategy`](#go-control-plane-torchtunestrategy) + - [Go Control Plane: `TRLStrategy`](#go-control-plane-trlstrategy) + - [Go Control Plane: Refactored Torch Plugin Dispatch](#go-control-plane-refactored-torch-plugin-dispatch) + - [Go Control Plane: New Constant](#go-control-plane-new-constant) + - [TRL Container Image](#trl-container-image) + - [TRL `ClusterTrainingRuntime` Manifest](#trl-clustertrainingruntime-manifest) + - [SDK Usage Example](#sdk-usage-example) - [Risks and Mitigations](#risks-and-mitigations) @@ -215,6 +228,616 @@ End-to-end for a TRL SFT job: --- +## Design Details + +### Python SDK: `LLMBackend` Interface + +Today `BuiltinTrainer.config` is typed as `TorchTuneConfig` directly. This introduces an +abstract base class that every backend must implement. + +```python +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +class LLMBackend(ABC): + """Abstract base for all LLM training backends. + + Each implementation translates its config into a (command, args) pair + that the Kubernetes backend writes into the TrainJob CR. + """ + + @abstractmethod + def to_command(self) -> tuple[str, ...]: + """Return the container entrypoint command. + + Examples: + TorchTune: ("tune", "run") + TRL: ("trl",) + """ + ... + + @abstractmethod + def to_args(self, initializer: "Initializer | None" = None) -> list[str]: + """Return the CLI arguments for the entrypoint. + + Args: + initializer: Optional initializer config for resolving dataset/model paths. + + Returns: + List of string arguments (e.g. ["sft", "--model_name_or_path", "/workspace/model"]). + """ + ... + + @abstractmethod + def validate(self) -> None: + """Raise ValueError if the config is invalid.""" + ... +``` + +`BuiltinTrainer` widens its type annotation: + +```python +@dataclass +class BuiltinTrainer: + """Builtin Trainer configuration.""" + config: LLMBackend # was: TorchTuneConfig +``` + +`TorchTuneConfig` implements `LLMBackend` with no field changes — backward compatible: + +```python +@dataclass +class TorchTuneConfig(LLMBackend): + dtype: DataType | None = None + batch_size: int | None = None + epochs: int | None = None + loss: Loss | None = None + num_nodes: int | None = None + peft_config: LoraConfig | None = None + dataset_preprocess_config: TorchTuneInstructDataset | None = None + resources_per_node: dict | None = None + + def to_command(self) -> tuple[str, ...]: + return ("tune", "run") + + def to_args(self, initializer=None) -> list[str]: + # Existing get_args_using_torchtune_config() logic moves here + ... + + def validate(self) -> None: + ... +``` + +### Python SDK: Backend Registry + +A decorator-based registry enables out-of-tree backends (community requirement from #2752): + +```python +_BACKEND_REGISTRY: dict[str, type[LLMBackend]] = {} + + +def register_backend(name: str): + """Register an LLMBackend implementation under a framework name. + + Usage: + @register_backend("trl") + class TRLConfig(LLMBackend): + ... + """ + def decorator(cls: type[LLMBackend]) -> type[LLMBackend]: + if not issubclass(cls, LLMBackend): + raise TypeError(f"{cls.__name__} must subclass LLMBackend") + _BACKEND_REGISTRY[name] = cls + return cls + return decorator + + +def get_backend(name: str) -> type[LLMBackend]: + """Look up a registered backend by name.""" + if name not in _BACKEND_REGISTRY: + raise KeyError( + f"Unknown backend '{name}'. Registered: {list(_BACKEND_REGISTRY)}" + ) + return _BACKEND_REGISTRY[name] +``` + +Built-in backends register themselves at import time: + +```python +@register_backend("torchtune") +class TorchTuneConfig(LLMBackend): + ... + +@register_backend("trl") +class TRLConfig(LLMBackend): + ... +``` + +### Python SDK: `TRLConfig` + +```python +from enum import Enum + + +class TRLTrainerType(Enum): + """Training algorithms available via the TRL CLI.""" + SFT = "sft" + DPO = "dpo" + KTO = "kto" + GRPO = "grpo" + + +@dataclass +@register_backend("trl") +class TRLConfig(LLMBackend): + """TRL LLM Trainer configuration. + + Args: + trainer_type: Training algorithm (SFT, DPO, KTO, GRPO). + model_name_or_path: HuggingFace model ID or local path. + dataset_name: HuggingFace dataset ID or local path. + num_nodes: Number of training nodes. + resources_per_node: Resource requirements dict. + learning_rate: Learning rate. + num_train_epochs: Number of training epochs. + per_device_train_batch_size: Batch size per device. + gradient_checkpointing: Enable gradient checkpointing. + bf16: Use bfloat16 precision. + use_peft: Enable LoRA via PEFT. + lora_r: LoRA rank. + lora_alpha: LoRA alpha. + lora_target_modules: Comma-separated target modules for LoRA. + extra_args: Additional CLI arguments passed through verbatim. + """ + + trainer_type: TRLTrainerType = TRLTrainerType.SFT + model_name_or_path: str | None = None + dataset_name: str | None = None + num_nodes: int | None = None + resources_per_node: dict | None = None + learning_rate: float | None = None + num_train_epochs: int | None = None + per_device_train_batch_size: int | None = None + gradient_checkpointing: bool = True + bf16: bool = True + use_peft: bool = False + lora_r: int | None = None + lora_alpha: int | None = None + lora_target_modules: str | None = None + extra_args: dict[str, str] | None = None + + def to_command(self) -> tuple[str, ...]: + return ("trl",) + + def to_args(self, initializer=None) -> list[str]: + args = [self.trainer_type.value] # subcommand: "sft", "dpo", etc. + + # Model path: prefer initializer workspace, fall back to config + model_path = self.model_name_or_path + if initializer and initializer.model: + model_path = "/workspace/model" + if model_path: + args.extend(["--model_name_or_path", model_path]) + + # Dataset: prefer initializer workspace, fall back to config + dataset = self.dataset_name + if initializer and initializer.dataset: + dataset = "/workspace/dataset" + if dataset: + args.extend(["--dataset_name", dataset]) + + if self.learning_rate is not None: + args.extend(["--learning_rate", str(self.learning_rate)]) + if self.num_train_epochs is not None: + args.extend(["--num_train_epochs", str(self.num_train_epochs)]) + if self.per_device_train_batch_size is not None: + args.extend(["--per_device_train_batch_size", str(self.per_device_train_batch_size)]) + if self.gradient_checkpointing: + args.append("--gradient_checkpointing") + if self.bf16: + args.append("--bf16") + if self.use_peft: + args.append("--use_peft") + if self.lora_r is not None: + args.extend(["--lora_r", str(self.lora_r)]) + if self.lora_alpha is not None: + args.extend(["--lora_alpha", str(self.lora_alpha)]) + if self.lora_target_modules: + args.extend(["--lora_target_modules", self.lora_target_modules]) + + # Pass-through extra args + if self.extra_args: + for k, v in self.extra_args.items(): + args.extend([f"--{k}", v]) + + return args + + def validate(self) -> None: + if self.use_peft and self.lora_r is None: + raise ValueError("lora_r is required when use_peft=True") +``` + +### Python SDK: Integration into `KubernetesBackend` + +The current `get_trainer_cr_from_builtin_trainer()` hardcodes `isinstance(trainer.config, TorchTuneConfig)`. +This changes to use the `LLMBackend` interface: + +```python +# backends/kubernetes/utils.py (modified) + +def get_trainer_cr_from_builtin_trainer( + runtime: types.Runtime, + trainer: types.BuiltinTrainer, + initializer: types.Initializer | None = None, +) -> models.TrainerV1alpha1Trainer: + config = trainer.config + if not isinstance(config, LLMBackend): + raise ValueError(f"BuiltinTrainer config must implement LLMBackend, got: {type(config)}") + + config.validate() + + trainer_cr = models.TrainerV1alpha1Trainer() + if hasattr(config, "num_nodes") and config.num_nodes: + trainer_cr.num_nodes = config.num_nodes + if hasattr(config, "resources_per_node") and config.resources_per_node: + trainer_cr.resources_per_node = get_resources_per_node(config.resources_per_node) + + trainer_cr.command = list(config.to_command()) + trainer_cr.args = config.to_args(initializer) + return trainer_cr +``` + +### Go Control Plane: `LLMBackendStrategy` Interface + +Inside the Torch plugin package, a strategy interface replaces the inline if/else: + +```go +// pkg/runtime/framework/plugins/torch/strategy.go + +package torch + +import ( + "k8s.io/apimachinery/pkg/util/validation/field" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/v2/pkg/runtime" +) + +// LLMBackendStrategy defines backend-specific behavior for the Torch plugin. +// Each strategy handles the portion of EnforceMLPolicy and Validate that differs +// between backends (e.g., command mutation, env var injection, validation rules). +type LLMBackendStrategy interface { + // EnforceCommand mutates the trainer container's command, args, and env vars + // with backend-specific values (e.g., rendezvous args for TorchTune, + // accelerate env vars for TRL). + EnforceCommand(info *runtime.Info, trainJob *trainer.TrainJob, container *runtime.Container) error + + // Validate performs backend-specific validation on the TrainJob. + Validate(runtimeInfo *runtime.Info, trainJob *trainer.TrainJob) (admission.Warnings, field.ErrorList) +} +``` + +### Go Control Plane: `TorchTuneStrategy` + +Extracts the existing inline code from `torch.go:159-183` and `torchtune.go`: + +```go +// pkg/runtime/framework/plugins/torch/torchtune_strategy.go + +type TorchTuneStrategy struct{} + +func (s *TorchTuneStrategy) EnforceCommand( + info *runtime.Info, + trainJob *trainer.TrainJob, + container *runtime.Container, +) error { + // Moved from torch.go:159-183 + // 1. Build rendezvous endpoint args + // 2. Call getRecipeAndConfig() for recipe/config selection + // 3. Call extractOverridesFromRuntime() for immutable overrides + // 4. Append to trainJob.Spec.Trainer.Command + return nil +} + +func (s *TorchTuneStrategy) Validate( + runtimeInfo *runtime.Info, + trainJob *trainer.TrainJob, +) (admission.Warnings, field.ErrorList) { + // Calls existing validateTorchTune() + return validateTorchTune(runtimeInfo, trainJob) +} +``` + +### Go Control Plane: `TRLStrategy` + +```go +// pkg/runtime/framework/plugins/torch/trl_strategy.go + +type TRLStrategy struct{} + +func (s *TRLStrategy) EnforceCommand( + info *runtime.Info, + trainJob *trainer.TrainJob, + container *runtime.Container, +) error { + trainerPS := info.FindPodSetByAncestor(constants.AncestorTrainer) + numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1) + masterAddr := fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name) + masterPort := fmt.Sprintf("%d", constants.ContainerTrainerPort) + worldSize := fmt.Sprintf("%d", numNodes * numProcPerNode) + + // TRL uses accelerate, which reads standard env vars (not PET_* variants). + // Inject both sets for compatibility. + apply.UpsertEnvVars(&container.Env, + // PET env vars (for torchrun compatibility) + *corev1ac.EnvVar().WithName(constants.TorchEnvMasterAddr).WithValue(masterAddr), + *corev1ac.EnvVar().WithName(constants.TorchEnvMasterPort).WithValue(masterPort), + // Standard env vars (for accelerate/TRL) + *corev1ac.EnvVar().WithName("MASTER_ADDR").WithValue(masterAddr), + *corev1ac.EnvVar().WithName("MASTER_PORT").WithValue(masterPort), + *corev1ac.EnvVar().WithName("WORLD_SIZE").WithValue(worldSize), + *corev1ac.EnvVar().WithName("RANK").WithValueFrom( + corev1ac.EnvVarSource().WithFieldRef( + corev1ac.ObjectFieldSelector().WithFieldPath(constants.JobCompletionIndexFieldPath), + ), + ), + ) + return nil +} + +func (s *TRLStrategy) Validate( + runtimeInfo *runtime.Info, + trainJob *trainer.TrainJob, +) (admission.Warnings, field.ErrorList) { + // TRL validation: check that trainer_type subcommand is valid, etc. + return nil, nil +} +``` + +### Go Control Plane: Refactored Torch Plugin Dispatch + +The `Torch` struct gains a `backends` map, and `EnforceMLPolicy` dispatches by label: + +```go +// pkg/runtime/framework/plugins/torch/torch.go (modified) + +type Torch struct { + backends map[string]LLMBackendStrategy +} + +func New(ctx context.Context, c client.Client, fi client.FieldIndexer) (framework.Plugin, error) { + return &Torch{ + backends: map[string]LLMBackendStrategy{ + "torchtune": &TorchTuneStrategy{}, + "trl": &TRLStrategy{}, + }, + }, nil +} +``` + +The dispatch logic in `EnforceMLPolicy` changes from command-sniffing to label lookup: + +```go +func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { + // ... (existing common logic: numNodes, numProcPerNode, PET_NNODES, + // PET_NPROC_PER_NODE, PET_NODE_RANK — unchanged) ... + + // NEW: label-based dispatch replaces command-sniffing + framework := info.Labels[constants.RuntimeFrameworkLabel] // "trainer.kubeflow.org/framework" + if strategy, ok := t.backends[framework]; ok { + if err := strategy.EnforceCommand(info, trainJob, trainerContainer); err != nil { + return err + } + } else { + // Default: standard torchrun path (PET_MASTER_ADDR, PET_MASTER_PORT) + apply.UpsertEnvVars(&trainerContainer.Env, + *corev1ac.EnvVar().WithName(constants.TorchEnvMasterAddr).WithValue(...), + *corev1ac.EnvVar().WithName(constants.TorchEnvMasterPort).WithValue(...), + ) + } + + // ... (existing: add container port) ... + return nil +} +``` + +The same pattern applies to `Validate`: + +```go +func (t *Torch) Validate(ctx context.Context, runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { + // ... (existing common validation: numProcPerNode, reserved envs) ... + + // NEW: label-based dispatch replaces command-sniffing + framework := runtimeInfo.Labels[constants.RuntimeFrameworkLabel] + if strategy, ok := t.backends[framework]; ok { + warnings, errs := strategy.Validate(runtimeInfo, newObj) + allErrs = append(allErrs, errs...) + return warnings, allErrs + } + return nil, allErrs +} +``` + +### Go Control Plane: New Constant + +```go +// pkg/constants/constants.go (addition) + +// RuntimeFrameworkLabel is the label on ClusterTrainingRuntime manifests +// that identifies which LLM framework the runtime belongs to. +// Existing manifests already use this label (e.g., "torchtune"). +const RuntimeFrameworkLabel string = "trainer.kubeflow.org/framework" +``` + +### TRL Container Image + +A minimal Dockerfile for the TRL trainer image: + +```dockerfile +FROM python:3.11-slim + +RUN pip install --no-cache-dir \ + trl>=0.15.0,<1.0.0 \ + torch>=2.5.0 \ + peft>=0.8.0 + +ENTRYPOINT ["trl"] +``` + +The image is published as `ghcr.io/kubeflow/trainer/trl-trainer` alongside the existing +`ghcr.io/kubeflow/trainer/torchtune-trainer`. + +### TRL `ClusterTrainingRuntime` Manifest + +Example runtime for Llama 3.2 1B SFT with TRL (modeled on the existing TorchTune runtime): + +```yaml +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: trl-llama3.2-1b + labels: + trainer.kubeflow.org/framework: trl +spec: + mlPolicy: + numNodes: 1 + torch: + numProcPerNode: auto + template: + spec: + volumeClaimPolicies: + - templates: + - metadata: + name: initializer + spec: + accessModes: ["ReadWriteOnce"] + resources: + requests: + storage: 20Gi + replicatedJobs: + - name: dataset-initializer + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: dataset-initializer + spec: + template: + spec: + containers: + - name: dataset-initializer + image: ghcr.io/kubeflow/trainer/dataset-initializer + env: + - name: STORAGE_URI + value: hf://tatsu-lab/alpaca + volumeMounts: + - mountPath: /workspace + name: initializer + - name: model-initializer + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: model-initializer + spec: + template: + spec: + containers: + - name: model-initializer + image: ghcr.io/kubeflow/trainer/model-initializer + env: + - name: STORAGE_URI + value: hf://meta-llama/Llama-3.2-1B-Instruct + volumeMounts: + - name: initializer + mountPath: /workspace + - name: node + dependsOn: + - name: dataset-initializer + status: Complete + - name: model-initializer + status: Complete + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: trainer + spec: + template: + spec: + containers: + - name: node + image: ghcr.io/kubeflow/trainer/trl-trainer + command: + - trl + args: + - sft + - --model_name_or_path + - /workspace/model + - --dataset_name + - /workspace/dataset + - --output_dir + - /workspace/output + - --gradient_checkpointing + - --bf16 + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - mountPath: /workspace + name: initializer +``` + +### SDK Usage Example + +End-to-end TRL SFT fine-tuning from the Python SDK: + +```python +from kubeflow.trainer import TrainerClient, types + +client = TrainerClient() + +client.train( + runtime="trl-llama3.2-1b", + initializer=types.Initializer( + model=types.HuggingFaceModelInitializer( + storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct", + ), + dataset=types.HuggingFaceDatasetInitializer( + storage_uri="hf://tatsu-lab/alpaca", + ), + ), + trainer=types.BuiltinTrainer( + config=types.TRLConfig( + trainer_type=types.TRLTrainerType.SFT, + num_train_epochs=3, + per_device_train_batch_size=4, + learning_rate=2e-5, + bf16=True, + gradient_checkpointing=True, + use_peft=True, + lora_r=16, + lora_alpha=32, + ), + ), +) +``` + +For DPO, only `trainer_type` and dataset change: + +```python +client.train( + runtime="trl-llama3.2-1b", + trainer=types.BuiltinTrainer( + config=types.TRLConfig( + trainer_type=types.TRLTrainerType.DPO, + learning_rate=1e-6, + ), + ), +) +``` + +--- + ## Risks and Mitigations | Risk | Mitigation | From f75513615f2ece1c30e012b8391cbedf38c55fdf Mon Sep 17 00:00:00 2001 From: Sabari Narayana Date: Mon, 2 Mar 2026 14:46:35 +0530 Subject: [PATCH 04/11] updated kep for dpo example --- docs/proposals/2839-dynamic-llm-trainer/README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index 13e676cc2b..a363de617d 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -822,15 +822,25 @@ client.train( ) ``` -For DPO, only `trainer_type` and dataset change: +For DPO, the `trainer_type` changes and the dataset must be a preference dataset +with chosen/rejected pairs: ```python client.train( runtime="trl-llama3.2-1b", + initializer=types.Initializer( + model=types.HuggingFaceModelInitializer( + storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct", + ), + dataset=types.HuggingFaceDatasetInitializer( + storage_uri="hf://argilla/ultrafeedback-binarized-preferences", + ), + ), trainer=types.BuiltinTrainer( config=types.TRLConfig( trainer_type=types.TRLTrainerType.DPO, learning_rate=1e-6, + bf16=True, ), ), ) From 27b389471482a4b7037397b0c59c55bfd6d15275 Mon Sep 17 00:00:00 2001 From: Sabari Narayana Date: Tue, 10 Mar 2026 12:27:15 +0530 Subject: [PATCH 05/11] =?UTF-8?q?docs:=20simplify=20KEP-2839=20=E2=80=94?= =?UTF-8?q?=20drop=20registry,=20use=20ClassVar=20command,=20add=20KEP-285?= =?UTF-8?q?=20alignment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove @register_backend decorator and backend registry (YAGNI with 2 backends) - Change to_command() method to command: ClassVar[tuple[str, ...]] - Move num_nodes/resources_per_node to LLMBackend base class - Add Relationship to KEP-285 section for config-driven vs function-based trainers - Simplify KubernetesBackend integration (no hasattr checks) - Remove stale Phase 1/Phase 2 references from Risks table - Goals reduced from 7 to 5 --- .../2839-dynamic-llm-trainer/README.md | 162 +++++++----------- 1 file changed, 60 insertions(+), 102 deletions(-) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index a363de617d..e4b72e814d 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -22,13 +22,13 @@ - [How TorchTune Is Wired Today](#how-torchtune-is-wired-today) - [SDK Coupling](#sdk-coupling) - [Why This Must Change](#why-this-must-change) + - [Relationship to KEP-285 (Specialized Trainer Abstractions)](#relationship-to-kep-285-specialized-trainer-abstractions) - [High-Level Design](#high-level-design) - [Architecture Overview](#architecture-overview) - [Component Interaction Flow](#component-interaction-flow) - [What Changes vs What Stays](#what-changes-vs-what-stays) - [Design Details](#design-details) - [Python SDK: `LLMBackend` Interface](#python-sdk-llmbackend-interface) - - [Python SDK: Backend Registry](#python-sdk-backend-registry) - [Python SDK: `TRLConfig`](#python-sdk-trlconfig) - [Python SDK: Integration into `KubernetesBackend`](#python-sdk-integration-into-kubernetesbackend) - [Go Control Plane: `LLMBackendStrategy` Interface](#go-control-plane-llmbackendstrategy-interface) @@ -49,7 +49,10 @@ Decouple the `BuiltinTrainer` from TorchTune by introducing a pluggable `LLMBackend` interface in the SDK and a corresponding `LLMBackendStrategy` in the Go control plane. TorchTune becomes the first backend implementation (preserving backward compatibility), -and TRL is added as the first new backend with SFT/DPO support. +and TRL is added as the first new backend with SFT/DPO support. Config-driven backends +sit alongside [KEP-285](https://github.com/kubeflow/sdk/pull/308)'s function-based +trainers as Tier 2 extensions; see +[Relationship to KEP-285](#relationship-to-kep-285-specialized-trainer-abstractions). This builds on [KEP-2401](../2401-llm-trainer-v2/README.md) and the community consensus on "Plan 3" in [#2752](https://github.com/kubeflow/trainer/issues/2752). @@ -58,14 +61,12 @@ TorchTune stopped adding features in July 2025 ## Goals -1. Define an `LLMBackend` abstract interface in the Python SDK. -2. Implement a backend registry with `@register_backend` decorator. -3. Refactor `TorchTuneConfig` to implement `LLMBackend` with zero breaking changes. -4. Implement `TRLConfig` backend supporting SFT and DPO. -5. Create TRL container image and `ClusterTrainingRuntime` manifests. -6. Generalize the Go Torch plugin to dispatch via `LLMBackendStrategy` instead of +1. Define an `LLMBackend` abstract interface in the Python SDK for config-driven trainers. +2. Refactor `TorchTuneConfig` to implement `LLMBackend` with zero breaking changes. +3. Implement `TRLConfig` backend supporting SFT and DPO. +4. Create TRL container image and `ClusterTrainingRuntime` manifests. +5. Generalize the Go Torch plugin to dispatch via `LLMBackendStrategy` instead of hardcoded TorchTune command-sniffing. -7. Support external (out-of-tree) backend registration. ## Non-Goals @@ -130,6 +131,29 @@ abstraction — adding a new backend means modifying this function and the type --- +## Relationship to KEP-285 (Specialized Trainer Abstractions) + +[KEP-285](https://github.com/kubeflow/sdk/pull/308) introduces a `BaseTrainer` ABC for +function-based trainers (`TorchTrainer`, `JAXTrainer`, etc.) and a `RuntimeConfig` +dataclass. This KEP is complementary — it addresses **config-driven trainers** where the +framework's own CLI is the entrypoint (e.g., `trl sft ...`, `tune run ...`), not a +user-supplied Python function. + +In KEP-285's terminology, `LLMBackend` implementations are **Tier 2 config-driven +trainers**. If KEP-285 merges first, `LLMBackend` configs can be passed through +KEP-285's `TorchTrainer` instead of `BuiltinTrainer` — the interface is the same +(`command` class var / `to_args()`), only the entry point changes. + +**Shared design points**: + +- Both use `trainer.kubeflow.org/framework` as the dispatch key — KEP-285 for SDK + runtime auto-discovery, this KEP for Go strategy dispatch. +- Both KEPs are compatible with either keeping or deprecating `BuiltinTrainer`. +- If the framework label is promoted to a Runtime API spec field (as discussed in the + KEP-285 review), both KEPs benefit with no changes. + +--- + ## High-Level Design ### Architecture Overview @@ -145,11 +169,11 @@ controllers, no changes to the plugin framework itself. │ TorchTune │ │ LLMBackend │ │ Config │ │ (abstract) │ └──────┬───────┘ └──────┬───────┘ - │ │ - │ to_args() │ to_command() / to_args() - ▼ ▼ - get_args_using_ backend.to_command() - torchtune_config() backend.to_args() + │ │ + │ to_args() │ config.command / to_args() + ▼ ▼ + get_args_using_ config.command + torchtune_config() config.to_args() │ │ │ creates TrainJob CR │ creates TrainJob CR ▼ ▼ @@ -184,8 +208,8 @@ End-to-end for a TRL SFT job: trainer_type=TRLTrainerType.SFT, ...))) 2. SDK: TRLConfig.validate() → ok - TRLConfig.to_command() → ("trl",) - TRLConfig.to_args() → ["sft", "--model_name_or_path", "/workspace/model", ...] + TRLConfig.command → ("trl",) + TRLConfig.to_args() → ["sft", "--model_name_or_path", "/workspace/model", ...] Build TrainJob CR with: runtimeRef: { name: "trl-llama3.2-1b" } trainer: { command: ["trl"], args: ["sft", ...] } @@ -222,7 +246,6 @@ End-to-end for a TRL SFT job: | SDK `BuiltinTrainer` | **Widen** | `TorchTuneConfig` → `LLMBackend` | | SDK `TorchTuneConfig` | **Implement** | Implements `LLMBackend` (backward compatible) | | SDK `TRLConfig` | **New** | New backend class | -| SDK registry | **New** | `@register_backend` decorator | | Container images | **New** | `trl-trainer` image | | ClusterTrainingRuntimes | **New** | TRL-specific runtime manifests | @@ -233,40 +256,31 @@ End-to-end for a TRL SFT job: ### Python SDK: `LLMBackend` Interface Today `BuiltinTrainer.config` is typed as `TorchTuneConfig` directly. This introduces an -abstract base class that every backend must implement. +abstract base class that every config-driven backend must implement. ```python from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import ClassVar class LLMBackend(ABC): - """Abstract base for all LLM training backends. + """Abstract base for config-driven LLM training backends. Each implementation translates its config into a (command, args) pair that the Kubernetes backend writes into the TrainJob CR. """ - @abstractmethod - def to_command(self) -> tuple[str, ...]: - """Return the container entrypoint command. + # Subclasses set this to their CLI entrypoint, e.g. ("tune", "run") or ("trl",) + command: ClassVar[tuple[str, ...]] - Examples: - TorchTune: ("tune", "run") - TRL: ("trl",) - """ - ... + # Common fields shared by all backends + num_nodes: int | None = None + resources_per_node: dict | None = None @abstractmethod def to_args(self, initializer: "Initializer | None" = None) -> list[str]: - """Return the CLI arguments for the entrypoint. - - Args: - initializer: Optional initializer config for resolving dataset/model paths. - - Returns: - List of string arguments (e.g. ["sft", "--model_name_or_path", "/workspace/model"]). - """ + """Return CLI arguments for the entrypoint.""" ... @abstractmethod @@ -289,17 +303,14 @@ class BuiltinTrainer: ```python @dataclass class TorchTuneConfig(LLMBackend): + command = ("tune", "run") + dtype: DataType | None = None batch_size: int | None = None epochs: int | None = None loss: Loss | None = None - num_nodes: int | None = None peft_config: LoraConfig | None = None dataset_preprocess_config: TorchTuneInstructDataset | None = None - resources_per_node: dict | None = None - - def to_command(self) -> tuple[str, ...]: - return ("tune", "run") def to_args(self, initializer=None) -> list[str]: # Existing get_args_using_torchtune_config() logic moves here @@ -309,51 +320,6 @@ class TorchTuneConfig(LLMBackend): ... ``` -### Python SDK: Backend Registry - -A decorator-based registry enables out-of-tree backends (community requirement from #2752): - -```python -_BACKEND_REGISTRY: dict[str, type[LLMBackend]] = {} - - -def register_backend(name: str): - """Register an LLMBackend implementation under a framework name. - - Usage: - @register_backend("trl") - class TRLConfig(LLMBackend): - ... - """ - def decorator(cls: type[LLMBackend]) -> type[LLMBackend]: - if not issubclass(cls, LLMBackend): - raise TypeError(f"{cls.__name__} must subclass LLMBackend") - _BACKEND_REGISTRY[name] = cls - return cls - return decorator - - -def get_backend(name: str) -> type[LLMBackend]: - """Look up a registered backend by name.""" - if name not in _BACKEND_REGISTRY: - raise KeyError( - f"Unknown backend '{name}'. Registered: {list(_BACKEND_REGISTRY)}" - ) - return _BACKEND_REGISTRY[name] -``` - -Built-in backends register themselves at import time: - -```python -@register_backend("torchtune") -class TorchTuneConfig(LLMBackend): - ... - -@register_backend("trl") -class TRLConfig(LLMBackend): - ... -``` - ### Python SDK: `TRLConfig` ```python @@ -369,7 +335,6 @@ class TRLTrainerType(Enum): @dataclass -@register_backend("trl") class TRLConfig(LLMBackend): """TRL LLM Trainer configuration. @@ -377,8 +342,6 @@ class TRLConfig(LLMBackend): trainer_type: Training algorithm (SFT, DPO, KTO, GRPO). model_name_or_path: HuggingFace model ID or local path. dataset_name: HuggingFace dataset ID or local path. - num_nodes: Number of training nodes. - resources_per_node: Resource requirements dict. learning_rate: Learning rate. num_train_epochs: Number of training epochs. per_device_train_batch_size: Batch size per device. @@ -391,11 +354,11 @@ class TRLConfig(LLMBackend): extra_args: Additional CLI arguments passed through verbatim. """ + command = ("trl",) + trainer_type: TRLTrainerType = TRLTrainerType.SFT model_name_or_path: str | None = None dataset_name: str | None = None - num_nodes: int | None = None - resources_per_node: dict | None = None learning_rate: float | None = None num_train_epochs: int | None = None per_device_train_batch_size: int | None = None @@ -407,9 +370,6 @@ class TRLConfig(LLMBackend): lora_target_modules: str | None = None extra_args: dict[str, str] | None = None - def to_command(self) -> tuple[str, ...]: - return ("trl",) - def to_args(self, initializer=None) -> list[str]: args = [self.trainer_type.value] # subcommand: "sft", "dpo", etc. @@ -461,7 +421,7 @@ class TRLConfig(LLMBackend): ### Python SDK: Integration into `KubernetesBackend` The current `get_trainer_cr_from_builtin_trainer()` hardcodes `isinstance(trainer.config, TorchTuneConfig)`. -This changes to use the `LLMBackend` interface: +This changes to use the `LLMBackend` interface directly: ```python # backends/kubernetes/utils.py (modified) @@ -472,18 +432,15 @@ def get_trainer_cr_from_builtin_trainer( initializer: types.Initializer | None = None, ) -> models.TrainerV1alpha1Trainer: config = trainer.config - if not isinstance(config, LLMBackend): - raise ValueError(f"BuiltinTrainer config must implement LLMBackend, got: {type(config)}") - config.validate() trainer_cr = models.TrainerV1alpha1Trainer() - if hasattr(config, "num_nodes") and config.num_nodes: + if config.num_nodes: trainer_cr.num_nodes = config.num_nodes - if hasattr(config, "resources_per_node") and config.resources_per_node: + if config.resources_per_node: trainer_cr.resources_per_node = get_resources_per_node(config.resources_per_node) - trainer_cr.command = list(config.to_command()) + trainer_cr.command = list(config.command) trainer_cr.args = config.to_args(initializer) return trainer_cr ``` @@ -854,7 +811,8 @@ client.train( |------|------------| | TRL CLI changes across versions | Pin version range in requirements.txt; version compat tests | | TRL uses accelerate, not torchrun, for distributed | TRLStrategy injects both `PET_*` and standard env vars; accelerate reads `MASTER_ADDR`, `MASTER_PORT`, `WORLD_SIZE`, `RANK`; validated in E2E | -| Multi-node TRL untested at scale | Phase 1 scoped to single-node multi-GPU; multi-node added in Phase 2 with dedicated E2E | +| Multi-node TRL untested at scale | Initial implementation scoped to single-node multi-GPU; multi-node validated with dedicated E2E before GA | | SDK type widening affects static analysis | TorchTuneConfig is a subtype of LLMBackend; passes type checks | | Scope creep from adding backends | Scoped to TorchTune + TRL only | -| `trainer.kubeflow.org/framework` label not a Go constant | KEP adds `RuntimeFrameworkLabel` constant; existing manifests already use the label | \ No newline at end of file +| `trainer.kubeflow.org/framework` label not a Go constant | KEP adds `RuntimeFrameworkLabel` constant; existing manifests already use the label | +| KEP-285 `BaseTrainer` hierarchy merges before this KEP | `LLMBackend` is a separate ABC for config-driven trainers; if `BuiltinTrainer` is deprecated, `LLMBackend` implementations migrate to a config-driven Tier 2 trainer with minimal changes | \ No newline at end of file From 158c7adf340f02a8bb08cf1d71e0246378247126 Mon Sep 17 00:00:00 2001 From: Sabari Date: Sat, 28 Mar 2026 12:29:45 +0530 Subject: [PATCH 06/11] docs: redesign KEP-2839 to align with KEP-285 BaseTrainer hierarchy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace standalone LLMBackend ABC with ConfigTrainer that integrates into KEP-285's BaseTrainer type hierarchy, directly answering open questions from maintainers about how config-driven trainers fit alongside function-based trainers. Key changes: - LLMBackend → ConfigTrainer(BaseTrainer) for unified type hierarchy - LLMBackendStrategy → FrameworkStrategy (matches framework label convention) - TorchTuneConfig → TorchTuneTrainer with backward-compatible alias - TRLConfig → TRLTrainer with runtime auto-discovery support - Added detailed KEP-285 relationship section with maintainer references - Added implementation history and KEP.yaml-style metadata Tracking issue: kubeflow/trainer#2839 --- .../2839-dynamic-llm-trainer/README.md | 950 ++++++++++++------ 1 file changed, 635 insertions(+), 315 deletions(-) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index e4b72e814d..ecaf0c5552 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -1,114 +1,229 @@ # KEP-2839: Dynamic LLM Trainer Framework -**Authors**: NarayanaSabari - -**Status**: Provisional - -**Creation date**: 2026-02-27 - -**Tracking issue**: [kubeflow/trainer#2839](https://github.com/kubeflow/trainer/issues/2839) - -**Upstream KEP**: [KEP-2401: Kubeflow LLM Trainer V2](../2401-llm-trainer-v2/README.md) +| | | +| -------------- | ------------------------------------------------------------ | +| **Authors** | @NarayanaSabari | +| **Status** | Provisional | +| **Created** | 2026-02-27 | +| **Updated** | 2026-03-28 | +| **Reviewers** | @tariq-hasan, @andreyvelich, @Electronic-Waste | +| **Tracking** | [kubeflow/trainer#2839](https://github.com/kubeflow/trainer/issues/2839) | ## Table of Contents -- [KEP-2839: Dynamic LLM Trainer Framework](#kep-2839-dynamic-llm-trainer-framework) - - [Table of Contents](#table-of-contents) - - [Summary](#summary) - - [Goals](#goals) - - [Non-Goals](#non-goals) - - [Current State Analysis](#current-state-analysis) - - [How TorchTune Is Wired Today](#how-torchtune-is-wired-today) - - [SDK Coupling](#sdk-coupling) - - [Why This Must Change](#why-this-must-change) - - [Relationship to KEP-285 (Specialized Trainer Abstractions)](#relationship-to-kep-285-specialized-trainer-abstractions) - - [High-Level Design](#high-level-design) - - [Architecture Overview](#architecture-overview) - - [Component Interaction Flow](#component-interaction-flow) - - [What Changes vs What Stays](#what-changes-vs-what-stays) - - [Design Details](#design-details) - - [Python SDK: `LLMBackend` Interface](#python-sdk-llmbackend-interface) - - [Python SDK: `TRLConfig`](#python-sdk-trlconfig) - - [Python SDK: Integration into `KubernetesBackend`](#python-sdk-integration-into-kubernetesbackend) - - [Go Control Plane: `LLMBackendStrategy` Interface](#go-control-plane-llmbackendstrategy-interface) - - [Go Control Plane: `TorchTuneStrategy`](#go-control-plane-torchtunestrategy) - - [Go Control Plane: `TRLStrategy`](#go-control-plane-trlstrategy) - - [Go Control Plane: Refactored Torch Plugin Dispatch](#go-control-plane-refactored-torch-plugin-dispatch) - - [Go Control Plane: New Constant](#go-control-plane-new-constant) - - [TRL Container Image](#trl-container-image) - - [TRL `ClusterTrainingRuntime` Manifest](#trl-clustertrainingruntime-manifest) - - [SDK Usage Example](#sdk-usage-example) - - [Risks and Mitigations](#risks-and-mitigations) +- [Summary](#summary) +- [Motivation](#motivation) + - [Background](#background) + - [Why This Must Change](#why-this-must-change) +- [Goals](#goals) +- [Non-Goals](#non-goals) +- [Relationship to KEP-285 (Specialized Trainer Abstractions)](#relationship-to-kep-285-specialized-trainer-abstractions) + - [The ConfigTrainer vs FuncTrainer Question](#the-configtrainer-vs-functrainer-question) + - [Unified Type Hierarchy](#unified-type-hierarchy) + - [Shared Design Points](#shared-design-points) +- [Current State Analysis](#current-state-analysis) + - [SDK Coupling](#sdk-coupling) + - [Go Control Plane: Command-Sniffing](#go-control-plane-command-sniffing) +- [High-Level Design](#high-level-design) + - [Architecture Overview](#architecture-overview) + - [Component Interaction Flow](#component-interaction-flow) + - [What Changes vs What Stays](#what-changes-vs-what-stays) +- [Design Details](#design-details) + - [Python SDK: ConfigTrainer Base Class](#python-sdk-configtrainer-base-class) + - [Python SDK: TorchTuneTrainer (Refactored)](#python-sdk-torchtunetrainer-refactored) + - [Python SDK: TRLTrainer](#python-sdk-trltrainer) + - [Python SDK: TrainerClient Integration](#python-sdk-trainerclient-integration) + - [Python SDK: Backward Compatibility](#python-sdk-backward-compatibility) + - [Go Control Plane: FrameworkStrategy Interface](#go-control-plane-frameworkstrategy-interface) + - [Go Control Plane: TorchTuneStrategy](#go-control-plane-torchtunestrategy) + - [Go Control Plane: TRLStrategy](#go-control-plane-trlstrategy) + - [Go Control Plane: Refactored Torch Plugin Dispatch](#go-control-plane-refactored-torch-plugin-dispatch) + - [Go Control Plane: New Constant](#go-control-plane-new-constant) + - [TRL Container Image](#trl-container-image) + - [TRL ClusterTrainingRuntime Manifests](#trl-clustertrainingruntime-manifests) +- [User-Facing API Examples](#user-facing-api-examples) + - [TRL SFT Fine-Tuning](#trl-sft-fine-tuning) + - [TRL DPO Alignment](#trl-dpo-alignment) + - [TorchTune (Backward Compatible)](#torchtune-backward-compatible) + - [Backward Compatible: BuiltinTrainer Still Works](#backward-compatible-builtintrainer-still-works) +- [Implementation Plan](#implementation-plan) +- [Test Plan](#test-plan) +- [Risks and Mitigations](#risks-and-mitigations) +- [Implementation History](#implementation-history) --- ## Summary -Decouple the `BuiltinTrainer` from TorchTune by introducing a pluggable `LLMBackend` -interface in the SDK and a corresponding `LLMBackendStrategy` in the Go control plane. -TorchTune becomes the first backend implementation (preserving backward compatibility), -and TRL is added as the first new backend with SFT/DPO support. Config-driven backends -sit alongside [KEP-285](https://github.com/kubeflow/sdk/pull/308)'s function-based -trainers as Tier 2 extensions; see -[Relationship to KEP-285](#relationship-to-kep-285-specialized-trainer-abstractions). +This KEP introduces a **pluggable config-driven trainer framework** for LLM fine-tuning +in Kubeflow Trainer. It decouples the SDK and Go control plane from TorchTune by +introducing: + +1. A `ConfigTrainer` base class in the Python SDK that sits within + [KEP-285](https://github.com/kubeflow/sdk/pull/308)'s `BaseTrainer` hierarchy as + the foundation for all **config-driven trainers** (where the framework's own CLI is + the entrypoint, not a user-supplied Python function). + +2. A `FrameworkStrategy` interface in the Go Torch plugin that replaces hardcoded + command-sniffing with label-based dispatch via `trainer.kubeflow.org/framework`. + +3. **TRL** as the first new backend with SFT and DPO support, alongside TorchTune + refactored as a backward-compatible implementation. + +This builds on [KEP-2401](../2401-llm-trainer-v2/README.md), the community consensus on +"Plan 3" in [#2752](https://github.com/kubeflow/trainer/issues/2752), and aligns with +the `BaseTrainer` hierarchy being designed in +[KEP-285](https://github.com/kubeflow/sdk/pull/308). + +--- + +## Motivation + +### Background + +Kubeflow Trainer V2 introduced LLM fine-tuning support through +[KEP-2401](../2401-llm-trainer-v2/README.md), using TorchTune as the backend. The +implementation was successful for its scope, but the architecture hardcodes TorchTune +at two coupling points: + +- **SDK**: `BuiltinTrainer.config` is typed as `TorchTuneConfig` with no abstraction. +- **Go Torch plugin**: `EnforceMLPolicy()` uses command-sniffing + (`slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint)`) to + decide between the torchrun path and the TorchTune path. + +### Why This Must Change + +- **TorchTune stopped adding features** in July 2025 + ([pytorch/torchtune#2883](https://github.com/pytorch/torchtune/issues/2883)). New + models and post-training methods (DPO, PPO, ORPO) will not be supported. +- **The command-sniffing pattern doesn't scale.** Each new backend would require + another `slices.Equal` check, another branch in `EnforceMLPolicy`, and another + branch in `Validate`. +- **Community consensus on Plan 3** (pluggable framework) from + [#2752](https://github.com/kubeflow/trainer/issues/2752) was unanimous. +- **TRL is actively maintained** by Hugging Face with native CLI support + (`trl sft`, `trl dpo`, etc.) and built-in accelerate integration for multi-GPU and + multi-node training. +- **KEP-285 is actively designing** the `BaseTrainer` hierarchy and the maintainers + are [asking exactly how config-driven trainers fit in](https://github.com/kubeflow/sdk/pull/308#discussion_r2912976804). + This KEP provides that answer. -This builds on [KEP-2401](../2401-llm-trainer-v2/README.md) and the community consensus -on "Plan 3" in [#2752](https://github.com/kubeflow/trainer/issues/2752). -TorchTune stopped adding features in July 2025 -([pytorch/torchtune#2883](https://github.com/pytorch/torchtune/issues/2883)). +--- ## Goals -1. Define an `LLMBackend` abstract interface in the Python SDK for config-driven trainers. -2. Refactor `TorchTuneConfig` to implement `LLMBackend` with zero breaking changes. -3. Implement `TRLConfig` backend supporting SFT and DPO. +1. Define a `ConfigTrainer` base class within KEP-285's `BaseTrainer` hierarchy for + config-driven LLM trainers. +2. Refactor `TorchTuneConfig` into `TorchTuneTrainer` implementing `ConfigTrainer` + with zero breaking changes to existing workflows. +3. Implement `TRLTrainer` supporting SFT and DPO training algorithms. 4. Create TRL container image and `ClusterTrainingRuntime` manifests. -5. Generalize the Go Torch plugin to dispatch via `LLMBackendStrategy` instead of - hardcoded TorchTune command-sniffing. +5. Generalize the Go Torch plugin to dispatch via `FrameworkStrategy` instead of + hardcoded command-sniffing. +6. Maintain full backward compatibility with existing `BuiltinTrainer` API. ## Non-Goals -1. Unsloth or LlamaFactory backends (future work). -2. CRD schema changes — operates within existing `.spec.trainer.command`/`.spec.trainer.args`. +1. Unsloth, LlamaFactory, or other backends (future work following the same pattern). +2. CRD schema changes -- operates within existing `.spec.trainer.command`/`.spec.trainer.args`. 3. New Kubernetes resource topologies (e.g., launcher/worker patterns). -4. Go-side distributed training plugins per backend (backends use existing torchrun infra). +4. Deprecating `BuiltinTrainer` or `CustomTrainer` (both remain supported). +5. Implementing function-based trainers (that is KEP-285's Tier 1 scope). --- -## Current State Analysis +## Relationship to KEP-285 (Specialized Trainer Abstractions) -### How TorchTune Is Wired Today +[KEP-285](https://github.com/kubeflow/sdk/pull/308) introduces a `BaseTrainer` ABC +with framework-specific Tier 1 trainers (`TorchTrainer`, `JAXTrainer`, etc.) and +community-contributed Tier 2 extensions. This KEP is designed to integrate directly +into that hierarchy. -The Torch plugin (`pkg/runtime/framework/plugins/torch/torch.go`) is the only ML policy -plugin that handles LLM fine-tuning. It hardcodes TorchTune support via **command-sniffing**: +### The ConfigTrainer vs FuncTrainer Question -```go -// torch.go:149 — the branching point -if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { - // Standard torchrun path: inject PET_MASTER_ADDR, PET_MASTER_PORT -} else { - // TorchTune path: mutate command with recipe, config, rdzv_endpoint -} +In the KEP-285 review, @andreyvelich +[asked](https://github.com/kubeflow/sdk/pull/308#discussion_r2912976804): + +> "How are we going to refactor the BuiltinTrainer interface once we implement the +> BaseTrainer? And how can we dynamically register new LLM fine-tuning framework +> backends?" + +And @tariq-hasan +[asked](https://github.com/kubeflow/sdk/pull/308#discussion_r2901688930): + +> "How do we handle config-driven trainers for post-training LLM fine-tuning? Do we +> segregate them outside BaseTrainer scope?" + +And @szaher +[proposed](https://github.com/kubeflow/sdk/pull/308#discussion_r2955718123): + +> "Should I rename the proposal to have two main abstract classes `ConfigTrainer` and +> `FuncTrainer`?" + +This KEP answers these questions. There are two fundamentally different trainer +patterns in Kubeflow: + +| Pattern | Entrypoint | Examples | KEP | +|---------|-----------|----------|-----| +| **Function-based** (`FuncTrainer`) | User's Python `train()` function | TorchTrainer, JAXTrainer | KEP-285 Tier 1 | +| **Config-driven** (`ConfigTrainer`) | Framework's own CLI | TorchTune, TRL, Unsloth | This KEP (Tier 2) | + +### Unified Type Hierarchy + +``` + BaseTrainer (ABC) ← KEP-285 + ├── get_train_func() + ├── get_framework_args() + ├── validate_runtime() + └── supported_frameworks + │ + ┌────────────────┼────────────────┐ + │ │ │ + TorchTrainer JAXTrainer ConfigTrainer (ABC) ← This KEP + (Tier 1) (Tier 1) ├── command + ├── to_args() + └── validate() + │ + ┌───────────────┼───────────────┐ + │ │ │ + TorchTuneTrainer TRLTrainer (future: Unsloth, + (Tier 2) (Tier 2) LlamaFactory) + + + Existing (unchanged, backward compatible): + + CustomTrainer BuiltinTrainer CustomTrainerContainer + (flat dataclass) (config: ConfigTrainer) (image-based) ``` -`constants.TorchTuneEntrypoint` is `[]string{"tune", "run"}`. When the trainer command -matches this, the plugin enters the TorchTune branch (torch.go:159-183) which: +`ConfigTrainer` extends `BaseTrainer` by adding: +- A `command` class variable (the CLI entrypoint, e.g., `("trl",)` or `("tune", "run")`) +- A `to_args()` method that translates config into CLI arguments +- A `validate()` method for config-level validation + +`ConfigTrainer.get_train_func()` returns `None` (there is no user function -- the +framework's CLI **is** the entrypoint). `ConfigTrainer.get_framework_args()` delegates +to `to_args()`. + +### Shared Design Points -1. Builds the rendezvous endpoint: `--rdzv_endpoint={name}-node-0-0.{name}:29500` -2. Calls `getRecipeAndConfig()` (torchtune.go:78) to select a recipe/config pair - from a matrix of `numNodes × numGPUs × LoRA/QLoRA` combinations. -3. Calls `extractOverridesFromRuntime()` (torchtune.go:131) to pull immutable config - overrides from the ClusterTrainingRuntime's node container command. -4. Appends all of this to `trainJob.Spec.Trainer.Command`. +- Both KEPs use `trainer.kubeflow.org/framework` as the dispatch key. KEP-285 uses it + for SDK runtime auto-discovery; this KEP uses it for Go strategy dispatch. +- Both KEPs are compatible with either keeping or deprecating `BuiltinTrainer`. +- If the framework label is + [promoted to a Runtime API spec field](https://github.com/kubeflow/sdk/pull/308#discussion_r2894627115) + (as discussed in KEP-285), both KEPs benefit with no changes. + +--- -The validation path (torch.go:88) also sniffs the same entrypoint to decide whether -to run `validateTorchTune()`. +## Current State Analysis ### SDK Coupling -In the Python SDK (`kubeflow/sdk` repo), `BuiltinTrainer` has a single field: +In the Python SDK, `BuiltinTrainer` has a single field +([types.py:226-237](https://github.com/kubeflow/sdk/blob/main/kubeflow/trainer/types/types.py#L226)): ```python @dataclass @@ -116,41 +231,33 @@ class BuiltinTrainer: config: TorchTuneConfig # No other option ``` -The `KubernetesBackend.train()` method calls `get_args_using_torchtune_config()` in -`backends/kubernetes/utils.py` to translate the config into CLI args. There is no -abstraction — adding a new backend means modifying this function and the type annotation. - -### Why This Must Change - -- **TorchTune stopped adding features** in July 2025. The project is in maintenance mode. -- **The command-sniffing pattern doesn't scale.** Each new backend would require another - `slices.Equal` check, another branch in `EnforceMLPolicy`, and another branch in `Validate`. -- **Community consensus on Plan 3** (pluggable framework) from #2752 was unanimous. -- **TRL is actively maintained** by HuggingFace with native CLI support (`trl sft`, `trl dpo`, etc.) - and built-in accelerate integration for multi-GPU/multi-node. - ---- +The comment at line 240 explicitly signals readiness for change: +```python +# Change it to list: BUILTIN_CONFIGS, once we support more Builtin Trainer configs. +``` -## Relationship to KEP-285 (Specialized Trainer Abstractions) +The `KubernetesBackend` calls `get_args_using_torchtune_config()` +([utils.py:467-521](https://github.com/kubeflow/sdk/blob/main/kubeflow/trainer/backends/kubernetes/utils.py#L467)) +with no abstraction -- adding a new backend means modifying this function and the +type annotation. -[KEP-285](https://github.com/kubeflow/sdk/pull/308) introduces a `BaseTrainer` ABC for -function-based trainers (`TorchTrainer`, `JAXTrainer`, etc.) and a `RuntimeConfig` -dataclass. This KEP is complementary — it addresses **config-driven trainers** where the -framework's own CLI is the entrypoint (e.g., `trl sft ...`, `tune run ...`), not a -user-supplied Python function. +### Go Control Plane: Command-Sniffing -In KEP-285's terminology, `LLMBackend` implementations are **Tier 2 config-driven -trainers**. If KEP-285 merges first, `LLMBackend` configs can be passed through -KEP-285's `TorchTrainer` instead of `BuiltinTrainer` — the interface is the same -(`command` class var / `to_args()`), only the entry point changes. +The Torch plugin +([torch.go:149](https://github.com/kubeflow/trainer/blob/master/pkg/runtime/framework/plugins/torch/torch.go#L149)) +uses command-sniffing to branch: -**Shared design points**: +```go +if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { + // Standard torchrun path +} else { + // TorchTune path: recipe selection, config overrides, rdzv_endpoint +} +``` -- Both use `trainer.kubeflow.org/framework` as the dispatch key — KEP-285 for SDK - runtime auto-discovery, this KEP for Go strategy dispatch. -- Both KEPs are compatible with either keeping or deprecating `BuiltinTrainer`. -- If the framework label is promoted to a Runtime API spec field (as discussed in the - KEP-285 review), both KEPs benefit with no changes. +This pattern requires a new `slices.Equal` check for every new backend. The +validation path ([torch.go:88](https://github.com/kubeflow/trainer/blob/master/pkg/runtime/framework/plugins/torch/torch.go#L88)) +similarly sniffs the entrypoint to decide whether to run `validateTorchTune()`. --- @@ -162,41 +269,41 @@ The change is a **localized refactor** of two coupling points. No new CRDs, no n controllers, no changes to the plugin framework itself. ``` - BEFORE AFTER - ┌──────────────┐ ┌──────────────┐ - SDK │BuiltinTrainer│ │BuiltinTrainer│ - │ config: │ │ config: │ - │ TorchTune │ │ LLMBackend │ - │ Config │ │ (abstract) │ - └──────┬───────┘ └──────┬───────┘ - │ │ - │ to_args() │ config.command / to_args() - ▼ ▼ - get_args_using_ config.command - torchtune_config() config.to_args() - │ │ - │ creates TrainJob CR │ creates TrainJob CR - ▼ ▼ - ┌────────────────────────────────────────────────────────────────────────┐ - │ Kubernetes API │ - └────────────────────────────────┬───────────────────────────────────────┘ - │ - Go ▼ - Torch ┌─────────────────────────────────┐ - Plugin │ EnforceMLPolicy() │ - │ │ - BEFORE: │ if cmd == ["tune","run"]: │ - │ → TorchTune branch │ - │ else: │ - │ → torchrun branch │ - │ │ - AFTER: │ // common: PET env vars │ - │ label = info.Labels[framework] │ - │ if strategy = backends[label]: │ - │ → strategy.EnforceCommand() │ - │ else: │ - │ → default torchrun branch │ - └─────────────────────────────────┘ + BEFORE AFTER + ┌──────────────┐ ┌──────────────┐ + SDK │BuiltinTrainer│ │BuiltinTrainer│ + │ config: │ │ config: │ + │ TorchTune │ │ Config │ + │ Config │ │ Trainer │ + └──────┬───────┘ └──────┬───────┘ + │ │ + │ hardcoded │ config.command + │ get_args_using_ │ config.to_args() + │ torchtune_config() │ + ▼ ▼ + creates TrainJob CR creates TrainJob CR + │ │ + ┌────────────────────────────────────────────────────────────────────┐ + │ Kubernetes API │ + └──────────────────────────┬─────────────────────────────────────────┘ + │ + Go ▼ + Torch ┌─────────────────────────────┐ + Plugin │ EnforceMLPolicy() │ + │ │ + BEFORE: │ if cmd == ["tune","run"]: │ + │ → TorchTune branch │ + │ else: │ + │ → torchrun branch │ + │ │ + AFTER: │ label = info.Labels │ + │ [framework] │ + │ if strategy = backends │ + │ [label]: │ + │ → strategy.Enforce() │ + │ else: │ + │ → default torchrun │ + └─────────────────────────────┘ ``` ### Component Interaction Flow @@ -204,25 +311,32 @@ controllers, no changes to the plugin framework itself. End-to-end for a TRL SFT job: ``` -1. User: TrainerClient.train(builtin_trainer=BuiltinTrainer(config=TRLConfig( - trainer_type=TRLTrainerType.SFT, ...))) +1. User: TrainerClient.train( + trainer=TRLTrainer(trainer_type=SFT, ...), + runtime="trl-llama3.2-1b") + + -- OR with auto-discovery -- + + User: TrainerClient.train( + trainer=TRLTrainer(trainer_type=SFT, ...)) + # SDK finds runtime with label trainer.kubeflow.org/framework: trl -2. SDK: TRLConfig.validate() → ok - TRLConfig.command → ("trl",) - TRLConfig.to_args() → ["sft", "--model_name_or_path", "/workspace/model", ...] +2. SDK: TRLTrainer.validate() → ok + TRLTrainer.command → ("trl",) + TRLTrainer.to_args() → ["sft", "--model_name_or_path", ...] Build TrainJob CR with: runtimeRef: { name: "trl-llama3.2-1b" } trainer: { command: ["trl"], args: ["sft", ...] } 3. K8s: Webhook validates TrainJob - Torch plugin Validate() → label=trl → TRLStrategy.Validate() → ok + Torch plugin Validate() → label=trl → TRLStrategy.Validate() 4. Go: TrainJob controller reconciles: Torch EnforceMLPolicy(): a) Common: set PET_NNODES, PET_NPROC_PER_NODE, PET_NODE_RANK b) Label "trl" → TRLStrategy.EnforceCommand(): inject PET_MASTER_ADDR, PET_MASTER_PORT - inject MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK (accelerate-compatible) + inject MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK c) Add container port 5. K8s: Controller SSA → JobSet → ReplicatedJobs → Pods @@ -243,9 +357,9 @@ End-to-end for a TRL SFT job: | Torch plugin (TorchTune path) | **Refactor** | Extract inline code → `TorchTuneStrategy` | | Torch plugin (dispatch) | **New** | Label-based strategy lookup replaces command-sniffing | | TRL strategy | **New** | `TRLStrategy` for TRL-specific env vars | -| SDK `BuiltinTrainer` | **Widen** | `TorchTuneConfig` → `LLMBackend` | -| SDK `TorchTuneConfig` | **Implement** | Implements `LLMBackend` (backward compatible) | -| SDK `TRLConfig` | **New** | New backend class | +| SDK `BuiltinTrainer` | **Widen** | `TorchTuneConfig` → `ConfigTrainer` | +| SDK `TorchTuneConfig` | **Refactor** | → `TorchTuneTrainer(ConfigTrainer)`, backward compatible | +| SDK `TRLTrainer` | **New** | New config-driven trainer | | Container images | **New** | `trl-trainer` image | | ClusterTrainingRuntimes | **New** | TRL-specific runtime manifests | @@ -253,33 +367,42 @@ End-to-end for a TRL SFT job: ## Design Details -### Python SDK: `LLMBackend` Interface +### Python SDK: ConfigTrainer Base Class -Today `BuiltinTrainer.config` is typed as `TorchTuneConfig` directly. This introduces an -abstract base class that every config-driven backend must implement. +`ConfigTrainer` extends KEP-285's `BaseTrainer` for config-driven trainers where the +framework's own CLI is the entrypoint. It bridges the gap between function-based +Tier 1 trainers and the existing `BuiltinTrainer`. ```python -from abc import ABC, abstractmethod +from abc import abstractmethod from dataclasses import dataclass -from typing import ClassVar +from typing import Callable, ClassVar, Optional -class LLMBackend(ABC): - """Abstract base for config-driven LLM training backends. +@dataclass +class ConfigTrainer(BaseTrainer): + """Base class for config-driven LLM training backends. + + Config-driven trainers use the framework's own CLI as the entrypoint + (e.g., `trl sft ...`, `tune run ...`) rather than a user-supplied + Python function. Each implementation translates its config into a + (command, args) pair that the Kubernetes backend writes into the + TrainJob CR. - Each implementation translates its config into a (command, args) pair - that the Kubernetes backend writes into the TrainJob CR. + This class sits in KEP-285's BaseTrainer hierarchy as the foundation + for Tier 2 config-driven trainers. """ - # Subclasses set this to their CLI entrypoint, e.g. ("tune", "run") or ("trl",) + # Subclasses set this to their CLI entrypoint. + # e.g., ("tune", "run") for TorchTune, ("trl",) for TRL. command: ClassVar[tuple[str, ...]] - # Common fields shared by all backends - num_nodes: int | None = None - resources_per_node: dict | None = None + # Common fields shared by all config-driven trainers. + num_nodes: Optional[int] = None + resources_per_node: Optional[dict] = None @abstractmethod - def to_args(self, initializer: "Initializer | None" = None) -> list[str]: + def to_args(self, initializer: Optional["Initializer"] = None) -> list[str]: """Return CLI arguments for the entrypoint.""" ... @@ -287,40 +410,61 @@ class LLMBackend(ABC): def validate(self) -> None: """Raise ValueError if the config is invalid.""" ... -``` -`BuiltinTrainer` widens its type annotation: + # --- BaseTrainer interface implementation --- -```python -@dataclass -class BuiltinTrainer: - """Builtin Trainer configuration.""" - config: LLMBackend # was: TorchTuneConfig + def get_train_func(self) -> Optional[Callable]: + """Config-driven trainers have no user function.""" + return None + + def get_train_func_args(self) -> Optional[dict]: + """Config-driven trainers have no function args.""" + return None + + def get_framework_args(self) -> dict: + """Delegate to to_args() for CLI argument generation.""" + return {"_config_args": self.to_args()} ``` -`TorchTuneConfig` implements `LLMBackend` with no field changes — backward compatible: +### Python SDK: TorchTuneTrainer (Refactored) + +`TorchTuneConfig` is refactored into `TorchTuneTrainer` implementing `ConfigTrainer`. +All existing fields are preserved. `TorchTuneConfig` becomes a type alias for backward +compatibility. ```python @dataclass -class TorchTuneConfig(LLMBackend): - command = ("tune", "run") +class TorchTuneTrainer(ConfigTrainer): + """TorchTune LLM Trainer configuration. + + Supports runtimes labeled with trainer.kubeflow.org/framework: torchtune. + """ - dtype: DataType | None = None - batch_size: int | None = None - epochs: int | None = None - loss: Loss | None = None - peft_config: LoraConfig | None = None - dataset_preprocess_config: TorchTuneInstructDataset | None = None + supported_frameworks: ClassVar[list[str]] = ["torchtune"] + command: ClassVar[tuple[str, ...]] = ("tune", "run") + + # All existing TorchTuneConfig fields preserved. + dtype: Optional[DataType] = None + batch_size: Optional[int] = None + epochs: Optional[int] = None + loss: Optional[Loss] = None + peft_config: Optional[LoraConfig] = None + dataset_preprocess_config: Optional[TorchTuneInstructDataset] = None def to_args(self, initializer=None) -> list[str]: - # Existing get_args_using_torchtune_config() logic moves here + # Existing get_args_using_torchtune_config() logic moves here. ... def validate(self) -> None: + # Validate supported model, LoRA config, etc. ... + + +# Backward compatibility alias. +TorchTuneConfig = TorchTuneTrainer ``` -### Python SDK: `TRLConfig` +### Python SDK: TRLTrainer ```python from enum import Enum @@ -335,13 +479,17 @@ class TRLTrainerType(Enum): @dataclass -class TRLConfig(LLMBackend): +class TRLTrainer(ConfigTrainer): """TRL LLM Trainer configuration. + Supports runtimes labeled with trainer.kubeflow.org/framework: trl. + TRL is maintained by Hugging Face with native CLI support and built-in + accelerate integration for multi-GPU/multi-node training. + Args: trainer_type: Training algorithm (SFT, DPO, KTO, GRPO). - model_name_or_path: HuggingFace model ID or local path. - dataset_name: HuggingFace dataset ID or local path. + model_name_or_path: Hugging Face model ID or local path. + dataset_name: Hugging Face dataset ID or local path. learning_rate: Learning rate. num_train_epochs: Number of training epochs. per_device_train_batch_size: Batch size per device. @@ -354,33 +502,34 @@ class TRLConfig(LLMBackend): extra_args: Additional CLI arguments passed through verbatim. """ - command = ("trl",) + supported_frameworks: ClassVar[list[str]] = ["trl"] + command: ClassVar[tuple[str, ...]] = ("trl",) trainer_type: TRLTrainerType = TRLTrainerType.SFT - model_name_or_path: str | None = None - dataset_name: str | None = None - learning_rate: float | None = None - num_train_epochs: int | None = None - per_device_train_batch_size: int | None = None + model_name_or_path: Optional[str] = None + dataset_name: Optional[str] = None + learning_rate: Optional[float] = None + num_train_epochs: Optional[int] = None + per_device_train_batch_size: Optional[int] = None gradient_checkpointing: bool = True bf16: bool = True use_peft: bool = False - lora_r: int | None = None - lora_alpha: int | None = None - lora_target_modules: str | None = None - extra_args: dict[str, str] | None = None + lora_r: Optional[int] = None + lora_alpha: Optional[int] = None + lora_target_modules: Optional[str] = None + extra_args: Optional[dict[str, str]] = None def to_args(self, initializer=None) -> list[str]: args = [self.trainer_type.value] # subcommand: "sft", "dpo", etc. - # Model path: prefer initializer workspace, fall back to config + # Model path: prefer initializer workspace, fall back to config. model_path = self.model_name_or_path if initializer and initializer.model: model_path = "/workspace/model" if model_path: args.extend(["--model_name_or_path", model_path]) - # Dataset: prefer initializer workspace, fall back to config + # Dataset: prefer initializer workspace, fall back to config. dataset = self.dataset_name if initializer and initializer.dataset: dataset = "/workspace/dataset" @@ -392,7 +541,8 @@ class TRLConfig(LLMBackend): if self.num_train_epochs is not None: args.extend(["--num_train_epochs", str(self.num_train_epochs)]) if self.per_device_train_batch_size is not None: - args.extend(["--per_device_train_batch_size", str(self.per_device_train_batch_size)]) + args.extend(["--per_device_train_batch_size", + str(self.per_device_train_batch_size)]) if self.gradient_checkpointing: args.append("--gradient_checkpointing") if self.bf16: @@ -406,7 +556,7 @@ class TRLConfig(LLMBackend): if self.lora_target_modules: args.extend(["--lora_target_modules", self.lora_target_modules]) - # Pass-through extra args + # Pass-through extra args. if self.extra_args: for k, v in self.extra_args.items(): args.extend([f"--{k}", v]) @@ -418,36 +568,58 @@ class TRLConfig(LLMBackend): raise ValueError("lora_r is required when use_peft=True") ``` -### Python SDK: Integration into `KubernetesBackend` +### Python SDK: TrainerClient Integration -The current `get_trainer_cr_from_builtin_trainer()` hardcodes `isinstance(trainer.config, TorchTuneConfig)`. -This changes to use the `LLMBackend` interface directly: +The `TrainerClient.train()` method gains support for `ConfigTrainer` through KEP-285's +`BaseTrainer` interface. When a `ConfigTrainer` is passed: + +1. If `runtime` is `None`, the SDK auto-discovers a runtime by matching the + `trainer.kubeflow.org/framework` label against `supported_frameworks` (using + KEP-285's `_resolve_runtime()` mechanism). +2. `validate_runtime()` ensures the runtime's framework label matches. +3. The backend uses `config.command` and `config.to_args()` to build the TrainJob CR. ```python -# backends/kubernetes/utils.py (modified) +# In KubernetesBackend — unified handler for ConfigTrainer. -def get_trainer_cr_from_builtin_trainer( +def get_trainer_cr( runtime: types.Runtime, - trainer: types.BuiltinTrainer, - initializer: types.Initializer | None = None, + trainer: ConfigTrainer, + initializer: Optional[types.Initializer] = None, ) -> models.TrainerV1alpha1Trainer: - config = trainer.config - config.validate() + trainer.validate() trainer_cr = models.TrainerV1alpha1Trainer() - if config.num_nodes: - trainer_cr.num_nodes = config.num_nodes - if config.resources_per_node: - trainer_cr.resources_per_node = get_resources_per_node(config.resources_per_node) + if trainer.num_nodes: + trainer_cr.num_nodes = trainer.num_nodes + if trainer.resources_per_node: + trainer_cr.resources_per_node = get_resources_per_node( + trainer.resources_per_node + ) - trainer_cr.command = list(config.command) - trainer_cr.args = config.to_args(initializer) + trainer_cr.command = list(trainer.command) + trainer_cr.args = trainer.to_args(initializer) return trainer_cr ``` -### Go Control Plane: `LLMBackendStrategy` Interface +### Python SDK: Backward Compatibility + +| Existing API | Status | Details | +|-------------|--------|---------| +| `BuiltinTrainer(config=TorchTuneConfig(...))` | **Works** | `TorchTuneConfig` is an alias for `TorchTuneTrainer` | +| `BuiltinTrainer(config=TRLTrainer(...))` | **New** | `BuiltinTrainer.config` type widens to `ConfigTrainer` | +| `CustomTrainer(func=...)` | **Unchanged** | No modifications | +| `CustomTrainerContainer(image=...)` | **Unchanged** | No modifications | +| `TrainerClient.train(trainer=TRLTrainer(...))` | **New** | Direct `BaseTrainer` subclass via KEP-285 | + +The `BuiltinTrainer.config` field type changes from `TorchTuneConfig` to +`ConfigTrainer`. Since `TorchTuneConfig` is a type alias for `TorchTuneTrainer` +which extends `ConfigTrainer`, all existing code continues to work. -Inside the Torch plugin package, a strategy interface replaces the inline if/else: +### Go Control Plane: FrameworkStrategy Interface + +Inside the Torch plugin package, a strategy interface replaces the inline if/else. +The naming follows the existing `trainer.kubeflow.org/framework` label convention. ```go // pkg/runtime/framework/plugins/torch/strategy.go @@ -462,23 +634,33 @@ import ( "github.com/kubeflow/trainer/v2/pkg/runtime" ) -// LLMBackendStrategy defines backend-specific behavior for the Torch plugin. -// Each strategy handles the portion of EnforceMLPolicy and Validate that differs -// between backends (e.g., command mutation, env var injection, validation rules). -type LLMBackendStrategy interface { - // EnforceCommand mutates the trainer container's command, args, and env vars - // with backend-specific values (e.g., rendezvous args for TorchTune, - // accelerate env vars for TRL). - EnforceCommand(info *runtime.Info, trainJob *trainer.TrainJob, container *runtime.Container) error - - // Validate performs backend-specific validation on the TrainJob. - Validate(runtimeInfo *runtime.Info, trainJob *trainer.TrainJob) (admission.Warnings, field.ErrorList) +// FrameworkStrategy defines backend-specific behavior for the Torch plugin. +// Each strategy handles the portion of EnforceMLPolicy and Validate that +// differs between frameworks (e.g., command mutation, env var injection, +// validation rules). +type FrameworkStrategy interface { + // EnforceCommand mutates the trainer container's command, args, and + // env vars with framework-specific values. + EnforceCommand( + info *runtime.Info, + trainJob *trainer.TrainJob, + container *runtime.Container, + ) error + + // Validate performs framework-specific validation on the TrainJob. + Validate( + runtimeInfo *runtime.Info, + trainJob *trainer.TrainJob, + ) (admission.Warnings, field.ErrorList) } ``` -### Go Control Plane: `TorchTuneStrategy` +### Go Control Plane: TorchTuneStrategy -Extracts the existing inline code from `torch.go:159-183` and `torchtune.go`: +Extracts the existing inline code from +[torch.go:149-183](https://github.com/kubeflow/trainer/blob/master/pkg/runtime/framework/plugins/torch/torch.go#L149) +and the validation from +[torchtune.go](https://github.com/kubeflow/trainer/blob/master/pkg/runtime/framework/plugins/torch/torchtune.go): ```go // pkg/runtime/framework/plugins/torch/torchtune_strategy.go @@ -490,7 +672,7 @@ func (s *TorchTuneStrategy) EnforceCommand( trainJob *trainer.TrainJob, container *runtime.Container, ) error { - // Moved from torch.go:159-183 + // Moved from torch.go:149-183 (unchanged logic): // 1. Build rendezvous endpoint args // 2. Call getRecipeAndConfig() for recipe/config selection // 3. Call extractOverridesFromRuntime() for immutable overrides @@ -502,12 +684,17 @@ func (s *TorchTuneStrategy) Validate( runtimeInfo *runtime.Info, trainJob *trainer.TrainJob, ) (admission.Warnings, field.ErrorList) { - // Calls existing validateTorchTune() + // Delegates to existing validateTorchTune() (torchtune.go:35-74). return validateTorchTune(runtimeInfo, trainJob) } ``` -### Go Control Plane: `TRLStrategy` +### Go Control Plane: TRLStrategy + +TRL uses Hugging Face's `accelerate` for distributed training, which reads standard +environment variables (`MASTER_ADDR`, `MASTER_PORT`, `WORLD_SIZE`, `RANK`) rather +than the `PET_*` variants used by torchrun. The strategy injects both sets for +maximum compatibility. ```go // pkg/runtime/framework/plugins/torch/trl_strategy.go @@ -520,24 +707,36 @@ func (s *TRLStrategy) EnforceCommand( container *runtime.Container, ) error { trainerPS := info.FindPodSetByAncestor(constants.AncestorTrainer) - numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1) - masterAddr := fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name) + numNodes := ptr.Deref( + ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1, + ) + masterAddr := fmt.Sprintf( + "%s-%s-0-0.%s", + trainJob.Name, constants.Node, trainJob.Name, + ) masterPort := fmt.Sprintf("%d", constants.ContainerTrainerPort) - worldSize := fmt.Sprintf("%d", numNodes * numProcPerNode) + worldSize := fmt.Sprintf("%d", numNodes*numProcPerNode) - // TRL uses accelerate, which reads standard env vars (not PET_* variants). - // Inject both sets for compatibility. + // Inject both PET_* (torchrun compat) and standard env vars + // (accelerate/TRL). apply.UpsertEnvVars(&container.Env, - // PET env vars (for torchrun compatibility) - *corev1ac.EnvVar().WithName(constants.TorchEnvMasterAddr).WithValue(masterAddr), - *corev1ac.EnvVar().WithName(constants.TorchEnvMasterPort).WithValue(masterPort), - // Standard env vars (for accelerate/TRL) - *corev1ac.EnvVar().WithName("MASTER_ADDR").WithValue(masterAddr), - *corev1ac.EnvVar().WithName("MASTER_PORT").WithValue(masterPort), - *corev1ac.EnvVar().WithName("WORLD_SIZE").WithValue(worldSize), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvMasterAddr). + WithValue(masterAddr), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvMasterPort). + WithValue(masterPort), + *corev1ac.EnvVar(). + WithName("MASTER_ADDR").WithValue(masterAddr), + *corev1ac.EnvVar(). + WithName("MASTER_PORT").WithValue(masterPort), + *corev1ac.EnvVar(). + WithName("WORLD_SIZE").WithValue(worldSize), *corev1ac.EnvVar().WithName("RANK").WithValueFrom( corev1ac.EnvVarSource().WithFieldRef( - corev1ac.ObjectFieldSelector().WithFieldPath(constants.JobCompletionIndexFieldPath), + corev1ac.ObjectFieldSelector().WithFieldPath( + constants.JobCompletionIndexFieldPath, + ), ), ), ) @@ -548,25 +747,30 @@ func (s *TRLStrategy) Validate( runtimeInfo *runtime.Info, trainJob *trainer.TrainJob, ) (admission.Warnings, field.ErrorList) { - // TRL validation: check that trainer_type subcommand is valid, etc. + // TRL validation is minimal -- config is fully constructed by the SDK. return nil, nil } ``` ### Go Control Plane: Refactored Torch Plugin Dispatch -The `Torch` struct gains a `backends` map, and `EnforceMLPolicy` dispatches by label: +The `Torch` struct gains a `strategies` map, and `EnforceMLPolicy` dispatches by +the `trainer.kubeflow.org/framework` label: ```go // pkg/runtime/framework/plugins/torch/torch.go (modified) type Torch struct { - backends map[string]LLMBackendStrategy + strategies map[string]FrameworkStrategy } -func New(ctx context.Context, c client.Client, fi client.FieldIndexer) (framework.Plugin, error) { +func New( + ctx context.Context, + c client.Client, + fi client.FieldIndexer, +) (framework.Plugin, error) { return &Torch{ - backends: map[string]LLMBackendStrategy{ + strategies: map[string]FrameworkStrategy{ "torchtune": &TorchTuneStrategy{}, "trl": &TRLStrategy{}, }, @@ -574,28 +778,30 @@ func New(ctx context.Context, c client.Client, fi client.FieldIndexer) (framewor } ``` -The dispatch logic in `EnforceMLPolicy` changes from command-sniffing to label lookup: +The dispatch logic in `EnforceMLPolicy` changes from command-sniffing to label +lookup: ```go -func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error { +func (t *Torch) EnforceMLPolicy( + info *runtime.Info, + trainJob *trainer.TrainJob, +) error { // ... (existing common logic: numNodes, numProcPerNode, PET_NNODES, // PET_NPROC_PER_NODE, PET_NODE_RANK — unchanged) ... - // NEW: label-based dispatch replaces command-sniffing - framework := info.Labels[constants.RuntimeFrameworkLabel] // "trainer.kubeflow.org/framework" - if strategy, ok := t.backends[framework]; ok { - if err := strategy.EnforceCommand(info, trainJob, trainerContainer); err != nil { - return err - } - } else { - // Default: standard torchrun path (PET_MASTER_ADDR, PET_MASTER_PORT) - apply.UpsertEnvVars(&trainerContainer.Env, - *corev1ac.EnvVar().WithName(constants.TorchEnvMasterAddr).WithValue(...), - *corev1ac.EnvVar().WithName(constants.TorchEnvMasterPort).WithValue(...), - ) + // Label-based dispatch replaces command-sniffing. + fw := info.Labels[constants.FrameworkLabel] + if strategy, ok := t.strategies[fw]; ok { + return strategy.EnforceCommand(info, trainJob, trainerContainer) } - // ... (existing: add container port) ... + // Default: standard torchrun path (PET_MASTER_ADDR, PET_MASTER_PORT). + apply.UpsertEnvVars(&trainerContainer.Env, + *corev1ac.EnvVar(). + WithName(constants.TorchEnvMasterAddr).WithValue(masterAddr), + *corev1ac.EnvVar(). + WithName(constants.TorchEnvMasterPort).WithValue(masterPort), + ) return nil } ``` @@ -603,12 +809,15 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) The same pattern applies to `Validate`: ```go -func (t *Torch) Validate(ctx context.Context, runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) { +func (t *Torch) Validate( + ctx context.Context, + runtimeInfo *runtime.Info, + _, newObj *trainer.TrainJob, +) (admission.Warnings, field.ErrorList) { // ... (existing common validation: numProcPerNode, reserved envs) ... - // NEW: label-based dispatch replaces command-sniffing - framework := runtimeInfo.Labels[constants.RuntimeFrameworkLabel] - if strategy, ok := t.backends[framework]; ok { + fw := runtimeInfo.Labels[constants.FrameworkLabel] + if strategy, ok := t.strategies[fw]; ok { warnings, errs := strategy.Validate(runtimeInfo, newObj) allErrs = append(allErrs, errs...) return warnings, allErrs @@ -622,10 +831,11 @@ func (t *Torch) Validate(ctx context.Context, runtimeInfo *runtime.Info, _, newO ```go // pkg/constants/constants.go (addition) -// RuntimeFrameworkLabel is the label on ClusterTrainingRuntime manifests -// that identifies which LLM framework the runtime belongs to. -// Existing manifests already use this label (e.g., "torchtune"). -const RuntimeFrameworkLabel string = "trainer.kubeflow.org/framework" +// FrameworkLabel is the label on ClusterTrainingRuntime manifests that +// identifies which framework the runtime belongs to. +// Existing manifests already use this label (e.g., "torchtune", "torch", +// "deepspeed", "jax", "mlx", "xgboost"). +const FrameworkLabel string = "trainer.kubeflow.org/framework" ``` ### TRL Container Image @@ -643,12 +853,12 @@ RUN pip install --no-cache-dir \ ENTRYPOINT ["trl"] ``` -The image is published as `ghcr.io/kubeflow/trainer/trl-trainer` alongside the existing +Published as `ghcr.io/kubeflow/trainer/trl-trainer` alongside the existing `ghcr.io/kubeflow/trainer/torchtune-trainer`. -### TRL `ClusterTrainingRuntime` Manifest +### TRL ClusterTrainingRuntime Manifests -Example runtime for Llama 3.2 1B SFT with TRL (modeled on the existing TorchTune runtime): +Example runtime for Llama 3.2 1B SFT with TRL: ```yaml apiVersion: trainer.kubeflow.org/v1alpha1 @@ -744,60 +954,101 @@ spec: name: initializer ``` -### SDK Usage Example +--- -End-to-end TRL SFT fine-tuning from the Python SDK: +## User-Facing API Examples + +### TRL SFT Fine-Tuning + +Using KEP-285's `BaseTrainer` interface directly: ```python -from kubeflow.trainer import TrainerClient, types +from kubeflow.trainer import TrainerClient, TRLTrainer, TRLTrainerType, RuntimeConfig +from kubeflow.trainer.types import Initializer, HuggingFaceModelInitializer, HuggingFaceDatasetInitializer client = TrainerClient() +# Runtime auto-discovered via trainer.kubeflow.org/framework: trl client.train( - runtime="trl-llama3.2-1b", - initializer=types.Initializer( - model=types.HuggingFaceModelInitializer( + initializer=Initializer( + model=HuggingFaceModelInitializer( storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct", ), - dataset=types.HuggingFaceDatasetInitializer( + dataset=HuggingFaceDatasetInitializer( storage_uri="hf://tatsu-lab/alpaca", ), ), - trainer=types.BuiltinTrainer( - config=types.TRLConfig( - trainer_type=types.TRLTrainerType.SFT, - num_train_epochs=3, - per_device_train_batch_size=4, - learning_rate=2e-5, - bf16=True, - gradient_checkpointing=True, - use_peft=True, - lora_r=16, - lora_alpha=32, - ), + trainer=TRLTrainer( + trainer_type=TRLTrainerType.SFT, + num_train_epochs=3, + per_device_train_batch_size=4, + learning_rate=2e-5, + bf16=True, + gradient_checkpointing=True, + use_peft=True, + lora_r=16, + lora_alpha=32, + ), + runtime_config=RuntimeConfig( + packages=["flash-attn"], ), ) ``` -For DPO, the `trainer_type` changes and the dataset must be a preference dataset -with chosen/rejected pairs: +### TRL DPO Alignment ```python client.train( - runtime="trl-llama3.2-1b", - initializer=types.Initializer( - model=types.HuggingFaceModelInitializer( + initializer=Initializer( + model=HuggingFaceModelInitializer( storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct", ), - dataset=types.HuggingFaceDatasetInitializer( + dataset=HuggingFaceDatasetInitializer( storage_uri="hf://argilla/ultrafeedback-binarized-preferences", ), ), - trainer=types.BuiltinTrainer( - config=types.TRLConfig( - trainer_type=types.TRLTrainerType.DPO, - learning_rate=1e-6, - bf16=True, + trainer=TRLTrainer( + trainer_type=TRLTrainerType.DPO, + learning_rate=1e-6, + bf16=True, + ), +) +``` + +### TorchTune (Backward Compatible) + +Existing TorchTune code continues to work unchanged: + +```python +client.train( + runtime="torch-llama3.2-1b", + initializer=Initializer( + model=HuggingFaceModelInitializer( + storage_uri="hf://meta-llama/Llama-3.2-1B-Instruct", + ), + dataset=HuggingFaceDatasetInitializer( + storage_uri="hf://tatsu-lab/alpaca", + ), + ), + trainer=TorchTuneTrainer( + epochs=3, + batch_size=4, + peft_config=LoraConfig(lora_rank=16, lora_alpha=32), + ), +) +``` + +### Backward Compatible: BuiltinTrainer Still Works + +```python +# This existing code continues to work with no changes. +client.train( + runtime="torch-llama3.2-1b", + initializer=Initializer(...), + trainer=BuiltinTrainer( + config=TorchTuneConfig( + epochs=3, + batch_size=4, ), ), ) @@ -805,14 +1056,83 @@ client.train( --- +## Implementation Plan + +This proposal is scoped for 350 hours (GSoC Large) and can be implemented in phases: + +**Phase 1: SDK Foundation (Weeks 1-4)** +- Add `ConfigTrainer` base class to `kubeflow/sdk` +- Refactor `TorchTuneConfig` → `TorchTuneTrainer(ConfigTrainer)` with alias +- Update `KubernetesBackend` to use `ConfigTrainer` interface +- Update `BuiltinTrainer.config` type to `ConfigTrainer` +- Unit tests for backward compatibility +- Coordinate with KEP-285 on `BaseTrainer` integration + +**Phase 2: Go Control Plane Refactor (Weeks 5-8)** +- Add `FrameworkLabel` constant to `pkg/constants/constants.go` +- Implement `FrameworkStrategy` interface +- Extract `TorchTuneStrategy` from existing inline code +- Refactor Torch plugin dispatch from command-sniffing to label lookup +- Unit tests for strategy dispatch and TorchTune regression +- Integration tests + +**Phase 3: TRL Backend (Weeks 9-14)** +- Implement `TRLTrainer` in SDK +- Implement `TRLStrategy` in Go Torch plugin +- Build TRL container image (`cmd/trainers/trl/`) +- Create TRL `ClusterTrainingRuntime` manifests +- E2E tests for TRL SFT on GPU +- Documentation and examples + +**Phase 4: Polish and DPO (Weeks 15-18)** +- Add DPO support and E2E tests +- Helm chart additions for TRL runtimes +- SDK documentation on sdk.kubeflow.org +- TorchTune regression E2E validation + +--- + +## Test Plan + +### Unit Tests (SDK) + +- `ConfigTrainer` interface compliance for `TorchTuneTrainer` and `TRLTrainer` +- `TorchTuneConfig` alias backward compatibility +- `TRLTrainer.to_args()` produces correct CLI arguments for SFT and DPO +- `TRLTrainer.validate()` catches invalid configs (e.g., `use_peft=True` without `lora_r`) +- `BuiltinTrainer(config=TRLTrainer(...))` constructs correctly +- Runtime auto-discovery for `supported_frameworks=["trl"]` + +### Unit Tests (Go) + +- `FrameworkStrategy` dispatch: label `"torchtune"` → `TorchTuneStrategy` +- `FrameworkStrategy` dispatch: label `"trl"` → `TRLStrategy` +- `FrameworkStrategy` dispatch: label `"torch"` → default torchrun path +- `TorchTuneStrategy.EnforceCommand()` produces same output as current inline code +- `TRLStrategy.EnforceCommand()` injects correct env vars (`MASTER_ADDR`, `WORLD_SIZE`, etc.) +- `TRLStrategy.Validate()` passes for valid TRL configs + +### Integration Tests + +- TRL TrainJob reconciliation with `ClusterTrainingRuntime` labeled `trl` +- TorchTune regression: existing TorchTune workflows produce identical TrainJobs + +### E2E Tests + +- TRL SFT fine-tuning on GPU +- TRL DPO alignment on GPU +- TorchTune regression on GPU (existing tests) + +--- + ## Risks and Mitigations | Risk | Mitigation | |------|------------| -| TRL CLI changes across versions | Pin version range in requirements.txt; version compat tests | -| TRL uses accelerate, not torchrun, for distributed | TRLStrategy injects both `PET_*` and standard env vars; accelerate reads `MASTER_ADDR`, `MASTER_PORT`, `WORLD_SIZE`, `RANK`; validated in E2E | -| Multi-node TRL untested at scale | Initial implementation scoped to single-node multi-GPU; multi-node validated with dedicated E2E before GA | -| SDK type widening affects static analysis | TorchTuneConfig is a subtype of LLMBackend; passes type checks | -| Scope creep from adding backends | Scoped to TorchTune + TRL only | -| `trainer.kubeflow.org/framework` label not a Go constant | KEP adds `RuntimeFrameworkLabel` constant; existing manifests already use the label | -| KEP-285 `BaseTrainer` hierarchy merges before this KEP | `LLMBackend` is a separate ABC for config-driven trainers; if `BuiltinTrainer` is deprecated, `LLMBackend` implementations migrate to a config-driven Tier 2 trainer with minimal changes | \ No newline at end of file +| TRL CLI changes across versions | Pin version range in container image; version compat tests | +| TRL uses accelerate, not torchrun, for distributed | TRLStrategy injects both `PET_*` and standard env vars; validated in E2E | +| Multi-node TRL untested at scale | Initial scope: single-node multi-GPU; multi-node validated before GA | +| SDK type widening breaks static analysis | `TorchTuneConfig` alias ensures existing type checks pass | +| KEP-285 design changes before this KEP lands | `ConfigTrainer` is designed to adapt to either `BaseTrainer` integration or standalone use | +| Scope creep from adding backends | Scoped to TorchTune + TRL only; other backends follow the same pattern | +| `trainer.kubeflow.org/framework` label not a Go constant | KEP adds `FrameworkLabel` constant; existing manifests already use the label | From ff23bbffcbaabbc16f5d352a553fd7c29ecd6eb3 Mon Sep 17 00:00:00 2001 From: Sabari Date: Tue, 31 Mar 2026 16:05:10 +0530 Subject: [PATCH 07/11] =?UTF-8?q?docs:=20redesign=20KEP-2839=20=E2=80=94?= =?UTF-8?q?=20ConfigTrainer=20as=20separate=20ABC=20from=20BaseTrainer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on mentor feedback, ConfigTrainer is now a standalone ABC rather than a subclass of KEP-285's BaseTrainer. This avoids LSP violations (dead get_train_func() methods) and allows both hierarchies to evolve independently. Key architectural change: - ConfigTrainer and BaseTrainer are separate ABCs for separate patterns (config-driven vs function-based) - Both accepted through same TrainerClient.train(trainer=...) parameter for flat, unified user experience - No inheritance relationship — clean separation of concerns Also adds Alternatives Considered section documenting the unified hierarchy option and why it was rejected. --- .../2839-dynamic-llm-trainer/README.md | 329 ++++++++++++------ 1 file changed, 217 insertions(+), 112 deletions(-) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index ecaf0c5552..195eb38893 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -5,7 +5,7 @@ | **Authors** | @NarayanaSabari | | **Status** | Provisional | | **Created** | 2026-02-27 | -| **Updated** | 2026-03-28 | +| **Updated** | 2026-03-31 | | **Reviewers** | @tariq-hasan, @andreyvelich, @Electronic-Waste | | **Tracking** | [kubeflow/trainer#2839](https://github.com/kubeflow/trainer/issues/2839) | @@ -19,8 +19,9 @@ - [Goals](#goals) - [Non-Goals](#non-goals) - [Relationship to KEP-285 (Specialized Trainer Abstractions)](#relationship-to-kep-285-specialized-trainer-abstractions) - - [The ConfigTrainer vs FuncTrainer Question](#the-configtrainer-vs-functrainer-question) - - [Unified Type Hierarchy](#unified-type-hierarchy) + - [Two Fundamentally Different Trainer Patterns](#two-fundamentally-different-trainer-patterns) + - [Why Separate ABCs Instead of a Unified Hierarchy](#why-separate-abcs-instead-of-a-unified-hierarchy) + - [Unified API Entry Point](#unified-api-entry-point) - [Shared Design Points](#shared-design-points) - [Current State Analysis](#current-state-analysis) - [SDK Coupling](#sdk-coupling) @@ -47,6 +48,7 @@ - [TRL DPO Alignment](#trl-dpo-alignment) - [TorchTune (Backward Compatible)](#torchtune-backward-compatible) - [Backward Compatible: BuiltinTrainer Still Works](#backward-compatible-builtintrainer-still-works) +- [Alternatives Considered](#alternatives-considered) - [Implementation Plan](#implementation-plan) - [Test Plan](#test-plan) - [Risks and Mitigations](#risks-and-mitigations) @@ -61,10 +63,11 @@ This KEP introduces a **pluggable config-driven trainer framework** for LLM fine in Kubeflow Trainer. It decouples the SDK and Go control plane from TorchTune by introducing: -1. A `ConfigTrainer` base class in the Python SDK that sits within - [KEP-285](https://github.com/kubeflow/sdk/pull/308)'s `BaseTrainer` hierarchy as - the foundation for all **config-driven trainers** (where the framework's own CLI is - the entrypoint, not a user-supplied Python function). +1. A `ConfigTrainer` ABC in the Python SDK — a **separate abstraction** from KEP-285's + `BaseTrainer`, purpose-built for **config-driven trainers** where the framework's + own CLI is the entrypoint (e.g., `trl sft ...`, `tune run ...`). Both ABCs are + accepted through the same `TrainerClient.train(trainer=...)` parameter, giving + data scientists a flat, unified API. 2. A `FrameworkStrategy` interface in the Go Torch plugin that replaces hardcoded command-sniffing with label-based dispatch via `trainer.kubeflow.org/framework`. @@ -73,9 +76,9 @@ introducing: refactored as a backward-compatible implementation. This builds on [KEP-2401](../2401-llm-trainer-v2/README.md), the community consensus on -"Plan 3" in [#2752](https://github.com/kubeflow/trainer/issues/2752), and aligns with -the `BaseTrainer` hierarchy being designed in -[KEP-285](https://github.com/kubeflow/sdk/pull/308). +"Plan 3" in [#2752](https://github.com/kubeflow/trainer/issues/2752), and is designed to +complement [KEP-285](https://github.com/kubeflow/sdk/pull/308)'s function-based trainer +hierarchy. --- @@ -114,8 +117,8 @@ at two coupling points: ## Goals -1. Define a `ConfigTrainer` base class within KEP-285's `BaseTrainer` hierarchy for - config-driven LLM trainers. +1. Define a `ConfigTrainer` ABC in the Python SDK as a separate abstraction for + config-driven LLM trainers, complementing KEP-285's function-based `BaseTrainer`. 2. Refactor `TorchTuneConfig` into `TorchTuneTrainer` implementing `ConfigTrainer` with zero breaking changes to existing workflows. 3. Implement `TRLTrainer` supporting SFT and DPO training algorithms. @@ -130,87 +133,100 @@ at two coupling points: 2. CRD schema changes -- operates within existing `.spec.trainer.command`/`.spec.trainer.args`. 3. New Kubernetes resource topologies (e.g., launcher/worker patterns). 4. Deprecating `BuiltinTrainer` or `CustomTrainer` (both remain supported). -5. Implementing function-based trainers (that is KEP-285's Tier 1 scope). +5. Implementing function-based trainers (that is KEP-285's scope). --- ## Relationship to KEP-285 (Specialized Trainer Abstractions) [KEP-285](https://github.com/kubeflow/sdk/pull/308) introduces a `BaseTrainer` ABC -with framework-specific Tier 1 trainers (`TorchTrainer`, `JAXTrainer`, etc.) and -community-contributed Tier 2 extensions. This KEP is designed to integrate directly -into that hierarchy. +with framework-specific trainers (`TorchTrainer`, `JAXTrainer`, etc.) for +**function-based** training — where the user passes a Python `train()` function. +This KEP addresses **config-driven** training — where the framework's own CLI is the +entrypoint. + +### Two Fundamentally Different Trainer Patterns + +| Pattern | Entrypoint | SDK Responsibility | Examples | +|---------|-----------|-------------------|----------| +| **Function-based** (KEP-285) | User's Python `train()` function | Package user code into a container | TorchTrainer, JAXTrainer | +| **Config-driven** (This KEP) | Framework's own CLI binary | Translate config fields into CLI args | TorchTuneTrainer, TRLTrainer | + +These are architecturally distinct: +- Function-based trainers need `get_train_func()`, `get_train_func_args()`, + `packages_to_install` — concepts that don't apply to config-driven trainers. +- Config-driven trainers need `command`, `to_args()`, framework-specific validation + — concepts that don't apply to function-based trainers. + +### Why Separate ABCs Instead of a Unified Hierarchy + +Placing `ConfigTrainer` under `BaseTrainer` would force config-driven trainers to +implement methods that don't apply (`get_train_func()` returning `None`, +`get_train_func_args()` returning `None`). This violates the +[Liskov Substitution Principle](https://en.wikipedia.org/wiki/Liskov_substitution_principle) +— any code calling `trainer.get_train_func()` would need null-checks, and the +interface would carry dead methods. + +Separate ABCs allow each hierarchy to evolve independently: +- KEP-285 can add function-packaging features (e.g., dependency snapshotting) without + affecting config-driven trainers. +- This KEP can add config-driven features (e.g., recipe selection, config file + generation) without polluting function-based trainers. -### The ConfigTrainer vs FuncTrainer Question - -In the KEP-285 review, @andreyvelich -[asked](https://github.com/kubeflow/sdk/pull/308#discussion_r2912976804): - -> "How are we going to refactor the BuiltinTrainer interface once we implement the -> BaseTrainer? And how can we dynamically register new LLM fine-tuning framework -> backends?" - -And @tariq-hasan -[asked](https://github.com/kubeflow/sdk/pull/308#discussion_r2901688930): - -> "How do we handle config-driven trainers for post-training LLM fine-tuning? Do we -> segregate them outside BaseTrainer scope?" +``` + BaseTrainer (ABC) ConfigTrainer (ABC) + ├── get_train_func() ├── command (ClassVar) + ├── get_train_func_args() ├── to_args() + ├── get_framework_args() ├── validate() + ├── validate_runtime() └── supported_frameworks + └── supported_frameworks + │ │ + ┌────┴─────┐ ┌─────────┴─────────┐ + │ │ │ │ + Torch JAX TorchTune TRL + Trainer Trainer Trainer Trainer + (KEP-285) (KEP-285) (This KEP) (This KEP) + + + Existing (unchanged, backward compatible): + + CustomTrainer BuiltinTrainer CustomTrainerContainer + (flat dataclass) (config: ConfigTrainer) (image-based) +``` -And @szaher -[proposed](https://github.com/kubeflow/sdk/pull/308#discussion_r2955718123): +### Unified API Entry Point -> "Should I rename the proposal to have two main abstract classes `ConfigTrainer` and -> `FuncTrainer`?" +Despite being separate ABCs, both are accepted through the **same API parameter**. +Data scientists see a single, flat interface: -This KEP answers these questions. There are two fundamentally different trainer -patterns in Kubeflow: +```python +# Function-based (KEP-285) +client.train(trainer=TorchTrainer(func=my_train_fn, num_nodes=4)) -| Pattern | Entrypoint | Examples | KEP | -|---------|-----------|----------|-----| -| **Function-based** (`FuncTrainer`) | User's Python `train()` function | TorchTrainer, JAXTrainer | KEP-285 Tier 1 | -| **Config-driven** (`ConfigTrainer`) | Framework's own CLI | TorchTune, TRL, Unsloth | This KEP (Tier 2) | +# Config-driven (This KEP) — same parameter, same pattern +client.train(trainer=TRLTrainer(trainer_type=SFT, learning_rate=2e-5)) +``` -### Unified Type Hierarchy +The `TrainerClient.train()` signature widens to accept both: +```python +def train( + self, + trainer: BaseTrainer | ConfigTrainer | CustomTrainer + | CustomTrainerContainer | BuiltinTrainer | None = None, + ... +) ``` - BaseTrainer (ABC) ← KEP-285 - ├── get_train_func() - ├── get_framework_args() - ├── validate_runtime() - └── supported_frameworks - │ - ┌────────────────┼────────────────┐ - │ │ │ - TorchTrainer JAXTrainer ConfigTrainer (ABC) ← This KEP - (Tier 1) (Tier 1) ├── command - ├── to_args() - └── validate() - │ - ┌───────────────┼───────────────┐ - │ │ │ - TorchTuneTrainer TRLTrainer (future: Unsloth, - (Tier 2) (Tier 2) LlamaFactory) - - - Existing (unchanged, backward compatible): - - CustomTrainer BuiltinTrainer CustomTrainerContainer - (flat dataclass) (config: ConfigTrainer) (image-based) -``` - -`ConfigTrainer` extends `BaseTrainer` by adding: -- A `command` class variable (the CLI entrypoint, e.g., `("trl",)` or `("tune", "run")`) -- A `to_args()` method that translates config into CLI arguments -- A `validate()` method for config-level validation -`ConfigTrainer.get_train_func()` returns `None` (there is no user function -- the -framework's CLI **is** the entrypoint). `ConfigTrainer.get_framework_args()` delegates -to `to_args()`. +This gives the best of both worlds: **clean architecture** (separate ABCs, no LSP +violation, independent evolution) with **flat user experience** (one parameter, one +concept to learn, full IDE autocomplete). ### Shared Design Points - Both KEPs use `trainer.kubeflow.org/framework` as the dispatch key. KEP-285 uses it for SDK runtime auto-discovery; this KEP uses it for Go strategy dispatch. +- Both support runtime auto-discovery via `supported_frameworks`. - Both KEPs are compatible with either keeping or deprecating `BuiltinTrainer`. - If the framework label is [promoted to a Runtime API spec field](https://github.com/kubeflow/sdk/pull/308#discussion_r2894627115) @@ -298,7 +314,7 @@ controllers, no changes to the plugin framework itself. │ │ AFTER: │ label = info.Labels │ │ [framework] │ - │ if strategy = backends │ + │ if strategy = strategies │ │ [label]: │ │ → strategy.Enforce() │ │ else: │ @@ -360,6 +376,7 @@ End-to-end for a TRL SFT job: | SDK `BuiltinTrainer` | **Widen** | `TorchTuneConfig` → `ConfigTrainer` | | SDK `TorchTuneConfig` | **Refactor** | → `TorchTuneTrainer(ConfigTrainer)`, backward compatible | | SDK `TRLTrainer` | **New** | New config-driven trainer | +| SDK `TrainerClient.train()` | **Widen** | `trainer=` union accepts `ConfigTrainer` directly | | Container images | **New** | `trl-trainer` image | | ClusterTrainingRuntimes | **New** | TRL-specific runtime manifests | @@ -369,18 +386,17 @@ End-to-end for a TRL SFT job: ### Python SDK: ConfigTrainer Base Class -`ConfigTrainer` extends KEP-285's `BaseTrainer` for config-driven trainers where the -framework's own CLI is the entrypoint. It bridges the gap between function-based -Tier 1 trainers and the existing `BuiltinTrainer`. +`ConfigTrainer` is a **standalone ABC** purpose-built for config-driven trainers. It +does not extend `BaseTrainer` — they are separate abstractions for separate patterns. ```python -from abc import abstractmethod +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Callable, ClassVar, Optional +from typing import ClassVar, Optional @dataclass -class ConfigTrainer(BaseTrainer): +class ConfigTrainer(ABC): """Base class for config-driven LLM training backends. Config-driven trainers use the framework's own CLI as the entrypoint @@ -389,13 +405,18 @@ class ConfigTrainer(BaseTrainer): (command, args) pair that the Kubernetes backend writes into the TrainJob CR. - This class sits in KEP-285's BaseTrainer hierarchy as the foundation - for Tier 2 config-driven trainers. + This is a separate ABC from KEP-285's BaseTrainer. Both are accepted + through TrainerClient.train(trainer=...) for a unified user experience. + + Class Attributes: + command: The CLI entrypoint, e.g., ("tune", "run") or ("trl",). + supported_frameworks: Framework identifiers this trainer supports. + Must match values of the `trainer.kubeflow.org/framework` label + on ClusterTrainingRuntime resources. """ - # Subclasses set this to their CLI entrypoint. - # e.g., ("tune", "run") for TorchTune, ("trl",) for TRL. command: ClassVar[tuple[str, ...]] + supported_frameworks: ClassVar[list[str]] # Common fields shared by all config-driven trainers. num_nodes: Optional[int] = None @@ -411,20 +432,27 @@ class ConfigTrainer(BaseTrainer): """Raise ValueError if the config is invalid.""" ... - # --- BaseTrainer interface implementation --- + def validate_runtime(self, runtime: "Runtime") -> None: + """Validate that the given runtime is compatible with this trainer. + + Raises: + ValueError: If the runtime's framework is not in supported_frameworks. + """ + if runtime.trainer.framework not in self.supported_frameworks: + raise ValueError( + f"{type(self).__name__} supports frameworks " + f"{self.supported_frameworks}, but runtime '{runtime.name}' " + f"has framework '{runtime.trainer.framework}'" + ) +``` - def get_train_func(self) -> Optional[Callable]: - """Config-driven trainers have no user function.""" - return None +**Design rationale:** - def get_train_func_args(self) -> Optional[dict]: - """Config-driven trainers have no function args.""" - return None - - def get_framework_args(self) -> dict: - """Delegate to to_args() for CLI argument generation.""" - return {"_config_args": self.to_args()} -``` +- `ConfigTrainer` does not inherit from `BaseTrainer` — avoids dead methods + (`get_train_func() → None`) and LSP violations. +- `supported_frameworks` and `validate_runtime()` mirror KEP-285's pattern for + runtime auto-discovery, ensuring both ABCs work with the same mechanism. +- `command` as a `ClassVar` — it's a property of the trainer *class*, not instances. ### Python SDK: TorchTuneTrainer (Refactored) @@ -570,19 +598,43 @@ class TRLTrainer(ConfigTrainer): ### Python SDK: TrainerClient Integration -The `TrainerClient.train()` method gains support for `ConfigTrainer` through KEP-285's -`BaseTrainer` interface. When a `ConfigTrainer` is passed: +The `TrainerClient.train()` signature widens to accept `ConfigTrainer` directly, +alongside KEP-285's `BaseTrainer`: + +```python +class TrainerClient: + + def train( + self, + runtime: Optional[Union[str, "Runtime"]] = None, + initializer: Optional["Initializer"] = None, + trainer: Optional[ + Union[ + "BaseTrainer", # KEP-285: function-based + "ConfigTrainer", # This KEP: config-driven + "CustomTrainer", # Existing + "CustomTrainerContainer", # Existing + "BuiltinTrainer", # Existing + ] + ] = None, + runtime_config: Optional["RuntimeConfig"] = None, # KEP-285 + options: Optional[list] = None, + ) -> str: +``` + +When a `ConfigTrainer` is passed: 1. If `runtime` is `None`, the SDK auto-discovers a runtime by matching the - `trainer.kubeflow.org/framework` label against `supported_frameworks` (using - KEP-285's `_resolve_runtime()` mechanism). + `trainer.kubeflow.org/framework` label against `supported_frameworks`. 2. `validate_runtime()` ensures the runtime's framework label matches. 3. The backend uses `config.command` and `config.to_args()` to build the TrainJob CR. +The backend handler for `ConfigTrainer`: + ```python # In KubernetesBackend — unified handler for ConfigTrainer. -def get_trainer_cr( +def get_trainer_cr_from_config_trainer( runtime: types.Runtime, trainer: ConfigTrainer, initializer: Optional[types.Initializer] = None, @@ -608,9 +660,9 @@ def get_trainer_cr( |-------------|--------|---------| | `BuiltinTrainer(config=TorchTuneConfig(...))` | **Works** | `TorchTuneConfig` is an alias for `TorchTuneTrainer` | | `BuiltinTrainer(config=TRLTrainer(...))` | **New** | `BuiltinTrainer.config` type widens to `ConfigTrainer` | +| `client.train(trainer=TRLTrainer(...))` | **New** | `ConfigTrainer` accepted directly in `trainer=` | | `CustomTrainer(func=...)` | **Unchanged** | No modifications | | `CustomTrainerContainer(image=...)` | **Unchanged** | No modifications | -| `TrainerClient.train(trainer=TRLTrainer(...))` | **New** | Direct `BaseTrainer` subclass via KEP-285 | The `BuiltinTrainer.config` field type changes from `TorchTuneConfig` to `ConfigTrainer`. Since `TorchTuneConfig` is a type alias for `TorchTuneTrainer` @@ -747,7 +799,7 @@ func (s *TRLStrategy) Validate( runtimeInfo *runtime.Info, trainJob *trainer.TrainJob, ) (admission.Warnings, field.ErrorList) { - // TRL validation is minimal -- config is fully constructed by the SDK. + // TRL validation is minimal — config is fully constructed by the SDK. return nil, nil } ``` @@ -960,10 +1012,10 @@ spec: ### TRL SFT Fine-Tuning -Using KEP-285's `BaseTrainer` interface directly: +Config-driven trainer passed directly — no wrapper needed: ```python -from kubeflow.trainer import TrainerClient, TRLTrainer, TRLTrainerType, RuntimeConfig +from kubeflow.trainer import TrainerClient, TRLTrainer, TRLTrainerType from kubeflow.trainer.types import Initializer, HuggingFaceModelInitializer, HuggingFaceDatasetInitializer client = TrainerClient() @@ -989,9 +1041,6 @@ client.train( lora_r=16, lora_alpha=32, ), - runtime_config=RuntimeConfig( - packages=["flash-attn"], - ), ) ``` @@ -1056,17 +1105,58 @@ client.train( --- +## Alternatives Considered + +### 1. ConfigTrainer as a subclass of BaseTrainer (unified hierarchy) + +Place `ConfigTrainer` under KEP-285's `BaseTrainer` so all trainers share one ABC. + +**Rejected because:** +- Forces `ConfigTrainer` to implement `get_train_func()` and + `get_train_func_args()` returning `None` — dead methods that violate LSP. +- Any code processing `BaseTrainer` must null-check function-based methods, + adding defensive logic throughout the backend. +- Couples the evolution of config-driven and function-based trainers — changes + to one hierarchy's interface affect the other. + +### 2. Keep config-driven trainers inside BuiltinTrainer only (no direct API) + +Keep the current pattern where config-driven trainers are always wrapped in +`BuiltinTrainer(config=...)`. + +**Rejected because:** +- Forces unnecessary nesting: `BuiltinTrainer(config=TRLTrainer(...))` vs + `TRLTrainer(...)` directly. +- Poor IDE discoverability — data scientists must know about `BuiltinTrainer` + as a wrapper concept. +- Doesn't enable runtime auto-discovery (BuiltinTrainer has no + `supported_frameworks`). + +### 3. Standalone LLMBackend ABC (original KEP-2839 design) + +The original proposal used `LLMBackend` as the ABC name with no relationship to +KEP-285. + +**Rejected because:** +- The name `LLMBackend` is too narrow — config-driven trainers could extend beyond + LLM fine-tuning (e.g., XGBoost config-driven training). +- Didn't address the KEP-285 integration questions raised by maintainers. +- `ConfigTrainer` better communicates the pattern (config-driven, trainer hierarchy). + +--- + ## Implementation Plan This proposal is scoped for 350 hours (GSoC Large) and can be implemented in phases: **Phase 1: SDK Foundation (Weeks 1-4)** -- Add `ConfigTrainer` base class to `kubeflow/sdk` +- Add `ConfigTrainer` ABC to `kubeflow/sdk` - Refactor `TorchTuneConfig` → `TorchTuneTrainer(ConfigTrainer)` with alias - Update `KubernetesBackend` to use `ConfigTrainer` interface -- Update `BuiltinTrainer.config` type to `ConfigTrainer` +- Widen `BuiltinTrainer.config` type to `ConfigTrainer` +- Widen `TrainerClient.train()` to accept `ConfigTrainer` directly - Unit tests for backward compatibility -- Coordinate with KEP-285 on `BaseTrainer` integration +- Coordinate with KEP-285 on shared patterns **Phase 2: Go Control Plane Refactor (Weeks 5-8)** - Add `FrameworkLabel` constant to `pkg/constants/constants.go` @@ -1101,6 +1191,7 @@ This proposal is scoped for 350 hours (GSoC Large) and can be implemented in pha - `TRLTrainer.to_args()` produces correct CLI arguments for SFT and DPO - `TRLTrainer.validate()` catches invalid configs (e.g., `use_peft=True` without `lora_r`) - `BuiltinTrainer(config=TRLTrainer(...))` constructs correctly +- `TrainerClient.train(trainer=TRLTrainer(...))` dispatches correctly - Runtime auto-discovery for `supported_frameworks=["trl"]` ### Unit Tests (Go) @@ -1133,6 +1224,20 @@ This proposal is scoped for 350 hours (GSoC Large) and can be implemented in pha | TRL uses accelerate, not torchrun, for distributed | TRLStrategy injects both `PET_*` and standard env vars; validated in E2E | | Multi-node TRL untested at scale | Initial scope: single-node multi-GPU; multi-node validated before GA | | SDK type widening breaks static analysis | `TorchTuneConfig` alias ensures existing type checks pass | -| KEP-285 design changes before this KEP lands | `ConfigTrainer` is designed to adapt to either `BaseTrainer` integration or standalone use | +| KEP-285 design changes before this KEP lands | `ConfigTrainer` is a separate ABC; no dependency on `BaseTrainer` internals | | Scope creep from adding backends | Scoped to TorchTune + TRL only; other backends follow the same pattern | | `trainer.kubeflow.org/framework` label not a Go constant | KEP adds `FrameworkLabel` constant; existing manifests already use the label | + +--- + +## Implementation History + +- **2025-09-19**: KEP-2839 tracking issue opened by @Electronic-Waste +- **2025-07-24**: Community consensus on Plan 3 (pluggable framework) in #2752 +- **2026-01-08**: @andreyvelich reopened issue, looking for contributors +- **2026-02-27**: Initial KEP proposal submitted by @NarayanaSabari +- **2026-03-28**: KEP redesigned to align with KEP-285 BaseTrainer hierarchy + (ConfigTrainer as subclass of BaseTrainer) +- **2026-03-31**: KEP redesigned again based on mentor feedback — ConfigTrainer as + separate ABC from BaseTrainer (clean separation of concerns), with unified API + entry point through TrainerClient.train(trainer=...) From 9570cc6954d02e4aa1a01dc4e52c563f053d585f Mon Sep 17 00:00:00 2001 From: Sabari Date: Tue, 31 Mar 2026 16:05:41 +0530 Subject: [PATCH 08/11] docs: rename ConfigTrainer to LLMTrainer in KEP-2839 --- .../2839-dynamic-llm-trainer/README.md | 84 +++++++++---------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index 195eb38893..b5d5cfb993 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -31,7 +31,7 @@ - [Component Interaction Flow](#component-interaction-flow) - [What Changes vs What Stays](#what-changes-vs-what-stays) - [Design Details](#design-details) - - [Python SDK: ConfigTrainer Base Class](#python-sdk-configtrainer-base-class) + - [Python SDK: LLMTrainer Base Class](#python-sdk-configtrainer-base-class) - [Python SDK: TorchTuneTrainer (Refactored)](#python-sdk-torchtunetrainer-refactored) - [Python SDK: TRLTrainer](#python-sdk-trltrainer) - [Python SDK: TrainerClient Integration](#python-sdk-trainerclient-integration) @@ -63,7 +63,7 @@ This KEP introduces a **pluggable config-driven trainer framework** for LLM fine in Kubeflow Trainer. It decouples the SDK and Go control plane from TorchTune by introducing: -1. A `ConfigTrainer` ABC in the Python SDK — a **separate abstraction** from KEP-285's +1. A `LLMTrainer` ABC in the Python SDK — a **separate abstraction** from KEP-285's `BaseTrainer`, purpose-built for **config-driven trainers** where the framework's own CLI is the entrypoint (e.g., `trl sft ...`, `tune run ...`). Both ABCs are accepted through the same `TrainerClient.train(trainer=...)` parameter, giving @@ -117,9 +117,9 @@ at two coupling points: ## Goals -1. Define a `ConfigTrainer` ABC in the Python SDK as a separate abstraction for +1. Define a `LLMTrainer` ABC in the Python SDK as a separate abstraction for config-driven LLM trainers, complementing KEP-285's function-based `BaseTrainer`. -2. Refactor `TorchTuneConfig` into `TorchTuneTrainer` implementing `ConfigTrainer` +2. Refactor `TorchTuneConfig` into `TorchTuneTrainer` implementing `LLMTrainer` with zero breaking changes to existing workflows. 3. Implement `TRLTrainer` supporting SFT and DPO training algorithms. 4. Create TRL container image and `ClusterTrainingRuntime` manifests. @@ -160,7 +160,7 @@ These are architecturally distinct: ### Why Separate ABCs Instead of a Unified Hierarchy -Placing `ConfigTrainer` under `BaseTrainer` would force config-driven trainers to +Placing `LLMTrainer` under `BaseTrainer` would force config-driven trainers to implement methods that don't apply (`get_train_func()` returning `None`, `get_train_func_args()` returning `None`). This violates the [Liskov Substitution Principle](https://en.wikipedia.org/wiki/Liskov_substitution_principle) @@ -174,7 +174,7 @@ Separate ABCs allow each hierarchy to evolve independently: generation) without polluting function-based trainers. ``` - BaseTrainer (ABC) ConfigTrainer (ABC) + BaseTrainer (ABC) LLMTrainer (ABC) ├── get_train_func() ├── command (ClassVar) ├── get_train_func_args() ├── to_args() ├── get_framework_args() ├── validate() @@ -191,7 +191,7 @@ Separate ABCs allow each hierarchy to evolve independently: Existing (unchanged, backward compatible): CustomTrainer BuiltinTrainer CustomTrainerContainer - (flat dataclass) (config: ConfigTrainer) (image-based) + (flat dataclass) (config: LLMTrainer) (image-based) ``` ### Unified API Entry Point @@ -212,7 +212,7 @@ The `TrainerClient.train()` signature widens to accept both: ```python def train( self, - trainer: BaseTrainer | ConfigTrainer | CustomTrainer + trainer: BaseTrainer | LLMTrainer | CustomTrainer | CustomTrainerContainer | BuiltinTrainer | None = None, ... ) @@ -373,10 +373,10 @@ End-to-end for a TRL SFT job: | Torch plugin (TorchTune path) | **Refactor** | Extract inline code → `TorchTuneStrategy` | | Torch plugin (dispatch) | **New** | Label-based strategy lookup replaces command-sniffing | | TRL strategy | **New** | `TRLStrategy` for TRL-specific env vars | -| SDK `BuiltinTrainer` | **Widen** | `TorchTuneConfig` → `ConfigTrainer` | -| SDK `TorchTuneConfig` | **Refactor** | → `TorchTuneTrainer(ConfigTrainer)`, backward compatible | +| SDK `BuiltinTrainer` | **Widen** | `TorchTuneConfig` → `LLMTrainer` | +| SDK `TorchTuneConfig` | **Refactor** | → `TorchTuneTrainer(LLMTrainer)`, backward compatible | | SDK `TRLTrainer` | **New** | New config-driven trainer | -| SDK `TrainerClient.train()` | **Widen** | `trainer=` union accepts `ConfigTrainer` directly | +| SDK `TrainerClient.train()` | **Widen** | `trainer=` union accepts `LLMTrainer` directly | | Container images | **New** | `trl-trainer` image | | ClusterTrainingRuntimes | **New** | TRL-specific runtime manifests | @@ -384,9 +384,9 @@ End-to-end for a TRL SFT job: ## Design Details -### Python SDK: ConfigTrainer Base Class +### Python SDK: LLMTrainer Base Class -`ConfigTrainer` is a **standalone ABC** purpose-built for config-driven trainers. It +`LLMTrainer` is a **standalone ABC** purpose-built for config-driven trainers. It does not extend `BaseTrainer` — they are separate abstractions for separate patterns. ```python @@ -396,7 +396,7 @@ from typing import ClassVar, Optional @dataclass -class ConfigTrainer(ABC): +class LLMTrainer(ABC): """Base class for config-driven LLM training backends. Config-driven trainers use the framework's own CLI as the entrypoint @@ -448,7 +448,7 @@ class ConfigTrainer(ABC): **Design rationale:** -- `ConfigTrainer` does not inherit from `BaseTrainer` — avoids dead methods +- `LLMTrainer` does not inherit from `BaseTrainer` — avoids dead methods (`get_train_func() → None`) and LSP violations. - `supported_frameworks` and `validate_runtime()` mirror KEP-285's pattern for runtime auto-discovery, ensuring both ABCs work with the same mechanism. @@ -456,13 +456,13 @@ class ConfigTrainer(ABC): ### Python SDK: TorchTuneTrainer (Refactored) -`TorchTuneConfig` is refactored into `TorchTuneTrainer` implementing `ConfigTrainer`. +`TorchTuneConfig` is refactored into `TorchTuneTrainer` implementing `LLMTrainer`. All existing fields are preserved. `TorchTuneConfig` becomes a type alias for backward compatibility. ```python @dataclass -class TorchTuneTrainer(ConfigTrainer): +class TorchTuneTrainer(LLMTrainer): """TorchTune LLM Trainer configuration. Supports runtimes labeled with trainer.kubeflow.org/framework: torchtune. @@ -507,7 +507,7 @@ class TRLTrainerType(Enum): @dataclass -class TRLTrainer(ConfigTrainer): +class TRLTrainer(LLMTrainer): """TRL LLM Trainer configuration. Supports runtimes labeled with trainer.kubeflow.org/framework: trl. @@ -598,7 +598,7 @@ class TRLTrainer(ConfigTrainer): ### Python SDK: TrainerClient Integration -The `TrainerClient.train()` signature widens to accept `ConfigTrainer` directly, +The `TrainerClient.train()` signature widens to accept `LLMTrainer` directly, alongside KEP-285's `BaseTrainer`: ```python @@ -611,7 +611,7 @@ class TrainerClient: trainer: Optional[ Union[ "BaseTrainer", # KEP-285: function-based - "ConfigTrainer", # This KEP: config-driven + "LLMTrainer", # This KEP: config-driven "CustomTrainer", # Existing "CustomTrainerContainer", # Existing "BuiltinTrainer", # Existing @@ -622,21 +622,21 @@ class TrainerClient: ) -> str: ``` -When a `ConfigTrainer` is passed: +When a `LLMTrainer` is passed: 1. If `runtime` is `None`, the SDK auto-discovers a runtime by matching the `trainer.kubeflow.org/framework` label against `supported_frameworks`. 2. `validate_runtime()` ensures the runtime's framework label matches. 3. The backend uses `config.command` and `config.to_args()` to build the TrainJob CR. -The backend handler for `ConfigTrainer`: +The backend handler for `LLMTrainer`: ```python -# In KubernetesBackend — unified handler for ConfigTrainer. +# In KubernetesBackend — unified handler for LLMTrainer. -def get_trainer_cr_from_config_trainer( +def get_trainer_cr_from_llm_trainer( runtime: types.Runtime, - trainer: ConfigTrainer, + trainer: LLMTrainer, initializer: Optional[types.Initializer] = None, ) -> models.TrainerV1alpha1Trainer: trainer.validate() @@ -659,14 +659,14 @@ def get_trainer_cr_from_config_trainer( | Existing API | Status | Details | |-------------|--------|---------| | `BuiltinTrainer(config=TorchTuneConfig(...))` | **Works** | `TorchTuneConfig` is an alias for `TorchTuneTrainer` | -| `BuiltinTrainer(config=TRLTrainer(...))` | **New** | `BuiltinTrainer.config` type widens to `ConfigTrainer` | -| `client.train(trainer=TRLTrainer(...))` | **New** | `ConfigTrainer` accepted directly in `trainer=` | +| `BuiltinTrainer(config=TRLTrainer(...))` | **New** | `BuiltinTrainer.config` type widens to `LLMTrainer` | +| `client.train(trainer=TRLTrainer(...))` | **New** | `LLMTrainer` accepted directly in `trainer=` | | `CustomTrainer(func=...)` | **Unchanged** | No modifications | | `CustomTrainerContainer(image=...)` | **Unchanged** | No modifications | The `BuiltinTrainer.config` field type changes from `TorchTuneConfig` to -`ConfigTrainer`. Since `TorchTuneConfig` is a type alias for `TorchTuneTrainer` -which extends `ConfigTrainer`, all existing code continues to work. +`LLMTrainer`. Since `TorchTuneConfig` is a type alias for `TorchTuneTrainer` +which extends `LLMTrainer`, all existing code continues to work. ### Go Control Plane: FrameworkStrategy Interface @@ -1107,12 +1107,12 @@ client.train( ## Alternatives Considered -### 1. ConfigTrainer as a subclass of BaseTrainer (unified hierarchy) +### 1. LLMTrainer as a subclass of BaseTrainer (unified hierarchy) -Place `ConfigTrainer` under KEP-285's `BaseTrainer` so all trainers share one ABC. +Place `LLMTrainer` under KEP-285's `BaseTrainer` so all trainers share one ABC. **Rejected because:** -- Forces `ConfigTrainer` to implement `get_train_func()` and +- Forces `LLMTrainer` to implement `get_train_func()` and `get_train_func_args()` returning `None` — dead methods that violate LSP. - Any code processing `BaseTrainer` must null-check function-based methods, adding defensive logic throughout the backend. @@ -1141,7 +1141,7 @@ KEP-285. - The name `LLMBackend` is too narrow — config-driven trainers could extend beyond LLM fine-tuning (e.g., XGBoost config-driven training). - Didn't address the KEP-285 integration questions raised by maintainers. -- `ConfigTrainer` better communicates the pattern (config-driven, trainer hierarchy). +- `LLMTrainer` better communicates the pattern (config-driven, trainer hierarchy). --- @@ -1150,11 +1150,11 @@ KEP-285. This proposal is scoped for 350 hours (GSoC Large) and can be implemented in phases: **Phase 1: SDK Foundation (Weeks 1-4)** -- Add `ConfigTrainer` ABC to `kubeflow/sdk` -- Refactor `TorchTuneConfig` → `TorchTuneTrainer(ConfigTrainer)` with alias -- Update `KubernetesBackend` to use `ConfigTrainer` interface -- Widen `BuiltinTrainer.config` type to `ConfigTrainer` -- Widen `TrainerClient.train()` to accept `ConfigTrainer` directly +- Add `LLMTrainer` ABC to `kubeflow/sdk` +- Refactor `TorchTuneConfig` → `TorchTuneTrainer(LLMTrainer)` with alias +- Update `KubernetesBackend` to use `LLMTrainer` interface +- Widen `BuiltinTrainer.config` type to `LLMTrainer` +- Widen `TrainerClient.train()` to accept `LLMTrainer` directly - Unit tests for backward compatibility - Coordinate with KEP-285 on shared patterns @@ -1186,7 +1186,7 @@ This proposal is scoped for 350 hours (GSoC Large) and can be implemented in pha ### Unit Tests (SDK) -- `ConfigTrainer` interface compliance for `TorchTuneTrainer` and `TRLTrainer` +- `LLMTrainer` interface compliance for `TorchTuneTrainer` and `TRLTrainer` - `TorchTuneConfig` alias backward compatibility - `TRLTrainer.to_args()` produces correct CLI arguments for SFT and DPO - `TRLTrainer.validate()` catches invalid configs (e.g., `use_peft=True` without `lora_r`) @@ -1224,7 +1224,7 @@ This proposal is scoped for 350 hours (GSoC Large) and can be implemented in pha | TRL uses accelerate, not torchrun, for distributed | TRLStrategy injects both `PET_*` and standard env vars; validated in E2E | | Multi-node TRL untested at scale | Initial scope: single-node multi-GPU; multi-node validated before GA | | SDK type widening breaks static analysis | `TorchTuneConfig` alias ensures existing type checks pass | -| KEP-285 design changes before this KEP lands | `ConfigTrainer` is a separate ABC; no dependency on `BaseTrainer` internals | +| KEP-285 design changes before this KEP lands | `LLMTrainer` is a separate ABC; no dependency on `BaseTrainer` internals | | Scope creep from adding backends | Scoped to TorchTune + TRL only; other backends follow the same pattern | | `trainer.kubeflow.org/framework` label not a Go constant | KEP adds `FrameworkLabel` constant; existing manifests already use the label | @@ -1237,7 +1237,7 @@ This proposal is scoped for 350 hours (GSoC Large) and can be implemented in pha - **2026-01-08**: @andreyvelich reopened issue, looking for contributors - **2026-02-27**: Initial KEP proposal submitted by @NarayanaSabari - **2026-03-28**: KEP redesigned to align with KEP-285 BaseTrainer hierarchy - (ConfigTrainer as subclass of BaseTrainer) -- **2026-03-31**: KEP redesigned again based on mentor feedback — ConfigTrainer as + (LLMTrainer as subclass of BaseTrainer) +- **2026-03-31**: KEP redesigned again based on mentor feedback — LLMTrainer as separate ABC from BaseTrainer (clean separation of concerns), with unified API entry point through TrainerClient.train(trainer=...) From 2f3e675a7d5fb529dde7b45d47a049c379f392cc Mon Sep 17 00:00:00 2001 From: Sabari Date: Tue, 31 Mar 2026 19:57:00 +0530 Subject: [PATCH 09/11] docs: add architecture diagrams to KEP-2839 Add three diagrams: - SDK type hierarchy showing LLMTrainer and BaseTrainer as separate ABCs - End-to-end system architecture from Python SDK to Kubernetes pods - Go Torch plugin strategy dispatch flow --- .../2839-dynamic-llm-trainer/README.md | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index b5d5cfb993..b4b5951718 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -152,6 +152,31 @@ entrypoint. | **Function-based** (KEP-285) | User's Python `train()` function | Package user code into a container | TorchTrainer, JAXTrainer | | **Config-driven** (This KEP) | Framework's own CLI binary | Translate config fields into CLI args | TorchTuneTrainer, TRLTrainer | +``` + ┌─────────────────────────────────────────────────────────────────────┐ + │ TrainerClient.train(trainer=...) │ + │ Unified API Entry Point │ + └────────────┬──────────────────────────┬─────────────────────────────┘ + │ │ + ▼ ▼ + ┌────────────────────────┐ ┌──────────────────────────────┐ + │ BaseTrainer (ABC) │ │ LLMTrainer (ABC) │ + │ ────────────────── │ │ ──────────────────── │ + │ KEP-285 │ │ This KEP │ + │ Function-based │ │ Config-driven │ + │ │ │ │ + │ get_train_func() │ │ command: ClassVar │ + │ get_train_func_args() │ │ to_args() → CLI args │ + │ get_framework_args() │ │ validate() │ + │ validate_runtime() │ │ validate_runtime() │ + ├────────────────────────┤ ├──────────────────────────────┤ + │ ┌──────┐ ┌──────────┐│ │ ┌──────────────┐ ┌─────────┐│ + │ │Torch │ │JAX ││ │ │TorchTune │ │TRL ││ + │ │Train.│ │Trainer ││ │ │Trainer │ │Trainer ││ + │ └──────┘ └──────────┘│ │ └──────────────┘ └─────────┘│ + └────────────────────────┘ └──────────────────────────────┘ +``` + These are architecturally distinct: - Function-based trainers need `get_train_func()`, `get_train_func_args()`, `packages_to_install` — concepts that don't apply to config-driven trainers. @@ -322,6 +347,124 @@ controllers, no changes to the plugin framework itself. └─────────────────────────────┘ ``` +### End-to-End System Architecture + +The following diagram shows how a TRL training job flows through the entire system, +from the data scientist's Python call to pods running on Kubernetes: + +``` + Data Scientist + │ + │ client.train(trainer=TRLTrainer(SFT, lr=2e-5, ...)) + ▼ + ┌──────────────────────────────────────────────────────────────────┐ + │ Python SDK │ + │ │ + │ TrainerClient.train(trainer=...) │ + │ │ │ + │ ├─ Is it LLMTrainer? ──yes──► Auto-discover runtime │ + │ │ (or BaseTrainer, by framework label │ + │ │ CustomTrainer, etc.) "trl" → trl-llama3.2-1b │ + │ │ │ + │ ▼ │ + │ trainer.validate() │ + │ trainer.validate_runtime(runtime) │ + │ │ │ + │ ▼ │ + │ KubernetesBackend │ + │ ┌─────────────────────────────────────────────────────────┐ │ + │ │ command = list(trainer.command) → ["trl"] │ │ + │ │ args = trainer.to_args(init) → ["sft", "--m..."] │ │ + │ │ Build TrainJob CR with runtimeRef │ │ + │ └─────────────────────────────────────────────────────────┘ │ + └──────────────────────────────┬───────────────────────────────────┘ + │ POST TrainJob CR + ▼ + ┌──────────────────────────────────────────────────────────────────┐ + │ Kubernetes API Server │ + │ │ + │ Webhook Validation │ + │ ┌──────────────────────────────────────────────────────────┐ │ + │ │ Torch Plugin .Validate() │ │ + │ │ label = runtime.Labels["trainer.kubeflow.org/framework"]│ │ + │ │ strategies["trl"] → TRLStrategy.Validate() │ │ + │ └──────────────────────────────────────────────────────────┘ │ + └──────────────────────────────┬───────────────────────────────────┘ + │ + ▼ + ┌──────────────────────────────────────────────────────────────────┐ + │ TrainJob Controller (Go) │ + │ │ + │ Torch Plugin .EnforceMLPolicy() │ + │ ┌──────────────────────────────────────────────────────────┐ │ + │ │ │ │ + │ │ 1. Common (all frameworks): │ │ + │ │ inject PET_NNODES, PET_NPROC_PER_NODE, PET_NODE_RANK│ │ + │ │ │ │ + │ │ 2. Strategy dispatch: │ │ + │ │ label "trl" → TRLStrategy.EnforceCommand() │ │ + │ │ ├─ MASTER_ADDR = │ │ + │ │ ├─ MASTER_PORT = 29500 │ │ + │ │ ├─ WORLD_SIZE = numNodes * numProcPerNode │ │ + │ │ └─ RANK = JOB_COMPLETION_INDEX │ │ + │ │ │ │ + │ │ 3. Add container port 29500 │ │ + │ └──────────────────────────────────────────────────────────┘ │ + │ │ + │ SSA Apply → JobSet → ReplicatedJobs │ + └──────────────────────────────┬───────────────────────────────────┘ + │ + ▼ + ┌──────────────────────────────────────────────────────────────────┐ + │ Kubernetes Pods │ + │ │ + │ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ + │ │ dataset-init │ │ model-init │ │ trainer (node) │ │ + │ │ ───────────── │ │ ───────────── │ │ ─────────────── │ │ + │ │ Download dataset │→ │ Download model │→ │ trl sft \ │ │ + │ │ from HF Hub │ │ from HF Hub │ │ --model ... │ │ + │ │ │ │ │ │ --dataset ... │ │ + │ └─────────────────┘ └─────────────────┘ └─────────────────┘ │ + └──────────────────────────────────────────────────────────────────┘ +``` + +### Go Torch Plugin: Strategy Dispatch + +The core of the Go-side refactor — replacing command-sniffing with label-based +dispatch: + +``` + EnforceMLPolicy(info, trainJob) + │ + ▼ + ┌─────────────────────┐ + │ Common Path │ + │ (all frameworks) │ + │ │ + │ PET_NNODES │ + │ PET_NPROC_PER_NODE│ + │ PET_NODE_RANK │ + └──────────┬──────────┘ + │ + ▼ + info.Labels["trainer.kubeflow.org/framework"] + │ + ┌─────────────────┼─────────────────┐ + │ │ │ + ▼ ▼ ▼ + label="torchtune" label="trl" label="torch" + │ │ (or unknown) + ▼ ▼ │ + ┌─────────────────┐ ┌─────────────────┐ │ + │TorchTuneStrategy│ │ TRLStrategy │ │ + │ │ │ │ │ + │ rdzv_endpoint │ │ MASTER_ADDR │ ▼ + │ recipe selection│ │ MASTER_PORT │ ┌──────────┐ + │ config overrides│ │ WORLD_SIZE │ │ Default │ + │ immutable args │ │ RANK │ │ torchrun │ + └─────────────────┘ └─────────────────┘ └──────────┘ +``` + ### Component Interaction Flow End-to-end for a TRL SFT job: From 76f6b01b3c34adad0219661a582fd27cd5698f48 Mon Sep 17 00:00:00 2001 From: Sabari Date: Tue, 31 Mar 2026 20:59:19 +0530 Subject: [PATCH 10/11] docs: add architecture diagrams to KEP-2839 Add three visual diagrams to the High-Level Design section: - Before vs After: side-by-side comparison of SDK and Go coupling points - SDK Type Hierarchy: shows LLMTrainer and BaseTrainer as parallel ABCs feeding into unified TrainerClient.train(trainer=...) - End-to-End Flow: full stack trace of a TRL SFT job from data scientist through SDK, K8s API, Go Torch plugin, down to pods --- .../2839-dynamic-llm-trainer/README.md | 327 +++++++----------- 1 file changed, 119 insertions(+), 208 deletions(-) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index b4b5951718..844a1337b4 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -152,31 +152,6 @@ entrypoint. | **Function-based** (KEP-285) | User's Python `train()` function | Package user code into a container | TorchTrainer, JAXTrainer | | **Config-driven** (This KEP) | Framework's own CLI binary | Translate config fields into CLI args | TorchTuneTrainer, TRLTrainer | -``` - ┌─────────────────────────────────────────────────────────────────────┐ - │ TrainerClient.train(trainer=...) │ - │ Unified API Entry Point │ - └────────────┬──────────────────────────┬─────────────────────────────┘ - │ │ - ▼ ▼ - ┌────────────────────────┐ ┌──────────────────────────────┐ - │ BaseTrainer (ABC) │ │ LLMTrainer (ABC) │ - │ ────────────────── │ │ ──────────────────── │ - │ KEP-285 │ │ This KEP │ - │ Function-based │ │ Config-driven │ - │ │ │ │ - │ get_train_func() │ │ command: ClassVar │ - │ get_train_func_args() │ │ to_args() → CLI args │ - │ get_framework_args() │ │ validate() │ - │ validate_runtime() │ │ validate_runtime() │ - ├────────────────────────┤ ├──────────────────────────────┤ - │ ┌──────┐ ┌──────────┐│ │ ┌──────────────┐ ┌─────────┐│ - │ │Torch │ │JAX ││ │ │TorchTune │ │TRL ││ - │ │Train.│ │Trainer ││ │ │Trainer │ │Trainer ││ - │ └──────┘ └──────────┘│ │ └──────────────┘ └─────────┘│ - └────────────────────────┘ └──────────────────────────────┘ -``` - These are architecturally distinct: - Function-based trainers need `get_train_func()`, `get_train_func_args()`, `packages_to_install` — concepts that don't apply to config-driven trainers. @@ -309,199 +284,135 @@ similarly sniffs the entrypoint to decide whether to run `validateTorchTune()`. The change is a **localized refactor** of two coupling points. No new CRDs, no new controllers, no changes to the plugin framework itself. -``` - BEFORE AFTER - ┌──────────────┐ ┌──────────────┐ - SDK │BuiltinTrainer│ │BuiltinTrainer│ - │ config: │ │ config: │ - │ TorchTune │ │ Config │ - │ Config │ │ Trainer │ - └──────┬───────┘ └──────┬───────┘ - │ │ - │ hardcoded │ config.command - │ get_args_using_ │ config.to_args() - │ torchtune_config() │ - ▼ ▼ - creates TrainJob CR creates TrainJob CR - │ │ - ┌────────────────────────────────────────────────────────────────────┐ - │ Kubernetes API │ - └──────────────────────────┬─────────────────────────────────────────┘ - │ - Go ▼ - Torch ┌─────────────────────────────┐ - Plugin │ EnforceMLPolicy() │ - │ │ - BEFORE: │ if cmd == ["tune","run"]: │ - │ → TorchTune branch │ - │ else: │ - │ → torchrun branch │ - │ │ - AFTER: │ label = info.Labels │ - │ [framework] │ - │ if strategy = strategies │ - │ [label]: │ - │ → strategy.Enforce() │ - │ else: │ - │ → default torchrun │ - └─────────────────────────────┘ -``` - -### End-to-End System Architecture - -The following diagram shows how a TRL training job flows through the entire system, -from the data scientist's Python call to pods running on Kubernetes: +#### Before vs After ``` - Data Scientist - │ - │ client.train(trainer=TRLTrainer(SFT, lr=2e-5, ...)) - ▼ - ┌──────────────────────────────────────────────────────────────────┐ - │ Python SDK │ - │ │ - │ TrainerClient.train(trainer=...) │ - │ │ │ - │ ├─ Is it LLMTrainer? ──yes──► Auto-discover runtime │ - │ │ (or BaseTrainer, by framework label │ - │ │ CustomTrainer, etc.) "trl" → trl-llama3.2-1b │ - │ │ │ - │ ▼ │ - │ trainer.validate() │ - │ trainer.validate_runtime(runtime) │ - │ │ │ - │ ▼ │ - │ KubernetesBackend │ - │ ┌─────────────────────────────────────────────────────────┐ │ - │ │ command = list(trainer.command) → ["trl"] │ │ - │ │ args = trainer.to_args(init) → ["sft", "--m..."] │ │ - │ │ Build TrainJob CR with runtimeRef │ │ - │ └─────────────────────────────────────────────────────────┘ │ - └──────────────────────────────┬───────────────────────────────────┘ - │ POST TrainJob CR - ▼ - ┌──────────────────────────────────────────────────────────────────┐ - │ Kubernetes API Server │ - │ │ - │ Webhook Validation │ - │ ┌──────────────────────────────────────────────────────────┐ │ - │ │ Torch Plugin .Validate() │ │ - │ │ label = runtime.Labels["trainer.kubeflow.org/framework"]│ │ - │ │ strategies["trl"] → TRLStrategy.Validate() │ │ - │ └──────────────────────────────────────────────────────────┘ │ - └──────────────────────────────┬───────────────────────────────────┘ - │ - ▼ - ┌──────────────────────────────────────────────────────────────────┐ - │ TrainJob Controller (Go) │ - │ │ - │ Torch Plugin .EnforceMLPolicy() │ - │ ┌──────────────────────────────────────────────────────────┐ │ - │ │ │ │ - │ │ 1. Common (all frameworks): │ │ - │ │ inject PET_NNODES, PET_NPROC_PER_NODE, PET_NODE_RANK│ │ - │ │ │ │ - │ │ 2. Strategy dispatch: │ │ - │ │ label "trl" → TRLStrategy.EnforceCommand() │ │ - │ │ ├─ MASTER_ADDR = │ │ - │ │ ├─ MASTER_PORT = 29500 │ │ - │ │ ├─ WORLD_SIZE = numNodes * numProcPerNode │ │ - │ │ └─ RANK = JOB_COMPLETION_INDEX │ │ - │ │ │ │ - │ │ 3. Add container port 29500 │ │ - │ └──────────────────────────────────────────────────────────┘ │ - │ │ - │ SSA Apply → JobSet → ReplicatedJobs │ - └──────────────────────────────┬───────────────────────────────────┘ - │ - ▼ - ┌──────────────────────────────────────────────────────────────────┐ - │ Kubernetes Pods │ - │ │ - │ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ - │ │ dataset-init │ │ model-init │ │ trainer (node) │ │ - │ │ ───────────── │ │ ───────────── │ │ ─────────────── │ │ - │ │ Download dataset │→ │ Download model │→ │ trl sft \ │ │ - │ │ from HF Hub │ │ from HF Hub │ │ --model ... │ │ - │ │ │ │ │ │ --dataset ... │ │ - │ └─────────────────┘ └─────────────────┘ └─────────────────┘ │ - └──────────────────────────────────────────────────────────────────┘ + BEFORE (hardcoded) AFTER (pluggable) + ══════════════════ ═════════════════ + + ┌─────────────────────┐ ┌─────────────────────┐ + │ Python SDK │ │ Python SDK │ + │ │ │ │ + │ BuiltinTrainer │ │ BuiltinTrainer │ + │ config: │ │ config: │ + │ TorchTuneConfig │ ← only option │ LLMTrainer (ABC) │ ← pluggable + │ │ │ ├─ TorchTuneTrainer│ + │ get_args_using_ │ │ └─ TRLTrainer │ + │ torchtune_config() │ ← hardcoded │ │ + └──────────┬───────────┘ │ config.command │ + │ │ config.to_args() │ ← generic + ───────────┼──────────── └──────────┬───────────┘ + │ TrainJob CR │ TrainJob CR + ┌──────────▼───────────┐ ┌──────────▼───────────┐ + │ Go Torch Plugin │ │ Go Torch Plugin │ + │ │ │ │ + │ if cmd == ["tune", │ │ label = info.Labels │ + │ "run"]: │ │ ["framework"] │ + │ → TorchTune │ ← cmd sniffing │ strategies[label] │ + │ else: │ │ .EnforceCommand() │ ← label dispatch + │ → torchrun │ │ │ + └──────────────────────┘ └──────────────────────┘ ``` -### Go Torch Plugin: Strategy Dispatch - -The core of the Go-side refactor — replacing command-sniffing with label-based -dispatch: +#### SDK Type Hierarchy ``` - EnforceMLPolicy(info, trainJob) - │ - ▼ - ┌─────────────────────┐ - │ Common Path │ - │ (all frameworks) │ - │ │ - │ PET_NNODES │ - │ PET_NPROC_PER_NODE│ - │ PET_NODE_RANK │ - └──────────┬──────────┘ - │ - ▼ - info.Labels["trainer.kubeflow.org/framework"] - │ - ┌─────────────────┼─────────────────┐ - │ │ │ - ▼ ▼ ▼ - label="torchtune" label="trl" label="torch" - │ │ (or unknown) - ▼ ▼ │ - ┌─────────────────┐ ┌─────────────────┐ │ - │TorchTuneStrategy│ │ TRLStrategy │ │ - │ │ │ │ │ - │ rdzv_endpoint │ │ MASTER_ADDR │ ▼ - │ recipe selection│ │ MASTER_PORT │ ┌──────────┐ - │ config overrides│ │ WORLD_SIZE │ │ Default │ - │ immutable args │ │ RANK │ │ torchrun │ - └─────────────────┘ └─────────────────┘ └──────────┘ + ┌──────────────────────────────────────────────────────────────────────────┐ + │ TrainerClient.train(trainer=...) │ + │ │ + │ Accepts ANY of these — unified API, separate abstractions: │ + └──┬────────────┬────────────────┬────────────────────┬───────────────────┘ + │ │ │ │ + ▼ ▼ ▼ ▼ + BaseTrainer LLMTrainer CustomTrainer BuiltinTrainer + (KEP-285) (This KEP) (existing) (existing) + func-based config-driven user function wraps LLMTrainer + │ │ + │ ├── TorchTuneTrainer + │ │ command: ("tune", "run") + │ │ framework: torchtune + │ │ + │ └── TRLTrainer + │ command: ("trl",) + │ framework: trl + │ + ├── TorchTrainer + │ func: user train() + │ framework: torch + │ + └── JAXTrainer + func: user train() + framework: jax ``` -### Component Interaction Flow - -End-to-end for a TRL SFT job: +#### End-to-End Flow: TRL SFT Job ``` -1. User: TrainerClient.train( - trainer=TRLTrainer(trainer_type=SFT, ...), - runtime="trl-llama3.2-1b") - - -- OR with auto-discovery -- - - User: TrainerClient.train( - trainer=TRLTrainer(trainer_type=SFT, ...)) - # SDK finds runtime with label trainer.kubeflow.org/framework: trl - -2. SDK: TRLTrainer.validate() → ok - TRLTrainer.command → ("trl",) - TRLTrainer.to_args() → ["sft", "--model_name_or_path", ...] - Build TrainJob CR with: - runtimeRef: { name: "trl-llama3.2-1b" } - trainer: { command: ["trl"], args: ["sft", ...] } - -3. K8s: Webhook validates TrainJob - Torch plugin Validate() → label=trl → TRLStrategy.Validate() - -4. Go: TrainJob controller reconciles: - Torch EnforceMLPolicy(): - a) Common: set PET_NNODES, PET_NPROC_PER_NODE, PET_NODE_RANK - b) Label "trl" → TRLStrategy.EnforceCommand(): - inject PET_MASTER_ADDR, PET_MASTER_PORT - inject MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK - c) Add container port - -5. K8s: Controller SSA → JobSet → ReplicatedJobs → Pods - Init: dataset-initializer downloads dataset - Init: model-initializer downloads model - Main: trl sft --model_name_or_path /workspace/model ... + ┌──────────────────────────────────────────────────────────────────────┐ + │ DATA SCIENTIST │ + │ │ + │ client.train( │ + │ trainer=TRLTrainer(trainer_type=SFT, learning_rate=2e-5, ...), │ + │ initializer=Initializer(model=HF("llama-3.2-1b"), dataset=...) │ + │ ) │ + └──────────────────────────────┬───────────────────────────────────────┘ + │ + ▼ + ┌──────────────────────────────────────────────────────────────────────┐ + │ PYTHON SDK │ + │ │ + │ 1. Auto-discover runtime: list_runtimes() │ + │ → filter by label trainer.kubeflow.org/framework: trl │ + │ → selects "trl-llama3.2-1b" │ + │ │ + │ 2. TRLTrainer.validate() → ok │ + │ │ + │ 3. Build TrainJob CR: │ + │ command: ["trl"] ← from TRLTrainer.command │ + │ args: ["sft", "--model_name_or_path", "/workspace/model", ...] │ + │ ← from TRLTrainer.to_args() │ + │ runtimeRef: "trl-llama3.2-1b" │ + └──────────────────────────────┬───────────────────────────────────────┘ + │ kubectl apply + ▼ + ┌──────────────────────────────────────────────────────────────────────┐ + │ KUBERNETES API SERVER │ + │ │ + │ Webhook → Torch plugin Validate() │ + │ → label "trl" → TRLStrategy.Validate() → ok │ + └──────────────────────────────┬───────────────────────────────────────┘ + │ reconcile + ▼ + ┌──────────────────────────────────────────────────────────────────────┐ + │ GO TORCH PLUGIN — EnforceMLPolicy() │ + │ │ + │ Common (all frameworks): │ + │ PET_NNODES=1, PET_NPROC_PER_NODE=auto, PET_NODE_RANK=... │ + │ │ + │ Label dispatch: │ + │ strategies["trl"] → TRLStrategy.EnforceCommand(): │ + │ + PET_MASTER_ADDR=-node-0-0. │ + │ + PET_MASTER_PORT=29500 │ + │ + MASTER_ADDR=-node-0-0. ← accelerate compat │ + │ + MASTER_PORT=29500 │ + │ + WORLD_SIZE= │ + │ + RANK= │ + └──────────────────────────────┬───────────────────────────────────────┘ + │ SSA + ▼ + ┌──────────────────────────────────────────────────────────────────────┐ + │ KUBERNETES PODS │ + │ │ + │ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────────┐ │ + │ │ dataset-init │ │ model-init │ │ trainer node │ │ + │ │ ──────────────── │ │ ──────────────── │ │ ─────────────────── │ │ + │ │ hf://tatsu-lab/ │→ │ hf://meta-llama/ │→ │ trl sft \ │ │ + │ │ alpaca │ │ Llama-3.2-1B │ │ --model ... \ │ │ + │ │ │ │ │ │ --dataset ... \ │ │ + │ │ /workspace/ │ │ /workspace/ │ │ --bf16 │ │ + │ │ dataset/ │ │ model/ │ │ --lora_r 16 │ │ + │ └─────────────────┘ └─────────────────┘ └─────────────────────┘ │ + └──────────────────────────────────────────────────────────────────────┘ ``` ### What Changes vs What Stays From 30e0d10fcd4567ad47030c028dbe6b2155303cf5 Mon Sep 17 00:00:00 2001 From: Sabari Date: Wed, 1 Apr 2026 10:43:09 +0530 Subject: [PATCH 11/11] docs: add user stories, refactor scope, and trim code in KEP-2839 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make the KEP more compelling for maintainer review: - Add User Stories section with 3 concrete data scientist pain points - Add Why TRL section with feature comparison table - Add Refactor Scope section showing exact file/line counts (~400 lines total to unlock every future LLM backend) - Trim verbose implementation code (to_args body, TRLStrategy body, dispatch code) — keep interfaces and concepts, defer details to LLD - Replace code-heavy Go sections with tables and concise explanations --- .../2839-dynamic-llm-trainer/README.md | 326 ++++++------------ 1 file changed, 113 insertions(+), 213 deletions(-) diff --git a/docs/proposals/2839-dynamic-llm-trainer/README.md b/docs/proposals/2839-dynamic-llm-trainer/README.md index 844a1337b4..648b47a019 100644 --- a/docs/proposals/2839-dynamic-llm-trainer/README.md +++ b/docs/proposals/2839-dynamic-llm-trainer/README.md @@ -113,6 +113,43 @@ at two coupling points: are [asking exactly how config-driven trainers fit in](https://github.com/kubeflow/sdk/pull/308#discussion_r2912976804). This KEP provides that answer. +### User Stories + +**"I want to do DPO alignment, but Kubeflow only supports SFT via TorchTune."** + +A data scientist wants to align a model using preference data (DPO). TorchTune +doesn't support DPO, and there's no way to plug in TRL without modifying the SDK +source code. They fall back to raw YAML or leave Kubeflow entirely. + +**"I want to use a newer model that TorchTune doesn't have recipes for."** + +TorchTune supports 4 models (Llama 3.2 1B/3B, Llama 3.3 70B, Qwen 2.5 1.5B). +When a user tries a model outside this list, the Go validation rejects it. TRL +works with any Hugging Face model out of the box. + +**"I want to switch from TorchTune to TRL without relearning the SDK."** + +A team that started with TorchTune wants to migrate to TRL for its active +development and broader algorithm support. Today this requires understanding +`BuiltinTrainer` internals. With this KEP, it's a one-line change: +`TorchTuneTrainer(...)` → `TRLTrainer(...)`. + +### Why TRL as the First New Backend + +| | TorchTune | TRL | +|--|-----------|-----| +| **Status** | Maintenance mode (July 2025) | Actively maintained by Hugging Face | +| **Algorithms** | SFT only | SFT, DPO, KTO, GRPO, PPO, RLOO | +| **Models** | 4 hardcoded models | Any Hugging Face model | +| **CLI** | `tune run ` | `trl sft \| dpo \| kto \| grpo` | +| **Distributed** | torchrun | accelerate (+ torchrun compat) | +| **PEFT** | Built-in LoRA/QLoRA/DoRA | Via `peft` library (LoRA/QLoRA) | +| **Community** | ~12k GitHub stars | ~13k GitHub stars, 250+ contributors | + +TRL is the most requested alternative in +[#2839](https://github.com/kubeflow/trainer/issues/2839) and aligns with the +Hugging Face ecosystem that most Kubeflow users already use for models and datasets. + --- ## Goals @@ -434,6 +471,40 @@ controllers, no changes to the plugin framework itself. | Container images | **New** | `trl-trainer` image | | ClusterTrainingRuntimes | **New** | TRL-specific runtime manifests | +### Refactor Scope: What This Actually Touches + +This KEP is designed to be **low-risk and minimal**. Here is the concrete scope: + +**Python SDK (kubeflow/sdk) — ~200 lines changed across 3 files:** + +| File | Change | Lines | +|------|--------|-------| +| `types/types.py` | Add `LLMTrainer` ABC (~30 lines), rename `TorchTuneConfig` class + alias (~5 lines), add `TRLTrainer` (~60 lines) | ~95 new | +| `backends/kubernetes/utils.py` | Replace `isinstance(config, TorchTuneConfig)` with generic `config.command` / `config.to_args()` | ~20 changed | +| `api/trainer_client.py` | Widen `trainer=` union type | ~3 changed | + +**Go control plane (kubeflow/trainer) — ~150 lines moved, ~100 lines new:** + +| File | Change | Lines | +|------|--------|-------| +| `torch/strategy.go` | New `FrameworkStrategy` interface | ~15 new | +| `torch/torchtune_strategy.go` | **Moved** from `torch.go` (no logic change) | ~80 moved | +| `torch/trl_strategy.go` | New `TRLStrategy` | ~50 new | +| `torch/torch.go` | Replace if/else with `strategies[label]` lookup | ~10 changed | +| `constants/constants.go` | Add `FrameworkLabel` constant | 1 new | + +**Key point:** The TorchTune code path is **moved, not rewritten**. The +`TorchTuneStrategy` wraps the exact same functions (`getRecipeAndConfig`, +`extractOverridesFromRuntime`, `validateTorchTune`) that exist today. Existing +tests continue to pass without modification. + +**New infrastructure:** +- 1 Dockerfile (~10 lines) +- 1 ClusterTrainingRuntime manifest (~70 lines YAML) +- Helm chart additions (~20 lines) + +**Total: ~400 lines of new/changed code to unlock every future LLM backend.** + --- ## Design Details @@ -602,111 +673,29 @@ class TRLTrainer(LLMTrainer): extra_args: Optional[dict[str, str]] = None def to_args(self, initializer=None) -> list[str]: - args = [self.trainer_type.value] # subcommand: "sft", "dpo", etc. - - # Model path: prefer initializer workspace, fall back to config. - model_path = self.model_name_or_path - if initializer and initializer.model: - model_path = "/workspace/model" - if model_path: - args.extend(["--model_name_or_path", model_path]) - - # Dataset: prefer initializer workspace, fall back to config. - dataset = self.dataset_name - if initializer and initializer.dataset: - dataset = "/workspace/dataset" - if dataset: - args.extend(["--dataset_name", dataset]) - - if self.learning_rate is not None: - args.extend(["--learning_rate", str(self.learning_rate)]) - if self.num_train_epochs is not None: - args.extend(["--num_train_epochs", str(self.num_train_epochs)]) - if self.per_device_train_batch_size is not None: - args.extend(["--per_device_train_batch_size", - str(self.per_device_train_batch_size)]) - if self.gradient_checkpointing: - args.append("--gradient_checkpointing") - if self.bf16: - args.append("--bf16") - if self.use_peft: - args.append("--use_peft") - if self.lora_r is not None: - args.extend(["--lora_r", str(self.lora_r)]) - if self.lora_alpha is not None: - args.extend(["--lora_alpha", str(self.lora_alpha)]) - if self.lora_target_modules: - args.extend(["--lora_target_modules", self.lora_target_modules]) - - # Pass-through extra args. - if self.extra_args: - for k, v in self.extra_args.items(): - args.extend([f"--{k}", v]) - - return args + # Produces: ["sft", "--model_name_or_path", "/workspace/model", ...] + # Prefers initializer workspace paths over config values. + # Full implementation in LLD. + ... def validate(self) -> None: - if self.use_peft and self.lora_r is None: - raise ValueError("lora_r is required when use_peft=True") + # e.g., lora_r required when use_peft=True + ... ``` ### Python SDK: TrainerClient Integration -The `TrainerClient.train()` signature widens to accept `LLMTrainer` directly, -alongside KEP-285's `BaseTrainer`: - -```python -class TrainerClient: - - def train( - self, - runtime: Optional[Union[str, "Runtime"]] = None, - initializer: Optional["Initializer"] = None, - trainer: Optional[ - Union[ - "BaseTrainer", # KEP-285: function-based - "LLMTrainer", # This KEP: config-driven - "CustomTrainer", # Existing - "CustomTrainerContainer", # Existing - "BuiltinTrainer", # Existing - ] - ] = None, - runtime_config: Optional["RuntimeConfig"] = None, # KEP-285 - options: Optional[list] = None, - ) -> str: -``` +`TrainerClient.train(trainer=...)` widens to accept `LLMTrainer` directly in the +union type, alongside `BaseTrainer` (KEP-285), `CustomTrainer`, and `BuiltinTrainer`. When a `LLMTrainer` is passed: -1. If `runtime` is `None`, the SDK auto-discovers a runtime by matching the - `trainer.kubeflow.org/framework` label against `supported_frameworks`. -2. `validate_runtime()` ensures the runtime's framework label matches. -3. The backend uses `config.command` and `config.to_args()` to build the TrainJob CR. - -The backend handler for `LLMTrainer`: - -```python -# In KubernetesBackend — unified handler for LLMTrainer. - -def get_trainer_cr_from_llm_trainer( - runtime: types.Runtime, - trainer: LLMTrainer, - initializer: Optional[types.Initializer] = None, -) -> models.TrainerV1alpha1Trainer: - trainer.validate() - - trainer_cr = models.TrainerV1alpha1Trainer() - if trainer.num_nodes: - trainer_cr.num_nodes = trainer.num_nodes - if trainer.resources_per_node: - trainer_cr.resources_per_node = get_resources_per_node( - trainer.resources_per_node - ) - - trainer_cr.command = list(trainer.command) - trainer_cr.args = trainer.to_args(initializer) - return trainer_cr -``` +1. **Runtime auto-discovery**: If `runtime` is `None`, the SDK calls + `list_runtimes()` and filters by `trainer.kubeflow.org/framework` matching + `supported_frameworks`. One match → auto-selected. Multiple → error with list. +2. **Validation**: `validate_runtime()` ensures the runtime's framework label matches. +3. **Generic dispatch**: The backend uses `config.command` and `config.to_args()` + to build the TrainJob CR — no `isinstance` checks, no framework-specific code paths. ### Python SDK: Backward Compatibility @@ -802,135 +791,46 @@ environment variables (`MASTER_ADDR`, `MASTER_PORT`, `WORLD_SIZE`, `RANK`) rathe than the `PET_*` variants used by torchrun. The strategy injects both sets for maximum compatibility. -```go -// pkg/runtime/framework/plugins/torch/trl_strategy.go - -type TRLStrategy struct{} +TRL uses `accelerate` for distributed training, which reads standard env vars +(`MASTER_ADDR`, `MASTER_PORT`, `WORLD_SIZE`, `RANK`) rather than the `PET_*` +variants. `TRLStrategy.EnforceCommand()` injects **both sets** for compatibility: -func (s *TRLStrategy) EnforceCommand( - info *runtime.Info, - trainJob *trainer.TrainJob, - container *runtime.Container, -) error { - trainerPS := info.FindPodSetByAncestor(constants.AncestorTrainer) - numNodes := ptr.Deref( - ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1, - ) - masterAddr := fmt.Sprintf( - "%s-%s-0-0.%s", - trainJob.Name, constants.Node, trainJob.Name, - ) - masterPort := fmt.Sprintf("%d", constants.ContainerTrainerPort) - worldSize := fmt.Sprintf("%d", numNodes*numProcPerNode) - - // Inject both PET_* (torchrun compat) and standard env vars - // (accelerate/TRL). - apply.UpsertEnvVars(&container.Env, - *corev1ac.EnvVar(). - WithName(constants.TorchEnvMasterAddr). - WithValue(masterAddr), - *corev1ac.EnvVar(). - WithName(constants.TorchEnvMasterPort). - WithValue(masterPort), - *corev1ac.EnvVar(). - WithName("MASTER_ADDR").WithValue(masterAddr), - *corev1ac.EnvVar(). - WithName("MASTER_PORT").WithValue(masterPort), - *corev1ac.EnvVar(). - WithName("WORLD_SIZE").WithValue(worldSize), - *corev1ac.EnvVar().WithName("RANK").WithValueFrom( - corev1ac.EnvVarSource().WithFieldRef( - corev1ac.ObjectFieldSelector().WithFieldPath( - constants.JobCompletionIndexFieldPath, - ), - ), - ), - ) - return nil -} +| Env Var | Source | Purpose | +|---------|--------|---------| +| `PET_MASTER_ADDR` | Existing | torchrun compatibility | +| `PET_MASTER_PORT` | Existing | torchrun compatibility | +| `MASTER_ADDR` | **New** | accelerate/TRL | +| `MASTER_PORT` | **New** | accelerate/TRL | +| `WORLD_SIZE` | **New** | accelerate/TRL | +| `RANK` | **New** | From `JOB_COMPLETION_INDEX` | -func (s *TRLStrategy) Validate( - runtimeInfo *runtime.Info, - trainJob *trainer.TrainJob, -) (admission.Warnings, field.ErrorList) { - // TRL validation is minimal — config is fully constructed by the SDK. - return nil, nil -} -``` +`TRLStrategy.Validate()` is minimal — TRL config is fully constructed by the SDK, +so Go-side validation only checks structural constraints. ### Go Control Plane: Refactored Torch Plugin Dispatch -The `Torch` struct gains a `strategies` map, and `EnforceMLPolicy` dispatches by -the `trainer.kubeflow.org/framework` label: +The `Torch` struct gains a `strategies map[string]FrameworkStrategy` and both +`EnforceMLPolicy` and `Validate` change from command-sniffing to a 3-line +label lookup: ```go -// pkg/runtime/framework/plugins/torch/torch.go (modified) - -type Torch struct { - strategies map[string]FrameworkStrategy -} - -func New( - ctx context.Context, - c client.Client, - fi client.FieldIndexer, -) (framework.Plugin, error) { - return &Torch{ - strategies: map[string]FrameworkStrategy{ - "torchtune": &TorchTuneStrategy{}, - "trl": &TRLStrategy{}, - }, - }, nil +// BEFORE (torch.go:149) +if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) { + // torchrun path +} else { + // TorchTune path } -``` - -The dispatch logic in `EnforceMLPolicy` changes from command-sniffing to label -lookup: -```go -func (t *Torch) EnforceMLPolicy( - info *runtime.Info, - trainJob *trainer.TrainJob, -) error { - // ... (existing common logic: numNodes, numProcPerNode, PET_NNODES, - // PET_NPROC_PER_NODE, PET_NODE_RANK — unchanged) ... - - // Label-based dispatch replaces command-sniffing. - fw := info.Labels[constants.FrameworkLabel] - if strategy, ok := t.strategies[fw]; ok { - return strategy.EnforceCommand(info, trainJob, trainerContainer) - } - - // Default: standard torchrun path (PET_MASTER_ADDR, PET_MASTER_PORT). - apply.UpsertEnvVars(&trainerContainer.Env, - *corev1ac.EnvVar(). - WithName(constants.TorchEnvMasterAddr).WithValue(masterAddr), - *corev1ac.EnvVar(). - WithName(constants.TorchEnvMasterPort).WithValue(masterPort), - ) - return nil +// AFTER +fw := info.Labels[constants.FrameworkLabel] +if strategy, ok := t.strategies[fw]; ok { + return strategy.EnforceCommand(info, trainJob, trainerContainer) } +// else: default torchrun path (unchanged) ``` -The same pattern applies to `Validate`: - -```go -func (t *Torch) Validate( - ctx context.Context, - runtimeInfo *runtime.Info, - _, newObj *trainer.TrainJob, -) (admission.Warnings, field.ErrorList) { - // ... (existing common validation: numProcPerNode, reserved envs) ... - - fw := runtimeInfo.Labels[constants.FrameworkLabel] - if strategy, ok := t.strategies[fw]; ok { - warnings, errs := strategy.Validate(runtimeInfo, newObj) - allErrs = append(allErrs, errs...) - return warnings, allErrs - } - return nil, allErrs -} -``` +New strategies are registered in the constructor — adding a future backend is +one line: `"unsloth": &UnslothStrategy{}`. ### Go Control Plane: New Constant