Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions swift/arguments/sft_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def __post_init__(self) -> None:
self.eval_strategy = 'no'
self.training_args = TrainerFactory.get_training_args(self)
self.training_args.remove_unused_columns = False
# The generic HF log patch only sees training_args, so pass the fixed-length TGS size through.
self.training_args.seq_length = self.packing_length or self.max_length
self._add_version()

if 'swanlab' in self.report_to:
Expand Down Expand Up @@ -413,9 +415,6 @@ def _init_metric(self):
self.eval_metric = 'reranker'
if self.eval_metric == 'nlg':
require_version('jieba', 'Setting `--eval_metric nlg` requires installing the jieba dependency.')
self._init_metric_for_best_model()

def _init_metric_for_best_model(self):
if self.metric_for_best_model is None:
self.metric_for_best_model = 'rouge-l' if self.predict_with_generate else 'loss'
if self.greater_is_better is None and self.metric_for_best_model is not None:
Expand Down
57 changes: 56 additions & 1 deletion swift/megatron/callbacks/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tqdm import tqdm

from swift.megatron.utils import reduce_max_stat_across_model_parallel_group
from swift.utils import JsonlWriter, format_time, get_logger, is_last_rank
from swift.utils import JsonlWriter, format_time, get_env_args, get_logger, is_last_rank
from .base import MegatronCallback

logger = get_logger()
Expand All @@ -19,6 +19,11 @@ def __init__(self, trainer):
self.eval_bar = None
self.jsonl_writer = None
self.is_write_rank = is_last_rank()
self.device_peak_tflops = get_env_args('DEVICE_TFLOPS', float, None)
if self.device_peak_tflops is not None:
logger.info(
f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{self.device_peak_tflops} TFLOPS]")
self.device_peak_tflops = float(self.device_peak_tflops)

def on_train_begin(self):
self.training_bar = tqdm(
Expand Down Expand Up @@ -59,6 +64,16 @@ def on_log(self, logs):
logs['elapsed_time'] = format_time(elapsed)
n_steps = state.iteration - self.start_step
train_speed = elapsed / n_steps if n_steps > 0 else 0.0
seq_length = getattr(args, 'seq_length', None)
world_size = getattr(args, 'world_size', None) or 1
if train_speed > 0 and seq_length:
logs['tgs'] = round(
args.global_batch_size * args.seq_length / train_speed / world_size, 3)
if self.device_peak_tflops:
throughput = self._get_throughput_tflops_per_gpu(train_speed)
if throughput is not None:
logs['throughput(TFLOP/s/GPU)'] = round(throughput, 3)
logs['MFU'] = round(throughput / self.device_peak_tflops, 6)
logs['remaining_time'] = format_time((args.train_iters - state.iteration) * train_speed)
memory = reduce_max_stat_across_model_parallel_group(torch.cuda.max_memory_reserved() / 1024**3)
logs['memory(GiB)'] = round(memory, 2)
Expand All @@ -67,3 +82,43 @@ def on_log(self, logs):
self.jsonl_writer.append(logs)
if self.is_write_rank:
self.training_bar.write(str(logs))

def _get_throughput_tflops_per_gpu(self, train_speed):
if train_speed <= 0:
return None
world_size = getattr(self.args, 'world_size', None) or 1
num_flops = self._num_floating_point_operations(self.args.global_batch_size)
if num_flops is None:
return None
return num_flops / (train_speed * 10**12 * world_size)

def _num_floating_point_operations(self, batch_size):
seq_length = getattr(self.args, 'seq_length', None)
if seq_length is None:
return None
config = self.trainer.config
hidden_size = getattr(config, 'hidden_size', None)
num_layers = getattr(config, 'num_layers', None)
num_attention_heads = getattr(config, 'num_attention_heads', None)
ffn_hidden_size = getattr(config, 'ffn_hidden_size', None)
if None in {hidden_size, num_layers, num_attention_heads, ffn_hidden_size}:
return None

kv_channels = getattr(config, 'kv_channels', None) or hidden_size // num_attention_heads
num_query_groups = getattr(config, 'num_query_groups', None) or num_attention_heads
padded_vocab_size = getattr(config, 'padded_vocab_size', None) or getattr(config, 'vocab_size', None)
if padded_vocab_size is None:
return None

query_projection_size = kv_channels * num_attention_heads
query_projection_to_hidden_size_ratio = query_projection_size / hidden_size
num_experts_routed_to = 1 if getattr(config, 'num_moe_experts', None) is None else getattr(
config, 'moe_router_topk', 1)
gated_linear_multiplier = 1.5 if getattr(config, 'swiglu', False) else 1.0

return (
12 * batch_size * seq_length * num_layers * hidden_size * hidden_size * (
(1 + (num_query_groups / num_attention_heads) + (seq_length / hidden_size))
* query_projection_to_hidden_size_ratio
+ (ffn_hidden_size / hidden_size) * num_experts_routed_to * gated_linear_multiplier
+ padded_vocab_size / (2 * num_layers * hidden_size)))
Comment on lines +95 to +124

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If self.args.seq_length is None, this calculation will raise a TypeError. Retrieving seq_length defensively and returning None early if it is not set prevents this issue.

    def _num_floating_point_operations(self, batch_size):
        seq_length = getattr(self.args, 'seq_length', None)
        if seq_length is None:
            return None
        config = self.trainer.config
        hidden_size = getattr(config, 'hidden_size', None)
        num_layers = getattr(config, 'num_layers', None)
        num_attention_heads = getattr(config, 'num_attention_heads', None)
        ffn_hidden_size = getattr(config, 'ffn_hidden_size', None)
        if None in {hidden_size, num_layers, num_attention_heads, ffn_hidden_size}:
            return None

        kv_channels = getattr(config, 'kv_channels', None) or hidden_size // num_attention_heads
        num_query_groups = getattr(config, 'num_query_groups', None) or num_attention_heads
        padded_vocab_size = getattr(config, 'padded_vocab_size', None) or getattr(config, 'vocab_size', None)
        if padded_vocab_size is None:
            return None

        query_projection_size = kv_channels * num_attention_heads
        query_projection_to_hidden_size_ratio = query_projection_size / hidden_size
        num_experts_routed_to = 1 if getattr(config, 'num_moe_experts', None) is None else getattr(
            config, 'moe_router_topk', 1)
        gated_linear_multiplier = 1.5 if getattr(config, 'swiglu', False) else 1.0

        return (
            12 * batch_size * seq_length * num_layers * hidden_size * hidden_size * (
                (1 + (num_query_groups / num_attention_heads) + (seq_length / hidden_size))
                * query_projection_to_hidden_size_ratio
                + (ffn_hidden_size / hidden_size) * num_experts_routed_to * gated_linear_multiplier
                + padded_vocab_size / (2 * num_layers * hidden_size)))

56 changes: 51 additions & 5 deletions swift/trainers/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,24 @@
TrainerState)
from transformers.trainer_utils import IntervalStrategy, has_length

from swift.utils import append_to_jsonl, format_time, get_logger, get_max_reserved_memory, is_pai_training_job
from swift.utils import (append_to_jsonl, format_time, get_env_args, get_logger, get_max_reserved_memory,
is_pai_training_job)
from .arguments import TrainingArguments

logger = get_logger()


def add_train_message(logs, state, start_time, start_step) -> None:
def add_train_message(logs, state, start_time, start_step, args=None) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If logs is None, attempting to set keys on it will raise a TypeError. Adding an early return when logs is None prevents potential runtime crashes.

Suggested change
def add_train_message(logs, state, start_time, start_step, args=None) -> None:
def add_train_message(logs, state, start_time, start_step, args=None) -> None:
if logs is None:
return

logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}'
elapsed = time.time() - start_time
logs['elapsed_time'] = format_time(elapsed)
n_steps = state.global_step - start_step
train_speed = elapsed / n_steps if n_steps > 0 else 0.0
seq_length = getattr(args, 'seq_length', None) if args is not None else None
if train_speed > 0 and seq_length:
world_size = max(getattr(args, 'world_size', None) or 1, 1)
global_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size
logs['tgs'] = round(global_batch_size * seq_length / train_speed / world_size, 3)
logs['remaining_time'] = format_time((state.max_steps - state.global_step) * train_speed)
for k, v in logs.items():
if isinstance(v, float):
Expand Down Expand Up @@ -49,7 +55,7 @@ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader
self.prediction_bar.update()

def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
add_train_message(logs, state, self.start_time, self.start_step)
add_train_message(logs, state, self.start_time, self.start_step, args)
if not is_pai_training_job() and state.is_world_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, logs)
Expand All @@ -58,6 +64,42 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=Non
self.training_bar.refresh()


class ProgressCallbackNewWithMFU(ProgressCallbackNew):
# perf_log is a user callback and can run after progress logs are written; inject MFU before printing here.

def __init__(self):
super().__init__()
self.device_tflops = None
self.elapsed = 0.0
self.start_flos = 0
self.step_start_time = None

def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
tflops = get_env_args('DEVICE_TFLOPS', float, None)
if tflops is not None:
self.device_tflops = tflops * max(getattr(args, 'world_size', None) or 1, 1)
return super().on_init_end(args, state, control, **kwargs)

def on_train_begin(self, args, state, control, **kwargs):
self.start_flos = getattr(state, 'total_flos', 0)
return super().on_train_begin(args, state, control, **kwargs)

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.step_start_time = time.time()
return super().on_step_begin(args, state, control, **kwargs)

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if self.step_start_time is not None:
self.elapsed += time.time() - self.step_start_time
return super().on_step_end(args, state, control, **kwargs)

def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
total_flos = getattr(state, 'total_flos', 0) - self.start_flos
if self.elapsed > 0 and self.device_tflops and total_flos > 0 and logs is not None:
logs['MFU'] = round(total_flos / self.elapsed / (self.device_tflops * 1e12), 6)
return super().on_log(args, state, control, logs, **kwargs)
Comment on lines +96 to +100

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If logs is None, attempting to assign logs['MFU'] will raise a TypeError. Checking that logs is not None before assigning makes the callback robust.

Suggested change
def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
total_flos = getattr(state, 'total_flos', 0) - self.start_flos
if self.elapsed > 0 and self.device_tflops and total_flos > 0:
logs['MFU'] = round(total_flos / self.elapsed / (self.device_tflops * 1e12), 6)
return super().on_log(args, state, control, logs, **kwargs)
def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
total_flos = getattr(state, 'total_flos', 0) - self.start_flos
if self.elapsed > 0 and self.device_tflops and total_flos > 0 and logs is not None:
logs['MFU'] = round(total_flos / self.elapsed / (self.device_tflops * 1e12), 6)
return super().on_log(args, state, control, logs, **kwargs)



class DefaultFlowCallbackNew(DefaultFlowCallback):

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
Expand Down Expand Up @@ -92,7 +134,7 @@ def on_train_begin(self, args, state, control, **kwargs):
return super().on_train_begin(args, state, control, **kwargs)

def on_log(self, args, state, control, logs=None, **kwargs):
add_train_message(logs, state, self.start_time, self.start_step)
add_train_message(logs, state, self.start_time, self.start_step, args)
if not is_pai_training_job() and state.is_world_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, logs)
Expand All @@ -103,6 +145,10 @@ def on_log(self, args, state, control, logs=None, **kwargs):


# monkey patching
trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew
# MFU needs user-provided per-device peak TFLOPS; otherwise keep the default progress callback unchanged.
if get_env_args('DEVICE_TFLOPS', float, None) is not None:
trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNewWithMFU
else:
trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew
trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew]
trainer.PrinterCallback = PrinterCallbackNew