Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ knobs `manual_eviction`, `allow_block_sharing`, `max_blocks_per_request`, `use_a
When `attn_implementation` resolves to paged FlashAttention and `max_blocks_per_request` is left
unset in Evalution, the engine seeds the block-table decode fast path defaults it needs for that
runtime. Evalution also keeps a compatibility monkeypatch for `transformers` builds that still
need FA2 decode-fast-path enablement, and that fallback defaults `use_cuda_graph=False`. Evalution
keeps a session-owned manager alive while stop
need FA2 decode-fast-path enablement, and that fallback defaults `use_cuda_graph=False`. Set
`EVALUTION_PATCH_TRANSFOMRERS=0` to skip Evalution's local Transformers/FlashAttention monkey
patches entirely; the default is enabled. Evalution keeps a session-owned manager alive while stop
strings and sampling settings stay compatible, then tears it down on `gc()` between suites or on
`close()`.

Expand Down
15 changes: 12 additions & 3 deletions evalution/engines/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import inspect
import os
import sys
import threading
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -45,6 +46,7 @@
_FLASH_ATTENTION_SMALL_MAX_BATCH_TOKENS = 4096
_FLASH_ATTENTION_SMALL_BLOCKS_PER_REQUEST = 1
_FLASH_ATTENTION_LARGE_BLOCKS_PER_REQUEST = 16
_PATCH_TRANSFORMERS_ENV = "EVALUTION_PATCH_TRANSFOMRERS"


@dataclass(slots=True)
Expand Down Expand Up @@ -84,11 +86,17 @@ def build(self, model: Model) -> BaseTransformerSession:
return TransformersCompat.from_transformers(self).build(model)

_warn_pending_nogil_transformers_pr_once()
_patch_flash_attn_varlen_fwd_cuda_context_once()
if _transformers_monkey_patches_enabled():
_patch_flash_attn_varlen_fwd_cuda_context_once()
self.resolved_engine = "Transformers"
return TransformersSession.from_config(self, model)


def _transformers_monkey_patches_enabled() -> bool:
"""Return whether Evalution's local Transformers monkey patches should be applied."""
return os.environ.get(_PATCH_TRANSFORMERS_ENV, "1") != "0"


def _warn_pending_nogil_transformers_pr_once() -> None:
"""Implement warn pending no-GIL transformers pr once for this module."""
is_gil_enabled = getattr(sys, "_is_gil_enabled", None)
Expand Down Expand Up @@ -698,8 +706,9 @@ def _ensure_continuous_batching_manager(
"""Ensure continuous batching manager."""
from transformers import ContinuousBatchingManager

_patch_continuous_batching_manager_cuda_context_once(ContinuousBatchingManager)
_patch_continuous_batching_flash_attention_decode_once()
if _transformers_monkey_patches_enabled():
_patch_continuous_batching_manager_cuda_context_once(ContinuousBatchingManager)
_patch_continuous_batching_flash_attention_decode_once()
generation_config = self._build_generation_config([request])

with self._state_lock:
Expand Down
176 changes: 176 additions & 0 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,68 @@ def warning(self, message, *args):
assert warnings == []


def test_transformer_build_skips_monkey_patches_when_patch_env_disabled(monkeypatch) -> None:
"""Verify transformer build skips local monkey patches when the env gate disables them."""
import evalution.engines.transformers as transformer_module

fake_session = object()
varlen_patch_calls: list[str] = []
warning_calls: list[str] = []

monkeypatch.setenv("EVALUTION_PATCH_TRANSFOMRERS", "0")
monkeypatch.setattr(
"evalution.engines.transformers.transformers_continuous_batching_support",
lambda: (True, "ok"),
)
monkeypatch.setattr(
TransformersSession,
"from_config",
classmethod(lambda cls, engine, model_config: fake_session),
)
monkeypatch.setattr(
transformer_module,
"_patch_flash_attn_varlen_fwd_cuda_context_once",
lambda: varlen_patch_calls.append("patched"),
)
monkeypatch.setattr(
transformer_module,
"_warn_pending_nogil_transformers_pr_once",
lambda: warning_calls.append("warned"),
)

assert Transformers(device="cpu").build(Model(path="/tmp/model")) is fake_session
assert warning_calls == ["warned"]
assert varlen_patch_calls == []


def test_transformer_build_applies_monkey_patches_by_default(monkeypatch) -> None:
"""Verify transformer build applies local monkey patches when the env gate is unset."""
import evalution.engines.transformers as transformer_module

fake_session = object()
varlen_patch_calls: list[str] = []

monkeypatch.delenv("EVALUTION_PATCH_TRANSFOMRERS", raising=False)
monkeypatch.setattr(
"evalution.engines.transformers.transformers_continuous_batching_support",
lambda: (True, "ok"),
)
monkeypatch.setattr(
TransformersSession,
"from_config",
classmethod(lambda cls, engine, model_config: fake_session),
)
monkeypatch.setattr(
transformer_module,
"_patch_flash_attn_varlen_fwd_cuda_context_once",
lambda: varlen_patch_calls.append("patched"),
)
monkeypatch.setattr(transformer_module, "_warn_pending_nogil_transformers_pr_once", lambda: None)

assert Transformers(device="cpu").build(Model(path="/tmp/model")) is fake_session
assert varlen_patch_calls == ["patched"]


def test_transformer_monkey_patches_continuous_batching_generation_loop_once(monkeypatch) -> None:
"""Verify transformer monkey patches continuous batching generation loop once."""
import evalution.engines.transformers as transformer_module
Expand Down Expand Up @@ -940,6 +1002,120 @@ def fail_if_called(device: object) -> object:
assert "_run_generation_loop" not in FakeContinuousBatchingManager.__dict__


def test_transformer_session_skips_continuous_batching_monkey_patches_when_patch_env_disabled(monkeypatch) -> None:
"""Verify transformer session skips paged-batching monkey patches when the env gate disables them."""
import evalution.engines.transformers as transformer_module

patch_calls: list[str] = []

class FakeContinuousBatchingManager:
"""Provide the fake continuous batching manager helper used by the surrounding tests."""
def __init__(self) -> None:
"""Initialize this object."""
self.started = False

def is_running(self) -> bool:
"""Report whether this manager is already running."""
return self.started

def start(self) -> None:
"""Start this manager."""
self.started = True

session = TransformersSession(
config=Transformers(attn_implementation="paged|flash_attention_2"),
model_config=Model(path="/tmp/model"),
model=SimpleNamespace(dtype="bfloat16"),
tokenizer=SimpleNamespace(),
input_device=SimpleNamespace(type="cpu"),
)

monkeypatch.setenv("EVALUTION_PATCH_TRANSFOMRERS", "0")
monkeypatch.setattr(transformers, "ContinuousBatchingManager", FakeContinuousBatchingManager)
monkeypatch.setattr(
transformer_module,
"_patch_continuous_batching_manager_cuda_context_once",
lambda manager_cls: patch_calls.append(f"cuda:{manager_cls.__name__}"),
)
monkeypatch.setattr(
transformer_module,
"_patch_continuous_batching_flash_attention_decode_once",
lambda: patch_calls.append("decode"),
)
monkeypatch.setattr(session, "_build_generation_config", lambda requests: object())
monkeypatch.setattr(
session,
"_build_continuous_batching_manager",
lambda **kwargs: FakeContinuousBatchingManager(),
)

manager = session._ensure_continuous_batching_manager(
request_signature=("sig",),
request=GenerationRequest(prompt="alpha"),
)

assert isinstance(manager, FakeContinuousBatchingManager)
assert manager.started is True
assert patch_calls == []


def test_transformer_session_applies_continuous_batching_monkey_patches_by_default(monkeypatch) -> None:
"""Verify transformer session applies paged-batching monkey patches when the env gate is unset."""
import evalution.engines.transformers as transformer_module

patch_calls: list[str] = []

class FakeContinuousBatchingManager:
"""Provide the fake continuous batching manager helper used by the surrounding tests."""
def __init__(self) -> None:
"""Initialize this object."""
self.started = False

def is_running(self) -> bool:
"""Report whether this manager is already running."""
return self.started

def start(self) -> None:
"""Start this manager."""
self.started = True

session = TransformersSession(
config=Transformers(attn_implementation="paged|flash_attention_2"),
model_config=Model(path="/tmp/model"),
model=SimpleNamespace(dtype="bfloat16"),
tokenizer=SimpleNamespace(),
input_device=SimpleNamespace(type="cpu"),
)

monkeypatch.delenv("EVALUTION_PATCH_TRANSFOMRERS", raising=False)
monkeypatch.setattr(transformers, "ContinuousBatchingManager", FakeContinuousBatchingManager)
monkeypatch.setattr(
transformer_module,
"_patch_continuous_batching_manager_cuda_context_once",
lambda manager_cls: patch_calls.append(f"cuda:{manager_cls.__name__}"),
)
monkeypatch.setattr(
transformer_module,
"_patch_continuous_batching_flash_attention_decode_once",
lambda: patch_calls.append("decode"),
)
monkeypatch.setattr(session, "_build_generation_config", lambda requests: object())
monkeypatch.setattr(
session,
"_build_continuous_batching_manager",
lambda **kwargs: FakeContinuousBatchingManager(),
)

manager = session._ensure_continuous_batching_manager(
request_signature=("sig",),
request=GenerationRequest(prompt="alpha"),
)

assert isinstance(manager, FakeContinuousBatchingManager)
assert manager.started is True
assert patch_calls == ["cuda:FakeContinuousBatchingManager", "decode"]


@pytest.mark.parametrize(
("attn_implementation", "max_batch_tokens", "expected_blocks"),
[
Expand Down