diff --git a/esm/pretrained.py b/esm/pretrained.py index 5c12bdf..f076a83 100644 --- a/esm/pretrained.py +++ b/esm/pretrained.py @@ -72,7 +72,14 @@ def ESMC_300M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = tokenizer=get_esmc_model_tokenizers(), use_flash_attn=use_flash_attn, ).eval() - load_torch_model(model, data_root("esmc-300")) + # The 300M repo ships a single legacy .pth (no safetensors), so load it + # explicitly like the ESM3 loaders above rather than via load_torch_model, + # which (safe=True) expects a safetensors checkpoint in the directory. + state_dict = torch.load( + data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth", + map_location=device, + ) + model.load_state_dict(state_dict, assign=True) model = model.to(device) return model @@ -86,7 +93,14 @@ def ESMC_600M_202412(device: torch.device | str = "cpu", use_flash_attn: bool = tokenizer=get_esmc_model_tokenizers(), use_flash_attn=use_flash_attn, ).eval() - load_torch_model(model, data_root("esmc-600")) + # The 600M repo ships a single legacy .pth (no safetensors), so load it + # explicitly like the ESM3 loaders above rather than via load_torch_model, + # which (safe=True) expects a safetensors checkpoint in the directory. + state_dict = torch.load( + data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth", + map_location=device, + ) + model.load_state_dict(state_dict, assign=True) model = model.to(device) return model