From d68f2e883994bc5d65e85cbae43519475f03364b Mon Sep 17 00:00:00 2001 From: Jingyuan-zhu Date: Thu, 26 Feb 2026 04:46:22 +0000 Subject: [PATCH] feat: add training stability guards and fix transformers 5.2.0 compatibility --- src/lmflow/pipeline/finetuner.py | 27 ++++++++++++++++++++++- src/lmflow/pipeline/utils/raft_trainer.py | 4 ++-- tests/pipeline/test_sglang_infernecer.py | 3 +++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/lmflow/pipeline/finetuner.py b/src/lmflow/pipeline/finetuner.py index a4ff2c25c..195799415 100644 --- a/src/lmflow/pipeline/finetuner.py +++ b/src/lmflow/pipeline/finetuner.py @@ -107,7 +107,7 @@ def __init__( if ( os.path.isdir(finetuner_args.output_dir) and finetuner_args.do_train - and not finetuner_args.overwrite_output_dir + and not getattr(finetuner_args, "overwrite_output_dir", False) ): last_checkpoint = get_last_checkpoint(finetuner_args.output_dir) if last_checkpoint is None and len(os.listdir(finetuner_args.output_dir)) > 0: @@ -322,6 +322,31 @@ def compute_metrics(eval_preds): max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) + if getattr(finetuner_args, "bf16", False) and not torch.cuda.is_bf16_supported(): + logger.warning( + "Hardware does not support bfloat16 (requires Ampere architecture or newer). " + "Automatically falling back to fp16 to prevent CUDA crashes." + ) + finetuner_args.bf16 = False + finetuner_args.fp16 = True + + if getattr(finetuner_args, "gradient_checkpointing", False): + backend_model = model.get_backend_model() + if hasattr(backend_model, "config") and getattr(backend_model.config, "use_cache", False): + logger.info("Gradient checkpointing is enabled. Automatically setting use_cache=False for the model.") + backend_model.config.use_cache = False + + backend_model = model.get_backend_model() + tokenizer = model.get_tokenizer() + if hasattr(backend_model, "get_input_embeddings") and tokenizer is not None: + embeddings = backend_model.get_input_embeddings() + if embeddings is not None and len(tokenizer) > embeddings.weight.shape[0]: + logger.warning( + f"Tokenizer vocabulary size ({len(tokenizer)}) is greater than model embedding dimension " + f"({embeddings.weight.shape[0]}). Resizing model embeddings to prevent CUDA out-of-bounds crashes." + ) + backend_model.resize_token_embeddings(len(tokenizer)) + # Initialize our Trainer training_args = finetuner_args FinetuningTrainer = Trainer diff --git a/src/lmflow/pipeline/utils/raft_trainer.py b/src/lmflow/pipeline/utils/raft_trainer.py index 83686c11b..0d89b18f2 100644 --- a/src/lmflow/pipeline/utils/raft_trainer.py +++ b/src/lmflow/pipeline/utils/raft_trainer.py @@ -3389,7 +3389,7 @@ def init_git_repo(self, at_init: bool = False): Initializes a git repo in `self.args.hub_model_id`. Args: at_init (`bool`, *optional*, defaults to `False`): - Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is + Whether this function is called before any training or not. If `getattr(self.args, "overwrite_output_dir", False)` is `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. """ @@ -3407,7 +3407,7 @@ def init_git_repo(self, at_init: bool = False): try: self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) except OSError: - if self.args.overwrite_output_dir and at_init: + if getattr(self.args, "overwrite_output_dir", False) and at_init: # Try again after wiping output_dir shutil.rmtree(self.args.output_dir) self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) diff --git a/tests/pipeline/test_sglang_infernecer.py b/tests/pipeline/test_sglang_infernecer.py index 5dec650bc..42f2912e4 100644 --- a/tests/pipeline/test_sglang_infernecer.py +++ b/tests/pipeline/test_sglang_infernecer.py @@ -1,5 +1,8 @@ import numpy as np import pytest + +pytest.importorskip("sglang") + from sglang.srt.entrypoints.engine import Engine from sglang.srt.server_args import ServerArgs