diff --git a/docs/engine.md b/docs/engine.md index cd4db23..86c9735 100644 --- a/docs/engine.md +++ b/docs/engine.md @@ -147,8 +147,9 @@ knobs `manual_eviction`, `allow_block_sharing`, `max_blocks_per_request`, `use_a When `attn_implementation` resolves to paged FlashAttention and `max_blocks_per_request` is left unset in Evalution, the engine seeds the block-table decode fast path defaults it needs for that runtime. Evalution also keeps a compatibility monkeypatch for `transformers` builds that still -need FA2 decode-fast-path enablement, and that fallback defaults `use_cuda_graph=False`. Evalution -keeps a session-owned manager alive while stop +need FA2 decode-fast-path enablement, and that fallback defaults `use_cuda_graph=False`. Set +`EVALUTION_PATCH_TRANSFOMRERS=0` to skip Evalution's local Transformers/FlashAttention monkey +patches entirely; the default is enabled. Evalution keeps a session-owned manager alive while stop strings and sampling settings stay compatible, then tears it down on `gc()` between suites or on `close()`. diff --git a/evalution/engines/transformers.py b/evalution/engines/transformers.py index 3f84458..9fbff35 100644 --- a/evalution/engines/transformers.py +++ b/evalution/engines/transformers.py @@ -6,6 +6,7 @@ from __future__ import annotations import inspect +import os import sys import threading from collections.abc import Iterable, Iterator @@ -45,6 +46,7 @@ _FLASH_ATTENTION_SMALL_MAX_BATCH_TOKENS = 4096 _FLASH_ATTENTION_SMALL_BLOCKS_PER_REQUEST = 1 _FLASH_ATTENTION_LARGE_BLOCKS_PER_REQUEST = 16 +_PATCH_TRANSFORMERS_ENV = "EVALUTION_PATCH_TRANSFOMRERS" @dataclass(slots=True) @@ -84,11 +86,17 @@ def build(self, model: Model) -> BaseTransformerSession: return TransformersCompat.from_transformers(self).build(model) _warn_pending_nogil_transformers_pr_once() - _patch_flash_attn_varlen_fwd_cuda_context_once() + if _transformers_monkey_patches_enabled(): + _patch_flash_attn_varlen_fwd_cuda_context_once() self.resolved_engine = "Transformers" return TransformersSession.from_config(self, model) +def _transformers_monkey_patches_enabled() -> bool: + """Return whether Evalution's local Transformers monkey patches should be applied.""" + return os.environ.get(_PATCH_TRANSFORMERS_ENV, "1") != "0" + + def _warn_pending_nogil_transformers_pr_once() -> None: """Implement warn pending no-GIL transformers pr once for this module.""" is_gil_enabled = getattr(sys, "_is_gil_enabled", None) @@ -698,8 +706,9 @@ def _ensure_continuous_batching_manager( """Ensure continuous batching manager.""" from transformers import ContinuousBatchingManager - _patch_continuous_batching_manager_cuda_context_once(ContinuousBatchingManager) - _patch_continuous_batching_flash_attention_decode_once() + if _transformers_monkey_patches_enabled(): + _patch_continuous_batching_manager_cuda_context_once(ContinuousBatchingManager) + _patch_continuous_batching_flash_attention_decode_once() generation_config = self._build_generation_config([request]) with self._state_lock: diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 68156a1..bc8cd33 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -844,6 +844,68 @@ def warning(self, message, *args): assert warnings == [] +def test_transformer_build_skips_monkey_patches_when_patch_env_disabled(monkeypatch) -> None: + """Verify transformer build skips local monkey patches when the env gate disables them.""" + import evalution.engines.transformers as transformer_module + + fake_session = object() + varlen_patch_calls: list[str] = [] + warning_calls: list[str] = [] + + monkeypatch.setenv("EVALUTION_PATCH_TRANSFOMRERS", "0") + monkeypatch.setattr( + "evalution.engines.transformers.transformers_continuous_batching_support", + lambda: (True, "ok"), + ) + monkeypatch.setattr( + TransformersSession, + "from_config", + classmethod(lambda cls, engine, model_config: fake_session), + ) + monkeypatch.setattr( + transformer_module, + "_patch_flash_attn_varlen_fwd_cuda_context_once", + lambda: varlen_patch_calls.append("patched"), + ) + monkeypatch.setattr( + transformer_module, + "_warn_pending_nogil_transformers_pr_once", + lambda: warning_calls.append("warned"), + ) + + assert Transformers(device="cpu").build(Model(path="/tmp/model")) is fake_session + assert warning_calls == ["warned"] + assert varlen_patch_calls == [] + + +def test_transformer_build_applies_monkey_patches_by_default(monkeypatch) -> None: + """Verify transformer build applies local monkey patches when the env gate is unset.""" + import evalution.engines.transformers as transformer_module + + fake_session = object() + varlen_patch_calls: list[str] = [] + + monkeypatch.delenv("EVALUTION_PATCH_TRANSFOMRERS", raising=False) + monkeypatch.setattr( + "evalution.engines.transformers.transformers_continuous_batching_support", + lambda: (True, "ok"), + ) + monkeypatch.setattr( + TransformersSession, + "from_config", + classmethod(lambda cls, engine, model_config: fake_session), + ) + monkeypatch.setattr( + transformer_module, + "_patch_flash_attn_varlen_fwd_cuda_context_once", + lambda: varlen_patch_calls.append("patched"), + ) + monkeypatch.setattr(transformer_module, "_warn_pending_nogil_transformers_pr_once", lambda: None) + + assert Transformers(device="cpu").build(Model(path="/tmp/model")) is fake_session + assert varlen_patch_calls == ["patched"] + + def test_transformer_monkey_patches_continuous_batching_generation_loop_once(monkeypatch) -> None: """Verify transformer monkey patches continuous batching generation loop once.""" import evalution.engines.transformers as transformer_module @@ -940,6 +1002,120 @@ def fail_if_called(device: object) -> object: assert "_run_generation_loop" not in FakeContinuousBatchingManager.__dict__ +def test_transformer_session_skips_continuous_batching_monkey_patches_when_patch_env_disabled(monkeypatch) -> None: + """Verify transformer session skips paged-batching monkey patches when the env gate disables them.""" + import evalution.engines.transformers as transformer_module + + patch_calls: list[str] = [] + + class FakeContinuousBatchingManager: + """Provide the fake continuous batching manager helper used by the surrounding tests.""" + def __init__(self) -> None: + """Initialize this object.""" + self.started = False + + def is_running(self) -> bool: + """Report whether this manager is already running.""" + return self.started + + def start(self) -> None: + """Start this manager.""" + self.started = True + + session = TransformersSession( + config=Transformers(attn_implementation="paged|flash_attention_2"), + model_config=Model(path="/tmp/model"), + model=SimpleNamespace(dtype="bfloat16"), + tokenizer=SimpleNamespace(), + input_device=SimpleNamespace(type="cpu"), + ) + + monkeypatch.setenv("EVALUTION_PATCH_TRANSFOMRERS", "0") + monkeypatch.setattr(transformers, "ContinuousBatchingManager", FakeContinuousBatchingManager) + monkeypatch.setattr( + transformer_module, + "_patch_continuous_batching_manager_cuda_context_once", + lambda manager_cls: patch_calls.append(f"cuda:{manager_cls.__name__}"), + ) + monkeypatch.setattr( + transformer_module, + "_patch_continuous_batching_flash_attention_decode_once", + lambda: patch_calls.append("decode"), + ) + monkeypatch.setattr(session, "_build_generation_config", lambda requests: object()) + monkeypatch.setattr( + session, + "_build_continuous_batching_manager", + lambda **kwargs: FakeContinuousBatchingManager(), + ) + + manager = session._ensure_continuous_batching_manager( + request_signature=("sig",), + request=GenerationRequest(prompt="alpha"), + ) + + assert isinstance(manager, FakeContinuousBatchingManager) + assert manager.started is True + assert patch_calls == [] + + +def test_transformer_session_applies_continuous_batching_monkey_patches_by_default(monkeypatch) -> None: + """Verify transformer session applies paged-batching monkey patches when the env gate is unset.""" + import evalution.engines.transformers as transformer_module + + patch_calls: list[str] = [] + + class FakeContinuousBatchingManager: + """Provide the fake continuous batching manager helper used by the surrounding tests.""" + def __init__(self) -> None: + """Initialize this object.""" + self.started = False + + def is_running(self) -> bool: + """Report whether this manager is already running.""" + return self.started + + def start(self) -> None: + """Start this manager.""" + self.started = True + + session = TransformersSession( + config=Transformers(attn_implementation="paged|flash_attention_2"), + model_config=Model(path="/tmp/model"), + model=SimpleNamespace(dtype="bfloat16"), + tokenizer=SimpleNamespace(), + input_device=SimpleNamespace(type="cpu"), + ) + + monkeypatch.delenv("EVALUTION_PATCH_TRANSFOMRERS", raising=False) + monkeypatch.setattr(transformers, "ContinuousBatchingManager", FakeContinuousBatchingManager) + monkeypatch.setattr( + transformer_module, + "_patch_continuous_batching_manager_cuda_context_once", + lambda manager_cls: patch_calls.append(f"cuda:{manager_cls.__name__}"), + ) + monkeypatch.setattr( + transformer_module, + "_patch_continuous_batching_flash_attention_decode_once", + lambda: patch_calls.append("decode"), + ) + monkeypatch.setattr(session, "_build_generation_config", lambda requests: object()) + monkeypatch.setattr( + session, + "_build_continuous_batching_manager", + lambda **kwargs: FakeContinuousBatchingManager(), + ) + + manager = session._ensure_continuous_batching_manager( + request_signature=("sig",), + request=GenerationRequest(prompt="alpha"), + ) + + assert isinstance(manager, FakeContinuousBatchingManager) + assert manager.started is True + assert patch_calls == ["cuda:FakeContinuousBatchingManager", "decode"] + + @pytest.mark.parametrize( ("attn_implementation", "max_batch_tokens", "expected_blocks"), [