From 81ac4013bdd0e344531f8c1503aa0fc29c3ecd97 Mon Sep 17 00:00:00 2001 From: sakzgupzzz Date: Tue, 2 Jun 2026 21:52:58 -0400 Subject: [PATCH] Fix ESMC 300M/600M local loading: torch.load the .pth weights The 300M and 600M HF repos (biohub/esmc-300m-2024-12, biohub/esmc-600m-2024-12) ship a single legacy checkpoint at data/weights/esmc_*_2024_12_v0.pth and no safetensors. ESMC_300M_202412 / ESMC_600M_202412 loaded them via load_torch_model(model, dir), which defaults to safe=True and looks for a safetensors checkpoint in the directory, so loading fails with: ValueError: Directory '.../snapshots/...' does not contain a valid checkpoint. Expected either a sharded checkpoint with an index file, or a single model file. Load the .pth explicitly with torch.load + load_state_dict(assign=True), exactly as the ESM3 loaders in this file already do. The 6B loader is unchanged because its repo ships sharded safetensors, for which load_torch_model works. Co-Authored-By: Claude Opus 4.8 --- esm/pretrained.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/esm/pretrained.py b/esm/pretrained.py index 5c12bdf8..f076a835 100644 --- a/esm/pretrained.py +++ b/esm/pretrained.py @@ -72,7 +72,14 @@ def ESMC_300M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = tokenizer=get_esmc_model_tokenizers(), use_flash_attn=use_flash_attn, ).eval() - load_torch_model(model, data_root("esmc-300")) + # The 300M repo ships a single legacy .pth (no safetensors), so load it + # explicitly like the ESM3 loaders above rather than via load_torch_model, + # which (safe=True) expects a safetensors checkpoint in the directory. + state_dict = torch.load( + data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth", + map_location=device, + ) + model.load_state_dict(state_dict, assign=True) model = model.to(device) return model @@ -86,7 +93,14 @@ def ESMC_600M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = tokenizer=get_esmc_model_tokenizers(), use_flash_attn=use_flash_attn, ).eval() - load_torch_model(model, data_root("esmc-600")) + # The 600M repo ships a single legacy .pth (no safetensors), so load it + # explicitly like the ESM3 loaders above rather than via load_torch_model, + # which (safe=True) expects a safetensors checkpoint in the directory. + state_dict = torch.load( + data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth", + map_location=device, + ) + model.load_state_dict(state_dict, assign=True) model = model.to(device) return model