diff --git a/CLAUDE.md b/CLAUDE.md index 8e60258..1422a3b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,6 +8,18 @@ **Treat errors as signals, not obstacles.** An unexpected value or failed assertion means something upstream is wrong. Trace the data flow end-to-end: where was this value produced? What configuration or state fed into it? The goal is durable, correct software — not silencing errors. +## Type Safety + +- Avoid `Any` in Python — prefer `cast()` with the correct type when the type checker can't infer it (e.g. GPU library return types). Use typed dicts, dataclasses, or Pydantic models over bare `dict` without type parameters. +- Data should have the correct shape from the moment it enters the system. If you need a `_normalize_*` or `_make_serializable` function, the upstream data model is wrong — fix the source instead. + +## Code Style + +- Catch specific exception types (`except httpx.HTTPError:`, `except json.JSONDecodeError:`), not bare `except Exception:` or `except:`. +- Keep imports at module top level. Exception: GPU dependencies (torch, peft, vllm, transformers) that aren't installed locally use deferred imports inside functions. +- Don't guard against impossible states — avoid redundant `None` checks, `hasattr()` duck-typing, or `isinstance()` checks when the type is already known from the signature or upstream logic. +- Every public function, class, and module should have a docstring. Keep them concise — one line for simple functions, a short paragraph for complex ones. If you touch a function that lacks one, add it. + ## Code Quality Rules ### After Every Change diff --git a/claas/modal/worker.py b/claas/modal/worker.py index 1cfcb78..63bf5b4 100644 --- a/claas/modal/worker.py +++ b/claas/modal/worker.py @@ -90,7 +90,7 @@ def distill(self, request: DistillBatchRequestPayload) -> DistillResponse: Distillation response payload. """ try: - return self.trainer.distill(request) + return self.trainer.distill(request).response finally: self.trainer.offload_base_model() diff --git a/claas/training/distillation.py b/claas/training/distillation.py index 3e2e007..c854194 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -12,6 +12,13 @@ import torch from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput +from claas.training.engine.local.cache import ( + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, + cpu_optimizer_state, + gpu_optimizer_state, +) from claas.training.sdpo_loss import compute_sdpo_loss from claas.training.storage import ( cleanup_local_lora, @@ -46,6 +53,7 @@ class PreparedSample(TypedDict): behavior_logprobs: torch.Tensor + class DistillationTrainer: """Runs one SDPO distillation update using a loaded base model.""" @@ -98,6 +106,10 @@ def load_base_model(self) -> None: self.optimizer_cls = torch.optim.AdamW self.functional = torch.nn.functional + def reload_base_model(self) -> None: + """Move base model from CPU back to CUDA.""" + self.base_model.to(self.device) # type: ignore[arg-type] # functools.wraps confuses ty + def offload_base_model(self) -> None: """Move base model to CPU and release CUDA memory.""" @@ -141,6 +153,33 @@ def _load_or_create_lora(self, lora_path: str) -> "PeftModel | PeftMixedModel": ) return get_peft_model(self.base_model, lora_config) + def _load_lora_from_cache( + self, + cached: LoraCacheEntry, + ) -> "PeftModel | PeftMixedModel": + """Restore a LoRA adapter from a CPU cache entry. + + Args: + cached: CPU-resident snapshot of adapter state. + + Returns: + Trainable PEFT model with cached weights loaded. + """ + from peft import LoraConfig, get_peft_model, set_peft_model_state_dict + + cfg = cached.adapter_config + lora_config = LoraConfig( + r=cfg.r, + lora_alpha=cfg.lora_alpha, + target_modules=cfg.target_modules, + lora_dropout=cfg.lora_dropout, + bias=cfg.bias, # type: ignore[arg-type] # peft Literal vs str + task_type=cfg.task_type, + ) + model = get_peft_model(self.base_model, lora_config) + set_peft_model_state_dict(model, cached.lora_state_dict) + return model + def _load_optimizer_state( self, lora_path: str, @@ -248,14 +287,59 @@ def _compute_student_response_logprobs( torch.cuda.empty_cache() return student_logprobs.to(dtype=torch.float32).detach() - def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: + def _build_cache_entry( + self, + model: "PeftModel | PeftMixedModel", + optimizer: "torch.optim.Optimizer", + ) -> LoraCacheEntry: + """Snapshot current model + optimizer state into a CPU-resident cache entry.""" + from peft import LoraConfig, PeftModel as PeftModelCls + + peft_config = cast(LoraConfig, model.peft_config["default"]) + adapter_config = LoraAdapterConfig( + r=peft_config.r, + lora_alpha=peft_config.lora_alpha, + target_modules=list(peft_config.target_modules or []), + lora_dropout=peft_config.lora_dropout, + bias=peft_config.bias, + task_type=str(peft_config.task_type), + ) + + # Extract adapter-only state dict via PEFT + if not isinstance(model, PeftModelCls): + raise TypeError( + f"Expected a PeftModel for cache entry construction, got {type(model).__name__}" + ) + from peft import get_peft_model_state_dict + + raw_state = get_peft_model_state_dict(model) + + lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()} + opt_state = cpu_optimizer_state(optimizer.state_dict()) + + return LoraCacheEntry( + lora_state_dict=lora_state, + optimizer_state_dict=opt_state, + adapter_config=adapter_config, + ) + + def distill( + self, + payload: DistillBatchRequestPayload, + *, + cached: LoraCacheEntry | None = None, + ) -> DistillStepResult: """Run one SDPO distillation step. Args: payload: Distillation request payload. + cached: When provided, skip disk reads and load LoRA + optimizer + state from this CPU-resident cache entry. When ``None``, + load from disk (cold start). Returns: - Distillation response with metrics. + Result containing both the distillation response and a cache + entry for the post-step state. """ torch.cuda.empty_cache() @@ -266,10 +350,18 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: if len(payload.samples) == 0: raise ValueError("samples must contain at least one item") - lora_local_path = load_lora(payload.lora_id) + # Disk path (cold start) or cache path + lora_local_path: str | None = None + if cached is None: + lora_local_path = load_lora(payload.lora_id) + try: try: - model = self._load_or_create_lora(lora_local_path) + if cached is not None: + model = self._load_lora_from_cache(cached) + else: + assert lora_local_path is not None + model = self._load_or_create_lora(lora_local_path) model.train() model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False}, @@ -284,7 +376,13 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: betas=(0.9, 0.999), weight_decay=0.01, ) - self._load_optimizer_state(lora_local_path, optimizer) + + if cached is not None: + optimizer.load_state_dict( + gpu_optimizer_state(cached.optimizer_state_dict, self.device) + ) + elif lora_local_path is not None: + self._load_optimizer_state(lora_local_path, optimizer) prepared_samples: list[PreparedSample] = [] batch_teacher_scored_texts: list[str] = [] @@ -436,10 +534,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: final_step_metrics = step_metrics[-1] - del model, optimizer + cache_entry = self._build_cache_entry(model, optimizer) + + del model, optimizer, batch_loss_tensors torch.cuda.empty_cache() - return DistillResponse.model_validate( + response = DistillResponse.model_validate( { "lora_id": new_lora_id, "metadata": { @@ -457,5 +557,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: }, } ) + return DistillStepResult(response=response, cache_entry=cache_entry) finally: - cleanup_local_lora(lora_local_path) + if lora_local_path is not None: + cleanup_local_lora(lora_local_path) diff --git a/claas/training/engine/local/cache.py b/claas/training/engine/local/cache.py new file mode 100644 index 0000000..44ca4e9 --- /dev/null +++ b/claas/training/engine/local/cache.py @@ -0,0 +1,60 @@ +"""CPU-resident optimizer state helpers for the local training engine.""" + +from __future__ import annotations + +import copy +from typing import cast + +import torch + +# Re-export types for backward compatibility +from claas.training.engine.local.types import ( # noqa: F401 + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, +) + + +def cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to CPU.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + cpu_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + cpu_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + cpu_param[k] = v.detach().cpu().clone() + else: + cpu_param[k] = copy.deepcopy(v) + cpu_states[param_id] = cpu_param + result[key] = cpu_states + else: + result[key] = copy.deepcopy(value) + return result + + +def gpu_optimizer_state( + state_dict: dict[str, object], + device: torch.device, +) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to a target device.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + gpu_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + gpu_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + gpu_param[k] = v.detach().to(device).clone() + else: + gpu_param[k] = copy.deepcopy(v) + gpu_states[param_id] = gpu_param + result[key] = gpu_states + else: + result[key] = copy.deepcopy(value) + return result diff --git a/claas/training/engine/local/engine.py b/claas/training/engine/local/engine.py index 0533127..4eb302b 100644 --- a/claas/training/engine/local/engine.py +++ b/claas/training/engine/local/engine.py @@ -3,7 +3,9 @@ from __future__ import annotations import asyncio +import logging import re +import threading from claas.core.config import LocalConfig from claas.core.types import ( @@ -20,6 +22,7 @@ ) from claas.training.distillation import DistillationTrainer from claas.training.engine.base import TrainingEngine +from claas.training.engine.local.cache import LoraCacheEntry from claas.training.storage import ( configure_storage_backend, create_initial_lora, @@ -31,14 +34,36 @@ resolve_lora_id, ) +logger = logging.getLogger(__name__) + class LocalTrainingEngine(TrainingEngine): """Executes training and LoRA operations on local infrastructure.""" + _trainer: DistillationTrainer + _lora_cache: dict[str, LoraCacheEntry] + _cache_lock: threading.Lock + _model_loaded: bool + def __init__(self, cfg: LocalConfig) -> None: configure_storage_backend("local_fs") self._base_model_id = cfg.base_model_id self._attn_implementation = cfg.attn_implementation + self._trainer = DistillationTrainer( + base_model_id=cfg.base_model_id, + attn_implementation=cfg.attn_implementation, + ) + self._lora_cache = {} + self._cache_lock = threading.Lock() + self._model_loaded = False + self._load_lock = asyncio.Lock() + + async def _ensure_model_loaded(self) -> None: + """One-time base model load on first distill() call.""" + async with self._load_lock: + if not self._model_loaded: + await asyncio.to_thread(self._trainer.load_base_model) + self._model_loaded = True async def distill( self, @@ -52,15 +77,25 @@ async def distill( Returns: Distillation response. """ - trainer = DistillationTrainer( - base_model_id=self._base_model_id, - attn_implementation=self._attn_implementation, - ) - await asyncio.to_thread(trainer.load_base_model) + await self._ensure_model_loaded() + await asyncio.to_thread(self._trainer.reload_base_model) + + resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id) + with self._cache_lock: + cached = self._lora_cache.get(resolved_id) + try: - return await asyncio.to_thread(trainer.distill, payload) + result = await asyncio.to_thread( + self._trainer.distill, payload, cached=cached + ) finally: - await asyncio.to_thread(trainer.offload_base_model) + await asyncio.to_thread(self._trainer.offload_base_model) + + new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id) + with self._cache_lock: + self._lora_cache[new_resolved] = result.cache_entry + + return result.response async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse: """Initialize a LoRA adapter locally. @@ -82,7 +117,11 @@ async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse: return LoraInitResponse(lora_id=lora_id) async def delete_lora(self, lora_id: str) -> LoraDeleteResponse: + resolved_id = await asyncio.to_thread(resolve_lora_id, lora_id) deleted = await asyncio.to_thread(delete_lora, lora_id) + if deleted: + with self._cache_lock: + self._lora_cache.pop(resolved_id, None) return LoraDeleteResponse(deleted=deleted) async def list_loras(self, prefix: str) -> LoraListResponse: diff --git a/claas/training/engine/local/types.py b/claas/training/engine/local/types.py new file mode 100644 index 0000000..6b606d9 --- /dev/null +++ b/claas/training/engine/local/types.py @@ -0,0 +1,38 @@ +"""Typed structures for the local training engine's CPU-resident LoRA cache.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from claas.core.types import DistillResponse + + +@dataclass(frozen=True, slots=True) +class LoraAdapterConfig: + """Typed representation of LoRA adapter configuration.""" + + r: int + lora_alpha: int + target_modules: list[str] + lora_dropout: float + bias: str + task_type: str + + +@dataclass(frozen=True, slots=True) +class LoraCacheEntry: + """CPU-resident snapshot of LoRA adapter state between training steps.""" + + lora_state_dict: dict[str, torch.Tensor] + optimizer_state_dict: dict[str, object] + adapter_config: LoraAdapterConfig + + +@dataclass(frozen=True, slots=True) +class DistillStepResult: + """Result of a distillation step with both response and cache entry.""" + + response: DistillResponse + cache_entry: LoraCacheEntry diff --git a/tests/integration/test_local_engine_integration.py b/tests/integration/test_local_engine_integration.py index 06e945c..1398994 100644 --- a/tests/integration/test_local_engine_integration.py +++ b/tests/integration/test_local_engine_integration.py @@ -18,6 +18,11 @@ TrainingConfig, ) from claas.training import storage # noqa: E402 +from claas.training.engine.local.cache import ( # noqa: E402 + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, +) from claas.training.engine.local.engine import LocalTrainingEngine # noqa: E402 @@ -41,11 +46,20 @@ def __init__(self, base_model_id: str, attn_implementation: str, state: TrainerS def load_base_model(self) -> None: self._state.loaded = True - def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: + def reload_base_model(self) -> None: + pass + + def distill(self, payload: DistillBatchRequestPayload, *, cached: object = None) -> DistillStepResult: self._state.payload = payload - return DistillResponse.model_validate( + response = DistillResponse.model_validate( {"lora_id": payload.lora_id, "metadata": {"tokens_processed": 5}} ) + cache_entry = LoraCacheEntry( + lora_state_dict={}, + optimizer_state_dict={}, + adapter_config=LoraAdapterConfig(r=8, lora_alpha=16, target_modules=[], lora_dropout=0.0, bias="none", task_type="CAUSAL_LM"), + ) + return DistillStepResult(response=response, cache_entry=cache_entry) def offload_base_model(self) -> None: self._state.cleaned_up = True diff --git a/tests/test_distillation_optimizer_state.py b/tests/test_distillation_optimizer_state.py index ecce841..bec05dd 100644 --- a/tests/test_distillation_optimizer_state.py +++ b/tests/test_distillation_optimizer_state.py @@ -7,6 +7,12 @@ torch = pytest.importorskip("torch") from claas.training.distillation import DistillationTrainer # noqa: E402 +from claas.training.engine.local.cache import ( # noqa: E402 + LoraAdapterConfig, + LoraCacheEntry, + cpu_optimizer_state, + gpu_optimizer_state, +) class _SimpleLoraModel(torch.nn.Module): @@ -90,3 +96,81 @@ def test_optimizer_state_missing_gracefully_skips(trainer: DistillationTrainer, trainer._load_optimizer_state(str(tmp_path), optimizer) assert len(optimizer.state) == 0 + + +def testcpu_optimizer_state_moves_tensors_to_cpu() -> None: + """cpu_optimizer_state produces a state dict with all tensors on CPU.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + cpu_state = cpu_optimizer_state(original) + + for param_state in cpu_state["state"].values(): # type: ignore[union-attr] + for v in param_state.values(): + if isinstance(v, torch.Tensor): + assert v.device == torch.device("cpu") + + +def test_cpugpu_optimizer_state_roundtrip() -> None: + """cpu_optimizer_state / gpu_optimizer_state round-trip preserves values.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + cpu_state = cpu_optimizer_state(original) + roundtripped = gpu_optimizer_state(cpu_state, torch.device("cpu")) + + # Step counts match + for param_id in original["state"]: + assert roundtripped["state"][param_id]["step"] == original["state"][param_id]["step"] # type: ignore[index] + + # Tensor values match + for param_id in original["state"]: + for key in ("exp_avg", "exp_avg_sq"): + orig_tensor = original["state"][param_id][key] + rt_tensor = roundtripped["state"][param_id][key] # type: ignore[index] + assert torch.equal(orig_tensor, rt_tensor) + + +def testcpu_optimizer_state_does_not_mutate_original() -> None: + """cpu_optimizer_state deep-copies — mutating the copy leaves the original intact.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + original_exp_avg = original["state"][0]["exp_avg"].clone() + + cpu_state = cpu_optimizer_state(original) + # Mutate the copy + cpu_state["state"][0]["exp_avg"].zero_() # type: ignore[index] + + # Original is unchanged + assert torch.equal(original["state"][0]["exp_avg"], original_exp_avg) + + +def test_lora_cache_entry_is_frozen() -> None: + """LoraCacheEntry is immutable — attribute assignment raises.""" + entry = LoraCacheEntry( + lora_state_dict={"w": torch.zeros(2)}, + optimizer_state_dict={"state": {}, "param_groups": []}, + adapter_config=LoraAdapterConfig( + r=8, + lora_alpha=16, + target_modules=["q_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ), + ) + with pytest.raises(AttributeError): + entry.lora_state_dict = {} # type: ignore[misc] diff --git a/tests/test_local_training_engine.py b/tests/test_local_training_engine.py index 3d72027..efff14b 100644 --- a/tests/test_local_training_engine.py +++ b/tests/test_local_training_engine.py @@ -13,50 +13,155 @@ DistillResponse, TrainingConfig, ) +from claas.training.engine.local.cache import ( # noqa: E402 + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, +) from claas.training.engine.local.engine import LocalTrainingEngine # noqa: E402 +_DUMMY_CACHE_ENTRY = LoraCacheEntry( + lora_state_dict={"w": torch.zeros(2)}, + optimizer_state_dict={"state": {}, "param_groups": []}, + adapter_config=LoraAdapterConfig( + r=8, + lora_alpha=16, + target_modules=["q_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ), +) + + +def _make_payload(lora_id: str = "user/model") -> DistillBatchRequestPayload: + return DistillBatchRequestPayload( + lora_id=lora_id, + training=TrainingConfig(), + samples=[ + DistillBatchItem( + prompt="p", + response="r", + feedback="f", + response_logprobs=[-0.1], + prompt_token_ids=[1, 2], + response_token_ids=[3], + user_prompt="p", + system_prompt="You are a helpful assistant.", + ) + ], + ) + class _Trainer: + """Fake trainer that records method calls.""" + def __init__(self, base_model_id: str, attn_implementation: str): self.base_model_id = base_model_id self.attn_implementation = attn_implementation + self.load_base_model_count = 0 + self.reload_count = 0 + self.offload_count = 0 + self.distill_calls: list[dict] = [] def load_base_model(self) -> None: - return None + self.load_base_model_count += 1 - def distill(self, _payload: DistillBatchRequestPayload) -> DistillResponse: - return DistillResponse(lora_id="user/model", metadata={}) + def reload_base_model(self) -> None: + self.reload_count += 1 + + def distill( + self, + _payload: DistillBatchRequestPayload, + *, + cached: LoraCacheEntry | None = None, + ) -> DistillStepResult: + self.distill_calls.append({"cached": cached}) + return DistillStepResult( + response=DistillResponse(lora_id="user/model", metadata={}), + cache_entry=_DUMMY_CACHE_ENTRY, + ) + + def offload_base_model(self) -> None: + self.offload_count += 1 + + +class _FailingOffloadTrainer(_Trainer): + """Trainer whose offload raises to test error propagation.""" def offload_base_model(self) -> None: raise OSError("cleanup failed") -def test_local_engine_distill_propagates_cleanup_error(monkeypatch): +def _build_engine(monkeypatch, trainer_cls=_Trainer): + from claas.training.engine.local import engine as local_engine + + monkeypatch.setattr(local_engine, "DistillationTrainer", trainer_cls) + monkeypatch.setattr(local_engine, "resolve_lora_id", lambda lid: lid.strip("/")) + cfg = LocalConfig(base_model_id="Qwen/Qwen3-8B", attn_implementation="sdpa") + return LocalTrainingEngine(cfg) + + +def test_trainer_created_eagerly_in_init(monkeypatch): + """Trainer is created in __init__, not lazily on first distill().""" + engine = _build_engine(monkeypatch) + assert isinstance(engine._trainer, _Trainer) + assert engine._model_loaded is False + + +def test_load_base_model_called_once(monkeypatch): + """load_base_model is called exactly once across multiple distill() calls.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + asyncio.run(engine.distill(_make_payload())) + + assert engine._trainer.load_base_model_count == 1 + + +def test_reload_called_every_distill(monkeypatch): + """reload_base_model is called on every distill() call.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + asyncio.run(engine.distill(_make_payload())) + + assert engine._trainer.reload_count == 2 + + +def test_cache_miss_then_hit(monkeypatch): + """First call has cached=None, second call uses the cached entry.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + # First call: no cache + assert engine._trainer.distill_calls[0]["cached"] is None + + asyncio.run(engine.distill(_make_payload())) + # Second call: cache hit + assert engine._trainer.distill_calls[1]["cached"] is _DUMMY_CACHE_ENTRY + + +def test_cache_evicted_on_delete(monkeypatch): + """delete_lora() evicts the cache entry for that lora_id.""" from claas.training.engine.local import engine as local_engine - monkeypatch.setenv("CLAAS_BASE_MODEL_ID", "Qwen/Qwen3-8B") - monkeypatch.setenv("CLAAS_ATTN_IMPLEMENTATION", "sdpa") monkeypatch.setattr(local_engine, "DistillationTrainer", _Trainer) + monkeypatch.setattr(local_engine, "resolve_lora_id", lambda lid: lid.strip("/")) + monkeypatch.setattr(local_engine, "delete_lora", lambda lid: True) cfg = LocalConfig(base_model_id="Qwen/Qwen3-8B", attn_implementation="sdpa") + engine = LocalTrainingEngine(cfg) + + asyncio.run(engine.distill(_make_payload())) + assert "user/model" in engine._lora_cache + + asyncio.run(engine.delete_lora("user/model")) + assert "user/model" not in engine._lora_cache + + +def test_offload_error_propagates(monkeypatch): + """Errors from offload_base_model propagate to the caller.""" + engine = _build_engine(monkeypatch, trainer_cls=_FailingOffloadTrainer) with pytest.raises(OSError, match="cleanup failed"): - asyncio.run( - LocalTrainingEngine(cfg).distill( - DistillBatchRequestPayload( - lora_id="user/model", - training=TrainingConfig(), - samples=[ - DistillBatchItem( - prompt="p", - response="r", - feedback="f", - response_logprobs=[-0.1], - prompt_token_ids=[1, 2], - response_token_ids=[3], - user_prompt="p", - system_prompt="You are a helpful assistant.", - ) - ], - ) - ) - ) + asyncio.run(engine.distill(_make_payload()))