Skip to content

Commit 2cdf922

Browse files
SouthWest7SouthWest7
andauthored
[Feature]: Remove Chunking From FusedMoE (vllm-project#34086)
Signed-off-by: SouthWest7 <am1ao@qq.com> Signed-off-by: Southwest <1403572259@qq.com> Signed-off-by: southwest <am1ao@qq.com> Signed-off-by: Xinan Miao <1403572259@qq.com> Co-authored-by: SouthWest7 <am1ao@qq.com>
1 parent c973ecd commit 2cdf922

28 files changed

Lines changed: 153 additions & 524 deletions

docs/design/fused_moe_modular_kernel.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,6 @@ FusedMoEExpertsModular performs the core of the FusedMoE operations. The various
167167

168168
`FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
169169

170-
`FusedMoEExpertsModular::supports_chunking()`: Return True if the implementation supports chunking. Typically
171-
implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not.
172-
173170
`FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map.
174171

175172
`FusedMoEExpertsModular::workspace_shapes()` /
@@ -220,8 +217,8 @@ If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsMod
220217

221218
1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively.
222219
2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`,
223-
`Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`,
224-
`Config::is_fe_supports_chunking()` methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py)
220+
`Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`
221+
methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py)
225222

226223
Doing this will add the new implementation to the test suite.
227224

tests/kernels/moe/modular_kernel_tools/cli_args.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,6 @@ def to_quant_torch_dtype(s: str) -> torch.dtype:
8282
"--num-experts", type=int, default=32, help="Global num experts"
8383
)
8484
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
85-
parser.add_argument(
86-
"--fused-moe-chunk-size",
87-
type=int,
88-
help="Fused moe chunk size used for the non-batched fused experts impl.",
89-
)
9085

9186
# Quant args
9287
parser.add_argument(
@@ -158,7 +153,6 @@ def make_config(args: argparse.Namespace) -> Config:
158153
quant_config=quant_config,
159154
prepare_finalize_type=args.pf_type,
160155
fused_experts_type=args.experts_type,
161-
fused_moe_chunk_size=args.fused_moe_chunk_size,
162156
world_size=args.world_size,
163157
torch_trace_dir_path=args.torch_trace_dir_path,
164158
)

tests/kernels/moe/modular_kernel_tools/common.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ class Config:
6868
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
6969
fused_experts_type: mk.FusedMoEExperts
7070

71-
fused_moe_chunk_size: int | None
7271
world_size: int
7372

7473
torch_trace_dir_path: str | None = None
@@ -89,7 +88,6 @@ def describe(self) -> str:
8988
s += f" K={self.K}\n"
9089
s += f" topk={self.topks}\n"
9190
s += f" dtype={self.dtype}\n"
92-
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
9391
s += " Quant:\n"
9492
if self.quant_config is not None:
9593
s += f" q_dtype={self.quant_dtype}\n"
@@ -152,11 +150,6 @@ def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
152150

153151
vllm_config.parallel_config.all2all_backend = self.all2all_backend()
154152

155-
if self.fused_moe_chunk_size is not None:
156-
env_dict.update(
157-
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
158-
)
159-
160153
return vllm_config, env_dict
161154

162155
def is_fp8_block_quantized(self):
@@ -189,10 +182,6 @@ def is_block_quant_supported(self):
189182
info = expert_info(self.fused_experts_type)
190183
return info.blocked_quantization_support
191184

192-
def is_fe_supports_chunking(self):
193-
info = expert_info(self.fused_experts_type)
194-
return info.supports_chunking
195-
196185
def supports_expert_map(self):
197186
info = expert_info(self.fused_experts_type)
198187
return info.supports_expert_map
@@ -233,10 +222,6 @@ def is_valid(self) -> tuple[bool, str | None]:
233222
if not self.is_standard_fused_experts():
234223
return False, "Mismatched format."
235224

236-
use_chunking = self.fused_moe_chunk_size is not None
237-
if use_chunking and not self.is_fe_supports_chunking():
238-
return False, "Chunking not supported."
239-
240225
# Check quantization sanity
241226
if (
242227
int(self.is_per_act_token_quant)

tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,6 @@ def rank_worker(
4242
):
4343
set_random_seed(pgi.rank)
4444

45-
# sanity check
46-
from vllm import envs
47-
48-
if config.fused_moe_chunk_size is not None:
49-
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
50-
5145
# get weights to this device
5246
weights.to_current_device()
5347

@@ -135,7 +129,6 @@ def add_to_results(
135129
fused_experts_type=experts_type,
136130
quant_config=quant_config,
137131
world_size=2,
138-
fused_moe_chunk_size=None,
139132
)
140133

141134
success = None

tests/kernels/moe/modular_kernel_tools/mk_objects.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ class ExpertInfo:
6464
activation_format: mk.FusedMoEActivationFormat
6565
supported_dtypes: list[torch.dtype | str]
6666
blocked_quantization_support: bool
67-
supports_chunking: bool
6867
supports_expert_map: bool
6968
needs_matching_quant: bool = False
7069
needs_deep_gemm: bool = False
@@ -127,7 +126,6 @@ def register_experts(
127126
activation_format: mk.FusedMoEActivationFormat,
128127
supported_dtypes: list[torch.dtype | str],
129128
blocked_quantization_support: bool,
130-
supports_chunking: bool,
131129
supports_expert_map: bool,
132130
needs_matching_quant: bool = False,
133131
needs_deep_gemm: bool = False,
@@ -141,7 +139,6 @@ def register_experts(
141139
activation_format,
142140
supported_dtypes,
143141
blocked_quantization_support,
144-
supports_chunking,
145142
supports_expert_map,
146143
needs_matching_quant,
147144
needs_deep_gemm,
@@ -176,7 +173,6 @@ def expert_info(kind) -> ExpertInfo:
176173
batched_format,
177174
common_float_types,
178175
blocked_quantization_support=True,
179-
supports_chunking=False,
180176
supports_expert_map=False,
181177
needs_matching_quant=True,
182178
)
@@ -186,7 +182,6 @@ def expert_info(kind) -> ExpertInfo:
186182
standard_format,
187183
common_float_and_int_types,
188184
blocked_quantization_support=True,
189-
supports_chunking=True,
190185
supports_expert_map=True,
191186
needs_matching_quant=True,
192187
)
@@ -196,7 +191,6 @@ def expert_info(kind) -> ExpertInfo:
196191
batched_format,
197192
common_float_and_int_types,
198193
blocked_quantization_support=True,
199-
supports_chunking=False,
200194
supports_expert_map=True,
201195
)
202196

@@ -262,7 +256,6 @@ def expert_info(kind) -> ExpertInfo:
262256
standard_format,
263257
nvfp4_types + fp8_types,
264258
blocked_quantization_support=True,
265-
supports_chunking=True,
266259
# Note: this is a hack to get it to run for now
267260
supports_expert_map=True,
268261
)
@@ -281,7 +274,6 @@ def expert_info(kind) -> ExpertInfo:
281274
standard_format,
282275
fp8_types,
283276
blocked_quantization_support=True,
284-
supports_chunking=True,
285277
supports_expert_map=True,
286278
needs_aiter=True,
287279
)
@@ -294,7 +286,6 @@ def expert_info(kind) -> ExpertInfo:
294286
batched_format,
295287
fp8_types,
296288
blocked_quantization_support=True,
297-
supports_chunking=False,
298289
supports_expert_map=False,
299290
needs_matching_quant=False,
300291
needs_deep_gemm=True,
@@ -304,7 +295,6 @@ def expert_info(kind) -> ExpertInfo:
304295
standard_format,
305296
fp8_types,
306297
blocked_quantization_support=True,
307-
supports_chunking=True,
308298
supports_expert_map=True,
309299
needs_matching_quant=False,
310300
needs_deep_gemm=True,
@@ -314,7 +304,6 @@ def expert_info(kind) -> ExpertInfo:
314304
standard_format,
315305
common_float_and_int_types,
316306
blocked_quantization_support=True,
317-
supports_chunking=True,
318307
supports_expert_map=True,
319308
needs_matching_quant=True,
320309
needs_deep_gemm=True,
@@ -331,15 +320,13 @@ def expert_info(kind) -> ExpertInfo:
331320
standard_format,
332321
fp8_types,
333322
blocked_quantization_support=False,
334-
supports_chunking=True,
335323
supports_expert_map=False,
336324
)
337325
register_experts(
338326
CutlassBatchedExpertsFp8,
339327
batched_format,
340328
fp8_types,
341329
blocked_quantization_support=False,
342-
supports_chunking=False,
343330
supports_expert_map=False,
344331
)
345332
else:
@@ -354,7 +341,6 @@ def expert_info(kind) -> ExpertInfo:
354341
standard_format,
355342
nvfp4_types,
356343
blocked_quantization_support=True,
357-
supports_chunking=True,
358344
supports_expert_map=False,
359345
)
360346
else:

tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,6 @@ def rank_worker(
8585
):
8686
set_random_seed(pgi.rank)
8787

88-
# sanity check
89-
from vllm import envs
90-
91-
if config.fused_moe_chunk_size is not None:
92-
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
93-
9488
# get weights to this device
9589
weights.to_current_device()
9690

tests/kernels/moe/test_block_fp8.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@ def test_w8a8_block_fp8_fused_moe(
158158

159159
torch.manual_seed(seed)
160160

161-
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
162-
163161
a = torch.randn((M, K), dtype=dtype) / 10
164162
score = torch.randn((M, E), dtype=dtype)
165163

@@ -226,11 +224,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
226224
if not _valid_deep_gemm_shape(M, N, K):
227225
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
228226

229-
chunk_size = 1024
230-
231227
torch.manual_seed(seed)
232228

233-
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
234229
block_size = get_mk_alignment_for_contiguous_layout()
235230
dtype = torch.bfloat16
236231

@@ -252,9 +247,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
252247
# setup code in case we are able to revisit this later.
253248
use_compile = False
254249

255-
use_cudagraph = (
256-
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
257-
)
250+
use_cudagraph = N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
258251

259252
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
260253

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ def test_cutlass_moe_8_bit_no_graph(
321321
ep_size: int | None = None,
322322
):
323323
set_random_seed(7)
324-
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
325324
with set_current_vllm_config(vllm_config):
326325
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
327326

@@ -376,7 +375,6 @@ def test_cutlass_moe_8_bit_cuda_graph(
376375
workspace_init,
377376
):
378377
set_random_seed(7)
379-
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
380378
with set_current_vllm_config(vllm_config):
381379
dtype = torch.half
382380

tests/kernels/moe/test_flashinfer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
204204
if not current_platform.has_device_capability(100):
205205
pytest.skip("Test is only supported for sm >= 100")
206206
set_random_seed(7)
207-
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
208207
with set_current_vllm_config(vllm_config):
209208
td = TestData.make_moe_tensors_8bit(
210209
m, k, n, e, is_trtllm=True, activation=activation
@@ -289,7 +288,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
289288
workspace_init,
290289
):
291290
set_random_seed(7)
292-
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
293291
with set_current_vllm_config(vllm_config):
294292
td = TestData.make_moe_tensors_8bit(
295293
m, k, n, e, is_trtllm=False, activation=activation

0 commit comments

Comments
 (0)