Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/art/adapter_leases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from contextlib import asynccontextmanager
from contextvars import ContextVar
from typing import AsyncIterator

_pinned_inference_steps: ContextVar[dict[str, int]] = ContextVar(
"art_pinned_inference_steps",
default={},
)


def pinned_inference_step(model_name: str) -> int | None:
return _pinned_inference_steps.get().get(model_name)


@asynccontextmanager
async def pin_inference_step(
model_name: str,
step: int,
) -> AsyncIterator[None]:
pinned_steps = dict(_pinned_inference_steps.get())
pinned_steps[model_name] = step
token = _pinned_inference_steps.set(pinned_steps)
try:
yield
finally:
_pinned_inference_steps.reset(token)
24 changes: 1 addition & 23 deletions src/art/local/adapter_leases.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,9 @@
import asyncio
from collections import Counter
from contextlib import asynccontextmanager
from contextvars import ContextVar
from typing import AsyncIterator

_pinned_inference_steps: ContextVar[dict[str, int]] = ContextVar(
"art_pinned_inference_steps",
default={},
)


def pinned_inference_step(model_name: str) -> int | None:
return _pinned_inference_steps.get().get(model_name)


@asynccontextmanager
async def pin_inference_step(
model_name: str,
step: int,
) -> AsyncIterator[None]:
pinned_steps = dict(_pinned_inference_steps.get())
pinned_steps[model_name] = step
token = _pinned_inference_steps.set(pinned_steps)
try:
yield
finally:
_pinned_inference_steps.reset(token)
from art.adapter_leases import pin_inference_step, pinned_inference_step


class AdapterLeaseManager:
Expand Down
2 changes: 1 addition & 1 deletion src/art/pipeline_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
def _to_async_iterator(iterable: Iterable[T] | AsyncIterator[T]) -> AsyncIterator[T]:
"""Convert a sync Iterable to an AsyncIterator, or pass through if already async."""
if isinstance(iterable, AsyncIterator):
return cast(AsyncIterator[T], iterable)
return iterable

async def _iter():
for item in iterable:
Expand Down
114 changes: 110 additions & 4 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from contextlib import asynccontextmanager
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal
import warnings

from openai._types import NOT_GIVEN
from tqdm import auto as tqdm

from art.adapter_leases import pin_inference_step, pinned_inference_step
from art.serverless.client import Client, ExperimentalTrainingConfig

from .. import dev
Expand Down Expand Up @@ -144,14 +146,26 @@ def _model_inference_name(self, model: "Model", step: int | None = None) -> str:
model: The model.
step: If provided, returns name for specific checkpoint using
W&B artifact versioning (e.g., :step5). If None, returns
name for latest checkpoint (default, backwards compatible).
name for the pinned checkpoint when running inside an
adapter_lease, otherwise latest checkpoint.
"""
assert model.entity is not None, "Model entity is required"
if step is None:
step = pinned_inference_step(model.name)
base_name = f"wandb-artifact:///{model.entity}/{model.project}/{model.name}"
if step is not None:
return f"{base_name}:step{step}"
return base_name

@asynccontextmanager
async def adapter_lease(
self,
model: AnyTrainableModel,
step: int,
) -> AsyncIterator[None]:
async with pin_inference_step(model.name, step):
yield

async def _get_step(self, model: "Model") -> int:
if model.trainable:
assert model.id is not None, "Model ID is required"
Expand Down Expand Up @@ -197,8 +211,15 @@ async def train( # type: ignore[override]
*,
# Core training parameters
learning_rate: float = 5e-6,
loss_fn: Literal["cispo", "ppo"] | None = None,
loss_fn_config: dict | None = None,
normalize_advantages: bool = True,
adam_params: object | None = None,
# KL-penalized advantage adjustment
kl_penalty_coef: float = 0.0,
kl_ref_adapter_path: str | None = None,
# RL algorithm settings
ppo: bool = False,
ppo: bool | None = None,
epsilon: float | None = None,
epsilon_high: float | None = None,
# Advantage computation
Expand All @@ -213,6 +234,15 @@ async def train( # type: ignore[override]
# Experimental parameters
kimi_k2_tau: float | None = None,
precalculate_logprobs: bool = False,
allow_training_without_logprobs: bool = False,
plot_tensors: bool = False,
truncated_importance_sampling: float | None = None,
scale_learning_rate_by_reward_std_dev: bool = False,
logprob_calculation_chunk_size: int = 1024,
packed_sequence_length: int | None = None,
num_trajectories_learning_rate_multiplier_power: float = 0.0,
# Checkpoint behavior
save_checkpoint: bool = True,
# Verbosity
verbose: bool = False,
) -> ServerlessTrainResult:
Expand All @@ -226,7 +256,20 @@ async def train( # type: ignore[override]
model: The trainable model to train.
trajectory_groups: Batches of trajectories to train on.
learning_rate: Learning rate for training. Defaults to 5e-6.
ppo: Whether to use PPO clipping. Defaults to False.
loss_fn: RL loss function. ServerlessBackend supports "cispo" and
"ppo". If unset, the legacy ppo argument is used.
loss_fn_config: Additional loss-function config. Not supported by
ServerlessBackend.
normalize_advantages: Backward-compatible alias for reward std scaling.
When False, ServerlessBackend centers rewards but does not divide
by group reward std dev.
adam_params: Custom optimizer params. Not supported by
ServerlessBackend.
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
Defaults to 0.0 (disabled).
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
checkpoint to use as the KL reference.
ppo: Legacy flag for PPO clipping. Prefer loss_fn="ppo".
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
advantage_balance: Balance between negative and positive advantages
Expand All @@ -240,6 +283,22 @@ async def train( # type: ignore[override]
mask_prob_ratio: Whether to mask probability ratios. Defaults to False.
kimi_k2_tau: Tau parameter for Kimi K2 algorithm.
precalculate_logprobs: Whether to precalculate logprobs.
allow_training_without_logprobs: Allow training even when no logprobs
are available. Defaults to False.
plot_tensors: Whether to plot training tensors for debugging.
Defaults to False.
truncated_importance_sampling: Truncation threshold for importance
sampling weights.
scale_learning_rate_by_reward_std_dev: Whether to scale learning rate
by reward standard deviation. Defaults to False.
logprob_calculation_chunk_size: Chunk size for logprob calculation.
Defaults to 1024.
packed_sequence_length: Packed sequence length to use for training.
num_trajectories_learning_rate_multiplier_power: Power for learning
rate multiplier based on number of trajectories.
save_checkpoint: Accepted for PipelineTrainer compatibility. Serverless
training currently always saves a trainable checkpoint for the next
inference step.
verbose: Whether to print verbose output. Defaults to False.

Returns:
Expand All @@ -252,19 +311,45 @@ async def train( # type: ignore[override]
# await model.log(metrics=result.metrics, step=result.step)
"""
groups_list = list(trajectory_groups)
if loss_fn is None:
resolved_loss_fn: Literal["cispo", "ppo"] = "ppo" if ppo else "cispo"
else:
resolved_loss_fn = loss_fn
if ppo is not None and ppo != (loss_fn == "ppo"):
raise ValueError("ServerlessBackend got conflicting loss_fn and ppo.")
if resolved_loss_fn not in {"cispo", "ppo"}:
raise ValueError(
"ServerlessBackend only supports loss_fn='cispo' or 'ppo'."
)
if loss_fn_config is not None:
raise ValueError("ServerlessBackend requires loss_fn_config=None.")
if not normalize_advantages:
scale_rewards = False
if adam_params is not None:
raise ValueError("ServerlessBackend requires adam_params=None.")
_ = save_checkpoint

config, dev_config = build_rl_train_configs(
learning_rate=learning_rate,
advantage_balance=advantage_balance,
scale_rewards=scale_rewards,
importance_sampling_level=importance_sampling_level,
mask_prob_ratio=mask_prob_ratio,
ppo=ppo,
ppo=resolved_loss_fn == "ppo",
precalculate_logprobs=precalculate_logprobs,
epsilon=epsilon,
epsilon_high=epsilon_high,
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
kimi_k2_tau=kimi_k2_tau,
kl_penalty_coef=kl_penalty_coef,
allow_training_without_logprobs=allow_training_without_logprobs,
plot_tensors=plot_tensors,
truncated_importance_sampling=truncated_importance_sampling,
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
packed_sequence_length=packed_sequence_length,
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
kl_ref_adapter_path=kl_ref_adapter_path,
)

# Collect metrics from training
Expand Down Expand Up @@ -317,18 +402,39 @@ async def _train_model(
trajectory_groups=trajectory_groups,
experimental_config=ExperimentalTrainingConfig(
advantage_balance=dev_config.get("advantage_balance"),
allow_training_without_logprobs=dev_config.get(
"allow_training_without_logprobs"
),
epsilon=dev_config.get("epsilon"),
epsilon_high=dev_config.get("epsilon_high"),
importance_sampling_level=dev_config.get("importance_sampling_level"),
kimi_k2_tau=dev_config.get("kimi_k2_tau"),
kl_penalty_coef=dev_config.get("kl_penalty_coef"),
kl_ref_adapter_path=dev_config.get("kl_ref_adapter_path"),
learning_rate=config.learning_rate,
logprob_calculation_chunk_size=dev_config.get(
"logprob_calculation_chunk_size"
),
loss_fn="ppo" if dev_config.get("ppo") else "cispo",
mask_prob_ratio=dev_config.get("mask_prob_ratio"),
max_negative_advantage_importance_sampling_weight=dev_config.get(
"max_negative_advantage_importance_sampling_weight"
),
normalize_advantages=dev_config.get("scale_rewards"),
num_trajectories_learning_rate_multiplier_power=dev_config.get(
"num_trajectories_learning_rate_multiplier_power"
),
packed_sequence_length=dev_config.get("packed_sequence_length"),
plot_tensors=dev_config.get("plot_tensors"),
ppo=dev_config.get("ppo"),
precalculate_logprobs=dev_config.get("precalculate_logprobs"),
scale_learning_rate_by_reward_std_dev=dev_config.get(
"scale_learning_rate_by_reward_std_dev"
),
scale_rewards=dev_config.get("scale_rewards"),
truncated_importance_sampling=dev_config.get(
"truncated_importance_sampling"
),
),
)
after: str | None = None
Expand Down
11 changes: 11 additions & 0 deletions src/art/serverless/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,29 @@ class DeleteCheckpointsResponse(BaseModel):

class ExperimentalTrainingConfig(TypedDict, total=False):
advantage_balance: float | None
allow_training_without_logprobs: bool | None
epsilon: float | None
epsilon_high: float | None
importance_sampling_level: (
Literal["token", "sequence", "average", "geometric_average"] | None
)
kimi_k2_tau: float | None
kl_penalty_coef: float | None
kl_ref_adapter_path: str | None
learning_rate: float | None
logprob_calculation_chunk_size: int | None
loss_fn: Literal["cispo", "ppo"] | None
mask_prob_ratio: bool | None
max_negative_advantage_importance_sampling_weight: float | None
normalize_advantages: bool | None
num_trajectories_learning_rate_multiplier_power: float | None
packed_sequence_length: int | None
plot_tensors: bool | None
ppo: bool | None
precalculate_logprobs: bool | None
scale_learning_rate_by_reward_std_dev: bool | None
scale_rewards: bool | None
truncated_importance_sampling: float | None


class SFTTrainingConfig(TypedDict, total=False):
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/test_serverless_adapter_lease.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import art
from art.serverless.backend import ServerlessBackend


async def test_serverless_adapter_lease_pins_inference_step() -> None:
backend = ServerlessBackend(api_key="test-api-key")
model = art.TrainableModel(
name="test-model",
project="test-project",
entity="test-entity",
base_model="test-base-model",
)
model._backend = backend

assert (
model.get_inference_name()
== "wandb-artifact:///test-entity/test-project/test-model"
)

async with backend.adapter_lease(model, 3):
assert (
model.get_inference_name()
== "wandb-artifact:///test-entity/test-project/test-model:step3"
)
assert (
model.get_inference_name(step=4)
== "wandb-artifact:///test-entity/test-project/test-model:step4"
)

assert (
model.get_inference_name()
== "wandb-artifact:///test-entity/test-project/test-model"
)


async def test_serverless_adapter_lease_is_model_scoped() -> None:
backend = ServerlessBackend(api_key="test-api-key")
model_a = art.TrainableModel(
name="model-a",
project="test-project",
entity="test-entity",
base_model="test-base-model",
)
model_b = art.TrainableModel(
name="model-b",
project="test-project",
entity="test-entity",
base_model="test-base-model",
)
model_a._backend = backend
model_b._backend = backend

async with backend.adapter_lease(model_a, 2):
assert (
model_a.get_inference_name()
== "wandb-artifact:///test-entity/test-project/model-a:step2"
)
assert (
model_b.get_inference_name()
== "wandb-artifact:///test-entity/test-project/model-b"
)
Loading
Loading