Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 3 additions & 80 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from .autotuner.benchmarking import synchronize_device
from .runtime.settings import _get_backend
from .runtime.settings import is_pallas_interpret
from helion.autotuner.base_search import (
_assert_close as assert_close_with_mismatch_tolerance,
)
from helion.autotuner.base_search import _clone_args

if _get_backend() == "pallas":
Expand Down Expand Up @@ -1618,83 +1621,3 @@ def capture_output(self) -> Generator[_OutputCapture, None, None]:
yield capture
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr


def assert_close_with_mismatch_tolerance(
actual: object,
expected: object,
*,
atol: float = 1e-4,
rtol: float = 1e-4,
max_mismatch_pct: float = 0.01,
max_abs_diff: float | None = None,
max_rel_diff: float | None = None,
) -> None:
"""Check that actual and expected are close, tolerating a small fraction of mismatches.

First tries ``torch.testing.assert_close`` with the given *atol*/*rtol*.
If that fails **and** both arguments are tensors, falls back to a relaxed
check using the same mismatch definition as ``torch.testing.assert_close``
(``|actual - expected| > atol + rtol * |expected|``):

- *max_mismatch_pct*: maximum allowed fraction of mismatched elements
(default 1%). Always checked.
- *max_abs_diff*: if not None, the greatest absolute difference across
all elements must not exceed this value.
- *max_rel_diff*: if not None, the greatest relative difference
(``|actual - expected| / |expected|``) must not exceed this value.

This is useful for kernels where most elements match but a tiny
fraction have large relative differences. Pass this function directly as
``autotune_baseline_accuracy_check_fn`` for the default thresholds, or use
``functools.partial`` to customize them::

from functools import partial
from helion._testing import assert_close_with_mismatch_tolerance

@helion.kernel(
autotune_baseline_accuracy_check_fn=partial(
assert_close_with_mismatch_tolerance,
max_mismatch_pct=0.05,
max_abs_diff=10.0,
max_rel_diff=15.0,
),
)
def my_kernel(...): ...
"""
try:
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
return
except AssertionError:
if not (
isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor)
):
raise

abs_diff = (actual - expected).abs()
total = actual.numel()

# Use the same mismatch definition as torch.testing.assert_close:
# an element is mismatched when |actual - expected| > atol + rtol * |expected|
mismatched = (abs_diff > atol + rtol * expected.abs()).sum().item()
mismatch_pct = mismatched / total if total > 0 else 0.0

if mismatch_pct > max_mismatch_pct:
raise AssertionError(
f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}"
)

if max_abs_diff is not None:
worst_abs = abs_diff.max().item()
if worst_abs > max_abs_diff:
raise AssertionError(
f"Absolute diff too large: {worst_abs} > {max_abs_diff}"
)

if max_rel_diff is not None:
rel_diff = abs_diff / expected.abs().clamp(min=1e-6)
worst_rel = rel_diff.max().item()
if worst_rel > max_rel_diff:
raise AssertionError(
f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}"
)
81 changes: 55 additions & 26 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,22 @@ class BenchmarkResult(NamedTuple):
}


def _assert_close(actual: object, expected: object, atol: float, rtol: float) -> None:
"""Like torch.testing.assert_close but handles fp8 and uses chunked comparison for large tensors."""
def _assert_close(
actual: object,
expected: object,
*,
atol: float = 1e-4,
rtol: float = 1e-4,
max_mismatch_pct: float | None = None,
max_abs_diff: float | None = None,
max_rel_diff: float | None = None,
Comment on lines +223 to +224
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this different from atol/rtol?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't use these 2 arguments and just copied over from the existing code. @yf225 is there any usage of them?

) -> None:
"""Like torch.testing.assert_close but handles fp8, pytree structures, and strings.

For tensors, uses chunked comparison for large tensors. When
*max_mismatch_pct* is set, falls back to a relaxed mismatch-fraction check
instead of raising immediately on the first out-of-tolerance element.
"""

def convert(t: torch.Tensor) -> torch.Tensor:
return t.view(torch.uint8) if t.dtype in _FP8_DTYPES else t
Expand All @@ -235,7 +249,35 @@ def convert(t: torch.Tensor) -> torch.Tensor:

for a, e in zip(actual_flat, expected_flat, strict=True):
if isinstance(a, torch.Tensor):
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
if max_mismatch_pct is not None:
try:
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
continue
except AssertionError:
pass
abs_diff = (a - e).abs()
total = a.numel()
mismatched = (abs_diff > atol + rtol * e.abs()).sum().item()
mismatch_pct = mismatched / total if total > 0 else 0.0
if mismatch_pct > max_mismatch_pct:
raise AssertionError(
f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}"
)
Comment on lines +260 to +265
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it semantically different to do this in a chunked way? You are measuring percentage different on a chunk not on the entire tensor.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chunk is handled inside _chunked_assert_close. This part of code operates on the entire tensor though

if max_abs_diff is not None:
worst_abs = abs_diff.max().item()
if worst_abs > max_abs_diff:
raise AssertionError(
f"Absolute diff too large: {worst_abs} > {max_abs_diff}"
)
if max_rel_diff is not None:
rel_diff = abs_diff / e.abs().clamp(min=1e-6)
worst_rel = rel_diff.max().item()
if worst_rel > max_rel_diff:
raise AssertionError(
f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}"
)
else:
_chunked_assert_close(a, e, atol=atol, rtol=rtol)
elif isinstance(a, str):
if not isinstance(e, str):
raise AssertionError(f"Type mismatch {a} vs {e}")
Expand Down Expand Up @@ -584,30 +626,17 @@ def _validate_against_baseline(
self, config: Config, output: object, args: Sequence[object]
) -> bool:
try:
custom_check = self.settings.autotune_baseline_accuracy_check_fn
if custom_check is not None:
custom_check(output, self._baseline_output)
if len(self._mutated_arg_indices) > 0:
custom_check(args, self._baseline_post_args)
else:
_assert_close(
output,
self._baseline_output,
atol=self._effective_atol,
rtol=self._effective_rtol,
check_fn = (
self.settings.autotune_baseline_accuracy_check_fn
or functools.partial(
_assert_close, atol=self._effective_atol, rtol=self._effective_rtol
)
if os.getenv("CHECK_INPUT_ACCURACY", "1") == "1":
if len(self._mutated_arg_indices) > 0:
# For distributed kernel, group_name may also be a argument.
# torch.testing.assert_close does not handle str argument.
# Filter needed.
assert self._baseline_post_args is not None
_assert_close(
args,
self._baseline_post_args,
atol=self._effective_atol,
rtol=self._effective_rtol,
)
)
check_fn(output, self._baseline_output)
if os.getenv("CHECK_INPUT_ACCURACY", "1") == "1":
if len(self._mutated_arg_indices) > 0:
assert self._baseline_post_args is not None
check_fn(args, self._baseline_post_args)
except AssertionError as e:
if not self.settings.autotune_ignore_errors:
self.log.warning(
Expand Down
Loading