diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 7f2db2203..8e92dce92 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -197,9 +197,14 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name): row, skipped_count = proc_fn(item) if row is None: continue + # Data starts with an assistant message, skip the entire conversation + if row["conversations"][0]["role"] == "assistant": + total_skipped_count += len(row["conversations"]) + continue total_skipped_count += skipped_count else: row = item + row, skipped_count = proc_fn(item) f.write(json.dumps(row, ensure_ascii=False) + "\n") if test_ds is not None: @@ -210,6 +215,10 @@ def process_and_save_ds(train_ds, test_ds, output_path, proc_fn, dataset_name): row, skipped_count = proc_fn(item) if row is None: continue + # Data starts with an assistant message, skip the entire conversation + if row["conversations"][0]["role"] == "assistant": + total_skipped_count += len(row["conversations"]) + continue total_skipped_count += skipped_count else: row = item diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 7a652df4a..5055bb7ff 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -4,7 +4,8 @@ import os import time from argparse import ArgumentParser, Namespace -from typing import List, Optional, Tuple, Union +from itertools import islice +from typing import Any, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -315,6 +316,39 @@ def build_target_model( return target_head, None +def load_checkpoint(args: Namespace) -> Tuple[int, int, Optional[str], Any]: + """ + Load checkpoint and return starting epoch and global_step + returns: + start_epoch: The starting epoch. + global_step: The global step. + checkpoint_path: The path to the checkpoint. + state: The training state. + """ + if not args.resume or not os.path.isdir(args.output_dir): + print_on_rank0("Starting training from scratch") + return 0, 0, None, None + + checkpoint_path = get_last_checkpoint(args.output_dir) + if not checkpoint_path: + print_on_rank0("No checkpoint found, starting from scratch") + return 0, 0, None, None + + training_state_path = os.path.join(checkpoint_path, "training_state.pt") + if not os.path.exists(training_state_path): + # Could be fine-tuning from a pretrained model without training state + print_on_rank0(f"Training state not found at {training_state_path}") + return 0, 0, checkpoint_path, None + + # Load training state + state = torch.load(training_state_path, weights_only=False, map_location="cpu") + start_epoch = state["epoch"] + global_step = state["global_step"] + + print_on_rank0(f"Resumed from epoch {start_epoch}, step {global_step}") + return start_epoch, global_step, checkpoint_path, state + + def sanity_check(args: Namespace) -> None: """ Perform sanity checks on the arguments. @@ -336,7 +370,9 @@ def sanity_check(args: Namespace) -> None: ), "train_hidden_states_path should not be None for usp" -def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]: +def build_draft_model( + args: Namespace, draft_model_last_checkpoint: Optional[str] +) -> Tuple[AutoDraftModelConfig, nn.Module]: # Handle draft model config if args.draft_model_config is None: # Auto-generate and save config file @@ -348,27 +384,10 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] # Use provided config file draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) - # Handle base ckpt, config file - draft_model_last_checkpoint = None - if args.ckpt_dir is not None: - if os.path.isdir(args.ckpt_dir): - draft_model_config = AutoDraftModelConfig.from_file( - os.path.join(args.ckpt_dir, "config.json") - ) - draft_model_last_checkpoint = args.ckpt_dir - print_on_rank0(f"Finetuning from base model: {draft_model_last_checkpoint}") - else: - raise ValueError( - f"Provided base model dir {args.ckpt_dir} is not a valid directory." - ) - - # detecting last ckpt for draft model - if args.resume and os.path.isdir(args.output_dir): - print_on_rank0(args.output_dir) - draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) - print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") - if draft_model_last_checkpoint: + draft_model_config = AutoDraftModelConfig.from_file( + os.path.join(args.ckpt_dir, "config.json") + ) draft_model = AutoEagle3DraftModel.from_pretrained( draft_model_last_checkpoint, attention_backend=args.attention_backend, @@ -395,13 +414,21 @@ def build_dataloaders( tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) # convert to dataloader - cache_params_string = ( + train_cache_params_string = ( f"{args.train_data_path}-" f"{args.max_length}-" f"{args.chat_template}-" f"{args.target_model_path}" # Tokenizer may also different ) - cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() + train_cache_key = hashlib.md5(train_cache_params_string.encode()).hexdigest() + eval_cache_params_string = ( + f"{args.eval_data_path}-" + f"{args.max_length}-" + f"{args.chat_template}-" + f"{args.target_model_path}" # Tokenizer may also different + ) + eval_cache_key = hashlib.md5(eval_cache_params_string.encode()).hexdigest() + cache_dir = os.path.join(args.cache_dir, "processed_dataset") train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] with rank_0_priority(): train_eagle3_dataset = build_eagle3_dataset( @@ -409,8 +436,8 @@ def build_dataloaders( tokenizer=tokenizer, chat_template=args.chat_template, max_length=args.max_length, - cache_dir=os.path.join(args.cache_dir, "processed_dataset"), - cache_key=cache_key, + cache_dir=cache_dir, + cache_key=train_cache_key, is_vlm=args.is_vlm, is_preformatted=args.is_preformatted, processor=processor, @@ -421,7 +448,7 @@ def build_dataloaders( target_vocab_size=draft_model_config.vocab_size, draft_vocab_size=draft_model_config.draft_vocab_size, cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), - cache_key=cache_key, + cache_key=train_cache_key, ) if args.train_hidden_states_path is not None: @@ -449,6 +476,8 @@ def build_dataloaders( tokenizer, args.chat_template, args.max_length, + cache_dir=cache_dir, + cache_key=eval_cache_key, is_vlm=args.is_vlm, processor=processor, num_proc=args.build_dataset_num_proc, @@ -652,10 +681,16 @@ def main(): print_args_with_dots(args) print_with_rank("Initialized distributed environment") + start_epoch, global_step, draft_model_last_checkpoint, optimizer_state = ( + load_checkpoint(args) + ) + # ================================================ # 2. Build models # ================================================ - draft_model_config, draft_model = build_draft_model(args) + draft_model_config, draft_model = build_draft_model( + args, draft_model_last_checkpoint + ) target_model, processor = build_target_model(args, draft_model_config, is_online) # ================================================ @@ -664,6 +699,8 @@ def main(): train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders( args, draft_model_config, processor ) + # Set this attribute to show draft model config in the tracker + args.draft_model_config_dict = draft_model_config.__dict__ # we load the vocab mapping then draft_model.load_vocab_mapping(vocab_mapping_path) @@ -724,14 +761,15 @@ def main(): warmup_ratio=args.warmup_ratio, total_steps=args.total_steps, ) + if optimizer_state is not None: + optimizer.load_state_dict(optimizer_state) + print_with_rank("Loaded optimizer state from checkpoint") print_with_rank("Initialized optimizer and scheduler") # ================================================ # 6. Build tracker # ================================================ tracker = build_tracker(args, parser) - global_step = 0 - start_epoch = 0 dist.barrier() last_time = time.time() @@ -741,6 +779,8 @@ def main(): # ================================================ print_on_rank0(f"Starting training from epoch {start_epoch}") + steps_to_skip = global_step - steps_per_epoch * start_epoch + for epoch in range(start_epoch, args.num_epochs): # Run training train_dataloader.sampler.set_epoch(epoch + 1) @@ -748,12 +788,18 @@ def main(): if dist.get_rank() == 0: progress_bar = tqdm( - train_dataloader, desc=f"Training Epoch {epoch}", leave=True + train_dataloader, + desc=f"Training Epoch {epoch}", + leave=True, + initial=steps_to_skip, ) else: progress_bar = train_dataloader - for data in progress_bar: + for batch_idx, data in enumerate( + islice(progress_bar, steps_to_skip, None), start=steps_to_skip + ): + steps_to_skip = 0 # reset for next epoch global_step += 1 # ================================================ @@ -789,12 +835,16 @@ def main(): ) run_backward_and_update(args, plosses, optimizer, global_step) + # detach losses and accuracies to avoid memory leak + plosses_for_metrics = [p.detach() for p in plosses] + acces_for_metrics = [a.detach() for a in acces] + # log training metrics if global_step % (args.log_interval * args.draft_accumulation_steps) == 0: record_metrcs( args, - acces, - plosses, + acces_for_metrics, + plosses_for_metrics, global_step // args.draft_accumulation_steps, tracker, optimizer, @@ -804,8 +854,10 @@ def main(): if dist.get_rank() == 0: time_per_step = time.time() - last_time last_time = time.time() - avg_loss = sum(pl for pl in plosses) / len(plosses) - avg_acc = sum(acces) / len(acces) + avg_loss = sum(pl for pl in plosses_for_metrics) / len( + plosses_for_metrics + ) + avg_acc = sum(acces_for_metrics) / len(acces_for_metrics) progress_bar.set_postfix( { "loss": f"{avg_loss:.2f}", diff --git a/specforge/tracker.py b/specforge/tracker.py index 02b7498c1..6a8f0a64e 100644 --- a/specforge/tracker.py +++ b/specforge/tracker.py @@ -130,8 +130,14 @@ def __init__(self, args, output_dir: str): if self.rank == 0: wandb.login(key=args.wandb_key) wandb.init( - project=args.wandb_project, name=args.wandb_name, config=vars(args) + project=args.wandb_project, + name=args.wandb_name, + config={ + **vars(args), + "draft_model_config_dict": args.draft_model_config_dict, + }, ) + wandb.save(args.draft_model_config) self.is_initialized = True def log(self, log_dict: Dict[str, Any], step: Optional[int] = None): diff --git a/specforge/utils.py b/specforge/utils.py index 57a423bbd..c015d2b6f 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -76,9 +76,9 @@ def print_on_rank0(message): logger.info(message) -def get_last_checkpoint(folder, prefix="epoch"): +def get_last_checkpoint(folder): content = os.listdir(folder) - _re_checkpoint = re.compile(r"^" + prefix + r"_(\d+)$") + _re_checkpoint = re.compile(r"^epoch_(\d+)_step_(\d+)$") checkpoints = [ path for path in content