Skip to content

Commit 9399bdc

Browse files
add test (#6)
1 parent 1c6833b commit 9399bdc

5 files changed

Lines changed: 91 additions & 5 deletions

File tree

.pre-commit-config.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ repos:
2020
rev: v3.2.0
2121
hooks:
2222
- id: add-trailing-comma
23-
- repo: https://github.com/PyCQA/flake8
24-
rev: 7.2.0
25-
hooks:
26-
- id: flake8
27-
args: ["--ignore=E501,W503,E203"]
2823
- repo: https://github.com/ibm/detect-secrets
2924
rev: 0.13.1+ibm.62.dss
3025
hooks:

tests/create_reference.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# save_reference_output.py
2+
import torch
3+
import numpy as np
4+
5+
torch.manual_seed(0)
6+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7+
8+
num_nodes = 1
9+
x_dim = 9
10+
pe_dim = 20
11+
edge_attr_dim = 2
12+
13+
# Dummy all-zero input
14+
x = torch.zeros((num_nodes, x_dim)).to(device)
15+
pe = torch.zeros((num_nodes, pe_dim)).to(device)
16+
edge_index = torch.tensor([[0], [0]]).to(device)
17+
edge_attr = torch.zeros((1, edge_attr_dim)).to(device)
18+
batch = torch.zeros(num_nodes, dtype=torch.long).to(device)
19+
20+
21+
models = {
22+
"v0_1_2": "../examples/models/GridFM_v0_1_2.pth",
23+
"v0_2_3": "../examples/models/GridFM_v0_2_3.pth",
24+
}
25+
26+
for version, path in models.items():
27+
print(f"Loading model {version}...")
28+
model = torch.load(path, weights_only=False, map_location=device).to(device)
29+
model.eval()
30+
31+
with torch.no_grad():
32+
output = model(x, pe, edge_index, edge_attr, batch)
33+
34+
out_path = f"./data/reference_output_{version}.npy"
35+
np.save(out_path, output.cpu().numpy())
36+
print(f"Saved output for {version} to {out_path}")
152 Bytes
Binary file not shown.
152 Bytes
Binary file not shown.

tests/test_model_outputs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
import numpy as np
3+
import pytest
4+
5+
# Device setup
6+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7+
8+
# Input shape config
9+
num_nodes = 1
10+
x_dim = 9
11+
pe_dim = 20
12+
edge_attr_dim = 2
13+
14+
# List of models and reference files to check
15+
models_to_test = [
16+
(
17+
"v0_1_2",
18+
"examples/models/GridFM_v0_1_2.pth",
19+
"tests/data/reference_output_v0_1_2.npy",
20+
),
21+
(
22+
"v0_2_3",
23+
"examples/models/GridFM_v0_2_3.pth",
24+
"tests/data/reference_output_v0_2_3.npy",
25+
),
26+
]
27+
28+
29+
@pytest.mark.parametrize("version, model_path, ref_output_path", models_to_test)
30+
def test_model_matches_reference(version, model_path, ref_output_path):
31+
torch.manual_seed(0)
32+
33+
# Prepare zero input
34+
x = torch.zeros((num_nodes, x_dim), device=device)
35+
pe = torch.zeros((num_nodes, pe_dim), device=device)
36+
edge_index = torch.tensor([[0], [0]], device=device)
37+
edge_attr = torch.zeros((1, edge_attr_dim), device=device)
38+
batch = torch.zeros(num_nodes, dtype=torch.long, device=device)
39+
40+
# Load model
41+
model = torch.load(model_path, weights_only=False, map_location=device).to(device)
42+
model.eval()
43+
44+
# Get current output
45+
with torch.no_grad():
46+
output = model(x, pe, edge_index, edge_attr, batch).cpu().numpy()
47+
48+
# Load saved reference
49+
reference = np.load(ref_output_path)
50+
51+
# Exact match assertion
52+
assert np.allclose(output, reference, rtol=1e-5, atol=1e-6), (
53+
f"Model output for {version} does not match reference within tolerance.\n"
54+
f"Max absolute difference: {np.max(np.abs(output - reference))}"
55+
)

0 commit comments

Comments
 (0)