|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -import copy |
6 | 5 | import json |
7 | 6 | import logging |
8 | 7 | import os |
|
13 | 12 | import torch |
14 | 13 |
|
15 | 14 | from claas.core.types import DistillBatchRequestPayload, DistillResponse, SDPOLossInput |
16 | | -from claas.training.cache import ( |
| 15 | +from claas.training.engine.local.cache import ( |
17 | 16 | DistillStepResult, |
18 | 17 | LoraAdapterConfig, |
19 | 18 | LoraCacheEntry, |
| 19 | + cpu_optimizer_state, |
| 20 | + gpu_optimizer_state, |
20 | 21 | ) |
21 | 22 | from claas.training.sdpo_loss import compute_sdpo_loss |
22 | 23 | from claas.training.storage import ( |
@@ -52,50 +53,6 @@ class PreparedSample(TypedDict): |
52 | 53 | behavior_logprobs: torch.Tensor |
53 | 54 |
|
54 | 55 |
|
55 | | -def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: |
56 | | - """Deep-copy optimizer state with all tensors moved to CPU.""" |
57 | | - result: dict[str, object] = {} |
58 | | - for key, value in state_dict.items(): |
59 | | - if key == "state": |
60 | | - param_states = cast("dict[int, dict[str, object]]", value) |
61 | | - cpu_states: dict[int, dict[str, object]] = {} |
62 | | - for param_id, param_state in param_states.items(): |
63 | | - cpu_param: dict[str, object] = {} |
64 | | - for k, v in param_state.items(): |
65 | | - if isinstance(v, torch.Tensor): |
66 | | - cpu_param[k] = v.detach().cpu().clone() |
67 | | - else: |
68 | | - cpu_param[k] = copy.deepcopy(v) |
69 | | - cpu_states[param_id] = cpu_param |
70 | | - result[key] = cpu_states |
71 | | - else: |
72 | | - result[key] = copy.deepcopy(value) |
73 | | - return result |
74 | | - |
75 | | - |
76 | | -def _gpu_optimizer_state( |
77 | | - state_dict: dict[str, object], |
78 | | - device: torch.device, |
79 | | -) -> dict[str, object]: |
80 | | - """Deep-copy optimizer state with all tensors moved to a target device.""" |
81 | | - result: dict[str, object] = {} |
82 | | - for key, value in state_dict.items(): |
83 | | - if key == "state": |
84 | | - param_states = cast("dict[int, dict[str, object]]", value) |
85 | | - gpu_states: dict[int, dict[str, object]] = {} |
86 | | - for param_id, param_state in param_states.items(): |
87 | | - gpu_param: dict[str, object] = {} |
88 | | - for k, v in param_state.items(): |
89 | | - if isinstance(v, torch.Tensor): |
90 | | - gpu_param[k] = v.detach().to(device).clone() |
91 | | - else: |
92 | | - gpu_param[k] = copy.deepcopy(v) |
93 | | - gpu_states[param_id] = gpu_param |
94 | | - result[key] = gpu_states |
95 | | - else: |
96 | | - result[key] = copy.deepcopy(value) |
97 | | - return result |
98 | | - |
99 | 56 |
|
100 | 57 | class DistillationTrainer: |
101 | 58 | """Runs one SDPO distillation update using a loaded base model.""" |
@@ -357,7 +314,7 @@ def _build_cache_entry( |
357 | 314 | raw_state = model.state_dict() |
358 | 315 |
|
359 | 316 | lora_state = {k: v.detach().cpu().clone() for k, v in raw_state.items()} |
360 | | - opt_state = _cpu_optimizer_state(optimizer.state_dict()) |
| 317 | + opt_state = cpu_optimizer_state(optimizer.state_dict()) |
361 | 318 |
|
362 | 319 | return LoraCacheEntry( |
363 | 320 | lora_state_dict=lora_state, |
@@ -421,7 +378,7 @@ def distill( |
421 | 378 |
|
422 | 379 | if cached is not None: |
423 | 380 | optimizer.load_state_dict( |
424 | | - _gpu_optimizer_state(cached.optimizer_state_dict, self.device) |
| 381 | + gpu_optimizer_state(cached.optimizer_state_dict, self.device) |
425 | 382 | ) |
426 | 383 | elif lora_local_path is not None: |
427 | 384 | self._load_optimizer_state(lora_local_path, optimizer) |
|
0 commit comments