File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1- """Shared test utilities for Metal kernel tests ."""
1+ """Pytest conftest — fixtures auto-injected by pytest ."""
22
33import pytest
4- import torch
54
6-
7- # ── Device detection ──
8-
9-
10- def get_device ():
11- """Get the best available compute device."""
12- if hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available ():
13- return torch .device ("mps" )
14- if torch .cuda .is_available ():
15- return torch .device ("cuda" )
16- return torch .device ("cpu" )
17-
18-
19- def skip_if_no_metal ():
20- """Skip test if MPS device is not available."""
21- if get_device ().type != "mps" :
22- pytest .skip ("Metal kernel requires MPS device" )
23-
24-
25- # ── Tolerance helpers ──
26-
27- DEFAULT_TOLERANCES = {
28- torch .float32 : {"atol" : 1e-5 , "rtol" : 1e-5 },
29- torch .float16 : {"atol" : 1e-3 , "rtol" : 1e-3 },
30- torch .bfloat16 : {"atol" : 1e-2 , "rtol" : 1.6e-2 },
31- }
32-
33-
34- def get_tolerances (dtype ):
35- """Get atol/rtol for a given dtype."""
36- return DEFAULT_TOLERANCES .get (dtype , {"atol" : 0.1 , "rtol" : 0.1 })
37-
38-
39- # ── Fixtures ──
5+ from helpers import get_device
406
417
428@pytest .fixture
Original file line number Diff line number Diff line change 1+ """Shared test helpers for Metal kernel tests."""
2+
3+ import pytest
4+ import torch
5+
6+
7+ def get_device ():
8+ """Get the best available compute device."""
9+ if hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available ():
10+ return torch .device ("mps" )
11+ if torch .cuda .is_available ():
12+ return torch .device ("cuda" )
13+ return torch .device ("cpu" )
14+
15+
16+ def skip_if_no_metal ():
17+ """Skip test if MPS device is not available."""
18+ if get_device ().type != "mps" :
19+ pytest .skip ("Metal kernel requires MPS device" )
20+
21+
22+ DEFAULT_TOLERANCES = {
23+ torch .float32 : {"atol" : 1e-5 , "rtol" : 1e-5 },
24+ torch .float16 : {"atol" : 1e-3 , "rtol" : 1e-3 },
25+ torch .bfloat16 : {"atol" : 1e-2 , "rtol" : 1.6e-2 },
26+ }
27+
28+
29+ def get_tolerances (dtype ):
30+ """Get atol/rtol for a given dtype."""
31+ return DEFAULT_TOLERANCES .get (dtype , {"atol" : 0.1 , "rtol" : 0.1 })
Original file line number Diff line number Diff line change 33import pytest
44import torch
55
6- from conftest import get_device , skip_if_no_metal
6+ from helpers import get_device , skip_if_no_metal
77
88
99# ── Pure PyTorch reference implementations ──
You can’t perform that action at this time.
0 commit comments