Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ pip install -r code/requirements_public_inference.txt

2. **Get the pre-trained models**
This repo ships two pre-trained model files (tracked with Git LFS):
- `models/PreDecoderModelMemory_r9_v1.0.77.pt` (receptive field R=9, checkpoint 77)
- `models/PreDecoderModelMemory_r13_v1.0.86.pt` (receptive field R=13, checkpoint 86)
- `models/Ising-Decoder-SurfaceCode-1-Fast.pt` (receptive field R=9)
- `models/Ising-Decoder-SurfaceCode-1-Accurate.pt` (receptive field R=13)

Clones get the files via `git lfs pull`. Optionally, set `PREDECODER_MODEL_URL` to the LFS/raw URL to fetch files when not in the working tree (e.g. in a minimal checkout or CI).

Expand Down
4 changes: 2 additions & 2 deletions code/export/checkpoint_to_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

Usage:
PYTHONPATH=code python code/export/checkpoint_to_safetensors.py \\
--checkpoint models/PreDecoderModelMemory_r9_v1.0.77.pt \\
--checkpoint models/Ising-Decoder-SurfaceCode-1-Fast.pt \\
--model-id 1 [--fp16]

Then run inference with:
PREDECODER_SAFETENSORS_CHECKPOINT=models/PreDecoderModelMemory_r9_v1.0.77_fp16.safetensors \\
PREDECODER_SAFETENSORS_CHECKPOINT=models/Ising-Decoder-SurfaceCode-1-Fast_fp16.safetensors \\
WORKFLOW=inference DISTANCE=9 N_ROUNDS=9 EXPERIMENT_NAME=predecoder_model_1 \\
bash code/scripts/local_run.sh
"""
Expand Down
9 changes: 3 additions & 6 deletions code/tests/test_inference_public_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@
MODELS_DIR = REPO_ROOT / "models"

MODEL_R9 = {
"filename": "PreDecoderModelMemory_r9_v1.0.77.pt",
"checkpoint": 77,
"filename": "Ising-Decoder-SurfaceCode-1-Fast.pt",
"model_id": 1,
}
MODEL_R13 = {
"filename": "PreDecoderModelMemory_r13_v1.0.86.pt",
"checkpoint": 86,
"filename": "Ising-Decoder-SurfaceCode-1-Accurate.pt",
"model_id": 4,
}

Expand Down Expand Up @@ -74,8 +72,7 @@ def _run_inference_rtest(distance: int, n_rounds: int, model_info: dict):
f"Missing model file: {model_file}. It must be in the repo (Git LFS). Run 'git lfs pull' or restore the file."
)

merged.model_checkpoint_dir = str(model_file.parent)
merged.test.use_model_checkpoint = model_info["checkpoint"]
merged.model_checkpoint_file = str(model_file)
merged.test.latency_num_samples = 0
merged.test.verbose_inference = False
if "dataloader" in merged.test:
Expand Down
20 changes: 19 additions & 1 deletion code/workflows/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def find_best_model(path, *, rank: int = 0):
print(f" [{marker}] {filename} (epoch {epoch})")

if best_file is None:
raise FileNotFoundError(f"No valid PreDecoderModelMemory files found in {path}")
raise FileNotFoundError(f"No valid model checkpoint files found in {path}")

best_model_path = os.path.join(path, best_file)
if rank == 0:
Expand Down Expand Up @@ -211,6 +211,24 @@ def _load_model(cfg, dist):
cfg.enable_fp16 = True
return model

# Direct file path override (for named pretrained models without epoch numbers)
model_checkpoint_file = getattr(cfg, 'model_checkpoint_file', None)
if model_checkpoint_file:
model_checkpoint_file = _resolve_dir(str(model_checkpoint_file))
if not os.path.exists(model_checkpoint_file):
raise FileNotFoundError(f"Checkpoint not found: {model_checkpoint_file}")
if dist.rank == 0:
print(f"Loading model from: {model_checkpoint_file}")
model = ModelFactory.create_model(cfg).to(dist.device)
if cfg.enable_fp16:
model = model.half()
state_dict = _load_state_dict_from_pt(model_checkpoint_file, dist.device)
model.load_state_dict(state_dict)
if dist.rank == 0:
param_count = sum(p.numel() for p in model.parameters())
print(f"Model loaded ({param_count:,} parameters)")
return model

model = ModelFactory.create_model(cfg).to(dist.device)

if cfg.enable_fp16:
Expand Down
Loading