From 9366887e0ff30822a125300b705f8a946abb50ab Mon Sep 17 00:00:00 2001 From: Kovbo Date: Thu, 28 May 2026 18:56:31 +0000 Subject: [PATCH 1/4] Support PipelineTrainer with ServerlessBackend --- src/art/model.py | 26 ++- src/art/serverless/backend.py | 99 ++++++++- src/art/serverless/client.py | 11 + tests/unit/test_metric_routing.py | 25 ++- ...test_serverless_pipeline_trainer_compat.py | 189 ++++++++++++++++++ 5 files changed, 334 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test_serverless_pipeline_trainer_compat.py diff --git a/src/art/model.py b/src/art/model.py index f499fe1d3..391f79b3f 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -618,15 +618,16 @@ def _get_wandb_run(self) -> Optional["Run"]: # This allows out-of-order logging (e.g., async validation for previous steps). run.define_metric("training_step") run.define_metric("time/wall_clock_sec") + for split in sorted(METRIC_SPLITS): + run.define_metric(f"{split}/step", hidden=True) run.define_metric("reward/*", step_metric="training_step") run.define_metric("loss/*", step_metric="training_step") run.define_metric("throughput/*", step_metric="training_step") run.define_metric("costs/*", step_metric="training_step") run.define_metric("time/*", step_metric="training_step") run.define_metric("data/*", step_metric="training_step") - run.define_metric("train/*", step_metric="training_step") - run.define_metric("val/*", step_metric="training_step") - run.define_metric("test/*", step_metric="training_step") + for split in sorted(METRIC_SPLITS): + run.define_metric(f"{split}/*", step_metric=f"{split}/step") run.define_metric("discarded/*", step_metric="training_step") self._sync_wandb_config(run) return self._wandb_run @@ -654,6 +655,13 @@ def _log_metrics( prefixed = {f"{split}/{k}": v for k, v in metrics.items()} prefixed["training_step"] = step + split_prefixes = { + key.split("/", 1)[0] + for key in prefixed + if "/" in key and key.split("/", 1)[0] in METRIC_SPLITS + } + for split_prefix in split_prefixes: + prefixed[f"{split_prefix}/step"] = step prefixed["time/wall_clock_sec"] = time.time() - self._run_start_time output_dir = self._get_output_dir() @@ -691,11 +699,19 @@ def _define_wandb_step_metrics(self, keys: Iterable[str]) -> None: return for key in keys: - if not key.startswith("costs/"): + first_component = key.split("/", 1)[0] + if key in {"training_step", "time/wall_clock_sec"} or key.endswith("/step"): + continue + if "/" not in key: continue if key in self._wandb_defined_metrics: continue - run.define_metric(key, step_metric="training_step") + step_metric = ( + f"{first_component}/step" + if first_component in METRIC_SPLITS + else "training_step" + ) + run.define_metric(key, step_metric=step_metric, overwrite=True) self._wandb_defined_metrics.add(key) def _route_metrics_and_collect_non_costs( diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 1aa5b1f55..fc0f5cb84 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,7 +1,6 @@ import asyncio import time from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal -import warnings from openai._types import NOT_GIVEN from tqdm import auto as tqdm @@ -197,8 +196,15 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, + loss_fn: Literal["cispo", "ppo"] | None = None, + loss_fn_config: dict | None = None, + normalize_advantages: bool = True, + adam_params: object | None = None, + # KL-penalized advantage adjustment + kl_penalty_coef: float = 0.0, + kl_ref_adapter_path: str | None = None, # RL algorithm settings - ppo: bool = False, + ppo: bool | None = None, epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -213,6 +219,15 @@ async def train( # type: ignore[override] # Experimental parameters kimi_k2_tau: float | None = None, precalculate_logprobs: bool = False, + allow_training_without_logprobs: bool = False, + plot_tensors: bool = False, + truncated_importance_sampling: float | None = None, + scale_learning_rate_by_reward_std_dev: bool = False, + logprob_calculation_chunk_size: int = 1024, + packed_sequence_length: int | None = None, + num_trajectories_learning_rate_multiplier_power: float = 0.0, + # Checkpoint behavior + save_checkpoint: bool = True, # Verbosity verbose: bool = False, ) -> ServerlessTrainResult: @@ -226,7 +241,20 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. - ppo: Whether to use PPO clipping. Defaults to False. + loss_fn: RL loss function. ServerlessBackend supports "cispo" and + "ppo". If unset, the legacy ppo argument is used. + loss_fn_config: Additional loss-function config. Not supported by + ServerlessBackend. + normalize_advantages: Backward-compatible alias for reward std scaling. + When False, ServerlessBackend centers rewards but does not divide + by group reward std dev. + adam_params: Custom optimizer params. Not supported by + ServerlessBackend. + kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. + Defaults to 0.0 (disabled). + kl_ref_adapter_path: Direct filesystem path to a LoRA adapter + checkpoint to use as the KL reference. + ppo: Legacy flag for PPO clipping. Prefer loss_fn="ppo". epsilon: Clip epsilon for importance sampling. Defaults based on ppo. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages @@ -240,6 +268,22 @@ async def train( # type: ignore[override] mask_prob_ratio: Whether to mask probability ratios. Defaults to False. kimi_k2_tau: Tau parameter for Kimi K2 algorithm. precalculate_logprobs: Whether to precalculate logprobs. + allow_training_without_logprobs: Allow training even when no logprobs + are available. Defaults to False. + plot_tensors: Whether to plot training tensors for debugging. + Defaults to False. + truncated_importance_sampling: Truncation threshold for importance + sampling weights. + scale_learning_rate_by_reward_std_dev: Whether to scale learning rate + by reward standard deviation. Defaults to False. + logprob_calculation_chunk_size: Chunk size for logprob calculation. + Defaults to 1024. + packed_sequence_length: Packed sequence length to use for training. + num_trajectories_learning_rate_multiplier_power: Power for learning + rate multiplier based on number of trajectories. + save_checkpoint: Accepted for PipelineTrainer compatibility. Serverless + training currently always saves a trainable checkpoint for the next + inference step. verbose: Whether to print verbose output. Defaults to False. Returns: @@ -252,6 +296,23 @@ async def train( # type: ignore[override] # await model.log(metrics=result.metrics, step=result.step) """ groups_list = list(trajectory_groups) + if loss_fn is None: + resolved_loss_fn: Literal["cispo", "ppo"] = "ppo" if ppo else "cispo" + else: + resolved_loss_fn = loss_fn + if ppo is not None and ppo != (loss_fn == "ppo"): + raise ValueError("ServerlessBackend got conflicting loss_fn and ppo.") + if resolved_loss_fn not in {"cispo", "ppo"}: + raise ValueError( + "ServerlessBackend only supports loss_fn='cispo' or 'ppo'." + ) + if loss_fn_config is not None: + raise ValueError("ServerlessBackend requires loss_fn_config=None.") + if not normalize_advantages: + scale_rewards = False + if adam_params is not None: + raise ValueError("ServerlessBackend requires adam_params=None.") + _ = save_checkpoint config, dev_config = build_rl_train_configs( learning_rate=learning_rate, @@ -259,12 +320,21 @@ async def train( # type: ignore[override] scale_rewards=scale_rewards, importance_sampling_level=importance_sampling_level, mask_prob_ratio=mask_prob_ratio, - ppo=ppo, + ppo=resolved_loss_fn == "ppo", precalculate_logprobs=precalculate_logprobs, epsilon=epsilon, epsilon_high=epsilon_high, max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight, kimi_k2_tau=kimi_k2_tau, + kl_penalty_coef=kl_penalty_coef, + allow_training_without_logprobs=allow_training_without_logprobs, + plot_tensors=plot_tensors, + truncated_importance_sampling=truncated_importance_sampling, + scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev, + logprob_calculation_chunk_size=logprob_calculation_chunk_size, + packed_sequence_length=packed_sequence_length, + num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power, + kl_ref_adapter_path=kl_ref_adapter_path, ) # Collect metrics from training @@ -317,18 +387,39 @@ async def _train_model( trajectory_groups=trajectory_groups, experimental_config=ExperimentalTrainingConfig( advantage_balance=dev_config.get("advantage_balance"), + allow_training_without_logprobs=dev_config.get( + "allow_training_without_logprobs" + ), epsilon=dev_config.get("epsilon"), epsilon_high=dev_config.get("epsilon_high"), importance_sampling_level=dev_config.get("importance_sampling_level"), kimi_k2_tau=dev_config.get("kimi_k2_tau"), + kl_penalty_coef=dev_config.get("kl_penalty_coef"), + kl_ref_adapter_path=dev_config.get("kl_ref_adapter_path"), learning_rate=config.learning_rate, + logprob_calculation_chunk_size=dev_config.get( + "logprob_calculation_chunk_size" + ), + loss_fn="ppo" if dev_config.get("ppo") else "cispo", mask_prob_ratio=dev_config.get("mask_prob_ratio"), max_negative_advantage_importance_sampling_weight=dev_config.get( "max_negative_advantage_importance_sampling_weight" ), + normalize_advantages=dev_config.get("scale_rewards"), + num_trajectories_learning_rate_multiplier_power=dev_config.get( + "num_trajectories_learning_rate_multiplier_power" + ), + packed_sequence_length=dev_config.get("packed_sequence_length"), + plot_tensors=dev_config.get("plot_tensors"), ppo=dev_config.get("ppo"), precalculate_logprobs=dev_config.get("precalculate_logprobs"), + scale_learning_rate_by_reward_std_dev=dev_config.get( + "scale_learning_rate_by_reward_std_dev" + ), scale_rewards=dev_config.get("scale_rewards"), + truncated_importance_sampling=dev_config.get( + "truncated_importance_sampling" + ), ), ) after: str | None = None diff --git a/src/art/serverless/client.py b/src/art/serverless/client.py index 7b92c2276..19d724e7d 100644 --- a/src/art/serverless/client.py +++ b/src/art/serverless/client.py @@ -52,18 +52,29 @@ class DeleteCheckpointsResponse(BaseModel): class ExperimentalTrainingConfig(TypedDict, total=False): advantage_balance: float | None + allow_training_without_logprobs: bool | None epsilon: float | None epsilon_high: float | None importance_sampling_level: ( Literal["token", "sequence", "average", "geometric_average"] | None ) kimi_k2_tau: float | None + kl_penalty_coef: float | None + kl_ref_adapter_path: str | None learning_rate: float | None + logprob_calculation_chunk_size: int | None + loss_fn: Literal["cispo", "ppo"] | None mask_prob_ratio: bool | None max_negative_advantage_importance_sampling_weight: float | None + normalize_advantages: bool | None + num_trajectories_learning_rate_multiplier_power: float | None + packed_sequence_length: int | None + plot_tensors: bool | None ppo: bool | None precalculate_logprobs: bool | None + scale_learning_rate_by_reward_std_dev: bool | None scale_rewards: bool | None + truncated_importance_sampling: float | None class SFTTrainingConfig(TypedDict, total=False): diff --git a/tests/unit/test_metric_routing.py b/tests/unit/test_metric_routing.py index 5a290ebfb..cadca3bf5 100644 --- a/tests/unit/test_metric_routing.py +++ b/tests/unit/test_metric_routing.py @@ -67,19 +67,22 @@ def test_get_wandb_run_registers_taxonomy_sections(self, tmp_path: Path) -> None assert define_calls == [ (("training_step",), {}), (("time/wall_clock_sec",), {}), + (("test/step",), {"hidden": True}), + (("train/step",), {"hidden": True}), + (("val/step",), {"hidden": True}), (("reward/*",), {"step_metric": "training_step"}), (("loss/*",), {"step_metric": "training_step"}), (("throughput/*",), {"step_metric": "training_step"}), (("costs/*",), {"step_metric": "training_step"}), (("time/*",), {"step_metric": "training_step"}), (("data/*",), {"step_metric": "training_step"}), - (("train/*",), {"step_metric": "training_step"}), - (("val/*",), {"step_metric": "training_step"}), - (("test/*",), {"step_metric": "training_step"}), + (("test/*",), {"step_metric": "test/step"}), + (("train/*",), {"step_metric": "train/step"}), + (("val/*",), {"step_metric": "val/step"}), (("discarded/*",), {"step_metric": "training_step"}), ] - def test_log_metrics_defines_nested_cost_keys_with_training_step( + def test_log_metrics_defines_concrete_wandb_metric_keys_with_training_step( self, tmp_path: Path ) -> None: fake_run = MagicMock() @@ -103,8 +106,10 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step( { "costs/train/sample": 0.1, "costs/cum/train/prefill": 0.2, + "reward": 0.3, + "val/reward": 0.4, }, - split="train", + split="val", step=1, ) @@ -113,16 +118,22 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step( ] assert ( ("costs/train/sample",), - {"step_metric": "training_step"}, + {"step_metric": "training_step", "overwrite": True}, ) in define_calls assert ( ("costs/cum/train/prefill",), - {"step_metric": "training_step"}, + {"step_metric": "training_step", "overwrite": True}, + ) in define_calls + assert ( + ("val/reward",), + {"step_metric": "val/step", "overwrite": True}, ) in define_calls fake_run.log.assert_called_once() logged_metrics = fake_run.log.call_args.args[0] assert logged_metrics["costs/train/sample"] == 0.1 assert logged_metrics["costs/cum/train/prefill"] == 0.2 + assert logged_metrics["val/reward"] == 0.4 + assert logged_metrics["val/step"] == 1 assert logged_metrics["training_step"] == 1 assert "time/wall_clock_sec" in logged_metrics assert fake_run.log.call_args.kwargs == {} diff --git a/tests/unit/test_serverless_pipeline_trainer_compat.py b/tests/unit/test_serverless_pipeline_trainer_compat.py new file mode 100644 index 000000000..fec8d23f7 --- /dev/null +++ b/tests/unit/test_serverless_pipeline_trainer_compat.py @@ -0,0 +1,189 @@ +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from art import TrainableModel, Trajectory, TrajectoryGroup +from art.serverless.backend import ServerlessBackend +from art.types import TrainConfig + + +def _make_group() -> TrajectoryGroup: + return TrajectoryGroup( + [ + Trajectory( + reward=1.0, + messages_and_choices=[ + {"role": "user", "content": "prompt"}, + {"role": "assistant", "content": "answer"}, + ], + ) + ] + ) + + +def _make_backend() -> ServerlessBackend: + with patch("art.serverless.backend.Client") as client_cls: + client = MagicMock() + client.base_url = "http://serverless.test/v1" + client_cls.return_value = client + return ServerlessBackend(api_key="test-key") + + +@pytest.mark.asyncio +async def test_serverless_train_accepts_pipeline_trainer_kwargs() -> None: + backend = _make_backend() + model = TrainableModel( + name="serverless-pipeline-compat", + project="pipeline-tests", + base_model="test-model", + ) + model.id = "model-id" + model.entity = "entity" + + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: TrainConfig, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["config"] = config + seen["dev_config"] = dev_config + seen["verbose"] = verbose + yield {"loss": 0.25} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=3) # type: ignore[method-assign] + + with patch.object(model, "_get_wandb_run", return_value=None): + result = await backend.train( + model, + [_make_group()], + learning_rate=2e-5, + loss_fn="ppo", + normalize_advantages=False, + save_checkpoint=False, + packed_sequence_length=4096, + kl_penalty_coef=0.1, + kl_ref_adapter_path="/tmp/ref-adapter", + allow_training_without_logprobs=True, + plot_tensors=True, + truncated_importance_sampling=2.0, + scale_learning_rate_by_reward_std_dev=True, + logprob_calculation_chunk_size=512, + num_trajectories_learning_rate_multiplier_power=0.5, + verbose=True, + ) + + assert result.step == 3 + assert ( + result.artifact_name == "entity/pipeline-tests/serverless-pipeline-compat:step3" + ) + assert seen["config"].learning_rate == 2e-5 + assert seen["config"].kl_penalty_coef == 0.1 + assert seen["verbose"] is True + assert seen["dev_config"] == { + "advantage_balance": 0.0, + "allow_training_without_logprobs": True, + "importance_sampling_level": "token", + "kl_penalty_coef": 0.1, + "kl_ref_adapter_path": "/tmp/ref-adapter", + "logprob_calculation_chunk_size": 512, + "mask_prob_ratio": False, + "num_trajectories_learning_rate_multiplier_power": 0.5, + "packed_sequence_length": 4096, + "plot_tensors": True, + "ppo": True, + "precalculate_logprobs": False, + "scale_learning_rate_by_reward_std_dev": True, + "scale_rewards": False, + "truncated_importance_sampling": 2.0, + } + + +@pytest.mark.asyncio +async def test_serverless_train_rejects_unsupported_pipeline_kwargs() -> None: + backend = _make_backend() + model = TrainableModel( + name="serverless-pipeline-rejects", + project="pipeline-tests", + base_model="test-model", + ) + + with pytest.raises(ValueError, match="loss_fn_config=None"): + await backend.train(model, [_make_group()], loss_fn_config={"clip": 0.2}) + + with pytest.raises(ValueError, match="adam_params=None"): + await backend.train(model, [_make_group()], adam_params=object()) + + with pytest.raises(ValueError, match="conflicting loss_fn and ppo"): + await backend.train(model, [_make_group()], loss_fn="ppo", ppo=False) + + +@pytest.mark.asyncio +async def test_serverless_train_model_forwards_experimental_config() -> None: + backend = _make_backend() + model = TrainableModel( + name="serverless-config-payload", + project="pipeline-tests", + base_model="test-model", + ) + model.id = "model-id" + + captured: dict[str, Any] = {} + backend._client.training_jobs.create = AsyncMock( # type: ignore[attr-defined] + side_effect=lambda **kwargs: ( + captured.update(kwargs) or SimpleNamespace(id="training-job-id") + ) + ) + + async def events_list(**_kwargs: Any): + yield SimpleNamespace(id="event-id", type="training_ended", data={}) + + backend._client.training_jobs.events.list = events_list # type: ignore[attr-defined] + + async def no_sleep(_seconds: float) -> None: + return None + + with patch("art.serverless.backend.asyncio.sleep", no_sleep): + async for _ in backend._train_model( + model, + [_make_group()], + TrainConfig(learning_rate=7e-6, kl_penalty_coef=0.2), + { + "advantage_balance": 0.3, + "allow_training_without_logprobs": True, + "epsilon": 0.1, + "epsilon_high": 0.2, + "importance_sampling_level": "sequence", + "kimi_k2_tau": 0.4, + "kl_penalty_coef": 0.2, + "kl_ref_adapter_path": "/tmp/ref", + "logprob_calculation_chunk_size": 512, + "mask_prob_ratio": True, + "max_negative_advantage_importance_sampling_weight": 3.0, + "num_trajectories_learning_rate_multiplier_power": 0.5, + "packed_sequence_length": 4096, + "plot_tensors": True, + "ppo": True, + "precalculate_logprobs": True, + "scale_learning_rate_by_reward_std_dev": True, + "scale_rewards": False, + "truncated_importance_sampling": 2.0, + }, + ): + pass + + payload = captured["experimental_config"] + assert payload["learning_rate"] == 7e-6 + assert payload["loss_fn"] == "ppo" + assert payload["normalize_advantages"] is False + assert payload["packed_sequence_length"] == 4096 + assert payload["kl_penalty_coef"] == 0.2 + assert payload["kl_ref_adapter_path"] == "/tmp/ref" + assert payload["allow_training_without_logprobs"] is True + assert payload["scale_learning_rate_by_reward_std_dev"] is True From e96dec2521b6f0489e513bdee3c6175e2f58fbba Mon Sep 17 00:00:00 2001 From: Bohdan Date: Thu, 28 May 2026 14:43:59 -0700 Subject: [PATCH 2/4] Revert metric routing changes --- src/art/model.py | 26 +++++--------------------- tests/unit/test_metric_routing.py | 25 +++++++------------------ 2 files changed, 12 insertions(+), 39 deletions(-) diff --git a/src/art/model.py b/src/art/model.py index 391f79b3f..f499fe1d3 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -618,16 +618,15 @@ def _get_wandb_run(self) -> Optional["Run"]: # This allows out-of-order logging (e.g., async validation for previous steps). run.define_metric("training_step") run.define_metric("time/wall_clock_sec") - for split in sorted(METRIC_SPLITS): - run.define_metric(f"{split}/step", hidden=True) run.define_metric("reward/*", step_metric="training_step") run.define_metric("loss/*", step_metric="training_step") run.define_metric("throughput/*", step_metric="training_step") run.define_metric("costs/*", step_metric="training_step") run.define_metric("time/*", step_metric="training_step") run.define_metric("data/*", step_metric="training_step") - for split in sorted(METRIC_SPLITS): - run.define_metric(f"{split}/*", step_metric=f"{split}/step") + run.define_metric("train/*", step_metric="training_step") + run.define_metric("val/*", step_metric="training_step") + run.define_metric("test/*", step_metric="training_step") run.define_metric("discarded/*", step_metric="training_step") self._sync_wandb_config(run) return self._wandb_run @@ -655,13 +654,6 @@ def _log_metrics( prefixed = {f"{split}/{k}": v for k, v in metrics.items()} prefixed["training_step"] = step - split_prefixes = { - key.split("/", 1)[0] - for key in prefixed - if "/" in key and key.split("/", 1)[0] in METRIC_SPLITS - } - for split_prefix in split_prefixes: - prefixed[f"{split_prefix}/step"] = step prefixed["time/wall_clock_sec"] = time.time() - self._run_start_time output_dir = self._get_output_dir() @@ -699,19 +691,11 @@ def _define_wandb_step_metrics(self, keys: Iterable[str]) -> None: return for key in keys: - first_component = key.split("/", 1)[0] - if key in {"training_step", "time/wall_clock_sec"} or key.endswith("/step"): - continue - if "/" not in key: + if not key.startswith("costs/"): continue if key in self._wandb_defined_metrics: continue - step_metric = ( - f"{first_component}/step" - if first_component in METRIC_SPLITS - else "training_step" - ) - run.define_metric(key, step_metric=step_metric, overwrite=True) + run.define_metric(key, step_metric="training_step") self._wandb_defined_metrics.add(key) def _route_metrics_and_collect_non_costs( diff --git a/tests/unit/test_metric_routing.py b/tests/unit/test_metric_routing.py index cadca3bf5..5a290ebfb 100644 --- a/tests/unit/test_metric_routing.py +++ b/tests/unit/test_metric_routing.py @@ -67,22 +67,19 @@ def test_get_wandb_run_registers_taxonomy_sections(self, tmp_path: Path) -> None assert define_calls == [ (("training_step",), {}), (("time/wall_clock_sec",), {}), - (("test/step",), {"hidden": True}), - (("train/step",), {"hidden": True}), - (("val/step",), {"hidden": True}), (("reward/*",), {"step_metric": "training_step"}), (("loss/*",), {"step_metric": "training_step"}), (("throughput/*",), {"step_metric": "training_step"}), (("costs/*",), {"step_metric": "training_step"}), (("time/*",), {"step_metric": "training_step"}), (("data/*",), {"step_metric": "training_step"}), - (("test/*",), {"step_metric": "test/step"}), - (("train/*",), {"step_metric": "train/step"}), - (("val/*",), {"step_metric": "val/step"}), + (("train/*",), {"step_metric": "training_step"}), + (("val/*",), {"step_metric": "training_step"}), + (("test/*",), {"step_metric": "training_step"}), (("discarded/*",), {"step_metric": "training_step"}), ] - def test_log_metrics_defines_concrete_wandb_metric_keys_with_training_step( + def test_log_metrics_defines_nested_cost_keys_with_training_step( self, tmp_path: Path ) -> None: fake_run = MagicMock() @@ -106,10 +103,8 @@ def test_log_metrics_defines_concrete_wandb_metric_keys_with_training_step( { "costs/train/sample": 0.1, "costs/cum/train/prefill": 0.2, - "reward": 0.3, - "val/reward": 0.4, }, - split="val", + split="train", step=1, ) @@ -118,22 +113,16 @@ def test_log_metrics_defines_concrete_wandb_metric_keys_with_training_step( ] assert ( ("costs/train/sample",), - {"step_metric": "training_step", "overwrite": True}, + {"step_metric": "training_step"}, ) in define_calls assert ( ("costs/cum/train/prefill",), - {"step_metric": "training_step", "overwrite": True}, - ) in define_calls - assert ( - ("val/reward",), - {"step_metric": "val/step", "overwrite": True}, + {"step_metric": "training_step"}, ) in define_calls fake_run.log.assert_called_once() logged_metrics = fake_run.log.call_args.args[0] assert logged_metrics["costs/train/sample"] == 0.1 assert logged_metrics["costs/cum/train/prefill"] == 0.2 - assert logged_metrics["val/reward"] == 0.4 - assert logged_metrics["val/step"] == 1 assert logged_metrics["training_step"] == 1 assert "time/wall_clock_sec" in logged_metrics assert fake_run.log.call_args.kwargs == {} From 642c85b0164f76b57e885128e4aa35c85cf22977 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Thu, 28 May 2026 22:02:31 +0000 Subject: [PATCH 3/4] Fix PipelineTrainer PR type diagnostics --- src/art/pipeline_trainer/trainer.py | 2 +- src/art/serverless/backend.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index ad38a7f47..721ddd791 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -36,7 +36,7 @@ def _to_async_iterator(iterable: Iterable[T] | AsyncIterator[T]) -> AsyncIterator[T]: """Convert a sync Iterable to an AsyncIterator, or pass through if already async.""" if isinstance(iterable, AsyncIterator): - return cast(AsyncIterator[T], iterable) + return iterable async def _iter(): for item in iterable: diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index fc0f5cb84..82869e013 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,6 +1,7 @@ import asyncio import time from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal +import warnings from openai._types import NOT_GIVEN from tqdm import auto as tqdm From 249d4b5883c03ea7eae2484283c4b7c8810d2ea5 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Fri, 29 May 2026 22:52:25 +0000 Subject: [PATCH 4/4] Pin serverless inference during adapter leases --- src/art/adapter_leases.py | 26 +++++++++ src/art/local/adapter_leases.py | 24 +------- src/art/serverless/backend.py | 16 +++++- tests/unit/test_serverless_adapter_lease.py | 61 +++++++++++++++++++++ 4 files changed, 103 insertions(+), 24 deletions(-) create mode 100644 src/art/adapter_leases.py create mode 100644 tests/unit/test_serverless_adapter_lease.py diff --git a/src/art/adapter_leases.py b/src/art/adapter_leases.py new file mode 100644 index 000000000..933fa6407 --- /dev/null +++ b/src/art/adapter_leases.py @@ -0,0 +1,26 @@ +from contextlib import asynccontextmanager +from contextvars import ContextVar +from typing import AsyncIterator + +_pinned_inference_steps: ContextVar[dict[str, int]] = ContextVar( + "art_pinned_inference_steps", + default={}, +) + + +def pinned_inference_step(model_name: str) -> int | None: + return _pinned_inference_steps.get().get(model_name) + + +@asynccontextmanager +async def pin_inference_step( + model_name: str, + step: int, +) -> AsyncIterator[None]: + pinned_steps = dict(_pinned_inference_steps.get()) + pinned_steps[model_name] = step + token = _pinned_inference_steps.set(pinned_steps) + try: + yield + finally: + _pinned_inference_steps.reset(token) diff --git a/src/art/local/adapter_leases.py b/src/art/local/adapter_leases.py index b63313e0c..2cc2f27c3 100644 --- a/src/art/local/adapter_leases.py +++ b/src/art/local/adapter_leases.py @@ -1,31 +1,9 @@ import asyncio from collections import Counter from contextlib import asynccontextmanager -from contextvars import ContextVar from typing import AsyncIterator -_pinned_inference_steps: ContextVar[dict[str, int]] = ContextVar( - "art_pinned_inference_steps", - default={}, -) - - -def pinned_inference_step(model_name: str) -> int | None: - return _pinned_inference_steps.get().get(model_name) - - -@asynccontextmanager -async def pin_inference_step( - model_name: str, - step: int, -) -> AsyncIterator[None]: - pinned_steps = dict(_pinned_inference_steps.get()) - pinned_steps[model_name] = step - token = _pinned_inference_steps.set(pinned_steps) - try: - yield - finally: - _pinned_inference_steps.reset(token) +from art.adapter_leases import pin_inference_step, pinned_inference_step class AdapterLeaseManager: diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 82869e013..4ab10742b 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import asynccontextmanager import time from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal import warnings @@ -6,6 +7,7 @@ from openai._types import NOT_GIVEN from tqdm import auto as tqdm +from art.adapter_leases import pin_inference_step, pinned_inference_step from art.serverless.client import Client, ExperimentalTrainingConfig from .. import dev @@ -144,14 +146,26 @@ def _model_inference_name(self, model: "Model", step: int | None = None) -> str: model: The model. step: If provided, returns name for specific checkpoint using W&B artifact versioning (e.g., :step5). If None, returns - name for latest checkpoint (default, backwards compatible). + name for the pinned checkpoint when running inside an + adapter_lease, otherwise latest checkpoint. """ assert model.entity is not None, "Model entity is required" + if step is None: + step = pinned_inference_step(model.name) base_name = f"wandb-artifact:///{model.entity}/{model.project}/{model.name}" if step is not None: return f"{base_name}:step{step}" return base_name + @asynccontextmanager + async def adapter_lease( + self, + model: AnyTrainableModel, + step: int, + ) -> AsyncIterator[None]: + async with pin_inference_step(model.name, step): + yield + async def _get_step(self, model: "Model") -> int: if model.trainable: assert model.id is not None, "Model ID is required" diff --git a/tests/unit/test_serverless_adapter_lease.py b/tests/unit/test_serverless_adapter_lease.py new file mode 100644 index 000000000..a5b7cc7d4 --- /dev/null +++ b/tests/unit/test_serverless_adapter_lease.py @@ -0,0 +1,61 @@ +import art +from art.serverless.backend import ServerlessBackend + + +async def test_serverless_adapter_lease_pins_inference_step() -> None: + backend = ServerlessBackend(api_key="test-api-key") + model = art.TrainableModel( + name="test-model", + project="test-project", + entity="test-entity", + base_model="test-base-model", + ) + model._backend = backend + + assert ( + model.get_inference_name() + == "wandb-artifact:///test-entity/test-project/test-model" + ) + + async with backend.adapter_lease(model, 3): + assert ( + model.get_inference_name() + == "wandb-artifact:///test-entity/test-project/test-model:step3" + ) + assert ( + model.get_inference_name(step=4) + == "wandb-artifact:///test-entity/test-project/test-model:step4" + ) + + assert ( + model.get_inference_name() + == "wandb-artifact:///test-entity/test-project/test-model" + ) + + +async def test_serverless_adapter_lease_is_model_scoped() -> None: + backend = ServerlessBackend(api_key="test-api-key") + model_a = art.TrainableModel( + name="model-a", + project="test-project", + entity="test-entity", + base_model="test-base-model", + ) + model_b = art.TrainableModel( + name="model-b", + project="test-project", + entity="test-entity", + base_model="test-base-model", + ) + model_a._backend = backend + model_b._backend = backend + + async with backend.adapter_lease(model_a, 2): + assert ( + model_a.get_inference_name() + == "wandb-artifact:///test-entity/test-project/model-a:step2" + ) + assert ( + model_b.get_inference_name() + == "wandb-artifact:///test-entity/test-project/model-b" + )