diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 4f95e61e..dd20cfa7 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -7,6 +7,8 @@ import torch from torch import distributed as dist from torch.fx.experimental import _config as fx_config +from torch_sendnn.backends.sendnn_backend import _get_global_state +from torch_sendnn.utils.graph_cache import SpyreGraphCache from aiu_fms_testing_utils.testing.validation import ( extract_validation_information, @@ -29,6 +31,7 @@ from transformers import AutoTokenizer from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup +import shutil import os try: @@ -132,7 +135,7 @@ if USE_MICRO_MODELS: VALIDATION_INFO_DIR = os.path.join(VALIDATION_INFO_DIR, "tiny_models") -# pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" +# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" if isinstance(COMMON_MODEL_PATHS, str): COMMON_MODEL_PATHS = COMMON_MODEL_PATHS.split(",") @@ -593,7 +596,6 @@ def _get_device_validation_information( token_iter, ATTN_NAME, ) - if cpu_validation_info is not None: return cpu_validation_info @@ -830,6 +832,8 @@ def _run_cpu_aiu_validation_test( aiu_model, micro_model_path, record_property, + verify_cache_state=None, + warmup_only=False, ): # Get the tokenizer and AIU / CPU models to compare tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -853,6 +857,16 @@ def _run_cpu_aiu_validation_test( aiu_model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SENDNN, **extra_kwargs ) + # Used only for cache tests; this is a nonparametric closure that + # should assert the cache for torch sendnn is in the correct state + # for this test + if verify_cache_state is not None: + verify_cache_state() + + # For some tests, e.g., cache checks, we only need to run the warmup + if warmup_only: + return + # Run validation level 0 failed_validation_level_0, validation_zero_info = _run_validation_level_0( model_path, @@ -888,6 +902,88 @@ def _run_cpu_aiu_validation_test( ) +def _reset_cache_settings(purge_cache_dir, cache_dir=None): + os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" + os.environ["COMPILATION_MODE"] = "offline_decoder" + if cache_dir is not None: + # Might be a posixpath + cache_dir = str(cache_dir) + os.environ["TORCH_SENDNN_CACHE_DIR"] = cache_dir + + # Ensure we start in clean state + if purge_cache_dir and os.path.isdir(cache_dir): + shutil.rmtree(cache_dir) + os.mkdir(cache_dir) + + # NOTE: currently, the cache dir is pulled from + # TORCH_SENDNN_CACHE_DIR at initialization time, + # so this should correctly use the cache_dir + _get_global_state().use_aiu_cache = True + _get_global_state().spyre_graph_cache = SpyreGraphCache() + + +@pytest.fixture +def use_cached_model(request, persistent_model, record_property, tmp_path): + """Configures the torchsendnn cache and runs the AIU model (warmup) + prior to test execution; this is computationally expensive and should + only be used in situations like testing cache hit correctness. + """ + torch.manual_seed(42) + torch.set_grad_enabled(False) + _reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path) + + model_path, batch_size, seq_length, max_new_tokens = request.param + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) + + def verify_cache_miss(): + cache_dir = str(tmp_path) + updated_cache_len = ( + len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 + ) + assert updated_cache_len == max_new_tokens, ( + "cache directory not populated on cache miss" + ) + + dprint( + f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}" + ) + + # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured + gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) + is_gptq = len(gptq_kwargs_aiu) != 0 + is_fp8 = "fp8" in ATTN_NAME + model_kwargs = _get_common_model_kwargs(is_gptq, model_path) + + # Get the AIU model w/ the persistent model fixture + model = persistent_model.get_or_create( + is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs + ) + + validation_model = _get_cpu_model( + is_gptq, + is_fp8, + micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, + **gptq_kwargs_cpu, + **model_kwargs, + ) + # We also return the models so that we can reuse them in the cache hit check + models = (model, validation_model) + + _run_cpu_aiu_validation_test( + model_path, + batch_size, + seq_length, + max_new_tokens, + validation_model, + model, + micro_model_path, + record_property, + verify_cache_state=verify_cache_miss, + warmup_only=True, + ) + return request.param, models + + @pytest.mark.parametrize( "model_path,batch_size,seq_length,max_new_tokens", COMMON_SHAPES ) @@ -937,3 +1033,50 @@ def test_common_shapes( micro_model_path, record_property, ) + + +@pytest.mark.parametrize( + "use_cached_model", + COMMON_SHAPES, + indirect=True, +) +def test_cache(use_cached_model, record_property, tmp_path): + torch.manual_seed(42) + torch.set_grad_enabled(False) + _reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path) + + # use_cached_model is an indirectly parametrized fixture, and the returned + # value is an expanded tuple from COMMON_SHAPES, so we unpack it here. + # In addition, we also pass the model created on AIU in the fixture to + # avoid recreating it. + test_params, models = use_cached_model + model, validation_model = models + model_path, batch_size, seq_length, max_new_tokens = test_params + + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) + + def verify_cache_hit(): + cache_dir = str(tmp_path) + updated_cache_len = ( + len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 + ) + assert updated_cache_len == max_new_tokens, ( + "cache miss occurred when hit was expected" + ) + + dprint( + f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit" + ) + + _run_cpu_aiu_validation_test( + model_path, + batch_size, + seq_length, + max_new_tokens, + validation_model, + model, + micro_model_path, + record_property, + verify_cache_state=verify_cache_hit, + warmup_only=True, + )