Skip to content

Commit 700db7d

Browse files
committed
update code
1 parent 41613f2 commit 700db7d

7 files changed

Lines changed: 42 additions & 40 deletions

File tree

docker/Dockerfile_ascend_a3

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ ARG LMDEPLOY_TAG=main
2222
RUN --mount=type=cache,target=/root/.cache \
2323
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
2424
pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn && \
25+
pip install --no-cache-dir torch==2.9.0 torch-npu==2.9.0 torchvision==0.24.0 && \
2526
TORCH_DEVICE_BACKEND_AUTOLOAD=0 DEVICE=ascend pip install git+https://github.com/DeepLink-org/dlinfer.git@${DLINFER_TAG} && \
2627
LMDEPLOY_TARGET_DEVICE=ascend pip install git+https://github.com/InternLM/lmdeploy.git@${LMDEPLOY_TAG}

docs/zh_cn/supported_models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
| QWen2.5-VL | 3B - 72B | MLLM | Yes | Yes | - | - | Yes | - | Yes | No |
151151
| QWen2-MoE | A14.57B | LLM | Yes | - | No | No | - | - | Yes | - |
152152
| QWen3 | 0.6B-235B | LLM | Yes | Yes | No | No | Yes | Yes | Yes | Yes |
153-
| DeepSeek-V2 | 16B | LLM | Yes | Yes | No | No | - | - | - | - |
153+
| DeepSeek-V2 | 16B | LLM | No | Yes | No | No | - | - | - | - |
154154
| InternVL(v1.5) | 2B-26B | MLLM | Yes | - | Yes | Yes | - | - | Yes | - |
155155
| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes |
156156
| InternVL2.5 | 1B-78B | MLLM | Yes | Yes | Yes | Yes | Yes | - | Yes | Yes |

lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from lmdeploy.pytorch.distributed import get_dist_manager
1717
from lmdeploy.utils import get_logger
1818

19-
from ..moe import DlinferMoeMetada, DlinferMoeType
19+
from ..moe import DlinferMoECommType, DlinferMoeMetadata
2020
from ..op_backend import DlinferOpsBackend
2121

2222
logger = get_logger('lmdeploy')
@@ -281,19 +281,19 @@ def get_dist_meta():
281281
def get_tokens_info(dp_size, tp_size, ep_size, ep_group):
282282
if ep_size <= 1:
283283
return 0, 0, 0
284-
# get runtime_tokens_current_rank
284+
# get padded_tokens_current_rank
285285
is_graph = cls.enable_graph and step_context.is_decoding
286286
if is_graph:
287287
from dlinfer.framework.lmdeploy_ext.cudagraph.ascend_cudagraph import get_ascend_compatible_size
288288
actual_tokens_current_rank = step_context.q_seqlens.shape[0]
289-
runtime_tokens_current_rank = min(get_ascend_compatible_size(actual_tokens_current_rank),
290-
cls.max_batches)
289+
padded_tokens_current_rank = min(get_ascend_compatible_size(actual_tokens_current_rank),
290+
cls.max_batches)
291291
else:
292292
actual_tokens_current_rank = step_context.q_seqlens.sum().item()
293-
runtime_tokens_current_rank = actual_tokens_current_rank
293+
padded_tokens_current_rank = actual_tokens_current_rank
294294
# get max_tokens_across_dp
295295
if dp_size > 1:
296-
runtime_tokens_tensor = torch.tensor([runtime_tokens_current_rank],
296+
runtime_tokens_tensor = torch.tensor([padded_tokens_current_rank],
297297
dtype=step_context.q_seqlens.dtype,
298298
device=torch.npu.current_device())
299299
world_size = dp_size * tp_size
@@ -303,49 +303,49 @@ def get_tokens_info(dp_size, tp_size, ep_size, ep_group):
303303
dist.all_gather_into_tensor(runtime_tokens_buffer, runtime_tokens_tensor, ep_group)
304304
max_tokens_across_dp = torch.max(runtime_tokens_buffer).item()
305305
else:
306-
max_tokens_across_dp = runtime_tokens_current_rank
307-
return actual_tokens_current_rank, runtime_tokens_current_rank, max_tokens_across_dp
306+
max_tokens_across_dp = padded_tokens_current_rank
307+
return actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp
308308

309309
@lru_cache
310310
def init_mc2_token_capacity(tp_size):
311311
max_num_tokens = min(cls.max_batches, 512)
312312
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
313313
return num_tokens_per_tp_rank * tp_size
314314

315-
def select_moe_type(max_tokens_across_dp, dp_size, tp_size, ep_size):
315+
def select_moe_comm_type(max_tokens_across_dp, dp_size, tp_size, ep_size):
316316
if ep_size <= 1:
317-
return DlinferMoeType.ALLGATHER
317+
return DlinferMoECommType.ALLGATHER
318318
mc2_token_capacity = init_mc2_token_capacity(tp_size)
319319
is_graph = cls.enable_graph and step_context.is_decoding
320320
if is_graph:
321321
import math
322322
max_tokens_across_dp = math.ceil(max_tokens_across_dp / tp_size) * tp_size
323323
if SocVersion.is_A2():
324324
if max_tokens_across_dp <= mc2_token_capacity and dp_size * tp_size >= 16:
325-
return DlinferMoeType.MC2
325+
return DlinferMoECommType.MC2
326326
else:
327-
return DlinferMoeType.ALLGATHER
327+
return DlinferMoECommType.ALLGATHER
328328
elif SocVersion.is_A3():
329329
if max_tokens_across_dp <= mc2_token_capacity:
330-
return DlinferMoeType.MC2
330+
return DlinferMoECommType.MC2
331331
else:
332-
return DlinferMoeType.ALLTOALL
332+
return DlinferMoECommType.ALLTOALL
333333
else:
334334
raise ValueError(f'Unsupported soc_version: {SocVersion.soc_version()}')
335335

336-
def get_pad_info(actual_tokens_current_rank, runtime_tokens_current_rank, max_tokens_across_dp, tp_size,
337-
moe_type):
336+
def get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp, tp_size,
337+
moe_comm_type):
338338
x_active_mask = None
339-
if moe_type == DlinferMoeType.MC2:
339+
if moe_comm_type == DlinferMoECommType.MC2:
340340
paded_size = math.ceil(max_tokens_across_dp / tp_size) * tp_size
341-
pad_size = paded_size - runtime_tokens_current_rank
341+
pad_size = paded_size - padded_tokens_current_rank
342342
x_active_mask = torch.ones(actual_tokens_current_rank,
343343
dtype=torch.bool,
344344
device=torch.npu.current_device())
345-
elif moe_type == DlinferMoeType.ALLTOALL:
346-
pad_size = tp_size - runtime_tokens_current_rank
347-
elif moe_type == DlinferMoeType.ALLGATHER:
348-
pad_size = max_tokens_across_dp - runtime_tokens_current_rank
345+
elif moe_comm_type == DlinferMoECommType.ALLTOALL:
346+
pad_size = tp_size - padded_tokens_current_rank
347+
elif moe_comm_type == DlinferMoECommType.ALLGATHER:
348+
pad_size = max_tokens_across_dp - padded_tokens_current_rank
349349
else:
350350
pad_size = 0
351351
return pad_size, x_active_mask
@@ -404,15 +404,15 @@ def get_moe_group_name(group):
404404
step_context.attn_metadata = attn_metadata
405405

406406
cls.dist_meta = get_dist_meta()
407-
actual_tokens_current_rank, runtime_tokens_current_rank, max_tokens_across_dp = get_tokens_info(
407+
actual_tokens_current_rank, padded_tokens_current_rank, max_tokens_across_dp = get_tokens_info(
408408
cls.dist_meta.dp_size, cls.dist_meta.tp_size, cls.dist_meta.ep_size, cls.dist_meta.ep_group)
409-
moe_type = select_moe_type(max_tokens_across_dp, cls.dist_meta.dp_size, cls.dist_meta.tp_size,
410-
cls.dist_meta.ep_size)
411-
pad_size, x_active_mask = get_pad_info(actual_tokens_current_rank, runtime_tokens_current_rank,
412-
max_tokens_across_dp, cls.dist_meta.tp_size, moe_type)
409+
moe_comm_type = select_moe_comm_type(max_tokens_across_dp, cls.dist_meta.dp_size, cls.dist_meta.tp_size,
410+
cls.dist_meta.ep_size)
411+
pad_size, x_active_mask = get_pad_info(actual_tokens_current_rank, padded_tokens_current_rank,
412+
max_tokens_across_dp, cls.dist_meta.tp_size, moe_comm_type)
413413
moe_group_name = get_moe_group_name(cls.dist_meta.ep_group)
414414

415-
moe_metadata = DlinferMoeMetada(
415+
moe_metadata = DlinferMoeMetadata(
416416
max_tokens_across_dp=max_tokens_across_dp,
417417
pad_size=pad_size,
418418
dp_size=cls.dist_meta.dp_size,
@@ -422,7 +422,7 @@ def get_moe_group_name(group):
422422
ep_rank=cls.dist_meta.ep_rank,
423423
tp_group=cls.dist_meta.tp_group,
424424
ep_group=cls.dist_meta.ep_group,
425-
moe_type=moe_type,
425+
moe_comm_type=moe_comm_type,
426426
x_active_mask=x_active_mask,
427427
moe_group_name=moe_group_name,
428428
)

lmdeploy/pytorch/backends/dlinfer/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import torch
66

7-
from lmdeploy.pytorch.kernels.dlinfer import DlinferMoeMetada # noqa: F401
8-
from lmdeploy.pytorch.kernels.dlinfer import DlinferMoeType # noqa: F401
7+
from lmdeploy.pytorch.kernels.dlinfer import DlinferMoECommType # noqa: F401
8+
from lmdeploy.pytorch.kernels.dlinfer import DlinferMoeMetadata # noqa: F401
99
from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax
1010
from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
1111

lmdeploy/pytorch/kernels/dlinfer/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .awq_kernels import awq_linear
55
from .fill_kv_cache import fill_kv_cache
66
from .flash_attention import flash_attention_fwd
7-
from .fused_moe import DlinferMoeMetada, DlinferMoeType, fused_moe
7+
from .fused_moe import DlinferMoECommType, DlinferMoeMetadata, fused_moe
88
from .linear import linear
99
from .moe_gating_topk_softmax import moe_gating_topk_softmax
1010
from .pagedattention import paged_attention_fwd
@@ -15,8 +15,8 @@
1515
'apply_rotary_pos_emb',
1616
'awq_linear',
1717
'fill_kv_cache',
18-
'DlinferMoeType',
19-
'DlinferMoeMetada',
18+
'DlinferMoECommType',
19+
'DlinferMoeMetadata',
2020
'fused_moe',
2121
'paged_attention_fwd',
2222
'flash_attention_fwd',

lmdeploy/pytorch/kernels/dlinfer/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import dlinfer.ops as ext_ops
3-
from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetada
4-
from dlinfer.utils.type_annotation import MoeType as DlinferMoeType # noqa: F401
3+
from dlinfer.utils.type_annotation import MoECommType as DlinferMoECommType # noqa: F401
4+
from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata
55
from torch import Tensor
66

77

@@ -13,7 +13,7 @@ def fused_moe(
1313
topk_ids: Tensor,
1414
topk: int,
1515
renormalize: bool,
16-
moe_metadata: DlinferMoeMetada,
16+
moe_metadata: DlinferMoeMetadata,
1717
):
1818
"""Dlinfer fused moe."""
1919
return ext_ops.fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights, topk_ids, topk, renormalize,
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import dlinfer.ops as ext_ops
3-
from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetada
3+
from dlinfer.utils.type_annotation import MoeMetadata as DlinferMoeMetadata
44
from torch import Tensor
55

66

7-
def moe_gating_topk_softmax(router_logits: Tensor, topk: int, moe_metadata: DlinferMoeMetada) -> tuple[Tensor, Tensor]:
7+
def moe_gating_topk_softmax(router_logits: Tensor, topk: int,
8+
moe_metadata: DlinferMoeMetadata) -> tuple[Tensor, Tensor]:
89
routing_weights, selected_experts = ext_ops.moe_gating_topk_softmax(router_logits, topk, moe_metadata)
910
return routing_weights, selected_experts

0 commit comments

Comments
 (0)