4040
4141# Add import for Mamba models
4242from bionemo .evo2 .models .mamba import MAMBA_MODEL_OPTIONS , MambaModel
43+ from bionemo .evo2 .run .peft import Evo2LoRA
4344from bionemo .llm .lightning import LightningPassthroughPredictionMixin
4445from bionemo .llm .model .biobert .lightning import get_batch_on_this_context_parallel_rank
4546from bionemo .llm .utils .callbacks import PredictionWriter
@@ -130,6 +131,13 @@ def parse_args():
130131 "know a model was trained with a specific interpolation factor for ROPE, provide it here, it can make a big "
131132 "difference in accuracy." ,
132133 )
134+ ap .add_argument (
135+ "--lora-checkpoint-path" ,
136+ type = Path ,
137+ required = False ,
138+ default = None ,
139+ help = "Path to the lora states to restore from." ,
140+ )
133141 return ap .parse_args ()
134142
135143
@@ -168,6 +176,10 @@ def __init__(
168176 self .output_log_prob_seqs = output_log_prob_seqs
169177 self .log_prob_collapse_option = log_prob_collapse_option
170178
179+ def configure_model (self , * args , ** kwargs ) -> None :
180+ super ().configure_model (* args , ** kwargs )
181+ self .trainer .strategy ._init_model_parallel = True
182+
171183 def predict_step (self , batch , batch_idx : int | None = None ) -> Tensor :
172184 """Alias for forward_step, also log the pad mask since sequences may not all have the same length."""
173185 if len (batch ) == 0 :
@@ -326,6 +338,7 @@ def predict(
326338 hybrid_override_pattern : str | None = None ,
327339 num_layers : int | None = None ,
328340 seq_len_interpolation_factor : int | None = None ,
341+ lora_checkpoint_path : Path | None = None ,
329342):
330343 """Inference workflow for Evo2.
331344
@@ -342,57 +355,26 @@ def predict(
342355 f"Requested model parallel size { model_parallel_size } is greater than the "
343356 f"number of available CUDA devices { torch .cuda .device_count ()} "
344357 )
345- # Create PTL trainer.
346- trainer = nl .Trainer (
347- accelerator = "gpu" ,
348- devices = model_parallel_size ,
349- strategy = nl .MegatronStrategy (
350- drop_last_batch = False ,
351- tensor_model_parallel_size = tensor_parallel_size ,
352- pipeline_model_parallel_size = pipeline_model_parallel_size ,
353- context_parallel_size = context_parallel_size ,
354- pipeline_dtype = torch .bfloat16 ,
355- ckpt_load_optimizer = False , # Needs to be false for a normal model checkpoint.
356- ckpt_save_optimizer = False ,
357- ckpt_async_save = False ,
358- sequence_parallel = tensor_parallel_size > 1 and sequence_parallel ,
359- save_ckpt_format = ckpt_format ,
360- ckpt_load_strictness = "log_all" ,
361- data_sampler = nl .MegatronDataSampler (
362- micro_batch_size = batch_size ,
363- global_batch_size = batch_size ,
364- seq_len = 8192 ,
365- output_log = False , # this is needed for predict step to work
366- ),
367- ),
368- log_every_n_steps = 1 ,
369- limit_val_batches = 10 ,
370- num_sanity_val_steps = 0 ,
371- callbacks = [
372- PredictionWriter (
373- output_dir = output_dir ,
374- write_interval = "epoch" ,
375- batch_dim_key_defaults = {"token_logits" : 0 },
376- seq_dim_key_defaults = {"token_logits" : 1 },
377- )
378- ],
379- plugins = nl .MegatronMixedPrecision (
380- precision = "bf16-mixed" ,
381- params_dtype = torch .bfloat16 ,
382- # Only use FP8 in this plugin when using full FP8 precision and FP8.
383- # Otherwise use vortex_style_fp8 in the model config.
384- fp8 = "hybrid" if fp8 and full_fp8 else None ,
385- fp8_amax_history_len = 16 if fp8 and full_fp8 else 1 ,
386- fp8_amax_compute_algo = "max" if fp8 and full_fp8 else "most_recent" ,
387- ),
388- )
358+
359+ callbacks = [
360+ PredictionWriter (
361+ output_dir = output_dir ,
362+ write_interval = "epoch" ,
363+ batch_dim_key_defaults = {"token_logits" : 0 },
364+ seq_dim_key_defaults = {"token_logits" : 1 },
365+ )
366+ ]
367+
389368 # The following two config options are really only used for testing, but may also be useful for getting output from
390369 # specific layers of the model.
391370 config_modifiers_init = {}
392371 if hybrid_override_pattern is not None :
393372 config_modifiers_init ["hybrid_override_pattern" ] = hybrid_override_pattern
394373 if num_layers is not None :
395374 config_modifiers_init ["num_layers" ] = num_layers
375+
376+ tokenizer = get_nmt_tokenizer ("byte-level" )
377+
396378 # Select model config based on model type
397379 if model_type == "hyena" :
398380 if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None :
@@ -412,6 +394,20 @@ def predict(
412394 vortex_style_fp8 = fp8 and not full_fp8 ,
413395 ** config_modifiers_init ,
414396 )
397+
398+ if lora_checkpoint_path :
399+ model_transform = Evo2LoRA (peft_ckpt_path = str (lora_checkpoint_path ))
400+ callbacks .append (model_transform )
401+ else :
402+ model_transform = None
403+
404+ model = HyenaPredictor (
405+ config ,
406+ tokenizer = tokenizer ,
407+ output_log_prob_seqs = output_log_prob_seqs ,
408+ log_prob_collapse_option = log_prob_collapse_option ,
409+ model_transform = model_transform ,
410+ )
415411 else : # mamba
416412 if model_size not in MAMBA_MODEL_OPTIONS :
417413 raise ValueError (f"Invalid model size for Mamba: { model_size } " )
@@ -422,6 +418,50 @@ def predict(
422418 ** config_modifiers_init ,
423419 )
424420
421+ model = MambaPredictor (
422+ config ,
423+ tokenizer = tokenizer ,
424+ output_log_prob_seqs = output_log_prob_seqs ,
425+ log_prob_collapse_option = log_prob_collapse_option ,
426+ )
427+
428+ # Create PTL trainer.
429+ trainer = nl .Trainer (
430+ accelerator = "gpu" ,
431+ devices = model_parallel_size ,
432+ strategy = nl .MegatronStrategy (
433+ drop_last_batch = False ,
434+ tensor_model_parallel_size = tensor_parallel_size ,
435+ pipeline_model_parallel_size = pipeline_model_parallel_size ,
436+ context_parallel_size = context_parallel_size ,
437+ pipeline_dtype = torch .bfloat16 ,
438+ ckpt_load_optimizer = False , # Needs to be false for a normal model checkpoint.
439+ ckpt_save_optimizer = False ,
440+ ckpt_async_save = False ,
441+ sequence_parallel = tensor_parallel_size > 1 and sequence_parallel ,
442+ save_ckpt_format = ckpt_format ,
443+ ckpt_load_strictness = "log_all" ,
444+ data_sampler = nl .MegatronDataSampler (
445+ micro_batch_size = batch_size ,
446+ global_batch_size = batch_size ,
447+ seq_len = 8192 ,
448+ output_log = False , # this is needed for predict step to work
449+ ),
450+ ),
451+ log_every_n_steps = 1 ,
452+ limit_val_batches = 10 ,
453+ num_sanity_val_steps = 0 ,
454+ callbacks = callbacks ,
455+ plugins = nl .MegatronMixedPrecision (
456+ precision = "bf16-mixed" ,
457+ params_dtype = torch .bfloat16 ,
458+ # Only use FP8 in this plugin when using full FP8 precision and FP8.
459+ # Otherwise use vortex_style_fp8 in the model config.
460+ fp8 = "hybrid" if fp8 and full_fp8 else None ,
461+ fp8_amax_history_len = 16 if fp8 and full_fp8 else 1 ,
462+ fp8_amax_compute_algo = "max" if fp8 and full_fp8 else "most_recent" ,
463+ ),
464+ )
425465 trainer .strategy ._setup_optimizers = False
426466
427467 nemo_logger = NeMoLogger (log_dir = work_dir )
@@ -437,23 +477,6 @@ def predict(
437477 load_artifacts = False ,
438478 ),
439479 )
440- tokenizer = get_nmt_tokenizer ("byte-level" )
441-
442- # Create appropriate model based on type
443- if model_type == "hyena" :
444- model = HyenaPredictor (
445- config ,
446- tokenizer = tokenizer ,
447- output_log_prob_seqs = output_log_prob_seqs ,
448- log_prob_collapse_option = log_prob_collapse_option ,
449- )
450- else : # mamba
451- model = MambaPredictor (
452- config ,
453- tokenizer = tokenizer ,
454- output_log_prob_seqs = output_log_prob_seqs ,
455- log_prob_collapse_option = log_prob_collapse_option ,
456- )
457480
458481 resume .setup (trainer , model ) # this pulls weights from the starting checkpoint.
459482
@@ -488,6 +511,7 @@ def main():
488511 hybrid_override_pattern = args .hybrid_override_pattern ,
489512 seq_len_interpolation_factor = args .seq_len_interpolation_factor ,
490513 num_layers = args .num_layers ,
514+ lora_checkpoint_path = args .lora_checkpoint_path ,
491515 )
492516
493517
0 commit comments