Skip to content

Commit c91aef2

Browse files
committed
Add lora-checkpointing option to evo2 predict
1 parent 9750267 commit c91aef2

3 files changed

Lines changed: 183 additions & 62 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 85 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
# Add import for Mamba models
4242
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel
43+
from bionemo.evo2.run.peft import Evo2LoRA
4344
from bionemo.llm.lightning import LightningPassthroughPredictionMixin
4445
from bionemo.llm.model.biobert.lightning import get_batch_on_this_context_parallel_rank
4546
from bionemo.llm.utils.callbacks import PredictionWriter
@@ -130,6 +131,13 @@ def parse_args():
130131
"know a model was trained with a specific interpolation factor for ROPE, provide it here, it can make a big "
131132
"difference in accuracy.",
132133
)
134+
ap.add_argument(
135+
"--lora-checkpoint-path",
136+
type=Path,
137+
required=False,
138+
default=None,
139+
help="Path to the lora states to restore from.",
140+
)
133141
return ap.parse_args()
134142

135143

@@ -168,6 +176,10 @@ def __init__(
168176
self.output_log_prob_seqs = output_log_prob_seqs
169177
self.log_prob_collapse_option = log_prob_collapse_option
170178

179+
def configure_model(self, *args, **kwargs) -> None:
180+
super().configure_model(*args, **kwargs)
181+
self.trainer.strategy._init_model_parallel = True
182+
171183
def predict_step(self, batch, batch_idx: int | None = None) -> Tensor:
172184
"""Alias for forward_step, also log the pad mask since sequences may not all have the same length."""
173185
if len(batch) == 0:
@@ -326,6 +338,7 @@ def predict(
326338
hybrid_override_pattern: str | None = None,
327339
num_layers: int | None = None,
328340
seq_len_interpolation_factor: int | None = None,
341+
lora_checkpoint_path: Path | None = None,
329342
):
330343
"""Inference workflow for Evo2.
331344
@@ -342,57 +355,26 @@ def predict(
342355
f"Requested model parallel size {model_parallel_size} is greater than the "
343356
f"number of available CUDA devices {torch.cuda.device_count()}"
344357
)
345-
# Create PTL trainer.
346-
trainer = nl.Trainer(
347-
accelerator="gpu",
348-
devices=model_parallel_size,
349-
strategy=nl.MegatronStrategy(
350-
drop_last_batch=False,
351-
tensor_model_parallel_size=tensor_parallel_size,
352-
pipeline_model_parallel_size=pipeline_model_parallel_size,
353-
context_parallel_size=context_parallel_size,
354-
pipeline_dtype=torch.bfloat16,
355-
ckpt_load_optimizer=False, # Needs to be false for a normal model checkpoint.
356-
ckpt_save_optimizer=False,
357-
ckpt_async_save=False,
358-
sequence_parallel=tensor_parallel_size > 1 and sequence_parallel,
359-
save_ckpt_format=ckpt_format,
360-
ckpt_load_strictness="log_all",
361-
data_sampler=nl.MegatronDataSampler(
362-
micro_batch_size=batch_size,
363-
global_batch_size=batch_size,
364-
seq_len=8192,
365-
output_log=False, # this is needed for predict step to work
366-
),
367-
),
368-
log_every_n_steps=1,
369-
limit_val_batches=10,
370-
num_sanity_val_steps=0,
371-
callbacks=[
372-
PredictionWriter(
373-
output_dir=output_dir,
374-
write_interval="epoch",
375-
batch_dim_key_defaults={"token_logits": 0},
376-
seq_dim_key_defaults={"token_logits": 1},
377-
)
378-
],
379-
plugins=nl.MegatronMixedPrecision(
380-
precision="bf16-mixed",
381-
params_dtype=torch.bfloat16,
382-
# Only use FP8 in this plugin when using full FP8 precision and FP8.
383-
# Otherwise use vortex_style_fp8 in the model config.
384-
fp8="hybrid" if fp8 and full_fp8 else None,
385-
fp8_amax_history_len=16 if fp8 and full_fp8 else 1,
386-
fp8_amax_compute_algo="max" if fp8 and full_fp8 else "most_recent",
387-
),
388-
)
358+
359+
callbacks = [
360+
PredictionWriter(
361+
output_dir=output_dir,
362+
write_interval="epoch",
363+
batch_dim_key_defaults={"token_logits": 0},
364+
seq_dim_key_defaults={"token_logits": 1},
365+
)
366+
]
367+
389368
# The following two config options are really only used for testing, but may also be useful for getting output from
390369
# specific layers of the model.
391370
config_modifiers_init = {}
392371
if hybrid_override_pattern is not None:
393372
config_modifiers_init["hybrid_override_pattern"] = hybrid_override_pattern
394373
if num_layers is not None:
395374
config_modifiers_init["num_layers"] = num_layers
375+
376+
tokenizer = get_nmt_tokenizer("byte-level")
377+
396378
# Select model config based on model type
397379
if model_type == "hyena":
398380
if "-1m" in model_size and "nv" not in model_size and seq_len_interpolation_factor is None:
@@ -412,6 +394,20 @@ def predict(
412394
vortex_style_fp8=fp8 and not full_fp8,
413395
**config_modifiers_init,
414396
)
397+
398+
if lora_checkpoint_path:
399+
model_transform = Evo2LoRA(peft_ckpt_path=str(lora_checkpoint_path))
400+
callbacks.append(model_transform)
401+
else:
402+
model_transform = None
403+
404+
model = HyenaPredictor(
405+
config,
406+
tokenizer=tokenizer,
407+
output_log_prob_seqs=output_log_prob_seqs,
408+
log_prob_collapse_option=log_prob_collapse_option,
409+
model_transform=model_transform,
410+
)
415411
else: # mamba
416412
if model_size not in MAMBA_MODEL_OPTIONS:
417413
raise ValueError(f"Invalid model size for Mamba: {model_size}")
@@ -422,6 +418,50 @@ def predict(
422418
**config_modifiers_init,
423419
)
424420

421+
model = MambaPredictor(
422+
config,
423+
tokenizer=tokenizer,
424+
output_log_prob_seqs=output_log_prob_seqs,
425+
log_prob_collapse_option=log_prob_collapse_option,
426+
)
427+
428+
# Create PTL trainer.
429+
trainer = nl.Trainer(
430+
accelerator="gpu",
431+
devices=model_parallel_size,
432+
strategy=nl.MegatronStrategy(
433+
drop_last_batch=False,
434+
tensor_model_parallel_size=tensor_parallel_size,
435+
pipeline_model_parallel_size=pipeline_model_parallel_size,
436+
context_parallel_size=context_parallel_size,
437+
pipeline_dtype=torch.bfloat16,
438+
ckpt_load_optimizer=False, # Needs to be false for a normal model checkpoint.
439+
ckpt_save_optimizer=False,
440+
ckpt_async_save=False,
441+
sequence_parallel=tensor_parallel_size > 1 and sequence_parallel,
442+
save_ckpt_format=ckpt_format,
443+
ckpt_load_strictness="log_all",
444+
data_sampler=nl.MegatronDataSampler(
445+
micro_batch_size=batch_size,
446+
global_batch_size=batch_size,
447+
seq_len=8192,
448+
output_log=False, # this is needed for predict step to work
449+
),
450+
),
451+
log_every_n_steps=1,
452+
limit_val_batches=10,
453+
num_sanity_val_steps=0,
454+
callbacks=callbacks,
455+
plugins=nl.MegatronMixedPrecision(
456+
precision="bf16-mixed",
457+
params_dtype=torch.bfloat16,
458+
# Only use FP8 in this plugin when using full FP8 precision and FP8.
459+
# Otherwise use vortex_style_fp8 in the model config.
460+
fp8="hybrid" if fp8 and full_fp8 else None,
461+
fp8_amax_history_len=16 if fp8 and full_fp8 else 1,
462+
fp8_amax_compute_algo="max" if fp8 and full_fp8 else "most_recent",
463+
),
464+
)
425465
trainer.strategy._setup_optimizers = False
426466

427467
nemo_logger = NeMoLogger(log_dir=work_dir)
@@ -437,23 +477,6 @@ def predict(
437477
load_artifacts=False,
438478
),
439479
)
440-
tokenizer = get_nmt_tokenizer("byte-level")
441-
442-
# Create appropriate model based on type
443-
if model_type == "hyena":
444-
model = HyenaPredictor(
445-
config,
446-
tokenizer=tokenizer,
447-
output_log_prob_seqs=output_log_prob_seqs,
448-
log_prob_collapse_option=log_prob_collapse_option,
449-
)
450-
else: # mamba
451-
model = MambaPredictor(
452-
config,
453-
tokenizer=tokenizer,
454-
output_log_prob_seqs=output_log_prob_seqs,
455-
log_prob_collapse_option=log_prob_collapse_option,
456-
)
457480

458481
resume.setup(trainer, model) # this pulls weights from the starting checkpoint.
459482

@@ -488,6 +511,7 @@ def main():
488511
hybrid_override_pattern=args.hybrid_override_pattern,
489512
seq_len_interpolation_factor=args.seq_len_interpolation_factor,
490513
num_layers=args.num_layers,
514+
lora_checkpoint_path=args.lora_checkpoint_path,
491515
)
492516

493517

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,14 @@ def small_training_finetune_cmd(
4848
f"{'--create-tflops-callback' if create_tflops_callback else ''}"
4949
)
5050
return cmd
51+
52+
53+
def predict_cmd(ckpt_dir: str, output_dir: str, fasta_file_path: str, additional_args: str=""):
54+
"""Command fro predict."""
55+
cmd = (
56+
f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {ckpt_dir} --output-dir {output_dir} "
57+
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --tensor-parallel-size 1 "
58+
f"--pipeline-model-parallel-size 1 --context-parallel-size 1 {additional_args}"
59+
)
60+
61+
return cmd

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_lora.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@
1515
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
18+
import glob
19+
1820
import pytest
21+
import torch
1922

23+
from bionemo.core.data.load import load
24+
from bionemo.testing.data.fasta import ALU_SEQUENCE, create_fasta_file
2025
from bionemo.testing.subprocess_utils import run_command_in_subprocess
2126

22-
from .common import small_training_cmd, small_training_finetune_cmd
27+
from .common import predict_cmd, small_training_cmd, small_training_finetune_cmd
2328

2429

2530
@pytest.mark.timeout(512) # Optional: fail if the test takes too long.
@@ -130,3 +135,84 @@ def test_train_evo2_finetune_runs_lora(tmp_path, with_peft: bool):
130135
assert log_dir_ft.exists(), "Logs folder should exist."
131136
# Check if checkpoints dir exists
132137
assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist."
138+
139+
140+
@pytest.mark.timeout(512)
141+
@pytest.mark.slow
142+
def test_different_results_with_peft(tmp_path):
143+
try:
144+
base_model_checkpoint_path = load("evo2/1b-8k:1.0")
145+
except ValueError as e:
146+
if e.args[0].endswith("does not have an NGC URL."):
147+
raise ValueError(
148+
"Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, "
149+
"one or more files are missing from ngc."
150+
)
151+
else:
152+
raise e
153+
154+
num_steps = 2
155+
156+
result_dir = tmp_path / "lora_finetune"
157+
158+
# Note: The command assumes that `train_evo2` is in your PATH.
159+
command_finetune = small_training_finetune_cmd(
160+
result_dir,
161+
max_steps=num_steps,
162+
val_check=num_steps,
163+
prev_ckpt=base_model_checkpoint_path,
164+
create_tflops_callback=False,
165+
additional_args="--lora-finetune",
166+
)
167+
stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path))
168+
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
169+
assert "Loading adapters from" not in stdout_finetune
170+
171+
# Check if checkpoints dir exists
172+
checkpoints_dir = result_dir / "evo2" / "checkpoints"
173+
assert checkpoints_dir.exists(), "Checkpoints folder does not exist."
174+
175+
# Create a sample FASTA file to run predictions
176+
fasta_file_path = tmp_path / "test.fasta"
177+
create_fasta_file(fasta_file_path, 3, sequence_lengths=[32, 65, 129], repeating_dna_pattern=ALU_SEQUENCE)
178+
179+
result_dir_original = tmp_path / "results_original"
180+
cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_original, fasta_file_path)
181+
stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path))
182+
183+
# Assert that the output directory was created.
184+
pred_files_original = glob.glob(str(result_dir_original / "predictions__rank_*.pt"))
185+
assert len(pred_files_original) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_original)}"
186+
187+
# Find the checkpoint dir generated by finetuning
188+
expected_checkpoint_suffix = f"{num_steps}.0-last"
189+
# Check if any subfolder ends with the expected suffix
190+
matching_subfolders = [
191+
p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name)
192+
]
193+
194+
assert matching_subfolders, (
195+
f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}."
196+
)
197+
198+
result_dir_peft = tmp_path / "results_peft"
199+
additional_args = f"--lora-checkpoint-path {matching_subfolders[0]}"
200+
cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_peft, fasta_file_path, additional_args)
201+
stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path))
202+
assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune
203+
assert "Loading adapters from" in stdout_predict
204+
205+
pred_files_peft = glob.glob(str(result_dir_peft / "predictions__rank_*.pt"))
206+
assert len(pred_files_peft) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_peft)}"
207+
208+
results_original = torch.load(f"{result_dir_original}/predictions__rank_0.pt")
209+
results_peft = torch.load(f"{result_dir_peft}/predictions__rank_0.pt")
210+
211+
seq_idx_original = results_original["seq_idx"]
212+
seq_idx_peft = results_peft["seq_idx"]
213+
assert torch.equal(seq_idx_original, seq_idx_peft), f"Tensors differ: {seq_idx_original} vs {seq_idx_peft}"
214+
215+
logits_original = results_original["token_logits"]
216+
logits_peft = results_peft["token_logits"]
217+
assert (logits_original != logits_peft).any()
218+
assert logits_original.shape == logits_peft.shape, f"Shapes don't match: {logits_original.shape} vs {logits_peft.shape}"

0 commit comments

Comments
 (0)