From 1fe927bd247ed143f2b661474d62e0b4dbf73452 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 23 Jun 2026 23:08:47 -0700 Subject: [PATCH 1/2] implement base model output caching in model-level tests --- tests/models/testing_utils/common.py | 88 ++++++++++++------- .../test_models_transformer_hunyuan_dit.py | 4 +- .../test_models_transformer_hunyuan_video.py | 8 +- .../test_models_transformer_wan_animate.py | 4 +- 4 files changed, 64 insertions(+), 40 deletions(-) diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index 626f1eb7f1bf..ec66bd30f0aa 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -24,9 +24,13 @@ from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging -from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator -from ...testing_utils import assert_tensors_close, torch_device +from ...testing_utils import ( + assert_tensors_close, + require_accelerator, + require_torch_multi_accelerator, + torch_device, +) def named_persistent_module_tensors( @@ -278,8 +282,30 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin): pass """ + @pytest.fixture(scope="class") + def base_model_output(self): + """Class-scoped reference forward output, built once and reused across the class. + + Building the model and running its forward pass is fully deterministic (``torch.manual_seed(0)`` + plus the deterministic ``get_dummy_inputs`` contract), so the reference ("base") output is + identical for every test in the class. The save/load and parallelism tests compare a reloaded + model against this output; computing it a single time here — instead of rebuilding the model and + re-running the forward in each test — removes that redundant work and speeds up the suite. + + The hardware-gated tests that consume this fixture use ``pytest.mark.skipif`` (via the + ``require_*`` decorators), which pytest evaluates before fixture setup, so skipping on a machine + without the required accelerators never triggers this forward. + + Tests that still need a live model (e.g. to save it) build their own with the same seed, so the + reloaded model's weights match this cached output. + """ + torch.manual_seed(0) + model = self.model_class(**self.get_init_dict()).eval().to(torch_device) + with torch.no_grad(): + return model(**self.get_dummy_inputs(), return_dict=False)[0] + @torch.no_grad() - def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): + def test_from_save_pretrained(self, base_model_output, tmp_path, atol=5e-5, rtol=5e-5): torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) @@ -296,13 +322,15 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" ) - image = model(**self.get_dummy_inputs(), return_dict=False)[0] new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] - assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") + assert_tensors_close( + base_model_output, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes." + ) @torch.no_grad() - def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): + def test_from_save_pretrained_variant(self, base_model_output, tmp_path, atol=5e-5, rtol=0): + torch.manual_seed(0) model = self.model_class(**self.get_init_dict()) model.to(torch_device) model.eval() @@ -317,10 +345,11 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): new_model.to(torch_device) - image = model(**self.get_dummy_inputs(), return_dict=False)[0] new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] - assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") + assert_tensors_close( + base_model_output, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes." + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) def test_from_save_pretrained_dtype(self, tmp_path, dtype): @@ -360,13 +389,8 @@ def test_determinism(self, atol=1e-5, rtol=0): ) @torch.no_grad() - def test_output(self, expected_output_shape=None): - model = self.model_class(**self.get_init_dict()) - model.to(torch_device) - model.eval() - - inputs_dict = self.get_dummy_inputs() - output = model(**inputs_dict, return_dict=False)[0] + def test_output(self, base_model_output, expected_output_shape=None): + output = base_model_output assert output is not None, "Model output is None" assert output[0].shape == expected_output_shape or self.output_shape, ( @@ -509,14 +533,12 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, @require_accelerator @torch.no_grad() - def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0): + def test_sharded_checkpoints(self, base_model_output, tmp_path, atol=1e-5, rtol=0): torch.manual_seed(0) config = self.get_init_dict() model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**self.get_dummy_inputs(), return_dict=False)[0] - model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -537,19 +559,17 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close( - base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load" + base_model_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load" ) @require_accelerator @torch.no_grad() - def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0): + def test_sharded_checkpoints_with_variant(self, base_model_output, tmp_path, atol=1e-5, rtol=0): torch.manual_seed(0) config = self.get_init_dict() model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**self.get_dummy_inputs(), return_dict=False)[0] - model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small variant = "fp16" @@ -575,11 +595,15 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close( - base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load" + base_model_output, + new_output, + atol=atol, + rtol=rtol, + msg="Output should match after variant sharded save/load", ) @torch.no_grad() - def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0): + def test_sharded_checkpoints_with_parallel_loading(self, base_model_output, tmp_path, atol=1e-5, rtol=0): from diffusers.utils import constants torch.manual_seed(0) @@ -587,8 +611,6 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt model = self.model_class(**config).eval() model = model.to(torch_device) - base_output = model(**self.get_dummy_inputs(), return_dict=False)[0] - model_size = compute_module_persistent_sizes(model)[""] max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small @@ -624,7 +646,11 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt output_parallel = model_parallel(**self.get_dummy_inputs(), return_dict=False)[0] assert_tensors_close( - base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading" + base_model_output, + output_parallel, + atol=atol, + rtol=rtol, + msg="Output should match with parallel loading", ) finally: @@ -635,19 +661,17 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt @require_torch_multi_accelerator @torch.no_grad() - def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0): + def test_model_parallelism(self, base_model_output, tmp_path, atol=1e-5, rtol=0): if self.model_class._no_split_modules is None: pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") + torch.manual_seed(0) config = self.get_init_dict() inputs_dict = self.get_dummy_inputs() model = self.model_class(**config).eval() model = model.to(torch_device) - torch.manual_seed(0) - base_output = model(**inputs_dict, return_dict=False)[0] - model_size = compute_module_sizes(model)[""] max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] @@ -665,5 +689,5 @@ def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0): new_output = new_model(**inputs_dict, return_dict=False)[0] assert_tensors_close( - base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism" + base_model_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism" ) diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py index 1c08244b620c..370033ef319f 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -120,9 +120,9 @@ def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]: class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin): - def test_output(self): + def test_output(self, base_model_output): batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0] - super().test_output(expected_output_shape=(batch_size,) + self.output_shape) + super().test_output(base_model_output, expected_output_shape=(batch_size,) + self.output_shape) class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin): diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py index 90c716a336a5..cc934be125aa 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -223,8 +223,8 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: class TestHunyuanVideoI2VTransformer(HunyuanVideoI2VTransformerTesterConfig, ModelTesterMixin): - def test_output(self): - super().test_output(expected_output_shape=(1, *self.output_shape)) + def test_output(self, base_model_output): + super().test_output(base_model_output, expected_output_shape=(1, *self.output_shape)) # ======================== HunyuanVideo Token Replace Image-to-Video ======================== @@ -299,5 +299,5 @@ def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: class TestHunyuanVideoTokenReplaceTransformer(HunyuanVideoTokenReplaceTransformerTesterConfig, ModelTesterMixin): - def test_output(self): - super().test_output(expected_output_shape=(1, *self.output_shape)) + def test_output(self, base_model_output): + super().test_output(base_model_output, expected_output_shape=(1, *self.output_shape)) diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py index 30f78ca1c3de..bd751974637b 100644 --- a/tests/models/transformers/test_models_transformer_wan_animate.py +++ b/tests/models/transformers/test_models_transformer_wan_animate.py @@ -146,11 +146,11 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin): """Core model tests for Wan Animate Transformer 3D.""" - def test_output(self): + def test_output(self, base_model_output): # Override test_output because the transformer output is expected to have less channels # than the main transformer input. expected_output_shape = (1, 4, 21, 16, 16) - super().test_output(expected_output_shape=expected_output_shape) + super().test_output(base_model_output, expected_output_shape=expected_output_shape) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): From ef206c246cb966de388323106b4315d7abfd67db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 23 Jun 2026 23:17:28 -0700 Subject: [PATCH 2/2] single quotes --- tests/models/testing_utils/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index ec66bd30f0aa..129c8197887d 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -286,14 +286,14 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin): def base_model_output(self): """Class-scoped reference forward output, built once and reused across the class. - Building the model and running its forward pass is fully deterministic (``torch.manual_seed(0)`` - plus the deterministic ``get_dummy_inputs`` contract), so the reference ("base") output is + Building the model and running its forward pass is fully deterministic (`torch.manual_seed(0)` + plus the deterministic `get_dummy_inputs` contract), so the reference ("base") output is identical for every test in the class. The save/load and parallelism tests compare a reloaded model against this output; computing it a single time here — instead of rebuilding the model and re-running the forward in each test — removes that redundant work and speeds up the suite. - The hardware-gated tests that consume this fixture use ``pytest.mark.skipif`` (via the - ``require_*`` decorators), which pytest evaluates before fixture setup, so skipping on a machine + The hardware-gated tests that consume this fixture use `pytest.mark.skipif` (via the + `require_*` decorators), which pytest evaluates before fixture setup, so skipping on a machine without the required accelerators never triggers this forward. Tests that still need a live model (e.g. to save it) build their own with the same seed, so the