Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion swift/callbacks/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .deepspeed_elastic import DeepspeedElasticCallback, GracefulExitCallback
from .early_stop import EarlyStopCallback
from .lisa import LISACallback
from .mem_snapshot import MemorySnapshotCallback
from .perf_log import PerfMetricsLogCallback

callbacks_map = {
Expand All @@ -13,5 +14,6 @@
'early_stop': EarlyStopCallback,
'graceful_exit': GracefulExitCallback,
'lisa': LISACallback,
'perf_log': PerfMetricsLogCallback
'perf_log': PerfMetricsLogCallback,
'mem_snapshot': MemorySnapshotCallback
}
55 changes: 55 additions & 0 deletions swift/callbacks/mem_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
import os
import torch

if TYPE_CHECKING:
from .base import Trainer, TrainingArguments

from swift import TrainerCallback, get_logger

logger = get_logger()


class MemorySnapshotCallback(TrainerCallback):
"""
Record CUDA memory history and dump snapshot with specified interval steps.
"""

def __init__(self, args: 'TrainingArguments', trainer: 'Trainer'):
super().__init__(args, trainer)
self.dump_interval = args.mem_snapshot_interval
self.dump_path = args.mem_snapshot_path
self._recording = False

def _dump_and_visualize(self, step: int, tag: str = ''):
rank = int(os.environ.get("RANK", 0))
raw = f'snapshot_step{step}_rank{rank}'
pickle_path = os.path.join(self.dump_path, f'{raw}.pickle')
html_path = os.path.join(self.dump_path, f'{raw}.html')

snapshot = torch.cuda.memory._snapshot()
os.makedirs(os.path.dirname(os.path.abspath(pickle_path)), exist_ok=True)
torch.cuda.memory._dump_snapshot(pickle_path)
Comment thread
mugglewei97 marked this conversation as resolved.
logger.info(f"{tag}CUDA memory snapshot dumped: {pickle_path}")

from torch.cuda._memory_viz import trace_plot
html_content = trace_plot(snapshot)
with open(html_path, 'w') as f:
f.write(html_content)
logger.info(f"{tag}CUDA memory html visualization saved: {html_path}")

def on_train_begin(self, args, state, control, **kwargs):
if torch.cuda.is_available():
torch.cuda.memory._record_memory_history(max_entries=100000)
self._recording = True
logger.info("CUDA memory history recording started")

def on_step_end(self, args, state, control, **kwargs):
if self._recording and self.dump_interval and state.global_step % self.dump_interval == 0:
self._dump_and_visualize(state.global_step, tag=f"step:{state.global_step}, ")

def on_train_end(self, args, state, control, **kwargs):
if self._recording:
self._dump_and_visualize(state.global_step, tag="[final] ")
torch.cuda.memory._record_memory_history(enabled=None)
self._recording = False
4 changes: 4 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class TrainArgumentsMixin:
use_logits_to_keep: Optional[bool] = None
ds3_gather_for_generation: bool = True
resume_only_model: bool = False
mem_snapshot_path: str = None
mem_snapshot_interval: int = None

# plugins
optimizer: Optional[str] = None
Expand Down Expand Up @@ -234,6 +236,8 @@ def _init_callbacks(self):
fsdp_config = getattr(self, 'fsdp_config', {})
if isinstance(fsdp_config, dict) and fsdp_config.get('activation_cpu_offload', False):
self.callbacks.append('activation_cpu_offload')
if self.mem_snapshot_path is not None and self.mem_snapshot_interval is not None and self.mem_snapshot_interval > 0:
self.callbacks.append('mem_snapshot')

def __post_init__(self):
if hasattr(self, 'output_dir'):
Expand Down