Skip to content

ESMFold2 produces all-NaN outputs on ROCm 6.2 — workaround: disable PyTorch flash/mem-efficient SDP backends #322

@jfpettit

Description

@jfpettit

OS: RedHat Linux
GPU: AMD MI300A
ROCm version: 6.2
PyTorch version: 2.5.0

Running ESMFold2 on any protein-only complex (e.g. a nanobody/antigen complex) on ROCm 6.2 produces all-NaN outputs for pLDDT, pTM, ipTM. The issue also reproduces on the published tutorial example for RNase H1 in the Google Colab notebook here. This is specific to ROCm.

Here is some code to reproduce the issue on the PDB structure 7EOW:

from esm.models.esmfold2 import (
    ESMFold2InputBuilder, ProteinInput, StructurePredictionInput
)
from transformers.models.esmfold2.modeling_esmfold2 import ESMFold2Model
import torch

model = ESMFold2Model.from_pretrained("biohub/ESMFold2").cuda().eval()

spi = StructurePredictionInput(sequences=[
    ProteinInput(id="A", sequence="MEVQLVESGGGLVQPGGSLRLSCAASGRTFSYNPMGWFRQAPGKGRELVAAISRTGGSTYYPDSVEGRFTISRDNAKRMVYLQMNSLRAEDTAVYYCAAAGVRAEDGRVRTLPSEYTFWGQGTQVTVSSLEHHHHHH"),
    ProteinInput(id="B", sequence="MGSSHHHHHHSSGENLYFQSGSHMDISEPPLHDFYCSRLLDLVFLLDGSSRLSEAEFEVLKAFVVDMMERLRISQKWVRVAVVEYHDGSHAYIGLKDRKRPSELRRIASQVKYAGSQVASTSEVLKYTLFQIFSKIDRPEASRITLLLMASQEPQRMSRNFVRYVQGLKKKKVIVIPVGIGPHANLKQIRLIEKQAPENKAFVLSSVDELEQQRDEIVSYLCDLAPEAPPPT"),
])

result = ESMFold2InputBuilder().fold(
    model, spi, num_loops=3, num_sampling_steps=50, num_diffusion_samples=1, seed=0
)
print(f"pLDDT: {float(result.plddt.mean()):.3f}")  # prints NaN

During my debugging of this, confirmed:

  1. Model weights are clean — no NaNs, Infs, or all-zero tensors in any shard (after re-download)
  2. ESMC runs without NaNs in standalone use
  3. When called from ESMFold2, ESMC hidden states contain NaNs — traced to the attention computation inside ESMC, not the input featurization
  4. Root cause: PyTorch's flash SDP and memory-efficient SDP backends produce NaNs on ROCm 6.2

I was able to resolve this by doing:

torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

update: This issue was also resolved by using ROCm 6.4 and PyTorch 2.9. So a note in the README about which ROCm versions to use might be sufficient.

Either way, this felt like valuable info to share for ROCm users. Thanks to all the maintainers for releasing such awesome models for everyone to use.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions