Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions src/twinkle/metric/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from typing import List, Union

import torch.distributed as dist

from transformers.utils import is_torch_npu_available

from twinkle import Platform
from twinkle.data_format import InputFeature, ModelOutput
from .base import Metric

Expand Down Expand Up @@ -46,19 +51,36 @@ def reset(self):
self.num_tokens = 0

def calculate(self):
local_results = [{
'loss': self.total_loss,
'count': self.total_count,
'grad_norm': self.grad_norm,
'num_tokens': self.num_tokens
}]
total_loss = float(self.total_loss)
total_count = float(self.total_count)
grad_norm = float(self.grad_norm)
num_tokens = float(self.num_tokens)

all_results = self.gather_results(local_results)
if self.device_mesh is not None and self.process_group is not None and dist.is_initialized():
is_npu = is_torch_npu_available()
if is_npu:
# On NPU/HCCL, tensor all_reduce is more stable than all_gather_object.
import torch
device = Platform.get_local_device()
stats = torch.tensor([total_loss, total_count, num_tokens], dtype=torch.float64, device=device)
dist.all_reduce(stats, op=dist.ReduceOp.SUM, group=self.process_group)
total_loss, total_count, num_tokens = stats.tolist()

total_loss = sum(r['loss'] for r in all_results)
total_count = sum(r['count'] for r in all_results)
grad_norm = max(r['grad_norm'] for r in all_results)
num_tokens = sum(r['num_tokens'] for r in all_results)
grad_tensor = torch.tensor([grad_norm], dtype=torch.float64, device=device)
dist.all_reduce(grad_tensor, op=dist.ReduceOp.MAX, group=self.process_group)
grad_norm = grad_tensor.item()
else:
local_results = [{
'loss': total_loss,
'count': total_count,
'grad_norm': grad_norm,
'num_tokens': num_tokens
}]
all_results = self.gather_results(local_results)
total_loss = sum(r['loss'] for r in all_results)
total_count = sum(r['count'] for r in all_results)
grad_norm = max(r['grad_norm'] for r in all_results)
num_tokens = sum(r['num_tokens'] for r in all_results)
if num_tokens > 0:
avg_loss = total_loss / num_tokens
else:
Expand Down
27 changes: 20 additions & 7 deletions src/twinkle/model/megatron/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from types import SimpleNamespace
from typing import Any, Dict, List, Literal, Optional

from transformers.utils import is_torch_npu_available

from twinkle import DeviceMesh
from twinkle.utils import exists
from .utils import convert_hf_config
Expand Down Expand Up @@ -435,7 +437,8 @@ def create_model(self, ) -> List[nn.Module]:
if self._model is not None:
return self._model
from megatron.core import parallel_state as mpu
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec)
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.enums import AttnBackend

Expand Down Expand Up @@ -614,17 +617,27 @@ def _get_base_model(m):
# Save transformer config for later use (e.g., DDP wrapping)
self.config = config

# Get layer spec - enable moe_grouped_gemm for MoE models
is_npu = is_torch_npu_available()

moe_grouped_gemm = num_experts > 0
try:
layer_spec = get_gpt_layer_with_transformer_engine_spec(
if is_npu:
layer_spec = get_gpt_layer_local_spec(
num_experts=mg_config_dict.get('num_experts'),
moe_grouped_gemm=moe_grouped_gemm,
qk_layernorm=mg_config_dict.get('qk_layernorm', False),
normalization='RMSNorm',
)
except (ImportError, AttributeError):
raise RuntimeError(
'TransformerEngine is not installed or not compatible with this version of Megatron-Core.')
else:
try:
layer_spec = get_gpt_layer_with_transformer_engine_spec(
num_experts=mg_config_dict.get('num_experts'),
moe_grouped_gemm=moe_grouped_gemm,
qk_layernorm=mg_config_dict.get('qk_layernorm', False),
)
except (ImportError, AttributeError) as e:
raise RuntimeError(
'TransformerEngine is not installed or not compatible with this version of Megatron-Core.'
) from e

# Create model
max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096)
Expand Down
88 changes: 72 additions & 16 deletions src/twinkle/model/megatron/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
from twinkle.template import Template
from twinkle.utils import construct_class, exists
from .strategy import MegatronStrategy
from transformers.utils import is_torch_npu_available

if is_torch_npu_available():
# Enable Megatron on Ascend NPU
from mindspeed.megatron_adaptor import repatch
else:
repatch = None

@dataclass
class MegatronOptimizerGroup:
Expand Down Expand Up @@ -71,7 +77,17 @@ def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> boo

def __post_init__(self):
if self._device_mesh.data_world_size > 1:
self._dp_group = self._device_mesh.create_process_group(['dp', 'fsdp'])
is_npu = is_torch_npu_available()
has_fsdp = (getattr(self._device_mesh, 'fsdp_world_size', 0) or 0) > 1

if is_npu and not has_fsdp:
# On NPU/HCCL without FSDP, cached dim-group creation avoids
# inconsistent creation order issues.
self._dp_group = self._device_mesh.get_dim_group('dp')
else:
# Keep metrics/data aggregation on the full data axis.
# This must include fsdp when fsdp_size > 1.
self._dp_group = self._device_mesh.create_process_group(['dp', 'fsdp'])
self.train_metrics = [
LossMetric(self._device_mesh, self._dp_group),
TrainMetric(self._device_mesh, self._dp_group),
Expand Down Expand Up @@ -223,6 +239,24 @@ def __init__(
sequence_parallel=self.strategy.sequence_parallel,
**ac_kwargs,
)

is_npu = is_torch_npu_available()

if repatch is not None and is_npu:
Comment on lines +243 to +245
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid repeated calls to is_torch_npu_available(), it's better to call it once in the initializer and store the result in an instance attribute like self.is_npu. This attribute can then be reused in other methods of this class, such as _create_megatron_optimizer and initialize, improving code clarity and maintainability.

Suggested change
is_npu = is_torch_npu_available()
if repatch is not None and is_npu:
self.is_npu = is_torch_npu_available()
if repatch is not None and self.is_npu:

from dataclasses import asdict
megatron_args = asdict(args)
try:
repatch(megatron_args)
except NameError as e:
# MindSpeed 0.12.1 has a known repatch bug:
# mindspeed/patch_utils.py references `inspect` without importing it.
# Keep training alive with initial patches already applied at import time.
if 'inspect' in str(e):
logging.getLogger(__name__).warning(
'Skip MindSpeed repatch due to upstream bug (%s). Continue with initial patches.', e)
else:
raise

set_args(args)
self._initialized = False
self.model: List[nn.Module] = self._create_megatron_model(load_weights, **kwargs)
Expand Down Expand Up @@ -433,6 +467,10 @@ def post_loss_function(output_tensor, inputs):
# forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func))
def forward_step_func(data_iterator, model):
batch = next(data_iterator)
local_device = Platform.get_local_device()
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.device != local_device:
batch[key] = value.to(local_device, non_blocking=True)
labels = batch.pop('labels', None)
output_tensor = model(**batch)
batch['labels'] = labels
Expand Down Expand Up @@ -726,30 +764,46 @@ def _create_megatron_optimizer(self, **kwargs):
lr = kwargs.pop('lr', 1e-4)
use_distributed_optimizer: bool = kwargs.pop('use_distributed_optimizer', False)

opt_config = OptimizerConfig(
optimizer='adam',
lr=lr,
min_lr=kwargs.get('min_lr', 0.0),
weight_decay=kwargs.get('weight_decay', 0.01),
adam_beta1=kwargs.get('adam_beta1', 0.9),
adam_beta2=kwargs.get('adam_beta2', 0.999),
adam_eps=kwargs.get('adam_eps', 1e-8),
clip_grad=kwargs.get('clip_grad', 1.0),
bf16=kwargs.get('bf16', True),
use_distributed_optimizer=use_distributed_optimizer,
overlap_param_gather=kwargs.get('overlap_param_gather', False),
log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False),
**kwargs,
)
config_sig = inspect.signature(OptimizerConfig).parameters
overlap_param_gather = kwargs.get('overlap_param_gather', False)
overlap_param_gather_with_step = kwargs.get('overlap_param_gather_with_optimizer_step', overlap_param_gather)

config_kwargs = {
'optimizer': 'adam',
'lr': lr,
'min_lr': kwargs.get('min_lr', 0.0),
'weight_decay': kwargs.get('weight_decay', 0.01),
'adam_beta1': kwargs.get('adam_beta1', 0.9),
'adam_beta2': kwargs.get('adam_beta2', 0.999),
'adam_eps': kwargs.get('adam_eps', 1e-8),
'clip_grad': kwargs.get('clip_grad', 1.0),
'bf16': kwargs.get('bf16', True),
'use_distributed_optimizer': use_distributed_optimizer,
'log_num_zeros_in_grad': kwargs.get('log_num_zeros_in_grad', False),
}
if 'overlap_param_gather' in config_sig:
config_kwargs['overlap_param_gather'] = overlap_param_gather
if 'overlap_param_gather_with_optimizer_step' in config_sig:
config_kwargs['overlap_param_gather_with_optimizer_step'] = overlap_param_gather_with_step

# Keep compatibility across Megatron-Core versions by only forwarding supported args.
for key, value in kwargs.items():
if key in config_sig and key not in config_kwargs:
config_kwargs[key] = value

opt_config = OptimizerConfig(**config_kwargs)

# Ensure each model chunk has ddp_config attached (required by Megatron optimizer)
from megatron.core.distributed import DistributedDataParallelConfig
is_npu = is_torch_npu_available()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To maintain consistency and avoid redundant calls, please use the self.is_npu attribute that should be initialized in the constructor.

Suggested change
is_npu = is_torch_npu_available()
is_npu = self.is_npu


model_chunks = self.model
for model_chunk in model_chunks:
assert hasattr(model_chunk, 'ddp_config')
optimizer = get_megatron_optimizer(
config=opt_config,
model_chunks=model_chunks,
use_gloo_process_groups=False if is_npu else True
)
return optimizer

Expand Down Expand Up @@ -1419,12 +1473,14 @@ def initialize(self, **kwargs) -> None:
from .args import get_args
self._try_init_process_group()
args = get_args()
is_npu = is_torch_npu_available()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To maintain consistency and avoid redundant calls, please use the self.is_npu attribute that should be initialized in the constructor.

Suggested change
is_npu = is_torch_npu_available()
is_npu = self.is_npu

init_kwargs = {
'tensor_model_parallel_size': args.tensor_model_parallel_size,
'pipeline_model_parallel_size': args.pipeline_model_parallel_size,
'context_parallel_size': args.context_parallel_size,
'virtual_pipeline_model_parallel_size': args.virtual_pipeline_model_parallel_size,
'expert_model_parallel_size': args.expert_model_parallel_size,
'create_gloo_process_groups': False if is_npu else True,
}

if args.order:
Expand Down
16 changes: 12 additions & 4 deletions src/twinkle/model/megatron/model/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,8 +1249,12 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo
self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore)
else:
hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore))
self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict,
'input_layernorm.weight', to_mcore)
# TE spec keeps attention LN under linear_qkv.layer_norm_weight,
# while local spec keeps it as input_layernorm.weight.
attn_ln_key = 'self_attention.linear_qkv.layer_norm_weight'
if deep_getattr(mg_layer, attn_ln_key) is None:
attn_ln_key = 'input_layernorm.weight'
self._set_state_dict(mg_layer, attn_ln_key, hf_state_dict, 'input_layernorm.weight', to_mcore)
return hf_state_dict

def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool):
Expand All @@ -1264,8 +1268,12 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool
to_mcore)
else:
hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore))
self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict,
'post_attention_layernorm.weight', to_mcore)
# TE spec keeps MLP LN under linear_fc1.layer_norm_weight,
# while local spec keeps it as pre_mlp_layernorm.weight.
mlp_ln_key = 'mlp.linear_fc1.layer_norm_weight'
if deep_getattr(mg_layer, mlp_ln_key) is None:
mlp_ln_key = 'pre_mlp_layernorm.weight'
self._set_state_dict(mg_layer, mlp_ln_key, hf_state_dict, 'post_attention_layernorm.weight', to_mcore)
return hf_state_dict

def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
Expand Down
11 changes: 10 additions & 1 deletion src/twinkle/model/megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from megatron.core import mpu
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.extensions.transformer_engine import TELinear
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.gpt import GPTModel as McoreGPTModel
Expand All @@ -28,6 +27,16 @@

mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')

try:
from megatron.core.extensions.transformer_engine import TELinear
except ImportError:

class TELinear(torch.nn.Module):

def __init__(self, *args, **kwargs):
super().__init__()
raise RuntimeError('TransformerEngine is required to instantiate OutputLayerLinear.')


class OutputLayerLinear(TELinear):

Expand Down