Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions code/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@

Contains:
- factory: Factory for creating models from config
- predecoder: Pre-decoder model architectures (PreDecoderModelMemory_v1, PreDecoderModelMemory_v2)
- predecoder: Pre-decoder model architectures (PreDecoderModelMemory_v1)
"""
from model.factory import ModelFactory

# Import predecoder models lazily to avoid hard dependency on optional training
# stacks (e.g., physicsnemo) during lightweight config validation.
try:
from model.predecoder import PreDecoderModelMemory_v1, PreDecoderModelMemory_v2
from model.predecoder import PreDecoderModelMemory_v1
except ModuleNotFoundError:
PreDecoderModelMemory_v1 = None
PreDecoderModelMemory_v2 = None
4 changes: 0 additions & 4 deletions code/model/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,5 @@ def _create_surface_model(cfg):
from model.predecoder import PreDecoderModelMemory_v1
model = PreDecoderModelMemory_v1(cfg)
return model
elif cfg.model.version == "predecoder_memory_v2":
from model.predecoder import PreDecoderModelMemory_v2
model = PreDecoderModelMemory_v2(cfg)
return model
else:
raise ValueError(f"Invalid model version: {cfg.model.version}")
106 changes: 0 additions & 106 deletions code/model/predecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,82 +112,6 @@ def forward(self, x):
return self.net(x) # x: (B, 4, T, D, D)


class PreDecoderModelMemory_v2(nn.Module):

def __init__(self, cfg):
super(PreDecoderModelMemory_v2, self).__init__()

self.distance = cfg.distance
self.n_rounds = cfg.n_rounds
self.dropout_p = cfg.model.dropout_p
activation_class = self._get_activation_class(cfg.model.activation)
self.activation_fn = activation_class()

filters = cfg.model.num_filters
kernel_sizes = cfg.model.kernel_size

assert (len(filters) - 2) % 3 == 0, \
"The number of filters minus 2 (for the first and last layers) must be divisible by 3."
assert len(filters) == len(kernel_sizes), \
"Mismatch: num_filters and kernel_size must be the same length."
input_channels = cfg.model.input_channels
out_channels = cfg.model.out_channels
assert filters[-1] == out_channels, \
f"The last element of num_filters must match the configured out_channels ({out_channels}), but got {filters[-1]}"

# === Initial Conv3D layer ===
self.layers = nn.ModuleList()
self.layers.append(
nn.Sequential(
nn.Conv3d(
in_channels=input_channels,
out_channels=filters[0],
kernel_size=kernel_sizes[0],
padding=kernel_sizes[0] // 2
), nn.BatchNorm3d(filters[0]), self.activation_fn
)
)

# === Residual Blocks ===
in_ch = filters[0]
i = 1
while i + 2 < len(filters) - 1:
out_ch1, out_ch2, out_ch3 = filters[i], filters[i + 1], filters[i + 2]
ks = [kernel_sizes[i], kernel_sizes[i + 1], kernel_sizes[i + 2]]
self.layers.append(
ResidualBlock3D(
channels=[in_ch, out_ch1, out_ch2, out_ch3],
kernel_sizes=ks,
activation=activation_class
)
)
in_ch = out_ch3
i += 3

# === Final Conv3D layer ===
self.final_conv = nn.Conv3d(
in_channels=filters[-2],
out_channels=out_channels,
kernel_size=kernel_sizes[-1],
padding=kernel_sizes[-1] // 2
)

def _get_activation_class(self, name):
if name == "relu":
return nn.ReLU
elif name == "gelu":
return nn.GELU
elif name == "leakyrelu":
return nn.LeakyReLU
else:
raise ValueError(f"Unsupported activation: {name}")

def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.final_conv(x)


# === Define a mock config using SimpleNamespace ===
def get_mock_config():
cfg = SimpleNamespace()
Expand All @@ -203,35 +127,6 @@ def get_mock_config():
return cfg


# === Mock config for testing ===
def get_mock_config_v2():
cfg = SimpleNamespace()
cfg.model = SimpleNamespace()
cfg.distance = 11
cfg.n_rounds = 3
cfg.model.dropout_p = 0.1
cfg.model.activation = 'relu'
cfg.model.input_channels = 4
cfg.model.out_channels = 2
cfg.model.num_filters = [8, 16, 16, 8, 8, 8, 4, 2] # (len - 2) % 3 == 0
cfg.model.kernel_size = [3] * len(cfg.model.num_filters)
return cfg


# === Test ===
def test_model_v2():
cfg = get_mock_config_v2()
model = PreDecoderModelMemory_v2(cfg)

B, C_in, T, D = 2, cfg.model.input_channels, cfg.n_rounds, cfg.distance
x = torch.randn(B, C_in, T, D, D)
out = model(x)

expected_shape = (B, cfg.model.out_channels, T, D, D)
assert out.shape == expected_shape, f"❌ Output shape mismatch: expected {expected_shape}, got {out.shape}"
print("✅ Model v2 test passed. Output shape:", out.shape)


# === Run the test ===
def test_model():
cfg = get_mock_config()
Expand All @@ -251,4 +146,3 @@ def test_model():

if __name__ == "__main__":
test_model()
test_model_v2()
13 changes: 1 addition & 12 deletions code/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_generator_default_device_is_cuda(self):
# ---------------------------------------------------------------------------
@_require_cuda
class TestPreDecoderModelGPU(unittest.TestCase):
"""PreDecoderModelMemory v1 and v2 forward pass on CUDA."""
"""PreDecoderModelMemory v1 forward pass on CUDA."""

def setUp(self):
self.device = torch.device("cuda")
Expand All @@ -252,17 +252,6 @@ def test_v1_forward_on_cuda(self):
self.assertEqual(out.device.type, "cuda")
self.assertEqual(out.shape, (B, cfg.model.out_channels, T, D, D))

def test_v2_forward_on_cuda(self):
from model.predecoder import PreDecoderModelMemory_v2, get_mock_config_v2

cfg = get_mock_config_v2()
model = PreDecoderModelMemory_v2(cfg).to(self.device)
B, C, T, D = 4, cfg.model.input_channels, cfg.n_rounds, cfg.distance
x = torch.randn(B, C, T, D, D, device=self.device)
out = model(x)
self.assertEqual(out.device.type, "cuda")
self.assertEqual(out.shape, (B, cfg.model.out_channels, T, D, D))

def test_v1_gradient_flow_on_cuda(self):
"""Verify gradients propagate through the model on GPU."""
from model.predecoder import PreDecoderModelMemory_v1, get_mock_config
Expand Down
11 changes: 1 addition & 10 deletions code/tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
sys.path.insert(0, str(_repo_code))

from model.factory import ModelFactory
from model.predecoder import get_mock_config, get_mock_config_v2
from model.predecoder import get_mock_config


class TestModelFactory(unittest.TestCase):
Expand All @@ -44,12 +44,3 @@ def test_create_surface_model_v1(self):
self.assertIsNotNone(model)
self.assertEqual(model.distance, cfg.distance)
self.assertEqual(model.n_rounds, cfg.n_rounds)

def test_create_surface_model_v2(self):
cfg = get_mock_config_v2()
cfg.code = "surface"
cfg.model.version = "predecoder_memory_v2"
model = ModelFactory.create_model(cfg)
self.assertIsNotNone(model)
self.assertEqual(model.distance, cfg.distance)
self.assertEqual(model.n_rounds, cfg.n_rounds)
15 changes: 1 addition & 14 deletions code/tests/test_predecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""Tests for model/predecoder: forward pass shape (v1 and v2). Catches breakage from architecture/config changes."""
"""Tests for model/predecoder: forward pass shape (v1). Catches breakage from architecture/config changes."""

import unittest
from pathlib import Path
Expand All @@ -18,9 +18,7 @@

from model.predecoder import (
PreDecoderModelMemory_v1,
PreDecoderModelMemory_v2,
get_mock_config,
get_mock_config_v2,
)


Expand All @@ -35,16 +33,5 @@ def test_forward_shape(self):
self.assertEqual(out.shape, (B, cfg.model.out_channels, T, D, D))


class TestPreDecoderModelMemoryV2(unittest.TestCase):

def test_forward_shape(self):
cfg = get_mock_config_v2()
model = PreDecoderModelMemory_v2(cfg)
B, C, T, D = 2, cfg.model.input_channels, cfg.n_rounds, cfg.distance
x = torch.randn(B, C, T, D, D)
out = model(x)
self.assertEqual(out.shape, (B, cfg.model.out_channels, T, D, D))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion conf/config_pre_decoder_memory_surface_model_1_d9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ data:
code_rotation: XV # Surface code orientation: XV, XH, ZV, ZH (default: XV)

model:
version: predecoder_memory_v1 # predecoder_memory_v1 (no skip conections and batch norm), predecoder_memory_v2 (includes skip connections and batch norm)
version: predecoder_memory_v1
dropout_p: 0.05
activation: gelu
num_filters: [128, 128, 128, 4]
Expand Down
Loading