Skip to content

Commit 2bd0764

Browse files
committed
Fix test imports: move helpers to importable module
conftest.py is auto-loaded by pytest and can't be imported as a regular module. Move get_device/skip_if_no_metal to helpers.py and remove __init__.py so pytest adds tests/ to sys.path. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
1 parent a70d94c commit 2bd0764

4 files changed

Lines changed: 34 additions & 37 deletions

File tree

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,8 @@
1-
"""Shared test utilities for Metal kernel tests."""
1+
"""Pytest conftest — fixtures auto-injected by pytest."""
22

33
import pytest
4-
import torch
54

6-
7-
# ── Device detection ──
8-
9-
10-
def get_device():
11-
"""Get the best available compute device."""
12-
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
13-
return torch.device("mps")
14-
if torch.cuda.is_available():
15-
return torch.device("cuda")
16-
return torch.device("cpu")
17-
18-
19-
def skip_if_no_metal():
20-
"""Skip test if MPS device is not available."""
21-
if get_device().type != "mps":
22-
pytest.skip("Metal kernel requires MPS device")
23-
24-
25-
# ── Tolerance helpers ──
26-
27-
DEFAULT_TOLERANCES = {
28-
torch.float32: {"atol": 1e-5, "rtol": 1e-5},
29-
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
30-
torch.bfloat16: {"atol": 1e-2, "rtol": 1.6e-2},
31-
}
32-
33-
34-
def get_tolerances(dtype):
35-
"""Get atol/rtol for a given dtype."""
36-
return DEFAULT_TOLERANCES.get(dtype, {"atol": 0.1, "rtol": 0.1})
37-
38-
39-
# ── Fixtures ──
5+
from helpers import get_device
406

417

428
@pytest.fixture

tests/helpers.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Shared test helpers for Metal kernel tests."""
2+
3+
import pytest
4+
import torch
5+
6+
7+
def get_device():
8+
"""Get the best available compute device."""
9+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
10+
return torch.device("mps")
11+
if torch.cuda.is_available():
12+
return torch.device("cuda")
13+
return torch.device("cpu")
14+
15+
16+
def skip_if_no_metal():
17+
"""Skip test if MPS device is not available."""
18+
if get_device().type != "mps":
19+
pytest.skip("Metal kernel requires MPS device")
20+
21+
22+
DEFAULT_TOLERANCES = {
23+
torch.float32: {"atol": 1e-5, "rtol": 1e-5},
24+
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
25+
torch.bfloat16: {"atol": 1e-2, "rtol": 1.6e-2},
26+
}
27+
28+
29+
def get_tolerances(dtype):
30+
"""Get atol/rtol for a given dtype."""
31+
return DEFAULT_TOLERANCES.get(dtype, {"atol": 0.1, "rtol": 0.1})

tests/test_dequant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55

6-
from conftest import get_device, skip_if_no_metal
6+
from helpers import get_device, skip_if_no_metal
77

88

99
# ── Pure PyTorch reference implementations ──

0 commit comments

Comments
 (0)