Skip to content

Commit 166a16d

Browse files
Kionclaude
andcommitted
Move cache types and optimizer state helpers into local engine
Relocate LoraCacheEntry, LoraAdapterConfig, DistillStepResult, and the cpu/gpu_optimizer_state helpers from the shared training module into claas/training/engine/local/cache.py since they are only used by the local engine's CPU caching path. The Modal worker never uses caching. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6607a2c commit 166a16d

6 files changed

Lines changed: 108 additions & 103 deletions

File tree

claas/training/cache.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

claas/training/distillation.py

Lines changed: 5 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import copy
65
import json
76
import logging
87
import os
@@ -13,10 +12,12 @@
1312
import torch
1413

1514
from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput
16-
from claas.training.cache import (
15+
from claas.training.engine.local.cache import (
1716
DistillStepResult,
1817
LoraAdapterConfig,
1918
LoraCacheEntry,
19+
cpu_optimizer_state,
20+
gpu_optimizer_state,
2021
)
2122
from claas.training.sdpo_loss import compute_sdpo_loss
2223
from claas.training.storage import (
@@ -52,50 +53,6 @@ class PreparedSample(TypedDict):
5253
behavior_logprobs: torch.Tensor
5354

5455

55-
def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
56-
"""Deep-copy optimizer state with all tensors moved to CPU."""
57-
result: dict[str, object] = {}
58-
for key, value in state_dict.items():
59-
if key == "state":
60-
param_states = cast("dict[int, dict[str, object]]", value)
61-
cpu_states: dict[int, dict[str, object]] = {}
62-
for param_id, param_state in param_states.items():
63-
cpu_param: dict[str, object] = {}
64-
for k, v in param_state.items():
65-
if isinstance(v, torch.Tensor):
66-
cpu_param[k] = v.detach().cpu().clone()
67-
else:
68-
cpu_param[k] = copy.deepcopy(v)
69-
cpu_states[param_id] = cpu_param
70-
result[key] = cpu_states
71-
else:
72-
result[key] = copy.deepcopy(value)
73-
return result
74-
75-
76-
def _gpu_optimizer_state(
77-
state_dict: dict[str, object],
78-
device: torch.device,
79-
) -> dict[str, object]:
80-
"""Deep-copy optimizer state with all tensors moved to a target device."""
81-
result: dict[str, object] = {}
82-
for key, value in state_dict.items():
83-
if key == "state":
84-
param_states = cast("dict[int, dict[str, object]]", value)
85-
gpu_states: dict[int, dict[str, object]] = {}
86-
for param_id, param_state in param_states.items():
87-
gpu_param: dict[str, object] = {}
88-
for k, v in param_state.items():
89-
if isinstance(v, torch.Tensor):
90-
gpu_param[k] = v.detach().to(device).clone()
91-
else:
92-
gpu_param[k] = copy.deepcopy(v)
93-
gpu_states[param_id] = gpu_param
94-
result[key] = gpu_states
95-
else:
96-
result[key] = copy.deepcopy(value)
97-
return result
98-
9956

10057
class DistillationTrainer:
10158
"""Runs one SDPO distillation update using a loaded base model."""
@@ -357,7 +314,7 @@ def _build_cache_entry(
357314
raw_state = model.state_dict()
358315

359316
lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()}
360-
opt_state = _cpu_optimizer_state(optimizer.state_dict())
317+
opt_state = cpu_optimizer_state(optimizer.state_dict())
361318

362319
return LoraCacheEntry(
363320
lora_state_dict=lora_state,
@@ -421,7 +378,7 @@ def distill(
421378

422379
if cached is not None:
423380
optimizer.load_state_dict(
424-
_gpu_optimizer_state(cached.optimizer_state_dict, self.device)
381+
gpu_optimizer_state(cached.optimizer_state_dict, self.device)
425382
)
426383
elif lora_local_path is not None:
427384
self._load_optimizer_state(lora_local_path, optimizer)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Typed cache structures and helpers for CPU-resident LoRA state between training steps."""
2+
3+
from __future__ import annotations
4+
5+
import copy
6+
from dataclasses import dataclass
7+
from typing import cast
8+
9+
import torch
10+
11+
from claas.core.types import DistillResponse
12+
13+
14+
@dataclass(frozen=True, slots=True)
15+
class LoraAdapterConfig:
16+
"""Typed representation of LoRA adapter configuration."""
17+
18+
r: int
19+
lora_alpha: int
20+
target_modules: list[str]
21+
lora_dropout: float
22+
bias: str
23+
task_type: str
24+
25+
26+
@dataclass(frozen=True, slots=True)
27+
class LoraCacheEntry:
28+
"""CPU-resident snapshot of LoRA adapter state between training steps."""
29+
30+
lora_state_dict: dict[str, torch.Tensor]
31+
optimizer_state_dict: dict[str, object]
32+
adapter_config: LoraAdapterConfig
33+
34+
35+
@dataclass(frozen=True, slots=True)
36+
class DistillStepResult:
37+
"""Result of a distillation step with both response and cache entry."""
38+
39+
response: DistillResponse
40+
cache_entry: LoraCacheEntry
41+
42+
43+
def cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
44+
"""Deep-copy optimizer state with all tensors moved to CPU."""
45+
result: dict[str, object] = {}
46+
for key, value in state_dict.items():
47+
if key == "state":
48+
param_states = cast("dict[int, dict[str, object]]", value)
49+
cpu_states: dict[int, dict[str, object]] = {}
50+
for param_id, param_state in param_states.items():
51+
cpu_param: dict[str, object] = {}
52+
for k, v in param_state.items():
53+
if isinstance(v, torch.Tensor):
54+
cpu_param[k] = v.detach().cpu().clone()
55+
else:
56+
cpu_param[k] = copy.deepcopy(v)
57+
cpu_states[param_id] = cpu_param
58+
result[key] = cpu_states
59+
else:
60+
result[key] = copy.deepcopy(value)
61+
return result
62+
63+
64+
def gpu_optimizer_state(
65+
state_dict: dict[str, object],
66+
device: torch.device,
67+
) -> dict[str, object]:
68+
"""Deep-copy optimizer state with all tensors moved to a target device."""
69+
result: dict[str, object] = {}
70+
for key, value in state_dict.items():
71+
if key == "state":
72+
param_states = cast("dict[int, dict[str, object]]", value)
73+
gpu_states: dict[int, dict[str, object]] = {}
74+
for param_id, param_state in param_states.items():
75+
gpu_param: dict[str, object] = {}
76+
for k, v in param_state.items():
77+
if isinstance(v, torch.Tensor):
78+
gpu_param[k] = v.detach().to(device).clone()
79+
else:
80+
gpu_param[k] = copy.deepcopy(v)
81+
gpu_states[param_id] = gpu_param
82+
result[key] = gpu_states
83+
else:
84+
result[key] = copy.deepcopy(value)
85+
return result

claas/training/engine/local/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
LoraRuntimeRef,
2121
ServiceHealth,
2222
)
23-
from claas.training.cache import LoraCacheEntry
2423
from claas.training.distillation import DistillationTrainer
2524
from claas.training.engine.base import TrainingEngine
25+
from claas.training.engine.local.cache import LoraCacheEntry
2626
from claas.training.storage import (
2727
configure_storage_backend,
2828
create_initial_lora,

tests/test_distillation_optimizer_state.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66

77
torch = pytest.importorskip("torch")
88

9-
from claas.training.cache import LoraAdapterConfig, LoraCacheEntry # noqa: E402
10-
from claas.training.distillation import ( # noqa: E402
11-
DistillationTrainer,
12-
_cpu_optimizer_state,
13-
_gpu_optimizer_state,
9+
from claas.training.distillation import DistillationTrainer # noqa: E402
10+
from claas.training.engine.local.cache import ( # noqa: E402
11+
LoraAdapterConfig,
12+
LoraCacheEntry,
13+
cpu_optimizer_state,
14+
gpu_optimizer_state,
1415
)
1516

1617

@@ -97,34 +98,34 @@ def test_optimizer_state_missing_gracefully_skips(trainer: DistillationTrainer,
9798
assert len(optimizer.state) == 0
9899

99100

100-
def test_cpu_optimizer_state_moves_tensors_to_cpu() -> None:
101-
"""_cpu_optimizer_state produces a state dict with all tensors on CPU."""
101+
def testcpu_optimizer_state_moves_tensors_to_cpu() -> None:
102+
"""cpu_optimizer_state produces a state dict with all tensors on CPU."""
102103
model = _SimpleLoraModel()
103104
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
104105
loss = model.first.sum()
105106
loss.backward()
106107
optimizer.step()
107108

108109
original = optimizer.state_dict()
109-
cpu_state = _cpu_optimizer_state(original)
110+
cpu_state = cpu_optimizer_state(original)
110111

111112
for param_state in cpu_state["state"].values():
112113
for v in param_state.values():
113114
if isinstance(v, torch.Tensor):
114115
assert v.device == torch.device("cpu")
115116

116117

117-
def test_cpu_gpu_optimizer_state_roundtrip() -> None:
118-
"""_cpu_optimizer_state / _gpu_optimizer_state round-trip preserves values."""
118+
def test_cpugpu_optimizer_state_roundtrip() -> None:
119+
"""cpu_optimizer_state / gpu_optimizer_state round-trip preserves values."""
119120
model = _SimpleLoraModel()
120121
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
121122
loss = model.first.sum()
122123
loss.backward()
123124
optimizer.step()
124125

125126
original = optimizer.state_dict()
126-
cpu_state = _cpu_optimizer_state(original)
127-
roundtripped = _gpu_optimizer_state(cpu_state, torch.device("cpu"))
127+
cpu_state = cpu_optimizer_state(original)
128+
roundtripped = gpu_optimizer_state(cpu_state, torch.device("cpu"))
128129

129130
# Step counts match
130131
for param_id in original["state"]:
@@ -138,8 +139,8 @@ def test_cpu_gpu_optimizer_state_roundtrip() -> None:
138139
assert torch.equal(orig_tensor, rt_tensor)
139140

140141

141-
def test_cpu_optimizer_state_does_not_mutate_original() -> None:
142-
"""_cpu_optimizer_state deep-copies — mutating the copy leaves the original intact."""
142+
def testcpu_optimizer_state_does_not_mutate_original() -> None:
143+
"""cpu_optimizer_state deep-copies — mutating the copy leaves the original intact."""
143144
model = _SimpleLoraModel()
144145
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
145146
loss = model.first.sum()
@@ -149,7 +150,7 @@ def test_cpu_optimizer_state_does_not_mutate_original() -> None:
149150
original = optimizer.state_dict()
150151
original_exp_avg = original["state"][0]["exp_avg"].clone()
151152

152-
cpu_state = _cpu_optimizer_state(original)
153+
cpu_state = cpu_optimizer_state(original)
153154
# Mutate the copy
154155
cpu_state["state"][0]["exp_avg"].zero_()
155156

tests/test_local_training_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
DistillResponse,
1414
TrainingConfig,
1515
)
16-
from claas.training.cache import ( # noqa: E402
16+
from claas.training.engine.local.cache import ( # noqa: E402
1717
DistillStepResult,
1818
LoraAdapterConfig,
1919
LoraCacheEntry,

0 commit comments

Comments
 (0)