1616from lmdeploy .pytorch .distributed import get_dist_manager
1717from lmdeploy .utils import get_logger
1818
19- from ..moe import DlinferMoeMetada , DlinferMoeType
19+ from ..moe import DlinferMoECommType , DlinferMoeMetadata
2020from ..op_backend import DlinferOpsBackend
2121
2222logger = get_logger ('lmdeploy' )
@@ -281,19 +281,19 @@ def get_dist_meta():
281281 def get_tokens_info (dp_size , tp_size , ep_size , ep_group ):
282282 if ep_size <= 1 :
283283 return 0 , 0 , 0
284- # get runtime_tokens_current_rank
284+ # get padded_tokens_current_rank
285285 is_graph = cls .enable_graph and step_context .is_decoding
286286 if is_graph :
287287 from dlinfer .framework .lmdeploy_ext .cudagraph .ascend_cudagraph import get_ascend_compatible_size
288288 actual_tokens_current_rank = step_context .q_seqlens .shape [0 ]
289- runtime_tokens_current_rank = min (get_ascend_compatible_size (actual_tokens_current_rank ),
290- cls .max_batches )
289+ padded_tokens_current_rank = min (get_ascend_compatible_size (actual_tokens_current_rank ),
290+ cls .max_batches )
291291 else :
292292 actual_tokens_current_rank = step_context .q_seqlens .sum ().item ()
293- runtime_tokens_current_rank = actual_tokens_current_rank
293+ padded_tokens_current_rank = actual_tokens_current_rank
294294 # get max_tokens_across_dp
295295 if dp_size > 1 :
296- runtime_tokens_tensor = torch .tensor ([runtime_tokens_current_rank ],
296+ runtime_tokens_tensor = torch .tensor ([padded_tokens_current_rank ],
297297 dtype = step_context .q_seqlens .dtype ,
298298 device = torch .npu .current_device ())
299299 world_size = dp_size * tp_size
@@ -303,49 +303,49 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group):
303303 dist .all_gather_into_tensor (runtime_tokens_buffer , runtime_tokens_tensor , ep_group )
304304 max_tokens_across_dp = torch .max (runtime_tokens_buffer ).item ()
305305 else :
306- max_tokens_across_dp = runtime_tokens_current_rank
307- return actual_tokens_current_rank , runtime_tokens_current_rank , max_tokens_across_dp
306+ max_tokens_across_dp = padded_tokens_current_rank
307+ return actual_tokens_current_rank , padded_tokens_current_rank , max_tokens_across_dp
308308
309309 @lru_cache
310310 def init_mc2_token_capacity (tp_size ):
311311 max_num_tokens = min (cls .max_batches , 512 )
312312 num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1 ) // tp_size
313313 return num_tokens_per_tp_rank * tp_size
314314
315- def select_moe_type (max_tokens_across_dp , dp_size , tp_size , ep_size ):
315+ def select_moe_comm_type (max_tokens_across_dp , dp_size , tp_size , ep_size ):
316316 if ep_size <= 1 :
317- return DlinferMoeType .ALLGATHER
317+ return DlinferMoECommType .ALLGATHER
318318 mc2_token_capacity = init_mc2_token_capacity (tp_size )
319319 is_graph = cls .enable_graph and step_context .is_decoding
320320 if is_graph :
321321 import math
322322 max_tokens_across_dp = math .ceil (max_tokens_across_dp / tp_size ) * tp_size
323323 if SocVersion .is_A2 ():
324324 if max_tokens_across_dp <= mc2_token_capacity and dp_size * tp_size >= 16 :
325- return DlinferMoeType .MC2
325+ return DlinferMoECommType .MC2
326326 else :
327- return DlinferMoeType .ALLGATHER
327+ return DlinferMoECommType .ALLGATHER
328328 elif SocVersion .is_A3 ():
329329 if max_tokens_across_dp <= mc2_token_capacity :
330- return DlinferMoeType .MC2
330+ return DlinferMoECommType .MC2
331331 else :
332- return DlinferMoeType .ALLTOALL
332+ return DlinferMoECommType .ALLTOALL
333333 else :
334334 raise ValueError (f'Unsupported soc_version: { SocVersion .soc_version ()} ' )
335335
336- def get_pad_info (actual_tokens_current_rank , runtime_tokens_current_rank , max_tokens_across_dp , tp_size ,
337- moe_type ):
336+ def get_pad_info (actual_tokens_current_rank , padded_tokens_current_rank , max_tokens_across_dp , tp_size ,
337+ moe_comm_type ):
338338 x_active_mask = None
339- if moe_type == DlinferMoeType .MC2 :
339+ if moe_comm_type == DlinferMoECommType .MC2 :
340340 paded_size = math .ceil (max_tokens_across_dp / tp_size ) * tp_size
341- pad_size = paded_size - runtime_tokens_current_rank
341+ pad_size = paded_size - padded_tokens_current_rank
342342 x_active_mask = torch .ones (actual_tokens_current_rank ,
343343 dtype = torch .bool ,
344344 device = torch .npu .current_device ())
345- elif moe_type == DlinferMoeType .ALLTOALL :
346- pad_size = tp_size - runtime_tokens_current_rank
347- elif moe_type == DlinferMoeType .ALLGATHER :
348- pad_size = max_tokens_across_dp - runtime_tokens_current_rank
345+ elif moe_comm_type == DlinferMoECommType .ALLTOALL :
346+ pad_size = tp_size - padded_tokens_current_rank
347+ elif moe_comm_type == DlinferMoECommType .ALLGATHER :
348+ pad_size = max_tokens_across_dp - padded_tokens_current_rank
349349 else :
350350 pad_size = 0
351351 return pad_size , x_active_mask
@@ -404,15 +404,15 @@ def get_moe_group_name(group):
404404 step_context .attn_metadata = attn_metadata
405405
406406 cls .dist_meta = get_dist_meta ()
407- actual_tokens_current_rank , runtime_tokens_current_rank , max_tokens_across_dp = get_tokens_info (
407+ actual_tokens_current_rank , padded_tokens_current_rank , max_tokens_across_dp = get_tokens_info (
408408 cls .dist_meta .dp_size , cls .dist_meta .tp_size , cls .dist_meta .ep_size , cls .dist_meta .ep_group )
409- moe_type = select_moe_type (max_tokens_across_dp , cls .dist_meta .dp_size , cls .dist_meta .tp_size ,
410- cls .dist_meta .ep_size )
411- pad_size , x_active_mask = get_pad_info (actual_tokens_current_rank , runtime_tokens_current_rank ,
412- max_tokens_across_dp , cls .dist_meta .tp_size , moe_type )
409+ moe_comm_type = select_moe_comm_type (max_tokens_across_dp , cls .dist_meta .dp_size , cls .dist_meta .tp_size ,
410+ cls .dist_meta .ep_size )
411+ pad_size , x_active_mask = get_pad_info (actual_tokens_current_rank , padded_tokens_current_rank ,
412+ max_tokens_across_dp , cls .dist_meta .tp_size , moe_comm_type )
413413 moe_group_name = get_moe_group_name (cls .dist_meta .ep_group )
414414
415- moe_metadata = DlinferMoeMetada (
415+ moe_metadata = DlinferMoeMetadata (
416416 max_tokens_across_dp = max_tokens_across_dp ,
417417 pad_size = pad_size ,
418418 dp_size = cls .dist_meta .dp_size ,
@@ -422,7 +422,7 @@ def get_moe_group_name(group):
422422 ep_rank = cls .dist_meta .ep_rank ,
423423 tp_group = cls .dist_meta .tp_group ,
424424 ep_group = cls .dist_meta .ep_group ,
425- moe_type = moe_type ,
425+ moe_comm_type = moe_comm_type ,
426426 x_active_mask = x_active_mask ,
427427 moe_group_name = moe_group_name ,
428428 )
0 commit comments