diff --git a/src/art/adapter_leases.py b/src/art/adapter_leases.py new file mode 100644 index 000000000..933fa6407 --- /dev/null +++ b/src/art/adapter_leases.py @@ -0,0 +1,26 @@ +from contextlib import asynccontextmanager +from contextvars import ContextVar +from typing import AsyncIterator + +_pinned_inference_steps: ContextVar[dict[str, int]] = ContextVar( + "art_pinned_inference_steps", + default={}, +) + + +def pinned_inference_step(model_name: str) -> int | None: + return _pinned_inference_steps.get().get(model_name) + + +@asynccontextmanager +async def pin_inference_step( + model_name: str, + step: int, +) -> AsyncIterator[None]: + pinned_steps = dict(_pinned_inference_steps.get()) + pinned_steps[model_name] = step + token = _pinned_inference_steps.set(pinned_steps) + try: + yield + finally: + _pinned_inference_steps.reset(token) diff --git a/src/art/local/adapter_leases.py b/src/art/local/adapter_leases.py index b63313e0c..2cc2f27c3 100644 --- a/src/art/local/adapter_leases.py +++ b/src/art/local/adapter_leases.py @@ -1,31 +1,9 @@ import asyncio from collections import Counter from contextlib import asynccontextmanager -from contextvars import ContextVar from typing import AsyncIterator -_pinned_inference_steps: ContextVar[dict[str, int]] = ContextVar( - "art_pinned_inference_steps", - default={}, -) - - -def pinned_inference_step(model_name: str) -> int | None: - return _pinned_inference_steps.get().get(model_name) - - -@asynccontextmanager -async def pin_inference_step( - model_name: str, - step: int, -) -> AsyncIterator[None]: - pinned_steps = dict(_pinned_inference_steps.get()) - pinned_steps[model_name] = step - token = _pinned_inference_steps.set(pinned_steps) - try: - yield - finally: - _pinned_inference_steps.reset(token) +from art.adapter_leases import pin_inference_step, pinned_inference_step class AdapterLeaseManager: diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index ad38a7f47..721ddd791 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -36,7 +36,7 @@ def _to_async_iterator(iterable: Iterable[T] | AsyncIterator[T]) -> AsyncIterator[T]: """Convert a sync Iterable to an AsyncIterator, or pass through if already async.""" if isinstance(iterable, AsyncIterator): - return cast(AsyncIterator[T], iterable) + return iterable async def _iter(): for item in iterable: diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 1aa5b1f55..4ab10742b 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import asynccontextmanager import time from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal import warnings @@ -6,6 +7,7 @@ from openai._types import NOT_GIVEN from tqdm import auto as tqdm +from art.adapter_leases import pin_inference_step, pinned_inference_step from art.serverless.client import Client, ExperimentalTrainingConfig from .. import dev @@ -144,14 +146,26 @@ def _model_inference_name(self, model: "Model", step: int | None = None) -> str: model: The model. step: If provided, returns name for specific checkpoint using W&B artifact versioning (e.g., :step5). If None, returns - name for latest checkpoint (default, backwards compatible). + name for the pinned checkpoint when running inside an + adapter_lease, otherwise latest checkpoint. """ assert model.entity is not None, "Model entity is required" + if step is None: + step = pinned_inference_step(model.name) base_name = f"wandb-artifact:///{model.entity}/{model.project}/{model.name}" if step is not None: return f"{base_name}:step{step}" return base_name + @asynccontextmanager + async def adapter_lease( + self, + model: AnyTrainableModel, + step: int, + ) -> AsyncIterator[None]: + async with pin_inference_step(model.name, step): + yield + async def _get_step(self, model: "Model") -> int: if model.trainable: assert model.id is not None, "Model ID is required" @@ -197,8 +211,15 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, + loss_fn: Literal["cispo", "ppo"] | None = None, + loss_fn_config: dict | None = None, + normalize_advantages: bool = True, + adam_params: object | None = None, + # KL-penalized advantage adjustment + kl_penalty_coef: float = 0.0, + kl_ref_adapter_path: str | None = None, # RL algorithm settings - ppo: bool = False, + ppo: bool | None = None, epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -213,6 +234,15 @@ async def train( # type: ignore[override] # Experimental parameters kimi_k2_tau: float | None = None, precalculate_logprobs: bool = False, + allow_training_without_logprobs: bool = False, + plot_tensors: bool = False, + truncated_importance_sampling: float | None = None, + scale_learning_rate_by_reward_std_dev: bool = False, + logprob_calculation_chunk_size: int = 1024, + packed_sequence_length: int | None = None, + num_trajectories_learning_rate_multiplier_power: float = 0.0, + # Checkpoint behavior + save_checkpoint: bool = True, # Verbosity verbose: bool = False, ) -> ServerlessTrainResult: @@ -226,7 +256,20 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. - ppo: Whether to use PPO clipping. Defaults to False. + loss_fn: RL loss function. ServerlessBackend supports "cispo" and + "ppo". If unset, the legacy ppo argument is used. + loss_fn_config: Additional loss-function config. Not supported by + ServerlessBackend. + normalize_advantages: Backward-compatible alias for reward std scaling. + When False, ServerlessBackend centers rewards but does not divide + by group reward std dev. + adam_params: Custom optimizer params. Not supported by + ServerlessBackend. + kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. + Defaults to 0.0 (disabled). + kl_ref_adapter_path: Direct filesystem path to a LoRA adapter + checkpoint to use as the KL reference. + ppo: Legacy flag for PPO clipping. Prefer loss_fn="ppo". epsilon: Clip epsilon for importance sampling. Defaults based on ppo. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages @@ -240,6 +283,22 @@ async def train( # type: ignore[override] mask_prob_ratio: Whether to mask probability ratios. Defaults to False. kimi_k2_tau: Tau parameter for Kimi K2 algorithm. precalculate_logprobs: Whether to precalculate logprobs. + allow_training_without_logprobs: Allow training even when no logprobs + are available. Defaults to False. + plot_tensors: Whether to plot training tensors for debugging. + Defaults to False. + truncated_importance_sampling: Truncation threshold for importance + sampling weights. + scale_learning_rate_by_reward_std_dev: Whether to scale learning rate + by reward standard deviation. Defaults to False. + logprob_calculation_chunk_size: Chunk size for logprob calculation. + Defaults to 1024. + packed_sequence_length: Packed sequence length to use for training. + num_trajectories_learning_rate_multiplier_power: Power for learning + rate multiplier based on number of trajectories. + save_checkpoint: Accepted for PipelineTrainer compatibility. Serverless + training currently always saves a trainable checkpoint for the next + inference step. verbose: Whether to print verbose output. Defaults to False. Returns: @@ -252,6 +311,23 @@ async def train( # type: ignore[override] # await model.log(metrics=result.metrics, step=result.step) """ groups_list = list(trajectory_groups) + if loss_fn is None: + resolved_loss_fn: Literal["cispo", "ppo"] = "ppo" if ppo else "cispo" + else: + resolved_loss_fn = loss_fn + if ppo is not None and ppo != (loss_fn == "ppo"): + raise ValueError("ServerlessBackend got conflicting loss_fn and ppo.") + if resolved_loss_fn not in {"cispo", "ppo"}: + raise ValueError( + "ServerlessBackend only supports loss_fn='cispo' or 'ppo'." + ) + if loss_fn_config is not None: + raise ValueError("ServerlessBackend requires loss_fn_config=None.") + if not normalize_advantages: + scale_rewards = False + if adam_params is not None: + raise ValueError("ServerlessBackend requires adam_params=None.") + _ = save_checkpoint config, dev_config = build_rl_train_configs( learning_rate=learning_rate, @@ -259,12 +335,21 @@ async def train( # type: ignore[override] scale_rewards=scale_rewards, importance_sampling_level=importance_sampling_level, mask_prob_ratio=mask_prob_ratio, - ppo=ppo, + ppo=resolved_loss_fn == "ppo", precalculate_logprobs=precalculate_logprobs, epsilon=epsilon, epsilon_high=epsilon_high, max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight, kimi_k2_tau=kimi_k2_tau, + kl_penalty_coef=kl_penalty_coef, + allow_training_without_logprobs=allow_training_without_logprobs, + plot_tensors=plot_tensors, + truncated_importance_sampling=truncated_importance_sampling, + scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev, + logprob_calculation_chunk_size=logprob_calculation_chunk_size, + packed_sequence_length=packed_sequence_length, + num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power, + kl_ref_adapter_path=kl_ref_adapter_path, ) # Collect metrics from training @@ -317,18 +402,39 @@ async def _train_model( trajectory_groups=trajectory_groups, experimental_config=ExperimentalTrainingConfig( advantage_balance=dev_config.get("advantage_balance"), + allow_training_without_logprobs=dev_config.get( + "allow_training_without_logprobs" + ), epsilon=dev_config.get("epsilon"), epsilon_high=dev_config.get("epsilon_high"), importance_sampling_level=dev_config.get("importance_sampling_level"), kimi_k2_tau=dev_config.get("kimi_k2_tau"), + kl_penalty_coef=dev_config.get("kl_penalty_coef"), + kl_ref_adapter_path=dev_config.get("kl_ref_adapter_path"), learning_rate=config.learning_rate, + logprob_calculation_chunk_size=dev_config.get( + "logprob_calculation_chunk_size" + ), + loss_fn="ppo" if dev_config.get("ppo") else "cispo", mask_prob_ratio=dev_config.get("mask_prob_ratio"), max_negative_advantage_importance_sampling_weight=dev_config.get( "max_negative_advantage_importance_sampling_weight" ), + normalize_advantages=dev_config.get("scale_rewards"), + num_trajectories_learning_rate_multiplier_power=dev_config.get( + "num_trajectories_learning_rate_multiplier_power" + ), + packed_sequence_length=dev_config.get("packed_sequence_length"), + plot_tensors=dev_config.get("plot_tensors"), ppo=dev_config.get("ppo"), precalculate_logprobs=dev_config.get("precalculate_logprobs"), + scale_learning_rate_by_reward_std_dev=dev_config.get( + "scale_learning_rate_by_reward_std_dev" + ), scale_rewards=dev_config.get("scale_rewards"), + truncated_importance_sampling=dev_config.get( + "truncated_importance_sampling" + ), ), ) after: str | None = None diff --git a/src/art/serverless/client.py b/src/art/serverless/client.py index 7b92c2276..19d724e7d 100644 --- a/src/art/serverless/client.py +++ b/src/art/serverless/client.py @@ -52,18 +52,29 @@ class DeleteCheckpointsResponse(BaseModel): class ExperimentalTrainingConfig(TypedDict, total=False): advantage_balance: float | None + allow_training_without_logprobs: bool | None epsilon: float | None epsilon_high: float | None importance_sampling_level: ( Literal["token", "sequence", "average", "geometric_average"] | None ) kimi_k2_tau: float | None + kl_penalty_coef: float | None + kl_ref_adapter_path: str | None learning_rate: float | None + logprob_calculation_chunk_size: int | None + loss_fn: Literal["cispo", "ppo"] | None mask_prob_ratio: bool | None max_negative_advantage_importance_sampling_weight: float | None + normalize_advantages: bool | None + num_trajectories_learning_rate_multiplier_power: float | None + packed_sequence_length: int | None + plot_tensors: bool | None ppo: bool | None precalculate_logprobs: bool | None + scale_learning_rate_by_reward_std_dev: bool | None scale_rewards: bool | None + truncated_importance_sampling: float | None class SFTTrainingConfig(TypedDict, total=False): diff --git a/tests/unit/test_serverless_adapter_lease.py b/tests/unit/test_serverless_adapter_lease.py new file mode 100644 index 000000000..a5b7cc7d4 --- /dev/null +++ b/tests/unit/test_serverless_adapter_lease.py @@ -0,0 +1,61 @@ +import art +from art.serverless.backend import ServerlessBackend + + +async def test_serverless_adapter_lease_pins_inference_step() -> None: + backend = ServerlessBackend(api_key="test-api-key") + model = art.TrainableModel( + name="test-model", + project="test-project", + entity="test-entity", + base_model="test-base-model", + ) + model._backend = backend + + assert ( + model.get_inference_name() + == "wandb-artifact:///test-entity/test-project/test-model" + ) + + async with backend.adapter_lease(model, 3): + assert ( + model.get_inference_name() + == "wandb-artifact:///test-entity/test-project/test-model:step3" + ) + assert ( + model.get_inference_name(step=4) + == "wandb-artifact:///test-entity/test-project/test-model:step4" + ) + + assert ( + model.get_inference_name() + == "wandb-artifact:///test-entity/test-project/test-model" + ) + + +async def test_serverless_adapter_lease_is_model_scoped() -> None: + backend = ServerlessBackend(api_key="test-api-key") + model_a = art.TrainableModel( + name="model-a", + project="test-project", + entity="test-entity", + base_model="test-base-model", + ) + model_b = art.TrainableModel( + name="model-b", + project="test-project", + entity="test-entity", + base_model="test-base-model", + ) + model_a._backend = backend + model_b._backend = backend + + async with backend.adapter_lease(model_a, 2): + assert ( + model_a.get_inference_name() + == "wandb-artifact:///test-entity/test-project/model-a:step2" + ) + assert ( + model_b.get_inference_name() + == "wandb-artifact:///test-entity/test-project/model-b" + ) diff --git a/tests/unit/test_serverless_pipeline_trainer_compat.py b/tests/unit/test_serverless_pipeline_trainer_compat.py new file mode 100644 index 000000000..fec8d23f7 --- /dev/null +++ b/tests/unit/test_serverless_pipeline_trainer_compat.py @@ -0,0 +1,189 @@ +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from art import TrainableModel, Trajectory, TrajectoryGroup +from art.serverless.backend import ServerlessBackend +from art.types import TrainConfig + + +def _make_group() -> TrajectoryGroup: + return TrajectoryGroup( + [ + Trajectory( + reward=1.0, + messages_and_choices=[ + {"role": "user", "content": "prompt"}, + {"role": "assistant", "content": "answer"}, + ], + ) + ] + ) + + +def _make_backend() -> ServerlessBackend: + with patch("art.serverless.backend.Client") as client_cls: + client = MagicMock() + client.base_url = "http://serverless.test/v1" + client_cls.return_value = client + return ServerlessBackend(api_key="test-key") + + +@pytest.mark.asyncio +async def test_serverless_train_accepts_pipeline_trainer_kwargs() -> None: + backend = _make_backend() + model = TrainableModel( + name="serverless-pipeline-compat", + project="pipeline-tests", + base_model="test-model", + ) + model.id = "model-id" + model.entity = "entity" + + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: TrainConfig, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["config"] = config + seen["dev_config"] = dev_config + seen["verbose"] = verbose + yield {"loss": 0.25} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=3) # type: ignore[method-assign] + + with patch.object(model, "_get_wandb_run", return_value=None): + result = await backend.train( + model, + [_make_group()], + learning_rate=2e-5, + loss_fn="ppo", + normalize_advantages=False, + save_checkpoint=False, + packed_sequence_length=4096, + kl_penalty_coef=0.1, + kl_ref_adapter_path="/tmp/ref-adapter", + allow_training_without_logprobs=True, + plot_tensors=True, + truncated_importance_sampling=2.0, + scale_learning_rate_by_reward_std_dev=True, + logprob_calculation_chunk_size=512, + num_trajectories_learning_rate_multiplier_power=0.5, + verbose=True, + ) + + assert result.step == 3 + assert ( + result.artifact_name == "entity/pipeline-tests/serverless-pipeline-compat:step3" + ) + assert seen["config"].learning_rate == 2e-5 + assert seen["config"].kl_penalty_coef == 0.1 + assert seen["verbose"] is True + assert seen["dev_config"] == { + "advantage_balance": 0.0, + "allow_training_without_logprobs": True, + "importance_sampling_level": "token", + "kl_penalty_coef": 0.1, + "kl_ref_adapter_path": "/tmp/ref-adapter", + "logprob_calculation_chunk_size": 512, + "mask_prob_ratio": False, + "num_trajectories_learning_rate_multiplier_power": 0.5, + "packed_sequence_length": 4096, + "plot_tensors": True, + "ppo": True, + "precalculate_logprobs": False, + "scale_learning_rate_by_reward_std_dev": True, + "scale_rewards": False, + "truncated_importance_sampling": 2.0, + } + + +@pytest.mark.asyncio +async def test_serverless_train_rejects_unsupported_pipeline_kwargs() -> None: + backend = _make_backend() + model = TrainableModel( + name="serverless-pipeline-rejects", + project="pipeline-tests", + base_model="test-model", + ) + + with pytest.raises(ValueError, match="loss_fn_config=None"): + await backend.train(model, [_make_group()], loss_fn_config={"clip": 0.2}) + + with pytest.raises(ValueError, match="adam_params=None"): + await backend.train(model, [_make_group()], adam_params=object()) + + with pytest.raises(ValueError, match="conflicting loss_fn and ppo"): + await backend.train(model, [_make_group()], loss_fn="ppo", ppo=False) + + +@pytest.mark.asyncio +async def test_serverless_train_model_forwards_experimental_config() -> None: + backend = _make_backend() + model = TrainableModel( + name="serverless-config-payload", + project="pipeline-tests", + base_model="test-model", + ) + model.id = "model-id" + + captured: dict[str, Any] = {} + backend._client.training_jobs.create = AsyncMock( # type: ignore[attr-defined] + side_effect=lambda **kwargs: ( + captured.update(kwargs) or SimpleNamespace(id="training-job-id") + ) + ) + + async def events_list(**_kwargs: Any): + yield SimpleNamespace(id="event-id", type="training_ended", data={}) + + backend._client.training_jobs.events.list = events_list # type: ignore[attr-defined] + + async def no_sleep(_seconds: float) -> None: + return None + + with patch("art.serverless.backend.asyncio.sleep", no_sleep): + async for _ in backend._train_model( + model, + [_make_group()], + TrainConfig(learning_rate=7e-6, kl_penalty_coef=0.2), + { + "advantage_balance": 0.3, + "allow_training_without_logprobs": True, + "epsilon": 0.1, + "epsilon_high": 0.2, + "importance_sampling_level": "sequence", + "kimi_k2_tau": 0.4, + "kl_penalty_coef": 0.2, + "kl_ref_adapter_path": "/tmp/ref", + "logprob_calculation_chunk_size": 512, + "mask_prob_ratio": True, + "max_negative_advantage_importance_sampling_weight": 3.0, + "num_trajectories_learning_rate_multiplier_power": 0.5, + "packed_sequence_length": 4096, + "plot_tensors": True, + "ppo": True, + "precalculate_logprobs": True, + "scale_learning_rate_by_reward_std_dev": True, + "scale_rewards": False, + "truncated_importance_sampling": 2.0, + }, + ): + pass + + payload = captured["experimental_config"] + assert payload["learning_rate"] == 7e-6 + assert payload["loss_fn"] == "ppo" + assert payload["normalize_advantages"] is False + assert payload["packed_sequence_length"] == 4096 + assert payload["kl_penalty_coef"] == 0.2 + assert payload["kl_ref_adapter_path"] == "/tmp/ref" + assert payload["allow_training_without_logprobs"] is True + assert payload["scale_learning_rate_by_reward_std_dev"] is True