From ae30102390d839ce09570d6ba24694061f0a4617 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 24 Feb 2026 10:15:17 +0800 Subject: [PATCH 1/9] wip --- .../model/transformers/moe/expert_parallel.py | 65 ++++++++++++++++-- .../transformers/strategy/native_fsdp.py | 66 +++++++++++++++++++ src/twinkle/utils/grad_clip.py | 27 +++++++- src/twinkle/utils/platform.py | 38 +++++++++++ 4 files changed, 186 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 7aab3c42..cdd844b9 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -31,6 +31,8 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: dic if ep_world_size <= 1: return model + ep_fsdp_enabled = device_mesh.is_implicit_ep_fsdp_enabled() + if cfg.pad_to_max: raise NotImplementedError('pad_to_max is not implemented.') if cfg.all_to_all != 'torch': @@ -44,7 +46,7 @@ def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: dic raise RuntimeError('EP process group is not available in device_mesh.') for block in find_moe_blocks(model): - shard_experts(block, device_mesh, cfg) + shard_experts(block, device_mesh, cfg, ep_fsdp_enabled=ep_fsdp_enabled) patch_forward(block, device_mesh, cfg) return model @@ -75,7 +77,8 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: return blocks -def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None: +def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig, *, + ep_fsdp_enabled: bool) -> None: num_experts = _get_num_experts(block) ep_world_size = device_mesh.ep_world_size ep_rank = device_mesh.ep_rank @@ -88,6 +91,9 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel local_end = local_start + experts_per_rank if isinstance(block.experts, nn.ModuleList): + if ep_fsdp_enabled: + raise NotImplementedError('EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. ' + 'Only tensor experts (gate_up_proj/down_proj) are supported.') local_experts = nn.ModuleList(block.experts[local_start:local_end]) block.experts = local_experts block._ep_tensor_experts = False @@ -102,6 +108,7 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel block._ep_rank = ep_rank block._ep_world_size = ep_world_size block._ep_ignore_shared_experts = cfg.ignore_shared_experts + block._ep_fsdp_enabled = ep_fsdp_enabled def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None: @@ -120,6 +127,7 @@ def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel ep_group = device_mesh.get_dim_group('ep') def forward(hidden_states: torch.Tensor, *args, **kwargs): + ep_rank = block._ep_rank if args or kwargs: raise RuntimeError('Expert parallel patch only supports forward(hidden_states).') @@ -193,11 +201,14 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): group=ep_group, ) recv_out = torch.empty_like(recv_tokens) - for expert_id in torch.unique(recv_expert_ids).tolist(): - idx = (recv_expert_ids == expert_id).nonzero(as_tuple=False).view(-1) - expert_in = recv_tokens.index_select(0, idx) - expert_out = _run_expert(block, expert_id, expert_in) - recv_out.index_copy_(0, idx, expert_out) + if getattr(block, '_ep_fsdp_enabled', False) and getattr(block, '_ep_tensor_experts', False): + recv_out = _run_experts_ep_fsdp_batch(block, recv_tokens, recv_expert_ids) + else: + for expert_id in torch.unique(recv_expert_ids).tolist(): + idx = (recv_expert_ids == expert_id).nonzero(as_tuple=False).view(-1) + expert_in = recv_tokens.index_select(0, idx) + expert_out = _run_expert(block, expert_id, expert_in) + recv_out.index_copy_(0, idx, expert_out) send_out = torch.empty_like(send_tokens) send_out = dist_nn.functional.all_to_all_single( @@ -327,6 +338,25 @@ def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> to expert = block.experts[expert_id] return _run_module_with_casting(expert, expert_in) experts = block.experts + if getattr(block, '_ep_fsdp_enabled', False): + # In EP+EP_FSDP mode, execute experts.forward so FSDP hooks can + # manage unshard/reshard around forward/backward safely. + top_k_index = torch.full( + (expert_in.shape[0], 1), + int(expert_id), + dtype=torch.long, + device=expert_in.device, + ) + top_k_weights = torch.ones( + (expert_in.shape[0], 1), + dtype=expert_in.dtype, + device=expert_in.device, + ) + out = experts(expert_in, top_k_index, top_k_weights) + if out.dtype != input_dtype: + out = out.to(input_dtype) + return out + gate_up = experts.gate_up_proj[expert_id] down = experts.down_proj[expert_id] compute_dtype = gate_up.dtype @@ -340,6 +370,27 @@ def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> to return out +def _run_experts_ep_fsdp_batch( + block: nn.Module, + expert_in: torch.Tensor, + local_expert_ids: torch.Tensor, +) -> torch.Tensor: + input_dtype = expert_in.dtype + if expert_in.numel() == 0: + return torch.empty_like(expert_in) + experts = block.experts + top_k_index = local_expert_ids.view(-1, 1).to(torch.long) + top_k_weights = torch.ones( + (expert_in.shape[0], 1), + dtype=expert_in.dtype, + device=expert_in.device, + ) + out = experts(expert_in, top_k_index, top_k_weights) + if out.dtype != input_dtype: + out = out.to(input_dtype) + return out + + def _module_compute_dtype(module: nn.Module, default: torch.dtype) -> torch.dtype: for param in module.parameters(): if param.dtype.is_floating_point: diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index a0b75d94..6d960a70 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -29,6 +29,10 @@ def wrap_model(self, model, optimizer=None): from torch.distributed.fsdp import fully_shard fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: + ep_fsdp_mode = _is_ep_fsdp_mode_enabled( + self.device_mesh, + self.enable_ep, + ) if self.enable_ep: _ensure_moe_patched_if_needed(model, self.device_mesh) _place_ep_experts_on_local_device(model, self.device_mesh) @@ -36,6 +40,19 @@ def wrap_model(self, model, optimizer=None): reshard_after_forward = self.fsdp_config.get('reshard_after_forward', True) ignored_params = _collect_expert_params(model) if self.enable_ep else None + if ep_fsdp_mode: + _ensure_ep_fsdp_supported(model) + ep_fsdp_mesh = _build_ep_fsdp_mesh(self.device_mesh) + if ep_fsdp_mesh is None: + raise RuntimeError( + 'Implicit EP_FSDP requires dp dim with size > 1, but could not build an ep_fsdp mesh from dp.') + sharded_blocks = _maybe_shard_ep_expert_blocks( + model, + mesh=ep_fsdp_mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, + ) + _maybe_shard_layers( model, mesh=fsdp_mesh, @@ -85,6 +102,21 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=('fsdp', )) +def _build_ep_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: + if device_mesh is None or not device_mesh.has_dim('dp'): + return None + ranks = device_mesh.get_ranks_for_dims('dp') + if len(ranks) <= 1: + return None + return TorchDeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=('ep_fsdp', )) + + +def _is_ep_fsdp_mode_enabled(device_mesh: Optional[DeviceMesh], enable_ep: bool) -> bool: + if not enable_ep or device_mesh is None: + return False + return device_mesh.is_implicit_ep_fsdp_enabled() + + def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: ignored: Set[nn.Parameter] = set() ep_patched = False @@ -137,6 +169,40 @@ def _ensure_moe_patched_if_needed(model: nn.Module, device_mesh: DeviceMesh) -> 'Call apply_expert_parallel(model, device_mesh, config) before wrapping with FSDP2.') +def _ensure_ep_fsdp_supported(model: nn.Module) -> None: + for module in model.modules(): + if not getattr(module, '_ep_patched', False): + continue + experts = getattr(module, 'experts', None) + if isinstance(experts, nn.ModuleList): + raise NotImplementedError('EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. ' + 'Only tensor experts (gate_up_proj/down_proj) are supported.') + + +def _maybe_shard_ep_expert_blocks(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool], + mp_policy: 'MixedPrecisionPolicy') -> int: + from torch.distributed.fsdp import fully_shard + from torch.distributed.tensor import Shard + sharded_blocks = 0 + for module in model.modules(): + if not getattr(module, '_ep_patched', False): + continue + experts = getattr(module, 'experts', None) + if experts is None: + continue + # Correct EP+EP_FSDP behavior: only experts are sharded on ep_fsdp mesh. + # Non-expert params (router/gate etc.) are left to global FSDP wrapping. + fully_shard( + experts, + mesh=mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, + shard_placement_fn=lambda param: Shard(1), + ) + sharded_blocks += 1 + return sharded_blocks + + def _maybe_shard_layers(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool], mp_policy: 'MixedPrecisionPolicy', ignored_params: Optional[Set[nn.Parameter]]) -> None: from torch.distributed.fsdp import fully_shard diff --git a/src/twinkle/utils/grad_clip.py b/src/twinkle/utils/grad_clip.py index 3f678053..4b9f192c 100644 --- a/src/twinkle/utils/grad_clip.py +++ b/src/twinkle/utils/grad_clip.py @@ -34,7 +34,23 @@ def normalize_and_clip_grad_norm(parameters: Iterable[torch.nn.Parameter], has_dtensor_grad = any(hasattr(grad, 'to_local') for grad in grads) has_local_tensor_grad = any(not hasattr(grad, 'to_local') for grad in grads) - if not (has_dtensor_grad and has_local_tensor_grad): + dtensor_mesh_keys = set() + for grad in grads: + if not hasattr(grad, 'to_local'): + continue + mesh = getattr(grad, 'device_mesh', None) + if mesh is None: + dtensor_mesh_keys.add('dtensor:unknown') + continue + try: + mesh_key = (tuple(mesh.mesh.flatten().tolist()), tuple(mesh.mesh_dim_names or ())) + except Exception: + mesh_key = repr(mesh) + dtensor_mesh_keys.add(mesh_key) + + has_mixed_dtensor_mesh = len(dtensor_mesh_keys) > 1 + + if not (has_dtensor_grad and has_local_tensor_grad) and not has_mixed_dtensor_mesh: grad_norm = torch.nn.utils.clip_grad_norm_( parameters, max_grad_norm, @@ -64,6 +80,11 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: reduce_device = torch.device(Platform.get_local_device()) else: reduce_device = torch.device('cpu') + reduce_group = group + if has_mixed_dtensor_mesh: + # Different DTensor meshes cannot be reduced by DTensor op propagation (e.g. aten.stack). + # Fall back to world reduction over local shards. + reduce_group = None if norm_type == float('inf'): local_norm = 0.0 @@ -74,7 +95,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: local_norm = max(local_norm, local_grad.detach().abs().max().item()) total_norm_tensor = torch.tensor(local_norm, device=reduce_device, dtype=torch.float32) if dist.is_initialized(): - dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=reduce_group) total_norm = float(total_norm_tensor.item()) else: local_sq = 0.0 @@ -85,7 +106,7 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: local_sq += local_grad.detach().float().pow(2).sum().item() total_sq_tensor = torch.tensor(local_sq, device=reduce_device, dtype=torch.float32) if dist.is_initialized(): - dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=group) + dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=reduce_group) total_norm = float(total_sq_tensor.sqrt().item()) clip_coef = float(max_grad_norm) / (total_norm + 1e-6) diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py index 0e1d9c97..16106063 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/platform.py @@ -187,6 +187,44 @@ def get_dim_group(self, dims): key = tuple(c for i, c in enumerate(coord) if i != dim_idx) return group_map[key] + def get_ranks_for_dims(self, dims): + if self.mesh_dim_names is None: + raise ValueError('mesh_dim_names is not set.') + if isinstance(dims, str): + dims = (dims, ) + for dim_name in dims: + if dim_name not in self.mesh_dim_names: + raise ValueError(f"Dimension '{dim_name}' not found in mesh. Available: {self.mesh_dim_names}") + + coord = self._get_coord() + if coord is None: + raise RuntimeError('Current rank is not found in mesh.') + + slices = [] + for i, dim_name in enumerate(self.mesh_dim_names): + if dim_name in dims: + slices.append(slice(None)) + else: + slices.append(coord[i]) + return sorted(self.mesh[tuple(slices)].flatten().tolist()) + + def is_implicit_ep_fsdp_enabled(self) -> bool: + ep_world_size = self.ep_world_size or 1 + dp_world_size = self.dp_world_size or 1 + if ep_world_size <= 1 or dp_world_size <= 1: + return False + + world_size = self.world_size or 1 + if world_size % ep_world_size != 0: + raise ValueError(f'world_size ({world_size}) must be divisible by ep_world_size ({ep_world_size}) ' + 'to infer implicit EP_FSDP from dp.') + expected_dp_size = world_size // ep_world_size + if dp_world_size != expected_dp_size: + raise ValueError(f'Implicit EP_FSDP requires dp_world_size == world_size // ep_world_size, ' + f'but got dp_world_size={dp_world_size}, world_size={world_size}, ' + f'ep_world_size={ep_world_size}.') + return True + @property def order(self): """The order of the dimensions for megatron""" From 181edbd0d3b9ac4c9b9e1e97ca92238e7ac1aea7 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 24 Feb 2026 10:19:42 +0800 Subject: [PATCH 2/9] lint --- src/twinkle/metric/train_metric.py | 4 ++-- src/twinkle/model/transformers/moe/expert_parallel.py | 1 - src/twinkle/model/transformers/strategy/native_fsdp.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/twinkle/metric/train_metric.py b/src/twinkle/metric/train_metric.py index f144c837..201ff859 100644 --- a/src/twinkle/metric/train_metric.py +++ b/src/twinkle/metric/train_metric.py @@ -44,7 +44,7 @@ def calculate(self): self.lr = self.lr[0] if isinstance(self.lr, list): for idx, lr in enumerate(self.lr): - results[f'learning rate(param group {idx+1})'] = lr + results[f'learning rate(param group {idx + 1})'] = lr else: results['learning rate'] = self.lr if self.step is not None: @@ -54,7 +54,7 @@ def calculate(self): if interval < 60: results['total time elapse'] = f'{(time.time() - self.start_time):.0f} seconds' else: - results['total time elapse'] = f'{(time.time() - self.start_time)/60:.1f} minutes' + results['total time elapse'] = f'{(time.time() - self.start_time) / 60:.1f} minutes' results['speed'] = f'{speed:.2f} iters/s' self.reset() return results diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index cdd844b9..00e83104 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -127,7 +127,6 @@ def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel ep_group = device_mesh.get_dim_group('ep') def forward(hidden_states: torch.Tensor, *args, **kwargs): - ep_rank = block._ep_rank if args or kwargs: raise RuntimeError('Expert parallel patch only supports forward(hidden_states).') diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 6d960a70..c0eca90a 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -46,7 +46,7 @@ def wrap_model(self, model, optimizer=None): if ep_fsdp_mesh is None: raise RuntimeError( 'Implicit EP_FSDP requires dp dim with size > 1, but could not build an ep_fsdp mesh from dp.') - sharded_blocks = _maybe_shard_ep_expert_blocks( + _maybe_shard_ep_expert_blocks( model, mesh=ep_fsdp_mesh, reshard_after_forward=reshard_after_forward, From 0358d974dbc92458a8f935857346067d043e345e Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Tue, 24 Feb 2026 11:35:39 +0800 Subject: [PATCH 3/9] wip --- .../model/transformers/moe/expert_parallel.py | 18 ------------------ .../model/transformers/strategy/native_fsdp.py | 4 +--- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 00e83104..8c81e3b0 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -337,24 +337,6 @@ def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> to expert = block.experts[expert_id] return _run_module_with_casting(expert, expert_in) experts = block.experts - if getattr(block, '_ep_fsdp_enabled', False): - # In EP+EP_FSDP mode, execute experts.forward so FSDP hooks can - # manage unshard/reshard around forward/backward safely. - top_k_index = torch.full( - (expert_in.shape[0], 1), - int(expert_id), - dtype=torch.long, - device=expert_in.device, - ) - top_k_weights = torch.ones( - (expert_in.shape[0], 1), - dtype=expert_in.dtype, - device=expert_in.device, - ) - out = experts(expert_in, top_k_index, top_k_weights) - if out.dtype != input_dtype: - out = out.to(input_dtype) - return out gate_up = experts.gate_up_proj[expert_id] down = experts.down_proj[expert_id] diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index c0eca90a..91668a00 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -2,6 +2,7 @@ import torch from torch import nn from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh +from torch.distributed.fsdp import fully_shard from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set from twinkle.utils import DeviceMesh, Platform @@ -26,7 +27,6 @@ def __init__(self, def wrap_model(self, model, optimizer=None): if self.device_mesh is None: return model, optimizer - from torch.distributed.fsdp import fully_shard fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: ep_fsdp_mode = _is_ep_fsdp_mode_enabled( @@ -181,7 +181,6 @@ def _ensure_ep_fsdp_supported(model: nn.Module) -> None: def _maybe_shard_ep_expert_blocks(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool], mp_policy: 'MixedPrecisionPolicy') -> int: - from torch.distributed.fsdp import fully_shard from torch.distributed.tensor import Shard sharded_blocks = 0 for module in model.modules(): @@ -205,7 +204,6 @@ def _maybe_shard_ep_expert_blocks(model: nn.Module, *, mesh: TorchDeviceMesh, re def _maybe_shard_layers(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool], mp_policy: 'MixedPrecisionPolicy', ignored_params: Optional[Set[nn.Parameter]]) -> None: - from torch.distributed.fsdp import fully_shard layers = getattr(model, 'layers', None) if not isinstance(layers, nn.ModuleList): return From dbd78df533888eddaa4fa2eaf39d7e7b1b9d25ca Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 25 Feb 2026 14:07:39 +0800 Subject: [PATCH 4/9] wip --- cookbook/transformers/ep_fsdp_qwen3_moe.py | 38 +- cookbook/transformers/fsdp2_moe_full.py | 103 +++++ .../transformers/strategy/native_fsdp.py | 3 +- tests/moe/test_expert_parallel_qwen3_fsdp.py | 405 ------------------ 4 files changed, 135 insertions(+), 414 deletions(-) create mode 100644 cookbook/transformers/fsdp2_moe_full.py delete mode 100644 tests/moe/test_expert_parallel_qwen3_fsdp.py diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.py b/cookbook/transformers/ep_fsdp_qwen3_moe.py index 16706eae..d2c854d3 100644 --- a/cookbook/transformers/ep_fsdp_qwen3_moe.py +++ b/cookbook/transformers/ep_fsdp_qwen3_moe.py @@ -17,6 +17,12 @@ TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template') _num_layers_env = os.environ.get('NUM_LAYERS') NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) +LR = float(os.environ.get('LR', '1e-5')) +DISABLE_CLIP = os.environ.get('DISABLE_CLIP', '1') == '1' +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) +KEEP_ROUTER_LOGITS = os.environ.get('KEEP_ROUTER_LOGITS', '0') == '1' # 4 gpus, dp=2, ep=2 dp_size = 2 @@ -51,11 +57,10 @@ def train(): dataset.encode(batched=True) dataloader = DataLoader( dataset=dataset, - batch_size=4, + batch_size=BATCH_SIZE, device_mesh=device_mesh, ) - grad_accum_steps = 4 model = TransformersModel( model_id=MODEL_ID, config=config, @@ -65,26 +70,43 @@ def train(): 'enabled': True, 'router_dtype': 'fp32', 'all_to_all': 'torch', - 'keep_router_logits': False, + 'keep_router_logits': KEEP_ROUTER_LOGITS, } }, ) # Disable foreach to avoid DTensor mixed-type errors in EP runs. - model.set_optimizer('AdamW', foreach=False) + model.set_optimizer('AdamW', lr=LR, foreach=False) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + ) logger.info(get_device_placement()) logger.info(model.get_train_configs()) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' + f'lr={LR:.2e}, disable_clip={DISABLE_CLIP}, max_grad_norm={MAX_GRAD_NORM}, ' + f'keep_router_logits={KEEP_ROUTER_LOGITS}') for step, batch in enumerate(dataloader): if callable(batch): batch = batch() - model.forward_backward(inputs=batch, gradient_accumulation_steps=grad_accum_steps) - model.clip_grad_and_step(gradient_accumulation_steps=grad_accum_steps) - if step % grad_accum_steps == 0: + model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + if DISABLE_CLIP: + model.step(gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.zero_grad(gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.lr_step(gradient_accumulation_steps=GRAD_ACCUM_STEPS) + else: + model.clip_grad_and_step( + max_grad_norm=MAX_GRAD_NORM, + gradient_accumulation_steps=GRAD_ACCUM_STEPS, + ) + if step % GRAD_ACCUM_STEPS == 0: metric = model.calculate_metric(is_training=True) if callable(metric): metric = metric() - logger.info(f'Current is step {step // grad_accum_steps}, metric: {metric}') + logger.info(f'Current is step {step // GRAD_ACCUM_STEPS}, metric: {metric}') if step > 0 and step % 50 == 0: model.save('./output') diff --git a/cookbook/transformers/fsdp2_moe_full.py b/cookbook/transformers/fsdp2_moe_full.py new file mode 100644 index 00000000..c1677322 --- /dev/null +++ b/cookbook/transformers/fsdp2_moe_full.py @@ -0,0 +1,103 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os + +from transformers import AutoConfig + +import twinkle +from twinkle import DeviceMesh, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() + +MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template') + +_num_layers_env = os.environ.get('NUM_LAYERS') +NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None + +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) +LR = float(os.environ.get('LR', '1e-5')) +DISABLE_CLIP = os.environ.get('DISABLE_CLIP', '1') == '1' +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) + +# Pure FSDP topology (no EP): default 4 GPUs -> fsdp=2, dp=2. +fsdp_size = int(os.environ.get('FSDP_SIZE', '2')) +dp_size = int(os.environ.get('DP_SIZE', '2')) +device_mesh = DeviceMesh.from_sizes(fsdp_size=fsdp_size, dp_size=dp_size) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def train(): + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'): + config.num_hidden_layers = NUM_LAYERS + if hasattr(config, 'use_cache'): + config.use_cache = False + + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000))) + try: + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + except ValueError: + dataset.set_template('Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) + dataset.encode(batched=True) + + dataloader = DataLoader( + dataset=dataset, + batch_size=BATCH_SIZE, + device_mesh=device_mesh, + ) + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + fsdp_config={'transformer_cls_names_to_wrap': ['Qwen3MoeSparseMoeBlock']}, + ) + + # Full-parameter training: no LoRA adapter is added. + model.set_optimizer(optimizer_cls='AdamW', lr=LR, foreach=False) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + ) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' + f'lr={LR:.2e}, disable_clip={DISABLE_CLIP}, max_grad_norm={MAX_GRAD_NORM}, ' + f'dp_size={dp_size}, fsdp_size={fsdp_size}') + if NUM_LAYERS is not None: + logger.info(f'NUM_LAYERS={NUM_LAYERS}') + + for step, batch in enumerate(dataloader): + if callable(batch): + batch = batch() + model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + if DISABLE_CLIP: + model.step(gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.zero_grad(gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.lr_step(gradient_accumulation_steps=GRAD_ACCUM_STEPS) + else: + model.clip_grad_and_step( + max_grad_norm=MAX_GRAD_NORM, + gradient_accumulation_steps=GRAD_ACCUM_STEPS, + ) + if step % GRAD_ACCUM_STEPS == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info(f'Current is step {step // GRAD_ACCUM_STEPS}, metric: {metric}') + if step > 0 and step % 50 == 0: + model.save('./output') + + +if __name__ == '__main__': + train() diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 91668a00..2e9db3e3 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -38,7 +38,6 @@ def wrap_model(self, model, optimizer=None): _place_ep_experts_on_local_device(model, self.device_mesh) mp_policy = _build_mp_policy(self.mixed_precision) reshard_after_forward = self.fsdp_config.get('reshard_after_forward', True) - ignored_params = _collect_expert_params(model) if self.enable_ep else None if ep_fsdp_mode: _ensure_ep_fsdp_supported(model) @@ -53,6 +52,8 @@ def wrap_model(self, model, optimizer=None): mp_policy=mp_policy, ) + ignored_params = _collect_expert_params(model) if self.enable_ep else None + _maybe_shard_layers( model, mesh=fsdp_mesh, diff --git a/tests/moe/test_expert_parallel_qwen3_fsdp.py b/tests/moe/test_expert_parallel_qwen3_fsdp.py deleted file mode 100644 index 88da379a..00000000 --- a/tests/moe/test_expert_parallel_qwen3_fsdp.py +++ /dev/null @@ -1,405 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import json -import numpy as np -import os -import socket -import sys -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import unittest -from pathlib import Path -from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig -from typing import Dict, List - -from twinkle.model.transformers.moe import apply_expert_parallel -from twinkle.model.transformers.strategy import NativeFSDPStrategy -from twinkle.utils import DeviceMesh - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(('127.0.0.1', 0)) - return sock.getsockname()[1] - - -def _find_moe_blocks(model: nn.Module) -> List[nn.Module]: - blocks = [] - for module in model.modules(): - experts = getattr(module, 'experts', None) - if experts is None: - continue - if not isinstance(experts, nn.ModuleList): - if not (hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj')): - continue - gate = getattr(module, 'gate', None) or getattr(module, 'router', None) - if gate is None: - continue - blocks.append(module) - return blocks - - -def _capture_router_logits(model: nn.Module): - router_logits: List[torch.Tensor] = [] - handles = [] - for block in _find_moe_blocks(model): - gate = getattr(block, 'gate', None) or getattr(block, 'router', None) - if gate is None: - continue - - def _hook(module, inputs, output): - if isinstance(output, tuple): - router_logits.append(output[0].detach()) - else: - router_logits.append(output.detach()) - - handles.append(gate.register_forward_hook(_hook)) - return router_logits, handles - - -def _get_top_k(block: nn.Module) -> int: - if hasattr(block, 'num_experts_per_tok') and getattr(block, 'num_experts_per_tok') is not None: - return int(getattr(block, 'num_experts_per_tok')) - if hasattr(block, 'top_k') and getattr(block, 'top_k') is not None: - return int(getattr(block, 'top_k')) - gate = getattr(block, 'gate', None) or getattr(block, 'router', None) - if gate is not None and hasattr(gate, 'top_k') and getattr(gate, 'top_k') is not None: - return int(getattr(gate, 'top_k')) - raise RuntimeError('Cannot infer top_k for MoE block.') - - -def _capture_router_state(model: nn.Module): - states: List[Dict[str, torch.Tensor]] = [] - handles = [] - for block in _find_moe_blocks(model): - gate = getattr(block, 'gate', None) or getattr(block, 'router', None) - if gate is None: - continue - top_k = _get_top_k(block) - norm_topk_prob = getattr(block, 'norm_topk_prob', False) - - def _hook(module, inputs, output, *, _top_k=top_k, _norm=norm_topk_prob): - if isinstance(output, tuple): - router_logits, routing_weights, selected_experts = output[:3] - else: - router_logits = output - routing_weights = torch.softmax(router_logits, dim=-1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, _top_k, dim=-1) - if _norm: - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - states.append({ - 'selected_experts': selected_experts.detach().cpu(), - 'routing_weights': routing_weights.detach().cpu(), - }) - - handles.append(gate.register_forward_hook(_hook)) - return states, handles - - -def _collect_baseline_local_expert_grads( - block: nn.Module, - ep_rank: int, - ep_world_size: int, - ep_group, -) -> Dict[int, Dict[str, torch.Tensor]]: - if isinstance(block.experts, nn.ModuleList): - num_experts = len(block.experts) - else: - num_experts = int(block.experts.gate_up_proj.shape[0]) - if num_experts % ep_world_size != 0: - raise ValueError(f'num_experts ({num_experts}) must be divisible by ep_world_size ({ep_world_size}).') - experts_per_rank = num_experts // ep_world_size - local_start = ep_rank * experts_per_rank - local_end = local_start + experts_per_rank - local_grads: Dict[int, Dict[str, torch.Tensor]] = {} - - if isinstance(block.experts, nn.ModuleList): - for global_idx, expert in enumerate(block.experts): - param_grads: Dict[str, torch.Tensor] = {} - for name, param in expert.named_parameters(): - grad = param.grad - if grad is None: - grad = torch.zeros_like(param, dtype=param.dtype) - dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=ep_group) - if local_start <= global_idx < local_end: - param_grads[name] = grad.detach().cpu() - if local_start <= global_idx < local_end: - local_grads[global_idx] = param_grads - else: - gate_up = block.experts.gate_up_proj - down = block.experts.down_proj - gate_up_grad = gate_up.grad if gate_up.grad is not None else torch.zeros_like(gate_up) - down_grad = down.grad if down.grad is not None else torch.zeros_like(down) - dist.all_reduce(gate_up_grad, op=dist.ReduceOp.SUM, group=ep_group) - dist.all_reduce(down_grad, op=dist.ReduceOp.SUM, group=ep_group) - for global_idx in range(num_experts): - if local_start <= global_idx < local_end: - local_grads[global_idx] = { - 'gate_up_proj': gate_up_grad[global_idx].detach().cpu(), - 'down_proj': down_grad[global_idx].detach().cpu(), - } - - return local_grads - - -def _load_qwen3_moe_config(model_id: str, local_files_only: bool): - try: - return AutoConfig.from_pretrained( - model_id, - trust_remote_code=True, - local_files_only=local_files_only, - ) - except Exception as exc: # noqa: BLE001 - config_path = Path(model_id) / 'config.json' - if not config_path.exists(): - raise exc - with config_path.open('r', encoding='utf-8') as handle: - data = json.load(handle) - if 'model_type' not in data: - data['model_type'] = 'qwen3_moe' - if 'architectures' not in data: - data['architectures'] = ['Qwen3MoeForCausalLM'] - try: - return AutoConfig.from_dict(data) - except Exception as exc: # noqa: BLE001 - print(f'AutoConfig.from_dict fallback to PretrainedConfig for {model_id}: {exc}') - return PretrainedConfig.from_dict(data) - - -def _load_qwen3_moe_pretrained(model_id: str, local_files_only: bool, device: torch.device) -> nn.Module: - config = _load_qwen3_moe_config(model_id, local_files_only) - if hasattr(config, 'num_hidden_layers'): - config.num_hidden_layers = 1 - if hasattr(config, 'use_cache'): - config.use_cache = False - if hasattr(config, '_experts_implementation'): - config._experts_implementation = 'eager' - model = AutoModelForCausalLM.from_pretrained( - model_id, - config=config, - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, - local_files_only=local_files_only, - ) - model.to(device) - model.eval() - return model - - -def _run_worker_ep_fsdp_pretrained(rank: int, world_size: int, port: int, model_id: str, local_files_only: bool): - os.environ['RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(port) - if not torch.cuda.is_available(): - raise RuntimeError('This test requires CUDA (4 GPUs).') - device = torch.device(f'cuda:{rank}') - torch.cuda.set_device(device) - os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '1' - dist.init_process_group( - backend='nccl', - rank=rank, - world_size=world_size, - init_method=f'tcp://127.0.0.1:{port}', - device_id=device, - ) - dist.barrier() - - try: - torch.manual_seed(1234) - model = _load_qwen3_moe_pretrained(model_id, local_files_only, device) - input_ids = torch.randint( - low=0, - high=model.config.vocab_size, - size=(2, 8), - device=device, - ) - - baseline_router_logits, baseline_handles = _capture_router_logits(model.model) - baseline_router_state, baseline_state_handles = _capture_router_state(model.model) - baseline_out = model(input_ids=input_ids).logits - for handle in baseline_handles: - handle.remove() - for handle in baseline_state_handles: - handle.remove() - baseline_out_ref = baseline_out.detach() - baseline_out.sum().backward() - - device_mesh = DeviceMesh( - device_type='cuda', - mesh=np.arange(world_size).reshape(2, 2), - mesh_dim_names=('fsdp', 'ep'), - ) - ep_group = device_mesh.get_dim_group('ep') - - baseline_blocks = _find_moe_blocks(model.model) - if not baseline_blocks: - raise RuntimeError('No MoE blocks found in Qwen3 model.') - - baseline_block_grads = [] - for block in baseline_blocks: - baseline_block_grads.append( - _collect_baseline_local_expert_grads( - block, - device_mesh.ep_rank, - device_mesh.ep_world_size, - ep_group, - )) - - model.zero_grad(set_to_none=True) - - apply_expert_parallel( - model.model, - device_mesh, - config={ - 'enabled': True, - 'router_dtype': 'fp32', - 'all_to_all': 'torch', - 'keep_router_logits': False, - }, - ) - - strategy = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision='bf16', fsdp_config={}) - model.model, _ = strategy.wrap_model(model.model, optimizer=None) - - ep_router_logits, ep_handles = _capture_router_logits(model.model) - ep_router_state, ep_state_handles = _capture_router_state(model.model) - ep_out = model(input_ids=input_ids).logits - for handle in ep_handles: - handle.remove() - for handle in ep_state_handles: - handle.remove() - - out_diff = (ep_out - baseline_out_ref).abs() - if not torch.allclose(ep_out, baseline_out_ref, rtol=1e-3, atol=1e-4): - print(f'[rank{rank}] ep_out diff mean={out_diff.mean().item():.6e} ' - f'max={out_diff.max().item():.6e}') - assert torch.allclose(ep_out, baseline_out_ref, rtol=1e-3, atol=1e-4) - - if baseline_router_logits and ep_router_logits: - for idx, (base_logits, ep_logits) in enumerate(zip(baseline_router_logits, ep_router_logits)): - logits_diff = (ep_logits - base_logits).abs() - if not torch.allclose(ep_logits, base_logits, rtol=1e-3, atol=1e-4): - print(f'[rank{rank}] router_logits[{idx}] diff ' - f'mean={logits_diff.mean().item():.6e} ' - f'max={logits_diff.max().item():.6e}') - else: - print(f'[rank{rank}] router_logits not captured for comparison.') - - if baseline_router_state and ep_router_state: - for idx, (base_state, ep_state) in enumerate(zip(baseline_router_state, ep_router_state)): - base_sel = base_state['selected_experts'] - ep_sel = ep_state['selected_experts'] - if not torch.equal(base_sel, ep_sel): - num_experts = int(base_sel.max().item()) + 1 - base_counts = torch.bincount(base_sel.reshape(-1), minlength=num_experts) - ep_counts = torch.bincount(ep_sel.reshape(-1), minlength=num_experts) - diff = (base_counts - ep_counts).abs() - print( - f'[rank{rank}] selected_experts[{idx}] mismatch ' - f'max_diff={diff.max().item()} mean_diff={diff.float().mean().item():.6e}', - flush=True, - ) - - ep_out.sum().backward() - - ep_blocks = _find_moe_blocks(model.model) - assert len(ep_blocks) == len(baseline_block_grads) - - for block_idx, ep_block in enumerate(ep_blocks): - baseline_grads = baseline_block_grads[block_idx] - printed_grad_diff = False - if isinstance(ep_block.experts, nn.ModuleList): - for local_idx, expert in enumerate(ep_block.experts): - global_idx = ep_block._ep_local_start + local_idx - baseline_params = baseline_grads[global_idx] - for name, param in expert.named_parameters(): - baseline_grad = baseline_params[name] - ep_grad = param.grad - if ep_grad is None: - assert torch.allclose( - baseline_grad, - torch.zeros_like(baseline_grad), - rtol=1e-5, - atol=1e-6, - ) - else: - base = baseline_grad.to(ep_grad.device, dtype=torch.float32) - diff = (ep_grad.to(torch.float32) - base) - rel = diff.norm() / (base.norm() + 1e-12) - if rel.item() > 1e-3 and not printed_grad_diff: - abs_diff = diff.abs() - base_norm = base.norm().item() - ep_norm = ep_grad.norm().item() - ratio = ep_norm / base_norm if base_norm != 0 else float('inf') - print(f'[rank{rank}] expert{global_idx}.{name} grad diff ' - f'mean={abs_diff.mean().item():.6e} max={abs_diff.max().item():.6e} ' - f'ep_norm={ep_norm:.6e} base_norm={base_norm:.6e} ratio={ratio:.6e} ' - f'rel_norm={rel.item():.6e}') - printed_grad_diff = True - assert rel.item() <= 1e-3 - else: - gate_up = ep_block.experts.gate_up_proj - down = ep_block.experts.down_proj - gate_up_grad = gate_up.grad - down_grad = down.grad - for local_idx in range(gate_up.shape[0]): - global_idx = ep_block._ep_local_start + local_idx - baseline_params = baseline_grads[global_idx] - for name, tensor, grad in ( - ('gate_up_proj', gate_up[local_idx], gate_up_grad), - ('down_proj', down[local_idx], down_grad), - ): - baseline_grad = baseline_params[name] - ep_grad = None if grad is None else grad[local_idx] - if ep_grad is None: - assert torch.allclose( - baseline_grad, - torch.zeros_like(baseline_grad), - rtol=1e-5, - atol=1e-6, - ) - else: - base = baseline_grad.to(ep_grad.device, dtype=torch.float32) - diff = (ep_grad.to(torch.float32) - base) - rel = diff.norm() / (base.norm() + 1e-12) - if rel.item() > 1e-3 and not printed_grad_diff: - abs_diff = diff.abs() - base_norm = base.norm().item() - ep_norm = ep_grad.norm().item() - ratio = ep_norm / base_norm if base_norm != 0 else float('inf') - print(f'[rank{rank}] expert{global_idx}.{name} grad diff ' - f'mean={abs_diff.mean().item():.6e} max={abs_diff.max().item():.6e} ' - f'ep_norm={ep_norm:.6e} base_norm={base_norm:.6e} ratio={ratio:.6e} ' - f'rel_norm={rel.item():.6e}') - printed_grad_diff = True - assert rel.item() <= 1e-3 - finally: - dist.destroy_process_group() - - -class TestExpertParallelFSDPPretrained(unittest.TestCase): - - def test_qwen3_moe_pretrained_ep_fsdp(self): - if not dist.is_available(): - self.skipTest('torch.distributed is not available') - if not torch.cuda.is_available(): - self.skipTest('CUDA is required for this test.') - world_size = 4 - if torch.cuda.device_count() < world_size: - self.skipTest('Requires at least 4 GPUs for EP+FSDP test.') - model_id = os.environ.get('QWEN3_MOE_MODEL_ID', 'Qwen/Qwen3-30B-A3B-Instruct-2507') - local_files_only = os.environ.get('QWEN3_MOE_LOCAL_ONLY', '1') != '0' - try: - _load_qwen3_moe_config(model_id, local_files_only) - except Exception as exc: # noqa: BLE001 - self.skipTest(f'Qwen3 model not available locally: {exc}') - port = _find_free_port() - mp.spawn( - _run_worker_ep_fsdp_pretrained, - args=(world_size, port, model_id, local_files_only), - nprocs=world_size, - join=True, - ) From de7555715704d30188b923d760a37f6a38ec28c5 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 26 Feb 2026 18:31:38 +0800 Subject: [PATCH 5/9] wip --- cookbook/transformers/ep_fsdp_qwen3_moe.py | 11 +- .../model/transformers/moe/__init__.py | 4 +- .../model/transformers/moe/ep_utils.py | 299 ++++++++++++ .../model/transformers/moe/expert_parallel.py | 359 ++++++++------ .../transformers/strategy/native_fsdp.py | 195 ++++---- .../model/transformers/transformers.py | 35 +- src/twinkle/utils/grad_clip.py | 117 ++++- src/twinkle/utils/platform.py | 33 +- tests/moe/test_ep_fsdp_vs_single.py | 460 ++++++++++++++++++ 9 files changed, 1243 insertions(+), 270 deletions(-) create mode 100644 src/twinkle/model/transformers/moe/ep_utils.py create mode 100644 tests/moe/test_ep_fsdp_vs_single.py diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.py b/cookbook/transformers/ep_fsdp_qwen3_moe.py index d2c854d3..85a943b2 100644 --- a/cookbook/transformers/ep_fsdp_qwen3_moe.py +++ b/cookbook/transformers/ep_fsdp_qwen3_moe.py @@ -24,14 +24,16 @@ MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) KEEP_ROUTER_LOGITS = os.environ.get('KEEP_ROUTER_LOGITS', '0') == '1' -# 4 gpus, dp=2, ep=2 -dp_size = 2 +# 4 gpus, fsdp=4 (data parallel), ep_size=2 (expert parallel) +# The main mesh does NOT include 'ep' dimension - EP is handled by separate ep_fsdp_device_mesh +fsdp_size = 4 ep_size = 2 device_mesh = DeviceMesh( device_type=Platform.get_platform().device_prefix(), - mesh=np.arange(dp_size * ep_size).reshape(dp_size, ep_size), - mesh_dim_names=('dp', 'ep'), + mesh=np.arange(fsdp_size).reshape(fsdp_size), + mesh_dim_names=('fsdp',), + ep_size=ep_size, # ep_size is stored as attribute, not a mesh dimension ) twinkle.initialize( @@ -69,7 +71,6 @@ def train(): 'expert_parallel': { 'enabled': True, 'router_dtype': 'fp32', - 'all_to_all': 'torch', 'keep_router_logits': KEEP_ROUTER_LOGITS, } }, diff --git a/src/twinkle/model/transformers/moe/__init__.py b/src/twinkle/model/transformers/moe/__init__.py index f80d6d48..af11026c 100644 --- a/src/twinkle/model/transformers/moe/__init__.py +++ b/src/twinkle/model/transformers/moe/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from .expert_parallel import apply_expert_parallel +from .expert_parallel import ExpertShardingSpec, apply_expert_parallel -__all__ = ['apply_expert_parallel'] +__all__ = ['ExpertShardingSpec', 'apply_expert_parallel'] diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py new file mode 100644 index 00000000..7446c6a6 --- /dev/null +++ b/src/twinkle/model/transformers/moe/ep_utils.py @@ -0,0 +1,299 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# +# Adapted from VeOmni (https://github.com/volcengine/VeOmni) +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + + +# ========================== comm ========================== +# Ported from veomni/distributed/moe/comm.py + + +class _AllToAll(torch.autograd.Function): + @staticmethod + def forward(ctx, group, input, output_split_sizes, input_split_sizes): + ctx.group = group + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + + world_size = dist.get_world_size(group=group) + + if world_size == 1: + return input + + input = input.contiguous() + + if output_split_sizes is None: + output = torch.empty_like(input) + else: + output = torch.empty(size=(sum(output_split_sizes), input.size(1)), dtype=input.dtype, device=input.device) + dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + @staticmethod + def backward(ctx, *grad_output): + return ( + None, + _AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes), + None, + None, + ) + + +class _AllToAll_Async(torch.autograd.Function): + @staticmethod + def forward(ctx, group, input, output_split_sizes, input_split_sizes): + ctx.group = group + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + + world_size = dist.get_world_size(group=group) + + if world_size == 1: + return input + + input = input.contiguous() + + if output_split_sizes is None: + output = torch.empty_like(input) + else: + output = torch.empty(size=(sum(output_split_sizes), input.size(1)), dtype=input.dtype, device=input.device) + async_handle = dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True, + ) + return output, async_handle + + @staticmethod + def backward(ctx, grad_output, grad_async_handle): + return ( + None, + _AllToAll_Async.apply(ctx.group, grad_output, ctx.input_split_sizes, ctx.output_split_sizes), + None, + None, + ) + + +def all_to_all(group, input, output_split_size=None, input_split_size=None): + return _AllToAll.apply(group, input, output_split_size, input_split_size) + + +def all_to_all_async(group, input, output_split_size, input_split_size): + return _AllToAll_Async.apply(group, input, output_split_size, input_split_size) + + +# ========================== moe_utils ========================== +# Ported from veomni/distributed/moe/moe_utils.py + + +def permute(tokens: torch.Tensor, routing_map: torch.Tensor): + """ + Permutes the tokens according to the routing map. + + Args: + tokens (torch.Tensor): The input token tensor, [num_tokens, hidden_dim]. + routing_map (torch.Tensor): The sparse token to expert mapping, [num_experts, tokens]. + + """ + num_tokens, _ = tokens.shape + num_experts = routing_map.shape[0] + + # mask [num_tokens, num_experts] -> [num_experts, num_tokens] + routing_map = routing_map.bool() + + # Create a dense expert-to-token mapping from the sparse token-to-expert mapping + token_indices = torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1) + sorted_indices = token_indices.masked_select(routing_map) + + # use the mapping to permute the tokens + permuted_input = tokens.index_select(0, sorted_indices) + + return permuted_input, sorted_indices + + +def unpermute( + tokens: torch.Tensor, + routing_weights: torch.Tensor, + hidden_states_shape: torch.Size, + permutation_mapping: torch.Tensor, + routing_map: torch.Tensor, +): + """ + Unpermutes the tokens and apply the weight. + + Args: + tokens (torch.Tensor): The input token tensor, [num_tokens, hidden_dim]. + routing_weights (torch.Tensor): The routing weights, [num_tokens, num_experts]. + hidden_states_shape (torch.Size): The shape of the hidden states, [num_tokens, hidden_dim]. + routing_map (torch.Tensor): The sparse token to expert mapping, [num_experts, tokens]. + + Returns: + torch.Tensor: The unpermuted token tensor, [num_tokens, hidden_dim]. + """ + tokens_weight = routing_weights.T.contiguous().masked_select(routing_map.bool()) + + tokens = tokens * tokens_weight.unsqueeze(-1) + hidden_dim = hidden_states_shape[-1] + + unpermuted_tokens = torch.zeros(hidden_states_shape, device=tokens.device, dtype=tokens.dtype) + + # Scatter add the permuted_input back to the original positions + unpermuted_tokens.scatter_add_(0, permutation_mapping.unsqueeze(1).expand(-1, hidden_dim), tokens) + return unpermuted_tokens + + +def generate_weights_idx(routing_weights: torch.Tensor, selected_experts: torch.Tensor, num_experts) -> torch.Tensor: + """ + Generate the weight index for the unpermute operation. + + Args: + routing_weights (torch.Tensor): The routing weights. shape [num_tokens, topk]. + selected_experts (torch.Tensor): The selected experts. shape [num_tokens, topk]. + num_experts (int): The number of experts. shape [num_tokens, num_experts]. + + Returns: + torch.Tensor: The weight index. + """ + num_tokens, topk = routing_weights.shape + weights_idx = torch.zeros((num_tokens, num_experts), dtype=routing_weights.dtype, device=routing_weights.device) + + weights_idx.scatter_add_(1, selected_experts, routing_weights) + + return weights_idx + + +def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor): + """Split and sort the input tensor based on the split_sizes and sorted indices.""" + input = torch.split(input, split_sizes.tolist(), dim=0) + output = torch.cat([input[i] for i in sorted_idxs], dim=0) + return output + + +# ========================== moe_layer ========================== +# Ported from veomni/distributed/moe/moe_layer.py (preprocess, token_pre_all2all, tokens_post_all2all) +# EPGroupGemm is NOT ported (requires triton group_gemm); Twinkle uses F.linear loop instead. + + +def preprocess( + expert_mask: torch.Tensor, + num_experts: int, + ep_group: dist.ProcessGroup, +) -> torch.Tensor: + ep_size = ep_group.size() + num_local_experts = num_experts // ep_size + rank = dist.get_rank(ep_group) + num_local_tokens_per_expert = expert_mask.sum(dim=(1, 2)) + + # [ep_size] represent the number of sum tokens in each rank + input_splits = num_local_tokens_per_expert.reshape(ep_size, num_local_experts).sum(dim=1).tolist() + + # gather all the number of tokens per expert from all ep ranks + # [ep_size, num_experts] + num_global_tokens_per_expert = torch.zeros( + ep_size, + num_local_tokens_per_expert.size(0), + dtype=num_local_tokens_per_expert.dtype, + device=num_local_tokens_per_expert.device, + ) + dist.all_gather_into_tensor(num_global_tokens_per_expert, num_local_tokens_per_expert, group=ep_group) + + # [ep_size, num_local_experts] + start_idx, end_idx = rank * num_local_experts, (rank + 1) * num_local_experts + num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, start_idx:end_idx].contiguous() + + # [ep_size] + output_splits = num_global_tokens_per_local_expert.sum(dim=1).tolist() + + # [num_local_expert] + num_global_sum_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=0).to( + torch.device("cpu"), non_blocking=True + ) + + num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(-1, num_local_experts).to( + torch.device("cpu"), non_blocking=True + ) + + return input_splits, output_splits, num_global_tokens_per_local_expert, num_global_sum_tokens_per_local_expert + + +def token_pre_all2all( + hidden_states: torch.Tensor, + expert_mask: torch.Tensor, + num_experts: int, + input_splits: torch.Tensor, + output_splits: torch.Tensor, + num_global_tokens_per_local_expert: torch.Tensor, + ep_group: Optional[dist.ProcessGroup] = None, +) -> torch.Tensor: + hidden_dim = hidden_states.size(-1) + hidden_states = hidden_states.reshape(-1, hidden_dim) + org_hidden_states_shape = hidden_states.shape + routing_map = expert_mask.sum(dim=1) + + local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, routing_map) + + global_permuted_hidden_states = all_to_all(ep_group, local_permuted_hidden_states, output_splits, input_splits) + + # group tokens together by expert + num_local_experts = num_experts // ep_group.size() + permute_order = torch.arange(num_experts).reshape(-1, num_local_experts).T.ravel().tolist() + global_permuted_hidden_states = sort_chunks_by_idxs( + global_permuted_hidden_states, + num_global_tokens_per_local_expert.ravel(), + permute_order, + ) + + return global_permuted_hidden_states, routing_map, local_input_permutation_mapping, org_hidden_states_shape + + +def tokens_post_all2all( + expert_outputs: torch.Tensor, + routing_weights: torch.Tensor, + selected_experts: int, + num_experts: int, + input_splits: torch.Tensor, + output_splits: torch.Tensor, + num_global_tokens_per_local_expert: torch.Tensor, + routing_map: torch.Tensor, + local_input_permutation_mapping: torch.Tensor, + org_hidden_states_shape: torch.Size, + ep_group: Optional[dist.ProcessGroup] = None, +) -> torch.Tensor: + # group tokens together by expert + num_local_experts = num_experts // ep_group.size() + unpermute_order = torch.arange(num_experts).reshape(num_local_experts, -1).T.ravel().tolist() + expert_outputs = sort_chunks_by_idxs( + expert_outputs, + num_global_tokens_per_local_expert.T.ravel(), + unpermute_order, + ) + + unpermute_outputs = all_to_all(ep_group, expert_outputs, input_splits, output_splits) + + # [tokens, experts] + weights_idx = generate_weights_idx(routing_weights, selected_experts, num_experts) + + unpermute_outputs = unpermute( + unpermute_outputs, + weights_idx, + org_hidden_states_shape, + local_input_permutation_mapping, + routing_map, + ) + + return unpermute_outputs diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 8c81e3b0..5798bcdc 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -6,50 +6,72 @@ import torch.nn.functional as F from dataclasses import dataclass from torch import nn -from torch.distributed import nn as dist_nn -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple from twinkle.utils import DeviceMesh +from twinkle.model.transformers.moe.ep_utils import ( + preprocess, + token_pre_all2all, + tokens_post_all2all, +) @dataclass class ExpertParallelConfig: enabled: bool = True router_dtype: str = 'fp32' - all_to_all: str = 'torch' keep_router_logits: bool = True - pad_to_max: bool = False ignore_shared_experts: bool = False + ep_size: Optional[int] = None # consumed by TransformersModel, not used in expert_parallel logic -def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: dict[str, Any] | None = None): +@dataclass +class ExpertShardingSpec: + """Describes expert sharding info for a single MoE block. Extensible for other models.""" + block: nn.Module + experts_module: nn.Module + num_experts: int + experts_per_rank: int + local_start: int + local_end: int + ep_rank: int + ep_world_size: int + is_tensor_experts: bool + + +def apply_expert_parallel( + model: nn.Module, + device_mesh: DeviceMesh, + config: dict[str, Any] | None = None, + ep_fsdp_device_mesh: Optional['torch.distributed.DeviceMesh'] = None, +) -> List[ExpertShardingSpec]: + """Apply expert parallelism to all MoE blocks in the model.""" cfg = _merge_config(config) - if not cfg.enabled or device_mesh is None or not device_mesh.has_dim('ep'): - return model - - ep_world_size = device_mesh.ep_world_size - if ep_world_size <= 1: - return model - ep_fsdp_enabled = device_mesh.is_implicit_ep_fsdp_enabled() + # EP info comes from the separate ep_fsdp_device_mesh, not from main mesh + if not cfg.enabled or ep_fsdp_device_mesh is None: + return [] - if cfg.pad_to_max: - raise NotImplementedError('pad_to_max is not implemented.') - if cfg.all_to_all != 'torch': - raise NotImplementedError(f'all_to_all={cfg.all_to_all} is not supported.') + # Always query EP via the 1D submesh to avoid relying on Tensor named dims. + ep_mesh = ep_fsdp_device_mesh['ep'] + ep_world_size = ep_mesh.size() + if ep_world_size <= 1: + return [] if not dist.is_initialized(): raise RuntimeError('torch.distributed is not initialized, cannot enable expert parallel.') - ep_group = device_mesh.get_dim_group('ep') - if ep_group is None: - raise RuntimeError('EP process group is not available in device_mesh.') + # Get process group and local rank from EP submesh. + ep_group = ep_mesh.get_group() + ep_rank = ep_mesh.get_local_rank() + specs = [] for block in find_moe_blocks(model): - shard_experts(block, device_mesh, cfg, ep_fsdp_enabled=ep_fsdp_enabled) - patch_forward(block, device_mesh, cfg) + spec = shard_experts(block, ep_world_size, ep_rank, cfg) + patch_forward(block, ep_group, ep_world_size, cfg) + specs.append(spec) - return model + return specs def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig: @@ -77,11 +99,23 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: return blocks -def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig, *, - ep_fsdp_enabled: bool) -> None: +def shard_experts( + block: nn.Module, + ep_world_size: int, + ep_rank: int, + cfg: ExpertParallelConfig, +) -> ExpertShardingSpec: + """Shard experts in a MoE block across EP ranks. + + Args: + block: The MoE block containing experts. + ep_world_size: The world size for expert parallelism. + ep_rank: The current rank in the EP group. + cfg: Expert parallel configuration. + + Returns an ExpertShardingSpec describing the sharding. + """ num_experts = _get_num_experts(block) - ep_world_size = device_mesh.ep_world_size - ep_rank = device_mesh.ep_rank if num_experts % ep_world_size != 0: raise ValueError(f'num_experts ({num_experts}) must be divisible by ep_world_size ({ep_world_size}).') @@ -91,15 +125,12 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel local_end = local_start + experts_per_rank if isinstance(block.experts, nn.ModuleList): - if ep_fsdp_enabled: - raise NotImplementedError('EP+EP_FSDP currently does not support MoE experts stored as nn.ModuleList. ' - 'Only tensor experts (gate_up_proj/down_proj) are supported.') local_experts = nn.ModuleList(block.experts[local_start:local_end]) block.experts = local_experts - block._ep_tensor_experts = False + is_tensor_experts = False else: _shard_tensor_experts(block.experts, local_start, local_end) - block._ep_tensor_experts = True + is_tensor_experts = True block._ep_num_experts = num_experts block._ep_experts_per_rank = experts_per_rank @@ -107,11 +138,39 @@ def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel block._ep_local_end = local_end block._ep_rank = ep_rank block._ep_world_size = ep_world_size + block._ep_tensor_experts = is_tensor_experts block._ep_ignore_shared_experts = cfg.ignore_shared_experts - block._ep_fsdp_enabled = ep_fsdp_enabled + + return ExpertShardingSpec( + block=block, + experts_module=block.experts, + num_experts=num_experts, + experts_per_rank=experts_per_rank, + local_start=local_start, + local_end=local_end, + ep_rank=ep_rank, + ep_world_size=ep_world_size, + is_tensor_experts=is_tensor_experts, + ) -def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None: +def patch_forward( + block: nn.Module, + ep_group: dist.ProcessGroup, + ep_world_size: int, + cfg: ExpertParallelConfig, +) -> None: + """Replace the MoE block forward with EP-aware communication flow. + + Communication pattern follows VeOmni: + preprocess → token_pre_all2all → expert_compute (F.linear loop) → tokens_post_all2all + + Args: + block: The MoE block to patch. + ep_group: The process group for EP communication (from ep_fsdp_device_mesh["ep"]). + ep_world_size: The world size for expert parallelism. + cfg: Expert parallel configuration. + """ if getattr(block, '_ep_patched', False): return @@ -124,13 +183,15 @@ def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallel raise ValueError('MoE block must define top_k/num_experts_per_tok.') orig_forward = block.forward - ep_group = device_mesh.get_dim_group('ep') + num_experts = block._ep_num_experts + experts_per_rank = block._ep_experts_per_rank def forward(hidden_states: torch.Tensor, *args, **kwargs): if args or kwargs: raise RuntimeError('Expert parallel patch only supports forward(hidden_states).') input_dtype = hidden_states.dtype + orig_shape = hidden_states.shape if hidden_states.ndim == 3: batch_size, seq_len, hidden_dim = hidden_states.shape hidden_states_2d = hidden_states.view(-1, hidden_dim) @@ -151,92 +212,71 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): if cast_weights: routing_weights = routing_weights.to(hidden_states_2d.dtype) - num_tokens = hidden_states_2d.shape[0] - flat_token_idx = torch.arange(num_tokens, device=hidden_states_2d.device).repeat_interleave(top_k) - flat_expert_id = selected_experts.reshape(-1) - flat_weight = routing_weights.reshape(-1) - - experts_per_rank = block._ep_experts_per_rank - dest_rank = flat_expert_id // experts_per_rank - local_expert_id = flat_expert_id - dest_rank * experts_per_rank - - order = torch.argsort(dest_rank) - ordered_token_idx = flat_token_idx[order] - ordered_weight = flat_weight[order] - ordered_global_expert_id = flat_expert_id[order] - ordered_expert_id = local_expert_id[order] - - send_counts = torch.bincount(dest_rank, minlength=block._ep_world_size) - send_counts_list = send_counts.cpu().tolist() - - recv_counts = _exchange_counts(send_counts, ep_group) - recv_counts_list = recv_counts.cpu().tolist() - - send_tokens = hidden_states_2d.index_select(0, ordered_token_idx) - recv_tokens = torch.empty( - (int(recv_counts.sum().item()), hidden_dim), - device=hidden_states_2d.device, - dtype=hidden_states_2d.dtype, - ) - send_expert_ids = ordered_expert_id.to(torch.int64) - recv_expert_ids = torch.empty( - (int(recv_counts.sum().item()), ), - device=hidden_states_2d.device, - dtype=torch.int64, + # Build expert_mask: [num_experts, top_k, num_tokens] (VeOmni convention) + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=num_experts + ).permute(2, 1, 0) # [num_experts, top_k, num_tokens] + + # 1. preprocess: compute splits and token counts + ( + input_splits, + output_splits, + num_global_tokens_per_local_expert, + num_global_sum_tokens_per_local_expert, + ) = preprocess(expert_mask, num_experts, ep_group) + + # 2. token_pre_all2all: permute → all_to_all → sort_chunks + ( + global_permuted_hidden_states, + routing_map, + local_input_permutation_mapping, + org_hidden_states_shape, + ) = token_pre_all2all( + hidden_states_2d, + expert_mask, + num_experts, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + ep_group, ) - recv_tokens = dist_nn.functional.all_to_all_single( - recv_tokens, - send_tokens, - input_split_sizes=send_counts_list, - output_split_sizes=recv_counts_list, - group=ep_group, + # 3. expert_compute: F.linear loop per local expert (no routing weight here) + # When FSDP2 wraps experts, params are sharded DTensors. Manually + # unshard (all-gather) so _run_local_experts sees full tensors. + _experts_mod = block.experts + _need_unshard = hasattr(_experts_mod, 'unshard') and hasattr(_experts_mod, 'reshard') + if _need_unshard: + _experts_mod.unshard() + expert_outputs = _run_local_experts( + block, + global_permuted_hidden_states, + num_global_sum_tokens_per_local_expert, + experts_per_rank, ) - dist.all_to_all_single( - recv_expert_ids, - send_expert_ids.to(torch.int64), - input_split_sizes=send_counts_list, - output_split_sizes=recv_counts_list, - group=ep_group, + if _need_unshard: + _experts_mod.reshard() + + # 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight) + final_hidden = tokens_post_all2all( + expert_outputs, + routing_weights, + selected_experts, + num_experts, + input_splits, + output_splits, + num_global_tokens_per_local_expert, + routing_map, + local_input_permutation_mapping, + org_hidden_states_shape, + ep_group, ) - recv_out = torch.empty_like(recv_tokens) - if getattr(block, '_ep_fsdp_enabled', False) and getattr(block, '_ep_tensor_experts', False): - recv_out = _run_experts_ep_fsdp_batch(block, recv_tokens, recv_expert_ids) - else: - for expert_id in torch.unique(recv_expert_ids).tolist(): - idx = (recv_expert_ids == expert_id).nonzero(as_tuple=False).view(-1) - expert_in = recv_tokens.index_select(0, idx) - expert_out = _run_expert(block, expert_id, expert_in) - recv_out.index_copy_(0, idx, expert_out) - - send_out = torch.empty_like(send_tokens) - send_out = dist_nn.functional.all_to_all_single( - send_out, - recv_out, - input_split_sizes=recv_counts_list, - output_split_sizes=send_counts_list, - group=ep_group, - ) - - final_hidden = torch.zeros((num_tokens, hidden_dim), device=hidden_states_2d.device, dtype=input_dtype) - expert_hit = torch.unique(ordered_global_expert_id) - if expert_hit.numel() > 0: - expert_hit, _ = torch.sort(expert_hit) - for expert_id in expert_hit: - idx = (ordered_global_expert_id == expert_id).nonzero(as_tuple=False).view(-1) - if idx.numel() == 0: - continue - token_idx = ordered_token_idx.index_select(0, idx) - weight = ordered_weight.index_select(0, idx) - contrib = send_out.index_select(0, idx) - scaled = contrib * weight.unsqueeze(-1) - final_hidden.index_add_(0, token_idx, scaled.to(input_dtype)) shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg) if shared_out is not None: final_hidden = final_hidden + shared_out - if hidden_states.ndim == 3: + if len(orig_shape) == 3: final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim) if cfg.keep_router_logits and not getattr(block, '_ep_tensor_experts', False): @@ -248,19 +288,6 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): block._ep_patched = True -def _exchange_counts(send_counts: torch.Tensor, group) -> torch.Tensor: - ep_world_size = int(send_counts.numel()) - recv_counts = torch.empty_like(send_counts) - dist.all_to_all_single( - recv_counts, - send_counts.to(torch.int64), - input_split_sizes=[1] * ep_world_size, - output_split_sizes=[1] * ep_world_size, - group=group, - ) - return recv_counts - - def _get_gate(block: nn.Module): gate = getattr(block, 'gate', None) if gate is None: @@ -331,45 +358,59 @@ def _shard_tensor_experts(experts: nn.Module, start: int, end: int) -> None: experts.num_experts = end - start -def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> torch.Tensor: - input_dtype = expert_in.dtype - if not getattr(block, '_ep_tensor_experts', False): - expert = block.experts[expert_id] - return _run_module_with_casting(expert, expert_in) +def _run_local_experts( + block: nn.Module, + permuted_tokens: torch.Tensor, + num_global_sum_tokens_per_local_expert: torch.Tensor, + experts_per_rank: int, +) -> torch.Tensor: + """Run local experts on permuted tokens using F.linear loop. + + Tokens are already grouped by expert (contiguous chunks), sizes given by + num_global_sum_tokens_per_local_expert. No routing weight is applied here; + that happens in unpermute. + """ + if permuted_tokens.numel() == 0: + return torch.empty_like(permuted_tokens) + + input_dtype = permuted_tokens.dtype + is_tensor_experts = getattr(block, '_ep_tensor_experts', False) experts = block.experts - gate_up = experts.gate_up_proj[expert_id] - down = experts.down_proj[expert_id] - compute_dtype = gate_up.dtype - if expert_in.dtype != compute_dtype: - expert_in = expert_in.to(compute_dtype) - gate, up = F.linear(expert_in, gate_up).chunk(2, dim=-1) - out = experts.act_fn(gate) * up - out = F.linear(out, down) - if out.dtype != input_dtype: - out = out.to(input_dtype) - return out + cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) + for i in range(experts_per_rank): + cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) + + output_chunks = [] + for i in range(experts_per_rank): + start = int(cumsum[i].item()) + end = int(cumsum[i + 1].item()) + expert_in = permuted_tokens[start:end] + if expert_in.numel() == 0: + output_chunks.append(expert_in) + continue + if is_tensor_experts: + gate_up = experts.gate_up_proj[i] + down = experts.down_proj[i] + compute_dtype = gate_up.dtype + if expert_in.dtype != compute_dtype: + expert_in = expert_in.to(compute_dtype) + gate, up = F.linear(expert_in, gate_up).chunk(2, dim=-1) + out = experts.act_fn(gate) * up + out = F.linear(out, down) + else: + expert = experts[i] + compute_dtype = _module_compute_dtype(expert, input_dtype) + if expert_in.dtype != compute_dtype: + expert_in = expert_in.to(compute_dtype) + out = expert(expert_in) -def _run_experts_ep_fsdp_batch( - block: nn.Module, - expert_in: torch.Tensor, - local_expert_ids: torch.Tensor, -) -> torch.Tensor: - input_dtype = expert_in.dtype - if expert_in.numel() == 0: - return torch.empty_like(expert_in) - experts = block.experts - top_k_index = local_expert_ids.view(-1, 1).to(torch.long) - top_k_weights = torch.ones( - (expert_in.shape[0], 1), - dtype=expert_in.dtype, - device=expert_in.device, - ) - out = experts(expert_in, top_k_index, top_k_weights) - if out.dtype != input_dtype: - out = out.to(input_dtype) - return out + if out.dtype != input_dtype: + out = out.to(input_dtype) + output_chunks.append(out) + + return torch.cat(output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) def _module_compute_dtype(module: nn.Module, default: torch.dtype) -> torch.dtype: diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 2e9db3e3..845e428e 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -12,63 +12,95 @@ class NativeFSDPStrategy: - """FSDP2 strategy with explicit process group control for EP compatibility.""" def __init__(self, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', fsdp_config: Dict[str, Any] = None, - enable_ep: bool = True): + enable_ep: bool = True, ep_fsdp_device_mesh: Optional[TorchDeviceMesh] = None): self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.fsdp_config = fsdp_config or {} self.enable_ep = enable_ep + self.ep_fsdp_device_mesh = ep_fsdp_device_mesh def wrap_model(self, model, optimizer=None): if self.device_mesh is None: return model, optimizer fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: - ep_fsdp_mode = _is_ep_fsdp_mode_enabled( - self.device_mesh, - self.enable_ep, - ) - if self.enable_ep: - _ensure_moe_patched_if_needed(model, self.device_mesh) - _place_ep_experts_on_local_device(model, self.device_mesh) + ep_enabled = (self.enable_ep and self.ep_fsdp_device_mesh is not None) + if ep_enabled: + _ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh) + _place_ep_experts_on_local_device(model, self.ep_fsdp_device_mesh) mp_policy = _build_mp_policy(self.mixed_precision) reshard_after_forward = self.fsdp_config.get('reshard_after_forward', True) - if ep_fsdp_mode: + if ep_enabled: _ensure_ep_fsdp_supported(model) - ep_fsdp_mesh = _build_ep_fsdp_mesh(self.device_mesh) - if ep_fsdp_mesh is None: - raise RuntimeError( - 'Implicit EP_FSDP requires dp dim with size > 1, but could not build an ep_fsdp mesh from dp.') - _maybe_shard_ep_expert_blocks( - model, - mesh=ep_fsdp_mesh, + + # Collect experts map and expert params + experts_map = _collect_ep_experts_map(model) if ep_enabled else {} + expert_params = _collect_expert_params(model) if self.enable_ep else None + + # Build layer_pairs: [(layer_mod, experts_mod_or_None)] + layers = _get_decoder_layers(model) + layer_pairs = [] + if layers is not None: + for layer_mod in layers: + experts_mod = _find_experts_in_layer(layer_mod, experts_map) + layer_pairs.append((layer_mod, experts_mod)) + + # FSDP2 wrapping per layer (mirrors VeOmni parallelize_model_fsdp2) + world_size = self.device_mesh.world_size + ep_fsdp_mesh_1d = self.ep_fsdp_device_mesh['ep_fsdp'] if ep_enabled else None + + for layer_mod, experts_mod in layer_pairs: + layer_mod._fsdp_modules = [] + + if experts_mod is not None and ep_fsdp_mesh_1d is not None: + from torch.distributed.tensor import Shard + fully_shard( + experts_mod, + mesh=ep_fsdp_mesh_1d, + reshard_after_forward=reshard_after_forward, + mp_policy=mp_policy, + shard_placement_fn=lambda param: Shard(1), + ) + # gradient_divide_factor = world_size (VeOmni convention) + experts_mod.set_gradient_divide_factor(world_size) + layer_mod._fsdp_modules.append(experts_mod) + + fully_shard( + layer_mod, + mesh=fsdp_mesh, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy, + ignored_params=expert_params, ) + layer_mod._fsdp_modules.append(layer_mod) - ignored_params = _collect_expert_params(model) if self.enable_ep else None - - _maybe_shard_layers( - model, - mesh=fsdp_mesh, - reshard_after_forward=reshard_after_forward, - mp_policy=mp_policy, - ignored_params=ignored_params, - ) + # Root model fully_shard( model, mesh=fsdp_mesh, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy, - ignored_params=ignored_params, + ignored_params=expert_params, ) + # Manual prefetch (mirrors VeOmni lines 396-411) + if ep_enabled and layer_pairs: + _setup_manual_prefetch([lp[0] for lp in layer_pairs]) + + # Tag ep_param_groups for EP-aware grad clip + if ep_enabled and expert_params: + all_params = set(model.parameters()) + model._ep_param_groups = { + 'ep': list(expert_params), + 'non_ep': [p for p in all_params if p not in expert_params], + } + if optimizer is not None: optimizer = _rebind_optimizer(optimizer, model) @@ -103,19 +135,14 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=('fsdp', )) -def _build_ep_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: - if device_mesh is None or not device_mesh.has_dim('dp'): - return None - ranks = device_mesh.get_ranks_for_dims('dp') - if len(ranks) <= 1: - return None - return TorchDeviceMesh(device_mesh.device_type, ranks, mesh_dim_names=('ep_fsdp', )) - +def _get_decoder_layers(model: nn.Module) -> Optional[nn.ModuleList]: + inner_model = getattr(model, 'model', None) + if inner_model is not None: + inner_layers = getattr(inner_model, 'layers', None) + if isinstance(inner_layers, nn.ModuleList): + return inner_layers -def _is_ep_fsdp_mode_enabled(device_mesh: Optional[DeviceMesh], enable_ep: bool) -> bool: - if not enable_ep or device_mesh is None: - return False - return device_mesh.is_implicit_ep_fsdp_enabled() + return None def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: @@ -142,8 +169,44 @@ def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]: return ignored or None -def _place_ep_experts_on_local_device(model: nn.Module, device_mesh: DeviceMesh) -> None: - ep_world_size = device_mesh.ep_world_size or 1 +def _collect_ep_experts_map(model: nn.Module) -> Dict[str, nn.Module]: + """Collect {fqn: experts_module} for all EP-patched MoE blocks.""" + experts_map = {} + for fqn, module in model.named_modules(): + if not getattr(module, '_ep_patched', False): + continue + experts = getattr(module, 'experts', None) + if experts is not None: + experts_fqn = fqn + '.experts' if fqn else 'experts' + experts_map[experts_fqn] = experts + return experts_map + + +def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]: + """Find the experts module inside a decoder layer, if any.""" + for module in layer_mod.modules(): + if module in experts_map.values(): + return module + return None + + +def _setup_manual_prefetch(blocks: list) -> None: + """Configure forward/backward prefetch for FSDP modules (mirrors VeOmni).""" + for i, block in enumerate(blocks): + if i + 1 < len(blocks): + next_fsdp_modules = getattr(blocks[i + 1], '_fsdp_modules', []) + if next_fsdp_modules: + block.set_modules_to_forward_prefetch(list(reversed(next_fsdp_modules))) + for i in range(len(blocks) - 1, 0, -1): + prev_fsdp_modules = getattr(blocks[i - 1], '_fsdp_modules', []) + if prev_fsdp_modules: + blocks[i].set_modules_to_backward_prefetch(list(reversed(prev_fsdp_modules))) + + +def _place_ep_experts_on_local_device(model: nn.Module, ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> None: + if ep_fsdp_device_mesh is None: + return + ep_world_size = ep_fsdp_device_mesh['ep'].size() if ep_world_size <= 1: return local_device = torch.device(Platform.get_local_device()) @@ -159,13 +222,19 @@ def _place_ep_experts_on_local_device(model: nn.Module, device_mesh: DeviceMesh) shared.to(local_device) -def _ensure_moe_patched_if_needed(model: nn.Module, device_mesh: DeviceMesh) -> None: - ep_world_size = device_mesh.ep_world_size or 1 +def _ensure_moe_patched_if_needed(model: nn.Module, ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> None: + if ep_fsdp_device_mesh is None: + return + ep_world_size = ep_fsdp_device_mesh['ep'].size() if ep_world_size <= 1: return for module in model.modules(): experts = getattr(module, 'experts', None) - if isinstance(experts, nn.ModuleList) and not getattr(module, '_ep_patched', False): + is_moe_experts = ( + isinstance(experts, nn.ModuleList) + or (hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj')) + ) + if is_moe_experts and not getattr(module, '_ep_patched', False): raise RuntimeError('Found MoE experts but expert parallel is not applied. ' 'Call apply_expert_parallel(model, device_mesh, config) before wrapping with FSDP2.') @@ -180,44 +249,6 @@ def _ensure_ep_fsdp_supported(model: nn.Module) -> None: 'Only tensor experts (gate_up_proj/down_proj) are supported.') -def _maybe_shard_ep_expert_blocks(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool], - mp_policy: 'MixedPrecisionPolicy') -> int: - from torch.distributed.tensor import Shard - sharded_blocks = 0 - for module in model.modules(): - if not getattr(module, '_ep_patched', False): - continue - experts = getattr(module, 'experts', None) - if experts is None: - continue - # Correct EP+EP_FSDP behavior: only experts are sharded on ep_fsdp mesh. - # Non-expert params (router/gate etc.) are left to global FSDP wrapping. - fully_shard( - experts, - mesh=mesh, - reshard_after_forward=reshard_after_forward, - mp_policy=mp_policy, - shard_placement_fn=lambda param: Shard(1), - ) - sharded_blocks += 1 - return sharded_blocks - - -def _maybe_shard_layers(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool], - mp_policy: 'MixedPrecisionPolicy', ignored_params: Optional[Set[nn.Parameter]]) -> None: - layers = getattr(model, 'layers', None) - if not isinstance(layers, nn.ModuleList): - return - for layer in layers: - fully_shard( - layer, - mesh=mesh, - reshard_after_forward=reshard_after_forward, - mp_policy=mp_policy, - ignored_params=ignored_params, - ) - - def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> torch.optim.Optimizer: if optimizer.state: raise RuntimeError('Optimizer already has state. Create the optimizer after FSDP wrapping, ' diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 6f80699b..ff36374a 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -218,13 +218,25 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): self._enable_expert_parallel = self._should_enable_expert_parallel(self._expert_parallel_config, self.device_mesh) self._expert_parallel_applied = False + # Store ep_size for later use (EP mesh construction, grad clip, etc.) + self._ep_size = (self._expert_parallel_config.get('ep_size') + if self._expert_parallel_config else None) + if self._ep_size is None and self.device_mesh is not None: + self._ep_size = getattr(self.device_mesh, 'ep_size', None) + if self._ep_size is None: + self._ep_size = 1 + use_native_fsdp = self._enable_expert_parallel or strategy == 'native_fsdp' if use_native_fsdp: + ep_fsdp_mesh = None + if self._enable_expert_parallel and self.device_mesh is not None: + ep_fsdp_mesh = self.device_mesh.build_ep_fsdp_device_mesh(ep_size=self._ep_size) self.strategy = NativeFSDPStrategy( mixed_precision=self.mixed_precision, fsdp_config=self._fsdp_config, device_mesh=self.device_mesh, enable_ep=self._enable_expert_parallel, + ep_fsdp_device_mesh=ep_fsdp_mesh, ) else: self.strategy = AccelerateStrategy( @@ -294,10 +306,9 @@ def _should_enable_expert_parallel(expert_parallel_config: Optional[Dict[str, An device_mesh: Optional[DeviceMesh]) -> bool: if expert_parallel_config is None or device_mesh is None: return False - if not device_mesh.has_dim('ep'): - return False - ep_world_size = device_mesh.ep_world_size or 1 - if ep_world_size <= 1: + # Check ep_size from config first, then from device_mesh.ep_size attribute + ep_size = expert_parallel_config.get('ep_size') or getattr(device_mesh, 'ep_size', None) or 1 + if ep_size <= 1: return False return expert_parallel_config.get('enabled', True) @@ -306,10 +317,13 @@ def _maybe_apply_expert_parallel(self): return self._ensure_optimizer_dp_groups() model = self.strategy.unwrap_model(self.model) + # Get the ep_fsdp_device_mesh from the strategy (NativeFSDPStrategy stores it) + ep_fsdp_mesh = getattr(self.strategy, 'ep_fsdp_device_mesh', None) apply_expert_parallel( model, self.device_mesh, config=self._expert_parallel_config, + ep_fsdp_device_mesh=ep_fsdp_mesh, ) self._expert_parallel_applied = True @@ -525,12 +539,25 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs): num_tokens = torch_util.gather_object([num_tokens], self.device_mesh, optimizer_config._dp_group) num_tokens = sum(num_tokens) parameters = list(self._get_trainable_parameters(adapter_name).values()) + + # EP-aware grad clip kwargs + ep_clip_kwargs = {} + model = self.strategy.unwrap_model(self.model) + if hasattr(model, '_ep_param_groups'): + ep_clip_kwargs['ep_param_groups'] = model._ep_param_groups + # Get EP groups from ep_fsdp_device_mesh, not from main mesh + ep_fsdp_mesh = getattr(self.strategy, 'ep_fsdp_device_mesh', None) + if ep_fsdp_mesh is not None: + ep_clip_kwargs['ep_group'] = ep_fsdp_mesh['ep'].get_group() + ep_clip_kwargs['ep_fsdp_group'] = ep_fsdp_mesh['ep_fsdp'].get_group() + grad_norm = normalize_and_clip_grad_norm( parameters, num_tokens=num_tokens, max_grad_norm=max_grad_norm, norm_type=norm_type, group=optimizer_config._dp_group, + **ep_clip_kwargs, ) optimizer_config._last_grad_norm = grad_norm optimizer_config.num_tokens = 0 diff --git a/src/twinkle/utils/grad_clip.py b/src/twinkle/utils/grad_clip.py index 4b9f192c..11ded854 100644 --- a/src/twinkle/utils/grad_clip.py +++ b/src/twinkle/utils/grad_clip.py @@ -15,7 +15,16 @@ def normalize_and_clip_grad_norm(parameters: Iterable[torch.nn.Parameter], num_tokens: int, max_grad_norm: float, norm_type: float, - group=None) -> float: + group=None, + ep_param_groups=None, + ep_group=None, + ep_fsdp_group=None) -> float: + """Normalize gradients by num_tokens, then clip by max_grad_norm. + + If ep_param_groups is provided, uses EP-aware two-phase reduction: + - non-EP params: all-reduce over group (fsdp_group) + - EP params: all-reduce over ep_fsdp_group, then ep_group + """ import torch import torch.distributed as dist parameters = list(parameters) @@ -32,6 +41,18 @@ def normalize_and_clip_grad_norm(parameters: Iterable[torch.nn.Parameter], if not grads: return 0.0 + # EP-aware path (mirrors VeOmni ep_fsdp2_clip_grad_norm) + if ep_param_groups is not None: + return _ep_aware_clip_grad_norm( + ep_param_groups=ep_param_groups, + max_grad_norm=max_grad_norm, + norm_type=norm_type, + fsdp_group=group, + ep_group=ep_group, + ep_fsdp_group=ep_fsdp_group, + ) + + # Standard path (backward compatible) has_dtensor_grad = any(hasattr(grad, 'to_local') for grad in grads) has_local_tensor_grad = any(not hasattr(grad, 'to_local') for grad in grads) dtensor_mesh_keys = set() @@ -114,3 +135,97 @@ def _local_grad(grad: torch.Tensor) -> torch.Tensor: for grad in grads: grad.mul_(clip_coef) return total_norm + + +def _ep_aware_clip_grad_norm( + *, + ep_param_groups, + max_grad_norm: float, + norm_type: float, + fsdp_group=None, + ep_group=None, + ep_fsdp_group=None, +) -> float: + """EP-aware gradient clipping (mirrors VeOmni ep_fsdp2_clip_grad_norm). + + - non-EP params: all-reduce over fsdp_group + - EP params: all-reduce over ep_fsdp_group, then ep_group + - Unified clip coefficient applied to both groups + """ + import math + import torch + import torch.distributed as dist + + ep_params = [p for p in ep_param_groups.get('ep', []) if p.grad is not None] + non_ep_params = [p for p in ep_param_groups.get('non_ep', []) if p.grad is not None] + + norm_type = float(norm_type) + + # non-EP: reduce over fsdp_group + non_ep_val = _local_norm_stat(non_ep_params, norm_type) + if fsdp_group is not None: + op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM + dist.all_reduce(non_ep_val, op=op, group=fsdp_group) + + # EP: reduce over ep_fsdp_group, then ep_group + ep_val = _local_norm_stat(ep_params, norm_type) + if ep_fsdp_group is not None: + op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM + dist.all_reduce(ep_val, op=op, group=ep_fsdp_group) + if ep_group is not None: + op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM + dist.all_reduce(ep_val, op=op, group=ep_group) + + # Combine + if math.isinf(norm_type): + total_norm = torch.maximum(non_ep_val, ep_val) + else: + total_norm = (non_ep_val + ep_val) ** (1.0 / norm_type) + + # Clip both groups with the same coefficient + clip_coef = float(max_grad_norm) / (float(total_norm.item()) + 1e-6) + if clip_coef < 1.0: + all_params = ep_params + non_ep_params + for p in all_params: + if p.grad is not None: + p.grad.mul_(clip_coef) + + return float(total_norm.item()) + + +def _local_norm_stat(params, norm_type: float): + """Compute local norm statistic: sum of p-th powers (finite p) or max (inf).""" + import math + import torch + from torch.distributed._tensor import DTensor + + device = None + for p in params: + if p.grad is not None: + g = p.grad.to_local() if isinstance(p.grad, DTensor) else p.grad + if g.is_cuda or getattr(g, 'is_npu', False): + device = g.device + break + if device is None: + device = torch.device(Platform.get_local_device()) + + if math.isinf(norm_type): + val = torch.tensor(0.0, device=device, dtype=torch.float32) + for p in params: + if p.grad is None: + continue + g = p.grad.to_local() if isinstance(p.grad, DTensor) else p.grad + if g.numel() == 0: + continue + val = torch.maximum(val, g.detach().to(torch.float32).abs().max()) + return val + else: + val = torch.tensor(0.0, device=device, dtype=torch.float32) + for p in params: + if p.grad is None: + continue + g = p.grad.to_local() if isinstance(p.grad, DTensor) else p.grad + if g.numel() == 0: + continue + val += g.detach().to(torch.float32).pow(norm_type).sum() + return val diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py index 16106063..81472d1d 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/platform.py @@ -124,7 +124,7 @@ def __post_init__(self): if not isinstance(self.mesh, np.ndarray): self.mesh = np.array(self.mesh) - valid_dim_names = {'dp', 'fsdp', 'tp', 'pp', 'cp', 'ep'} + valid_dim_names = {'dp', 'fsdp', 'tp', 'pp', 'cp', 'ep', 'ep_fsdp'} if self.mesh_dim_names is not None: if len(self.mesh_dim_names) != len(self.mesh.shape): raise ValueError(f'The shape of `mesh_dim_names`:({len(self.mesh_dim_names)}) ' @@ -208,22 +208,21 @@ def get_ranks_for_dims(self, dims): slices.append(coord[i]) return sorted(self.mesh[tuple(slices)].flatten().tolist()) - def is_implicit_ep_fsdp_enabled(self) -> bool: - ep_world_size = self.ep_world_size or 1 - dp_world_size = self.dp_world_size or 1 - if ep_world_size <= 1 or dp_world_size <= 1: - return False - - world_size = self.world_size or 1 - if world_size % ep_world_size != 0: - raise ValueError(f'world_size ({world_size}) must be divisible by ep_world_size ({ep_world_size}) ' - 'to infer implicit EP_FSDP from dp.') - expected_dp_size = world_size // ep_world_size - if dp_world_size != expected_dp_size: - raise ValueError(f'Implicit EP_FSDP requires dp_world_size == world_size // ep_world_size, ' - f'but got dp_world_size={dp_world_size}, world_size={world_size}, ' - f'ep_world_size={ep_world_size}.') - return True + def build_ep_fsdp_device_mesh(self, ep_size: int = None): + import math + import torch + ep_size = ep_size or self.ep_size or 1 + if ep_size <= 1: + return None + world_size = self.world_size + assert world_size % ep_size == 0, ( + f'world_size ({world_size}) must be divisible by ep_size ({ep_size})') + ep_fsdp_size = world_size // ep_size + with torch.device('cpu'): + mesh = (torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int) + .view(ep_fsdp_size, ep_size) + .transpose(0, 1)) + return torch.distributed.DeviceMesh(self.device_type, mesh, mesh_dim_names=('ep', 'ep_fsdp')) @property def order(self): diff --git a/tests/moe/test_ep_fsdp_vs_single.py b/tests/moe/test_ep_fsdp_vs_single.py new file mode 100644 index 00000000..f610ce30 --- /dev/null +++ b/tests/moe/test_ep_fsdp_vs_single.py @@ -0,0 +1,460 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Test EP+FSDP vs single-GPU precision: +1. Forward Logits & Loss +2. Gradients (non-expert and expert layers) +3. Updated Weights after optimizer step +""" + +import numpy as np +import os +import socket +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import unittest +from datetime import timedelta +from transformers import AutoConfig, AutoModelForCausalLM + +from twinkle.model.transformers.moe import apply_expert_parallel +from twinkle.model.transformers.strategy import NativeFSDPStrategy +from twinkle.utils import DeviceMesh + + +ABS_TOL = 5e-3 +LOSS_TOL = 1e-4 +REL_TOL = 1e-4 + + +def _find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +def _load_config(model_id: str, local_files_only: bool): + return AutoConfig.from_pretrained(model_id, trust_remote_code=True, local_files_only=local_files_only) + + +def _single_snapshot_path(port: int) -> str: + return f'/tmp/twinkle_ep_fsdp_vs_single_{port}.pt' + + +def _load_model(model_id: str, local_only: bool, device: torch.device, num_layers: int = 1): + config = _load_config(model_id, local_only) + if hasattr(config, 'num_hidden_layers'): + config.num_hidden_layers = num_layers + if hasattr(config, 'use_cache'): + config.use_cache = False + if hasattr(config, '_experts_implementation'): + config._experts_implementation = 'eager' + + # 关闭 Dropout,确保确定性 + dropout_attrs = ['attention_dropout', 'hidden_dropout', 'classifier_dropout', + 'resid_pdrop', 'embd_pdrop'] + for attr in dropout_attrs: + if hasattr(config, attr): + setattr(config, attr, 0.0) + + model = AutoModelForCausalLM.from_pretrained( + model_id, config=config, torch_dtype=torch.float32, + low_cpu_mem_usage=True, trust_remote_code=True, local_files_only=local_only + ) + model.to(device) + return model + + +def _clean_name(name: str) -> str: + """清洗参数名,去掉 FSDP wrapper 等前缀""" + name = name.replace('_fsdp_wrapped_module.', '') + name = name.replace('module.', '') + return name + + +def _get_full_tensor(tensor_obj): + """处理 DTensor 还原,普通 Tensor 直接返回""" + if tensor_obj is None: + return None + if hasattr(tensor_obj, 'full_tensor'): + return tensor_obj.full_tensor().detach().cpu() + elif hasattr(tensor_obj, '_local_tensor'): + return tensor_obj._local_tensor.detach().cpu() + else: + return tensor_obj.detach().cpu() + + +def _split_range(total: int, rank: int, world_size: int) -> tuple[int, int]: + if world_size <= 1: + return 0, total + if rank < 0 or rank >= world_size: + return 0, 0 + base, rem = divmod(total, world_size) + start = rank * base + min(rank, rem) + end = start + base + (1 if rank < rem else 0) + return start, end + + +def _match_single_tensor_for_compare(mapped_name: str, + multi_tensor: torch.Tensor, + single_dict: dict, + ep_rank: int, + fsdp_rank: int, + fsdp_world_size: int): + """返回和 multi_tensor 对齐后的 single tensor;expert 参数支持 dim0(ep)+dim1(fsdp) 切片。""" + single_tensor = single_dict.get(mapped_name) + if single_tensor is None: + return None + + if multi_tensor.shape == single_tensor.shape: + return single_tensor + + # EP/FSDP 下 expert 参数切片规则: + # - dim0 按 EP 切片:single=[num_experts,...] -> local experts + # - dim1 按 FSDP 切片:local experts 后再切 dim1 + if 'experts.' not in mapped_name: + return None + if multi_tensor.ndim < 1 or single_tensor.ndim < 1: + return None + candidate = single_tensor + + # 1) dim0: ep shard + if candidate.ndim >= 1 and candidate.shape[0] != multi_tensor.shape[0]: + local_experts = multi_tensor.shape[0] + total_experts = candidate.shape[0] + if local_experts == 0 or total_experts % local_experts != 0: + return None + ep_world_size = total_experts // local_experts + if ep_rank is None or ep_rank < 0 or ep_rank >= ep_world_size: + return None + start0 = ep_rank * local_experts + end0 = start0 + local_experts + candidate = candidate[start0:end0] + + # 2) dim1: fsdp shard + if candidate.ndim >= 2 and candidate.shape[1] != multi_tensor.shape[1]: + if fsdp_world_size is None or fsdp_world_size <= 1: + return None + if fsdp_rank is None: + return None + start1, end1 = _split_range(candidate.shape[1], int(fsdp_rank), int(fsdp_world_size)) + if end1 <= start1: + return None + candidate = candidate[:, start1:end1, ...] + + if candidate.shape != multi_tensor.shape: + return None + return candidate + + +def _run_single_gpu(rank, world_size, port, model_id, local_only): + """Single GPU baseline.""" + os.environ.update({ + 'RANK': '0', 'WORLD_SIZE': '1', 'LOCAL_RANK': '0', + 'LOCAL_WORLD_SIZE': '1', 'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': str(port) + }) + + if not torch.cuda.is_available(): + raise RuntimeError('CUDA required') + device = torch.device('cuda:0') + torch.cuda.set_device(device) + torch.manual_seed(1234) + + model = _load_model(model_id, local_only, device, num_layers=1) + model.train() + vocab_size = int(model.config.vocab_size) + + batch_size, seq_len = 2, 1024 + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1) + labels = input_ids.clone() + labels[:, 0] = -100 + + outputs = model(input_ids=input_ids, position_ids=position_ids, labels=labels, use_cache=False) + loss = outputs.loss + loss.backward() + + # Collect grads and weights + grad_dict = {n: p.grad.detach().cpu() for n, p in model.named_parameters() if p.grad is not None} + opt = torch.optim.AdamW(model.parameters(), lr=1e-5, foreach=False) + opt.step() + new_weight_dict = {n: p.detach().cpu() for n, p in model.named_parameters()} + + # 保存输入数据,确保多卡使用相同输入 + torch.save({ + 'input_ids': input_ids.cpu(), + 'position_ids': position_ids.cpu(), + 'logits': outputs.logits.detach().cpu(), + 'loss': loss.item(), + 'grad_dict': grad_dict, + 'new_weight_dict': new_weight_dict, + }, _single_snapshot_path(port)) + + print(f'[Single] Loss={loss.item():.4f}') + + +def _run_multi_gpu(rank, world_size, port, model_id, local_only): + """4-GPU EP+FSDP with two independent meshes.""" + os.environ.update({ + 'RANK': str(rank), 'WORLD_SIZE': str(world_size), + 'LOCAL_RANK': str(rank), 'LOCAL_WORLD_SIZE': str(world_size), + 'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': str(port) + }) + + if not torch.cuda.is_available(): + raise RuntimeError('CUDA required') + device = torch.device(f'cuda:{rank}') + torch.cuda.set_device(device) + + dist.init_process_group( + backend='nccl', rank=rank, world_size=world_size, + init_method=f'tcp://127.0.0.1:{port}', device_id=device, timeout=timedelta(minutes=15) + ) + dist.barrier() + + try: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + + # New design: main mesh does NOT include 'ep' dimension + # Main mesh: fsdp=4 (all 4 GPUs for FSDP) + # ep_size=2 is stored as attribute, used to build separate ep_fsdp_device_mesh + ep_size = 2 + device_mesh = DeviceMesh( + device_type='cuda', + mesh=np.arange(world_size).reshape(world_size), # 1D mesh: [0, 1, 2, 3] + mesh_dim_names=('fsdp',), + ep_size=ep_size, # ep_size as attribute, not mesh dimension + ) + + model = _load_model(model_id, local_only, device, num_layers=1) + model.train() + + # 加载单卡的输入数据 + single_data = torch.load(_single_snapshot_path(port), weights_only=True) + input_ids = single_data['input_ids'].to(device) + position_ids = single_data['position_ids'].to(device) + labels = input_ids.clone() + labels[:, 0] = -100 + + # Build explicit ep_fsdp_device_mesh + ep_fsdp_mesh = device_mesh.build_ep_fsdp_device_mesh(ep_size=ep_size) + + # Apply EP with ep_fsdp_device_mesh + apply_expert_parallel( + getattr(model, 'model', model), + device_mesh, + config={'enabled': True, 'router_dtype': 'fp32', 'keep_router_logits': False}, + ep_fsdp_device_mesh=ep_fsdp_mesh, + ) + + # FSDP2 wrap + fsdp = NativeFSDPStrategy( + device_mesh=device_mesh, + mixed_precision='no', + fsdp_config={}, + enable_ep=True, + ep_fsdp_device_mesh=ep_fsdp_mesh, + ) + model, _ = fsdp.wrap_model(model, optimizer=None) + + outputs = model(input_ids=input_ids, position_ids=position_ids, labels=labels, use_cache=False) + loss = outputs.loss + loss.backward() + + # 提取梯度 - 使用 full_tensor() 还原 DTensor + grad_dict = {} + for n, p in model.named_parameters(): + if p.grad is not None: + grad_dict[n] = _get_full_tensor(p.grad) + + # Optimizer step + opt = torch.optim.AdamW(model.parameters(), lr=1e-5, foreach=False) + opt.step() + + # 优化后的权重 + new_weight_dict = {} + for n, p in model.named_parameters(): + new_weight_dict[n] = _get_full_tensor(p) + + # 对比 - 所有 rank 都参与 + single_grads = single_data['grad_dict'] + single_weights = single_data['new_weight_dict'] + + # Get EP rank from ep_fsdp_mesh + ep_rank = ep_fsdp_mesh.get_local_rank('ep') if ep_fsdp_mesh is not None else 0 + fsdp_rank = device_mesh.fsdp_rank or 0 + fsdp_world_size = device_mesh.fsdp_world_size or 1 + + # Forward 对比只在 rank 0 计算,再广播结果,避免其余 rank 卡在 barrier + forward_err = None + if rank == 0: + single_logits = single_data['logits'] + multi_logits = _get_full_tensor(outputs.logits) + logits_abs_diff = (single_logits - multi_logits).abs() + logits_max_diff = logits_abs_diff.max().item() + logits_mean_diff = logits_abs_diff.mean().item() + print(f'\n=== Forward: Logits ===') + print(f' Max diff: {logits_max_diff:.2e}, Mean diff: {logits_mean_diff:.2e}') + if not torch.allclose(single_logits, multi_logits, rtol=REL_TOL, atol=ABS_TOL): + forward_err = f'Logits mismatch! Max diff: {logits_max_diff}, Mean diff: {logits_mean_diff}' + + print(f'\n=== Forward: Loss ===') + print(f' Single: {single_data["loss"]:.6f}, Multi: {loss.item():.6f}') + loss_diff = abs(single_data['loss'] - loss.item()) + single_loss_t = torch.tensor(single_data['loss'], dtype=torch.float32) + multi_loss_t = torch.tensor(loss.item(), dtype=torch.float32) + if (not torch.allclose(single_loss_t, multi_loss_t, rtol=REL_TOL, atol=LOSS_TOL) + and forward_err is None): + forward_err = f'Loss mismatch! Diff: {loss_diff}' + + obj = [forward_err] + dist.broadcast_object_list(obj, src=0) + if obj[0] is not None: + raise AssertionError(obj[0]) + + # 对比非专家层梯度 + print(f'\n=== Rank {rank}: Gradients (non-expert) ===') + verified = 0 + seen_mapped = set() + for n in grad_dict: + mapped_n = _clean_name(n) + if mapped_n in seen_mapped or 'experts.' in mapped_n: + continue + seen_mapped.add(mapped_n) + m_grad = grad_dict[n] + s_grad = _match_single_tensor_for_compare( + mapped_n, + m_grad, + single_grads, + ep_rank, + fsdp_rank, + fsdp_world_size, + ) + if s_grad is None: + continue + grad_abs_diff = (s_grad - m_grad).abs() + grad_max_diff = grad_abs_diff.max().item() + grad_mean_diff = grad_abs_diff.mean().item() + is_close = torch.allclose(s_grad, m_grad, rtol=REL_TOL, atol=ABS_TOL) + status = 'PASS' if is_close else 'FAIL' + print(f' [{status}] {mapped_n}: max_diff={grad_max_diff:.2e}, mean_diff={grad_mean_diff:.2e}') + assert is_close, f"Grad mismatch for {mapped_n}! Max diff: {grad_max_diff}, Mean diff: {grad_mean_diff}" + verified += 1 + assert verified > 0, f"Error: No non-expert gradients were verified on rank {rank}!" + + # 对比专家层梯度 + print(f'\n=== Rank {rank}: Gradients (expert) ===') + verified = 0 + ratio_list = [] + seen_mapped = set() + for n in grad_dict: + mapped_n = _clean_name(n) + if mapped_n in seen_mapped or 'experts.' not in mapped_n: + continue + seen_mapped.add(mapped_n) + m_grad = grad_dict[n] + s_grad = _match_single_tensor_for_compare( + mapped_n, + m_grad, + single_grads, + ep_rank, + fsdp_rank, + fsdp_world_size, + ) + if s_grad is None: + continue + grad_abs_diff = (s_grad - m_grad).abs() + grad_max_diff = grad_abs_diff.max().item() + grad_mean_diff = grad_abs_diff.mean().item() + s_norm = s_grad.float().norm().item() + m_norm = m_grad.float().norm().item() + ratio = m_norm / (s_norm + 1e-12) + ratio_list.append(ratio) + is_close = torch.allclose(s_grad, m_grad, rtol=REL_TOL, atol=ABS_TOL) + status = 'PASS' if is_close else 'FAIL' + print( + f' [{status}] {mapped_n}: max_diff={grad_max_diff:.2e}, mean_diff={grad_mean_diff:.2e}, ' + f'norm_ratio(ep/single)={ratio:.4f}') + assert is_close, ( + f'Expert grad mismatch for {mapped_n}! ' + f'Max diff: {grad_max_diff}, Mean diff: {grad_mean_diff}, Ratio: {ratio}') + verified += 1 + if verified == 0: + print(f' [INFO] No expert gradients matched on rank {rank} (EP distribution is expected)') + else: + ratio_t = torch.tensor(ratio_list, dtype=torch.float32) + print( + f' [INFO] expert grad norm ratio(ep/single): ' + f'min={ratio_t.min().item():.4f}, max={ratio_t.max().item():.4f}, mean={ratio_t.mean().item():.4f}') + + # 对比更新后的权重 + print(f'\n=== Rank {rank}: Updated Weights ===') + verified = 0 + seen_mapped = set() + for n in new_weight_dict: + mapped_n = _clean_name(n) + if mapped_n in seen_mapped: + continue + seen_mapped.add(mapped_n) + m_w = new_weight_dict[n] + s_w = _match_single_tensor_for_compare( + mapped_n, + m_w, + single_weights, + ep_rank, + fsdp_rank, + fsdp_world_size, + ) + if s_w is None: + continue + weight_abs_diff = (s_w - m_w).abs() + weight_max_diff = weight_abs_diff.max().item() + weight_mean_diff = weight_abs_diff.mean().item() + is_close = torch.allclose(s_w, m_w, rtol=REL_TOL, atol=ABS_TOL) + status = 'PASS' if is_close else 'FAIL' + print(f' [{status}] {mapped_n}: max_diff={weight_max_diff:.2e}, mean_diff={weight_mean_diff:.2e}') + assert is_close, ( + f'Weight mismatch for {mapped_n}! Max diff: {weight_max_diff}, Mean diff: {weight_mean_diff}') + verified += 1 + assert verified > 0, f"Error: No weights were verified on rank {rank}!" + + dist.barrier() + except Exception as e: + print(f'Rank {rank} error: {e}') + raise + finally: + dist.destroy_process_group() + + +class TestEPFSDPvsSingle(unittest.TestCase): + + def test_alignment(self): + if not dist.is_available() or not torch.cuda.is_available(): + self.skipTest('Need distributed + CUDA') + if torch.cuda.device_count() < 4: + self.skipTest('Need 4 GPUs') + + model_id = os.environ.get('QWEN3_MOE_MODEL_ID', 'Qwen/Qwen3-30B-A3B-Instruct-2507') + local_only = os.environ.get('QWEN3_MOE_LOCAL_ONLY', '1') != '0' + + try: + _load_config(model_id, local_only) + except Exception as e: + self.skipTest(f'Model not available: {e}') + + port = _find_free_port() + snapshot_path = _single_snapshot_path(port) + + try: + # Run single GPU baseline + mp.spawn(_run_single_gpu, args=(1, port, model_id, local_only), nprocs=1, join=True) + + # Run 4-GPU EP+FSDP + mp.spawn(_run_multi_gpu, args=(4, port, model_id, local_only), nprocs=4, join=True) + finally: + if os.path.exists(snapshot_path): + os.remove(snapshot_path) + + +if __name__ == '__main__': + unittest.main() From 4cedfc6065cd18d9a585e315c301d0bdea571198 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Sat, 28 Feb 2026 09:40:32 +0800 Subject: [PATCH 6/9] wip --- cookbook/transformers/ep_fsdp_qwen3_moe.py | 9 +- .../model/transformers/moe/expert_parallel.py | 137 +++++++--- .../transformers/strategy/native_fsdp.py | 25 +- tests/moe/verify_qwen3_moe_permute_cpu.py | 242 ++++++++++++++++++ 4 files changed, 367 insertions(+), 46 deletions(-) create mode 100644 tests/moe/verify_qwen3_moe_permute_cpu.py diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.py b/cookbook/transformers/ep_fsdp_qwen3_moe.py index 85a943b2..88c17d51 100644 --- a/cookbook/transformers/ep_fsdp_qwen3_moe.py +++ b/cookbook/transformers/ep_fsdp_qwen3_moe.py @@ -24,15 +24,16 @@ MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) KEEP_ROUTER_LOGITS = os.environ.get('KEEP_ROUTER_LOGITS', '0') == '1' -# 4 gpus, fsdp=4 (data parallel), ep_size=2 (expert parallel) +# 4 gpus, dp=1, fsdp=4 (data parallel), ep_size=2 (expert parallel) # The main mesh does NOT include 'ep' dimension - EP is handled by separate ep_fsdp_device_mesh +dp_size = 1 fsdp_size = 4 ep_size = 2 device_mesh = DeviceMesh( device_type=Platform.get_platform().device_prefix(), - mesh=np.arange(fsdp_size).reshape(fsdp_size), - mesh_dim_names=('fsdp',), + mesh=np.arange(fsdp_size * dp_size).reshape(fsdp_size, dp_size), + mesh_dim_names=('fsdp', 'dp'), ep_size=ep_size, # ep_size is stored as attribute, not a mesh dimension ) @@ -49,7 +50,7 @@ def train(): if hasattr(config, 'use_cache'): config.use_cache = False - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000))) try: dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) except ValueError: diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 5798bcdc..27cd52b9 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -163,7 +163,15 @@ def patch_forward( """Replace the MoE block forward with EP-aware communication flow. Communication pattern follows VeOmni: - preprocess → token_pre_all2all → expert_compute (F.linear loop) → tokens_post_all2all + preprocess → token_pre_all2all → expert_compute → tokens_post_all2all + + For tensor experts (gate_up_proj/down_proj), the expert compute is delegated + to block.experts(...) via nn.Module.__call__ so that FSDP2 pre/post-forward + hooks fire correctly (automatic unshard before forward, backward hook + registration, and reshard after forward). No manual unshard/reshard is needed. + + For ModuleList experts, each sub-expert is already called via __call__ inside + _run_local_experts, so the same principle applies. Args: block: The MoE block to patch. @@ -185,12 +193,18 @@ def patch_forward( orig_forward = block.forward num_experts = block._ep_num_experts experts_per_rank = block._ep_experts_per_rank + is_tensor_experts = block._ep_tensor_experts + + # For tensor experts, install an ep_forward on the experts module so we can + # call block.experts(permuted_tokens, counts, experts_per_rank) via __call__, + # letting FSDP2 manage unshard/reshard automatically. + if is_tensor_experts: + _install_ep_forward(block.experts, experts_per_rank) def forward(hidden_states: torch.Tensor, *args, **kwargs): if args or kwargs: raise RuntimeError('Expert parallel patch only supports forward(hidden_states).') - input_dtype = hidden_states.dtype orig_shape = hidden_states.shape if hidden_states.ndim == 3: batch_size, seq_len, hidden_dim = hidden_states.shape @@ -202,14 +216,15 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): else: raise ValueError(f'Unsupported hidden_states ndim: {hidden_states.ndim}') - router_logits, routing_weights, selected_experts, cast_weights = _run_router( + router_logits, routing_weights, selected_experts = _run_router( gate=gate, hidden_states=hidden_states_2d, top_k=top_k, router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype), norm_topk_prob=getattr(block, 'norm_topk_prob', False), ) - if cast_weights: + # Keep routing weights in activation dtype before unpermute weighting. + if routing_weights.dtype != hidden_states_2d.dtype: routing_weights = routing_weights.to(hidden_states_2d.dtype) # Build expert_mask: [num_experts, top_k, num_tokens] (VeOmni convention) @@ -241,21 +256,23 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): ep_group, ) - # 3. expert_compute: F.linear loop per local expert (no routing weight here) - # When FSDP2 wraps experts, params are sharded DTensors. Manually - # unshard (all-gather) so _run_local_experts sees full tensors. - _experts_mod = block.experts - _need_unshard = hasattr(_experts_mod, 'unshard') and hasattr(_experts_mod, 'reshard') - if _need_unshard: - _experts_mod.unshard() - expert_outputs = _run_local_experts( - block, - global_permuted_hidden_states, - num_global_sum_tokens_per_local_expert, - experts_per_rank, - ) - if _need_unshard: - _experts_mod.reshard() + # 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire. + # For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank) + # → FSDP2 pre-forward unshard → ep_forward → FSDP2 post-forward reshard + # For ModuleList experts: _run_local_experts calls each expert[i](...) via __call__. + if is_tensor_experts: + expert_outputs = block.experts( + global_permuted_hidden_states, + num_global_sum_tokens_per_local_expert, + experts_per_rank, + ) + else: + expert_outputs = _run_local_experts( + block, + global_permuted_hidden_states, + num_global_sum_tokens_per_local_expert, + experts_per_rank, + ) # 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight) final_hidden = tokens_post_all2all( @@ -279,7 +296,7 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): if len(orig_shape) == 3: final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim) - if cfg.keep_router_logits and not getattr(block, '_ep_tensor_experts', False): + if cfg.keep_router_logits: return final_hidden, router_logits return final_hidden @@ -288,6 +305,56 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): block._ep_patched = True +def _install_ep_forward(experts_mod: nn.Module, experts_per_rank: int) -> None: + if getattr(experts_mod, '_ep_forward_installed', False): + return + + def ep_forward( + self, + permuted_tokens: torch.Tensor, + num_global_sum_tokens_per_local_expert: torch.Tensor, + experts_per_rank: int, + ) -> torch.Tensor: + if permuted_tokens.numel() == 0: + return torch.empty_like(permuted_tokens) + + input_dtype = permuted_tokens.dtype + + cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) + for i in range(experts_per_rank): + cumsum[i + 1] = cumsum[i] + int(num_global_sum_tokens_per_local_expert[i].item()) + + output_chunks = [] + for i in range(experts_per_rank): + start = int(cumsum[i].item()) + end = int(cumsum[i + 1].item()) + expert_in = permuted_tokens[start:end] + if expert_in.numel() == 0: + output_chunks.append(expert_in) + continue + + gate_up = self.gate_up_proj[i] + down = self.down_proj[i] + compute_dtype = gate_up.dtype + if expert_in.dtype != compute_dtype: + expert_in = expert_in.to(compute_dtype) + gate, up = F.linear(expert_in, gate_up).chunk(2, dim=-1) + out = self.act_fn(gate) * up + out = F.linear(out, down) + + if out.dtype != input_dtype: + out = out.to(input_dtype) + output_chunks.append(out) + + return torch.cat(output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty( + 0, permuted_tokens.size(-1) + ) + + import types + experts_mod.forward = types.MethodType(ep_forward, experts_mod) + experts_mod._ep_forward_installed = True + + def _get_gate(block: nn.Module): gate = getattr(block, 'gate', None) if gate is None: @@ -364,8 +431,7 @@ def _run_local_experts( num_global_sum_tokens_per_local_expert: torch.Tensor, experts_per_rank: int, ) -> torch.Tensor: - """Run local experts on permuted tokens using F.linear loop. - + """Run ModuleList experts on permuted tokens via nn.Module.__call__. Tokens are already grouped by expert (contiguous chunks), sizes given by num_global_sum_tokens_per_local_expert. No routing weight is applied here; that happens in unpermute. @@ -374,7 +440,6 @@ def _run_local_experts( return torch.empty_like(permuted_tokens) input_dtype = permuted_tokens.dtype - is_tensor_experts = getattr(block, '_ep_tensor_experts', False) experts = block.experts cumsum = torch.zeros(experts_per_rank + 1, dtype=torch.long) @@ -390,21 +455,11 @@ def _run_local_experts( output_chunks.append(expert_in) continue - if is_tensor_experts: - gate_up = experts.gate_up_proj[i] - down = experts.down_proj[i] - compute_dtype = gate_up.dtype - if expert_in.dtype != compute_dtype: - expert_in = expert_in.to(compute_dtype) - gate, up = F.linear(expert_in, gate_up).chunk(2, dim=-1) - out = experts.act_fn(gate) * up - out = F.linear(out, down) - else: - expert = experts[i] - compute_dtype = _module_compute_dtype(expert, input_dtype) - if expert_in.dtype != compute_dtype: - expert_in = expert_in.to(compute_dtype) - out = expert(expert_in) + expert = experts[i] + compute_dtype = _module_compute_dtype(expert, input_dtype) + if expert_in.dtype != compute_dtype: + expert_in = expert_in.to(compute_dtype) + out = expert(expert_in) if out.dtype != input_dtype: out = out.to(input_dtype) @@ -438,15 +493,15 @@ def _run_router( top_k: int, router_dtype: torch.dtype, norm_topk_prob: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gate_out = gate(hidden_states) if isinstance(gate_out, tuple) and len(gate_out) >= 3: router_logits, routing_weights, selected_experts = gate_out[:3] - return router_logits, routing_weights, selected_experts, False + return router_logits, routing_weights, selected_experts router_logits = gate_out routing_weights = torch.softmax(router_logits, dim=-1, dtype=router_dtype) routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) if norm_topk_prob: routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - return router_logits, routing_weights, selected_experts, True + return router_logits, routing_weights, selected_experts diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 845e428e..7a700c46 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -60,11 +60,15 @@ def wrap_model(self, model, optimizer=None): if experts_mod is not None and ep_fsdp_mesh_1d is not None: from torch.distributed.tensor import Shard + # PreMulSum (used by set_gradient_divide_factor) only supports + # float16/float32/float64; override reduce_dtype to float32 + # when the base policy uses bfloat16. + ep_mp_policy = _build_ep_mp_policy(mp_policy) fully_shard( experts_mod, mesh=ep_fsdp_mesh_1d, reshard_after_forward=reshard_after_forward, - mp_policy=mp_policy, + mp_policy=ep_mp_policy, shard_placement_fn=lambda param: Shard(1), ) # gradient_divide_factor = world_size (VeOmni convention) @@ -126,6 +130,25 @@ def _build_mp_policy(mixed_precision: str) -> 'MixedPrecisionPolicy': ) +def _build_ep_mp_policy(base_policy: 'MixedPrecisionPolicy') -> 'MixedPrecisionPolicy': + """Build a MixedPrecisionPolicy for EP experts with reduce_dtype=float32. + + NCCL's PreMulSum (used by set_gradient_divide_factor) only supports + float16/float32/float64. When the base policy uses bfloat16 as reduce_dtype, + we must override it to float32 for the expert FSDP group. + """ + from torch.distributed.fsdp import MixedPrecisionPolicy + reduce_dtype = base_policy.reduce_dtype + if reduce_dtype == torch.bfloat16: + reduce_dtype = torch.float32 + return MixedPrecisionPolicy( + param_dtype=base_policy.param_dtype, + reduce_dtype=reduce_dtype, + output_dtype=base_policy.output_dtype, + cast_forward_inputs=base_policy.cast_forward_inputs, + ) + + def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]: if device_mesh is None or device_mesh.mesh_dim_names is None: return None diff --git a/tests/moe/verify_qwen3_moe_permute_cpu.py b/tests/moe/verify_qwen3_moe_permute_cpu.py new file mode 100644 index 00000000..8f1ea248 --- /dev/null +++ b/tests/moe/verify_qwen3_moe_permute_cpu.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +CPU script to verify numerical alignment between: +1) transformers original Qwen3MoeSparseMoeBlock forward +2) twinkle local path: permute -> expert compute -> unpermute + +This script does NOT use distributed init and does NOT include all_to_all communication. + +Usage: + python tests/moe/verify_qwen3_moe_permute_cpu.py + python tests/moe/verify_qwen3_moe_permute_cpu.py --transformers-root /mnt/d/workspace/transformers +""" + +from __future__ import annotations +from twinkle.model.transformers.moe.expert_parallel import _run_router +from twinkle.model.transformers.moe.ep_utils import generate_weights_idx, permute, unpermute + +import argparse +import copy +import sys +from pathlib import Path +from types import SimpleNamespace + +import torch +import torch.nn.functional as F + +# Allow running directly from repository root without installing twinkle. +REPO_ROOT = Path(__file__).resolve().parents[2] +SRC_DIR = REPO_ROOT / 'src' +if str(SRC_DIR) not in sys.path: + sys.path.insert(0, str(SRC_DIR)) + + +def _import_qwen3_block(transformers_root: str): + try: + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock + + return Qwen3MoeSparseMoeBlock + except Exception: + src_dir = Path(transformers_root).expanduser().resolve() / 'src' + if not src_dir.exists(): + raise RuntimeError( + f'Cannot import transformers qwen3_moe, and fallback path does not exist: {src_dir}' + ) + sys.path.insert(0, str(src_dir)) + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock + + return Qwen3MoeSparseMoeBlock + + +def _build_config( + hidden_size: int, + moe_intermediate_size: int, + num_experts: int, + top_k: int, + hidden_act: str, + norm_topk_prob: bool, +): + return SimpleNamespace( + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + num_experts=num_experts, + num_experts_per_tok=top_k, + hidden_act=hidden_act, + norm_topk_prob=norm_topk_prob, + _experts_implementation='eager', + ) + + +def _run_local_permute_expert_unpermute(block: torch.nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: + """Reproduce Twinkle EP compute path locally without communication.""" + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_states_2d = hidden_states.view(-1, hidden_dim) + + _, routing_weights, selected_experts = _run_router( + gate=block.gate, + hidden_states=hidden_states_2d, + top_k=block.gate.top_k, + router_dtype=torch.float32, + norm_topk_prob=bool(getattr(block.gate, 'norm_topk_prob', False)), + ) + + num_experts = int(block.experts.num_experts) + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0) + routing_map = expert_mask.sum(dim=1) + + permuted_tokens, permutation_mapping = permute(hidden_states_2d, routing_map) + + num_tokens_per_expert = routing_map.sum(dim=1).to(dtype=torch.long) + cumsum = torch.cat( + [ + torch.zeros(1, dtype=torch.long, device=num_tokens_per_expert.device), + num_tokens_per_expert.cumsum(dim=0), + ] + ) + + outputs = [] + experts = block.experts + input_dtype = permuted_tokens.dtype + + for expert_idx in range(num_experts): + start = int(cumsum[expert_idx].item()) + end = int(cumsum[expert_idx + 1].item()) + x = permuted_tokens[start:end] + if x.numel() == 0: + outputs.append(x) + continue + + gate_up = experts.gate_up_proj[expert_idx] + down = experts.down_proj[expert_idx] + compute_dtype = gate_up.dtype + + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + gate, up = F.linear(x, gate_up).chunk(2, dim=-1) + out = experts.act_fn(gate) * up + out = F.linear(out, down) + + if out.dtype != input_dtype: + out = out.to(input_dtype) + + outputs.append(out) + + expert_outputs = ( + torch.cat(outputs, dim=0) + if outputs + else permuted_tokens.new_empty((0, permuted_tokens.size(-1))) + ) + + weights_idx = generate_weights_idx(routing_weights, selected_experts, num_experts) + final_hidden_states_2d = unpermute( + expert_outputs, + weights_idx, + hidden_states_2d.shape, + permutation_mapping, + routing_map, + ) + + return final_hidden_states_2d.view(batch_size, seq_len, hidden_dim) + + +def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: + return (a - b).abs().max().item() + + +@torch.no_grad() +def _init_block_weights(block: torch.nn.Module, seed: int, std: float = 0.02) -> None: + # Standalone Qwen3MoeSparseMoeBlock does not go through PreTrainedModel.post_init(), + # so initialize weights explicitly to avoid trivial all-zero / allocator-dependent cases. + generator = torch.Generator(device='cpu') + generator.manual_seed(seed) + + experts = block.experts + gate = block.gate + experts.gate_up_proj.copy_(torch.randn(experts.gate_up_proj.shape, generator=generator) * std) + experts.down_proj.copy_(torch.randn(experts.down_proj.shape, generator=generator) * std) + gate.weight.copy_(torch.randn(gate.weight.shape, generator=generator) * std) + + +def main(): + parser = argparse.ArgumentParser(description='CPU precision check for Qwen3-MoE sparse block.') + parser.add_argument('--transformers-root', type=str, default='/mnt/d/workspace/transformers') + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--batch-size', type=int, default=2) + parser.add_argument('--seq-len', type=int, default=1024) + parser.add_argument('--hidden-size', type=int, default=64) + parser.add_argument('--moe-intermediate-size', type=int, default=32) + parser.add_argument('--num-experts', type=int, default=8) + parser.add_argument('--top-k', type=int, default=2) + parser.add_argument('--hidden-act', type=str, default='silu') + parser.add_argument('--norm-topk-prob', action='store_true') + parser.add_argument('--atol', type=float, default=1e-6) + parser.add_argument('--rtol', type=float, default=1e-6) + args = parser.parse_args() + + torch.set_num_threads(1) + torch.manual_seed(args.seed) + + Qwen3MoeSparseMoeBlock = _import_qwen3_block(args.transformers_root) + cfg = _build_config( + hidden_size=args.hidden_size, + moe_intermediate_size=args.moe_intermediate_size, + num_experts=args.num_experts, + top_k=args.top_k, + hidden_act=args.hidden_act, + norm_topk_prob=args.norm_topk_prob, + ) + + ref_block = Qwen3MoeSparseMoeBlock(cfg).cpu().float().train() + _init_block_weights(ref_block, seed=args.seed) + test_block = copy.deepcopy(ref_block).cpu().float().train() + + hidden_ref = torch.randn(args.batch_size, args.seq_len, args.hidden_size, dtype=torch.float32, requires_grad=True) + hidden_test = hidden_ref.detach().clone().requires_grad_(True) + + ref_out = ref_block(hidden_ref) + test_out = _run_local_permute_expert_unpermute(test_block, hidden_test) + + # Use identical loss form for backward alignment. + proj = torch.randn(args.hidden_size, dtype=torch.float32) + ref_loss = (ref_out * proj).sum() + test_loss = (test_out * proj).sum() + + ref_loss.backward() + test_loss.backward() + + out_max_diff = _max_abs_diff(ref_out.detach(), test_out.detach()) + in_grad_max_diff = _max_abs_diff(hidden_ref.grad.detach(), hidden_test.grad.detach()) + + print('\n=== Qwen3-MoE Sparse Block CPU Alignment ===') + print(f'seed={args.seed} shape=({args.batch_size}, {args.seq_len}, {args.hidden_size})') + print(f'num_experts={args.num_experts} top_k={args.top_k} hidden_act={args.hidden_act}') + print(f'forward max_abs_diff: {out_max_diff:.8e}') + print(f'input grad max_abs_diff: {in_grad_max_diff:.8e}') + + param_ok = True + for (name_ref, p_ref), (name_test, p_test) in zip(ref_block.named_parameters(), test_block.named_parameters()): + if name_ref != name_test: + raise RuntimeError(f'Parameter name mismatch: {name_ref} vs {name_test}') + if p_ref.grad is None or p_test.grad is None: + raise RuntimeError(f'Missing grad for parameter: {name_ref}') + diff = _max_abs_diff(p_ref.grad.detach(), p_test.grad.detach()) + print(f'grad[{name_ref}] max_abs_diff: {diff:.8e}') + if not torch.allclose(p_ref.grad, p_test.grad, rtol=args.rtol, atol=args.atol): + param_ok = False + + out_ok = torch.allclose(ref_out, test_out, rtol=args.rtol, atol=args.atol) + in_grad_ok = torch.allclose(hidden_ref.grad, hidden_test.grad, rtol=args.rtol, atol=args.atol) + + print('\n=== Result ===') + print(f'forward aligned: {out_ok}') + print(f'input grad aligned: {in_grad_ok}') + print(f'param grad aligned: {param_ok}') + + if not (out_ok and in_grad_ok and param_ok): + raise SystemExit(1) + + +if __name__ == '__main__': + main() From 7c0c4bd8c382fcbef302d16ddb41fda863da4ea8 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Sat, 28 Feb 2026 10:11:10 +0800 Subject: [PATCH 7/9] wip --- src/twinkle/utils/grad_clip.py | 119 ++++++----- tests/moe/verify_qwen3_moe_permute_cpu.py | 242 ---------------------- 2 files changed, 66 insertions(+), 295 deletions(-) delete mode 100644 tests/moe/verify_qwen3_moe_permute_cpu.py diff --git a/src/twinkle/utils/grad_clip.py b/src/twinkle/utils/grad_clip.py index 11ded854..efe7e9ef 100644 --- a/src/twinkle/utils/grad_clip.py +++ b/src/twinkle/utils/grad_clip.py @@ -150,7 +150,7 @@ def _ep_aware_clip_grad_norm( - non-EP params: all-reduce over fsdp_group - EP params: all-reduce over ep_fsdp_group, then ep_group - - Unified clip coefficient applied to both groups + - Unified clip coefficient applied to both groups via clip_grads_with_norm_ """ import math import torch @@ -161,71 +161,84 @@ def _ep_aware_clip_grad_norm( norm_type = float(norm_type) - # non-EP: reduce over fsdp_group - non_ep_val = _local_norm_stat(non_ep_params, norm_type) - if fsdp_group is not None: - op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM - dist.all_reduce(non_ep_val, op=op, group=fsdp_group) - - # EP: reduce over ep_fsdp_group, then ep_group - ep_val = _local_norm_stat(ep_params, norm_type) - if ep_fsdp_group is not None: - op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM - dist.all_reduce(ep_val, op=op, group=ep_fsdp_group) - if ep_group is not None: - op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM - dist.all_reduce(ep_val, op=op, group=ep_group) - - # Combine - if math.isinf(norm_type): - total_norm = torch.maximum(non_ep_val, ep_val) - else: - total_norm = (non_ep_val + ep_val) ** (1.0 / norm_type) + with torch.no_grad(): + # non-EP: reduce over fsdp_group + non_ep_val = _local_norm_stat(non_ep_params, norm_type) + if fsdp_group is not None: + op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM + dist.all_reduce(non_ep_val, op=op, group=fsdp_group) + + # EP: reduce over ep_fsdp_group, then ep_group + ep_val = _local_norm_stat(ep_params, norm_type) + if ep_fsdp_group is not None: + op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM + dist.all_reduce(ep_val, op=op, group=ep_fsdp_group) + if ep_group is not None: + op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM + dist.all_reduce(ep_val, op=op, group=ep_group) + + # Combine into total_norm tensor + if math.isinf(norm_type): + total_norm = torch.maximum(non_ep_val, ep_val) + else: + total_norm = (non_ep_val + ep_val) ** (1.0 / norm_type) - # Clip both groups with the same coefficient - clip_coef = float(max_grad_norm) / (float(total_norm.item()) + 1e-6) - if clip_coef < 1.0: - all_params = ep_params + non_ep_params - for p in all_params: - if p.grad is not None: - p.grad.mul_(clip_coef) + # Clip both groups with the same coefficient via PyTorch builtin (foreach-accelerated) + torch.nn.utils.clip_grads_with_norm_(ep_params, max_grad_norm, total_norm) + torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_grad_norm, total_norm) return float(total_norm.item()) def _local_norm_stat(params, norm_type: float): - """Compute local norm statistic: sum of p-th powers (finite p) or max (inf).""" + """Compute local norm statistic: sum of p-th powers (finite p) or max (inf). + + Uses torch._foreach_* batch kernels for finite p to reduce kernel launch overhead. + """ import math import torch from torch.distributed._tensor import DTensor - - device = None + from torch.utils._foreach_utils import ( + _device_has_foreach_support, + _group_tensors_by_device_and_dtype, + _has_foreach_support, + ) + + grads_local = [] + default_device = None for p in params: - if p.grad is not None: - g = p.grad.to_local() if isinstance(p.grad, DTensor) else p.grad - if g.is_cuda or getattr(g, 'is_npu', False): - device = g.device - break - if device is None: - device = torch.device(Platform.get_local_device()) + if p.grad is None: + continue + g = p.grad.to_local() if isinstance(p.grad, DTensor) else p.grad + if default_device is None and (g.is_cuda or getattr(g, 'is_npu', False)): + default_device = g.device + grads_local.append(g.detach().to(torch.float32)) + + if default_device is None: + default_device = torch.device(Platform.get_local_device()) if math.isinf(norm_type): - val = torch.tensor(0.0, device=device, dtype=torch.float32) - for p in params: - if p.grad is None: - continue - g = p.grad.to_local() if isinstance(p.grad, DTensor) else p.grad + val = torch.tensor(0.0, device=default_device, dtype=torch.float32) + for g in grads_local: if g.numel() == 0: continue - val = torch.maximum(val, g.detach().to(torch.float32).abs().max()) + val = torch.maximum(val, g.abs().max()) return val - else: - val = torch.tensor(0.0, device=device, dtype=torch.float32) - for p in params: - if p.grad is None: - continue - g = p.grad.to_local() if isinstance(p.grad, DTensor) else p.grad - if g.numel() == 0: - continue - val += g.detach().to(torch.float32).pow(norm_type).sum() + + p = float(norm_type) + val = torch.tensor(0.0, device=default_device, dtype=torch.float32) + if not grads_local: return val + non_empty = [g for g in grads_local if g.numel() > 0] + if not non_empty: + return val + grouped = _group_tensors_by_device_and_dtype([non_empty]) + for (device, _), ([device_grads], _) in grouped.items(): + if _has_foreach_support(device_grads, device) or _device_has_foreach_support(device): + # Batch: compute ||g||_p for each grad, raise to p-th power, then sum + out = torch._foreach_pow_(torch._foreach_norm(device_grads, p), p) + val += torch.sum(torch.stack(out)).to(default_device) + else: + for g in device_grads: + val += (torch.norm(g, p=p) ** p).to(default_device) + return val diff --git a/tests/moe/verify_qwen3_moe_permute_cpu.py b/tests/moe/verify_qwen3_moe_permute_cpu.py deleted file mode 100644 index 8f1ea248..00000000 --- a/tests/moe/verify_qwen3_moe_permute_cpu.py +++ /dev/null @@ -1,242 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -CPU script to verify numerical alignment between: -1) transformers original Qwen3MoeSparseMoeBlock forward -2) twinkle local path: permute -> expert compute -> unpermute - -This script does NOT use distributed init and does NOT include all_to_all communication. - -Usage: - python tests/moe/verify_qwen3_moe_permute_cpu.py - python tests/moe/verify_qwen3_moe_permute_cpu.py --transformers-root /mnt/d/workspace/transformers -""" - -from __future__ import annotations -from twinkle.model.transformers.moe.expert_parallel import _run_router -from twinkle.model.transformers.moe.ep_utils import generate_weights_idx, permute, unpermute - -import argparse -import copy -import sys -from pathlib import Path -from types import SimpleNamespace - -import torch -import torch.nn.functional as F - -# Allow running directly from repository root without installing twinkle. -REPO_ROOT = Path(__file__).resolve().parents[2] -SRC_DIR = REPO_ROOT / 'src' -if str(SRC_DIR) not in sys.path: - sys.path.insert(0, str(SRC_DIR)) - - -def _import_qwen3_block(transformers_root: str): - try: - from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock - - return Qwen3MoeSparseMoeBlock - except Exception: - src_dir = Path(transformers_root).expanduser().resolve() / 'src' - if not src_dir.exists(): - raise RuntimeError( - f'Cannot import transformers qwen3_moe, and fallback path does not exist: {src_dir}' - ) - sys.path.insert(0, str(src_dir)) - from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock - - return Qwen3MoeSparseMoeBlock - - -def _build_config( - hidden_size: int, - moe_intermediate_size: int, - num_experts: int, - top_k: int, - hidden_act: str, - norm_topk_prob: bool, -): - return SimpleNamespace( - hidden_size=hidden_size, - moe_intermediate_size=moe_intermediate_size, - num_experts=num_experts, - num_experts_per_tok=top_k, - hidden_act=hidden_act, - norm_topk_prob=norm_topk_prob, - _experts_implementation='eager', - ) - - -def _run_local_permute_expert_unpermute(block: torch.nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: - """Reproduce Twinkle EP compute path locally without communication.""" - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states_2d = hidden_states.view(-1, hidden_dim) - - _, routing_weights, selected_experts = _run_router( - gate=block.gate, - hidden_states=hidden_states_2d, - top_k=block.gate.top_k, - router_dtype=torch.float32, - norm_topk_prob=bool(getattr(block.gate, 'norm_topk_prob', False)), - ) - - num_experts = int(block.experts.num_experts) - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0) - routing_map = expert_mask.sum(dim=1) - - permuted_tokens, permutation_mapping = permute(hidden_states_2d, routing_map) - - num_tokens_per_expert = routing_map.sum(dim=1).to(dtype=torch.long) - cumsum = torch.cat( - [ - torch.zeros(1, dtype=torch.long, device=num_tokens_per_expert.device), - num_tokens_per_expert.cumsum(dim=0), - ] - ) - - outputs = [] - experts = block.experts - input_dtype = permuted_tokens.dtype - - for expert_idx in range(num_experts): - start = int(cumsum[expert_idx].item()) - end = int(cumsum[expert_idx + 1].item()) - x = permuted_tokens[start:end] - if x.numel() == 0: - outputs.append(x) - continue - - gate_up = experts.gate_up_proj[expert_idx] - down = experts.down_proj[expert_idx] - compute_dtype = gate_up.dtype - - if x.dtype != compute_dtype: - x = x.to(compute_dtype) - - gate, up = F.linear(x, gate_up).chunk(2, dim=-1) - out = experts.act_fn(gate) * up - out = F.linear(out, down) - - if out.dtype != input_dtype: - out = out.to(input_dtype) - - outputs.append(out) - - expert_outputs = ( - torch.cat(outputs, dim=0) - if outputs - else permuted_tokens.new_empty((0, permuted_tokens.size(-1))) - ) - - weights_idx = generate_weights_idx(routing_weights, selected_experts, num_experts) - final_hidden_states_2d = unpermute( - expert_outputs, - weights_idx, - hidden_states_2d.shape, - permutation_mapping, - routing_map, - ) - - return final_hidden_states_2d.view(batch_size, seq_len, hidden_dim) - - -def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: - return (a - b).abs().max().item() - - -@torch.no_grad() -def _init_block_weights(block: torch.nn.Module, seed: int, std: float = 0.02) -> None: - # Standalone Qwen3MoeSparseMoeBlock does not go through PreTrainedModel.post_init(), - # so initialize weights explicitly to avoid trivial all-zero / allocator-dependent cases. - generator = torch.Generator(device='cpu') - generator.manual_seed(seed) - - experts = block.experts - gate = block.gate - experts.gate_up_proj.copy_(torch.randn(experts.gate_up_proj.shape, generator=generator) * std) - experts.down_proj.copy_(torch.randn(experts.down_proj.shape, generator=generator) * std) - gate.weight.copy_(torch.randn(gate.weight.shape, generator=generator) * std) - - -def main(): - parser = argparse.ArgumentParser(description='CPU precision check for Qwen3-MoE sparse block.') - parser.add_argument('--transformers-root', type=str, default='/mnt/d/workspace/transformers') - parser.add_argument('--seed', type=int, default=42) - parser.add_argument('--batch-size', type=int, default=2) - parser.add_argument('--seq-len', type=int, default=1024) - parser.add_argument('--hidden-size', type=int, default=64) - parser.add_argument('--moe-intermediate-size', type=int, default=32) - parser.add_argument('--num-experts', type=int, default=8) - parser.add_argument('--top-k', type=int, default=2) - parser.add_argument('--hidden-act', type=str, default='silu') - parser.add_argument('--norm-topk-prob', action='store_true') - parser.add_argument('--atol', type=float, default=1e-6) - parser.add_argument('--rtol', type=float, default=1e-6) - args = parser.parse_args() - - torch.set_num_threads(1) - torch.manual_seed(args.seed) - - Qwen3MoeSparseMoeBlock = _import_qwen3_block(args.transformers_root) - cfg = _build_config( - hidden_size=args.hidden_size, - moe_intermediate_size=args.moe_intermediate_size, - num_experts=args.num_experts, - top_k=args.top_k, - hidden_act=args.hidden_act, - norm_topk_prob=args.norm_topk_prob, - ) - - ref_block = Qwen3MoeSparseMoeBlock(cfg).cpu().float().train() - _init_block_weights(ref_block, seed=args.seed) - test_block = copy.deepcopy(ref_block).cpu().float().train() - - hidden_ref = torch.randn(args.batch_size, args.seq_len, args.hidden_size, dtype=torch.float32, requires_grad=True) - hidden_test = hidden_ref.detach().clone().requires_grad_(True) - - ref_out = ref_block(hidden_ref) - test_out = _run_local_permute_expert_unpermute(test_block, hidden_test) - - # Use identical loss form for backward alignment. - proj = torch.randn(args.hidden_size, dtype=torch.float32) - ref_loss = (ref_out * proj).sum() - test_loss = (test_out * proj).sum() - - ref_loss.backward() - test_loss.backward() - - out_max_diff = _max_abs_diff(ref_out.detach(), test_out.detach()) - in_grad_max_diff = _max_abs_diff(hidden_ref.grad.detach(), hidden_test.grad.detach()) - - print('\n=== Qwen3-MoE Sparse Block CPU Alignment ===') - print(f'seed={args.seed} shape=({args.batch_size}, {args.seq_len}, {args.hidden_size})') - print(f'num_experts={args.num_experts} top_k={args.top_k} hidden_act={args.hidden_act}') - print(f'forward max_abs_diff: {out_max_diff:.8e}') - print(f'input grad max_abs_diff: {in_grad_max_diff:.8e}') - - param_ok = True - for (name_ref, p_ref), (name_test, p_test) in zip(ref_block.named_parameters(), test_block.named_parameters()): - if name_ref != name_test: - raise RuntimeError(f'Parameter name mismatch: {name_ref} vs {name_test}') - if p_ref.grad is None or p_test.grad is None: - raise RuntimeError(f'Missing grad for parameter: {name_ref}') - diff = _max_abs_diff(p_ref.grad.detach(), p_test.grad.detach()) - print(f'grad[{name_ref}] max_abs_diff: {diff:.8e}') - if not torch.allclose(p_ref.grad, p_test.grad, rtol=args.rtol, atol=args.atol): - param_ok = False - - out_ok = torch.allclose(ref_out, test_out, rtol=args.rtol, atol=args.atol) - in_grad_ok = torch.allclose(hidden_ref.grad, hidden_test.grad, rtol=args.rtol, atol=args.atol) - - print('\n=== Result ===') - print(f'forward aligned: {out_ok}') - print(f'input grad aligned: {in_grad_ok}') - print(f'param grad aligned: {param_ok}') - - if not (out_ok and in_grad_ok and param_ok): - raise SystemExit(1) - - -if __name__ == '__main__': - main() From 8a37a88fbccffb7978acab56768b34debc9c9958 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Sat, 28 Feb 2026 10:13:57 +0800 Subject: [PATCH 8/9] wip --- src/twinkle/model/transformers/moe/ep_utils.py | 10 +--------- src/twinkle/model/transformers/moe/expert_parallel.py | 4 ++-- src/twinkle/model/transformers/strategy/native_fsdp.py | 8 ++++---- src/twinkle/utils/grad_clip.py | 4 ++-- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py index 7446c6a6..d2c5d65b 100644 --- a/src/twinkle/model/transformers/moe/ep_utils.py +++ b/src/twinkle/model/transformers/moe/ep_utils.py @@ -1,6 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # -# Adapted from VeOmni (https://github.com/volcengine/VeOmni) # Copyright 2025 Bytedance Ltd. and/or its affiliates # Licensed under the Apache License, Version 2.0 @@ -11,9 +10,6 @@ # ========================== comm ========================== -# Ported from veomni/distributed/moe/comm.py - - class _AllToAll(torch.autograd.Function): @staticmethod def forward(ctx, group, input, output_split_sizes, input_split_sizes): @@ -98,9 +94,6 @@ def all_to_all_async(group, input, output_split_size, input_split_size): # ========================== moe_utils ========================== -# Ported from veomni/distributed/moe/moe_utils.py - - def permute(tokens: torch.Tensor, routing_map: torch.Tensor): """ Permutes the tokens according to the routing map. @@ -185,8 +178,7 @@ def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_i # ========================== moe_layer ========================== -# Ported from veomni/distributed/moe/moe_layer.py (preprocess, token_pre_all2all, tokens_post_all2all) -# EPGroupGemm is NOT ported (requires triton group_gemm); Twinkle uses F.linear loop instead. +# EPGroupGemm is not included here; Twinkle uses an F.linear loop instead. def preprocess( diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 27cd52b9..6a9514b1 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -162,7 +162,7 @@ def patch_forward( ) -> None: """Replace the MoE block forward with EP-aware communication flow. - Communication pattern follows VeOmni: + Communication pattern: preprocess → token_pre_all2all → expert_compute → tokens_post_all2all For tensor experts (gate_up_proj/down_proj), the expert compute is delegated @@ -227,7 +227,7 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): if routing_weights.dtype != hidden_states_2d.dtype: routing_weights = routing_weights.to(hidden_states_2d.dtype) - # Build expert_mask: [num_experts, top_k, num_tokens] (VeOmni convention) + # Build expert_mask: [num_experts, top_k, num_tokens] expert_mask = torch.nn.functional.one_hot( selected_experts, num_classes=num_experts ).permute(2, 1, 0) # [num_experts, top_k, num_tokens] diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 7a700c46..29a6ec3f 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -51,7 +51,7 @@ def wrap_model(self, model, optimizer=None): experts_mod = _find_experts_in_layer(layer_mod, experts_map) layer_pairs.append((layer_mod, experts_mod)) - # FSDP2 wrapping per layer (mirrors VeOmni parallelize_model_fsdp2) + # FSDP2 wrapping per layer world_size = self.device_mesh.world_size ep_fsdp_mesh_1d = self.ep_fsdp_device_mesh['ep_fsdp'] if ep_enabled else None @@ -71,7 +71,7 @@ def wrap_model(self, model, optimizer=None): mp_policy=ep_mp_policy, shard_placement_fn=lambda param: Shard(1), ) - # gradient_divide_factor = world_size (VeOmni convention) + # gradient_divide_factor = world_size experts_mod.set_gradient_divide_factor(world_size) layer_mod._fsdp_modules.append(experts_mod) @@ -93,7 +93,7 @@ def wrap_model(self, model, optimizer=None): ignored_params=expert_params, ) - # Manual prefetch (mirrors VeOmni lines 396-411) + # Manual prefetch if ep_enabled and layer_pairs: _setup_manual_prefetch([lp[0] for lp in layer_pairs]) @@ -214,7 +214,7 @@ def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Modul def _setup_manual_prefetch(blocks: list) -> None: - """Configure forward/backward prefetch for FSDP modules (mirrors VeOmni).""" + """Configure forward/backward prefetch for FSDP modules.""" for i, block in enumerate(blocks): if i + 1 < len(blocks): next_fsdp_modules = getattr(blocks[i + 1], '_fsdp_modules', []) diff --git a/src/twinkle/utils/grad_clip.py b/src/twinkle/utils/grad_clip.py index efe7e9ef..8dced18c 100644 --- a/src/twinkle/utils/grad_clip.py +++ b/src/twinkle/utils/grad_clip.py @@ -41,7 +41,7 @@ def normalize_and_clip_grad_norm(parameters: Iterable[torch.nn.Parameter], if not grads: return 0.0 - # EP-aware path (mirrors VeOmni ep_fsdp2_clip_grad_norm) + # EP-aware path if ep_param_groups is not None: return _ep_aware_clip_grad_norm( ep_param_groups=ep_param_groups, @@ -146,7 +146,7 @@ def _ep_aware_clip_grad_norm( ep_group=None, ep_fsdp_group=None, ) -> float: - """EP-aware gradient clipping (mirrors VeOmni ep_fsdp2_clip_grad_norm). + """EP-aware gradient clipping. - non-EP params: all-reduce over fsdp_group - EP params: all-reduce over ep_fsdp_group, then ep_group From ad51ef65dcf97945825e2683256360a81f25f5e0 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 4 Mar 2026 14:09:39 +0800 Subject: [PATCH 9/9] lint --- .../model/transformers/moe/ep_utils.py | 11 ++- .../model/transformers/moe/expert_parallel.py | 20 ++--- .../transformers/strategy/native_fsdp.py | 8 +- .../model/transformers/transformers.py | 3 +- src/twinkle/utils/device_mesh.py | 9 +- src/twinkle/utils/grad_clip.py | 11 +-- tests/moe/test_ep_fsdp_vs_single.py | 90 ++++++++++--------- 7 files changed, 74 insertions(+), 78 deletions(-) diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py index d2c5d65b..5d740afe 100644 --- a/src/twinkle/model/transformers/moe/ep_utils.py +++ b/src/twinkle/model/transformers/moe/ep_utils.py @@ -3,14 +3,14 @@ # Copyright 2025 Bytedance Ltd. and/or its affiliates # Licensed under the Apache License, Version 2.0 -from typing import Optional - import torch import torch.distributed as dist +from typing import Optional # ========================== comm ========================== class _AllToAll(torch.autograd.Function): + @staticmethod def forward(ctx, group, input, output_split_sizes, input_split_sizes): ctx.group = group @@ -48,6 +48,7 @@ def backward(ctx, *grad_output): class _AllToAll_Async(torch.autograd.Function): + @staticmethod def forward(ctx, group, input, output_split_sizes, input_split_sizes): ctx.group = group @@ -213,12 +214,10 @@ def preprocess( # [num_local_expert] num_global_sum_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=0).to( - torch.device("cpu"), non_blocking=True - ) + torch.device('cpu'), non_blocking=True) num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(-1, num_local_experts).to( - torch.device("cpu"), non_blocking=True - ) + torch.device('cpu'), non_blocking=True) return input_splits, output_splits, num_global_tokens_per_local_expert, num_global_sum_tokens_per_local_expert diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 6a9514b1..c9ee0fa0 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -8,12 +8,8 @@ from torch import nn from typing import Any, Dict, Iterable, List, Optional, Tuple +from twinkle.model.transformers.moe.ep_utils import preprocess, token_pre_all2all, tokens_post_all2all from twinkle.utils import DeviceMesh -from twinkle.model.transformers.moe.ep_utils import ( - preprocess, - token_pre_all2all, - tokens_post_all2all, -) @dataclass @@ -22,7 +18,7 @@ class ExpertParallelConfig: router_dtype: str = 'fp32' keep_router_logits: bool = True ignore_shared_experts: bool = False - ep_size: Optional[int] = None # consumed by TransformersModel, not used in expert_parallel logic + ep_size: int | None = None # consumed by TransformersModel, not used in expert_parallel logic @dataclass @@ -43,8 +39,8 @@ def apply_expert_parallel( model: nn.Module, device_mesh: DeviceMesh, config: dict[str, Any] | None = None, - ep_fsdp_device_mesh: Optional['torch.distributed.DeviceMesh'] = None, -) -> List[ExpertShardingSpec]: + ep_fsdp_device_mesh: torch.distributed.DeviceMesh | None = None, +) -> list[ExpertShardingSpec]: """Apply expert parallelism to all MoE blocks in the model.""" cfg = _merge_config(config) @@ -229,8 +225,7 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs): # Build expert_mask: [num_experts, top_k, num_tokens] expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=num_experts - ).permute(2, 1, 0) # [num_experts, top_k, num_tokens] + selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens] # 1. preprocess: compute splits and token counts ( @@ -346,9 +341,8 @@ def ep_forward( out = out.to(input_dtype) output_chunks.append(out) - return torch.cat(output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty( - 0, permuted_tokens.size(-1) - ) + return torch.cat( + output_chunks, dim=0) if output_chunks else permuted_tokens.new_empty(0, permuted_tokens.size(-1)) import types experts_mod.forward = types.MethodType(ep_forward, experts_mod) diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 29a6ec3f..3ef38030 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -17,7 +17,8 @@ def __init__(self, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', fsdp_config: Dict[str, Any] = None, - enable_ep: bool = True, ep_fsdp_device_mesh: Optional[TorchDeviceMesh] = None): + enable_ep: bool = True, + ep_fsdp_device_mesh: Optional[TorchDeviceMesh] = None): self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.fsdp_config = fsdp_config or {} @@ -60,6 +61,7 @@ def wrap_model(self, model, optimizer=None): if experts_mod is not None and ep_fsdp_mesh_1d is not None: from torch.distributed.tensor import Shard + # PreMulSum (used by set_gradient_divide_factor) only supports # float16/float32/float64; override reduce_dtype to float32 # when the base policy uses bfloat16. @@ -254,9 +256,7 @@ def _ensure_moe_patched_if_needed(model: nn.Module, ep_fsdp_device_mesh: Optiona for module in model.modules(): experts = getattr(module, 'experts', None) is_moe_experts = ( - isinstance(experts, nn.ModuleList) - or (hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj')) - ) + isinstance(experts, nn.ModuleList) or (hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj'))) if is_moe_experts and not getattr(module, '_ep_patched', False): raise RuntimeError('Found MoE experts but expert parallel is not applied. ' 'Call apply_expert_parallel(model, device_mesh, config) before wrapping with FSDP2.') diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index fadcf7be..3d72a374 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -219,8 +219,7 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): self.device_mesh) self._expert_parallel_applied = False # Store ep_size for later use (EP mesh construction, grad clip, etc.) - self._ep_size = (self._expert_parallel_config.get('ep_size') - if self._expert_parallel_config else None) + self._ep_size = (self._expert_parallel_config.get('ep_size') if self._expert_parallel_config else None) if self._ep_size is None and self.device_mesh is not None: self._ep_size = getattr(self.device_mesh, 'ep_size', None) if self._ep_size is None: diff --git a/src/twinkle/utils/device_mesh.py b/src/twinkle/utils/device_mesh.py index 7af1dd42..dedd3012 100644 --- a/src/twinkle/utils/device_mesh.py +++ b/src/twinkle/utils/device_mesh.py @@ -209,13 +209,12 @@ def build_ep_fsdp_device_mesh(self, ep_size: int = None): if ep_size <= 1: return None world_size = self.world_size - assert world_size % ep_size == 0, ( - f'world_size ({world_size}) must be divisible by ep_size ({ep_size})') + assert world_size % ep_size == 0, (f'world_size ({world_size}) must be divisible by ep_size ({ep_size})') ep_fsdp_size = world_size // ep_size with torch.device('cpu'): - mesh = (torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int) - .view(ep_fsdp_size, ep_size) - .transpose(0, 1)) + mesh = ( + torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int).view(ep_fsdp_size, + ep_size).transpose(0, 1)) return torch.distributed.DeviceMesh(self.device_type, mesh, mesh_dim_names=('ep', 'ep_fsdp')) @property diff --git a/src/twinkle/utils/grad_clip.py b/src/twinkle/utils/grad_clip.py index 8dced18c..9a3fb685 100644 --- a/src/twinkle/utils/grad_clip.py +++ b/src/twinkle/utils/grad_clip.py @@ -181,7 +181,7 @@ def _ep_aware_clip_grad_norm( if math.isinf(norm_type): total_norm = torch.maximum(non_ep_val, ep_val) else: - total_norm = (non_ep_val + ep_val) ** (1.0 / norm_type) + total_norm = (non_ep_val + ep_val)**(1.0 / norm_type) # Clip both groups with the same coefficient via PyTorch builtin (foreach-accelerated) torch.nn.utils.clip_grads_with_norm_(ep_params, max_grad_norm, total_norm) @@ -198,11 +198,8 @@ def _local_norm_stat(params, norm_type: float): import math import torch from torch.distributed._tensor import DTensor - from torch.utils._foreach_utils import ( - _device_has_foreach_support, - _group_tensors_by_device_and_dtype, - _has_foreach_support, - ) + from torch.utils._foreach_utils import (_device_has_foreach_support, _group_tensors_by_device_and_dtype, + _has_foreach_support) grads_local = [] default_device = None @@ -240,5 +237,5 @@ def _local_norm_stat(params, norm_type: float): val += torch.sum(torch.stack(out)).to(default_device) else: for g in device_grads: - val += (torch.norm(g, p=p) ** p).to(default_device) + val += (torch.norm(g, p=p)**p).to(default_device) return val diff --git a/tests/moe/test_ep_fsdp_vs_single.py b/tests/moe/test_ep_fsdp_vs_single.py index f610ce30..8dec9c09 100644 --- a/tests/moe/test_ep_fsdp_vs_single.py +++ b/tests/moe/test_ep_fsdp_vs_single.py @@ -20,7 +20,6 @@ from twinkle.model.transformers.strategy import NativeFSDPStrategy from twinkle.utils import DeviceMesh - ABS_TOL = 5e-3 LOSS_TOL = 1e-4 REL_TOL = 1e-4 @@ -50,16 +49,18 @@ def _load_model(model_id: str, local_only: bool, device: torch.device, num_layer config._experts_implementation = 'eager' # 关闭 Dropout,确保确定性 - dropout_attrs = ['attention_dropout', 'hidden_dropout', 'classifier_dropout', - 'resid_pdrop', 'embd_pdrop'] + dropout_attrs = ['attention_dropout', 'hidden_dropout', 'classifier_dropout', 'resid_pdrop', 'embd_pdrop'] for attr in dropout_attrs: if hasattr(config, attr): setattr(config, attr, 0.0) model = AutoModelForCausalLM.from_pretrained( - model_id, config=config, torch_dtype=torch.float32, - low_cpu_mem_usage=True, trust_remote_code=True, local_files_only=local_only - ) + model_id, + config=config, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + trust_remote_code=True, + local_files_only=local_only) model.to(device) return model @@ -94,12 +95,8 @@ def _split_range(total: int, rank: int, world_size: int) -> tuple[int, int]: return start, end -def _match_single_tensor_for_compare(mapped_name: str, - multi_tensor: torch.Tensor, - single_dict: dict, - ep_rank: int, - fsdp_rank: int, - fsdp_world_size: int): +def _match_single_tensor_for_compare(mapped_name: str, multi_tensor: torch.Tensor, single_dict: dict, ep_rank: int, + fsdp_rank: int, fsdp_world_size: int): """返回和 multi_tensor 对齐后的 single tensor;expert 参数支持 dim0(ep)+dim1(fsdp) 切片。""" single_tensor = single_dict.get(mapped_name) if single_tensor is None: @@ -149,8 +146,12 @@ def _match_single_tensor_for_compare(mapped_name: str, def _run_single_gpu(rank, world_size, port, model_id, local_only): """Single GPU baseline.""" os.environ.update({ - 'RANK': '0', 'WORLD_SIZE': '1', 'LOCAL_RANK': '0', - 'LOCAL_WORLD_SIZE': '1', 'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': str(port) + 'RANK': '0', + 'WORLD_SIZE': '1', + 'LOCAL_RANK': '0', + 'LOCAL_WORLD_SIZE': '1', + 'MASTER_ADDR': '127.0.0.1', + 'MASTER_PORT': str(port) }) if not torch.cuda.is_available(): @@ -180,14 +181,15 @@ def _run_single_gpu(rank, world_size, port, model_id, local_only): new_weight_dict = {n: p.detach().cpu() for n, p in model.named_parameters()} # 保存输入数据,确保多卡使用相同输入 - torch.save({ - 'input_ids': input_ids.cpu(), - 'position_ids': position_ids.cpu(), - 'logits': outputs.logits.detach().cpu(), - 'loss': loss.item(), - 'grad_dict': grad_dict, - 'new_weight_dict': new_weight_dict, - }, _single_snapshot_path(port)) + torch.save( + { + 'input_ids': input_ids.cpu(), + 'position_ids': position_ids.cpu(), + 'logits': outputs.logits.detach().cpu(), + 'loss': loss.item(), + 'grad_dict': grad_dict, + 'new_weight_dict': new_weight_dict, + }, _single_snapshot_path(port)) print(f'[Single] Loss={loss.item():.4f}') @@ -195,9 +197,12 @@ def _run_single_gpu(rank, world_size, port, model_id, local_only): def _run_multi_gpu(rank, world_size, port, model_id, local_only): """4-GPU EP+FSDP with two independent meshes.""" os.environ.update({ - 'RANK': str(rank), 'WORLD_SIZE': str(world_size), - 'LOCAL_RANK': str(rank), 'LOCAL_WORLD_SIZE': str(world_size), - 'MASTER_ADDR': '127.0.0.1', 'MASTER_PORT': str(port) + 'RANK': str(rank), + 'WORLD_SIZE': str(world_size), + 'LOCAL_RANK': str(rank), + 'LOCAL_WORLD_SIZE': str(world_size), + 'MASTER_ADDR': '127.0.0.1', + 'MASTER_PORT': str(port) }) if not torch.cuda.is_available(): @@ -206,9 +211,12 @@ def _run_multi_gpu(rank, world_size, port, model_id, local_only): torch.cuda.set_device(device) dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size, - init_method=f'tcp://127.0.0.1:{port}', device_id=device, timeout=timedelta(minutes=15) - ) + backend='nccl', + rank=rank, + world_size=world_size, + init_method=f'tcp://127.0.0.1:{port}', + device_id=device, + timeout=timedelta(minutes=15)) dist.barrier() try: @@ -222,7 +230,7 @@ def _run_multi_gpu(rank, world_size, port, model_id, local_only): device_mesh = DeviceMesh( device_type='cuda', mesh=np.arange(world_size).reshape(world_size), # 1D mesh: [0, 1, 2, 3] - mesh_dim_names=('fsdp',), + mesh_dim_names=('fsdp', ), ep_size=ep_size, # ep_size as attribute, not mesh dimension ) @@ -243,7 +251,11 @@ def _run_multi_gpu(rank, world_size, port, model_id, local_only): apply_expert_parallel( getattr(model, 'model', model), device_mesh, - config={'enabled': True, 'router_dtype': 'fp32', 'keep_router_logits': False}, + config={ + 'enabled': True, + 'router_dtype': 'fp32', + 'keep_router_logits': False + }, ep_fsdp_device_mesh=ep_fsdp_mesh, ) @@ -303,8 +315,7 @@ def _run_multi_gpu(rank, world_size, port, model_id, local_only): loss_diff = abs(single_data['loss'] - loss.item()) single_loss_t = torch.tensor(single_data['loss'], dtype=torch.float32) multi_loss_t = torch.tensor(loss.item(), dtype=torch.float32) - if (not torch.allclose(single_loss_t, multi_loss_t, rtol=REL_TOL, atol=LOSS_TOL) - and forward_err is None): + if (not torch.allclose(single_loss_t, multi_loss_t, rtol=REL_TOL, atol=LOSS_TOL) and forward_err is None): forward_err = f'Loss mismatch! Diff: {loss_diff}' obj = [forward_err] @@ -372,20 +383,17 @@ def _run_multi_gpu(rank, world_size, port, model_id, local_only): ratio_list.append(ratio) is_close = torch.allclose(s_grad, m_grad, rtol=REL_TOL, atol=ABS_TOL) status = 'PASS' if is_close else 'FAIL' - print( - f' [{status}] {mapped_n}: max_diff={grad_max_diff:.2e}, mean_diff={grad_mean_diff:.2e}, ' - f'norm_ratio(ep/single)={ratio:.4f}') - assert is_close, ( - f'Expert grad mismatch for {mapped_n}! ' - f'Max diff: {grad_max_diff}, Mean diff: {grad_mean_diff}, Ratio: {ratio}') + print(f' [{status}] {mapped_n}: max_diff={grad_max_diff:.2e}, mean_diff={grad_mean_diff:.2e}, ' + f'norm_ratio(ep/single)={ratio:.4f}') + assert is_close, (f'Expert grad mismatch for {mapped_n}! ' + f'Max diff: {grad_max_diff}, Mean diff: {grad_mean_diff}, Ratio: {ratio}') verified += 1 if verified == 0: print(f' [INFO] No expert gradients matched on rank {rank} (EP distribution is expected)') else: ratio_t = torch.tensor(ratio_list, dtype=torch.float32) - print( - f' [INFO] expert grad norm ratio(ep/single): ' - f'min={ratio_t.min().item():.4f}, max={ratio_t.max().item():.4f}, mean={ratio_t.mean().item():.4f}') + print(f' [INFO] expert grad norm ratio(ep/single): ' + f'min={ratio_t.min().item():.4f}, max={ratio_t.max().item():.4f}, mean={ratio_t.mean().item():.4f}') # 对比更新后的权重 print(f'\n=== Rank {rank}: Updated Weights ===')