Skip to content

Commit 5896940

Browse files
committed
Add lora-checkpointing option to evo2 predict
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent fc1cb60 commit 5896940

4 files changed

Lines changed: 205 additions & 64 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 87 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
# Add import for Mamba models
4343
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel
44+
from bionemo.evo2.models.peft import Evo2LoRA
4445
from bionemo.llm.data import collate
4546
from bionemo.llm.lightning import LightningPassthroughPredictionMixin
4647
from bionemo.llm.utils.callbacks import PredictionWriter
@@ -159,6 +160,13 @@ def parse_args():
159160
"know a model was trained with a specific interpolation factor for ROPE, provide it here, it can make a big "
160161
"difference in accuracy.",
161162
)
163+
ap.add_argument(
164+
"--lora-checkpoint-path",
165+
type=Path,
166+
required=False,
167+
default=None,
168+
help="Path to the lora states to restore from.",
169+
)
162170
return ap.parse_args()
163171

164172

@@ -261,6 +269,11 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor | dict[str
261269
class HyenaPredictor(BasePredictor, HyenaModel):
262270
"""A predictor for the Hyena model. This adds in the predict step and the passthrough method."""
263271

272+
def configure_model(self, *args, **kwargs) -> None:
273+
"""Configure the model."""
274+
super().configure_model(*args, **kwargs)
275+
self.trainer.strategy._init_model_parallel = True
276+
264277

265278
class MambaPredictor(BasePredictor, MambaModel):
266279
"""Mamba model for prediction with additional metrics."""
@@ -397,6 +410,7 @@ def predict(
397410
num_layers: int | None = None,
398411
seq_len_interpolation_factor: int | None = None,
399412
files_per_subdir: int | None = None,
413+
lora_checkpoint_path: Path | None = None,
400414
):
401415
"""Inference workflow for Evo2.
402416
@@ -424,6 +438,77 @@ def predict(
424438
)
425439
global_batch_size = micro_batch_size * world_size // model_parallel_size
426440

441+
callbacks = [
442+
PredictionWriter(
443+
output_dir=output_dir,
444+
write_interval=write_interval,
445+
batch_dim_key_defaults={"token_logits": 0},
446+
seq_dim_key_defaults={"token_logits": 1},
447+
files_per_subdir=files_per_subdir,
448+
save_all_model_parallel_ranks=False, # only write one copy of predictions.
449+
)
450+
]
451+
452+
# The following two config options are really only used for testing, but may also be useful for getting output from
453+
# specific layers of the model.
454+
config_modifiers_init = {}
455+
if hybrid_override_pattern is not None:
456+
config_modifiers_init["hybrid_override_pattern"] = hybrid_override_pattern
457+
if num_layers is not None:
458+
config_modifiers_init["num_layers"] = num_layers
459+
460+
tokenizer = get_nmt_tokenizer("byte-level")
461+
462+
# Select model config based on model type
463+
if model_type == "hyena":
464+
if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None:
465+
# TODO remove this override once we add this as a default upstream in NeMo.
466+
# if you see this, just check the pointed to model option for the 1m model in nemo and see if it already
467+
# has this option set.
468+
config_modifiers_init["seq_len_interpolation_factor"] = 128
469+
470+
if model_size not in HYENA_MODEL_OPTIONS:
471+
raise ValueError(f"Invalid model size for Hyena: {model_size}")
472+
config = HYENA_MODEL_OPTIONS[model_size](
473+
forward_step_fn=hyena_predict_forward_step,
474+
data_step_fn=hyena_predict_data_step, # , attention_backend=AttnBackend.fused,
475+
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
476+
# Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to
477+
# the projection layer of the hyena mixer.
478+
vortex_style_fp8=fp8 and not full_fp8,
479+
**config_modifiers_init,
480+
)
481+
482+
if lora_checkpoint_path:
483+
model_transform = Evo2LoRA(peft_ckpt_path=str(lora_checkpoint_path))
484+
callbacks.append(model_transform)
485+
else:
486+
model_transform = None
487+
488+
model = HyenaPredictor(
489+
config,
490+
tokenizer=tokenizer,
491+
output_log_prob_seqs=output_log_prob_seqs,
492+
log_prob_collapse_option=log_prob_collapse_option,
493+
model_transform=model_transform,
494+
)
495+
else: # mamba
496+
if model_size not in MAMBA_MODEL_OPTIONS:
497+
raise ValueError(f"Invalid model size for Mamba: {model_size}")
498+
config = MAMBA_MODEL_OPTIONS[model_size](
499+
forward_step_fn=hyena_predict_forward_step, # Can reuse the same forward steps
500+
data_step_fn=hyena_predict_data_step,
501+
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
502+
**config_modifiers_init,
503+
)
504+
505+
model = MambaPredictor(
506+
config,
507+
tokenizer=tokenizer,
508+
output_log_prob_seqs=output_log_prob_seqs,
509+
log_prob_collapse_option=log_prob_collapse_option,
510+
)
511+
427512
# Create PTL trainer.
428513
trainer = nl.Trainer(
429514
accelerator="gpu",
@@ -451,16 +536,7 @@ def predict(
451536
log_every_n_steps=1,
452537
limit_val_batches=10,
453538
num_sanity_val_steps=0,
454-
callbacks=[
455-
PredictionWriter(
456-
output_dir=output_dir,
457-
write_interval=write_interval,
458-
batch_dim_key_defaults={"token_logits": 0},
459-
seq_dim_key_defaults={"token_logits": 1},
460-
files_per_subdir=files_per_subdir,
461-
save_all_model_parallel_ranks=False, # only write one copy of predictions.
462-
)
463-
],
539+
callbacks=callbacks,
464540
plugins=nl.MegatronMixedPrecision(
465541
precision="bf16-mixed",
466542
params_dtype=torch.bfloat16,
@@ -471,42 +547,6 @@ def predict(
471547
fp8_amax_compute_algo="max" if fp8 and full_fp8 else "most_recent",
472548
),
473549
)
474-
# The following two config options are really only used for testing, but may also be useful for getting output from
475-
# specific layers of the model.
476-
config_modifiers_init = {}
477-
if hybrid_override_pattern is not None:
478-
config_modifiers_init["hybrid_override_pattern"] = hybrid_override_pattern
479-
if num_layers is not None:
480-
config_modifiers_init["num_layers"] = num_layers
481-
# Select model config based on model type
482-
if model_type == "hyena":
483-
if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None:
484-
# TODO remove this override once we add this as a default upstream in NeMo.
485-
# if you see this, just check the pointed to model option for the 1m model in nemo and see if it already
486-
# has this option set.
487-
config_modifiers_init["seq_len_interpolation_factor"] = 128
488-
489-
if model_size not in HYENA_MODEL_OPTIONS:
490-
raise ValueError(f"Invalid model size for Hyena: {model_size}")
491-
config = HYENA_MODEL_OPTIONS[model_size](
492-
forward_step_fn=hyena_predict_forward_step,
493-
data_step_fn=hyena_predict_data_step, # , attention_backend=AttnBackend.fused,
494-
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
495-
# Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to
496-
# the projection layer of the hyena mixer.
497-
vortex_style_fp8=fp8 and not full_fp8,
498-
**config_modifiers_init,
499-
)
500-
else: # mamba
501-
if model_size not in MAMBA_MODEL_OPTIONS:
502-
raise ValueError(f"Invalid model size for Mamba: {model_size}")
503-
config = MAMBA_MODEL_OPTIONS[model_size](
504-
forward_step_fn=hyena_predict_forward_step, # Can reuse the same forward steps
505-
data_step_fn=hyena_predict_data_step,
506-
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
507-
**config_modifiers_init,
508-
)
509-
510550
trainer.strategy._setup_optimizers = False
511551

512552
nemo_logger = NeMoLogger(log_dir=work_dir)
@@ -518,23 +558,6 @@ def predict(
518558
resume_from_path=str(ckpt_dir),
519559
restore_config=None,
520560
)
521-
tokenizer = get_nmt_tokenizer("byte-level")
522-
523-
# Create appropriate model based on type
524-
if model_type == "hyena":
525-
model = HyenaPredictor(
526-
config,
527-
tokenizer=tokenizer,
528-
output_log_prob_seqs=output_log_prob_seqs,
529-
log_prob_collapse_option=log_prob_collapse_option,
530-
)
531-
else: # mamba
532-
model = MambaPredictor(
533-
config,
534-
tokenizer=tokenizer,
535-
output_log_prob_seqs=output_log_prob_seqs,
536-
log_prob_collapse_option=log_prob_collapse_option,
537-
)
538561

539562
resume.setup(trainer, model) # this pulls weights from the starting checkpoint.
540563

@@ -573,6 +596,7 @@ def main():
573596
num_layers=args.num_layers,
574597
files_per_subdir=args.files_per_subdir,
575598
write_interval=args.write_interval,
599+
lora_checkpoint_path=args.lora_checkpoint_path,
576600
)
577601

578602

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,18 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
636636
default=False,
637637
help="Enable CUDA memory cleanup before validation to prevent initialization errors.",
638638
)
639+
parser.add_argument(
640+
"--lora-alpha",
641+
type=int,
642+
default=None,
643+
help="Alpha parameter for LoRA fine-tuning.",
644+
)
645+
parser.add_argument(
646+
"--lora-dim",
647+
type=int,
648+
default=None,
649+
help="Dim parameter for LoRA fine-tuning.",
650+
)
639651

640652
recompute_group = parser.add_mutually_exclusive_group(required=False)
641653
recompute_group.add_argument("--no-activation-checkpointing", action="store_true", default=False)
@@ -801,7 +813,16 @@ def train(args: argparse.Namespace) -> nl.Trainer:
801813
# Lora adaptors configuration
802814
lora_transform = None
803815
if args.lora_finetune:
804-
lora_transform = Evo2LoRA(peft_ckpt_path=args.lora_checkpoint_path)
816+
lora_kwargs = {
817+
k: v
818+
for k, v in {
819+
"alpha": args.lora_alpha,
820+
"dim": args.lora_dim,
821+
}.items()
822+
if v is not None
823+
}
824+
825+
lora_transform = Evo2LoRA(peft_ckpt_path=args.lora_checkpoint_path, **lora_kwargs)
805826

806827
model = llm.HyenaModel(model_config, tokenizer=data_module.tokenizer, model_transform=lora_transform)
807828
elif model_type == "mamba": # mamba

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,14 @@ def small_training_finetune_cmd(
5858
f"{'--global-batch-size ' + str(global_batch_size) if global_batch_size is not None else ''}"
5959
)
6060
return cmd
61+
62+
63+
def predict_cmd(ckpt_dir: str, output_dir: str, fasta_file_path: str, additional_args: str = ""):
64+
"""Command fro predict."""
65+
cmd = (
66+
f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {ckpt_dir} --output-dir {output_dir} "
67+
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --tensor-parallel-size 1 "
68+
f"--pipeline-model-parallel-size 1 --context-parallel-size 1 {additional_args}"
69+
)
70+
71+
return cmd

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@
3131
from bionemo.core.data.load import load
3232
from bionemo.llm.lightning import batch_collator
3333
from bionemo.testing.data.fasta import ALU_SEQUENCE, create_fasta_file
34+
from bionemo.testing.subprocess_utils import run_command_in_subprocess
3435
from bionemo.testing.torch import check_fp8_support
3536

37+
from .common import predict_cmd, small_training_finetune_cmd
38+
3639

3740
def is_a6000_gpu() -> bool:
3841
# Check if any of the visible GPUs is an A6000
@@ -364,3 +367,85 @@ def test_predict_evo2_equivalent_with_log_probs(
364367
else:
365368
rel = 1e-6
366369
assert log_probs.item() == pytest.approx(baseline_predictions_7b_1m_results[original_idx.item()], rel=rel)
370+
371+
372+
@pytest.mark.timeout(512)
373+
@pytest.mark.slow
374+
def test_different_results_with_without_peft(tmp_path):
375+
try:
376+
base_model_checkpoint_path = load("evo2/1b-8k:1.0")
377+
except ValueError as e:
378+
if e.args[0].endswith("does not have an NGC URL."):
379+
raise ValueError(
380+
"Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, "
381+
"one or more files are missing from ngc."
382+
)
383+
else:
384+
raise e
385+
386+
num_steps = 2
387+
388+
result_dir = tmp_path / "lora_finetune"
389+
390+
# Note: The command assumes that `train_evo2` is in your PATH.
391+
command_finetune = small_training_finetune_cmd(
392+
result_dir,
393+
max_steps=num_steps,
394+
val_check=num_steps,
395+
prev_ckpt=base_model_checkpoint_path,
396+
create_tflops_callback=False,
397+
additional_args="--lora-finetune",
398+
)
399+
stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path))
400+
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
401+
assert "Loading adapters from" not in stdout_finetune
402+
403+
# Check if checkpoints dir exists
404+
checkpoints_dir = result_dir / "evo2" / "checkpoints"
405+
assert checkpoints_dir.exists(), "Checkpoints folder does not exist."
406+
407+
# Create a sample FASTA file to run predictions
408+
fasta_file_path = tmp_path / "test.fasta"
409+
create_fasta_file(fasta_file_path, 3, sequence_lengths=[32, 65, 129], repeating_dna_pattern=ALU_SEQUENCE)
410+
411+
result_dir_original = tmp_path / "results_original"
412+
cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_original, fasta_file_path)
413+
stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path))
414+
415+
# Assert that the output directory was created.
416+
pred_files_original = glob.glob(str(result_dir_original / "predictions__rank_*.pt"))
417+
assert len(pred_files_original) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_original)}"
418+
419+
# Find the checkpoint dir generated by finetuning
420+
expected_checkpoint_suffix = f"{num_steps}.0-last"
421+
# Check if any subfolder ends with the expected suffix
422+
matching_subfolders = [
423+
p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
424+
]
425+
426+
assert matching_subfolders, (
427+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
428+
)
429+
430+
result_dir_peft = tmp_path / "results_peft"
431+
additional_args = f"--lora-checkpoint-path {matching_subfolders[0]}"
432+
cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_peft, fasta_file_path, additional_args)
433+
stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path))
434+
assert "Loading adapters from" in stdout_predict
435+
436+
pred_files_peft = glob.glob(str(result_dir_peft / "predictions__rank_*.pt"))
437+
assert len(pred_files_peft) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_peft)}"
438+
439+
results_original = torch.load(f"{result_dir_original}/predictions__rank_0__dp_rank_0.pt")
440+
results_peft = torch.load(f"{result_dir_peft}/predictions__rank_0__dp_rank_0.pt")
441+
442+
seq_idx_original = results_original["seq_idx"]
443+
seq_idx_peft = results_peft["seq_idx"]
444+
assert torch.equal(seq_idx_original, seq_idx_peft), f"Tensors differ: {seq_idx_original} vs {seq_idx_peft}"
445+
446+
logits_original = results_original["token_logits"]
447+
logits_peft = results_peft["token_logits"]
448+
assert (logits_original != logits_peft).any()
449+
assert logits_original.shape == logits_peft.shape, (
450+
f"Shapes don't match: {logits_original.shape} vs {logits_peft.shape}"
451+
)

0 commit comments

Comments
 (0)