OS: RedHat Linux
GPU: AMD MI300A
ROCm version: 6.2
PyTorch version: 2.5.0
Running ESMFold2 on any protein-only complex (e.g. a nanobody/antigen complex) on ROCm 6.2 produces all-NaN outputs for pLDDT, pTM, ipTM. The issue also reproduces on the published tutorial example for RNase H1 in the Google Colab notebook here. This is specific to ROCm.
Here is some code to reproduce the issue on the PDB structure 7EOW:
from esm.models.esmfold2 import (
ESMFold2InputBuilder, ProteinInput, StructurePredictionInput
)
from transformers.models.esmfold2.modeling_esmfold2 import ESMFold2Model
import torch
model = ESMFold2Model.from_pretrained("biohub/ESMFold2").cuda().eval()
spi = StructurePredictionInput(sequences=[
ProteinInput(id="A", sequence="MEVQLVESGGGLVQPGGSLRLSCAASGRTFSYNPMGWFRQAPGKGRELVAAISRTGGSTYYPDSVEGRFTISRDNAKRMVYLQMNSLRAEDTAVYYCAAAGVRAEDGRVRTLPSEYTFWGQGTQVTVSSLEHHHHHH"),
ProteinInput(id="B", sequence="MGSSHHHHHHSSGENLYFQSGSHMDISEPPLHDFYCSRLLDLVFLLDGSSRLSEAEFEVLKAFVVDMMERLRISQKWVRVAVVEYHDGSHAYIGLKDRKRPSELRRIASQVKYAGSQVASTSEVLKYTLFQIFSKIDRPEASRITLLLMASQEPQRMSRNFVRYVQGLKKKKVIVIPVGIGPHANLKQIRLIEKQAPENKAFVLSSVDELEQQRDEIVSYLCDLAPEAPPPT"),
])
result = ESMFold2InputBuilder().fold(
model, spi, num_loops=3, num_sampling_steps=50, num_diffusion_samples=1, seed=0
)
print(f"pLDDT: {float(result.plddt.mean()):.3f}") # prints NaN
During my debugging of this, confirmed:
- Model weights are clean — no NaNs, Infs, or all-zero tensors in any shard (after re-download)
- ESMC runs without NaNs in standalone use
- When called from ESMFold2, ESMC hidden states contain NaNs — traced to the attention computation inside ESMC, not the input featurization
- Root cause: PyTorch's flash SDP and memory-efficient SDP backends produce NaNs on ROCm 6.2
I was able to resolve this by doing:
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
update: This issue was also resolved by using ROCm 6.4 and PyTorch 2.9. So a note in the README about which ROCm versions to use might be sufficient.
Either way, this felt like valuable info to share for ROCm users. Thanks to all the maintainers for releasing such awesome models for everyone to use.
OS: RedHat Linux
GPU: AMD MI300A
ROCm version: 6.2
PyTorch version: 2.5.0
Running ESMFold2 on any protein-only complex (e.g. a nanobody/antigen complex) on ROCm 6.2 produces all-NaN outputs for pLDDT, pTM, ipTM. The issue also reproduces on the published tutorial example for RNase H1 in the Google Colab notebook here. This is specific to ROCm.
Here is some code to reproduce the issue on the PDB structure 7EOW:
During my debugging of this, confirmed:
I was able to resolve this by doing:
update: This issue was also resolved by using ROCm 6.4 and PyTorch 2.9. So a note in the README about which ROCm versions to use might be sufficient.
Either way, this felt like valuable info to share for ROCm users. Thanks to all the maintainers for releasing such awesome models for everyone to use.