diff --git a/swift/arguments/sft_args.py b/swift/arguments/sft_args.py index 6194d3a7f4..e63d797b0f 100644 --- a/swift/arguments/sft_args.py +++ b/swift/arguments/sft_args.py @@ -231,6 +231,8 @@ def __post_init__(self) -> None: self.eval_strategy = 'no' self.training_args = TrainerFactory.get_training_args(self) self.training_args.remove_unused_columns = False + # The generic HF log patch only sees training_args, so pass the fixed-length TGS size through. + self.training_args.seq_length = self.packing_length or self.max_length self._add_version() if 'swanlab' in self.report_to: @@ -413,9 +415,6 @@ def _init_metric(self): self.eval_metric = 'reranker' if self.eval_metric == 'nlg': require_version('jieba', 'Setting `--eval_metric nlg` requires installing the jieba dependency.') - self._init_metric_for_best_model() - - def _init_metric_for_best_model(self): if self.metric_for_best_model is None: self.metric_for_best_model = 'rouge-l' if self.predict_with_generate else 'loss' if self.greater_is_better is None and self.metric_for_best_model is not None: diff --git a/swift/megatron/callbacks/print.py b/swift/megatron/callbacks/print.py index 01a41b385d..15323c2873 100644 --- a/swift/megatron/callbacks/print.py +++ b/swift/megatron/callbacks/print.py @@ -5,7 +5,7 @@ from tqdm import tqdm from swift.megatron.utils import reduce_max_stat_across_model_parallel_group -from swift.utils import JsonlWriter, format_time, get_logger, is_last_rank +from swift.utils import JsonlWriter, format_time, get_env_args, get_logger, is_last_rank from .base import MegatronCallback logger = get_logger() @@ -19,6 +19,11 @@ def __init__(self, trainer): self.eval_bar = None self.jsonl_writer = None self.is_write_rank = is_last_rank() + self.device_peak_tflops = get_env_args('DEVICE_TFLOPS', float, None) + if self.device_peak_tflops is not None: + logger.info( + f"Specify theoretical max TFLOPS through ENV 'DEVICE_TFLOPS'. [{self.device_peak_tflops} TFLOPS]") + self.device_peak_tflops = float(self.device_peak_tflops) def on_train_begin(self): self.training_bar = tqdm( @@ -59,6 +64,16 @@ def on_log(self, logs): logs['elapsed_time'] = format_time(elapsed) n_steps = state.iteration - self.start_step train_speed = elapsed / n_steps if n_steps > 0 else 0.0 + seq_length = getattr(args, 'seq_length', None) + world_size = getattr(args, 'world_size', None) or 1 + if train_speed > 0 and seq_length: + logs['tgs'] = round( + args.global_batch_size * args.seq_length / train_speed / world_size, 3) + if self.device_peak_tflops: + throughput = self._get_throughput_tflops_per_gpu(train_speed) + if throughput is not None: + logs['throughput(TFLOP/s/GPU)'] = round(throughput, 3) + logs['MFU'] = round(throughput / self.device_peak_tflops, 6) logs['remaining_time'] = format_time((args.train_iters - state.iteration) * train_speed) memory = reduce_max_stat_across_model_parallel_group(torch.cuda.max_memory_reserved() / 1024**3) logs['memory(GiB)'] = round(memory, 2) @@ -67,3 +82,43 @@ def on_log(self, logs): self.jsonl_writer.append(logs) if self.is_write_rank: self.training_bar.write(str(logs)) + + def _get_throughput_tflops_per_gpu(self, train_speed): + if train_speed <= 0: + return None + world_size = getattr(self.args, 'world_size', None) or 1 + num_flops = self._num_floating_point_operations(self.args.global_batch_size) + if num_flops is None: + return None + return num_flops / (train_speed * 10**12 * world_size) + + def _num_floating_point_operations(self, batch_size): + seq_length = getattr(self.args, 'seq_length', None) + if seq_length is None: + return None + config = self.trainer.config + hidden_size = getattr(config, 'hidden_size', None) + num_layers = getattr(config, 'num_layers', None) + num_attention_heads = getattr(config, 'num_attention_heads', None) + ffn_hidden_size = getattr(config, 'ffn_hidden_size', None) + if None in {hidden_size, num_layers, num_attention_heads, ffn_hidden_size}: + return None + + kv_channels = getattr(config, 'kv_channels', None) or hidden_size // num_attention_heads + num_query_groups = getattr(config, 'num_query_groups', None) or num_attention_heads + padded_vocab_size = getattr(config, 'padded_vocab_size', None) or getattr(config, 'vocab_size', None) + if padded_vocab_size is None: + return None + + query_projection_size = kv_channels * num_attention_heads + query_projection_to_hidden_size_ratio = query_projection_size / hidden_size + num_experts_routed_to = 1 if getattr(config, 'num_moe_experts', None) is None else getattr( + config, 'moe_router_topk', 1) + gated_linear_multiplier = 1.5 if getattr(config, 'swiglu', False) else 1.0 + + return ( + 12 * batch_size * seq_length * num_layers * hidden_size * hidden_size * ( + (1 + (num_query_groups / num_attention_heads) + (seq_length / hidden_size)) + * query_projection_to_hidden_size_ratio + + (ffn_hidden_size / hidden_size) * num_experts_routed_to * gated_linear_multiplier + + padded_vocab_size / (2 * num_layers * hidden_size))) diff --git a/swift/trainers/patcher.py b/swift/trainers/patcher.py index a18071647e..644bc2632a 100644 --- a/swift/trainers/patcher.py +++ b/swift/trainers/patcher.py @@ -8,18 +8,24 @@ TrainerState) from transformers.trainer_utils import IntervalStrategy, has_length -from swift.utils import append_to_jsonl, format_time, get_logger, get_max_reserved_memory, is_pai_training_job +from swift.utils import (append_to_jsonl, format_time, get_env_args, get_logger, get_max_reserved_memory, + is_pai_training_job) from .arguments import TrainingArguments logger = get_logger() -def add_train_message(logs, state, start_time, start_step) -> None: +def add_train_message(logs, state, start_time, start_step, args=None) -> None: logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}' elapsed = time.time() - start_time logs['elapsed_time'] = format_time(elapsed) n_steps = state.global_step - start_step train_speed = elapsed / n_steps if n_steps > 0 else 0.0 + seq_length = getattr(args, 'seq_length', None) if args is not None else None + if train_speed > 0 and seq_length: + world_size = max(getattr(args, 'world_size', None) or 1, 1) + global_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size + logs['tgs'] = round(global_batch_size * seq_length / train_speed / world_size, 3) logs['remaining_time'] = format_time((state.max_steps - state.global_step) * train_speed) for k, v in logs.items(): if isinstance(v, float): @@ -49,7 +55,7 @@ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader self.prediction_bar.update() def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs): - add_train_message(logs, state, self.start_time, self.start_step) + add_train_message(logs, state, self.start_time, self.start_step, args) if not is_pai_training_job() and state.is_world_process_zero: jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') append_to_jsonl(jsonl_path, logs) @@ -58,6 +64,42 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=Non self.training_bar.refresh() +class ProgressCallbackNewWithMFU(ProgressCallbackNew): + # perf_log is a user callback and can run after progress logs are written; inject MFU before printing here. + + def __init__(self): + super().__init__() + self.device_tflops = None + self.elapsed = 0.0 + self.start_flos = 0 + self.step_start_time = None + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + tflops = get_env_args('DEVICE_TFLOPS', float, None) + if tflops is not None: + self.device_tflops = tflops * max(getattr(args, 'world_size', None) or 1, 1) + return super().on_init_end(args, state, control, **kwargs) + + def on_train_begin(self, args, state, control, **kwargs): + self.start_flos = getattr(state, 'total_flos', 0) + return super().on_train_begin(args, state, control, **kwargs) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + self.step_start_time = time.time() + return super().on_step_begin(args, state, control, **kwargs) + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if self.step_start_time is not None: + self.elapsed += time.time() - self.step_start_time + return super().on_step_end(args, state, control, **kwargs) + + def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs): + total_flos = getattr(state, 'total_flos', 0) - self.start_flos + if self.elapsed > 0 and self.device_tflops and total_flos > 0 and logs is not None: + logs['MFU'] = round(total_flos / self.elapsed / (self.device_tflops * 1e12), 6) + return super().on_log(args, state, control, logs, **kwargs) + + class DefaultFlowCallbackNew(DefaultFlowCallback): def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): @@ -92,7 +134,7 @@ def on_train_begin(self, args, state, control, **kwargs): return super().on_train_begin(args, state, control, **kwargs) def on_log(self, args, state, control, logs=None, **kwargs): - add_train_message(logs, state, self.start_time, self.start_step) + add_train_message(logs, state, self.start_time, self.start_step, args) if not is_pai_training_job() and state.is_world_process_zero: jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') append_to_jsonl(jsonl_path, logs) @@ -103,6 +145,10 @@ def on_log(self, args, state, control, logs=None, **kwargs): # monkey patching -trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew +# MFU needs user-provided per-device peak TFLOPS; otherwise keep the default progress callback unchanged. +if get_env_args('DEVICE_TFLOPS', float, None) is not None: + trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNewWithMFU +else: + trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew] trainer.PrinterCallback = PrinterCallbackNew