From 57df7105e300bf9976415f295e49b21443fe09f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=80=B8=E7=A3=8A?= Date: Mon, 25 May 2026 10:46:06 +0800 Subject: [PATCH] add callback: mem_snapshot --- swift/callbacks/mapping.py | 4 ++- swift/callbacks/mem_snapshot.py | 55 +++++++++++++++++++++++++++++++++ swift/trainers/arguments.py | 4 +++ 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 swift/callbacks/mem_snapshot.py diff --git a/swift/callbacks/mapping.py b/swift/callbacks/mapping.py index 3f18235e79..3ff54cba2b 100644 --- a/swift/callbacks/mapping.py +++ b/swift/callbacks/mapping.py @@ -4,6 +4,7 @@ from .deepspeed_elastic import DeepspeedElasticCallback, GracefulExitCallback from .early_stop import EarlyStopCallback from .lisa import LISACallback +from .mem_snapshot import MemorySnapshotCallback from .perf_log import PerfMetricsLogCallback callbacks_map = { @@ -13,5 +14,6 @@ 'early_stop': EarlyStopCallback, 'graceful_exit': GracefulExitCallback, 'lisa': LISACallback, - 'perf_log': PerfMetricsLogCallback + 'perf_log': PerfMetricsLogCallback, + 'mem_snapshot': MemorySnapshotCallback } diff --git a/swift/callbacks/mem_snapshot.py b/swift/callbacks/mem_snapshot.py new file mode 100644 index 0000000000..451956ed45 --- /dev/null +++ b/swift/callbacks/mem_snapshot.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING +import os +import torch + +if TYPE_CHECKING: + from .base import Trainer, TrainingArguments + +from swift import TrainerCallback, get_logger + +logger = get_logger() + + +class MemorySnapshotCallback(TrainerCallback): + """ + Record CUDA memory history and dump snapshot with specified interval steps. + """ + + def __init__(self, args: 'TrainingArguments', trainer: 'Trainer'): + super().__init__(args, trainer) + self.dump_interval = args.mem_snapshot_interval + self.dump_path = args.mem_snapshot_path + self._recording = False + + def _dump_and_visualize(self, step: int, tag: str = ''): + rank = int(os.environ.get("RANK", 0)) + raw = f'snapshot_step{step}_rank{rank}' + pickle_path = os.path.join(self.dump_path, f'{raw}.pickle') + html_path = os.path.join(self.dump_path, f'{raw}.html') + + snapshot = torch.cuda.memory._snapshot() + os.makedirs(os.path.dirname(os.path.abspath(pickle_path)), exist_ok=True) + torch.cuda.memory._dump_snapshot(pickle_path) + logger.info(f"{tag}CUDA memory snapshot dumped: {pickle_path}") + + from torch.cuda._memory_viz import trace_plot + html_content = trace_plot(snapshot) + with open(html_path, 'w') as f: + f.write(html_content) + logger.info(f"{tag}CUDA memory html visualization saved: {html_path}") + + def on_train_begin(self, args, state, control, **kwargs): + if torch.cuda.is_available(): + torch.cuda.memory._record_memory_history(max_entries=100000) + self._recording = True + logger.info("CUDA memory history recording started") + + def on_step_end(self, args, state, control, **kwargs): + if self._recording and self.dump_interval and state.global_step % self.dump_interval == 0: + self._dump_and_visualize(state.global_step, tag=f"step:{state.global_step}, ") + + def on_train_end(self, args, state, control, **kwargs): + if self._recording: + self._dump_and_visualize(state.global_step, tag="[final] ") + torch.cuda.memory._record_memory_history(enabled=None) + self._recording = False diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index ab253ab593..c3c06c67c2 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -159,6 +159,8 @@ class TrainArgumentsMixin: use_logits_to_keep: Optional[bool] = None ds3_gather_for_generation: bool = True resume_only_model: bool = False + mem_snapshot_path: str = None + mem_snapshot_interval: int = None # plugins optimizer: Optional[str] = None @@ -234,6 +236,8 @@ def _init_callbacks(self): fsdp_config = getattr(self, 'fsdp_config', {}) if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False): self.callbacks.append('activation_cpu_offload') + if self.mem_snapshot_path is not None and self.mem_snapshot_interval is not None and self.mem_snapshot_interval > 0: + self.callbacks.append('mem_snapshot') def __post_init__(self): if hasattr(self, 'output_dir'):