Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@

**Treat errors as signals, not obstacles.** An unexpected value or failed assertion means something upstream is wrong. Trace the data flow end-to-end: where was this value produced? What configuration or state fed into it? The goal is durable, correct software — not silencing errors.

## Type Safety

- Avoid `Any` in Python — prefer `cast()` with the correct type when the type checker can't infer it (e.g. GPU library return types). Use typed dicts, dataclasses, or Pydantic models over bare `dict` without type parameters.
- Data should have the correct shape from the moment it enters the system. If you need a `_normalize_*` or `_make_serializable` function, the upstream data model is wrong — fix the source instead.

## Code Style

- Catch specific exception types (`except httpx.HTTPError:`, `except json.JSONDecodeError:`), not bare `except Exception:` or `except:`.
- Keep imports at module top level. Exception: GPU dependencies (torch, peft, vllm, transformers) that aren't installed locally use deferred imports inside functions.
- Don't guard against impossible states — avoid redundant `None` checks, `hasattr()` duck-typing, or `isinstance()` checks when the type is already known from the signature or upstream logic.
- Every public function, class, and module should have a docstring. Keep them concise — one line for simple functions, a short paragraph for complex ones. If you touch a function that lacks one, add it.

## Code Quality Rules

### After Every Change
Expand Down
2 changes: 1 addition & 1 deletion claas/modal/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def distill(self, request: DistillBatchRequestPayload) -> DistillResponse:
Distillation response payload.
"""
try:
return self.trainer.distill(request)
return self.trainer.distill(request).response
finally:
self.trainer.offload_base_model()

Expand Down
118 changes: 110 additions & 8 deletions claas/training/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
import torch

from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput
from claas.training.engine.local.cache import (
DistillStepResult,
LoraAdapterConfig,
LoraCacheEntry,
cpu_optimizer_state,
gpu_optimizer_state,
)
from claas.training.sdpo_loss import compute_sdpo_loss
from claas.training.storage import (
cleanup_local_lora,
Expand Down Expand Up @@ -46,6 +53,7 @@ class PreparedSample(TypedDict):
behavior_logprobs: torch.Tensor



class DistillationTrainer:
"""Runs one SDPO distillation update using a loaded base model."""

Expand Down Expand Up @@ -98,6 +106,10 @@ def load_base_model(self) -> None:
self.optimizer_cls = torch.optim.AdamW
self.functional = torch.nn.functional

def reload_base_model(self) -> None:
"""Move base model from CPU back to CUDA."""
self.base_model.to(self.device) # type: ignore[arg-type] # functools.wraps confuses ty

def offload_base_model(self) -> None:
"""Move base model to CPU and release CUDA memory."""

Expand Down Expand Up @@ -141,6 +153,33 @@ def _load_or_create_lora(self, lora_path: str) -> "PeftModel | PeftMixedModel":
)
return get_peft_model(self.base_model, lora_config)

def _load_lora_from_cache(
self,
cached: LoraCacheEntry,
) -> "PeftModel | PeftMixedModel":
"""Restore a LoRA adapter from a CPU cache entry.

Args:
cached: CPU-resident snapshot of adapter state.

Returns:
Trainable PEFT model with cached weights loaded.
"""
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict

cfg = cached.adapter_config
lora_config = LoraConfig(
r=cfg.r,
lora_alpha=cfg.lora_alpha,
target_modules=cfg.target_modules,
lora_dropout=cfg.lora_dropout,
bias=cfg.bias, # type: ignore[arg-type] # peft Literal vs str
task_type=cfg.task_type,
)
model = get_peft_model(self.base_model, lora_config)
set_peft_model_state_dict(model, cached.lora_state_dict)
return model

def _load_optimizer_state(
self,
lora_path: str,
Expand Down Expand Up @@ -248,14 +287,59 @@ def _compute_student_response_logprobs(
torch.cuda.empty_cache()
return student_logprobs.to(dtype=torch.float32).detach()

def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
def _build_cache_entry(
self,
model: "PeftModel | PeftMixedModel",
optimizer: "torch.optim.Optimizer",
) -> LoraCacheEntry:
"""Snapshot current model + optimizer state into a CPU-resident cache entry."""
from peft import LoraConfig, PeftModel as PeftModelCls

peft_config = cast(LoraConfig, model.peft_config["default"])
adapter_config = LoraAdapterConfig(
r=peft_config.r,
lora_alpha=peft_config.lora_alpha,
target_modules=list(peft_config.target_modules or []),
lora_dropout=peft_config.lora_dropout,
bias=peft_config.bias,
task_type=str(peft_config.task_type),
)

# Extract adapter-only state dict via PEFT
if not isinstance(model, PeftModelCls):
raise TypeError(
f"Expected a PeftModel for cache entry construction, got {type(model).__name__}"
)
from peft import get_peft_model_state_dict

raw_state = get_peft_model_state_dict(model)

lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()}
opt_state = cpu_optimizer_state(optimizer.state_dict())

return LoraCacheEntry(
lora_state_dict=lora_state,
optimizer_state_dict=opt_state,
adapter_config=adapter_config,
)

def distill(
self,
payload: DistillBatchRequestPayload,
*,
cached: LoraCacheEntry | None = None,
) -> DistillStepResult:
"""Run one SDPO distillation step.

Args:
payload: Distillation request payload.
cached: When provided, skip disk reads and load LoRA + optimizer
state from this CPU-resident cache entry. When ``None``,
load from disk (cold start).

Returns:
Distillation response with metrics.
Result containing both the distillation response and a cache
entry for the post-step state.
"""

torch.cuda.empty_cache()
Expand All @@ -266,10 +350,18 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
if len(payload.samples) == 0:
raise ValueError("samples must contain at least one item")

lora_local_path = load_lora(payload.lora_id)
# Disk path (cold start) or cache path
lora_local_path: str | None = None
if cached is None:
lora_local_path = load_lora(payload.lora_id)

try:
try:
model = self._load_or_create_lora(lora_local_path)
if cached is not None:
model = self._load_lora_from_cache(cached)
else:
assert lora_local_path is not None
model = self._load_or_create_lora(lora_local_path)
model.train()
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False},
Expand All @@ -284,7 +376,13 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
betas=(0.9, 0.999),
weight_decay=0.01,
)
self._load_optimizer_state(lora_local_path, optimizer)

if cached is not None:
optimizer.load_state_dict(
gpu_optimizer_state(cached.optimizer_state_dict, self.device)
)
elif lora_local_path is not None:
self._load_optimizer_state(lora_local_path, optimizer)

prepared_samples: list[PreparedSample] = []
batch_teacher_scored_texts: list[str] = []
Expand Down Expand Up @@ -436,10 +534,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:

final_step_metrics = step_metrics[-1]

del model, optimizer
cache_entry = self._build_cache_entry(model, optimizer)

del model, optimizer, batch_loss_tensors
torch.cuda.empty_cache()

return DistillResponse.model_validate(
response = DistillResponse.model_validate(
{
"lora_id": new_lora_id,
"metadata": {
Expand All @@ -457,5 +557,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
},
}
)
return DistillStepResult(response=response, cache_entry=cache_entry)
finally:
cleanup_local_lora(lora_local_path)
if lora_local_path is not None:
cleanup_local_lora(lora_local_path)
60 changes: 60 additions & 0 deletions claas/training/engine/local/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""CPU-resident optimizer state helpers for the local training engine."""

from __future__ import annotations

import copy
from typing import cast

import torch

# Re-export types for backward compatibility
from claas.training.engine.local.types import ( # noqa: F401
DistillStepResult,
LoraAdapterConfig,
LoraCacheEntry,
)


def cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
"""Deep-copy optimizer state with all tensors moved to CPU."""
result: dict[str, object] = {}
for key, value in state_dict.items():
if key == "state":
param_states = cast("dict[int, dict[str, object]]", value)
cpu_states: dict[int, dict[str, object]] = {}
for param_id, param_state in param_states.items():
cpu_param: dict[str, object] = {}
for k, v in param_state.items():
if isinstance(v, torch.Tensor):
cpu_param[k] = v.detach().cpu().clone()
else:
cpu_param[k] = copy.deepcopy(v)
cpu_states[param_id] = cpu_param
result[key] = cpu_states
else:
result[key] = copy.deepcopy(value)
return result


def gpu_optimizer_state(
state_dict: dict[str, object],
device: torch.device,
) -> dict[str, object]:
"""Deep-copy optimizer state with all tensors moved to a target device."""
result: dict[str, object] = {}
for key, value in state_dict.items():
if key == "state":
param_states = cast("dict[int, dict[str, object]]", value)
gpu_states: dict[int, dict[str, object]] = {}
for param_id, param_state in param_states.items():
gpu_param: dict[str, object] = {}
for k, v in param_state.items():
if isinstance(v, torch.Tensor):
gpu_param[k] = v.detach().to(device).clone()
else:
gpu_param[k] = copy.deepcopy(v)
gpu_states[param_id] = gpu_param
result[key] = gpu_states
else:
result[key] = copy.deepcopy(value)
return result
53 changes: 46 additions & 7 deletions claas/training/engine/local/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from __future__ import annotations

import asyncio
import logging
import re
import threading

from claas.core.config import LocalConfig
from claas.core.types import (
Expand All @@ -20,6 +22,7 @@
)
from claas.training.distillation import DistillationTrainer
from claas.training.engine.base import TrainingEngine
from claas.training.engine.local.cache import LoraCacheEntry
from claas.training.storage import (
configure_storage_backend,
create_initial_lora,
Expand All @@ -31,14 +34,36 @@
resolve_lora_id,
)

logger = logging.getLogger(__name__)


class LocalTrainingEngine(TrainingEngine):
"""Executes training and LoRA operations on local infrastructure."""

_trainer: DistillationTrainer
_lora_cache: dict[str, LoraCacheEntry]
_cache_lock: threading.Lock
_model_loaded: bool

def __init__(self, cfg: LocalConfig) -> None:
configure_storage_backend("local_fs")
self._base_model_id = cfg.base_model_id
self._attn_implementation = cfg.attn_implementation
self._trainer = DistillationTrainer(
base_model_id=cfg.base_model_id,
attn_implementation=cfg.attn_implementation,
)
self._lora_cache = {}
self._cache_lock = threading.Lock()
self._model_loaded = False
self._load_lock = asyncio.Lock()

async def _ensure_model_loaded(self) -> None:
"""One-time base model load on first distill() call."""
async with self._load_lock:
if not self._model_loaded:
await asyncio.to_thread(self._trainer.load_base_model)
self._model_loaded = True

async def distill(
self,
Expand All @@ -52,15 +77,25 @@ async def distill(
Returns:
Distillation response.
"""
trainer = DistillationTrainer(
base_model_id=self._base_model_id,
attn_implementation=self._attn_implementation,
)
await asyncio.to_thread(trainer.load_base_model)
await self._ensure_model_loaded()
await asyncio.to_thread(self._trainer.reload_base_model)

resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id)
with self._cache_lock:
cached = self._lora_cache.get(resolved_id)

try:
return await asyncio.to_thread(trainer.distill, payload)
result = await asyncio.to_thread(
self._trainer.distill, payload, cached=cached
)
finally:
await asyncio.to_thread(trainer.offload_base_model)
await asyncio.to_thread(self._trainer.offload_base_model)

new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id)
with self._cache_lock:
self._lora_cache[new_resolved] = result.cache_entry

return result.response

async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse:
"""Initialize a LoRA adapter locally.
Expand All @@ -82,7 +117,11 @@ async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse:
return LoraInitResponse(lora_id=lora_id)

async def delete_lora(self, lora_id: str) -> LoraDeleteResponse:
resolved_id = await asyncio.to_thread(resolve_lora_id, lora_id)
deleted = await asyncio.to_thread(delete_lora, lora_id)
if deleted:
with self._cache_lock:
self._lora_cache.pop(resolved_id, None)
return LoraDeleteResponse(deleted=deleted)

async def list_loras(self, prefix: str) -> LoraListResponse:
Expand Down
Loading