Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions swift/arguments/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .data_args import DataArguments
from .generation_args import GenerationArguments
from .model_args import ModelArguments
from .profile_args import ProfilerArguments
from .quant_args import QuantizeArguments
from .template_args import TemplateArguments

Expand All @@ -30,12 +31,11 @@ def get_supported_tuners():

@dataclass
class BaseArguments(GenerationArguments, QuantizeArguments, DataArguments, TemplateArguments, ModelArguments,
RayArguments):
RayArguments, ProfilerArguments):
"""BaseArguments class is a dataclass that inherits from multiple argument classes.

This class consolidates arguments from GenerationArguments, QuantizeArguments, DataArguments,
TemplateArguments, ModelArguments, RayArguments.

TemplateArguments, ModelArguments, RayArguments, and ProfilerArguments.
Args:
tuner_backend (str): The tuner backend to use. Choices are 'peft' or 'unsloth'. Default is 'peft'.
tuner_type (str): The tuner type. Choices include 'lora', 'full', 'longlora', 'adalora', 'llamapro',
Expand Down Expand Up @@ -171,6 +171,7 @@ def __post_init__(self):
TemplateArguments.__post_init__(self)
DataArguments.__post_init__(self)
RayArguments.__post_init__(self)
ProfilerArguments.__post_init__(self)
self._init_stream()
if self.max_length is None and self.model_info is not None:
self.max_length = self.model_info.max_model_len
Expand Down
60 changes: 60 additions & 0 deletions swift/arguments/base_args/profile_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from dataclasses import dataclass, field
from typing import List, Optional

from swift.utils import get_logger

logger = get_logger()


@dataclass
class ProfilerArguments:

enable_profiler: bool = False
profiler_save_path: Optional[str] = None
profiler_all_ranks: bool = False
profiler_ranks: List[int] = field(default_factory=list)
profiler_contents: List[str] = field(default_factory=list) # e.g., "cpu", "cuda", "stack", "memory"."shape"
profiler_discrete: bool = False
profiler_tool: Optional[str] = 'torch'
profiler_steps: Optional[List[int]] = field(default_factory=list) # Steps to profile

def __post_init__(self):
assert not self.profiler_discrete, \
'Profiler discrete mode is not supported yet, please set profiler_discrete to false'

if hasattr(self, 'callbacks'):
if self.enable_profiler and 'profiler' not in self.callbacks:
self.callbacks.append('profiler')
if 'profiler' in self.callbacks and not self.enable_profiler:
self.enable_profiler = True
if self.enable_profiler:
assert 'profiler' in self.callbacks, \
'Profiler callback must be included in callbacks when profiler is enabled.'
if 'profiler' in self.callbacks:
assert self.enable_profiler, \
'Profiler callback is included in callbacks but profiler is not enabled.'
else:
assert not self.enable_profiler, \
'Profiler cannot be enabled without callbacks attribute or with profiler callback missing in callbacks.'
Comment on lines +26 to +39

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert not self.enable_profiler in the else block will raise an AssertionError and crash the application whenever enable_profiler is set to True on any arguments class that does not have a callbacks attribute (such as BaseArguments, DeployArguments, EvalArguments, or ExportArguments). This prevents using the profiler for inference, evaluation, or custom training loops (like RLHF rollout/actor phases) where standard trainer callbacks are not used.

We should remove this assertion to allow enabling the profiler without requiring a callbacks attribute.

Suggested change
if hasattr(self, 'callbacks'):
if self.enable_profiler and 'profiler' not in self.callbacks:
self.callbacks.append('profiler')
if 'profiler' in self.callbacks and not self.enable_profiler:
self.enable_profiler = True
if self.enable_profiler:
assert 'profiler' in self.callbacks, \
'Profiler callback must be included in callbacks when profiler is enabled.'
if 'profiler' in self.callbacks:
assert self.enable_profiler, \
'Profiler callback is included in callbacks but profiler is not enabled.'
else:
assert not self.enable_profiler, \
'Profiler cannot be enabled without callbacks attribute or with profiler callback missing in callbacks.'
if hasattr(self, 'callbacks'):
if self.enable_profiler and 'profiler' not in self.callbacks:
self.callbacks.append('profiler')
if 'profiler' in self.callbacks and not self.enable_profiler:
self.enable_profiler = True
if self.enable_profiler:
assert 'profiler' in self.callbacks, \
'Profiler callback must be included in callbacks when profiler is enabled.'
if 'profiler' in self.callbacks:
assert self.enable_profiler, \
'Profiler callback is included in callbacks but profiler is not enabled.'

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseArguments 不应该直接使用,一般都是被其他Arguments继承,例如SftArguments,是有callback 属性的,并且profiler功能依赖profiler callback ,因此我必须确保profiler callback和enable_profiler开启。而如果没有 profiler callback, 那也确实说明不该开启profiler

if self.enable_profiler:
assert self.profiler_save_path is not None, \
'Profiler save path must be specified when profiler is enabled.'
assert self.profiler_contents, \
'Profiler contents must be specified when profiler is enabled.'
assert self.profiler_steps, \
'Profiler steps must be specified when profiler is enabled.'
assert self.profiler_ranks != [] or self.profiler_all_ranks, \
'Either profiler_ranks must be specified or profiler_all_ranks must be set to True.'
Comment thread
qq1243196045 marked this conversation as resolved.

def get_profiler_kwargs(self):
return {
'enable_profiler': self.enable_profiler,
'profiler_save_path': self.profiler_save_path,
'profiler_all_ranks': self.profiler_all_ranks,
'profiler_ranks': self.profiler_ranks,
'profiler_contents': self.profiler_contents,
'profiler_discrete': self.profiler_discrete,
'profiler_tool': self.profiler_tool,
'profiler_steps': self.profiler_steps,
}
4 changes: 3 additions & 1 deletion swift/callbacks/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .early_stop import EarlyStopCallback
from .lisa import LISACallback
from .perf_log import PerfMetricsLogCallback
from .profiler import ProfilerCallback

callbacks_map = {
'activation_cpu_offload': ActivationCpuOffloadCallBack,
Expand All @@ -13,5 +14,6 @@
'early_stop': EarlyStopCallback,
'graceful_exit': GracefulExitCallback,
'lisa': LISACallback,
'perf_log': PerfMetricsLogCallback
'perf_log': PerfMetricsLogCallback,
'profiler': ProfilerCallback,
}
26 changes: 26 additions & 0 deletions swift/callbacks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from transformers.trainer_callback import ProgressCallback, TrainerControl, TrainerState

from swift.utils import get_logger
from swift.utils.profiler import DistProfiler

logger = get_logger()


class ProfilerCallback(ProgressCallback):

def __init__(self, args, trainer):
super().__init__()
self.args = args
self.trainer = trainer
self.trainer.profiler = DistProfiler(global_config=args)

def on_step_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs):
if self.args.profiler_steps and state.global_step in self.args.profiler_steps:
self.trainer.profiler.start()
super().on_step_begin(args, state, control, **kwargs)

def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
if self.args.profiler_steps and state.global_step + 1 not in self.args.profiler_steps:
self.trainer.profiler.stop()
super().on_step_end(args, state, control, **kwargs)
2 changes: 2 additions & 0 deletions swift/megatron/callbacks/mapping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .default_flow import DefaultFlowCallback
from .print import PrintCallback
from .profiler import ProfilerCallback
from .swanlab import SwanlabCallback
from .tensorboard import TensorboardCallback
from .wandb import WandbCallback
Expand All @@ -11,4 +12,5 @@
'swanlab': SwanlabCallback,
'wandb': WandbCallback,
'tensorboard': TensorboardCallback,
'profiler': ProfilerCallback,
}
23 changes: 23 additions & 0 deletions swift/megatron/callbacks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from swift.utils import get_logger
from swift.utils.profiler import DistProfiler
from .base import MegatronCallback

logger = get_logger()


class ProfilerCallback(MegatronCallback):

def __init__(self, trainer):
super().__init__(trainer)
self.args = trainer.args
self.trainer = trainer
self.trainer.profiler = DistProfiler(global_config=self.args)

def on_step_begin(self):
if self.args.profiler_steps and self.state.global_step in self.args.profiler_steps:
self.trainer.profiler.start()

def on_step_end(self):
if self.args.profiler_steps and self.state.global_step + 1 not in self.args.profiler_steps:
self.trainer.profiler.stop()
20 changes: 20 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ class TrainArgumentsMixin:
shared memory and then asynchronously persisted to disk. Currently does not support the safetensors format.
It is recommended to use this with `PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"` to prevent CUDA OOM
errors during training. Defaults to False.

enable_profiler (bool): Master switch to enable or disable performance profiling. Default is False.
profiler_save_path (Optional[str]): Directory path where the profiling results and trace files will be saved.
profiler_all_ranks (bool): If True, collects profiling data from all distributed processes.
profiler_ranks (List[int]): A list of specific rank IDs to collect profiling data from.
profiler_contents (List[str]): List of data categories to record, such as "cpu", "cuda", "memory", or "stack".
profiler_discrete (bool): If True, records data for each step independently.
profiler_tool (Optional[str]): Specifies the backend tool used for profiling.
profiler_steps (Optional[List[int]]): A list of specific training steps during which profiling should be active.

"""
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
Expand Down Expand Up @@ -202,6 +212,16 @@ class TrainArgumentsMixin:
# dlrover flash_checkpoint
use_flash_ckpt: bool = False

# profiler
enable_profiler: bool = False
profiler_save_path: Optional[str] = None
profiler_all_ranks: bool = False
profiler_ranks: List[int] = field(default_factory=list)
profiler_contents: List[str] = field(default_factory=list) # e.g., "cpu", "cuda", "stack", "memory"."shape"
profiler_discrete: bool = False
profiler_tool: Optional[str] = 'torch'
profiler_steps: Optional[List[int]] = field(default_factory=list) # Steps to profile

@staticmethod
def _patch_liger_kernel():
# fix logits_to_keep
Expand Down
7 changes: 7 additions & 0 deletions swift/utils/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .profile import DistProfiler, DistProfilerExtension, ProfilerConfig

__all__ = [
'DistProfiler',
'DistProfilerExtension',
'ProfilerConfig',
]
142 changes: 142 additions & 0 deletions swift/utils/profiler/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import collections
from dataclasses import FrozenInstanceError, dataclass, field, fields
from typing import Any, Optional

# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary


@dataclass
class BaseConfig(collections.abc.Mapping):
"""The BaseConfig provides dict-like interface for a dataclass config.

By default all fields in the config is not mutable, unless specified in
"_mutable_fields". The BaseConfig class implements the Mapping Abstract Base Class.
This allows instances of this class to be used like dictionaries.
"""

_mutable_fields = set()
_target_: str = ''

def __setattr__(self, name: str, value):
"""Set the value of an attribute. Check if the attr is mutable before setting the value."""
# If the field already exists, it's considered frozen unless it's in _mutable_fields
if name in self.__dict__ and name not in getattr(self, '_mutable_fields', set()):
raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified")
super().__setattr__(name, value)

def get(self, key: str, default: Any = None) -> Any:
"""Get the value associated with the given key. If the key does not exist, return the default value.

Args:
key (str): The attribute name to retrieve.
default (Any, optional): The value to return if the attribute does not exist. Defaults to None.

Returns:
Any: The value of the attribute or the default value.
"""
try:
return getattr(self, key)
except AttributeError:
return default

def __getitem__(self, key: str):
"""Implement the [] operator for the class. Allows accessing attributes like dictionary items.

Args:
key (str): The attribute name to retrieve.

Returns:
Any: The value of the attribute.

Raises:
AttributeError: If the attribute does not exist.
TypeError: If the key type is not string
"""
return getattr(self, key)

def __iter__(self):
"""Implement the iterator protocol. Allows iterating over the attribute names of the instance.

Yields:
str: The name of each field in the dataclass.
"""
for f in fields(self):
yield f.name

def __len__(self):
"""
Return the number of fields in the dataclass.

Returns:
int: The number of fields in the dataclass.
"""
return len(fields(self))


@dataclass
class ProfilerConfig(BaseConfig):
"""Worker profiler config.

Args:
discrete (bool): True for each task has its own database, False for all tasks in one training step
share one database.
all_ranks (bool): Whether to profile all ranks.
ranks (list[int]): The ranks that will be profiled. Defaults to [].
global_tool_config (Any): Global tool configuration for all profiling tools.
"""

tool: Optional[str] = None
enable: bool = False
all_ranks: bool = False
ranks: list[int] = field(default_factory=list)
save_path: Optional[str] = None
tool_config: Any = None
global_tool_config: Optional[Any] = None # Global tool configuration for all profiling tools

def union(self, other: 'ProfilerConfig') -> 'ProfilerConfig':
assert self.tool == other.tool, f"Cannot union ProfilerConfig with different tools: {self.tool} vs {other.tool}"
return ProfilerConfig(
tool=self.tool,
enable=self.enable or other.enable,
all_ranks=self.all_ranks or other.all_ranks,
ranks=list(set(self.ranks or []) | set(other.ranks or [])),
save_path=self.save_path or other.save_path,
tool_config=self.tool_config or other.tool_config,
global_tool_config=self.global_tool_config or other.global_tool_config,
)
Comment thread
qq1243196045 marked this conversation as resolved.

def intersect(self, other: 'ProfilerConfig') -> 'ProfilerConfig':
assert self.tool == other.tool, (
f"Cannot intersect ProfilerConfig with different tools: {self.tool} vs {other.tool}")
return ProfilerConfig(
tool=self.tool,
enable=self.enable and other.enable,
all_ranks=self.all_ranks and other.all_ranks,
ranks=list(set(self.ranks or []) & set(other.ranks or [])),
save_path=self.save_path,
tool_config=self.tool_config,
global_tool_config=self.global_tool_config if self.global_tool_config else other.global_tool_config,
)
Comment on lines +108 to +119

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the intersect method, if self.save_path or self.tool_config is None, it should fallback to other.save_path or other.tool_config instead of strictly using self's values which might be None.

Suggested change
def intersect(self, other: 'ProfilerConfig') -> 'ProfilerConfig':
assert self.tool == other.tool, (
f"Cannot intersect ProfilerConfig with different tools: {self.tool} vs {other.tool}")
return ProfilerConfig(
tool=self.tool,
enable=self.enable and other.enable,
all_ranks=self.all_ranks and other.all_ranks,
ranks=list(set(self.ranks or []) & set(other.ranks or [])),
save_path=self.save_path,
tool_config=self.tool_config,
global_tool_config=self.global_tool_config if self.global_tool_config else other.global_tool_config,
)
def intersect(self, other: 'ProfilerConfig') -> 'ProfilerConfig':
assert self.tool == other.tool, (
f"Cannot intersect ProfilerConfig with different tools: {self.tool} vs {other.tool}")
return ProfilerConfig(
tool=self.tool,
enable=self.enable and other.enable,
all_ranks=self.all_ranks and other.all_ranks,
ranks=list(set(self.ranks or []) & set(other.ranks or [])),
save_path=self.save_path or other.save_path,
tool_config=self.tool_config or other.tool_config,
global_tool_config=self.global_tool_config if self.global_tool_config else other.global_tool_config,
)


def __post_init__(self) -> None:
"""config validation logics go here"""
assert isinstance(self.ranks,
(set, list, tuple)), (f"Profiler ranks must be of type list, got {type(self.ranks)}")


@dataclass
class TorchProfilerToolConfig(BaseConfig):
"""Torch profiler tool config."""

# options: cuda, cpu, memory, shapes, stack
contents: list[str] = field(default_factory=list)
discrete: bool = False
name: str = 'torch'

def __post_init__(self) -> None:
"""config validation logics go here"""
assert isinstance(self.contents, list), f"Profiler contents must be of type list, got {type(self.contents)}"
__support_contents = ['cuda', 'cpu', 'memory', 'shapes', 'stack']
for content in self.contents:
assert content in __support_contents, (
f"Profiler contents only supports {__support_contents}, but gets {content}")
Loading
Loading