Following up on #332 (which fixes 300M/600M .pth loading), ESMC.from_pretrained("esmc_6b") fails locally for two further reasons. Filing as an issue rather than a PR since one of them looks like a repo-packaging mismatch that a maintainer should decide on.
1. Non-persistent RoPE buffer left on meta
ESMC_6B_202412 builds under init_empty_weights(), then loads the sharded safetensors. But esm/layers/rotary.py:
self.register_buffer("inv_freq", inv_freq, persistent=False)
inv_freq is non-persistent, so it is not in the checkpoint. Under init_empty_weights it stays on meta, and the first .to(device) / forward raises:
NotImplementedError: Cannot copy out of meta tensor; no data!
(The 300M/600M .pth path happens to dodge this because that checkpoint carries the buffer.)
2. Checkpoint key names do not match the SDK module tree
biohub/esmc-6b-2024-12 ships transformers-wrapped safetensors. Keys are:
esmc.embed.weight
esmc.transformer.blocks.{0..79}....
lm_head.0.weight # MLM head
but esm.models.esmc.ESMC expects embed.weight, transformer.blocks.N..., sequence_head.*. So load_state_dict reports ~808 missing and ~808 unexpected — the backbone never loads (silently, with strict=False).
Workaround that loads it correctly
Build without init_empty_weights (so inv_freq materializes), then remap keys:
raw = {} # merged from the safetensors shards
state_dict = {k[len("esmc."):]: v for k, v in raw.items() if k.startswith("esmc.")}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
# -> 6 missing (sequence_head + non-persistent buffers), 0 unexpected; backbone loads.
This produces correct embeddings on an A100.
Suggested fixes
- Materialize non-persistent buffers after load (e.g. avoid meta for them, or re-init post-load).
- Either republish
esmc-6b-2024-12 with SDK-native key names, or have the 6B loader remap the esmc.-prefixed / lm_head keys.
Happy to PR whichever direction you prefer.
Following up on #332 (which fixes 300M/600M
.pthloading),ESMC.from_pretrained("esmc_6b")fails locally for two further reasons. Filing as an issue rather than a PR since one of them looks like a repo-packaging mismatch that a maintainer should decide on.1. Non-persistent RoPE buffer left on
metaESMC_6B_202412builds underinit_empty_weights(), then loads the sharded safetensors. Butesm/layers/rotary.py:inv_freqis non-persistent, so it is not in the checkpoint. Underinit_empty_weightsit stays onmeta, and the first.to(device)/ forward raises:(The 300M/600M
.pthpath happens to dodge this because that checkpoint carries the buffer.)2. Checkpoint key names do not match the SDK module tree
biohub/esmc-6b-2024-12ships transformers-wrapped safetensors. Keys are:but
esm.models.esmc.ESMCexpectsembed.weight,transformer.blocks.N...,sequence_head.*. Soload_state_dictreports ~808 missing and ~808 unexpected — the backbone never loads (silently, withstrict=False).Workaround that loads it correctly
Build without
init_empty_weights(soinv_freqmaterializes), then remap keys:This produces correct embeddings on an A100.
Suggested fixes
esmc-6b-2024-12with SDK-native key names, or have the 6B loader remap theesmc.-prefixed /lm_headkeys.Happy to PR whichever direction you prefer.