diff --git a/swift/arguments/base_args/base_args.py b/swift/arguments/base_args/base_args.py index 03697db861..1cf3f03050 100644 --- a/swift/arguments/base_args/base_args.py +++ b/swift/arguments/base_args/base_args.py @@ -17,6 +17,7 @@ from .data_args import DataArguments from .generation_args import GenerationArguments from .model_args import ModelArguments +from .profile_args import ProfilerArguments from .quant_args import QuantizeArguments from .template_args import TemplateArguments @@ -30,12 +31,11 @@ def get_supported_tuners(): @dataclass class BaseArguments(GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, ModelArguments, - RayArguments): + RayArguments, ProfilerArguments): """BaseArguments class is a dataclass that inherits from multiple argument classes. This class consolidates arguments from GenerationArguments, QuantizeArguments, DataArguments, - TemplateArguments, ModelArguments, RayArguments. - + TemplateArguments, ModelArguments, RayArguments, and ProfilerArguments. Args: tuner_backend (str): The tuner backend to use. Choices are 'peft' or 'unsloth'. Default is 'peft'. tuner_type (str): The tuner type. Choices include 'lora', 'full', 'longlora', 'adalora', 'llamapro', @@ -171,6 +171,7 @@ def __post_init__(self): TemplateArguments.__post_init__(self) DataArguments.__post_init__(self) RayArguments.__post_init__(self) + ProfilerArguments.__post_init__(self) self._init_stream() if self.max_length is None and self.model_info is not None: self.max_length = self.model_info.max_model_len diff --git a/swift/arguments/base_args/profile_args.py b/swift/arguments/base_args/profile_args.py new file mode 100644 index 0000000000..e77e457f92 --- /dev/null +++ b/swift/arguments/base_args/profile_args.py @@ -0,0 +1,60 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from dataclasses import dataclass, field +from typing import List, Optional + +from swift.utils import get_logger + +logger = get_logger() + + +@dataclass +class ProfilerArguments: + + enable_profiler: bool = False + profiler_save_path: Optional[str] = None + profiler_all_ranks: bool = False + profiler_ranks: List[int] = field(default_factory=list) + profiler_contents: List[str] = field(default_factory=list) # e.g., "cpu", "cuda", "stack", "memory"."shape" + profiler_discrete: bool = False + profiler_tool: Optional[str] = 'torch' + profiler_steps: Optional[List[int]] = field(default_factory=list) # Steps to profile + + def __post_init__(self): + assert not self.profiler_discrete, \ + 'Profiler discrete mode is not supported yet, please set profiler_discrete to false' + + if hasattr(self, 'callbacks'): + if self.enable_profiler and 'profiler' not in self.callbacks: + self.callbacks.append('profiler') + if 'profiler' in self.callbacks and not self.enable_profiler: + self.enable_profiler = True + if self.enable_profiler: + assert 'profiler' in self.callbacks, \ + 'Profiler callback must be included in callbacks when profiler is enabled.' + if 'profiler' in self.callbacks: + assert self.enable_profiler, \ + 'Profiler callback is included in callbacks but profiler is not enabled.' + else: + assert not self.enable_profiler, \ + 'Profiler cannot be enabled without callbacks attribute or with profiler callback missing in callbacks.' + if self.enable_profiler: + assert self.profiler_save_path is not None, \ + 'Profiler save path must be specified when profiler is enabled.' + assert self.profiler_contents, \ + 'Profiler contents must be specified when profiler is enabled.' + assert self.profiler_steps, \ + 'Profiler steps must be specified when profiler is enabled.' + assert self.profiler_ranks != [] or self.profiler_all_ranks, \ + 'Either profiler_ranks must be specified or profiler_all_ranks must be set to True.' + + def get_profiler_kwargs(self): + return { + 'enable_profiler': self.enable_profiler, + 'profiler_save_path': self.profiler_save_path, + 'profiler_all_ranks': self.profiler_all_ranks, + 'profiler_ranks': self.profiler_ranks, + 'profiler_contents': self.profiler_contents, + 'profiler_discrete': self.profiler_discrete, + 'profiler_tool': self.profiler_tool, + 'profiler_steps': self.profiler_steps, + } diff --git a/swift/callbacks/mapping.py b/swift/callbacks/mapping.py index 3f18235e79..f6219945f0 100644 --- a/swift/callbacks/mapping.py +++ b/swift/callbacks/mapping.py @@ -5,6 +5,7 @@ from .early_stop import EarlyStopCallback from .lisa import LISACallback from .perf_log import PerfMetricsLogCallback +from .profiler import ProfilerCallback callbacks_map = { 'activation_cpu_offload': ActivationCpuOffloadCallBack, @@ -13,5 +14,6 @@ 'early_stop': EarlyStopCallback, 'graceful_exit': GracefulExitCallback, 'lisa': LISACallback, - 'perf_log': PerfMetricsLogCallback + 'perf_log': PerfMetricsLogCallback, + 'profiler': ProfilerCallback, } diff --git a/swift/callbacks/profiler.py b/swift/callbacks/profiler.py new file mode 100644 index 0000000000..796ce91598 --- /dev/null +++ b/swift/callbacks/profiler.py @@ -0,0 +1,26 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from transformers.trainer_callback import ProgressCallback, TrainerControl, TrainerState + +from swift.utils import get_logger +from swift.utils.profiler import DistProfiler + +logger = get_logger() + + +class ProfilerCallback(ProgressCallback): + + def __init__(self, args, trainer): + super().__init__() + self.args = args + self.trainer = trainer + self.trainer.profiler = DistProfiler(global_config=args) + + def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): + if self.args.profiler_steps and state.global_step in self.args.profiler_steps: + self.trainer.profiler.start() + super().on_step_begin(args, state, control, **kwargs) + + def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): + if self.args.profiler_steps and state.global_step + 1 not in self.args.profiler_steps: + self.trainer.profiler.stop() + super().on_step_end(args, state, control, **kwargs) diff --git a/swift/megatron/callbacks/mapping.py b/swift/megatron/callbacks/mapping.py index 68e13269c1..72c8d69075 100644 --- a/swift/megatron/callbacks/mapping.py +++ b/swift/megatron/callbacks/mapping.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .default_flow import DefaultFlowCallback from .print import PrintCallback +from .profiler import ProfilerCallback from .swanlab import SwanlabCallback from .tensorboard import TensorboardCallback from .wandb import WandbCallback @@ -11,4 +12,5 @@ 'swanlab': SwanlabCallback, 'wandb': WandbCallback, 'tensorboard': TensorboardCallback, + 'profiler': ProfilerCallback, } diff --git a/swift/megatron/callbacks/profiler.py b/swift/megatron/callbacks/profiler.py new file mode 100644 index 0000000000..78e25fb55e --- /dev/null +++ b/swift/megatron/callbacks/profiler.py @@ -0,0 +1,23 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from swift.utils import get_logger +from swift.utils.profiler import DistProfiler +from .base import MegatronCallback + +logger = get_logger() + + +class ProfilerCallback(MegatronCallback): + + def __init__(self, trainer): + super().__init__(trainer) + self.args = trainer.args + self.trainer = trainer + self.trainer.profiler = DistProfiler(global_config=self.args) + + def on_step_begin(self): + if self.args.profiler_steps and self.state.global_step in self.args.profiler_steps: + self.trainer.profiler.start() + + def on_step_end(self): + if self.args.profiler_steps and self.state.global_step + 1 not in self.args.profiler_steps: + self.trainer.profiler.stop() diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 96f63dfc0d..ab1a5b863d 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -121,6 +121,16 @@ class TrainArgumentsMixin: shared memory and then asynchronously persisted to disk. Currently does not support the safetensors format. It is recommended to use this with `PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"` to prevent CUDA OOM errors during training. Defaults to False. + + enable_profiler (bool): Master switch to enable or disable performance profiling. Default is False. + profiler_save_path (Optional[str]): Directory path where the profiling results and trace files will be saved. + profiler_all_ranks (bool): If True, collects profiling data from all distributed processes. + profiler_ranks (List[int]): A list of specific rank IDs to collect profiling data from. + profiler_contents (List[str]): List of data categories to record, such as "cpu", "cuda", "memory", or "stack". + profiler_discrete (bool): If True, records data for each step independently. + profiler_tool (Optional[str]): Specifies the backend tool used for profiling. + profiler_steps (Optional[List[int]]): A list of specific training steps during which profiling should be active. + """ per_device_train_batch_size: int = 1 per_device_eval_batch_size: int = 1 @@ -202,6 +212,16 @@ class TrainArgumentsMixin: # dlrover flash_checkpoint use_flash_ckpt: bool = False + # profiler + enable_profiler: bool = False + profiler_save_path: Optional[str] = None + profiler_all_ranks: bool = False + profiler_ranks: List[int] = field(default_factory=list) + profiler_contents: List[str] = field(default_factory=list) # e.g., "cpu", "cuda", "stack", "memory"."shape" + profiler_discrete: bool = False + profiler_tool: Optional[str] = 'torch' + profiler_steps: Optional[List[int]] = field(default_factory=list) # Steps to profile + @staticmethod def _patch_liger_kernel(): # fix logits_to_keep diff --git a/swift/utils/profiler/__init__.py b/swift/utils/profiler/__init__.py new file mode 100644 index 0000000000..be059d404c --- /dev/null +++ b/swift/utils/profiler/__init__.py @@ -0,0 +1,7 @@ +from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig + +__all__ = [ + 'DistProfiler', + 'DistProfilerExtension', + 'ProfilerConfig', +] diff --git a/swift/utils/profiler/config.py b/swift/utils/profiler/config.py new file mode 100644 index 0000000000..9750fc1ea4 --- /dev/null +++ b/swift/utils/profiler/config.py @@ -0,0 +1,142 @@ +import collections +from dataclasses import FrozenInstanceError, dataclass, field, fields +from typing import Any, Optional + +# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary + + +@dataclass +class BaseConfig(collections.abc.Mapping): + """The BaseConfig provides dict-like interface for a dataclass config. + + By default all fields in the config is not mutable, unless specified in + "_mutable_fields". The BaseConfig class implements the Mapping Abstract Base Class. + This allows instances of this class to be used like dictionaries. + """ + + _mutable_fields = set() + _target_: str = '' + + def __setattr__(self, name: str, value): + """Set the value of an attribute. Check if the attr is mutable before setting the value.""" + # If the field already exists, it's considered frozen unless it's in _mutable_fields + if name in self.__dict__ and name not in getattr(self, '_mutable_fields', set()): + raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified") + super().__setattr__(name, value) + + def get(self, key: str, default: Any = None) -> Any: + """Get the value associated with the given key. If the key does not exist, return the default value. + + Args: + key (str): The attribute name to retrieve. + default (Any, optional): The value to return if the attribute does not exist. Defaults to None. + + Returns: + Any: The value of the attribute or the default value. + """ + try: + return getattr(self, key) + except AttributeError: + return default + + def __getitem__(self, key: str): + """Implement the [] operator for the class. Allows accessing attributes like dictionary items. + + Args: + key (str): The attribute name to retrieve. + + Returns: + Any: The value of the attribute. + + Raises: + AttributeError: If the attribute does not exist. + TypeError: If the key type is not string + """ + return getattr(self, key) + + def __iter__(self): + """Implement the iterator protocol. Allows iterating over the attribute names of the instance. + + Yields: + str: The name of each field in the dataclass. + """ + for f in fields(self): + yield f.name + + def __len__(self): + """ + Return the number of fields in the dataclass. + + Returns: + int: The number of fields in the dataclass. + """ + return len(fields(self)) + + +@dataclass +class ProfilerConfig(BaseConfig): + """Worker profiler config. + + Args: + discrete (bool): True for each task has its own database, False for all tasks in one training step + share one database. + all_ranks (bool): Whether to profile all ranks. + ranks (list[int]): The ranks that will be profiled. Defaults to []. + global_tool_config (Any): Global tool configuration for all profiling tools. + """ + + tool: Optional[str] = None + enable: bool = False + all_ranks: bool = False + ranks: list[int] = field(default_factory=list) + save_path: Optional[str] = None + tool_config: Any = None + global_tool_config: Optional[Any] = None # Global tool configuration for all profiling tools + + def union(self, other: 'ProfilerConfig') -> 'ProfilerConfig': + assert self.tool == other.tool, f"Cannot union ProfilerConfig with different tools: {self.tool} vs {other.tool}" + return ProfilerConfig( + tool=self.tool, + enable=self.enable or other.enable, + all_ranks=self.all_ranks or other.all_ranks, + ranks=list(set(self.ranks or []) | set(other.ranks or [])), + save_path=self.save_path or other.save_path, + tool_config=self.tool_config or other.tool_config, + global_tool_config=self.global_tool_config or other.global_tool_config, + ) + + def intersect(self, other: 'ProfilerConfig') -> 'ProfilerConfig': + assert self.tool == other.tool, ( + f"Cannot intersect ProfilerConfig with different tools: {self.tool} vs {other.tool}") + return ProfilerConfig( + tool=self.tool, + enable=self.enable and other.enable, + all_ranks=self.all_ranks and other.all_ranks, + ranks=list(set(self.ranks or []) & set(other.ranks or [])), + save_path=self.save_path, + tool_config=self.tool_config, + global_tool_config=self.global_tool_config if self.global_tool_config else other.global_tool_config, + ) + + def __post_init__(self) -> None: + """config validation logics go here""" + assert isinstance(self.ranks, + (set, list, tuple)), (f"Profiler ranks must be of type list, got {type(self.ranks)}") + + +@dataclass +class TorchProfilerToolConfig(BaseConfig): + """Torch profiler tool config.""" + + # options: cuda, cpu, memory, shapes, stack + contents: list[str] = field(default_factory=list) + discrete: bool = False + name: str = 'torch' + + def __post_init__(self) -> None: + """config validation logics go here""" + assert isinstance(self.contents, list), f"Profiler contents must be of type list, got {type(self.contents)}" + __support_contents = ['cuda', 'cpu', 'memory', 'shapes', 'stack'] + for content in self.contents: + assert content in __support_contents, ( + f"Profiler contents only supports {__support_contents}, but gets {content}") diff --git a/swift/utils/profiler/profile.py b/swift/utils/profiler/profile.py new file mode 100644 index 0000000000..a3be0c5e0b --- /dev/null +++ b/swift/utils/profiler/profile.py @@ -0,0 +1,149 @@ +import functools +import os +import torch +from typing import Callable, Optional + +from swift.utils import get_logger +from .config import ProfilerConfig, TorchProfilerToolConfig + +logger = get_logger() + + +class DistProfiler: + + def __init__(self, + global_config=None, + rank: int = None, + config: Optional[ProfilerConfig] = None, + tool_config: Optional[object] = None, + **kwargs): + # Default config + if rank is None: + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = int(os.environ.get('RANK', 0)) + logger.warning(f"Warning: torch.distributed is not initialized, using RANK env var for rank: {rank}") + if global_config is not None: + config = ProfilerConfig( + tool=global_config.profiler_tool, + enable=global_config.enable_profiler, + all_ranks=global_config.profiler_all_ranks, + ranks=global_config.profiler_ranks, + save_path=global_config.profiler_save_path, + tool_config=tool_config or TorchProfilerToolConfig( + contents=global_config.profiler_contents, discrete=global_config.profiler_discrete), + ) + elif not config: + config = ProfilerConfig(ranks=[], enable=False, tool_config=None) + + if tool_config is None: + tool_config = config.tool_config + + self.config = config + self.tool_config = tool_config + + self._impl = None + self._tool = getattr(config, 'tool', None) + self._enable = config.enable + self._this_step = False + + # Normalize rank selection + self._this_rank = False + if config.all_ranks: + self._this_rank = True + elif config.ranks: + self._this_rank = rank in config.ranks + else: + # default rank 0 if enabled but ranks unspecified + self._this_rank = (rank == 0) if self._enable else False + + self._discrete = getattr(tool_config, 'discrete', False) if tool_config else False + + if self._tool == 'torch': + from .torch_profile import Profiler as _Torch + + self._impl = _Torch(rank=rank, config=config, tool_config=tool_config) + else: + # Fallback to a no-op impl + self._impl = _NoOpProfiler() + + def check_enable(self): + return self._enable + + def check_this_rank(self): + return self._this_rank + + def check_this_step(self): + return self._this_step + + def is_discrete_mode(self): + return self._discrete + + def start(self, **kwargs): + if self.check_enable() and self.check_this_rank(): + self._this_step = True + return getattr(self._impl, 'start', lambda **_: None)(**kwargs) + + def stop(self): + if self.check_enable() and self.check_this_rank(): + self._this_step = False + return getattr(self._impl, 'stop', lambda: None)() + + @classmethod + def annotate( + cls, + message: Optional[str] = None, + color: Optional[str] = None, + domain: Optional[str] = None, + category: Optional[str] = None, + **kwargs_outer, + ) -> Callable: + + def decorator(func): + + @functools.wraps(func) + def wrapper(self_instance, *args, **kwargs_inner): + profiler = getattr(self_instance, 'profiler', None) + + if (not profiler or not profiler.check_enable() or not profiler.check_this_step() + or not profiler.check_this_rank()): + return func(self_instance, *args, **kwargs_inner) + + impl = profiler._impl + if hasattr(impl, 'annotate'): + try: + actual_decorator = impl.annotate( + message=message, color=color, domain=domain, category=category, **kwargs_outer) + + return actual_decorator(func)(self_instance, *args, **kwargs_inner) + except Exception: + return func(self_instance, *args, **kwargs_inner) + return func(self_instance, *args, **kwargs_inner) + + return wrapper + + return decorator + + +class DistProfilerExtension: + + def __init__(self, profiler: DistProfiler): + self.profiler = profiler + + def start_profile(self, **kwargs) -> None: + """Start profiling for the current rank in the current training step.""" + self.profiler.start(**kwargs) + + def stop_profile(self) -> None: + """Stop profiling for the current rank in the current training step.""" + self.profiler.stop() + + +class _NoOpProfiler: + + def start(self, **kwargs): + return + + def stop(self): + return diff --git a/swift/utils/profiler/torch_profile.py b/swift/utils/profiler/torch_profile.py new file mode 100644 index 0000000000..0d0f74c408 --- /dev/null +++ b/swift/utils/profiler/torch_profile.py @@ -0,0 +1,152 @@ +import functools +import os +import torch +from datetime import datetime, timezone +from typing import Callable, Optional + +from swift.utils import get_logger +from .config import ProfilerConfig, TorchProfilerToolConfig +from .profile import DistProfiler + +logger = get_logger() + + +def get_torch_profiler( + contents: list[str], + save_path: str, + role: Optional[str] = None, + save_file_prefix: Optional[str] = None, + rank: int = 0, +): + if role: + save_path = os.path.join(save_path, role) + + os.makedirs(save_path, exist_ok=True) + + current_time = datetime.now(tz=timezone.utc).astimezone() + timestamp = current_time.strftime('%Y%m%d%H%M%S%f')[:-3] + pid = os.getpid() + + save_file_name = f"prof_rank-{rank}_{pid}_{timestamp}.json.gz" + if save_file_prefix: + save_file_name = f"{save_file_prefix}_{save_file_name}" + save_path = os.path.join(save_path, save_file_name) + + def _trace_handler(prof): + logger.info(f"[Profiler] Saving trace to {save_path}") + prof.export_chrome_trace(save_path) + + contents = set(contents) if contents else set() + activities = [] + if not contents or 'cpu' in contents: + activities.append(torch.profiler.ProfilerActivity.CPU) + if not contents or 'cuda' in contents: + activities.append(torch.profiler.ProfilerActivity.CUDA) + + return torch.profiler.profile( + activities=activities, + with_stack='stack' in contents, + record_shapes='shapes' in contents, + profile_memory='memory' in contents, + on_trace_ready=_trace_handler, + ) + + +class Profiler(DistProfiler): + + _define_count = 0 + + def __init__( + self, + rank, + config: ProfilerConfig, + tool_config: Optional[TorchProfilerToolConfig] = None, + save_file_prefix=None, + ): + # note : if we do not set use_profile, it will be set as None, so that all function will be skip + config = config or ProfilerConfig(ranks=[], enable=False) + self.save_file_prefix = save_file_prefix + + if not tool_config: + assert not config.enable, 'tool_config must be provided when profiler is enabled' + + self.prof = None + self.rank = rank + self.config = config + self.tool_config = tool_config + self.contents = self.tool_config.contents if self.tool_config else [] + self.save_path = self.config.save_path + # Align with other profilers: read discrete mode, default to False for torch profiler + self.discrete = getattr(self.tool_config, 'discrete', False) + + def check(self): + return self.prof is not None + + def start(self, **kwargs): + role = kwargs.get('role', None) + if not self.discrete and Profiler._define_count == 0: + self.prof = get_torch_profiler( + contents=self.contents, + save_path=self.save_path, + role=role, + save_file_prefix=self.save_file_prefix, + rank=self.rank, + ) + logger.info(f"[Profiler] started for rank {self.rank}") + self.prof.start() + Profiler._define_count += 1 + + def step(self): + if self.check(): + self.prof.step() + + def stop(self): + if not self.discrete and Profiler._define_count == 1 and self.check(): + self.step() + logger.info(f"[Profiler] stopped for rank {self.rank}") + self.prof.stop() + Profiler._define_count -= 1 + + def annotate(self, message: Optional[str] = None, role: Optional[str] = None, **kwargs_outer) -> Callable: + """Decorate a Worker member function to profile the current rank in the current training step. + + Requires the target function to be a member function of a Worker, + which has a member field `profiler` with Profiler type. + + Args: + message (str, optional): + The message to be displayed in the profiler. Defaults to None. + role (str, optional): + The role of the current data collection. Defaults to None. + """ + + def decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs_inner): + profile_name = message or func.__name__ + + if not self.discrete: + # In continuous mode, we just record function, profiler started globally + with torch.profiler.record_function(profile_name): + return func(*args, **kwargs_inner) + + # In discrete mode, we start/stop profiler around the function + prof = get_torch_profiler( + contents=self.contents, + save_path=self.save_path, + role=role, + save_file_prefix=self.save_file_prefix, + rank=self.rank, + ) + prof.start() + try: + with torch.profiler.record_function(profile_name): + result = func(*args, **kwargs_inner) + finally: + prof.stop() + return result + + return wrapper + + return decorator