ESM2 NVFP4 and MXFP8 support and documentation update.#1484
ESM2 NVFP4 and MXFP8 support and documentation update.#1484jomitchellnv wants to merge 10 commits intomainfrom
Conversation
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughIntroduces per-layer FP8/FP4 quantization support for TransformerEngine-accelerated ESM2 models via new TE-optimized model classes, layer-wise quantization resolution utilities, per-layer autocast contexts, NVTX instrumentation, and updates to five training strategies (DDP, DDP+CP, FSDP2, FSDP2+CP, mFSDP). Changes
Sequence DiagramsequenceDiagram
participant User
participant TrainScript as Train Script<br/>(DDP/FSDP2/mFSDP)
participant QuantUtil as Quantization<br/>Utils
participant Modeling as NVEsmConfig &<br/>NVEsmForMaskedLM
participant Encoder as NVEsmEncoder
participant TEAutocast as TE Autocast<br/>Context
participant Forward as Forward<br/>Pass
User->>TrainScript: Launch training with fp8/fp4 config
TrainScript->>QuantUtil: resolve_quantization_layers(num_layers, fp8_enabled, fp4_enabled, fp8_layers, fp4_layers)
QuantUtil-->>TrainScript: QuantizationLayers{fp8_0idx, fp4_0idx, ...}
TrainScript->>Modeling: NVEsmConfig.from_pretrained(model_tag)
Modeling-->>TrainScript: config
TrainScript->>Modeling: NVEsmForMaskedLM(config)
Modeling->>Encoder: create NVEsmEncoder instance
Encoder-->>Modeling: encoder
Modeling-->>TrainScript: model
TrainScript->>Encoder: initialize_quantization(fp8_layers, fp4_layers, fp8_recipe, fp4_recipe)
Encoder->>Encoder: build _layer_precision map {layer_idx: 'fp8'|'fp4'|None}
Encoder->>Encoder: store _fp8_recipe, _fp4_recipe
Encoder-->>TrainScript: initialized
TrainScript->>Forward: forward(input_ids, attention_mask)
Forward->>Encoder: for each layer_idx in layers
Encoder->>Encoder: get_layer_autocast(layer_idx)
alt Layer is FP8
Encoder-->>TEAutocast: nullcontext
else Layer is FP4
Encoder-->>TEAutocast: autocast(enabled=True, recipe=fp4_recipe)
else Layer is BF16/None
Encoder-->>TEAutocast: autocast(enabled=False)
end
TEAutocast->>Forward: execute layer within context
Forward->>Forward: nvtx.range_push("encoder_layer_N")
Forward->>Forward: hidden_states = layer(hidden_states, ...)
Forward->>Forward: nvtx.range_pop()
Forward-->>Encoder: hidden_states
Encoder-->>Forward: processed output
Forward-->>TrainScript: logits, loss
TrainScript->>QuantUtil: initialize_quant_stats_logging(if enabled)
QuantUtil->>QuantUtil: update_quant_stats_config with resolved layers
QuantUtil-->>TrainScript: debug API initialized
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
a67637c to
091299c
Compare
- includes capability to log out stats for MXFP8 and NVFP4 at the same time - Enables layer-wise precision setting Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
091299c to
734a25a
Compare
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1428.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py
Outdated
Show resolved
Hide resolved
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 11
🧹 Nitpick comments (11)
bionemo-recipes/recipes/esm2_native_te/.dockerignore (1)
33-34: Consider anchoring the scratch path ignore pattern.
j/on Line 34 matches any directory namedjat any depth. If this is intended as a repo-root local scratch dir, prefer/j/to avoid accidental exclusions in nested paths.Proposed tweak
- j/ + /j/🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/.dockerignore` around lines 33 - 34, The ignore pattern "j/" in .dockerignore matches any directory named "j" at any depth; change it to "/j/" to anchor it to the repository root so only the top-level scratch dir is ignored. Update the pattern "j/" to "/j/" (preserve the trailing slash) in the .dockerignore entry to avoid accidentally excluding nested directories named "j".bionemo-recipes/recipes/esm2_native_te/tests/test_train.py (1)
146-158: Config key rename looks correct; consider updating test naming for consistency.The config keys are correctly updated from
fp8_stats_configtoquant_stats_config. However, the test function name (test_sanity_ddp_fp8_stats_logging), docstring ("FP8 stats logging"), and variable names (fp8_log_dir) still reference FP8 specifically.Since
quant_stats_confignow supports both FP8 and FP4, consider renaming for clarity:
test_sanity_ddp_quant_stats_loggingquant_log_dirvariableThis is optional since the test still validates FP8 stats specifically.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/tests/test_train.py` around lines 146 - 158, Rename test identifiers and docstring to reflect the config rename from fp8_stats_config to quant_stats_config: update the test function name test_sanity_ddp_fp8_stats_logging to test_sanity_ddp_quant_stats_logging, change the fp8_log_dir variable to quant_log_dir (or similar), and update the docstring "FP8 stats logging" to "quant stats logging" while keeping all uses of quant_stats_config and the existing assertions intact so the test still validates FP8 behavior under the new config name.bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py (1)
60-67: Consider makingFP4_RECIPESa tuple for consistency withFP8_RECIPES.
FP8_RECIPESis defined as a tuple of classes, butFP4_RECIPESis a single class. While this works withisinstance(), making it a tuple would be more consistent and future-proof if additional FP4 recipes are added.♻️ Suggested change for consistency
-FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling +FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling,)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py` around lines 60 - 67, FP4_RECIPES is assigned a single class (transformer_engine.common.recipe.NVFP4BlockScaling) while FP8_RECIPES is a tuple; make FP4_RECIPES a tuple for consistency and to allow adding more entries later — update the FP4_RECIPES assignment to use a tuple containing NVFP4BlockScaling (e.g., FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling,)) so code that expects a sequence of recipe classes (similar to FP8_RECIPES) will work uniformly.bionemo-recipes/models/esm2/modeling_esm_te.py (1)
60-67: Same consistency suggestion: consider makingFP4_RECIPESa tuple.This matches the same pattern seen in the checkpoint
esm_nv.pyfiles. For consistency across the codebase, consider using a tuple.♻️ Suggested change
-FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling +FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling,)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/models/esm2/modeling_esm_te.py` around lines 60 - 67, FP4_RECIPES is defined as a single value whereas FP8_RECIPES is a tuple; make FP4_RECIPES a tuple for consistency by wrapping transformer_engine.common.recipe.NVFP4BlockScaling in a one-element tuple (use a trailing comma) so the constant mirrors FP8_RECIPES and any tuple-based handling of recipes (reference FP4_RECIPES and FP8_RECIPES names).bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml (1)
9-12: Minor formatting inconsistency: missing blank line before WandB section.Other config files (L1_3B.yaml, L1_15B_perf_test.yaml) have a blank line between the
datasetsection and thewandb_init_argscomment. Consider adding one for consistency.♻️ Suggested formatting fix
dataset: micro_batch_size: 4 tokenizer_revision: null + # WandB config🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml` around lines 9 - 12, Add a blank line between the dataset block and the WandB section to match other config files: locate the dataset section (keys dataset, micro_batch_size, tokenizer_revision) and insert a single empty line before the existing WandB comment/section (the "# WandB config" or the wandb_init_args section) so formatting is consistent with L1_3B.yaml and L1_15B_perf_test.yaml.bionemo-recipes/recipes/esm2_native_te/dataset.py (1)
60-62: Simplify redundant conditional forrevisionparameter.The expression
revision=tokenizer_revision if tokenizer_revision else Noneis equivalent to justrevision=tokenizer_revisionsince bothNoneand empty string are falsy and the default isNoneanyway.♻️ Simplify revision parameter
tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, revision=tokenizer_revision if tokenizer_revision else None + tokenizer_name, revision=tokenizer_revision )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/dataset.py` around lines 60 - 62, The call to AutoTokenizer.from_pretrained uses a redundant conditional for the revision argument; replace revision=tokenizer_revision if tokenizer_revision else None with simply revision=tokenizer_revision in the AutoTokenizer.from_pretrained(...) call so the revision parameter directly uses the tokenizer_revision variable.bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml (1)
104-107: Consider adding validation comments foruse_fp32_master_weights: null.Setting
use_fp32_master_weightstonullrather than a boolean may be intentional (e.g., to require explicit user configuration), but the training script at line 125 usesargs.use_fp32_master_weightsin a conditional. Verify thatnullis handled correctly (treated as falsy).💡 Consider adding a comment explaining the null default
-# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. +# Note: The layers are going to come in 1-indexed and we convert them to 0-indexed at runtime. fp8_layers: null fp4_layers: null +# Set explicitly to true/false. When null, defaults to false behavior. use_fp32_master_weights: null🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml` around lines 104 - 107, The default for use_fp32_master_weights is set to null which may be unexpected when later checked as args.use_fp32_master_weights; update the defaults.yaml to document that null is intentional and treated as falsy (or change the default to false) and add a short comment next to use_fp32_master_weights explaining that the training script expects a boolean-like value and that null will be treated as false by the conditional using args.use_fp32_master_weights; ensure any consumers (e.g., the code that checks args.use_fp32_master_weights) handle null explicitly if you want different behavior.bionemo-recipes/recipes/esm2_native_te/quantization.py (2)
61-62: Specify explicit encoding when opening files.Opening files without explicit encoding can lead to platform-dependent behavior. Specify
encoding="utf-8"for consistent behavior.♻️ Add explicit encoding
- with open(config_file, "r") as f: + with open(config_file, "r", encoding="utf-8") as f: config = yaml.safe_load(f)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/quantization.py` around lines 61 - 62, The file open call using config_file should specify an explicit encoding to avoid platform-dependent behavior; update the open(...) in the block that reads the YAML (the with open(config_file, "r") as f: config = yaml.safe_load(f) statement) to include encoding="utf-8" so the YAML is read consistently across platforms.
136-157: Consider using@dataclassforQuantizationLayers.The class is essentially a data container with no methods. Using
@dataclasswould reduce boilerplate and provide__repr__,__eq__, etc. automatically.♻️ Convert to dataclass
+from dataclasses import dataclass + +@dataclass class QuantizationLayers: """Resolved layer-wise quantization assignments. Attributes: fp8_layers_0indexed: 0-indexed FP8 layer numbers (for model internals), or None. fp4_layers_0indexed: 0-indexed FP4 layer numbers (for model internals), or None. fp8_layers_1indexed: 1-indexed FP8 layer numbers (for user-facing logs / quant stats), or None. fp4_layers_1indexed: 1-indexed FP4 layer numbers (for user-facing logs / quant stats), or None. """ - def __init__( - self, - fp8_layers_0indexed: list[int] | None, - fp4_layers_0indexed: list[int] | None, - fp8_layers_1indexed: list[int] | None, - fp4_layers_1indexed: list[int] | None, - ): - """Initialize QuantizationLayers with the resolved layer assignments.""" - self.fp8_layers_0indexed = fp8_layers_0indexed - self.fp4_layers_0indexed = fp4_layers_0indexed - self.fp8_layers_1indexed = fp8_layers_1indexed - self.fp4_layers_1indexed = fp4_layers_1indexed + fp8_layers_0indexed: list[int] | None + fp4_layers_0indexed: list[int] | None + fp8_layers_1indexed: list[int] | None + fp4_layers_1indexed: list[int] | None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/quantization.py` around lines 136 - 157, Replace the manual boilerplate class QuantizationLayers with a dataclass: add "from dataclasses import dataclass" and annotate the class with `@dataclass`, convert the four constructor args (fp8_layers_0indexed, fp4_layers_0indexed, fp8_layers_1indexed, fp4_layers_1indexed) to dataclass fields using Optional[list[int]] types (or list[int] | None) and remove the explicit __init__; keep the existing class docstring and attribute names so __repr__ and __eq__ are provided automatically.bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py (2)
604-609: Hardcodedmask_ratio_train = 0.15 * 0.8is documented but not configurable.The comment explains this matches ESM training, which is appropriate. However, consider extracting to a class constant for clarity and potential future configurability.
💡 Extract to class constant
class NVEsmEmbeddings(nn.Module): """Modified version of EsmEmbeddings to support THD inputs.""" + + # Hardcoded mask ratio used in all ESM model training runs (0.15 * 0.8) + _MASK_RATIO_TRAIN = 0.12 def _apply_token_dropout_bshd(self, embeddings, input_ids, attention_mask): ... - mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + mask_ratio_train = self._MASK_RATIO_TRAINAlso applies to: 625-637
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py` around lines 604 - 609, Replace the hardcoded mask ratio literal by a named class constant and use it wherever needed: define a class-level constant (e.g., MASK_RATIO_TRAIN = 0.15 * 0.8) on the model class, then update the local variable usage in the method that currently sets mask_ratio_train and in the similar block at lines 625-637 to read from that constant (self.MASK_RATIO_TRAIN or ClassName.MASK_RATIO_TRAIN) so scale_factor and embedding scaling use the named constant for clarity and potential configurability.
228-234: Address TODO: Create unit test for per-layer FP context selection.The TODO at line 234 notes the need to verify and test this logic. The FP context selection logic (BF16 vs FP8 vs FP4) is critical for correctness.
Would you like me to generate a unit test skeleton for verifying the per-layer FP context selection logic?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py` around lines 228 - 234, Extract the per-layer FP context selection into a small function (e.g., get_fp_context(fp_recipe)) that returns nullcontext when fp_recipe is in FP8_RECIPES, returns transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) when fp_recipe is in FP4_RECIPES, and returns transformer_engine.pytorch.autocast(enabled=False) otherwise; then add a unit test file that parametrizes several fp_recipe values (members of FP8_RECIPES, FP4_RECIPES, and a default/None case) and asserts that get_fp_context returns the expected context object type/behavior (e.g., is nullcontext for FP8 cases and an autocast context for FP4/default) to validate the per-layer FP context selection logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@bionemo-recipes/models/esm2/modeling_esm_te.py`:
- Around line 214-234: Add unit tests that exercise the per-layer FP context
selection in the loop that iterates self.layers and reads
self.layer_number_quantized_recipe_map: create three scenarios where the mapped
fp_recipe for a given layer is (a) an instance of FP8_RECIPES, (b) an instance
of FP4_RECIPES, and (c) None/other; for each case assert that fp_context becomes
nullcontext() for FP8, transformer_engine.pytorch.autocast(enabled=True,
recipe=fp_recipe) for FP4, and
transformer_engine.pytorch.autocast(enabled=False) for the default/BF16 path.
Use lightweight mocks/monkeypatching to replace
transformer_engine.pytorch.autocast with a stub that records its args/returns so
you can assert enabled and recipe values, and construct minimal model instances
(or unit-test the loop function directly) that set
layer_number_quantized_recipe_map and self.layers to trigger each branch; also
include a test that output_hidden_states True appends hidden_states to
all_hidden_states.
In `@bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py`:
- Line 234: The TODO on the forward path flags unverified precision-routing
behavior; replace it by adding focused unit tests that exercise the precision
routing logic for FP8, FP4 and BF16 (e.g., create tests that run the module's
forward pass with tensors/configs that should route through each precision
branch and assert the correct branch was used and outputs match expected
numerical/shape properties), and then remove the TODO comment. Target the
functions and code paths that implement precision routing (the forward path /
precision-routing conditional blocks referenced by the TODO in esm_nv.py) and
add parametrized tests that cover boundary cases and mixed precision
combinations to ensure deterministic routing.
In `@bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml`:
- Around line 21-23: Update the comment describing layer ranges so it matches
the actual regex stored in layer_name_regex_pattern; the current comment
mentions layers 0-4 / 1-5 but the regex
'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)'
targets layers 6–10, so change the comment to state layers 6-10 (or 7-11 if
using 1-indexed wording) to prevent confusion and ensure the comment and the
pattern are consistent.
In `@bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml`:
- Around line 19-23: Replace the invalid tensor type "fprop" in the
LogTensorStats block with the valid TransformerEngine tensor type "activation"
(modify the tensors: [dgrad, wgrad, fprop] entry to tensors: [dgrad, wgrad,
activation]); keep stats and freq as-is, or optionally refactor LogTensorStats
to use tensors_struct for per-tensor configs if you need different stats per
tensor type.
In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml`:
- Around line 55-61: Either remove the unused fp4_model_init_kwargs entry from
the fp4_config block in defaults.yaml, or implement FP4 model initialization in
train_fsdp2.py to mirror the FP8 pattern: when args.fp4_config.enabled and
args.fp4_config.fp4_model_init_kwargs.enabled are true, call
transformer_engine.pytorch.quantized_model_init(...) (or the appropriate FP4
init API) with args.fp4_config.fp4_model_init_kwargs before model training;
update any related code paths that reference
fp4_config.fp4_recipe/fp4_format/fp4_recipe_kwargs to ensure consistent
behavior.
In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py`:
- Around line 130-133: The assertion message for the padded_vocab_size check is
too long; update the assert in the block that checks self.padded_vocab_size and
self.vocab_size to keep lines <=119 characters by moving the long f-string into
a shorter expression (e.g., build the message in a local variable like msg =
(f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to
" f"vocab_size ({self.vocab_size})") or use implicit string concatenation/split
across lines) and then call assert self.padded_vocab_size >= self.vocab_size,
msg; keep references to self.padded_vocab_size and self.vocab_size so the
semantic check and error content remain unchanged.
In `@bionemo-recipes/recipes/esm2_native_te/quantization.py`:
- Around line 86-94: The temp YAML file created with "temp_file =
tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False)" is left on
disk because you return temp_file.name; change this by either (A) making the
function return the config contents or a context-managed path (use
tempfile.NamedTemporaryFile with delete=True or write via
tempfile.TemporaryDirectory and yield the path) so the file is removed
automatically, or (B) accept a caller-provided "quant_log_dir" or "cleanup" flag
and write the file into that deterministic log directory (or delete the temp
file before returning when appropriate). Update the function's docstring and any
callers to reflect whether the caller is responsible for cleanup, and adjust
logging (logger.info uses temp_file.name) to log the deterministic path or note
that the file is ephemeral. Ensure you modify the code around the
NamedTemporaryFile usage and the return value accordingly.
In `@bionemo-recipes/recipes/esm2_native_te/README.md`:
- Line 106: Replace the vague anchor text "here" with a descriptive label that
explains the target, e.g., change "and
[here](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html)"
to something like "and the NVIDIA Transformer Engine FP8 primer" (update the
anchor text around the existing URL), so the README sentence reads clearly and
satisfies markdown linting.
- Around line 80-83: Fix the wording in the low-precision benchmark paragraph:
change "low precision" to "low-precision" throughout, add missing period after
"etc" ("etc."), correct the typo "outweights" to "outweighs", and rephrase the
sentence that reads "the cost to quantize activations from high precision to low
precision outweights the benefits of performing matrix multiplication with low
precision" to a clearer form (e.g., "the cost to quantize activations from
high-precision to low-precision outweighs the benefits of using low-precision
matrix multiplication") so the paragraph reads smoothly and consistently.
In `@bionemo-recipes/recipes/esm2_native_te/train_ddp.py`:
- Around line 62-64: Replace the silent warning when args.fp4_config.enabled is
true with a fail-fast guard: in the block that currently calls logger.warning
(the check of args.fp4_config.enabled), either raise a clear RuntimeError to
stop execution or require an explicit override flag (e.g.,
args.allow_experimental_nvfp4_ddp) before proceeding; update the check to
validate that if fp4 is enabled and DDP is in use (the current DDP launch path),
then if not args.allow_experimental_nvfp4_ddp raise an error with a message
explaining NVFP4+DDP is unsupported and how to opt into experimental mode.
Ensure you reference and change the condition around args.fp4_config.enabled and
the logger.warning call so the run is blocked unless the explicit override is
provided.
In `@bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py`:
- Around line 214-234: Add unit tests validating the per-layer quantization
context selection in the loop that uses layer_number_quantized_recipe_map:
ensure that when layer_number_quantized_recipe_map returns an FP8 recipe the
code sets fp_context to nullcontext(), when it returns an FP4 recipe it uses
transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe), and for
None/other recipes it uses transformer_engine.pytorch.autocast(enabled=False);
add tests covering output_hidden_states path as well (all_hidden_states
behavior) and include edge cases (missing map, unexpected recipe types), and
also correct the typo in the TODO comment from "funciton" to "function".
---
Nitpick comments:
In `@bionemo-recipes/models/esm2/modeling_esm_te.py`:
- Around line 60-67: FP4_RECIPES is defined as a single value whereas
FP8_RECIPES is a tuple; make FP4_RECIPES a tuple for consistency by wrapping
transformer_engine.common.recipe.NVFP4BlockScaling in a one-element tuple (use a
trailing comma) so the constant mirrors FP8_RECIPES and any tuple-based handling
of recipes (reference FP4_RECIPES and FP8_RECIPES names).
In `@bionemo-recipes/recipes/esm2_native_te/.dockerignore`:
- Around line 33-34: The ignore pattern "j/" in .dockerignore matches any
directory named "j" at any depth; change it to "/j/" to anchor it to the
repository root so only the top-level scratch dir is ignored. Update the pattern
"j/" to "/j/" (preserve the trailing slash) in the .dockerignore entry to avoid
accidentally excluding nested directories named "j".
In `@bionemo-recipes/recipes/esm2_native_te/dataset.py`:
- Around line 60-62: The call to AutoTokenizer.from_pretrained uses a redundant
conditional for the revision argument; replace revision=tokenizer_revision if
tokenizer_revision else None with simply revision=tokenizer_revision in the
AutoTokenizer.from_pretrained(...) call so the revision parameter directly uses
the tokenizer_revision variable.
In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml`:
- Around line 104-107: The default for use_fp32_master_weights is set to null
which may be unexpected when later checked as args.use_fp32_master_weights;
update the defaults.yaml to document that null is intentional and treated as
falsy (or change the default to false) and add a short comment next to
use_fp32_master_weights explaining that the training script expects a
boolean-like value and that null will be treated as false by the conditional
using args.use_fp32_master_weights; ensure any consumers (e.g., the code that
checks args.use_fp32_master_weights) handle null explicitly if you want
different behavior.
In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml`:
- Around line 9-12: Add a blank line between the dataset block and the WandB
section to match other config files: locate the dataset section (keys dataset,
micro_batch_size, tokenizer_revision) and insert a single empty line before the
existing WandB comment/section (the "# WandB config" or the wandb_init_args
section) so formatting is consistent with L1_3B.yaml and L1_15B_perf_test.yaml.
In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py`:
- Around line 604-609: Replace the hardcoded mask ratio literal by a named class
constant and use it wherever needed: define a class-level constant (e.g.,
MASK_RATIO_TRAIN = 0.15 * 0.8) on the model class, then update the local
variable usage in the method that currently sets mask_ratio_train and in the
similar block at lines 625-637 to read from that constant (self.MASK_RATIO_TRAIN
or ClassName.MASK_RATIO_TRAIN) so scale_factor and embedding scaling use the
named constant for clarity and potential configurability.
- Around line 228-234: Extract the per-layer FP context selection into a small
function (e.g., get_fp_context(fp_recipe)) that returns nullcontext when
fp_recipe is in FP8_RECIPES, returns
transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) when
fp_recipe is in FP4_RECIPES, and returns
transformer_engine.pytorch.autocast(enabled=False) otherwise; then add a unit
test file that parametrizes several fp_recipe values (members of FP8_RECIPES,
FP4_RECIPES, and a default/None case) and asserts that get_fp_context returns
the expected context object type/behavior (e.g., is nullcontext for FP8 cases
and an autocast context for FP4/default) to validate the per-layer FP context
selection logic.
In `@bionemo-recipes/recipes/esm2_native_te/quantization.py`:
- Around line 61-62: The file open call using config_file should specify an
explicit encoding to avoid platform-dependent behavior; update the open(...) in
the block that reads the YAML (the with open(config_file, "r") as f: config =
yaml.safe_load(f) statement) to include encoding="utf-8" so the YAML is read
consistently across platforms.
- Around line 136-157: Replace the manual boilerplate class QuantizationLayers
with a dataclass: add "from dataclasses import dataclass" and annotate the class
with `@dataclass`, convert the four constructor args (fp8_layers_0indexed,
fp4_layers_0indexed, fp8_layers_1indexed, fp4_layers_1indexed) to dataclass
fields using Optional[list[int]] types (or list[int] | None) and remove the
explicit __init__; keep the existing class docstring and attribute names so
__repr__ and __eq__ are provided automatically.
In `@bionemo-recipes/recipes/esm2_native_te/tests/test_train.py`:
- Around line 146-158: Rename test identifiers and docstring to reflect the
config rename from fp8_stats_config to quant_stats_config: update the test
function name test_sanity_ddp_fp8_stats_logging to
test_sanity_ddp_quant_stats_logging, change the fp8_log_dir variable to
quant_log_dir (or similar), and update the docstring "FP8 stats logging" to
"quant stats logging" while keeping all uses of quant_stats_config and the
existing assertions intact so the test still validates FP8 behavior under the
new config name.
In `@bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py`:
- Around line 60-67: FP4_RECIPES is assigned a single class
(transformer_engine.common.recipe.NVFP4BlockScaling) while FP8_RECIPES is a
tuple; make FP4_RECIPES a tuple for consistency and to allow adding more entries
later — update the FP4_RECIPES assignment to use a tuple containing
NVFP4BlockScaling (e.g., FP4_RECIPES =
(transformer_engine.common.recipe.NVFP4BlockScaling,)) so code that expects a
sequence of recipe classes (similar to FP8_RECIPES) will work uniformly.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (5)
docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svgis excluded by!**/*.svgdocs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svgis excluded by!**/*.svgdocs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-mxfp8-6node-conv.pngis excluded by!**/*.pngdocs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-nvfp4-6node-conv.pngis excluded by!**/*.pngdocs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.pngis excluded by!**/*.png
📒 Files selected for processing (24)
.vscode/settings.jsonbionemo-recipes/models/esm2/modeling_esm_te.pybionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.pybionemo-recipes/recipes/esm2_native_te/.dockerignorebionemo-recipes/recipes/esm2_native_te/README.mdbionemo-recipes/recipes/esm2_native_te/dataset.pybionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.pybionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yamlbionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yamlbionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yamlbionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yamlbionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yamlbionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yamlbionemo-recipes/recipes/esm2_native_te/modeling_esm_te.pybionemo-recipes/recipes/esm2_native_te/perf_logger.pybionemo-recipes/recipes/esm2_native_te/quantization.pybionemo-recipes/recipes/esm2_native_te/requirements.txtbionemo-recipes/recipes/esm2_native_te/tests/test_quantization.pybionemo-recipes/recipes/esm2_native_te/tests/test_train.pybionemo-recipes/recipes/esm2_native_te/train_ddp.pybionemo-recipes/recipes/esm2_native_te/train_ddp_cp.pybionemo-recipes/recipes/esm2_native_te/train_fsdp2.pybionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.pybionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py
bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py
Outdated
Show resolved
Hide resolved
bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml
Outdated
Show resolved
Hide resolved
bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
- removes nsys - adds unit tests for layer-wise quantization Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
|
@coderabbitai resolve |
✅ Actions performedComments resolved. Auto-approval is disabled; enable |
|
@coderabbitai resolve |
✅ Actions performedComments resolved. Auto-approval is disabled; enable |
|
/ok to test 8c53527 |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (1)
69-77: Gate FP8 recipe construction behindfp8_config.enabled.
fp8_recipeis instantiated even when FP8 is disabled. This can make BF16/FP4-only runs fail on irrelevant FP8 recipe config errors.♻️ Suggested change
- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) + fp8_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py` around lines 69 - 77, fp8_recipe is always instantiated even when FP8 is disabled; change the logic to only construct fp8_recipe when args.fp8_config.enabled is true (mirror how fp4_recipe is handled). Specifically, initialize fp8_recipe = None and wrap the hydra.utils.get_class(...) construction in an if args.fp8_config.enabled: block so that fp8_recipe is only created when enabled, referencing the existing symbols fp8_recipe and args.fp8_config.enabled to locate and modify the code.bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py (1)
87-90: Gate FP8 recipe creation onfp8_config.enabledhere as well.This avoids requiring FP8 recipe validity for runs that only use BF16/FP4.
♻️ Suggested change
- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) + fp8_recipe = None + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py` around lines 87 - 90, The FP8 recipe is created unconditionally which forces FP8 recipe validation even when FP8 is disabled; wrap the hydra.utils.get_class(...) call that constructs fp8_recipe in a conditional that checks args.fp8_config.enabled and only builds fp8_recipe (using Format[args.fp8_config.fp8_format] and args.fp8_config.fp8_recipe_kwargs) when enabled, otherwise set fp8_recipe to None or skip creation so BF16/FP4-only runs don't require FP8 recipe validity.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@bionemo-recipes/models/esm2/tests/test_layer_quantization.py`:
- Around line 16-242: Add a golden-value parity test that runs the
TransformerEngine ESM encoder (NVEsmEncoder) and the reference ESM encoder on
the same deterministic input/seed and asserts numerical parity (e.g., final
token logits or pooled embeddings) within a small tolerance; create a new test
function (e.g., test_te_vs_reference_golden_value_parity) in this module that
uses torch.manual_seed, a small random input tensor on CUDA, constructs an
NVEsmEncoder via NVEsmConfig and constructs the reference ESM model (import the
reference model used in the repo), runs both forward passes with identical
settings, and asserts outputs are close with pytest.approx or torch.allclose;
ensure the test uses the existing encoder fixture pattern/device and keeps the
comparison deterministic and tolerant to tiny numeric differences.
In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py`:
- Around line 213-222: In initialize_quantization, validate fp8_layers and
fp4_layers before applying them: convert to sets (fp8_layers_set,
fp4_layers_set), check every layer id is an int within 0..len(self.layers)-1 and
raise a ValueError if any id is out of range, check for overlap by computing
intersection = fp8_layers_set & fp4_layers_set and raise a ValueError if
non-empty, and optionally ensure inputs are unique/convertible to int; only
after these checks populate self._layer_precision using range(len(self.layers)).
In `@bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py`:
- Around line 78-81: The FP8 recipe is being constructed unconditionally
(fp8_recipe via hydra.utils.get_class and Format[args.fp8_config.fp8_format])
even when FP8 is disabled; wrap that construction in an if
args.fp8_config.enabled: guard (same pattern used for fp4_recipe) so fp8_recipe
is only created when args.fp8_config.enabled is true, and ensure any references
to fp8_recipe are only used within that guarded block or handled when disabled.
---
Nitpick comments:
In `@bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py`:
- Around line 87-90: The FP8 recipe is created unconditionally which forces FP8
recipe validation even when FP8 is disabled; wrap the hydra.utils.get_class(...)
call that constructs fp8_recipe in a conditional that checks
args.fp8_config.enabled and only builds fp8_recipe (using
Format[args.fp8_config.fp8_format] and args.fp8_config.fp8_recipe_kwargs) when
enabled, otherwise set fp8_recipe to None or skip creation so BF16/FP4-only runs
don't require FP8 recipe validity.
In `@bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py`:
- Around line 69-77: fp8_recipe is always instantiated even when FP8 is
disabled; change the logic to only construct fp8_recipe when
args.fp8_config.enabled is true (mirror how fp4_recipe is handled).
Specifically, initialize fp8_recipe = None and wrap the
hydra.utils.get_class(...) construction in an if args.fp8_config.enabled: block
so that fp8_recipe is only created when enabled, referencing the existing
symbols fp8_recipe and args.fp8_config.enabled to locate and modify the code.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 18a61b77-3157-42be-b9b8-05feda4cad9e
📒 Files selected for processing (14)
bionemo-recipes/models/esm2/modeling_esm_te.pybionemo-recipes/models/esm2/tests/test_layer_quantization.pybionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.pybionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.pybionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yamlbionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yamlbionemo-recipes/recipes/esm2_native_te/modeling_esm_te.pybionemo-recipes/recipes/esm2_native_te/tests/test_distributed_checkpointing.pybionemo-recipes/recipes/esm2_native_te/train_ddp.pybionemo-recipes/recipes/esm2_native_te/train_ddp_cp.pybionemo-recipes/recipes/esm2_native_te/train_fsdp2.pybionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.pybionemo-recipes/recipes/esm2_native_te/train_mfsdp.pybionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py
🚧 Files skipped from review as they are similar to previous changes (1)
- bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py
| """Unit tests for NVEsmEncoder.initialize_quantization and get_layer_autocast.""" | ||
|
|
||
| from contextlib import nullcontext | ||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
| import transformer_engine.common.recipe | ||
| import transformer_engine.pytorch | ||
|
|
||
| from modeling_esm_te import NVEsmConfig, NVEsmEncoder | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def encoder(): | ||
| """Create a small NVEsmEncoder on CUDA for testing.""" | ||
| config = NVEsmConfig( | ||
| hidden_size=320, | ||
| intermediate_size=1280, | ||
| num_hidden_layers=6, | ||
| num_attention_heads=20, | ||
| max_position_embeddings=1026, | ||
| ) | ||
| return NVEsmEncoder(config) | ||
|
|
||
|
|
||
| class TestInitializeQuantization: | ||
| """Tests for NVEsmEncoder.initialize_quantization.""" | ||
|
|
||
| def test_all_fp8(self, encoder): | ||
| fp8_recipe = transformer_engine.common.recipe.DelayedScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0, 1, 2, 3, 4, 5], | ||
| fp4_layers=None, | ||
| fp8_recipe=fp8_recipe, | ||
| fp4_recipe=None, | ||
| ) | ||
| assert encoder._fp8_recipe is fp8_recipe | ||
| assert encoder._fp4_recipe is None | ||
| assert all(encoder._layer_precision[i] == "fp8" for i in range(6)) | ||
|
|
||
| def test_all_fp4(self, encoder): | ||
| fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=None, | ||
| fp4_layers=[0, 1, 2, 3, 4, 5], | ||
| fp8_recipe=None, | ||
| fp4_recipe=fp4_recipe, | ||
| ) | ||
| assert encoder._fp8_recipe is None | ||
| assert encoder._fp4_recipe is fp4_recipe | ||
| assert all(encoder._layer_precision[i] == "fp4" for i in range(6)) | ||
|
|
||
| def test_all_bf16(self, encoder): | ||
| encoder.initialize_quantization( | ||
| fp8_layers=None, | ||
| fp4_layers=None, | ||
| fp8_recipe=None, | ||
| fp4_recipe=None, | ||
| ) | ||
| assert all(encoder._layer_precision[i] is None for i in range(6)) | ||
|
|
||
| def test_mixed_fp8_fp4(self, encoder): | ||
| fp8_recipe = transformer_engine.common.recipe.DelayedScaling() | ||
| fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0, 1, 2], | ||
| fp4_layers=[3, 4, 5], | ||
| fp8_recipe=fp8_recipe, | ||
| fp4_recipe=fp4_recipe, | ||
| ) | ||
| for i in range(3): | ||
| assert encoder._layer_precision[i] == "fp8" | ||
| for i in range(3, 6): | ||
| assert encoder._layer_precision[i] == "fp4" | ||
|
|
||
| def test_mixed_fp8_bf16(self, encoder): | ||
| fp8_recipe = transformer_engine.common.recipe.DelayedScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0, 2, 4], | ||
| fp4_layers=None, | ||
| fp8_recipe=fp8_recipe, | ||
| fp4_recipe=None, | ||
| ) | ||
| assert encoder._layer_precision[0] == "fp8" | ||
| assert encoder._layer_precision[1] is None | ||
| assert encoder._layer_precision[2] == "fp8" | ||
| assert encoder._layer_precision[3] is None | ||
| assert encoder._layer_precision[4] == "fp8" | ||
| assert encoder._layer_precision[5] is None | ||
|
|
||
| def test_mixed_all_three(self, encoder): | ||
| fp8_recipe = transformer_engine.common.recipe.DelayedScaling() | ||
| fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0, 1], | ||
| fp4_layers=[4, 5], | ||
| fp8_recipe=fp8_recipe, | ||
| fp4_recipe=fp4_recipe, | ||
| ) | ||
| assert encoder._layer_precision[0] == "fp8" | ||
| assert encoder._layer_precision[1] == "fp8" | ||
| assert encoder._layer_precision[2] is None # BF16 | ||
| assert encoder._layer_precision[3] is None # BF16 | ||
| assert encoder._layer_precision[4] == "fp4" | ||
| assert encoder._layer_precision[5] == "fp4" | ||
|
|
||
| def test_empty_lists_treated_as_none(self, encoder): | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[], | ||
| fp4_layers=[], | ||
| fp8_recipe=None, | ||
| fp4_recipe=None, | ||
| ) | ||
| assert all(encoder._layer_precision[i] is None for i in range(6)) | ||
|
|
||
| def test_covers_all_layers(self, encoder): | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0], | ||
| fp4_layers=None, | ||
| fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), | ||
| fp4_recipe=None, | ||
| ) | ||
| assert len(encoder._layer_precision) == 6 | ||
|
|
||
| def test_recipes_stored_as_attributes(self, encoder): | ||
| fp8_recipe = transformer_engine.common.recipe.DelayedScaling() | ||
| fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0], | ||
| fp4_layers=[1], | ||
| fp8_recipe=fp8_recipe, | ||
| fp4_recipe=fp4_recipe, | ||
| ) | ||
| # Recipes are stored once, not duplicated per-layer in the map. | ||
| assert encoder._fp8_recipe is fp8_recipe | ||
| assert encoder._fp4_recipe is fp4_recipe | ||
| # The map only contains strings, not recipe objects. | ||
| for v in encoder._layer_precision.values(): | ||
| assert v is None or isinstance(v, str) | ||
|
|
||
|
|
||
| class TestGetLayerAutocast: | ||
| """Tests for NVEsmEncoder.get_layer_autocast.""" | ||
|
|
||
| def test_fp8_layer_returns_nullcontext(self, encoder): | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0], | ||
| fp4_layers=None, | ||
| fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), | ||
| fp4_recipe=None, | ||
| ) | ||
| ctx = encoder.get_layer_autocast(0) | ||
| assert isinstance(ctx, nullcontext) | ||
|
|
||
| def test_fp4_layer_returns_te_autocast(self, encoder): | ||
| fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=None, | ||
| fp4_layers=[0], | ||
| fp8_recipe=None, | ||
| fp4_recipe=fp4_recipe, | ||
| ) | ||
| with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: | ||
| mock_autocast.return_value = "fp4_context" | ||
| ctx = encoder.get_layer_autocast(0) | ||
| mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe) | ||
| assert ctx == "fp4_context" | ||
|
|
||
| def test_bf16_layer_returns_te_autocast_disabled(self, encoder): | ||
| encoder.initialize_quantization( | ||
| fp8_layers=None, | ||
| fp4_layers=None, | ||
| fp8_recipe=None, | ||
| fp4_recipe=None, | ||
| ) | ||
| with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: | ||
| mock_autocast.return_value = "bf16_context" | ||
| ctx = encoder.get_layer_autocast(0) | ||
| mock_autocast.assert_called_once_with(enabled=False) | ||
| assert ctx == "bf16_context" | ||
|
|
||
| def test_uninitialized_defaults_to_bf16(self, encoder): | ||
| """When initialize_quantization was never called, all layers default to BF16.""" | ||
| with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: | ||
| mock_autocast.return_value = "bf16_context" | ||
| ctx = encoder.get_layer_autocast(0) | ||
| mock_autocast.assert_called_once_with(enabled=False) | ||
| assert ctx == "bf16_context" | ||
|
|
||
| def test_mixed_layers_return_correct_contexts(self, encoder): | ||
| fp8_recipe = transformer_engine.common.recipe.DelayedScaling() | ||
| fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0, 1], | ||
| fp4_layers=[2, 3], | ||
| fp8_recipe=fp8_recipe, | ||
| fp4_recipe=fp4_recipe, | ||
| ) | ||
| # FP8 layers -> nullcontext | ||
| assert isinstance(encoder.get_layer_autocast(0), nullcontext) | ||
| assert isinstance(encoder.get_layer_autocast(1), nullcontext) | ||
|
|
||
| # FP4 and BF16 layers -> te.pytorch.autocast (not nullcontext) | ||
| with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: | ||
| mock_autocast.return_value = "fp4_context" | ||
| encoder.get_layer_autocast(2) | ||
| mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe) | ||
|
|
||
| with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: | ||
| mock_autocast.return_value = "bf16_context" | ||
| encoder.get_layer_autocast(4) | ||
| mock_autocast.assert_called_with(enabled=False) | ||
|
|
||
| def test_layer_precision_map_is_pickleable(self, encoder): | ||
| """The _layer_precision map should be trivially pickleable (only strings/None).""" | ||
| import pickle | ||
|
|
||
| fp8_recipe = transformer_engine.common.recipe.DelayedScaling() | ||
| fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() | ||
| encoder.initialize_quantization( | ||
| fp8_layers=[0, 1], | ||
| fp4_layers=[2, 3], | ||
| fp8_recipe=fp8_recipe, | ||
| fp4_recipe=fp4_recipe, | ||
| ) | ||
| roundtripped = pickle.loads(pickle.dumps(encoder._layer_precision)) | ||
| assert roundtripped == encoder._layer_precision |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Add at least one TE-vs-reference golden-value parity test in this module.
These tests cover routing/context behavior, but they do not assert numerical parity between the TE model and the reference ESM model for a fixed input/seed.
As per coding guidelines: "In bionemo-recipes/models/, create golden value tests proving that the TransformerEngine model matches the reference model".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@bionemo-recipes/models/esm2/tests/test_layer_quantization.py` around lines 16
- 242, Add a golden-value parity test that runs the TransformerEngine ESM
encoder (NVEsmEncoder) and the reference ESM encoder on the same deterministic
input/seed and asserts numerical parity (e.g., final token logits or pooled
embeddings) within a small tolerance; create a new test function (e.g.,
test_te_vs_reference_golden_value_parity) in this module that uses
torch.manual_seed, a small random input tensor on CUDA, constructs an
NVEsmEncoder via NVEsmConfig and constructs the reference ESM model (import the
reference model used in the repo), runs both forward passes with identical
settings, and asserts outputs are close with pytest.approx or torch.allclose;
ensure the test uses the existing encoder fixture pattern/device and keeps the
comparison deterministic and tolerant to tiny numeric differences.
| fp8_layers_set = set(fp8_layers) if fp8_layers else set() | ||
| fp4_layers_set = set(fp4_layers) if fp4_layers else set() | ||
| self._layer_precision = {} | ||
| for layer_number in range(len(self.layers)): | ||
| if layer_number in fp8_layers_set: | ||
| self._layer_precision[layer_number] = "fp8" | ||
| elif layer_number in fp4_layers_set: | ||
| self._layer_precision[layer_number] = "fp4" | ||
| else: | ||
| self._layer_precision[layer_number] = None |
There was a problem hiding this comment.
Validate overlaps and bounds in initialize_quantization.
Right now, overlapping or out-of-range layer IDs are silently accepted. This can hide config mistakes and route layers to unintended precision.
♻️ Suggested fix
self._fp8_recipe = fp8_recipe
self._fp4_recipe = fp4_recipe
fp8_layers_set = set(fp8_layers) if fp8_layers else set()
fp4_layers_set = set(fp4_layers) if fp4_layers else set()
+ overlap = fp8_layers_set & fp4_layers_set
+ if overlap:
+ raise ValueError(f"fp8_layers and fp4_layers overlap: {sorted(overlap)}")
+
+ valid_layers = set(range(len(self.layers)))
+ invalid = (fp8_layers_set | fp4_layers_set) - valid_layers
+ if invalid:
+ raise ValueError(
+ f"Layer indices out of range [0, {len(self.layers) - 1}]: {sorted(invalid)}"
+ )
+
self._layer_precision = {}
for layer_number in range(len(self.layers)):
if layer_number in fp8_layers_set:
self._layer_precision[layer_number] = "fp8"
elif layer_number in fp4_layers_set:
self._layer_precision[layer_number] = "fp4"
else:
self._layer_precision[layer_number] = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py` around lines 213 -
222, In initialize_quantization, validate fp8_layers and fp4_layers before
applying them: convert to sets (fp8_layers_set, fp4_layers_set), check every
layer id is an int within 0..len(self.layers)-1 and raise a ValueError if any id
is out of range, check for overlap by computing intersection = fp8_layers_set &
fp4_layers_set and raise a ValueError if non-empty, and optionally ensure inputs
are unique/convertible to int; only after these checks populate
self._layer_precision using range(len(self.layers)).
| # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. | ||
| fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( | ||
| fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether FP8 recipe construction is gated by fp8_config.enabled in train_ddp_cp.py.
cat -n bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py | sed -n '74,90p'
echo "---"
rg -n "fp8_recipe|fp8_config.enabled" -C3 bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.pyRepository: NVIDIA/bionemo-framework
Length of output: 2742
Guard FP8 recipe construction behind fp8_config.enabled.
Lines 79–81 construct fp8_recipe unconditionally, even when FP8 is disabled. This contradicts the comment on line 78 and creates an asymmetry with fp4_recipe (lines 82–86), which is correctly guarded by if args.fp4_config.enabled:. Disabled FP8 runs will still fail if the FP8 config is invalid.
Suggested fix
+ fp8_recipe = None
+ if args.fp8_config.enabled:
- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
+ fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
- )
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py` around lines 78 - 81,
The FP8 recipe is being constructed unconditionally (fp8_recipe via
hydra.utils.get_class and Format[args.fp8_config.fp8_format]) even when FP8 is
disabled; wrap that construction in an if args.fp8_config.enabled: guard (same
pattern used for fp4_recipe) so fp8_recipe is only created when
args.fp8_config.enabled is true, and ensure any references to fp8_recipe are
only used within that guarded block or handled when disabled.
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>

Layer-wise MXFP8/NVFP4 precision for ESM-2 TransformerEngine training
Adds support for per-layer quantization precision control, enabling mixed FP8/FP4/BF16
configurations across transformer layers during training. This allows users to assign
different quantization formats to different layers via Hydra config (1-indexed fp8_layers
and fp4_layers lists), enabling convergence/performance tradeoff exploration.
Key changes:
layer_number_quantized_recipe_map that selects the appropriate TE autocast context per
layer (nullcontext for FP8 to respect outer autocast, explicit autocast for FP4, or
autocast(enabled=False) for BF16).
(resolve_quantization_layers), generating debug API regex patterns
(generate_layer_regex), and initializing nvdlfw_inspect quant stats logging
(initialize_quant_stats_logging). Handles 0-indexed (model internals) and 1-indexed
(user-facing) layer numbering.
resolves layer assignments from config, builds recipe map, assigns to encoder, and
optionally initializes quant stats logging.
NVEsmConfig/NVEsmForMaskedLM for consistency and to avoid remote code trust issues.
fp8_layers, fp4_layers, and use_fp32_master_weights settings.
accelerate, peft) and the models package with layer-wise quantization support, NVTX
annotations per encoder layer, and FP8_RECIPES/FP4_RECIPES type constants.
generate_layer_regex, and update_quant_stats_config covering defaults, explicit
layers, mixed assignments, overlap validation, and edge cases.### Description
Usage
Type of changes
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.
Unit tests marked as
@pytest.mark.multi_gpuor@pytest.mark.distributedare not run in the PR pipeline.For more details, see CONTRIBUTING
Note
By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
/ok to testcomment on the pull request to trigger CI. This will need to be done for each new commit.Triggering Code Rabbit AI Review
To trigger a code review from code rabbit, comment on a pull request with one of these commands:
See https://docs.coderabbit.ai/reference/review-commands for a full list of commands.
Pre-submit Checklist
Summary by CodeRabbit
Release Notes
New Features
Documentation
Configuration