Skip to content

Commit c42e844

Browse files
committed
minor
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 5ecfaca commit c42e844

File tree

1 file changed

+21
-18
lines changed
  • modelopt/torch/quantization/plugins

1 file changed

+21
-18
lines changed

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,6 @@
8787

8888
vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe")
8989

90-
_vllm_fused_moe_invoke_name_cache: str | None = None
91-
92-
93-
def _vllm_fused_moe_invoke_name() -> str:
94-
"""Return the vLLM public fused_moe entrypoint (renamed across versions)."""
95-
global _vllm_fused_moe_invoke_name_cache
96-
if _vllm_fused_moe_invoke_name_cache is not None:
97-
return _vllm_fused_moe_invoke_name_cache
98-
for name in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"):
99-
if hasattr(vllm_fused_moe_package, name):
100-
_vllm_fused_moe_invoke_name_cache = name
101-
return name
102-
raise ValueError("fused_moe_kernel is not found")
103-
10490

10591
@contextmanager
10692
def disable_compilation(model):
@@ -349,6 +335,17 @@ def _setup(self):
349335
)
350336
self.parallel_state = create_parallel_state()
351337

338+
if getattr(self, "invoke_fused_moe_kernel_func", None) is None: # pragma: no cover
339+
for name in ("invoke_fused_moe_kernel", "invoke_fused_moe_triton_kernel"):
340+
if hasattr(vllm_fused_moe_package, name):
341+
self.invoke_fused_moe_kernel_func = name
342+
break
343+
assert ( # pragma: no cover
344+
getattr(self, "invoke_fused_moe_kernel_func", None) is not None
345+
), (
346+
"fused_moe_kernel is not found"
347+
)
348+
352349
def invoke_fused_moe_quantized(
353350
self,
354351
A: torch.Tensor, # noqa: N803
@@ -360,11 +357,14 @@ def invoke_fused_moe_quantized(
360357
if B is self.w13_weight:
361358
# First layer of expert
362359
A = self.w13_input_quantizer(A) # noqa: N806
363-
if self.w13_weight_quantizer.is_enabled:
360+
if self.w13_weight_quantizer.is_enabled: # pragma: no cover
364361
original_weight, self.w13_weight = (
365362
self.w13_weight,
366363
self.w13_weight_quantizer(self.w13_weight),
367364
)
365+
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
366+
# quantized weight to the kernel.
367+
B = self.w13_weight # noqa: N806
368368
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
369369
self.w13_weight = original_weight
370370
else:
@@ -373,11 +373,14 @@ def invoke_fused_moe_quantized(
373373
C[:] = self.w13_output_quantizer(C)
374374
elif B is self.w2_weight:
375375
A = self.w2_input_quantizer(A) # noqa: N806
376-
if self.w2_weight_quantizer.is_enabled:
376+
if self.w2_weight_quantizer.is_enabled: # pragma: no cover
377377
original_weight, self.w2_weight = (
378378
self.w2_weight,
379379
self.w2_weight_quantizer(self.w2_weight),
380380
)
381+
# In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
382+
# quantized weight to the kernel.
383+
B = self.w2_weight # noqa: N806
381384
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
382385
self.w2_weight = original_weight
383386
else:
@@ -388,9 +391,9 @@ def invoke_fused_moe_quantized(
388391
raise ValueError("Cannot determine first or second layer of expert")
389392

390393
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
391-
with replace_function(
394+
with replace_function( # pragma: no cover
392395
vllm_fused_moe_package,
393-
_vllm_fused_moe_invoke_name(),
396+
self.invoke_fused_moe_kernel_func,
394397
self.invoke_fused_moe_quantized,
395398
og_func_cache_name="_invoke_fused_moe_kernel",
396399
):

0 commit comments

Comments
 (0)