4141
4242# Add import for Mamba models
4343from bionemo .evo2 .models .mamba import MAMBA_MODEL_OPTIONS , MambaModel
44+ from bionemo .evo2 .models .peft import Evo2LoRA
4445from bionemo .llm .data import collate
4546from bionemo .llm .lightning import LightningPassthroughPredictionMixin
4647from bionemo .llm .utils .callbacks import PredictionWriter
@@ -159,6 +160,13 @@ def parse_args():
159160 "know a model was trained with a specific interpolation factor for ROPE, provide it here, it can make a big "
160161 "difference in accuracy." ,
161162 )
163+ ap .add_argument (
164+ "--lora-checkpoint-path" ,
165+ type = Path ,
166+ required = False ,
167+ default = None ,
168+ help = "Path to the lora states to restore from." ,
169+ )
162170 return ap .parse_args ()
163171
164172
@@ -261,6 +269,11 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor | dict[str
261269class HyenaPredictor (BasePredictor , HyenaModel ):
262270 """A predictor for the Hyena model. This adds in the predict step and the passthrough method."""
263271
272+ def configure_model (self , * args , ** kwargs ) -> None :
273+ """Configure the model."""
274+ super ().configure_model (* args , ** kwargs )
275+ self .trainer .strategy ._init_model_parallel = True
276+
264277
265278class MambaPredictor (BasePredictor , MambaModel ):
266279 """Mamba model for prediction with additional metrics."""
@@ -397,6 +410,7 @@ def predict(
397410 num_layers : int | None = None ,
398411 seq_len_interpolation_factor : int | None = None ,
399412 files_per_subdir : int | None = None ,
413+ lora_checkpoint_path : Path | None = None ,
400414):
401415 """Inference workflow for Evo2.
402416
@@ -424,6 +438,77 @@ def predict(
424438 )
425439 global_batch_size = micro_batch_size * world_size // model_parallel_size
426440
441+ callbacks = [
442+ PredictionWriter (
443+ output_dir = output_dir ,
444+ write_interval = write_interval ,
445+ batch_dim_key_defaults = {"token_logits" : 0 },
446+ seq_dim_key_defaults = {"token_logits" : 1 },
447+ files_per_subdir = files_per_subdir ,
448+ save_all_model_parallel_ranks = False , # only write one copy of predictions.
449+ )
450+ ]
451+
452+ # The following two config options are really only used for testing, but may also be useful for getting output from
453+ # specific layers of the model.
454+ config_modifiers_init = {}
455+ if hybrid_override_pattern is not None :
456+ config_modifiers_init ["hybrid_override_pattern" ] = hybrid_override_pattern
457+ if num_layers is not None :
458+ config_modifiers_init ["num_layers" ] = num_layers
459+
460+ tokenizer = get_nmt_tokenizer ("byte-level" )
461+
462+ # Select model config based on model type
463+ if model_type == "hyena" :
464+ if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None :
465+ # TODO remove this override once we add this as a default upstream in NeMo.
466+ # if you see this, just check the pointed to model option for the 1m model in nemo and see if it already
467+ # has this option set.
468+ config_modifiers_init ["seq_len_interpolation_factor" ] = 128
469+
470+ if model_size not in HYENA_MODEL_OPTIONS :
471+ raise ValueError (f"Invalid model size for Hyena: { model_size } " )
472+ config = HYENA_MODEL_OPTIONS [model_size ](
473+ forward_step_fn = hyena_predict_forward_step ,
474+ data_step_fn = hyena_predict_data_step , # , attention_backend=AttnBackend.fused,
475+ distribute_saved_activations = False if sequence_parallel and tensor_parallel_size > 1 else True ,
476+ # Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to
477+ # the projection layer of the hyena mixer.
478+ vortex_style_fp8 = fp8 and not full_fp8 ,
479+ ** config_modifiers_init ,
480+ )
481+
482+ if lora_checkpoint_path :
483+ model_transform = Evo2LoRA (peft_ckpt_path = str (lora_checkpoint_path ))
484+ callbacks .append (model_transform )
485+ else :
486+ model_transform = None
487+
488+ model = HyenaPredictor (
489+ config ,
490+ tokenizer = tokenizer ,
491+ output_log_prob_seqs = output_log_prob_seqs ,
492+ log_prob_collapse_option = log_prob_collapse_option ,
493+ model_transform = model_transform ,
494+ )
495+ else : # mamba
496+ if model_size not in MAMBA_MODEL_OPTIONS :
497+ raise ValueError (f"Invalid model size for Mamba: { model_size } " )
498+ config = MAMBA_MODEL_OPTIONS [model_size ](
499+ forward_step_fn = hyena_predict_forward_step , # Can reuse the same forward steps
500+ data_step_fn = hyena_predict_data_step ,
501+ distribute_saved_activations = False if sequence_parallel and tensor_parallel_size > 1 else True ,
502+ ** config_modifiers_init ,
503+ )
504+
505+ model = MambaPredictor (
506+ config ,
507+ tokenizer = tokenizer ,
508+ output_log_prob_seqs = output_log_prob_seqs ,
509+ log_prob_collapse_option = log_prob_collapse_option ,
510+ )
511+
427512 # Create PTL trainer.
428513 trainer = nl .Trainer (
429514 accelerator = "gpu" ,
@@ -451,16 +536,7 @@ def predict(
451536 log_every_n_steps = 1 ,
452537 limit_val_batches = 10 ,
453538 num_sanity_val_steps = 0 ,
454- callbacks = [
455- PredictionWriter (
456- output_dir = output_dir ,
457- write_interval = write_interval ,
458- batch_dim_key_defaults = {"token_logits" : 0 },
459- seq_dim_key_defaults = {"token_logits" : 1 },
460- files_per_subdir = files_per_subdir ,
461- save_all_model_parallel_ranks = False , # only write one copy of predictions.
462- )
463- ],
539+ callbacks = callbacks ,
464540 plugins = nl .MegatronMixedPrecision (
465541 precision = "bf16-mixed" ,
466542 params_dtype = torch .bfloat16 ,
@@ -471,42 +547,6 @@ def predict(
471547 fp8_amax_compute_algo = "max" if fp8 and full_fp8 else "most_recent" ,
472548 ),
473549 )
474- # The following two config options are really only used for testing, but may also be useful for getting output from
475- # specific layers of the model.
476- config_modifiers_init = {}
477- if hybrid_override_pattern is not None :
478- config_modifiers_init ["hybrid_override_pattern" ] = hybrid_override_pattern
479- if num_layers is not None :
480- config_modifiers_init ["num_layers" ] = num_layers
481- # Select model config based on model type
482- if model_type == "hyena" :
483- if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None :
484- # TODO remove this override once we add this as a default upstream in NeMo.
485- # if you see this, just check the pointed to model option for the 1m model in nemo and see if it already
486- # has this option set.
487- config_modifiers_init ["seq_len_interpolation_factor" ] = 128
488-
489- if model_size not in HYENA_MODEL_OPTIONS :
490- raise ValueError (f"Invalid model size for Hyena: { model_size } " )
491- config = HYENA_MODEL_OPTIONS [model_size ](
492- forward_step_fn = hyena_predict_forward_step ,
493- data_step_fn = hyena_predict_data_step , # , attention_backend=AttnBackend.fused,
494- distribute_saved_activations = False if sequence_parallel and tensor_parallel_size > 1 else True ,
495- # Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to
496- # the projection layer of the hyena mixer.
497- vortex_style_fp8 = fp8 and not full_fp8 ,
498- ** config_modifiers_init ,
499- )
500- else : # mamba
501- if model_size not in MAMBA_MODEL_OPTIONS :
502- raise ValueError (f"Invalid model size for Mamba: { model_size } " )
503- config = MAMBA_MODEL_OPTIONS [model_size ](
504- forward_step_fn = hyena_predict_forward_step , # Can reuse the same forward steps
505- data_step_fn = hyena_predict_data_step ,
506- distribute_saved_activations = False if sequence_parallel and tensor_parallel_size > 1 else True ,
507- ** config_modifiers_init ,
508- )
509-
510550 trainer .strategy ._setup_optimizers = False
511551
512552 nemo_logger = NeMoLogger (log_dir = work_dir )
@@ -518,23 +558,6 @@ def predict(
518558 resume_from_path = str (ckpt_dir ),
519559 restore_config = None ,
520560 )
521- tokenizer = get_nmt_tokenizer ("byte-level" )
522-
523- # Create appropriate model based on type
524- if model_type == "hyena" :
525- model = HyenaPredictor (
526- config ,
527- tokenizer = tokenizer ,
528- output_log_prob_seqs = output_log_prob_seqs ,
529- log_prob_collapse_option = log_prob_collapse_option ,
530- )
531- else : # mamba
532- model = MambaPredictor (
533- config ,
534- tokenizer = tokenizer ,
535- output_log_prob_seqs = output_log_prob_seqs ,
536- log_prob_collapse_option = log_prob_collapse_option ,
537- )
538561
539562 resume .setup (trainer , model ) # this pulls weights from the starting checkpoint.
540563
@@ -573,6 +596,7 @@ def main():
573596 num_layers = args .num_layers ,
574597 files_per_subdir = args .files_per_subdir ,
575598 write_interval = args .write_interval ,
599+ lora_checkpoint_path = args .lora_checkpoint_path ,
576600 )
577601
578602
0 commit comments