-
Notifications
You must be signed in to change notification settings - Fork 144
consolidate accuracy check APIs #1910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -213,8 +213,22 @@ class BenchmarkResult(NamedTuple): | |
| } | ||
|
|
||
|
|
||
| def _assert_close(actual: object, expected: object, atol: float, rtol: float) -> None: | ||
| """Like torch.testing.assert_close but handles fp8 and uses chunked comparison for large tensors.""" | ||
| def _assert_close( | ||
| actual: object, | ||
| expected: object, | ||
| *, | ||
| atol: float = 1e-4, | ||
| rtol: float = 1e-4, | ||
| max_mismatch_pct: float | None = None, | ||
| max_abs_diff: float | None = None, | ||
| max_rel_diff: float | None = None, | ||
| ) -> None: | ||
| """Like torch.testing.assert_close but handles fp8, pytree structures, and strings. | ||
|
|
||
| For tensors, uses chunked comparison for large tensors. When | ||
| *max_mismatch_pct* is set, falls back to a relaxed mismatch-fraction check | ||
| instead of raising immediately on the first out-of-tolerance element. | ||
| """ | ||
|
|
||
| def convert(t: torch.Tensor) -> torch.Tensor: | ||
| return t.view(torch.uint8) if t.dtype in _FP8_DTYPES else t | ||
|
|
@@ -235,7 +249,35 @@ def convert(t: torch.Tensor) -> torch.Tensor: | |
|
|
||
| for a, e in zip(actual_flat, expected_flat, strict=True): | ||
| if isinstance(a, torch.Tensor): | ||
| _chunked_assert_close(a, e, atol=atol, rtol=rtol) | ||
| if max_mismatch_pct is not None: | ||
| try: | ||
| _chunked_assert_close(a, e, atol=atol, rtol=rtol) | ||
| continue | ||
| except AssertionError: | ||
| pass | ||
| abs_diff = (a - e).abs() | ||
| total = a.numel() | ||
| mismatched = (abs_diff > atol + rtol * e.abs()).sum().item() | ||
| mismatch_pct = mismatched / total if total > 0 else 0.0 | ||
| if mismatch_pct > max_mismatch_pct: | ||
| raise AssertionError( | ||
| f"Too many mismatches: {mismatch_pct:.4%} > {max_mismatch_pct:.4%}" | ||
| ) | ||
|
Comment on lines
+260
to
+265
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it semantically different to do this in a chunked way? You are measuring percentage different on a chunk not on the entire tensor.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chunk is handled inside |
||
| if max_abs_diff is not None: | ||
| worst_abs = abs_diff.max().item() | ||
| if worst_abs > max_abs_diff: | ||
| raise AssertionError( | ||
| f"Absolute diff too large: {worst_abs} > {max_abs_diff}" | ||
| ) | ||
| if max_rel_diff is not None: | ||
| rel_diff = abs_diff / e.abs().clamp(min=1e-6) | ||
| worst_rel = rel_diff.max().item() | ||
| if worst_rel > max_rel_diff: | ||
| raise AssertionError( | ||
| f"Relative diff too large: {worst_rel:.2f} > {max_rel_diff}" | ||
| ) | ||
| else: | ||
| _chunked_assert_close(a, e, atol=atol, rtol=rtol) | ||
| elif isinstance(a, str): | ||
| if not isinstance(e, str): | ||
| raise AssertionError(f"Type mismatch {a} vs {e}") | ||
|
|
@@ -584,30 +626,17 @@ def _validate_against_baseline( | |
| self, config: Config, output: object, args: Sequence[object] | ||
| ) -> bool: | ||
| try: | ||
| custom_check = self.settings.autotune_baseline_accuracy_check_fn | ||
| if custom_check is not None: | ||
| custom_check(output, self._baseline_output) | ||
| if len(self._mutated_arg_indices) > 0: | ||
| custom_check(args, self._baseline_post_args) | ||
| else: | ||
| _assert_close( | ||
| output, | ||
| self._baseline_output, | ||
| atol=self._effective_atol, | ||
| rtol=self._effective_rtol, | ||
| check_fn = ( | ||
| self.settings.autotune_baseline_accuracy_check_fn | ||
| or functools.partial( | ||
| _assert_close, atol=self._effective_atol, rtol=self._effective_rtol | ||
| ) | ||
| if os.getenv("CHECK_INPUT_ACCURACY", "1") == "1": | ||
| if len(self._mutated_arg_indices) > 0: | ||
| # For distributed kernel, group_name may also be a argument. | ||
| # torch.testing.assert_close does not handle str argument. | ||
| # Filter needed. | ||
| assert self._baseline_post_args is not None | ||
| _assert_close( | ||
| args, | ||
| self._baseline_post_args, | ||
| atol=self._effective_atol, | ||
| rtol=self._effective_rtol, | ||
| ) | ||
| ) | ||
| check_fn(output, self._baseline_output) | ||
| if os.getenv("CHECK_INPUT_ACCURACY", "1") == "1": | ||
| if len(self._mutated_arg_indices) > 0: | ||
| assert self._baseline_post_args is not None | ||
| check_fn(args, self._baseline_post_args) | ||
| except AssertionError as e: | ||
| if not self.settings.autotune_ignore_errors: | ||
| self.log.warning( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different from atol/rtol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't use these 2 arguments and just copied over from the existing code. @yf225 is there any usage of them?