8787
8888vllm_fused_moe_package = importlib .import_module ("vllm.model_executor.layers.fused_moe.fused_moe" )
8989
90- _vllm_fused_moe_invoke_name_cache : str | None = None
91-
92-
93- def _vllm_fused_moe_invoke_name () -> str :
94- """Return the vLLM public fused_moe entrypoint (renamed across versions)."""
95- global _vllm_fused_moe_invoke_name_cache
96- if _vllm_fused_moe_invoke_name_cache is not None :
97- return _vllm_fused_moe_invoke_name_cache
98- for name in ("invoke_fused_moe_kernel" , "invoke_fused_moe_triton_kernel" ):
99- if hasattr (vllm_fused_moe_package , name ):
100- _vllm_fused_moe_invoke_name_cache = name
101- return name
102- raise ValueError ("fused_moe_kernel is not found" )
103-
10490
10591@contextmanager
10692def disable_compilation (model ):
@@ -349,6 +335,17 @@ def _setup(self):
349335 )
350336 self .parallel_state = create_parallel_state ()
351337
338+ if getattr (self , "invoke_fused_moe_kernel_func" , None ) is None : # pragma: no cover
339+ for name in ("invoke_fused_moe_kernel" , "invoke_fused_moe_triton_kernel" ):
340+ if hasattr (vllm_fused_moe_package , name ):
341+ self .invoke_fused_moe_kernel_func = name
342+ break
343+ assert ( # pragma: no cover
344+ getattr (self , "invoke_fused_moe_kernel_func" , None ) is not None
345+ ), (
346+ "fused_moe_kernel is not found"
347+ )
348+
352349 def invoke_fused_moe_quantized (
353350 self ,
354351 A : torch .Tensor , # noqa: N803
@@ -360,11 +357,14 @@ def invoke_fused_moe_quantized(
360357 if B is self .w13_weight :
361358 # First layer of expert
362359 A = self .w13_input_quantizer (A ) # noqa: N806
363- if self .w13_weight_quantizer .is_enabled :
360+ if self .w13_weight_quantizer .is_enabled : # pragma: no cover
364361 original_weight , self .w13_weight = (
365362 self .w13_weight ,
366363 self .w13_weight_quantizer (self .w13_weight ),
367364 )
365+ # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
366+ # quantized weight to the kernel.
367+ B = self .w13_weight # noqa: N806
368368 vllm_fused_moe_package ._invoke_fused_moe_kernel (A , B , C , * args , ** kwargs )
369369 self .w13_weight = original_weight
370370 else :
@@ -373,11 +373,14 @@ def invoke_fused_moe_quantized(
373373 C [:] = self .w13_output_quantizer (C )
374374 elif B is self .w2_weight :
375375 A = self .w2_input_quantizer (A ) # noqa: N806
376- if self .w2_weight_quantizer .is_enabled :
376+ if self .w2_weight_quantizer .is_enabled : # pragma: no cover
377377 original_weight , self .w2_weight = (
378378 self .w2_weight ,
379379 self .w2_weight_quantizer (self .w2_weight ),
380380 )
381+ # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
382+ # quantized weight to the kernel.
383+ B = self .w2_weight # noqa: N806
381384 vllm_fused_moe_package ._invoke_fused_moe_kernel (A , B , C , * args , ** kwargs )
382385 self .w2_weight = original_weight
383386 else :
@@ -388,9 +391,9 @@ def invoke_fused_moe_quantized(
388391 raise ValueError ("Cannot determine first or second layer of expert" )
389392
390393 def forward (self , hidden_states : torch .Tensor , router_logits : torch .Tensor ):
391- with replace_function (
394+ with replace_function ( # pragma: no cover
392395 vllm_fused_moe_package ,
393- _vllm_fused_moe_invoke_name () ,
396+ self . invoke_fused_moe_kernel_func ,
394397 self .invoke_fused_moe_quantized ,
395398 og_func_cache_name = "_invoke_fused_moe_kernel" ,
396399 ):
0 commit comments