Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
162217c
Generalized Tensor Parallelism (GTP) init commit
fanshiqing May 25, 2026
1cab66a
fix conflicts
fanshiqing May 25, 2026
ccfee04
code clean
fanshiqing May 25, 2026
8b4041a
fix comments; defer wgrad->dgrad support in following up MR;
fanshiqing May 25, 2026
1903598
fix comments
fanshiqing May 26, 2026
9eb5007
code clean.
fanshiqing May 27, 2026
019536f
Fix GTP broadcast_params + add partial DP-CP with GTP group
fanshiqing May 27, 2026
375f09c
fix none-egpt sharded params's reduction in moe layer; fix comments
fanshiqing May 27, 2026
6872981
rename 'generalized_tensor_parallel_size' into 'generalized_tensor_pa…
fanshiqing May 28, 2026
6c28127
update README
fanshiqing May 28, 2026
02f76a7
fix comments
fanshiqing May 28, 2026
476aa05
Fix EGTP correctness on cudagraph bwd capture + main_param dedup
fanshiqing May 28, 2026
cdf5d35
fix comments
fanshiqing May 29, 2026
e885461
Merge remote-tracking branch 'adlr_github/main' into gtp_release
fanshiqing Jun 1, 2026
3c84291
Only batch with _foreach_add_ when finalizing multiple (routed) weight
fanshiqing Jun 1, 2026
392816a
gtp+gmm-fusion: support offloading(moe-act-input)
fanshiqing Jun 2, 2026
fc570d0
GTP + full-iter CG
fanshiqing Jun 3, 2026
23ed3ba
[feat]GTP: prefetch recompute-forward weight gathers via a separate c…
fanshiqing Jun 4, 2026
ecf2dd1
GTP: allocate GRAPHED buffers into CG mempool at creation; fix comments
fanshiqing Jun 4, 2026
1163b4a
fix comments
fanshiqing Jun 4, 2026
5dc0423
fix onlince checks: copyright, intallation test, build.
fanshiqing Jun 4, 2026
ab869b9
fix te min version required for GTP.
fanshiqing Jun 5, 2026
45a604f
fix online UTs; fix comments.
fanshiqing Jun 5, 2026
87a50cd
fix UTs
fanshiqing Jun 5, 2026
107c077
fix UTs
fanshiqing Jun 6, 2026
9dca05a
Fix GTP DDP bucket alignment for distributed optimizer; add correspon…
fanshiqing Jun 6, 2026
0906db0
fix formating
fanshiqing Jun 6, 2026
6cdfe5d
fix regular ddp buffer bucket misalignment when GTP params are present
fanshiqing Jun 6, 2026
4d1e2eb
add integration test for {mamba,attn,moe}+gtp; polish existing gtp an…
fanshiqing Jun 8, 2026
4695217
fix comments from Jimmy and Deepak
fanshiqing Jun 10, 2026
3725f51
feat: make (E)GTP a first-class orthogonal parallelism axis
fanshiqing Jun 11, 2026
84e6a7d
code clean
fanshiqing Jun 11, 2026
42fd06f
Merge remote-tracking branch 'adlr_github/main' into gtp_release
fanshiqing Jun 12, 2026
b8b078a
code clean
fanshiqing Jun 12, 2026
83bea9f
Generate the DDP param layerout for the GTP replicate group at it's s…
fanshiqing Jun 12, 2026
22fc6e6
fold the GTP intra DP groups into intra_dp_cp_group and intra_expt_dp…
fanshiqing Jun 12, 2026
0a55ed5
fix UTs
fanshiqing Jun 12, 2026
be22dce
[feat] GTP+DCP
fanshiqing Jun 13, 2026
70ef35d
rename gtp-exclude process group: with_gtp -> no_gtp
fanshiqing Jun 13, 2026
e430807
fix dense GTP NCCL group using stale 'ps' key
fanshiqing Jun 15, 2026
5a8c469
update README with scalability
fanshiqing Jun 17, 2026
aa40d0d
fix comments
fanshiqing Jun 18, 2026
601a658
fix comments
fanshiqing Jun 19, 2026
eddb7ba
Rename GTP remat knobs and add num-weight-shards user API
fanshiqing Jun 19, 2026
6806f43
Support GTP/EGTP in LayerWiseDistributedOptimizer and Muon (#3)
deepakn94 Jun 19, 2026
7d71c08
GTP+DCP: simplify gtp replica_ids in MambaMixer.sharded_state_dict; a…
fanshiqing Jun 19, 2026
69dae5a
GTP+Muon: fix DCP save/load; add corresponding UTs
fanshiqing Jun 19, 2026
f13a042
code clean and fix comments
fanshiqing Jun 20, 2026
0e3c3d2
Fix GTP DDP grad-ready firing before deferred wgrad accumulation
fanshiqing Jun 23, 2026
f6dca05
fix format and comments
fanshiqing Jun 24, 2026
c33667a
Merge remote-tracking branch 'adlr_github/main' into gtp_release
fanshiqing Jun 25, 2026
ce02728
fix comments
fanshiqing Jun 25, 2026
ae8a571
fix linting
fanshiqing Jun 25, 2026
71a53c6
Merge remote-tracking branch 'adlr_github/main' into gtp_release
fanshiqing Jun 25, 2026
6aeecc1
Fix optional process-group fallbacks defeated by __getattr__; Log hum…
fanshiqing Jun 25, 2026
114b6fc
Fix GTP grad norm inflated on CUDA-graph capture step; fix linting
fanshiqing Jun 25, 2026
3c7aa6c
fix online UTs
fanshiqing Jun 25, 2026
1b066a5
Simplify GTP grad-norm fix: drop unnecessary bwd-graph backup
fanshiqing Jun 26, 2026
7d7e8c3
Move GTP from megatron.experimental into megatron.core
fanshiqing Jun 26, 2026
14464b5
GTP+CG: code clean: replace GTP bwd Phase-2 completion event with a r…
fanshiqing Jun 26, 2026
083c15f
add gtp public API file
fanshiqing Jun 26, 2026
8aa2b6d
GTP: clean up generalized_tensor_parallelism after the core move
fanshiqing Jun 26, 2026
dc629cc
fix1: populate EGTP-excluded expert-DP groups in get_default_pg_colle…
fanshiqing Jun 26, 2026
00c9d20
Fix: defer global TP/DP group reads in _backfill_gtp_sharded_param_map
fanshiqing Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
419 changes: 419 additions & 0 deletions docs/api-guide/core/generalized_tensor_parallel.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/api-guide/core/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Low-level API reference for core Megatron components.

transformer
tensor_parallel
generalized_tensor_parallel
pipeline_parallel
fusions
distributed
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 58 additions & 7 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,36 @@ def __init__(
# Assign all required process groups
self.dp_group = process_group_dict['dp_group']
self.dp_cp_group = process_group_dict['dp_cp_group']
self.intra_dp_cp_group = process_group_dict['intra_dp_cp_group']
self.expt_dp_group = process_group_dict['expt_dp_group']
self.intra_expt_dp_group = process_group_dict['intra_expt_dp_group']
# Example process-group sizes (e.g., TP=2, GTP=64, world_size=1024 with PP=CP=EP=1 and
# single DistOpt instance):
# model_size = TP x PP x CP x GTP = 2 x 64 = 128 -> DP = 1024 / 128 = 8.
# The model weights are replicated DP (= 8) times.
# dp_cp_group (degree of batch sharding, includes GTP) = GTP x DP = 64 * 8 = 512.
# dp_cp_no_gtp_group (degree of weight replication, excludes GTP) = 8.
# gtp_group = 64.
# tp_group = 2.
#
# Data-parallel gradient reductions for each bucket are performed over dp_cp_no_gtp_group
# (GTP-excluded group). Data-parallel gradient reductions over the GTP group are completed
# separately in the model backward pass.
#
# See Section 3.2 in `docs/api-guide/core/generalized_tensor_parallel.md`
# for more details (including why average_in_collective=False).
#
# When GTP is disabled, the *_no_gtp groups alias the regular DP groups.
self.intra_dp_cp_group = process_group_dict.get(
'intra_dp_cp_no_gtp_group', process_group_dict['intra_dp_cp_group']
)
self.intra_expt_dp_group = process_group_dict.get(
'intra_expt_dp_no_egtp_group', process_group_dict['intra_expt_dp_group']
)
Comment thread
fanshiqing marked this conversation as resolved.
# Full cross-instance, GTP-peer-EXCLUDED groups for broadcast_params (init-time weight
Comment thread
fanshiqing marked this conversation as resolved.
# sync must reach all true replicas). Fall back to the full DP groups when GTP is off.
self.dp_cp_no_gtp_group = process_group_dict.get('dp_cp_no_gtp_group', self.dp_cp_group)
Comment thread
fanshiqing marked this conversation as resolved.
self.expt_dp_no_egtp_group = process_group_dict.get(
'expt_dp_no_egtp_group', self.expt_dp_group
)
self.tp_group = process_group_dict['tp_group']
self.pp_group = process_group_dict['pp_group']
self.ep_group = process_group_dict['ep_group']
Expand Down Expand Up @@ -166,6 +193,15 @@ def __init__(

self.full_param_layout = full_param_layout

# GTP needs average_in_collective=False: the per-bucket collective runs over the
# GTP-EXCLUDED group, so NCCL AVG would miss the 1/gtp factor. arguments.py guards the
# training path; this assert covers direct megatron-core users.
gtp_active = ProcessGroupCollection.is_gtp_active(process_group_dict)
assert not (gtp_active and self.ddp_config.average_in_collective), (
"GTP requires average_in_collective=False (the default); averaged collectives reduce "
"over the GTP-excluded group and would miss the 1/gtp gradient scaling factor."
)

# Compute gradient scaling factors.
if config.calculate_per_token_loss:
assert (
Expand Down Expand Up @@ -364,6 +400,15 @@ def unmap_weight_tensor(m):
self._make_backward_post_hook(param)
)
break
elif getattr(param, 'is_gtp', False) and hasattr(param, 'register_grad_accum_hook'):
# GTP: drive the post-hook from GTP's manual invocation, not autograd's
# AccumulateGrad. GTP issues the wgrad RS async and defers the main_grad add
# to a later backward node, so AccumulateGrad can fire register_grad_ready
# before the wgrad lands in main_grad, dispatching the bucket reduce-scatter on
# stale grad_data (corrupts reduce_scatter_with_fp32_accumulation for
# chain-boundary weights). GTP fires this hook from _handle_megatron_grad_accum
# after the add instead.
param.register_grad_accum_hook(None, self._make_backward_post_hook(param))
else:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
Expand Down Expand Up @@ -460,9 +505,12 @@ def hook(*unused):
if param in self.param_to_bucket_group:
assert param.requires_grad
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
# GTP params legitimately have grad=None (async RS writes wgrad straight
# into main_grad), so skip the assertion for them.
if not getattr(param, 'is_gtp', False):
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
Expand Down Expand Up @@ -585,11 +633,14 @@ def broadcast_params(self):
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
is_gtp = getattr(param, 'is_gtp', False)

# Each (E)GTP peer holds a distinct 1/N shard, so broadcast over the (E)GTP-EXCLUDED
# group — else rank-0's shard would clobber the others.
if is_expert_parallel:
data_parallel_group = self.expt_dp_group
data_parallel_group = self.expt_dp_no_egtp_group if is_gtp else self.expt_dp_group
else:
data_parallel_group = self.dp_cp_group
data_parallel_group = self.dp_cp_no_gtp_group if is_gtp else self.dp_cp_group
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
Expand Down
51 changes: 51 additions & 0 deletions megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,49 @@ def _allreduce_non_tensor_model_parallel_grads(
_allreduce_layernorm_grads = _allreduce_non_tensor_model_parallel_grads


def _allreduce_replicated_grads_over_gtp_group(model: List[torch.nn.Module]):
"""Sum wgrads for replicated parameters over the gtp / egtp group.

The data-parallel collective already reduced wgrads over the GTP-excluded process groups with
1/full scaling, so the gtp-axis terms are still missing. A plain SUM (not AVG) over the gtp/egtp
group adds them and yields the exact full mean. No-op when GTP is inactive (group size <= 1).
"""
gtp_group = parallel_state.get_gtp_weight_remat_group(check_initialized=False)
egtp_group = parallel_state.get_expert_gtp_weight_remat_group(check_initialized=False)

dense_params, dense_grads = [], []
expert_params, expert_grads = [], []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if not param.requires_grad or getattr(param, 'is_gtp', False):
continue # GTP-sharded params: their gtp axis is handled by the RS-mean.
grad_attr = _get_main_grad_attr(param)
grad = getattr(param, grad_attr, None)
if grad is None:
continue
grad = _unshard_if_dtensor(grad)
if getattr(param, 'allreduce', True):
dense_params.append(param)
dense_grads.append(grad.data)
else:
expert_params.append(param)
expert_grads.append(grad.data)

for params, grads, group in (
(dense_params, dense_grads, gtp_group),
(expert_params, expert_grads, egtp_group),
):
if not grads or group is None or group.size() <= 1:
continue
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, op=torch.distributed.ReduceOp.SUM, group=group)
for param, buf, synced in zip(params, grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
grad_attr = _get_main_grad_attr(param)
orig_grad = getattr(param, grad_attr)
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))

Comment thread
fanshiqing marked this conversation as resolved.

def finalize_model_grads(
model: List[torch.nn.Module],
num_tokens: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -495,6 +538,12 @@ def finalize_model_grads(
pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False)
dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)

# Fence the current stream against all GTP backward grad work before the DP gradient sync.
if config.gtp_weight_remat_size > 1 or config.expert_gtp_weight_remat_size > 1:
from megatron.core.tensor_parallel.gtp import wait_for_gtp_grad_reduction_on_current_stream

wait_for_gtp_grad_reduction_on_current_stream()

# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
Expand All @@ -521,6 +570,8 @@ def finalize_model_grads(
barrier=config.barrier_with_L1_time
)
_allreduce_non_tensor_model_parallel_grads(model, config, tp_group)
# Complete the gtp-axis reduction for replicated (non-GTP) params (no-op when GTP inactive).
_allreduce_replicated_grads_over_gtp_group(model)
if config.timers is not None:
config.timers('non-tensor-parallel-grads-all-reduce').stop()

Expand Down
1 change: 1 addition & 0 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,7 @@ def __init__(
param_layout = _compute_default_per_buffer_param_layout(self.params, bucket_size)
self.param_index_map = param_layout.param_index_map
self.bucket_indices = param_layout.bucket_indices
self.num_optimizer_shards = param_layout.num_optimizer_shards
per_bucket_numel_unpadded = param_layout.per_bucket_numel_unpadded

# Check if this buffer contains NVFP4 params.
Expand Down
58 changes: 52 additions & 6 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn.parameter import Parameter
from typing_extensions import override

from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.enums import Fp4Recipe, Fp8Recipe
from megatron.core.model_parallel_config import ModelParallelConfig
Expand Down Expand Up @@ -381,6 +381,23 @@ def condition_init_method(config, init_method):
return init_method if config.perform_initialization else (lambda w: None)


def _maybe_setup_gtp(module, gtp_group, extra_kwargs):
"""Wire an active GTP group (size > 1) into TE's extra_kwargs and set module.gtp_size.

No-op when GTP is inactive (gtp_group is None or size 1), so module.gtp_size stays unset.
"""
if gtp_group is None or gtp_group.size() <= 1:
return
from megatron.core.tensor_parallel.gtp import HAVE_GTP

assert HAVE_GTP, (
"GTP requires TransformerEngine >= 2.17. "
"Set MEGATRON_GTP_FORCE_ENABLE=1 to bypass for custom TE builds."
)
module.gtp_size = get_pg_size(gtp_group)
extra_kwargs["gtp_group"] = gtp_group if torch.distributed.is_initialized() else None


def split_te_layernorm_column_parallel_linear(
fused_layer,
config,
Expand Down Expand Up @@ -762,6 +779,7 @@ def __init__(
symmetric_ar_type: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
name: str | None = None,
gtp_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Args:
Expand Down Expand Up @@ -895,6 +913,7 @@ def __init__(
self.te_quant_params, torch.is_grad_enabled()
)

_maybe_setup_gtp(self, gtp_group, extra_kwargs)
with init_quant_context:
super().__init__(
in_features=input_size,
Expand Down Expand Up @@ -1004,6 +1023,7 @@ def __init__(
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
gtp_group: Optional[torch.distributed.ProcessGroup] = None,
stride: int = 1,
name: str | None = None,
):
Expand Down Expand Up @@ -1101,6 +1121,7 @@ def __init__(
), "Must have at least TE version 2.3 or higher to use symmetric memory all reduce"
extra_kwargs["symmetric_ar_type"] = self.config.symmetric_ar_type

_maybe_setup_gtp(self, gtp_group, extra_kwargs)
self.stride = stride

self.te_quant_params: Optional[TEQuantizationParams] = None
Expand Down Expand Up @@ -1216,7 +1237,7 @@ def extra_repr(self) -> str:
f"in_features={self.in_features}, "
f"out_features={self.out_features}, "
f"bias={self.use_bias}, "
f"TP={self.tp_size}"
f"TP={self.tp_size}" + (f", GTP={self.gtp_size}" if hasattr(self, "gtp_size") else "")
)

def backward_dw(self):
Expand All @@ -1243,6 +1264,7 @@ def __init__(
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
gtp_group: Optional[torch.distributed.ProcessGroup] = None,
stride: int = 1,
name: str | None = None,
):
Expand Down Expand Up @@ -1282,6 +1304,7 @@ def __init__(
symmetric_ar_type=config.symmetric_ar_type,
tp_group=tp_group,
name=name,
gtp_group=gtp_group,
)

# Set proper partition_stride
Expand Down Expand Up @@ -1332,7 +1355,7 @@ def extra_repr(self) -> str:
f"in_features={self.in_features}, "
f"out_features={self.out_features}, "
f"bias={self.use_bias}, "
f"TP={self.tp_size}"
f"TP={self.tp_size}" + (f", GTP={self.gtp_size}" if hasattr(self, "gtp_size") else "")
)

def backward_dw(self):
Expand Down Expand Up @@ -1488,6 +1511,7 @@ def __init__(
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
name: str | None = None,
gtp_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Args:
Expand Down Expand Up @@ -1525,6 +1549,7 @@ def __init__(
symmetric_ar_type=config.symmetric_ar_type,
tp_group=tp_group,
name=name,
gtp_group=gtp_group,
)
if config.use_cpu_initialization:
world_size = get_pg_size(tp_group)
Expand Down Expand Up @@ -1571,7 +1596,7 @@ def extra_repr(self) -> str:
f"in_features={self.in_features}, "
f"out_features={self.out_features}, "
f"bias={self.use_bias}, "
f"TP={self.tp_size}"
f"TP={self.tp_size}" + (f", GTP={self.gtp_size}" if hasattr(self, "gtp_size") else "")
)

def backward_dw(self):
Expand Down Expand Up @@ -1981,6 +2006,7 @@ def __init__(
self._tp_group = tp_group
tp_size = get_pg_size(tp_group)
tp_group_for_te = tp_group
gtp_group = pg_collection.expt_gtp

self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

Expand All @@ -2000,6 +2026,7 @@ def __init__(
tp_size = 1
tp_group_for_te = None

_maybe_setup_gtp(self, gtp_group, extra_kwargs)
if is_te_min_version("2.14.0"):
extra_kwargs["single_grouped_weight"] = getattr(
config, "moe_single_grouped_weight", False
Expand Down Expand Up @@ -2378,16 +2405,25 @@ def get_gemm_tensor(param_name: str, gemm_idx: int) -> torch.Tensor:
)
if self.use_bias:
sharded_state_dict[f"{prefix}bias{gemm_idx}"] = sub_sd[f"{gemm_idx}.bias"]
# Adjust replica ids - replication along DP modulo EP
# Set the expert-DP replica_id, picking the group by what EGTP does to each entry:
# - weight ShardedTensor: SHARDED across EGTP (distinct chunks) → not replicas →
# use ``intra_expt_dp_no_egtp``.
# - _extra_state ShardedObject: REPLICATED across EGTP → need distinct replica_ids
# to avoid duplicate-writer collisions → use full ``expt_dp``.
# EGTP=1: the two groups coincide, so this is a no-op.
expt_dp_full = self._pg_collection.expt_dp
expt_dp_intra = self._pg_collection.intra_expt_dp_no_egtp
for k, sh_ten in sharded_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f"Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}"
if getattr(sh_ten, "is_data_parallel_fully_shard", False):
edp_replica_id = 0
elif isinstance(sh_ten, ShardedObject):
edp_replica_id = get_pg_rank(expt_dp_full)
else:
edp_replica_id = get_pg_rank(self._pg_collection.expt_dp)
edp_replica_id = get_pg_rank(expt_dp_intra)
sh_ten.replica_id = (*replica_id[:2], edp_replica_id)
return sharded_state_dict

Expand All @@ -2399,6 +2435,16 @@ def backward_dw(self):
if self.delay_wgrad_compute:
super().backward_dw()

def __repr__(self):
gtp_str = f", GTP={self.gtp_size}" if hasattr(self, "gtp_size") else ""
return (
f"{type(self).__name__}(per expert(["
f"in={self.in_features}, out={self.out_features}]) "
f"X num_gemms={self.num_gemms}, "
f"bias={self.use_bias}, TP={self.tp_size}"
f"{gtp_str})"
)
Comment thread
fanshiqing marked this conversation as resolved.

class TEColumnParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
Expand Down
Loading
Loading