From 6607a2cdbdaeceaa37a6bb36ac1fe13c6759ad1c Mon Sep 17 00:00:00 2001 From: Kion Date: Mon, 23 Feb 2026 13:18:49 -0800 Subject: [PATCH 1/8] Persistent trainer + CPU cache for local training engine Keep the DistillationTrainer and base model across distill() calls instead of recreating them each time. Cache LoRA adapter weights and optimizer state on CPU between steps so the second call for a given lora_id skips all disk I/O. GPU memory is still fully released after each step via the existing offload_base_model() pattern. Co-Authored-By: Claude Opus 4.6 --- claas/modal/worker.py | 2 +- claas/training/cache.py | 38 +++++ claas/training/distillation.py | 160 +++++++++++++++++++-- claas/training/engine/local/engine.py | 50 ++++++- tests/test_distillation_optimizer_state.py | 85 ++++++++++- tests/test_local_training_engine.py | 156 ++++++++++++++++---- 6 files changed, 448 insertions(+), 43 deletions(-) create mode 100644 claas/training/cache.py diff --git a/claas/modal/worker.py b/claas/modal/worker.py index 1cfcb78..63bf5b4 100644 --- a/claas/modal/worker.py +++ b/claas/modal/worker.py @@ -90,7 +90,7 @@ def distill(self, request: DistillBatchRequestPayload) -> DistillResponse: Distillation response payload. """ try: - return self.trainer.distill(request) + return self.trainer.distill(request).response finally: self.trainer.offload_base_model() diff --git a/claas/training/cache.py b/claas/training/cache.py new file mode 100644 index 0000000..17f7faf --- /dev/null +++ b/claas/training/cache.py @@ -0,0 +1,38 @@ +"""Typed cache structures for CPU-resident LoRA state between training steps.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from claas.core.types import DistillResponse + + +@dataclass(frozen=True, slots=True) +class LoraAdapterConfig: + """Typed representation of LoRA adapter configuration.""" + + r: int + lora_alpha: int + target_modules: list[str] + lora_dropout: float + bias: str + task_type: str + + +@dataclass(frozen=True, slots=True) +class LoraCacheEntry: + """CPU-resident snapshot of LoRA adapter state between training steps.""" + + lora_state_dict: dict[str, torch.Tensor] + optimizer_state_dict: dict[str, object] + adapter_config: LoraAdapterConfig + + +@dataclass(frozen=True, slots=True) +class DistillStepResult: + """Result of a distillation step with both response and cache entry.""" + + response: DistillResponse + cache_entry: LoraCacheEntry diff --git a/claas/training/distillation.py b/claas/training/distillation.py index 3e2e007..4875a2a 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -2,6 +2,7 @@ from __future__ import annotations +import copy import json import logging import os @@ -12,6 +13,11 @@ import torch from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput +from claas.training.cache import ( + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, +) from claas.training.sdpo_loss import compute_sdpo_loss from claas.training.storage import ( cleanup_local_lora, @@ -46,6 +52,51 @@ class PreparedSample(TypedDict): behavior_logprobs: torch.Tensor +def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to CPU.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + cpu_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + cpu_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + cpu_param[k] = v.detach().cpu().clone() + else: + cpu_param[k] = copy.deepcopy(v) + cpu_states[param_id] = cpu_param + result[key] = cpu_states + else: + result[key] = copy.deepcopy(value) + return result + + +def _gpu_optimizer_state( + state_dict: dict[str, object], + device: torch.device, +) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to a target device.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + gpu_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + gpu_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + gpu_param[k] = v.detach().to(device).clone() + else: + gpu_param[k] = copy.deepcopy(v) + gpu_states[param_id] = gpu_param + result[key] = gpu_states + else: + result[key] = copy.deepcopy(value) + return result + + class DistillationTrainer: """Runs one SDPO distillation update using a loaded base model.""" @@ -98,6 +149,10 @@ def load_base_model(self) -> None: self.optimizer_cls = torch.optim.AdamW self.functional = torch.nn.functional + def reload_base_model(self) -> None: + """Move base model from CPU back to CUDA.""" + self.base_model.to(self.device) # type: ignore[arg-type] # functools.wraps confuses ty + def offload_base_model(self) -> None: """Move base model to CPU and release CUDA memory.""" @@ -141,6 +196,33 @@ def _load_or_create_lora(self, lora_path: str) -> "PeftModel | PeftMixedModel": ) return get_peft_model(self.base_model, lora_config) + def _load_lora_from_cache( + self, + cached: LoraCacheEntry, + ) -> "PeftModel | PeftMixedModel": + """Restore a LoRA adapter from a CPU cache entry. + + Args: + cached: CPU-resident snapshot of adapter state. + + Returns: + Trainable PEFT model with cached weights loaded. + """ + from peft import LoraConfig, get_peft_model, set_peft_model_state_dict + + cfg = cached.adapter_config + lora_config = LoraConfig( + r=cfg.r, + lora_alpha=cfg.lora_alpha, + target_modules=cfg.target_modules, + lora_dropout=cfg.lora_dropout, + bias=cfg.bias, + task_type=cfg.task_type, + ) + model = get_peft_model(self.base_model, lora_config) + set_peft_model_state_dict(model, cached.lora_state_dict) + return model + def _load_optimizer_state( self, lora_path: str, @@ -248,14 +330,58 @@ def _compute_student_response_logprobs( torch.cuda.empty_cache() return student_logprobs.to(dtype=torch.float32).detach() - def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: + def _build_cache_entry( + self, + model: "PeftModel | PeftMixedModel", + optimizer: "torch.optim.Optimizer", + ) -> LoraCacheEntry: + """Snapshot current model + optimizer state into a CPU-resident cache entry.""" + from peft import PeftModel as PeftModelCls + + peft_config = model.peft_config["default"] + adapter_config = LoraAdapterConfig( + r=peft_config.r, + lora_alpha=peft_config.lora_alpha, + target_modules=list(peft_config.target_modules), + lora_dropout=peft_config.lora_dropout, + bias=peft_config.bias, + task_type=peft_config.task_type, + ) + + # Determine state dict — use PEFT's adapter-only extraction if available + if isinstance(model, PeftModelCls): + from peft import get_peft_model_state_dict + + raw_state = get_peft_model_state_dict(model) + else: + raw_state = model.state_dict() + + lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()} + opt_state = _cpu_optimizer_state(optimizer.state_dict()) + + return LoraCacheEntry( + lora_state_dict=lora_state, + optimizer_state_dict=opt_state, + adapter_config=adapter_config, + ) + + def distill( + self, + payload: DistillBatchRequestPayload, + *, + cached: LoraCacheEntry | None = None, + ) -> DistillStepResult: """Run one SDPO distillation step. Args: payload: Distillation request payload. + cached: When provided, skip disk reads and load LoRA + optimizer + state from this CPU-resident cache entry. When ``None``, + load from disk (cold start). Returns: - Distillation response with metrics. + Result containing both the distillation response and a cache + entry for the post-step state. """ torch.cuda.empty_cache() @@ -266,10 +392,18 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: if len(payload.samples) == 0: raise ValueError("samples must contain at least one item") - lora_local_path = load_lora(payload.lora_id) + # Disk path (cold start) or cache path + lora_local_path: str | None = None + if cached is None: + lora_local_path = load_lora(payload.lora_id) + try: try: - model = self._load_or_create_lora(lora_local_path) + if cached is not None: + model = self._load_lora_from_cache(cached) + else: + assert lora_local_path is not None + model = self._load_or_create_lora(lora_local_path) model.train() model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False}, @@ -284,7 +418,13 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: betas=(0.9, 0.999), weight_decay=0.01, ) - self._load_optimizer_state(lora_local_path, optimizer) + + if cached is not None: + optimizer.load_state_dict( + _gpu_optimizer_state(cached.optimizer_state_dict, self.device) + ) + elif lora_local_path is not None: + self._load_optimizer_state(lora_local_path, optimizer) prepared_samples: list[PreparedSample] = [] batch_teacher_scored_texts: list[str] = [] @@ -436,10 +576,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: final_step_metrics = step_metrics[-1] - del model, optimizer + cache_entry = self._build_cache_entry(model, optimizer) + + del model, optimizer, batch_loss_tensors torch.cuda.empty_cache() - return DistillResponse.model_validate( + response = DistillResponse.model_validate( { "lora_id": new_lora_id, "metadata": { @@ -457,5 +599,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: }, } ) + return DistillStepResult(response=response, cache_entry=cache_entry) finally: - cleanup_local_lora(lora_local_path) + if lora_local_path is not None: + cleanup_local_lora(lora_local_path) diff --git a/claas/training/engine/local/engine.py b/claas/training/engine/local/engine.py index 0533127..6379eb0 100644 --- a/claas/training/engine/local/engine.py +++ b/claas/training/engine/local/engine.py @@ -3,7 +3,9 @@ from __future__ import annotations import asyncio +import logging import re +import threading from claas.core.config import LocalConfig from claas.core.types import ( @@ -18,6 +20,7 @@ LoraRuntimeRef, ServiceHealth, ) +from claas.training.cache import LoraCacheEntry from claas.training.distillation import DistillationTrainer from claas.training.engine.base import TrainingEngine from claas.training.storage import ( @@ -31,14 +34,34 @@ resolve_lora_id, ) +logger = logging.getLogger(__name__) + class LocalTrainingEngine(TrainingEngine): """Executes training and LoRA operations on local infrastructure.""" + _trainer: DistillationTrainer + _lora_cache: dict[str, LoraCacheEntry] + _cache_lock: threading.Lock + _model_loaded: bool + def __init__(self, cfg: LocalConfig) -> None: configure_storage_backend("local_fs") self._base_model_id = cfg.base_model_id self._attn_implementation = cfg.attn_implementation + self._trainer = DistillationTrainer( + base_model_id=cfg.base_model_id, + attn_implementation=cfg.attn_implementation, + ) + self._lora_cache = {} + self._cache_lock = threading.Lock() + self._model_loaded = False + + async def _ensure_model_loaded(self) -> None: + """One-time base model load on first distill() call.""" + if not self._model_loaded: + await asyncio.to_thread(self._trainer.load_base_model) + self._model_loaded = True async def distill( self, @@ -52,15 +75,24 @@ async def distill( Returns: Distillation response. """ - trainer = DistillationTrainer( - base_model_id=self._base_model_id, - attn_implementation=self._attn_implementation, - ) - await asyncio.to_thread(trainer.load_base_model) + await self._ensure_model_loaded() + await asyncio.to_thread(self._trainer.reload_base_model) + + resolved_id = await asyncio.to_thread(resolve_lora_id, payload.lora_id) + with self._cache_lock: + cached = self._lora_cache.get(resolved_id) + try: - return await asyncio.to_thread(trainer.distill, payload) + result = await asyncio.to_thread( + self._trainer.distill, payload, cached=cached + ) finally: - await asyncio.to_thread(trainer.offload_base_model) + await asyncio.to_thread(self._trainer.offload_base_model) + + with self._cache_lock: + self._lora_cache[resolved_id] = result.cache_entry + + return result.response async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse: """Initialize a LoRA adapter locally. @@ -82,7 +114,11 @@ async def init_lora(self, request: LoraInitRequest) -> LoraInitResponse: return LoraInitResponse(lora_id=lora_id) async def delete_lora(self, lora_id: str) -> LoraDeleteResponse: + resolved_id = await asyncio.to_thread(resolve_lora_id, lora_id) deleted = await asyncio.to_thread(delete_lora, lora_id) + if deleted: + with self._cache_lock: + self._lora_cache.pop(resolved_id, None) return LoraDeleteResponse(deleted=deleted) async def list_loras(self, prefix: str) -> LoraListResponse: diff --git a/tests/test_distillation_optimizer_state.py b/tests/test_distillation_optimizer_state.py index ecce841..c211570 100644 --- a/tests/test_distillation_optimizer_state.py +++ b/tests/test_distillation_optimizer_state.py @@ -6,7 +6,12 @@ torch = pytest.importorskip("torch") -from claas.training.distillation import DistillationTrainer # noqa: E402 +from claas.training.cache import LoraAdapterConfig, LoraCacheEntry # noqa: E402 +from claas.training.distillation import ( # noqa: E402 + DistillationTrainer, + _cpu_optimizer_state, + _gpu_optimizer_state, +) class _SimpleLoraModel(torch.nn.Module): @@ -90,3 +95,81 @@ def test_optimizer_state_missing_gracefully_skips(trainer: DistillationTrainer, trainer._load_optimizer_state(str(tmp_path), optimizer) assert len(optimizer.state) == 0 + + +def test_cpu_optimizer_state_moves_tensors_to_cpu() -> None: + """_cpu_optimizer_state produces a state dict with all tensors on CPU.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + cpu_state = _cpu_optimizer_state(original) + + for param_state in cpu_state["state"].values(): + for v in param_state.values(): + if isinstance(v, torch.Tensor): + assert v.device == torch.device("cpu") + + +def test_cpu_gpu_optimizer_state_roundtrip() -> None: + """_cpu_optimizer_state / _gpu_optimizer_state round-trip preserves values.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + cpu_state = _cpu_optimizer_state(original) + roundtripped = _gpu_optimizer_state(cpu_state, torch.device("cpu")) + + # Step counts match + for param_id in original["state"]: + assert roundtripped["state"][param_id]["step"] == original["state"][param_id]["step"] + + # Tensor values match + for param_id in original["state"]: + for key in ("exp_avg", "exp_avg_sq"): + orig_tensor = original["state"][param_id][key] + rt_tensor = roundtripped["state"][param_id][key] + assert torch.equal(orig_tensor, rt_tensor) + + +def test_cpu_optimizer_state_does_not_mutate_original() -> None: + """_cpu_optimizer_state deep-copies — mutating the copy leaves the original intact.""" + model = _SimpleLoraModel() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + loss = model.first.sum() + loss.backward() + optimizer.step() + + original = optimizer.state_dict() + original_exp_avg = original["state"][0]["exp_avg"].clone() + + cpu_state = _cpu_optimizer_state(original) + # Mutate the copy + cpu_state["state"][0]["exp_avg"].zero_() + + # Original is unchanged + assert torch.equal(original["state"][0]["exp_avg"], original_exp_avg) + + +def test_lora_cache_entry_is_frozen() -> None: + """LoraCacheEntry is immutable — attribute assignment raises.""" + entry = LoraCacheEntry( + lora_state_dict={"w": torch.zeros(2)}, + optimizer_state_dict={"state": {}, "param_groups": []}, + adapter_config=LoraAdapterConfig( + r=8, + lora_alpha=16, + target_modules=["q_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ), + ) + with pytest.raises(AttributeError): + entry.lora_state_dict = {} # type: ignore[misc] diff --git a/tests/test_local_training_engine.py b/tests/test_local_training_engine.py index 3d72027..b7fba45 100644 --- a/tests/test_local_training_engine.py +++ b/tests/test_local_training_engine.py @@ -13,50 +13,154 @@ DistillResponse, TrainingConfig, ) +from claas.training.cache import ( # noqa: E402 + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, +) from claas.training.engine.local.engine import LocalTrainingEngine # noqa: E402 +_DUMMY_CACHE_ENTRY = LoraCacheEntry( + lora_state_dict={"w": torch.zeros(2)}, + optimizer_state_dict={"state": {}, "param_groups": []}, + adapter_config=LoraAdapterConfig( + r=8, + lora_alpha=16, + target_modules=["q_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ), +) + + +def _make_payload(lora_id: str = "user/model") -> DistillBatchRequestPayload: + return DistillBatchRequestPayload( + lora_id=lora_id, + training=TrainingConfig(), + samples=[ + DistillBatchItem( + prompt="p", + response="r", + feedback="f", + response_logprobs=[-0.1], + prompt_token_ids=[1, 2], + response_token_ids=[3], + user_prompt="p", + ) + ], + ) + class _Trainer: + """Fake trainer that records method calls.""" + def __init__(self, base_model_id: str, attn_implementation: str): self.base_model_id = base_model_id self.attn_implementation = attn_implementation + self.load_base_model_count = 0 + self.reload_count = 0 + self.offload_count = 0 + self.distill_calls: list[dict] = [] def load_base_model(self) -> None: - return None + self.load_base_model_count += 1 - def distill(self, _payload: DistillBatchRequestPayload) -> DistillResponse: - return DistillResponse(lora_id="user/model", metadata={}) + def reload_base_model(self) -> None: + self.reload_count += 1 + + def distill( + self, + _payload: DistillBatchRequestPayload, + *, + cached: LoraCacheEntry | None = None, + ) -> DistillStepResult: + self.distill_calls.append({"cached": cached}) + return DistillStepResult( + response=DistillResponse(lora_id="user/model", metadata={}), + cache_entry=_DUMMY_CACHE_ENTRY, + ) + + def offload_base_model(self) -> None: + self.offload_count += 1 + + +class _FailingOffloadTrainer(_Trainer): + """Trainer whose offload raises to test error propagation.""" def offload_base_model(self) -> None: raise OSError("cleanup failed") -def test_local_engine_distill_propagates_cleanup_error(monkeypatch): +def _build_engine(monkeypatch, trainer_cls=_Trainer): + from claas.training.engine.local import engine as local_engine + + monkeypatch.setattr(local_engine, "DistillationTrainer", trainer_cls) + monkeypatch.setattr(local_engine, "resolve_lora_id", lambda lid: lid.strip("/")) + cfg = LocalConfig(base_model_id="Qwen/Qwen3-8B", attn_implementation="sdpa") + return LocalTrainingEngine(cfg) + + +def test_trainer_created_eagerly_in_init(monkeypatch): + """Trainer is created in __init__, not lazily on first distill().""" + engine = _build_engine(monkeypatch) + assert isinstance(engine._trainer, _Trainer) + assert engine._model_loaded is False + + +def test_load_base_model_called_once(monkeypatch): + """load_base_model is called exactly once across multiple distill() calls.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + asyncio.run(engine.distill(_make_payload())) + + assert engine._trainer.load_base_model_count == 1 + + +def test_reload_called_every_distill(monkeypatch): + """reload_base_model is called on every distill() call.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + asyncio.run(engine.distill(_make_payload())) + + assert engine._trainer.reload_count == 2 + + +def test_cache_miss_then_hit(monkeypatch): + """First call has cached=None, second call uses the cached entry.""" + engine = _build_engine(monkeypatch) + + asyncio.run(engine.distill(_make_payload())) + # First call: no cache + assert engine._trainer.distill_calls[0]["cached"] is None + + asyncio.run(engine.distill(_make_payload())) + # Second call: cache hit + assert engine._trainer.distill_calls[1]["cached"] is _DUMMY_CACHE_ENTRY + + +def test_cache_evicted_on_delete(monkeypatch): + """delete_lora() evicts the cache entry for that lora_id.""" from claas.training.engine.local import engine as local_engine - monkeypatch.setenv("CLAAS_BASE_MODEL_ID", "Qwen/Qwen3-8B") - monkeypatch.setenv("CLAAS_ATTN_IMPLEMENTATION", "sdpa") monkeypatch.setattr(local_engine, "DistillationTrainer", _Trainer) + monkeypatch.setattr(local_engine, "resolve_lora_id", lambda lid: lid.strip("/")) + monkeypatch.setattr(local_engine, "delete_lora", lambda lid: True) cfg = LocalConfig(base_model_id="Qwen/Qwen3-8B", attn_implementation="sdpa") + engine = LocalTrainingEngine(cfg) + + asyncio.run(engine.distill(_make_payload())) + assert "user/model" in engine._lora_cache + + asyncio.run(engine.delete_lora("user/model")) + assert "user/model" not in engine._lora_cache + + +def test_offload_error_propagates(monkeypatch): + """Errors from offload_base_model propagate to the caller.""" + engine = _build_engine(monkeypatch, trainer_cls=_FailingOffloadTrainer) with pytest.raises(OSError, match="cleanup failed"): - asyncio.run( - LocalTrainingEngine(cfg).distill( - DistillBatchRequestPayload( - lora_id="user/model", - training=TrainingConfig(), - samples=[ - DistillBatchItem( - prompt="p", - response="r", - feedback="f", - response_logprobs=[-0.1], - prompt_token_ids=[1, 2], - response_token_ids=[3], - user_prompt="p", - system_prompt="You are a helpful assistant.", - ) - ], - ) - ) - ) + asyncio.run(engine.distill(_make_payload())) From 166a16dbfca896e185cd7948ae278b4007114af9 Mon Sep 17 00:00:00 2001 From: Kion Date: Mon, 23 Feb 2026 13:25:01 -0800 Subject: [PATCH 2/8] Move cache types and optimizer state helpers into local engine Relocate LoraCacheEntry, LoraAdapterConfig, DistillStepResult, and the cpu/gpu_optimizer_state helpers from the shared training module into claas/training/engine/local/cache.py since they are only used by the local engine's CPU caching path. The Modal worker never uses caching. Co-Authored-By: Claude Opus 4.6 --- claas/training/cache.py | 38 ---------- claas/training/distillation.py | 53 ++------------ claas/training/engine/local/cache.py | 85 ++++++++++++++++++++++ claas/training/engine/local/engine.py | 2 +- tests/test_distillation_optimizer_state.py | 31 ++++---- tests/test_local_training_engine.py | 2 +- 6 files changed, 108 insertions(+), 103 deletions(-) delete mode 100644 claas/training/cache.py create mode 100644 claas/training/engine/local/cache.py diff --git a/claas/training/cache.py b/claas/training/cache.py deleted file mode 100644 index 17f7faf..0000000 --- a/claas/training/cache.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Typed cache structures for CPU-resident LoRA state between training steps.""" - -from __future__ import annotations - -from dataclasses import dataclass - -import torch - -from claas.core.types import DistillResponse - - -@dataclass(frozen=True, slots=True) -class LoraAdapterConfig: - """Typed representation of LoRA adapter configuration.""" - - r: int - lora_alpha: int - target_modules: list[str] - lora_dropout: float - bias: str - task_type: str - - -@dataclass(frozen=True, slots=True) -class LoraCacheEntry: - """CPU-resident snapshot of LoRA adapter state between training steps.""" - - lora_state_dict: dict[str, torch.Tensor] - optimizer_state_dict: dict[str, object] - adapter_config: LoraAdapterConfig - - -@dataclass(frozen=True, slots=True) -class DistillStepResult: - """Result of a distillation step with both response and cache entry.""" - - response: DistillResponse - cache_entry: LoraCacheEntry diff --git a/claas/training/distillation.py b/claas/training/distillation.py index 4875a2a..3eeae96 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -2,7 +2,6 @@ from __future__ import annotations -import copy import json import logging import os @@ -13,10 +12,12 @@ import torch from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput -from claas.training.cache import ( +from claas.training.engine.local.cache import ( DistillStepResult, LoraAdapterConfig, LoraCacheEntry, + cpu_optimizer_state, + gpu_optimizer_state, ) from claas.training.sdpo_loss import compute_sdpo_loss from claas.training.storage import ( @@ -52,50 +53,6 @@ class PreparedSample(TypedDict): behavior_logprobs: torch.Tensor -def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: - """Deep-copy optimizer state with all tensors moved to CPU.""" - result: dict[str, object] = {} - for key, value in state_dict.items(): - if key == "state": - param_states = cast("dict[int, dict[str, object]]", value) - cpu_states: dict[int, dict[str, object]] = {} - for param_id, param_state in param_states.items(): - cpu_param: dict[str, object] = {} - for k, v in param_state.items(): - if isinstance(v, torch.Tensor): - cpu_param[k] = v.detach().cpu().clone() - else: - cpu_param[k] = copy.deepcopy(v) - cpu_states[param_id] = cpu_param - result[key] = cpu_states - else: - result[key] = copy.deepcopy(value) - return result - - -def _gpu_optimizer_state( - state_dict: dict[str, object], - device: torch.device, -) -> dict[str, object]: - """Deep-copy optimizer state with all tensors moved to a target device.""" - result: dict[str, object] = {} - for key, value in state_dict.items(): - if key == "state": - param_states = cast("dict[int, dict[str, object]]", value) - gpu_states: dict[int, dict[str, object]] = {} - for param_id, param_state in param_states.items(): - gpu_param: dict[str, object] = {} - for k, v in param_state.items(): - if isinstance(v, torch.Tensor): - gpu_param[k] = v.detach().to(device).clone() - else: - gpu_param[k] = copy.deepcopy(v) - gpu_states[param_id] = gpu_param - result[key] = gpu_states - else: - result[key] = copy.deepcopy(value) - return result - class DistillationTrainer: """Runs one SDPO distillation update using a loaded base model.""" @@ -357,7 +314,7 @@ def _build_cache_entry( raw_state = model.state_dict() lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()} - opt_state = _cpu_optimizer_state(optimizer.state_dict()) + opt_state = cpu_optimizer_state(optimizer.state_dict()) return LoraCacheEntry( lora_state_dict=lora_state, @@ -421,7 +378,7 @@ def distill( if cached is not None: optimizer.load_state_dict( - _gpu_optimizer_state(cached.optimizer_state_dict, self.device) + gpu_optimizer_state(cached.optimizer_state_dict, self.device) ) elif lora_local_path is not None: self._load_optimizer_state(lora_local_path, optimizer) diff --git a/claas/training/engine/local/cache.py b/claas/training/engine/local/cache.py new file mode 100644 index 0000000..309e47d --- /dev/null +++ b/claas/training/engine/local/cache.py @@ -0,0 +1,85 @@ +"""Typed cache structures and helpers for CPU-resident LoRA state between training steps.""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass +from typing import cast + +import torch + +from claas.core.types import DistillResponse + + +@dataclass(frozen=True, slots=True) +class LoraAdapterConfig: + """Typed representation of LoRA adapter configuration.""" + + r: int + lora_alpha: int + target_modules: list[str] + lora_dropout: float + bias: str + task_type: str + + +@dataclass(frozen=True, slots=True) +class LoraCacheEntry: + """CPU-resident snapshot of LoRA adapter state between training steps.""" + + lora_state_dict: dict[str, torch.Tensor] + optimizer_state_dict: dict[str, object] + adapter_config: LoraAdapterConfig + + +@dataclass(frozen=True, slots=True) +class DistillStepResult: + """Result of a distillation step with both response and cache entry.""" + + response: DistillResponse + cache_entry: LoraCacheEntry + + +def cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to CPU.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + cpu_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + cpu_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + cpu_param[k] = v.detach().cpu().clone() + else: + cpu_param[k] = copy.deepcopy(v) + cpu_states[param_id] = cpu_param + result[key] = cpu_states + else: + result[key] = copy.deepcopy(value) + return result + + +def gpu_optimizer_state( + state_dict: dict[str, object], + device: torch.device, +) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to a target device.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + gpu_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + gpu_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + gpu_param[k] = v.detach().to(device).clone() + else: + gpu_param[k] = copy.deepcopy(v) + gpu_states[param_id] = gpu_param + result[key] = gpu_states + else: + result[key] = copy.deepcopy(value) + return result diff --git a/claas/training/engine/local/engine.py b/claas/training/engine/local/engine.py index 6379eb0..4f850ed 100644 --- a/claas/training/engine/local/engine.py +++ b/claas/training/engine/local/engine.py @@ -20,9 +20,9 @@ LoraRuntimeRef, ServiceHealth, ) -from claas.training.cache import LoraCacheEntry from claas.training.distillation import DistillationTrainer from claas.training.engine.base import TrainingEngine +from claas.training.engine.local.cache import LoraCacheEntry from claas.training.storage import ( configure_storage_backend, create_initial_lora, diff --git a/tests/test_distillation_optimizer_state.py b/tests/test_distillation_optimizer_state.py index c211570..cdeebd0 100644 --- a/tests/test_distillation_optimizer_state.py +++ b/tests/test_distillation_optimizer_state.py @@ -6,11 +6,12 @@ torch = pytest.importorskip("torch") -from claas.training.cache import LoraAdapterConfig, LoraCacheEntry # noqa: E402 -from claas.training.distillation import ( # noqa: E402 - DistillationTrainer, - _cpu_optimizer_state, - _gpu_optimizer_state, +from claas.training.distillation import DistillationTrainer # noqa: E402 +from claas.training.engine.local.cache import ( # noqa: E402 + LoraAdapterConfig, + LoraCacheEntry, + cpu_optimizer_state, + gpu_optimizer_state, ) @@ -97,8 +98,8 @@ def test_optimizer_state_missing_gracefully_skips(trainer: DistillationTrainer, assert len(optimizer.state) == 0 -def test_cpu_optimizer_state_moves_tensors_to_cpu() -> None: - """_cpu_optimizer_state produces a state dict with all tensors on CPU.""" +def testcpu_optimizer_state_moves_tensors_to_cpu() -> None: + """cpu_optimizer_state produces a state dict with all tensors on CPU.""" model = _SimpleLoraModel() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) loss = model.first.sum() @@ -106,7 +107,7 @@ def test_cpu_optimizer_state_moves_tensors_to_cpu() -> None: optimizer.step() original = optimizer.state_dict() - cpu_state = _cpu_optimizer_state(original) + cpu_state = cpu_optimizer_state(original) for param_state in cpu_state["state"].values(): for v in param_state.values(): @@ -114,8 +115,8 @@ def test_cpu_optimizer_state_moves_tensors_to_cpu() -> None: assert v.device == torch.device("cpu") -def test_cpu_gpu_optimizer_state_roundtrip() -> None: - """_cpu_optimizer_state / _gpu_optimizer_state round-trip preserves values.""" +def test_cpugpu_optimizer_state_roundtrip() -> None: + """cpu_optimizer_state / gpu_optimizer_state round-trip preserves values.""" model = _SimpleLoraModel() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) loss = model.first.sum() @@ -123,8 +124,8 @@ def test_cpu_gpu_optimizer_state_roundtrip() -> None: optimizer.step() original = optimizer.state_dict() - cpu_state = _cpu_optimizer_state(original) - roundtripped = _gpu_optimizer_state(cpu_state, torch.device("cpu")) + cpu_state = cpu_optimizer_state(original) + roundtripped = gpu_optimizer_state(cpu_state, torch.device("cpu")) # Step counts match for param_id in original["state"]: @@ -138,8 +139,8 @@ def test_cpu_gpu_optimizer_state_roundtrip() -> None: assert torch.equal(orig_tensor, rt_tensor) -def test_cpu_optimizer_state_does_not_mutate_original() -> None: - """_cpu_optimizer_state deep-copies — mutating the copy leaves the original intact.""" +def testcpu_optimizer_state_does_not_mutate_original() -> None: + """cpu_optimizer_state deep-copies — mutating the copy leaves the original intact.""" model = _SimpleLoraModel() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) loss = model.first.sum() @@ -149,7 +150,7 @@ def test_cpu_optimizer_state_does_not_mutate_original() -> None: original = optimizer.state_dict() original_exp_avg = original["state"][0]["exp_avg"].clone() - cpu_state = _cpu_optimizer_state(original) + cpu_state = cpu_optimizer_state(original) # Mutate the copy cpu_state["state"][0]["exp_avg"].zero_() diff --git a/tests/test_local_training_engine.py b/tests/test_local_training_engine.py index b7fba45..8648fc5 100644 --- a/tests/test_local_training_engine.py +++ b/tests/test_local_training_engine.py @@ -13,7 +13,7 @@ DistillResponse, TrainingConfig, ) -from claas.training.cache import ( # noqa: E402 +from claas.training.engine.local.cache import ( # noqa: E402 DistillStepResult, LoraAdapterConfig, LoraCacheEntry, From ecf921bee91295786284929dbd089f6840dc598d Mon Sep 17 00:00:00 2001 From: Kion Date: Fri, 6 Mar 2026 16:44:32 -0800 Subject: [PATCH 3/8] Move cache types into separate types.py module Addresses review feedback: split dataclass types from helper functions in cache.py for clearer module organization. Co-Authored-By: Claude Opus 4.6 --- claas/training/engine/local/cache.py | 39 +++++----------------------- claas/training/engine/local/types.py | 38 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 32 deletions(-) create mode 100644 claas/training/engine/local/types.py diff --git a/claas/training/engine/local/cache.py b/claas/training/engine/local/cache.py index 309e47d..44ca4e9 100644 --- a/claas/training/engine/local/cache.py +++ b/claas/training/engine/local/cache.py @@ -1,43 +1,18 @@ -"""Typed cache structures and helpers for CPU-resident LoRA state between training steps.""" +"""CPU-resident optimizer state helpers for the local training engine.""" from __future__ import annotations import copy -from dataclasses import dataclass from typing import cast import torch -from claas.core.types import DistillResponse - - -@dataclass(frozen=True, slots=True) -class LoraAdapterConfig: - """Typed representation of LoRA adapter configuration.""" - - r: int - lora_alpha: int - target_modules: list[str] - lora_dropout: float - bias: str - task_type: str - - -@dataclass(frozen=True, slots=True) -class LoraCacheEntry: - """CPU-resident snapshot of LoRA adapter state between training steps.""" - - lora_state_dict: dict[str, torch.Tensor] - optimizer_state_dict: dict[str, object] - adapter_config: LoraAdapterConfig - - -@dataclass(frozen=True, slots=True) -class DistillStepResult: - """Result of a distillation step with both response and cache entry.""" - - response: DistillResponse - cache_entry: LoraCacheEntry +# Re-export types for backward compatibility +from claas.training.engine.local.types import ( # noqa: F401 + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, +) def cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: diff --git a/claas/training/engine/local/types.py b/claas/training/engine/local/types.py new file mode 100644 index 0000000..6b606d9 --- /dev/null +++ b/claas/training/engine/local/types.py @@ -0,0 +1,38 @@ +"""Typed structures for the local training engine's CPU-resident LoRA cache.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from claas.core.types import DistillResponse + + +@dataclass(frozen=True, slots=True) +class LoraAdapterConfig: + """Typed representation of LoRA adapter configuration.""" + + r: int + lora_alpha: int + target_modules: list[str] + lora_dropout: float + bias: str + task_type: str + + +@dataclass(frozen=True, slots=True) +class LoraCacheEntry: + """CPU-resident snapshot of LoRA adapter state between training steps.""" + + lora_state_dict: dict[str, torch.Tensor] + optimizer_state_dict: dict[str, object] + adapter_config: LoraAdapterConfig + + +@dataclass(frozen=True, slots=True) +class DistillStepResult: + """Result of a distillation step with both response and cache entry.""" + + response: DistillResponse + cache_entry: LoraCacheEntry From a125d2396617d6f5f12bc28206ff6ecdaf44d67a Mon Sep 17 00:00:00 2001 From: Kion Date: Fri, 6 Mar 2026 16:44:37 -0800 Subject: [PATCH 4/8] Raise TypeError instead of silent fallback in _build_cache_entry If the model is not a PeftModel, raise an explicit error instead of silently falling back to full state_dict extraction. Also coerces task_type enum to string for type consistency. Co-Authored-By: Claude Opus 4.6 --- claas/training/distillation.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/claas/training/distillation.py b/claas/training/distillation.py index 3eeae96..f848ee5 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -302,16 +302,17 @@ def _build_cache_entry( target_modules=list(peft_config.target_modules), lora_dropout=peft_config.lora_dropout, bias=peft_config.bias, - task_type=peft_config.task_type, + task_type=str(peft_config.task_type), ) - # Determine state dict — use PEFT's adapter-only extraction if available - if isinstance(model, PeftModelCls): - from peft import get_peft_model_state_dict + # Extract adapter-only state dict via PEFT + if not isinstance(model, PeftModelCls): + raise TypeError( + f"Expected a PeftModel for cache entry construction, got {type(model).__name__}" + ) + from peft import get_peft_model_state_dict - raw_state = get_peft_model_state_dict(model) - else: - raw_state = model.state_dict() + raw_state = get_peft_model_state_dict(model) lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()} opt_state = cpu_optimizer_state(optimizer.state_dict()) From bbb9a496929d672b2f2e625703006e008a83fcb6 Mon Sep 17 00:00:00 2001 From: Kion Date: Fri, 6 Mar 2026 16:44:42 -0800 Subject: [PATCH 5/8] Fix race in _ensure_model_loaded and cache key mismatch Add asyncio.Lock to prevent concurrent model loads. Cache distill results under the output lora_id instead of the input to handle non-in-place saves correctly. Co-Authored-By: Claude Opus 4.6 --- claas/training/engine/local/engine.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/claas/training/engine/local/engine.py b/claas/training/engine/local/engine.py index 4f850ed..4eb302b 100644 --- a/claas/training/engine/local/engine.py +++ b/claas/training/engine/local/engine.py @@ -56,12 +56,14 @@ def __init__(self, cfg: LocalConfig) -> None: self._lora_cache = {} self._cache_lock = threading.Lock() self._model_loaded = False + self._load_lock = asyncio.Lock() async def _ensure_model_loaded(self) -> None: """One-time base model load on first distill() call.""" - if not self._model_loaded: - await asyncio.to_thread(self._trainer.load_base_model) - self._model_loaded = True + async with self._load_lock: + if not self._model_loaded: + await asyncio.to_thread(self._trainer.load_base_model) + self._model_loaded = True async def distill( self, @@ -89,8 +91,9 @@ async def distill( finally: await asyncio.to_thread(self._trainer.offload_base_model) + new_resolved = await asyncio.to_thread(resolve_lora_id, result.response.lora_id) with self._cache_lock: - self._lora_cache[resolved_id] = result.cache_entry + self._lora_cache[new_resolved] = result.cache_entry return result.response From 79fe5130f3ed1a6cf274730a2a9689fe1bad810b Mon Sep 17 00:00:00 2001 From: Kion Date: Fri, 6 Mar 2026 17:01:26 -0800 Subject: [PATCH 6/8] Fix CI: type errors and missing system_prompt in tests - Cast peft_config to LoraConfig instead of PeftConfig base class - Add type: ignore for peft Literal bias type and dict[str, object] subscripts - Add missing system_prompt field to test payloads - Update integration test stubs for new distill() signature Co-Authored-By: Claude Opus 4.6 --- claas/training/distillation.py | 8 +++++--- .../test_local_engine_integration.py | 18 ++++++++++++++++-- tests/test_distillation_optimizer_state.py | 8 ++++---- tests/test_local_training_engine.py | 1 + 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/claas/training/distillation.py b/claas/training/distillation.py index f848ee5..ef9a595 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -173,7 +173,7 @@ def _load_lora_from_cache( lora_alpha=cfg.lora_alpha, target_modules=cfg.target_modules, lora_dropout=cfg.lora_dropout, - bias=cfg.bias, + bias=cfg.bias, # type: ignore[arg-type] # peft Literal vs str task_type=cfg.task_type, ) model = get_peft_model(self.base_model, lora_config) @@ -295,11 +295,13 @@ def _build_cache_entry( """Snapshot current model + optimizer state into a CPU-resident cache entry.""" from peft import PeftModel as PeftModelCls - peft_config = model.peft_config["default"] + from peft import LoraConfig + + peft_config = cast(LoraConfig, model.peft_config["default"]) adapter_config = LoraAdapterConfig( r=peft_config.r, lora_alpha=peft_config.lora_alpha, - target_modules=list(peft_config.target_modules), + target_modules=list(peft_config.target_modules or []), lora_dropout=peft_config.lora_dropout, bias=peft_config.bias, task_type=str(peft_config.task_type), diff --git a/tests/integration/test_local_engine_integration.py b/tests/integration/test_local_engine_integration.py index 06e945c..1398994 100644 --- a/tests/integration/test_local_engine_integration.py +++ b/tests/integration/test_local_engine_integration.py @@ -18,6 +18,11 @@ TrainingConfig, ) from claas.training import storage # noqa: E402 +from claas.training.engine.local.cache import ( # noqa: E402 + DistillStepResult, + LoraAdapterConfig, + LoraCacheEntry, +) from claas.training.engine.local.engine import LocalTrainingEngine # noqa: E402 @@ -41,11 +46,20 @@ def __init__(self, base_model_id: str, attn_implementation: str, state: TrainerS def load_base_model(self) -> None: self._state.loaded = True - def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: + def reload_base_model(self) -> None: + pass + + def distill(self, payload: DistillBatchRequestPayload, *, cached: object = None) -> DistillStepResult: self._state.payload = payload - return DistillResponse.model_validate( + response = DistillResponse.model_validate( {"lora_id": payload.lora_id, "metadata": {"tokens_processed": 5}} ) + cache_entry = LoraCacheEntry( + lora_state_dict={}, + optimizer_state_dict={}, + adapter_config=LoraAdapterConfig(r=8, lora_alpha=16, target_modules=[], lora_dropout=0.0, bias="none", task_type="CAUSAL_LM"), + ) + return DistillStepResult(response=response, cache_entry=cache_entry) def offload_base_model(self) -> None: self._state.cleaned_up = True diff --git a/tests/test_distillation_optimizer_state.py b/tests/test_distillation_optimizer_state.py index cdeebd0..bec05dd 100644 --- a/tests/test_distillation_optimizer_state.py +++ b/tests/test_distillation_optimizer_state.py @@ -109,7 +109,7 @@ def testcpu_optimizer_state_moves_tensors_to_cpu() -> None: original = optimizer.state_dict() cpu_state = cpu_optimizer_state(original) - for param_state in cpu_state["state"].values(): + for param_state in cpu_state["state"].values(): # type: ignore[union-attr] for v in param_state.values(): if isinstance(v, torch.Tensor): assert v.device == torch.device("cpu") @@ -129,13 +129,13 @@ def test_cpugpu_optimizer_state_roundtrip() -> None: # Step counts match for param_id in original["state"]: - assert roundtripped["state"][param_id]["step"] == original["state"][param_id]["step"] + assert roundtripped["state"][param_id]["step"] == original["state"][param_id]["step"] # type: ignore[index] # Tensor values match for param_id in original["state"]: for key in ("exp_avg", "exp_avg_sq"): orig_tensor = original["state"][param_id][key] - rt_tensor = roundtripped["state"][param_id][key] + rt_tensor = roundtripped["state"][param_id][key] # type: ignore[index] assert torch.equal(orig_tensor, rt_tensor) @@ -152,7 +152,7 @@ def testcpu_optimizer_state_does_not_mutate_original() -> None: cpu_state = cpu_optimizer_state(original) # Mutate the copy - cpu_state["state"][0]["exp_avg"].zero_() + cpu_state["state"][0]["exp_avg"].zero_() # type: ignore[index] # Original is unchanged assert torch.equal(original["state"][0]["exp_avg"], original_exp_avg) diff --git a/tests/test_local_training_engine.py b/tests/test_local_training_engine.py index 8648fc5..efff14b 100644 --- a/tests/test_local_training_engine.py +++ b/tests/test_local_training_engine.py @@ -47,6 +47,7 @@ def _make_payload(lora_id: str = "user/model") -> DistillBatchRequestPayload: prompt_token_ids=[1, 2], response_token_ids=[3], user_prompt="p", + system_prompt="You are a helpful assistant.", ) ], ) From 835b46da81df3c208f894c16c36072e747e0127e Mon Sep 17 00:00:00 2001 From: Kion Date: Fri, 6 Mar 2026 17:08:45 -0800 Subject: [PATCH 7/8] Add type safety and code style guidelines to CLAUDE.md Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index 8e60258..1422a3b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,6 +8,18 @@ **Treat errors as signals, not obstacles.** An unexpected value or failed assertion means something upstream is wrong. Trace the data flow end-to-end: where was this value produced? What configuration or state fed into it? The goal is durable, correct software — not silencing errors. +## Type Safety + +- Avoid `Any` in Python — prefer `cast()` with the correct type when the type checker can't infer it (e.g. GPU library return types). Use typed dicts, dataclasses, or Pydantic models over bare `dict` without type parameters. +- Data should have the correct shape from the moment it enters the system. If you need a `_normalize_*` or `_make_serializable` function, the upstream data model is wrong — fix the source instead. + +## Code Style + +- Catch specific exception types (`except httpx.HTTPError:`, `except json.JSONDecodeError:`), not bare `except Exception:` or `except:`. +- Keep imports at module top level. Exception: GPU dependencies (torch, peft, vllm, transformers) that aren't installed locally use deferred imports inside functions. +- Don't guard against impossible states — avoid redundant `None` checks, `hasattr()` duck-typing, or `isinstance()` checks when the type is already known from the signature or upstream logic. +- Every public function, class, and module should have a docstring. Keep them concise — one line for simple functions, a short paragraph for complex ones. If you touch a function that lacks one, add it. + ## Code Quality Rules ### After Every Change From 7f18d52f71ba728e9c4184249b4c645b03c80210 Mon Sep 17 00:00:00 2001 From: Kion Date: Fri, 6 Mar 2026 17:43:10 -0800 Subject: [PATCH 8/8] Fix ruff I001: merge peft imports in _build_cache_entry Co-Authored-By: Claude Opus 4.6 --- claas/training/distillation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/claas/training/distillation.py b/claas/training/distillation.py index ef9a595..c854194 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -293,9 +293,7 @@ def _build_cache_entry( optimizer: "torch.optim.Optimizer", ) -> LoraCacheEntry: """Snapshot current model + optimizer state into a CPU-resident cache entry.""" - from peft import PeftModel as PeftModelCls - - from peft import LoraConfig + from peft import LoraConfig, PeftModel as PeftModelCls peft_config = cast(LoraConfig, model.peft_config["default"]) adapter_config = LoraAdapterConfig(