-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add lightweight TGS and MFU logging for training #9465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,18 +8,24 @@ | |||||||||||||||||||||
| TrainerState) | ||||||||||||||||||||||
| from transformers.trainer_utils import IntervalStrategy, has_length | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from swift.utils import append_to_jsonl, format_time, get_logger, get_max_reserved_memory, is_pai_training_job | ||||||||||||||||||||||
| from swift.utils import (append_to_jsonl, format_time, get_env_args, get_logger, get_max_reserved_memory, | ||||||||||||||||||||||
| is_pai_training_job) | ||||||||||||||||||||||
| from .arguments import TrainingArguments | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| logger = get_logger() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def add_train_message(logs, state, start_time, start_step) -> None: | ||||||||||||||||||||||
| def add_train_message(logs, state, start_time, start_step, args=None) -> None: | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Suggested change
|
||||||||||||||||||||||
| logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}' | ||||||||||||||||||||||
| elapsed = time.time() - start_time | ||||||||||||||||||||||
| logs['elapsed_time'] = format_time(elapsed) | ||||||||||||||||||||||
| n_steps = state.global_step - start_step | ||||||||||||||||||||||
| train_speed = elapsed / n_steps if n_steps > 0 else 0.0 | ||||||||||||||||||||||
| seq_length = getattr(args, 'seq_length', None) if args is not None else None | ||||||||||||||||||||||
| if train_speed > 0 and seq_length: | ||||||||||||||||||||||
| world_size = max(getattr(args, 'world_size', None) or 1, 1) | ||||||||||||||||||||||
| global_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size | ||||||||||||||||||||||
| logs['tgs'] = round(global_batch_size * seq_length / train_speed / world_size, 3) | ||||||||||||||||||||||
| logs['remaining_time'] = format_time((state.max_steps - state.global_step) * train_speed) | ||||||||||||||||||||||
| for k, v in logs.items(): | ||||||||||||||||||||||
| if isinstance(v, float): | ||||||||||||||||||||||
|
|
@@ -49,7 +55,7 @@ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader | |||||||||||||||||||||
| self.prediction_bar.update() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs): | ||||||||||||||||||||||
| add_train_message(logs, state, self.start_time, self.start_step) | ||||||||||||||||||||||
| add_train_message(logs, state, self.start_time, self.start_step, args) | ||||||||||||||||||||||
| if not is_pai_training_job() and state.is_world_process_zero: | ||||||||||||||||||||||
| jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') | ||||||||||||||||||||||
| append_to_jsonl(jsonl_path, logs) | ||||||||||||||||||||||
|
|
@@ -58,6 +64,42 @@ def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=Non | |||||||||||||||||||||
| self.training_bar.refresh() | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class ProgressCallbackNewWithMFU(ProgressCallbackNew): | ||||||||||||||||||||||
| # perf_log is a user callback and can run after progress logs are written; inject MFU before printing here. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||
| self.device_tflops = None | ||||||||||||||||||||||
| self.elapsed = 0.0 | ||||||||||||||||||||||
| self.start_flos = 0 | ||||||||||||||||||||||
| self.step_start_time = None | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||
| tflops = get_env_args('DEVICE_TFLOPS', float, None) | ||||||||||||||||||||||
| if tflops is not None: | ||||||||||||||||||||||
| self.device_tflops = tflops * max(getattr(args, 'world_size', None) or 1, 1) | ||||||||||||||||||||||
| return super().on_init_end(args, state, control, **kwargs) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def on_train_begin(self, args, state, control, **kwargs): | ||||||||||||||||||||||
| self.start_flos = getattr(state, 'total_flos', 0) | ||||||||||||||||||||||
| return super().on_train_begin(args, state, control, **kwargs) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||
| self.step_start_time = time.time() | ||||||||||||||||||||||
| return super().on_step_begin(args, state, control, **kwargs) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||
| if self.step_start_time is not None: | ||||||||||||||||||||||
| self.elapsed += time.time() - self.step_start_time | ||||||||||||||||||||||
| return super().on_step_end(args, state, control, **kwargs) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs): | ||||||||||||||||||||||
| total_flos = getattr(state, 'total_flos', 0) - self.start_flos | ||||||||||||||||||||||
| if self.elapsed > 0 and self.device_tflops and total_flos > 0 and logs is not None: | ||||||||||||||||||||||
| logs['MFU'] = round(total_flos / self.elapsed / (self.device_tflops * 1e12), 6) | ||||||||||||||||||||||
| return super().on_log(args, state, control, logs, **kwargs) | ||||||||||||||||||||||
|
Comment on lines
+96
to
+100
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class DefaultFlowCallbackNew(DefaultFlowCallback): | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): | ||||||||||||||||||||||
|
|
@@ -92,7 +134,7 @@ def on_train_begin(self, args, state, control, **kwargs): | |||||||||||||||||||||
| return super().on_train_begin(args, state, control, **kwargs) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def on_log(self, args, state, control, logs=None, **kwargs): | ||||||||||||||||||||||
| add_train_message(logs, state, self.start_time, self.start_step) | ||||||||||||||||||||||
| add_train_message(logs, state, self.start_time, self.start_step, args) | ||||||||||||||||||||||
| if not is_pai_training_job() and state.is_world_process_zero: | ||||||||||||||||||||||
| jsonl_path = os.path.join(args.output_dir, 'logging.jsonl') | ||||||||||||||||||||||
| append_to_jsonl(jsonl_path, logs) | ||||||||||||||||||||||
|
|
@@ -103,6 +145,10 @@ def on_log(self, args, state, control, logs=None, **kwargs): | |||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # monkey patching | ||||||||||||||||||||||
| trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew | ||||||||||||||||||||||
| # MFU needs user-provided per-device peak TFLOPS; otherwise keep the default progress callback unchanged. | ||||||||||||||||||||||
| if get_env_args('DEVICE_TFLOPS', float, None) is not None: | ||||||||||||||||||||||
| trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNewWithMFU | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew | ||||||||||||||||||||||
| trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew] | ||||||||||||||||||||||
| trainer.PrinterCallback = PrinterCallbackNew | ||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
self.args.seq_lengthisNone, this calculation will raise aTypeError. Retrievingseq_lengthdefensively and returningNoneearly if it is not set prevents this issue.