diff --git a/tests/_comparison.py b/tests/_comparison.py new file mode 100644 index 0000000..66b8bc2 --- /dev/null +++ b/tests/_comparison.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Comparison helpers for tests. + +``assert_equal`` is :func:`torch.testing.assert_close` with ``atol=rtol=0``: it asserts that two +tensors are exactly (bitwise) equal. Use it for bit-identity checks instead of repeating +``atol=0, rtol=0``. +""" + +import functools + +import torch +from torch import testing as torch_testing + + +assert_equal = functools.partial(torch_testing.assert_close, rtol=0, atol=0) + + +def assert_close_to_identity(actual, *, off_diag_atol=0, diag_atol=0): + r"""Assert that ``actual`` is close to the identity matrix. + + Checks the identity structure with separate tolerances for the diagonal (compared to ones) and the + off-diagonal entries (compared to zeros). This is more informative, and allows looser tolerances on + the off-diagonal, than comparing the whole matrix to ``torch.eye`` with a single tolerance. + + Args: + actual: A square 2D tensor expected to be (approximately) the identity matrix. + off_diag_atol: Absolute tolerance for the off-diagonal entries (compared to 0). + diag_atol: Absolute tolerance for the diagonal entries (compared to 1). + + Raises: + ValueError: If ``actual`` is not a square 2D matrix. + AssertionError: If ``actual`` is not close to the identity matrix. + """ + if actual.ndim != 2 or actual.shape[0] != actual.shape[1]: + raise ValueError(f"actual must be a square 2D matrix, got shape {tuple(actual.shape)}") + + n = actual.shape[-1] + off_diag_mask = ~torch.eye(n, dtype=torch.bool, device=actual.device) + diag = torch.diagonal(actual) + off_diag = actual[off_diag_mask] + torch_testing.assert_close( + diag, + torch.ones_like(diag), + atol=diag_atol, + rtol=0, + msg=lambda msg: f"Diagonal is not close to ones.\n\n{msg}", + ) + torch_testing.assert_close( + off_diag, + torch.zeros_like(off_diag), + atol=off_diag_atol, + rtol=0, + msg=lambda msg: f"Off-diagonal is not close to zeros.\n\n{msg}", + ) + + +def assert_close_to_orthogonal(actual, *, off_diag_atol=0, diag_atol=0): + r"""Assert that a 2D matrix is (semi-)orthogonal. + + Builds the Gram matrix over the smaller dimension (``X @ Xᵀ`` when ``X`` has no more rows than + columns, otherwise ``Xᵀ @ X``) and asserts it is close to the identity via + :func:`assert_close_to_identity`. + + Args: + actual: A 2D tensor expected to have (semi-)orthonormal rows or columns. + off_diag_atol: Absolute tolerance for the off-diagonal entries of the Gram matrix. + diag_atol: Absolute tolerance for the diagonal entries of the Gram matrix. + + Raises: + ValueError: If ``actual`` is not a 2D matrix. + AssertionError: If ``actual`` is not (semi-)orthogonal. + """ + if actual.ndim != 2: + raise ValueError(f"actual must be a 2D matrix, got shape {tuple(actual.shape)}") + + m, n = actual.shape + gram = actual @ actual.mT if m <= n else actual.mT @ actual + assert_close_to_identity(gram, off_diag_atol=off_diag_atol, diag_atol=diag_atol) diff --git a/tests/test_distributed_muon_utils_cpu.py b/tests/test_distributed_muon_utils_cpu.py index 7a0d8c1..03a3094 100644 --- a/tests/test_distributed_muon_utils_cpu.py +++ b/tests/test_distributed_muon_utils_cpu.py @@ -17,6 +17,7 @@ import numpy as np import torch +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -195,7 +196,7 @@ def test_fall_back_to_non_tp(self, shape): ) ref_out = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic") - torch.testing.assert_close(test_out, ref_out, atol=0, rtol=0) + assert_equal(test_out, ref_out) @parameterized.product( shape=((20, 16), (16, 32)), diff --git a/tests/test_distributed_rekls_cpu.py b/tests/test_distributed_rekls_cpu.py index 9d9eb80..3c3913a 100644 --- a/tests/test_distributed_rekls_cpu.py +++ b/tests/test_distributed_rekls_cpu.py @@ -16,6 +16,7 @@ import sys import torch +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -122,11 +123,9 @@ def test_5steps_matches_non_distributed_rekls(self): ref_local = ref_p.detach() else: ref_local = ref_p.detach().chunk(self.world_size, dim=pd)[self.rank] - torch.testing.assert_close( + assert_equal( tp_p.detach(), ref_local, - atol=0, - rtol=0, ) diff --git a/tests/test_distributed_soap_utils_cpu.py b/tests/test_distributed_soap_utils_cpu.py index 55b1dc3..0370d29 100644 --- a/tests/test_distributed_soap_utils_cpu.py +++ b/tests/test_distributed_soap_utils_cpu.py @@ -16,6 +16,7 @@ import sys import torch +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -75,9 +76,9 @@ def test_matches_non_distributed(self, shape, partition_dim): tp_group=self.tp_group, ) - torch.testing.assert_close(gathered_grad, full_grad, atol=0, rtol=0) - torch.testing.assert_close(gathered_factors[0], full_l, atol=0, rtol=0) - torch.testing.assert_close(gathered_factors[1], full_r, atol=0, rtol=0) + assert_equal(gathered_grad, full_grad) + assert_equal(gathered_factors[0], full_l) + assert_equal(gathered_factors[1], full_r) @parameterized.product( shape=((16, 32), (32, 16), (96, 200)), @@ -113,8 +114,8 @@ def test_updated_factors_match_non_distributed(self, shape, partition_dim): ) soap.update_kronecker_factors(gathered_factors, gathered_grad, shampoo_beta=shampoo_beta) - torch.testing.assert_close(gathered_factors[0], ref_l, atol=0, rtol=0) - torch.testing.assert_close(gathered_factors[1], ref_r, atol=0, rtol=0) + assert_equal(gathered_factors[0], ref_l) + assert_equal(gathered_factors[1], ref_r) if __name__ == "__main__": diff --git a/tests/test_eig_utils.py b/tests/test_eig_utils.py index 49d0e6c..2b1ae5e 100644 --- a/tests/test_eig_utils.py +++ b/tests/test_eig_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -145,7 +146,7 @@ def test_conjugate_match_reference(self) -> None: _, p = torch.linalg.eigh(a) ref = p.T @ a @ p - torch.testing.assert_close(eig_utils.conjugate(a, p), ref, atol=0, rtol=0) + assert_equal(eig_utils.conjugate(a, p), ref) def test_eigh_with_fallback_reraises_runtime_error_when_force_double(self) -> None: """Test that eigh_with_fallback re-raises when force_double=True and eigh fails.""" diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index dbc56e6..38585c1 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -16,6 +16,7 @@ from copy import deepcopy import torch +from _comparison import assert_close_to_orthogonal, assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -83,7 +84,7 @@ def tearDown(self): (512, 256), (256, 512), ) - def test_newtonschulz5_svd_close(self, dim1, dim2): + def test_newtonschulz5_close_to_svd(self, dim1, dim2): shape = (dim1, dim2) x = torch.randn(*shape, device=self.device, dtype=torch.float32) out_zeropowerns = muon_utils.newton_schulz(x, steps=5, coefficient_type="quintic") @@ -262,6 +263,18 @@ def test_polar_express_and_deepseekv4_10steps_better_than_quintic(self, size, co f"{coefficient_type} norm is larger than Quintic norm: {l2_norm_diff_polar:.6f} > {l2_norm_diff_quintic:.6f}", ) + @parameterized.product(size=[(512, 256), (256, 512)]) + def test_polar_express_16steps_almost_orthogonal(self, size): + """Polar Express Newton-Schulz with enough steps yields an almost-orthogonal matrix. + + The output ``O`` should satisfy ``O Oᵀ ≈ I`` (or ``Oᵀ O ≈ I`` for the tall case), i.e. its + smaller-dimension Gram is close to the identity. + """ + dim1, dim2 = size + x = torch.randn(dim1, dim2, device=self.device, dtype=torch.float32) + out = muon_utils.newton_schulz(x, steps=16, coefficient_type="polar_express") + assert_close_to_orthogonal(out, diag_atol=1e-5, off_diag_atol=1e-5) + @parameterized.parameters( (511, 513), (511, 257), @@ -461,7 +474,7 @@ def test_match_newton_schulz_step_by_gemm(self, dim1, dim2): test_out = muon_utils.newton_schulz_step_tsyrk(x, 2**-1, 2**-2, 2**-3) test_ref = muon_utils.newton_schulz_step(x, 2**-1, 2**-2, 2**-3) - torch.testing.assert_close(test_out, test_ref, atol=0, rtol=0) + assert_equal(test_out, test_ref) if __name__ == "__main__": diff --git a/tests/test_normalized_optimizer.py b/tests/test_normalized_optimizer.py index bfc8d95..77868f6 100644 --- a/tests/test_normalized_optimizer.py +++ b/tests/test_normalized_optimizer.py @@ -14,6 +14,7 @@ # limitations under the License. import torch +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -164,22 +165,18 @@ def test_oblique_sgd_momentum_buffer_accumulates_across_steps(self) -> None: param.grad = first_grad.clone() optimizer.step() - torch.testing.assert_close( + assert_equal( optimizer.state[param]["momentum_buffer"], first_grad, - atol=0, - rtol=0, ) param.grad = second_grad.clone() optimizer.step() expected_buffer = second_grad + 0.8 * first_grad - torch.testing.assert_close( + assert_equal( optimizer.state[param]["momentum_buffer"], expected_buffer, - atol=0, - rtol=0, ) def test_oblique_adam_zero_gradient(self) -> None: diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 9453be8..4bb719f 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -14,6 +14,7 @@ # limitations under the License. import torch import torch.nn as nn +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -94,11 +95,9 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None: orthogonalized_opt.step() sgd_opt.step() - torch.testing.assert_close( + assert_equal( test_param.data, ref_param.data, - atol=0, - rtol=0, ) @parameterized.parameters( @@ -139,11 +138,9 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> orthogonalized_opt.step() sgd_opt.step() - torch.testing.assert_close( + assert_equal( test_param.data, ref_param.data, - atol=0, - rtol=0, ) @parameterized.parameters( @@ -206,11 +203,9 @@ def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor: assert not torch.allclose(test_param, test_param.grad) ref_out = dummy_interleaved_split_orth_fn(test_param.grad) - torch.testing.assert_close( + assert_equal( test_param, ref_out, - atol=0, - rtol=0, ) def test_non_2d_param_raises_value_error(self) -> None: @@ -292,11 +287,9 @@ def test_use_independent_wd(self) -> None: ) muon_opt_indep.step() - torch.testing.assert_close( + assert_equal( test_param, expected_param, - atol=0, - rtol=0, ) def test_zero_num_ns_steps_raises_value_error(self) -> None: diff --git a/tests/test_polargrad.py b/tests/test_polargrad.py index 9bde161..56dead5 100644 --- a/tests/test_polargrad.py +++ b/tests/test_polargrad.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -74,11 +75,9 @@ def test_orthogonalize_fn_matches_ref(self, shape, extra_scale_factor) -> None: test_out = polargrad_opt.scaled_orthogonalize_fn(dummy_grad) - torch.testing.assert_close( + assert_equal( ref_out, test_out, - atol=0, - rtol=0, ) def test_negative_num_ns_steps_raises_value_error(self) -> None: @@ -104,8 +103,8 @@ def test_right_orthogonal_equivariance(self, shape, center_rows) -> None: torch.testing.assert_close( rotated, expected, - atol=1e-5, - rtol=1e-5, + atol=1e-4, + rtol=1e-4, ) def test_usable_as_scaled_orthogonalize_fn(self) -> None: diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index b6b6dd4..b2eaa08 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from _comparison import assert_equal from absl import flags, logging, testing from absl.testing import parameterized @@ -282,7 +283,7 @@ def test_calculate_signum_update_returns_sign(self, shape, momentum, correct_bia grad, exp_avg, momentum=momentum, correct_bias=correct_bias, nesterov=nesterov, step=step ) - torch.testing.assert_close(update.abs(), torch.ones(shape, device=self.device), atol=0, rtol=0) + assert_equal(update.abs(), torch.ones(shape, device=self.device)) def test_calculate_signum_with_shape_scaling_returns_sign(self) -> None: shape = (8, 12) @@ -300,7 +301,7 @@ def test_calculate_signum_with_shape_scaling_returns_sign(self) -> None: use_shape_scaling=True, ).abs() expected_update = torch.sign(exp_avg).abs() * (2 / (shape[0] + shape[1])) - torch.testing.assert_close(update_abs, expected_update, atol=0, rtol=0) + assert_equal(update_abs, expected_update) def test_calculate_lion_update_returns_sign(self) -> None: """Tests that Lion update returns sign of interpolated momentum.""" @@ -314,7 +315,7 @@ def test_calculate_lion_update_returns_sign(self) -> None: # Update should be sign(beta * m + (1 - beta) * g) expected_update = torch.sign(beta * exp_avg_clone + (1 - beta) * grad) - torch.testing.assert_close(update, expected_update, atol=0, rtol=0) + assert_equal(update, expected_update) # exp_avg should be updated in-place: lerp_(grad, 1 - beta) expected_exp_avg = torch.lerp(exp_avg_clone, grad, 1 - beta) @@ -331,7 +332,7 @@ def test_calculate_lion_update_with_separate_betas(self) -> None: update = update_functions.calculate_lion_update(grad, exp_avg, betas=(beta1, beta2), step=1) expected_update = torch.sign(beta1 * exp_avg_clone + (1 - beta1) * grad) - torch.testing.assert_close(update, expected_update, atol=0, rtol=0) + assert_equal(update, expected_update) # With separate beta2, momentum uses beta2 expected_exp_avg = torch.lerp(exp_avg_clone, grad, 1 - beta2) @@ -403,21 +404,17 @@ def test_calculate_madam_update_matches_adam_eps_zero(self, scale_log2: int, cor ) case = f"scale_log2={scale_log2}, correct_bias={correct_bias}" - torch.testing.assert_close( + assert_equal( madam_update, adam_update, - atol=0, - rtol=0, msg=lambda msg: f"MAdam vs Adam(eps=0) mismatch at {case}:\n\n{msg}", ) # First-moment EMA is unaffected by scaling. - torch.testing.assert_close(exp_avg_madam, exp_avg_adam, atol=0, rtol=0) + assert_equal(exp_avg_madam, exp_avg_adam) # Second-moment storage differs by exactly the (power-of-two) scale. - torch.testing.assert_close( + assert_equal( exp_avg_sq_scaled, exp_avg_sq_adam * (2**scale_log2), - atol=0, - rtol=0, msg=lambda msg: f"exp_avg_sq_scaled != s * exp_avg_sq at {case}:\n\n{msg}", ) @@ -446,11 +443,9 @@ def test_calculate_madam_update_5steps_zero_masked_is_finite(self) -> None: ) # Masked column: exactly zero (not NaN/Inf). - torch.testing.assert_close( + assert_equal( update[:, zero_col], torch.zeros(shape[0], device=self.device, dtype=update.dtype), - atol=0, - rtol=0, ) # Non-masked columns: finite and non-zero. nonzero_col_idx = [c for c in range(shape[1]) if c != zero_col] @@ -487,7 +482,7 @@ def test_no_grad_no_update_params_unchanged(self) -> None: original = param.detach().clone() opt = self.OPTIMIZER_CLS([param], lr=1e-4) opt.step() - torch.testing.assert_close(param, original, atol=0, rtol=0) + assert_equal(param, original) def test_state_keys_after_first_step(self) -> None: """First step populates exactly the expected state keys, with step==1 and matching-shape buffers.""" @@ -598,7 +593,7 @@ def test_update_is_sign_based(self, betas, shape) -> None: old_param = param.data.clone() opt.step() diff = old_param - param.data - torch.testing.assert_close(diff.abs(), torch.full_like(diff, 0.25), atol=0, rtol=0) + assert_equal(diff.abs(), torch.full_like(diff, 0.25)) @parameterized.parameters( {"shape": (3, 3)}, @@ -633,7 +628,7 @@ def test_weight_decay_decoupled_matches_analytical(self, shape) -> None: param.grad = torch.zeros(*shape, device=self.device) old_param = param.data.clone() opt.step() - torch.testing.assert_close(param.data, old_param * (1 - lr * wd), atol=0, rtol=0) + assert_equal(param.data, old_param * (1 - lr * wd)) @parameterized.parameters( {"shape": (3, 3)}, @@ -646,7 +641,7 @@ def test_weight_decay_independent_matches_analytical(self, shape) -> None: param.grad = torch.zeros(*shape, device=self.device) old_param = param.data.clone() opt.step() - torch.testing.assert_close(param.data, old_param * (1 - wd), atol=0, rtol=0) + assert_equal(param.data, old_param * (1 - wd)) @parameterized.parameters( {"shape": (3, 3)}, @@ -660,7 +655,7 @@ def test_weight_decay_l2(self, shape) -> None: param.grad = torch.zeros(*shape, device=self.device) old_param = param.data.clone() opt.step() - torch.testing.assert_close(param.data, old_param - lr, atol=0, rtol=0) + assert_equal(param.data, old_param - lr) @parameterized.parameters( {"shape": (3, 3)}, @@ -674,7 +669,7 @@ def test_weight_decay_l2_masked_by_gradient(self, shape) -> None: param.grad = torch.randint(-10, -5, shape, device=self.device, dtype=torch.float32) old_param = param.data.clone() opt.step() - torch.testing.assert_close(param.data, old_param + lr, atol=0, rtol=0) + assert_equal(param.data, old_param + lr) class SignumOptimizerTest(_CommonScalarOptimizerTests, parameterized.TestCase): @@ -708,7 +703,7 @@ def test_update_is_sign_based(self, shape) -> None: old_param = param.data.clone() opt.step() # bias_correction at step 1 cancels the (1-momentum) factor, so sign(corrected) = sign(grad) = +1. - torch.testing.assert_close(old_param - param.data, torch.full(shape, 0.25, device=self.device), atol=0, rtol=0) + assert_equal(old_param - param.data, torch.full(shape, 0.25, device=self.device)) class LaPropOptimizerTest(_CommonScalarOptimizerTests, _HasBetasTests, _HasEpsTests, parameterized.TestCase): @@ -741,8 +736,8 @@ def test_state_evolves_correctly(self, shape) -> None: expected_exp_avg_sq = (1 - beta2) * grad.square() normalized_grad = grad / (grad.abs() + opt.param_groups[0]["eps"]) expected_exp_avg = (1 - beta1) * normalized_grad - torch.testing.assert_close(opt.state[param]["exp_avg_sq"], expected_exp_avg_sq, atol=0, rtol=0) - torch.testing.assert_close(opt.state[param]["exp_avg"], expected_exp_avg, atol=0, rtol=0) + assert_equal(opt.state[param]["exp_avg_sq"], expected_exp_avg_sq) + assert_equal(opt.state[param]["exp_avg"], expected_exp_avg) @parameterized.parameters( {"shape": (3, 3)}, @@ -765,9 +760,9 @@ def test_optimizer_step_matches_update_function(self, shape) -> None: ) param.grad = grad.clone() opt.step() - torch.testing.assert_close(param, old_param - lr * expected_update, atol=0, rtol=0) - torch.testing.assert_close(opt.state[param]["exp_avg"], exp_avg, atol=0, rtol=0) - torch.testing.assert_close(opt.state[param]["exp_avg_sq"], exp_avg_sq, atol=0, rtol=0) + assert_equal(param, old_param - lr * expected_update) + assert_equal(opt.state[param]["exp_avg"], exp_avg) + assert_equal(opt.state[param]["exp_avg_sq"], exp_avg_sq) @parameterized.parameters( {"shape": (3, 3)}, @@ -839,9 +834,9 @@ def test_optimizer_step_matches_update_function(self, shape) -> None: ) param.grad = grad.clone() opt.step() - torch.testing.assert_close(param, old_param - lr * expected_update, atol=0, rtol=0) - torch.testing.assert_close(opt.state[param]["exp_avg"], exp_avg, atol=0, rtol=0) - torch.testing.assert_close(opt.state[param]["exp_avg_sq"], exp_avg_sq, atol=0, rtol=0) + assert_equal(param, old_param - lr * expected_update) + assert_equal(opt.state[param]["exp_avg"], exp_avg) + assert_equal(opt.state[param]["exp_avg_sq"], exp_avg_sq) if __name__ == "__main__": diff --git a/tests/test_soap.py b/tests/test_soap.py index 4ed8f91..e3312c4 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -17,6 +17,7 @@ import soap_reference import torch +from _comparison import assert_close_to_identity, assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -248,14 +249,7 @@ def test_update_eigenbasis_and_exp_avgs(self, M: int, N: int, use_eigh: bool) -> # Check eigenbasis orthogonality for Q in updated_eigenbasis_list: - identity = torch.eye(Q.shape[0], device=Q.device, dtype=Q.dtype) - torch.testing.assert_close( - Q.T @ Q, - identity, - atol=1e-5, - rtol=1e-5, - msg="Updated eigenbasis is not orthogonal.", - ) + assert_close_to_identity(Q.T @ Q, diag_atol=1e-5, off_diag_atol=1e-5) # exp_avg is projected via orthogonal transforms, so norm should be preserved torch.testing.assert_close( @@ -440,11 +434,9 @@ def test_8streams_matches_no_streams(self, use_kl_shampoo: bool, use_eigh: bool) torch.cuda.synchronize() for i, (p_no, p_with) in enumerate(zip(params_no_stream, params_with_stream)): - torch.testing.assert_close( + assert_equal( p_with, p_no, - atol=0, - rtol=0, msg=lambda msg: f"Parameter {i} mismatch at step {step}:\n{msg}", ) diff --git a/tests/test_soap_utils.py b/tests/test_soap_utils.py index 15d1882..b5ffb14 100644 --- a/tests/test_soap_utils.py +++ b/tests/test_soap_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from _comparison import assert_close_to_identity, assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -76,23 +77,8 @@ def test_get_eigenbasis_qr(self, N: int, M: int) -> None: self.assertEqual(Q_new_R.shape, (N, N)) # check Q^T Q ~ I - expected_identity = torch.eye(M, dtype=Q_new_L.dtype, device=Q_new_L.device) - torch.testing.assert_close( - Q_new_L.t() @ Q_new_L, - expected_identity, - atol=1e-5, - rtol=1e-5, - msg="Orthogonalization failed: Q^T Q is not close enough to the identity matrix.", - ) - - expected_identity = torch.eye(N, dtype=Q_new_R.dtype, device=Q_new_R.device) - torch.testing.assert_close( - Q_new_R.t() @ Q_new_R, - expected_identity, - atol=1e-5, - rtol=1e-5, - msg="Orthogonalization failed: Q^T Q is not close enough to the identity matrix.", - ) + assert_close_to_identity(Q_new_L.t() @ Q_new_L, diag_atol=1e-5, off_diag_atol=1e-5) + assert_close_to_identity(Q_new_R.t() @ Q_new_R, diag_atol=1e-5, off_diag_atol=1e-5) @parameterized.parameters( # type: ignore[misc] {"N": 4, "M": 8}, @@ -122,24 +108,20 @@ def test_sort_eigenbasis_by_approx_eigvals(self, N: int, M: int) -> None: # Each eigenbasis is column-permuted by its own sort_idx. for i, (Q_old, Q_sorted) in enumerate(zip(eigenbasis_list, sorted_eigenbasis_list, strict=True)): - torch.testing.assert_close( + assert_equal( Q_sorted, Q_old[:, sort_idx_list[i]], msg=lambda m, i=i: f"eigenbasis i={i} not permuted by sort_idx\n\n{m}", - atol=0, - rtol=0, ) # exp_avg_sq is permuted along every axis cumulatively. expected_sq = exp_avg_sq for i, sort_idx in enumerate(sort_idx_list): expected_sq = expected_sq.index_select(i, sort_idx) - torch.testing.assert_close( + assert_equal( sorted_exp_avg_sq, expected_sq, msg=lambda m: f"exp_avg_sq not permuted to match sorted eigenbases\n\n{m}", - atol=0, - rtol=0, ) # Sorted eigenbases yield descending approximate eigenvalues. @@ -168,16 +150,10 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: self.assertEqual(Q.shape, (orig_dim, orig_dim)) # Check orthogonality: Q.T @ Q should be close to identity due to orthonormal matrix property - identity = torch.eye(orig_dim, dtype=Q.dtype, device=self.device) with utils.fp32_matmul_precision("highest"): orthogonality_check = Q.T @ Q - torch.testing.assert_close( - orthogonality_check, - identity, - atol=1e-3, - rtol=1e-3, - ) + assert_close_to_identity(orthogonality_check, diag_atol=1e-3, off_diag_atol=1e-3) with utils.fp32_matmul_precision("highest"): # Check that Q diagonalizes the original matrix, by checking if off-diagonal elements are close to zero diagonalized_matrix = Q.T @ kronecker_factor_list[i].float() @ Q diff --git a/tests/test_spel.py b/tests/test_spel.py index be96341..b08eb94 100644 --- a/tests/test_spel.py +++ b/tests/test_spel.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn +from _comparison import assert_close_to_identity from absl import flags, logging from absl.testing import absltest, parameterized @@ -85,23 +86,8 @@ def test_post_update_produces_approximately_orthogonal_weights(self, shape) -> N else: gram = W.mT @ W - diag = torch.diagonal(gram) - off_diag = gram[~torch.eye(gram.shape[0], dtype=torch.bool, device=FLAGS.device)] - - # Newton-Schulz produces an approximate orthogonal factor, so check - # the identity structure directly instead of using one tolerance for all entries. - torch.testing.assert_close( - diag, - torch.ones_like(diag), - atol=0.0, - rtol=0.06, - ) - torch.testing.assert_close( - off_diag, - torch.zeros_like(off_diag), - atol=0.06, - rtol=0.0, - ) + # Newton-Schulz produces an approximate orthogonal factor, so allow a loose tolerance. + assert_close_to_identity(gram, diag_atol=0.06, off_diag_atol=0.06) if __name__ == "__main__": diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index 2a13c21..285e56d 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -14,6 +14,7 @@ # limitations under the License. import torch import triton +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -253,7 +254,7 @@ def test_tsyrk_ex_match_matmul(self, n: int, k: int, trans: bool): # warmup the triton kernel to avoid the wrong result from the first run. _ = triton_kernels.tsyrk_ex(a_warmup) c = triton_kernels.tsyrk_ex(a) - torch.testing.assert_close(c, ref, atol=0, rtol=0) + assert_equal(c, ref) @parameterized.product( ({"n": 128, "k": 128}, {"n": 256, "k": 64}), @@ -269,7 +270,7 @@ def test_tsyrk_ex_small_matrix_match_matmul(self, n: int, k: int, trans: bool): # warmup the triton kernel to avoid the wrong result from the first run. _ = triton_kernels.tsyrk_ex_small_matrix(a_warmup) c = triton_kernels.tsyrk_ex_small_matrix(a) - torch.testing.assert_close(c, ref, atol=0, rtol=0) + assert_equal(c, ref) @parameterized.product( ({"n": 128, "alpha": 0.5, "beta": 0.5}, {"n": 256, "alpha": 0.25, "beta": 0.25}), @@ -287,7 +288,7 @@ def test_tsyrk_ex_match_addmm(self, n: int, alpha: float, beta: float, trans: bo # warmup the triton kernel to avoid the wrong result from the first run. _ = triton_kernels.tsyrk_ex(a_warmup, a_warmup, alpha=alpha, beta=beta) c = triton_kernels.tsyrk_ex(a, a, alpha=alpha, beta=beta) - torch.testing.assert_close(c, ref, atol=0, rtol=0) + assert_equal(c, ref) @parameterized.product( ({"n": 128, "alpha": 0.5, "beta": 0.5}, {"n": 256, "alpha": 0.25, "beta": 0.25}), @@ -305,7 +306,7 @@ def test_tsyrk_ex_small_matrix_match_addmm(self, n: int, alpha: float, beta: flo # warmup the triton kernel to avoid the wrong result from the first run. _ = triton_kernels.tsyrk_ex_small_matrix(a_warmup, a_warmup, alpha=alpha, beta=beta) c = triton_kernels.tsyrk_ex_small_matrix(a, a, alpha=alpha, beta=beta) - torch.testing.assert_close(c, ref, atol=0, rtol=0) + assert_equal(c, ref) if __name__ == "__main__": diff --git a/tests/test_utils_modules.py b/tests/test_utils_modules.py index 0059a37..053c8d3 100644 --- a/tests/test_utils_modules.py +++ b/tests/test_utils_modules.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -60,17 +61,15 @@ def test_matches_conv1d(self, in_channels, out_channels, kernel_size, bias, batc y_ref = conv(x) y_test = conv_flat(x) - torch.testing.assert_close(y_ref, y_test, atol=0, rtol=0) + assert_equal(y_ref, y_test) y_ref.sum().backward() y_test.sum().backward() if bias: - torch.testing.assert_close( - conv.weight.grad.view(-1), conv_flat.weight.grad[:, :-1].reshape(-1), atol=0, rtol=0 - ) - torch.testing.assert_close(conv.bias.grad, conv_flat.weight.grad[:, -1], atol=0, rtol=0) + assert_equal(conv.weight.grad.view(-1), conv_flat.weight.grad[:, :-1].reshape(-1)) + assert_equal(conv.bias.grad, conv_flat.weight.grad[:, -1]) else: - torch.testing.assert_close(conv.weight.grad.view(-1), conv_flat.weight.grad.reshape(-1), atol=0, rtol=0) + assert_equal(conv.weight.grad.view(-1), conv_flat.weight.grad.reshape(-1)) @parameterized.product( bias=[False, True], @@ -97,16 +96,14 @@ def test_from_conv1d(self, in_channels, out_channels, kernel_size, bias, batch_s x = torch.randn(batch_size, in_channels, kernel_size, device=self.device) y_ref = conv(x) y_test = conv_flat(x) - torch.testing.assert_close(y_ref, y_test, atol=0, rtol=0) + assert_equal(y_ref, y_test) y_ref.sum().backward() y_test.sum().backward() if bias: - torch.testing.assert_close( - conv.weight.grad.view(-1), conv_flat.weight.grad[:, :-1].reshape(-1), atol=0, rtol=0 - ) - torch.testing.assert_close(conv.bias.grad, conv_flat.weight.grad[:, -1], atol=0, rtol=0) + assert_equal(conv.weight.grad.view(-1), conv_flat.weight.grad[:, :-1].reshape(-1)) + assert_equal(conv.bias.grad, conv_flat.weight.grad[:, -1]) else: - torch.testing.assert_close(conv.weight.grad.view(-1), conv_flat.weight.grad.reshape(-1), atol=0, rtol=0) + assert_equal(conv.weight.grad.view(-1), conv_flat.weight.grad.reshape(-1)) if __name__ == "__main__": diff --git a/tests/test_weight_decay_mixin.py b/tests/test_weight_decay_mixin.py index 0176d50..1f02346 100644 --- a/tests/test_weight_decay_mixin.py +++ b/tests/test_weight_decay_mixin.py @@ -14,6 +14,7 @@ # limitations under the License. import torch +from _comparison import assert_equal from absl import flags, logging from absl.testing import absltest, parameterized @@ -55,8 +56,8 @@ def test_zero_weight_decay_is_noop(self, method): helper._apply_weight_decay_inplace(p, grad, lr=0.1, weight_decay=0.0) - torch.testing.assert_close(p, p_orig, atol=0, rtol=0) - torch.testing.assert_close(grad, grad_orig, atol=0, rtol=0) + assert_equal(p, p_orig) + assert_equal(grad, grad_orig) @parameterized.parameters( {"lr": 0.25, "wd": 0.5}, @@ -73,8 +74,8 @@ def test_decoupled(self, lr, wd): helper._apply_weight_decay_inplace(p, grad, lr=lr, weight_decay=wd) expected_p = p_orig * (1 - wd * lr) - torch.testing.assert_close(p, expected_p, atol=0, rtol=0) - torch.testing.assert_close(grad, grad_orig, atol=0, rtol=0) + assert_equal(p, expected_p) + assert_equal(grad, grad_orig) @parameterized.parameters( {"lr": 0.25, "wd": 0.5}, @@ -91,8 +92,8 @@ def test_independent(self, lr, wd): helper._apply_weight_decay_inplace(p, grad, lr=lr, weight_decay=wd) expected_p = p_orig * (1 - wd) - torch.testing.assert_close(p, expected_p, atol=0, rtol=0) - torch.testing.assert_close(grad, grad_orig, atol=0, rtol=0) + assert_equal(p, expected_p) + assert_equal(grad, grad_orig) def test_independent_ignores_lr(self): """Two different lr values must produce identical results for independent decay.""" @@ -105,7 +106,7 @@ def test_independent_ignores_lr(self): _WeightDecayTestHelper("independent")._apply_weight_decay_inplace(p1, grad1, lr=0.001, weight_decay=wd) _WeightDecayTestHelper("independent")._apply_weight_decay_inplace(p2, grad2, lr=100.0, weight_decay=wd) - torch.testing.assert_close(p1, p2, atol=0, rtol=0) + assert_equal(p1, p2) @parameterized.parameters( {"lr": 0.1, "wd": 0.5}, @@ -122,8 +123,8 @@ def test_l2(self, lr, wd): helper._apply_weight_decay_inplace(p, grad, lr=lr, weight_decay=wd) expected_grad = grad_orig + p_orig * wd - torch.testing.assert_close(p, p_orig, atol=0, rtol=0) - torch.testing.assert_close(grad, expected_grad, atol=0, rtol=0) + assert_equal(p, p_orig) + assert_equal(grad, expected_grad) def test_l2_ignores_lr(self): """Two different lr values must produce identical results for L2 decay.""" @@ -136,7 +137,7 @@ def test_l2_ignores_lr(self): _WeightDecayTestHelper("l2")._apply_weight_decay_inplace(p1, grad1, lr=0.001, weight_decay=wd) _WeightDecayTestHelper("l2")._apply_weight_decay_inplace(p2, grad2, lr=100.0, weight_decay=wd) - torch.testing.assert_close(grad1, grad2, atol=0, rtol=0) + assert_equal(grad1, grad2) @parameterized.parameters( {"lr": 0.25, "wd": 0.5}, @@ -153,8 +154,8 @@ def test_palm(self, lr, wd): helper._apply_weight_decay_inplace(p, grad, lr=lr, weight_decay=wd) expected_p = p_orig * (1 - wd * lr * lr) - torch.testing.assert_close(p, expected_p, atol=0, rtol=0) - torch.testing.assert_close(grad, grad_orig, atol=0, rtol=0) + assert_equal(p, expected_p) + assert_equal(grad, grad_orig) def test_default_method_is_l2(self): """When weight_decay_method attribute is absent, default to L2.""" @@ -167,8 +168,8 @@ def test_default_method_is_l2(self): helper._apply_weight_decay_inplace(p, grad, lr=0.1, weight_decay=wd) expected_grad = grad_orig + p_orig * wd - torch.testing.assert_close(p, p_orig, atol=0, rtol=0) - torch.testing.assert_close(grad, expected_grad, atol=0, rtol=0) + assert_equal(p, p_orig) + assert_equal(grad, expected_grad) def test_invalid_method_raises(self): """An unrecognized weight_decay_method must raise ValueError."""