From d24aa6a26a2d0dd6026d6414f525c77e1c6c7127 Mon Sep 17 00:00:00 2001 From: addsubmuldiv Date: Tue, 3 Mar 2026 11:33:26 +0800 Subject: [PATCH] npu_megatron_adapt --- src/twinkle/metric/loss.py | 44 +++++++--- src/twinkle/model/megatron/args.py | 27 ++++-- src/twinkle/model/megatron/megatron.py | 88 +++++++++++++++---- .../model/megatron/model/gpt_bridge.py | 16 +++- src/twinkle/model/megatron/model/gpt_model.py | 11 ++- 5 files changed, 147 insertions(+), 39 deletions(-) diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index b15f1f96..9f6c7d41 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -1,6 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from typing import List, Union +import torch.distributed as dist + +from transformers.utils import is_torch_npu_available + +from twinkle import Platform from twinkle.data_format import InputFeature, ModelOutput from .base import Metric @@ -46,19 +51,36 @@ def reset(self): self.num_tokens = 0 def calculate(self): - local_results = [{ - 'loss': self.total_loss, - 'count': self.total_count, - 'grad_norm': self.grad_norm, - 'num_tokens': self.num_tokens - }] + total_loss = float(self.total_loss) + total_count = float(self.total_count) + grad_norm = float(self.grad_norm) + num_tokens = float(self.num_tokens) - all_results = self.gather_results(local_results) + if self.device_mesh is not None and self.process_group is not None and dist.is_initialized(): + is_npu = is_torch_npu_available() + if is_npu: + # On NPU/HCCL, tensor all_reduce is more stable than all_gather_object. + import torch + device = Platform.get_local_device() + stats = torch.tensor([total_loss, total_count, num_tokens], dtype=torch.float64, device=device) + dist.all_reduce(stats, op=dist.ReduceOp.SUM, group=self.process_group) + total_loss, total_count, num_tokens = stats.tolist() - total_loss = sum(r['loss'] for r in all_results) - total_count = sum(r['count'] for r in all_results) - grad_norm = max(r['grad_norm'] for r in all_results) - num_tokens = sum(r['num_tokens'] for r in all_results) + grad_tensor = torch.tensor([grad_norm], dtype=torch.float64, device=device) + dist.all_reduce(grad_tensor, op=dist.ReduceOp.MAX, group=self.process_group) + grad_norm = grad_tensor.item() + else: + local_results = [{ + 'loss': total_loss, + 'count': total_count, + 'grad_norm': grad_norm, + 'num_tokens': num_tokens + }] + all_results = self.gather_results(local_results) + total_loss = sum(r['loss'] for r in all_results) + total_count = sum(r['count'] for r in all_results) + grad_norm = max(r['grad_norm'] for r in all_results) + num_tokens = sum(r['num_tokens'] for r in all_results) if num_tokens > 0: avg_loss = total_loss / num_tokens else: diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py index 858c2f0d..ea6d8052 100644 --- a/src/twinkle/model/megatron/args.py +++ b/src/twinkle/model/megatron/args.py @@ -6,6 +6,8 @@ from types import SimpleNamespace from typing import Any, Dict, List, Literal, Optional +from transformers.utils import is_torch_npu_available + from twinkle import DeviceMesh from twinkle.utils import exists from .utils import convert_hf_config @@ -435,7 +437,8 @@ def create_model(self, ) -> List[nn.Module]: if self._model is not None: return self._model from megatron.core import parallel_state as mpu - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec) from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import AttnBackend @@ -614,17 +617,27 @@ def _get_base_model(m): # Save transformer config for later use (e.g., DDP wrapping) self.config = config - # Get layer spec - enable moe_grouped_gemm for MoE models + is_npu = is_torch_npu_available() + moe_grouped_gemm = num_experts > 0 - try: - layer_spec = get_gpt_layer_with_transformer_engine_spec( + if is_npu: + layer_spec = get_gpt_layer_local_spec( num_experts=mg_config_dict.get('num_experts'), moe_grouped_gemm=moe_grouped_gemm, qk_layernorm=mg_config_dict.get('qk_layernorm', False), + normalization='RMSNorm', ) - except (ImportError, AttributeError): - raise RuntimeError( - 'TransformerEngine is not installed or not compatible with this version of Megatron-Core.') + else: + try: + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=mg_config_dict.get('num_experts'), + moe_grouped_gemm=moe_grouped_gemm, + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + ) + except (ImportError, AttributeError) as e: + raise RuntimeError( + 'TransformerEngine is not installed or not compatible with this version of Megatron-Core.' + ) from e # Create model max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index aa74e72e..90db918c 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -34,7 +34,13 @@ from twinkle.template import Template from twinkle.utils import construct_class, exists from .strategy import MegatronStrategy +from transformers.utils import is_torch_npu_available +if is_torch_npu_available(): + # Enable Megatron on Ascend NPU + from mindspeed.megatron_adaptor import repatch +else: + repatch = None @dataclass class MegatronOptimizerGroup: @@ -71,7 +77,17 @@ def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> boo def __post_init__(self): if self._device_mesh.data_world_size > 1: - self._dp_group = self._device_mesh.create_process_group(['dp', 'fsdp']) + is_npu = is_torch_npu_available() + has_fsdp = (getattr(self._device_mesh, 'fsdp_world_size', 0) or 0) > 1 + + if is_npu and not has_fsdp: + # On NPU/HCCL without FSDP, cached dim-group creation avoids + # inconsistent creation order issues. + self._dp_group = self._device_mesh.get_dim_group('dp') + else: + # Keep metrics/data aggregation on the full data axis. + # This must include fsdp when fsdp_size > 1. + self._dp_group = self._device_mesh.create_process_group(['dp', 'fsdp']) self.train_metrics = [ LossMetric(self._device_mesh, self._dp_group), TrainMetric(self._device_mesh, self._dp_group), @@ -223,6 +239,24 @@ def __init__( sequence_parallel=self.strategy.sequence_parallel, **ac_kwargs, ) + + is_npu = is_torch_npu_available() + + if repatch is not None and is_npu: + from dataclasses import asdict + megatron_args = asdict(args) + try: + repatch(megatron_args) + except NameError as e: + # MindSpeed 0.12.1 has a known repatch bug: + # mindspeed/patch_utils.py references `inspect` without importing it. + # Keep training alive with initial patches already applied at import time. + if 'inspect' in str(e): + logging.getLogger(__name__).warning( + 'Skip MindSpeed repatch due to upstream bug (%s). Continue with initial patches.', e) + else: + raise + set_args(args) self._initialized = False self.model: List[nn.Module] = self._create_megatron_model(load_weights, **kwargs) @@ -433,6 +467,10 @@ def post_loss_function(output_tensor, inputs): # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) def forward_step_func(data_iterator, model): batch = next(data_iterator) + local_device = Platform.get_local_device() + for key, value in batch.items(): + if isinstance(value, torch.Tensor) and value.device != local_device: + batch[key] = value.to(local_device, non_blocking=True) labels = batch.pop('labels', None) output_tensor = model(**batch) batch['labels'] = labels @@ -726,30 +764,46 @@ def _create_megatron_optimizer(self, **kwargs): lr = kwargs.pop('lr', 1e-4) use_distributed_optimizer: bool = kwargs.pop('use_distributed_optimizer', False) - opt_config = OptimizerConfig( - optimizer='adam', - lr=lr, - min_lr=kwargs.get('min_lr', 0.0), - weight_decay=kwargs.get('weight_decay', 0.01), - adam_beta1=kwargs.get('adam_beta1', 0.9), - adam_beta2=kwargs.get('adam_beta2', 0.999), - adam_eps=kwargs.get('adam_eps', 1e-8), - clip_grad=kwargs.get('clip_grad', 1.0), - bf16=kwargs.get('bf16', True), - use_distributed_optimizer=use_distributed_optimizer, - overlap_param_gather=kwargs.get('overlap_param_gather', False), - log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), - **kwargs, - ) + config_sig = inspect.signature(OptimizerConfig).parameters + overlap_param_gather = kwargs.get('overlap_param_gather', False) + overlap_param_gather_with_step = kwargs.get('overlap_param_gather_with_optimizer_step', overlap_param_gather) + + config_kwargs = { + 'optimizer': 'adam', + 'lr': lr, + 'min_lr': kwargs.get('min_lr', 0.0), + 'weight_decay': kwargs.get('weight_decay', 0.01), + 'adam_beta1': kwargs.get('adam_beta1', 0.9), + 'adam_beta2': kwargs.get('adam_beta2', 0.999), + 'adam_eps': kwargs.get('adam_eps', 1e-8), + 'clip_grad': kwargs.get('clip_grad', 1.0), + 'bf16': kwargs.get('bf16', True), + 'use_distributed_optimizer': use_distributed_optimizer, + 'log_num_zeros_in_grad': kwargs.get('log_num_zeros_in_grad', False), + } + if 'overlap_param_gather' in config_sig: + config_kwargs['overlap_param_gather'] = overlap_param_gather + if 'overlap_param_gather_with_optimizer_step' in config_sig: + config_kwargs['overlap_param_gather_with_optimizer_step'] = overlap_param_gather_with_step + + # Keep compatibility across Megatron-Core versions by only forwarding supported args. + for key, value in kwargs.items(): + if key in config_sig and key not in config_kwargs: + config_kwargs[key] = value + + opt_config = OptimizerConfig(**config_kwargs) # Ensure each model chunk has ddp_config attached (required by Megatron optimizer) from megatron.core.distributed import DistributedDataParallelConfig + is_npu = is_torch_npu_available() + model_chunks = self.model for model_chunk in model_chunks: assert hasattr(model_chunk, 'ddp_config') optimizer = get_megatron_optimizer( config=opt_config, model_chunks=model_chunks, + use_gloo_process_groups=False if is_npu else True ) return optimizer @@ -1419,12 +1473,14 @@ def initialize(self, **kwargs) -> None: from .args import get_args self._try_init_process_group() args = get_args() + is_npu = is_torch_npu_available() init_kwargs = { 'tensor_model_parallel_size': args.tensor_model_parallel_size, 'pipeline_model_parallel_size': args.pipeline_model_parallel_size, 'context_parallel_size': args.context_parallel_size, 'virtual_pipeline_model_parallel_size': args.virtual_pipeline_model_parallel_size, 'expert_model_parallel_size': args.expert_model_parallel_size, + 'create_gloo_process_groups': False if is_npu else True, } if args.order: diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py index a4b59c9c..3e03c3ac 100644 --- a/src/twinkle/model/megatron/model/gpt_bridge.py +++ b/src/twinkle/model/megatron/model/gpt_bridge.py @@ -1249,8 +1249,12 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore) else: hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict, - 'input_layernorm.weight', to_mcore) + # TE spec keeps attention LN under linear_qkv.layer_norm_weight, + # while local spec keeps it as input_layernorm.weight. + attn_ln_key = 'self_attention.linear_qkv.layer_norm_weight' + if deep_getattr(mg_layer, attn_ln_key) is None: + attn_ln_key = 'input_layernorm.weight' + self._set_state_dict(mg_layer, attn_ln_key, hf_state_dict, 'input_layernorm.weight', to_mcore) return hf_state_dict def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool): @@ -1264,8 +1268,12 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool to_mcore) else: hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, - 'post_attention_layernorm.weight', to_mcore) + # TE spec keeps MLP LN under linear_fc1.layer_norm_weight, + # while local spec keeps it as pre_mlp_layernorm.weight. + mlp_ln_key = 'mlp.linear_fc1.layer_norm_weight' + if deep_getattr(mg_layer, mlp_ln_key) is None: + mlp_ln_key = 'pre_mlp_layernorm.weight' + self._set_state_dict(mg_layer, mlp_ln_key, hf_state_dict, 'post_attention_layernorm.weight', to_mcore) return hf_state_dict def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): diff --git a/src/twinkle/model/megatron/model/gpt_model.py b/src/twinkle/model/megatron/model/gpt_model.py index 477ccaf5..2aed5a53 100644 --- a/src/twinkle/model/megatron/model/gpt_model.py +++ b/src/twinkle/model/megatron/model/gpt_model.py @@ -6,7 +6,6 @@ from megatron.core import mpu from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.extensions.transformer_engine import TELinear from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel as McoreGPTModel @@ -28,6 +27,16 @@ mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +try: + from megatron.core.extensions.transformer_engine import TELinear +except ImportError: + + class TELinear(torch.nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + raise RuntimeError('TransformerEngine is required to instantiate OutputLayerLinear.') + class OutputLayerLinear(TELinear):