Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/vllm_serve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This is a simple example to demonstrate calibrating and serving ModelOpt fakequa

Compared with realquant, fakequant is 2-5x slower, but doesn't require dedicated kernel support and facilitates research.

This example is tested with vllm 0.9.0 and 0.11.2
This example is tested with vllm 0.9.0 and 0.19.1

## Prepare environment

Expand Down
7 changes: 5 additions & 2 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,14 @@ def determine_available_memory(self) -> int:
with disable_compilation(model):
return super().determine_available_memory()

def compile_or_warm_up_model(self) -> None:
def compile_or_warm_up_model(self) -> float:
if (
quant_config["quant_cfg"]
or quant_config["kv_quant_cfg"]
or quant_config["modelopt_state_path"]
):
_fakequant_run_prolog_worker(self)
super().compile_or_warm_up_model()
# Must return the base worker's compilation time (seconds). Returning None
# breaks vLLM V1 executor: initialize_from_config does max(compilation_times)
# across TP workers.
return super().compile_or_warm_up_model()
16 changes: 12 additions & 4 deletions examples/vllm_serve/vllm_serve_fakequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,12 @@

vllm_version = version.parse(vllm.__version__)
if vllm_version <= version.parse("0.11.0"):
from vllm.executor.ray_distributed_executor import RayDistributedExecutor
from vllm.utils import FlexibleArgumentParser
else:
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.executor.ray_executor import RayDistributedExecutor


# Adding the envs you want to pass to the workers
# Env vars to copy from the driver to Ray workers (must match fakequant_worker / vllm_ptq_utils).
additional_env_vars = {
"QUANT_DATASET",
"QUANT_CALIB_SIZE",
Expand All @@ -81,7 +79,17 @@
"TRUST_REMOTE_CODE",
}

RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars)
try:
from vllm.executor.ray_distributed_executor import RayDistributedExecutor

RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars)
except (ImportError, AttributeError):
# vLLM v1 Ray: vllm/ray/ray_env.py (get_env_vars_to_copy); merge with any user-set list.
extra_env_var = "VLLM_RAY_EXTRA_ENV_VARS_TO_COPY"
merged_env_vars = {
t.strip() for t in os.environ.get(extra_env_var, "").split(",") if t.strip()
} | additional_env_vars
os.environ[extra_env_var] = ",".join(sorted(merged_env_vars))


def main():
Expand Down
103 changes: 82 additions & 21 deletions modelopt/torch/quantization/plugins/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

"""Support quantization for VLLM layers."""

import contextvars
import importlib
from collections.abc import Callable
from contextlib import contextmanager
from functools import partial
from itertools import chain

import torch
Expand Down Expand Up @@ -85,6 +88,21 @@
)

vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
# vLLM may call one entry (e.g. ``dispatch_fused_moe_kernel``) which then calls another on the same
# module (e.g. ``invoke_fused_moe_triton_kernel``). Patching every name would otherwise apply fakequant
# twice; see ``_moe_fakequant_active`` in ``invoke_fused_moe_quantized``.
_FUSED_MOE_KERNEL_CANDIDATES = (
"invoke_fused_moe_kernel",
"invoke_fused_moe_triton_kernel",
"dispatch_fused_moe_kernel",
)
_FUSED_MOE_KERNEL_FUNCS = tuple(
n for n in _FUSED_MOE_KERNEL_CANDIDATES if hasattr(vllm_fused_moe_package, n)
)

_moe_fakequant_active: contextvars.ContextVar[bool] = contextvars.ContextVar(
"moe_fakequant_active", default=False
)


@contextmanager
Expand Down Expand Up @@ -340,29 +358,64 @@ def invoke_fused_moe_quantized(
B: torch.Tensor, # noqa: N803
C: torch.Tensor, # noqa: N803
*args,
original_kernel: Callable,
**kwargs,
):
# Nested module-level entry (e.g. dispatch -> triton): call the real kernel once, no second quant.
if _moe_fakequant_active.get():
return original_kernel(A, B, C, *args, **kwargs)
token = _moe_fakequant_active.set(True)
try:
return self._invoke_fused_moe_quantized_function(
A, B, C, *args, original_kernel=original_kernel, **kwargs
)
finally:
_moe_fakequant_active.reset(token)

def _invoke_fused_moe_quantized_function(
self,
A: torch.Tensor, # noqa: N803
B: torch.Tensor, # noqa: N803
C: torch.Tensor, # noqa: N803
*args,
original_kernel: Callable,
**kwargs,
):
if B is self.w13_weight:
# First layer of expert
A = self.w13_input_quantizer(A) # noqa: N806
if self.w13_weight_quantizer.is_enabled:
original_weight = self.w13_weight
self.w13_weight = self.w13_weight_quantizer(self.w13_weight)
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
self.w13_weight = original_weight
if self.w13_weight_quantizer.is_enabled: # pragma: no cover
original_weight, self.w13_weight = (
self.w13_weight,
self.w13_weight_quantizer(self.w13_weight),
)
Comment thread
kinjalpatel27 marked this conversation as resolved.
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
# quantized weight to the kernel.
Comment thread
kinjalpatel27 marked this conversation as resolved.
B = self.w13_weight # noqa: N806
try:
original_kernel(A, B, C, *args, **kwargs)
finally:
self.w13_weight = original_weight
else:
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
original_kernel(A, B, C, *args, **kwargs)
if self.w13_output_quantizer.is_enabled:
C[:] = self.w13_output_quantizer(C)
elif B is self.w2_weight:
A = self.w2_input_quantizer(A) # noqa: N806
if self.w2_weight_quantizer.is_enabled:
original_weight = self.w2_weight
self.w2_weight = self.w2_weight_quantizer(self.w2_weight)
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
self.w2_weight = original_weight
if self.w2_weight_quantizer.is_enabled: # pragma: no cover
original_weight, self.w2_weight = (
self.w2_weight,
self.w2_weight_quantizer(self.w2_weight),
)
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
# quantized weight to the kernel.
B = self.w2_weight # noqa: N806
try:
original_kernel(A, B, C, *args, **kwargs)
finally:
self.w2_weight = original_weight
else:
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
original_kernel(A, B, C, *args, **kwargs)
if self.w2_output_quantizer.is_enabled:
C[:] = self.w2_output_quantizer(C)
else:
Expand All @@ -372,24 +425,31 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
# This is again due to the bad coding of vLLM
# fused_moe submodule is overwritten by the fused_moe function
# so we need to import the fused_moe module explicitly
assert vllm_fused_moe_package.invoke_fused_moe_kernel is not None
assert _FUSED_MOE_KERNEL_FUNCS and all(
getattr(vllm_fused_moe_package, n, None) is not None for n in _FUSED_MOE_KERNEL_FUNCS
)
# This context manager will conflict with torch.compile
# with replace_function(
# vllm_fused_moe_package,
# "invoke_fused_moe_kernel",
Comment thread
kinjalpatel27 marked this conversation as resolved.
# self.invoke_fused_moe_quantized,
# ):
originals = {n: getattr(vllm_fused_moe_package, n) for n in _FUSED_MOE_KERNEL_FUNCS}
try:
vllm_fused_moe_package._invoke_fused_moe_kernel = ( # type: ignore[attr-defined]
vllm_fused_moe_package.invoke_fused_moe_kernel
)
vllm_fused_moe_package.invoke_fused_moe_kernel = self.invoke_fused_moe_quantized # type: ignore[attr-defined]
for n in _FUSED_MOE_KERNEL_FUNCS:
setattr(
vllm_fused_moe_package,
n,
partial(
self.invoke_fused_moe_quantized,
original_kernel=originals[n],
),
)
output = super().forward(hidden_states, router_logits)
return output
finally:
vllm_fused_moe_package.invoke_fused_moe_kernel = ( # type: ignore[attr-defined]
vllm_fused_moe_package._invoke_fused_moe_kernel
)
for n in _FUSED_MOE_KERNEL_FUNCS:
setattr(vllm_fused_moe_package, n, originals[n])

@torch.no_grad()
def fold_weight(self, keep_attrs: bool = False):
Expand All @@ -409,7 +469,8 @@ def fold_weight(self, keep_attrs: bool = False):
)
self.w2_weight_quantizer.disable()

torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()


@QuantModuleRegistry.register({vllm_fused_moe_layer.FusedMoE: "vllm_FusedMoE"})
Expand Down
Loading