diff --git a/autoparallel/_testing/models/dsv3.py b/autoparallel/_testing/models/dsv3.py index 863becf1..5a897b71 100644 --- a/autoparallel/_testing/models/dsv3.py +++ b/autoparallel/_testing/models/dsv3.py @@ -13,21 +13,52 @@ import triton import triton.language as tl from torch import nn - -# from torchtitan.distributed.expert_parallel import expert_parallel from torch.distributed.tensor import DeviceMesh, DTensor from torch.distributed.tensor.placement_types import Partial, Replicate, Shard from torch.nn.attention import SDPBackend, sdpa_kernel from autoparallel.collectives import all_to_all, axis_size, local_map -# When True, MoE uses uniform token routing and balanced all-to-all splits, -# eliminating data-dependent ops (.tolist(), dynamic grouped_mm offsets) that -# prevent Inductor compilation. -FORCE_BALANCED_ROUTING: bool = False +_MODULE_FQN = "module_fqn" + + +def _to_compute_dtype( + x: torch.Tensor, + compute_dtype: torch.dtype | None, +) -> torch.Tensor: + if compute_dtype is None or not torch.is_floating_point(x): + return x + return x.to(compute_dtype) + + +def _linear_compute( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + compute_dtype: torch.dtype | None, +) -> torch.Tensor: + if compute_dtype is None: + return F.linear(x, weight, bias) + bias = None if bias is None else bias.to(compute_dtype) + return F.linear(x.to(compute_dtype), weight.to(compute_dtype), bias) + + +def _rms_norm_compute( + x: torch.Tensor, + norm: nn.RMSNorm, + compute_dtype: torch.dtype | None, +) -> torch.Tensor: + if compute_dtype is None: + return norm(x) + weight = None if norm.weight is None else norm.weight.to(compute_dtype) + return F.rms_norm( + x.to(compute_dtype), + norm.normalized_shape, + weight, + norm.eps, + ) -# parallelized kernel @triton.jit def _fill_indices_kernel( tokens_per_expert_group_ptr, @@ -36,52 +67,30 @@ def _fill_indices_kernel( output_ptr, experts_per_rank: tl.constexpr, num_ranks: tl.constexpr, - BLOCK_SIZE: tl.constexpr, # Number of threads per block + BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) num_programs = tl.num_programs(axis=0) - # map programs (blocks) to the experts and loop (grid stride) if needed for expert_id in range(pid, experts_per_rank, num_programs): - # read this experts write offset write_offset = tl.load(write_offsets_ptr + expert_id) for r in range(num_ranks): - # index into tokens_per_expert_group array i = r * experts_per_rank + expert_id - - # load start index and number of tokens for this expert-rank pair start_index = tl.load(start_index_values_ptr + i) length = tl.load(tokens_per_expert_group_ptr + i) - - # each thread in block processes tokens in parallel offsets = tl.arange(0, BLOCK_SIZE) - # tokens are processed in chunks of BLOCK_SIZE for chunk_start in range(0, length, BLOCK_SIZE): chunk_offsets = chunk_start + offsets - - # mask valid indices mask = chunk_offsets < length - values = start_index + chunk_offsets - - # destination dest_indices = write_offset + chunk_offsets - - # store tl.store(output_ptr + dest_indices, values, mask=mask) - # update write offset for next rank write_offset += length -# ============== -# wrapper -# ============== - - -# workaround until local_map functionalization is fixed: https://github.com/pytorch/pytorch/issues/167568 @torch.library.custom_op("autoparallel::fill_indices_functional", mutates_args=()) def fill_indices_functional( tokens_per_expert_group: torch.Tensor, @@ -91,20 +100,14 @@ def fill_indices_functional( num_ranks: int, max_len: int, block_size: int = 128, - max_blocks: int = 1024, # cap on total number of blocks to launch + max_blocks: int = 1024, ) -> torch.Tensor: - # preallocate output permuted_indices = torch.full( (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device ) - # write offsets is per local expert... num_blocks = min(experts_per_rank, max_blocks) - # grid = one block per expert unless capped and then we loop... - grid = (num_blocks,) - - # launch kernel - _fill_indices_kernel[grid]( + _fill_indices_kernel[(num_blocks,)]( tokens_per_expert_group, start_index_values, write_offsets, @@ -125,126 +128,38 @@ def _( num_ranks: int, max_len: int, block_size: int = 128, - max_blocks: int = 1024, # cap on total number of blocks to launch + max_blocks: int = 1024, ) -> torch.Tensor: return torch.full( (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device ) -# reference -def fill_indices_cpu( - tokens_per_expert_group: torch.Tensor, - start_index_values: torch.Tensor, - write_offsets: torch.Tensor, - experts_per_rank: int, - num_ranks: int, - max_len: int, -): - # We need to preallocate the output - we ignore device and force it on cpu - # device = tokens_per_expert_group.device - permuted_indices = torch.full( - (max_len,), - -1, - dtype=torch.int32, - ) # device=device) - # Fill the permuted indices - # For each local expert - for e in range(experts_per_rank): - write_start = write_offsets[e].item() - assert isinstance(write_start, int) - # For each remote rank - for r in range(num_ranks): - i: int = r * experts_per_rank + e - start_index = start_index_values[i].item() - length = tokens_per_expert_group[i].item() - assert isinstance(start_index, int) - assert isinstance(length, int) - # Fill in the indices - if length > 0: - end_idx: int = min(write_start + length, max_len) - permuted_indices[write_start:end_idx] = torch.arange( - start_index, - start_index + (end_idx - write_start), - dtype=torch.int32, - ) - write_start += length - return permuted_indices - - def generate_permute_indices( tokens_per_expert_group: torch.Tensor, experts_per_rank: int, num_ranks: int, max_len: int, alignment: int, - use_cpu: bool = False, ): - """ - Prepare permutation indices and the number of tokens for each expert. - - Args: - tokens_per_expert_group: number of tokens for each expert from all ranks. - experts_per_rank: number of experts per rank. - num_ranks: number of ranks. - max_len: maximum length of the output index vector. - alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts. - use_cpu: whether to use CPU implementation. - - - Returns: - permuted_indices: Tensor of indices that map original token order to the expert-grouped order. - m_sizes: aligned number of tokens for each expert (padded to alignment boundary). - m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens. - - Explanatory details: - `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example: - From: | rank 0 | rank 1 | - To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 | - | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 | - """ - - # prefix sum to get start index of each expert (parallel scan kernel in future?) start_index_values = ( torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group ) - - # total tokens for each expert (sum over ranks) total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) - - # pad out empty experts to alignment requirement total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) - - # align the chunk sizes (cdiv) m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to( torch.int32 ) - - # additional prefix sum to get write offset of each expert in permuted_indices - # write offsets is per local expert, not global m_offsets = torch.cumsum(m_sizes, 0) write_offsets = m_offsets - m_sizes - - # Select the implementation to use - if use_cpu: - permuted_indices = fill_indices_cpu( - tokens_per_expert_group, - start_index_values, - write_offsets, - experts_per_rank, - num_ranks, - max_len, - ) - else: - permuted_indices = fill_indices_functional( - tokens_per_expert_group, - start_index_values, - write_offsets, - experts_per_rank, - num_ranks, - max_len, - ) - + permuted_indices = fill_indices_functional( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) return permuted_indices, m_sizes, m_offsets.to(torch.int32) @@ -252,89 +167,20 @@ def generate_permute_indices( def _round_up(x: int, y: int) -> int: - """Round up x to the nearest multiple of y.""" x_ceil_div_y = (x + y - 1) // y return x_ceil_div_y * y -def expert_parallel(func: Callable) -> Callable: - """ - This is a wrapper applied to the GroupedExperts computation, serving - the following three purposes: - 1. Convert parameters from DTensors to plain Tensors, to work with - dynamic-shape inputs which cannot be easily expressed as DTensors. - 2. In Expert Parallel, apply the generate_permute_indices kernel to - permute the inputs to be ordered by local experts (see the _token_dispatch - function in ExpertParallel) and permute the outputs back. - 3. In order to use torch._grouped_mm, we need to make sure the number of - tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices - kernel also helps achieve this via padding, without incurring synchronization - between device and host. Note that this will create side effects when wrapping - the for-loop implementation of GroupedExperts, as it does not need padding. - - Among the above: - 1 and 2 are needed only when expert_parallel_degree > 1. - 3 is needed even for single-device computation. - 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. - """ - - def wrapper( - w1: torch.Tensor, - w2: torch.Tensor, - w3: torch.Tensor, - x: torch.Tensor, - num_tokens_per_expert: torch.Tensor, - ) -> torch.Tensor: - global TOKEN_GROUP_ALIGN_SIZE_M - if isinstance(w1, DTensor): - assert isinstance(w2, DTensor) - assert isinstance(w3, DTensor) - w1 = w1.to_local() - w2 = w2.to_local() - w3 = w3.to_local() - - experts_per_ep_rank = w1.shape[0] - num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank - # assert ( - # num_ep_ranks == 64 - # ), f"{num_ep_ranks}, {experts_per_ep_rank}, num_tokens_per_expert.shape: {num_tokens_per_expert.shape}, x={x.ndim}, w={w1.shape}" - - # Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M, - # by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M. - x_padded_per_expert = ( - x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M - ) - padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M) - with torch.no_grad(): - ( - permuted_indices, - num_tokens_per_expert, - _, # offsets, - ) = generate_permute_indices( - num_tokens_per_expert, - experts_per_ep_rank, - num_ep_ranks, - padded_max_len, - TOKEN_GROUP_ALIGN_SIZE_M, - ) - - x = torch.vstack((x, x.new_zeros((x.shape[-1])))) - input_shape = x.shape - x = x[permuted_indices, :] - - out = func(w1, w2, w3, x, num_tokens_per_expert) - - out_unpermuted = out.new_empty(input_shape) - out_unpermuted[permuted_indices, :] = out - out = out_unpermuted[:-1] - - return out - - return wrapper - - -def functional_feed_forward(w1, w2, w3, x): - return F.linear(F.silu(F.linear(x, w1)) * F.linear(x, w3), w2) +def functional_feed_forward( + w1, + w2, + w3, + x, + compute_dtype: torch.dtype | None = None, +): + h1 = _linear_compute(x, w1, None, compute_dtype) + h3 = _linear_compute(x, w3, None, compute_dtype) + return _linear_compute(F.silu(h1) * h3, w2, None, compute_dtype) # can be used as dense FFN layer or shared experts in MoE layers @@ -354,14 +200,22 @@ def __init__( self, dim: int, hidden_dim: int, + compute_dtype: torch.dtype | None = None, ): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.compute_dtype = compute_dtype def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return functional_feed_forward( + self.w1.weight, + self.w2.weight, + self.w3.weight, + x, + self.compute_dtype, + ) def init_weights(self, init_std: float = 0.02): nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) @@ -369,9 +223,6 @@ def init_weights(self, init_std: float = 0.02): nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) -# TODO: keeping this for-loop implementation for comparison -# and readability, may remove later -@expert_parallel def _run_experts_for_loop( w1: torch.Tensor, w2: torch.Tensor, @@ -379,35 +230,47 @@ def _run_experts_for_loop( x_: torch.Tensor, num_tokens_per_expert_: torch.Tensor, ) -> torch.Tensor: + if isinstance(w1, DTensor): + assert isinstance(w2, DTensor) + assert isinstance(w3, DTensor) + w1 = w1.to_local() + w2 = w2.to_local() + w3 = w3.to_local() + # NOTE: this would incur a synchronization between device and host num_tokens_per_expert: list[int] = num_tokens_per_expert_.tolist() - # side-effect code due to the usage of generate_permute_indices - num_padding: int = x_.shape[0] - sum(num_tokens_per_expert) - # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) x: tuple[torch.Tensor, ...] = torch.split( - x_[: sum(num_tokens_per_expert)], + x_, split_size_or_sections=num_tokens_per_expert, dim=0, ) out_experts_splits = [] for expert_idx, x_expert in enumerate(x): - h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) - h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) - h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) + compute_dtype = x_expert.dtype if torch.is_floating_point(x_expert) else None + x_expert = _to_compute_dtype(x_expert, compute_dtype) + h = F.silu( + torch.matmul( + x_expert, + _to_compute_dtype(w1[expert_idx], compute_dtype).transpose(-2, -1), + ) + ) + h = h * torch.matmul( + x_expert, + _to_compute_dtype(w3[expert_idx], compute_dtype).transpose(-2, -1), + ) + h = torch.matmul( + h, + _to_compute_dtype(w2[expert_idx], compute_dtype).transpose(-2, -1), + ) # h shape (tokens_per_expert(varying), dim) out_experts_splits.append(h) - out = torch.cat(out_experts_splits, dim=0) - - # side-effect code due to the usage of generate_permute_indices - out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) - return out + return torch.cat(out_experts_splits, dim=0) -@expert_parallel def _run_experts_grouped_mm( w1: torch.Tensor, w2: torch.Tensor, @@ -415,6 +278,13 @@ def _run_experts_grouped_mm( x: torch.Tensor, num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: + if isinstance(w1, DTensor): + assert isinstance(w2, DTensor) + assert isinstance(w3, DTensor) + w1 = w1.to_local() + w2 = w2.to_local() + w3 = w3.to_local() + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # grouped mm between a 2D tensor and a 3D tensor assert x.dim() == 2 @@ -538,7 +408,12 @@ def forward( """ # scores shape (bs*slen, num_experts) # scores = self.gate(x) - scores = torch.nn.functional.linear(x, gate_weight) + scores = _linear_compute( + x, + gate_weight, + None, + x.dtype if torch.is_floating_point(x) else None, + ) # By default, sigmoid or softmax is performed in float32 to avoid loss explosion if self.score_func == "sigmoid": @@ -636,6 +511,44 @@ def forward( ) +def _permute(routed_input, num_tokens_per_expert_group, ep_size, num_local_experts): + """Reorder tokens from rank-major to expert-major layout. + + This is the local_map-friendly version of TorchTitan's dispatcher permute: + it uses a fixed-size custom op rather than repeat_interleave with a dynamic + output shape, which local_map cannot currently capture. + """ + x_padded_per_expert = ( + routed_input.shape[0] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M + ) + padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M) + with torch.no_grad(): + (permuted_indices, num_tokens_per_expert, _,) = generate_permute_indices( + num_tokens_per_expert_group, + num_local_experts, + ep_size, + padded_max_len, + TOKEN_GROUP_ALIGN_SIZE_M, + ) + + routed_input = torch.vstack( + (routed_input, routed_input.new_zeros((routed_input.shape[-1]))) + ) + return ( + routed_input.shape, + routed_input[permuted_indices, :], + permuted_indices, + num_tokens_per_expert, + ) + + +def _unpermute(routed_output, input_shape, permuted_indices): + """Reverse expert-major reordering.""" + out_unpermuted = routed_output.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = routed_output + return out_unpermuted[:-1] + + def _token_dispatch(routed_input, num_tokens_per_expert, axis_name): ep_size = axis_size(axis_name) @@ -648,24 +561,20 @@ def _token_dispatch(routed_input, num_tokens_per_expert, axis_name): axis_name, ) - if FORCE_BALANCED_ROUTING: - input_splits = None - output_splits = None - else: - with torch.no_grad(): - input_splits = ( - num_tokens_per_expert.view(ep_size, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=True) - ) - # NOTE: this would incur a device-to-host sync - output_splits = ( - num_tokens_per_expert_group.view(ep_size, -1) - .sum(dim=1) - .to(torch.device("cpu"), non_blocking=False) - ) - input_splits = input_splits.tolist() - output_splits = output_splits.tolist() + with torch.no_grad(): + input_splits = ( + num_tokens_per_expert.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=True) + ) + # NOTE: this would incur a device-to-host sync + output_splits = ( + num_tokens_per_expert_group.view(ep_size, -1) + .sum(dim=1) + .to(torch.device("cpu"), non_blocking=False) + ) + input_splits = input_splits.tolist() + output_splits = output_splits.tolist() with fx_traceback.annotate({"comm_region": "token_dispatch"}): routed_input = all_to_all( routed_input, @@ -674,20 +583,29 @@ def _token_dispatch(routed_input, num_tokens_per_expert, axis_name): axis_name, ) - # NOTE: After this all-to-all, the routed input is put on proper EP rank. - # However, the num_tokens_per_expert_group is not of the final target format - # [#tokens for local expert 0, #tokens for local expert 1, ...] - # Rather, it is of the format - # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., - # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] - # We need to perform another shuffle to get the correct format -- this is done via the function - # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens - # each expert gets locally is a multiple of ALIGN_SIZE_M. - - return routed_input, num_tokens_per_expert_group, input_splits, output_splits + # Reorder from rank-major to expert-major via _permute. + # + # num_tokens_per_expert_group layout after all-to-all: + # (e0,r0), (e1,r0), ..., (e0,r1), (e1,r1), ... (rank-major) + # _permute reshuffles to: + # (e0,r0), (e0,r1), ..., (e1,r0), (e1,r1), ... (expert-major) + num_local_experts = num_tokens_per_expert_group.shape[0] // ep_size + return ( + *_permute( + routed_input, + num_tokens_per_expert_group, + ep_size, + num_local_experts, + ), + input_splits, + output_splits, + ) -def _token_combine(routed_output, input_splits, output_splits, axis_name): +def _token_combine( + routed_output, input_shape, permuted_indices, input_splits, output_splits, axis_name +): + routed_output = _unpermute(routed_output, input_shape, permuted_indices) with fx_traceback.annotate({"comm_region": "token_combine"}): routed_output = all_to_all( routed_output, @@ -709,72 +627,51 @@ def local_mapped_region( out: torch.Tensor, top_k: int, num_experts: int, + score_before_experts: bool, axis_name: str, ) -> tuple[torch.Tensor, torch.Tensor]: # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}" dim = x.shape[-1] - if FORCE_BALANCED_ROUTING: - # Uniform distribution: same number of tokens per expert. - # Eliminates data-dependent grouped_mm offsets for Inductor. - total_tokens = selected_experts_indices.numel() - num_tokens_per_expert = torch.full( - (num_experts,), - total_tokens // num_experts, - device=x.device, - dtype=torch.int32, - ) - else: - # num_tokens_per_expert = torch.ops.autoparallel.batched_histc( - num_tokens_per_expert = torch.histc( - selected_experts_indices.flatten(), - bins=num_experts, - min=0, - max=num_experts, - ) + # num_tokens_per_expert = torch.ops.autoparallel.batched_histc( + num_tokens_per_expert = torch.histc( + selected_experts_indices.flatten(), + bins=num_experts, + min=0, + max=num_experts, + ) # total_tokens_per_expert = all_reduce(num_tokens_per_expert, axis_name) total_tokens_per_expert = num_tokens_per_expert token_indices_experts_sorted = torch.argsort( - selected_experts_indices.flatten(1), dim=-1, stable=True + selected_experts_indices.view(-1), stable=True ) - # top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] - top_scores_experts_sorted = top_scores.view_as(token_indices_experts_sorted).gather( - 1, token_indices_experts_sorted - ) + top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] token_indices_experts_sorted = token_indices_experts_sorted // top_k # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted[..., None].expand( - -1, -1, dim - ) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather( - x.view(token_indices_experts_sorted.shape[0], -1, dim), - dim=1, - index=token_indices_experts_sorted, - ) - routed_input = ( - routed_input.to(torch.float32) * top_scores_experts_sorted[..., None] - ).to(x.dtype) + routed_input = x[token_indices_experts_sorted] + if score_before_experts: + routed_input = ( + routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) shape = routed_input.shape dim = shape[-1] - routed_input = routed_input.view(-1, dim) num_tokens_per_expert = num_tokens_per_expert.view(-1) ( + input_shape, routed_input, + permuted_indices, num_tokens_per_expert_group, input_splits, output_splits, ) = _token_dispatch(routed_input, num_tokens_per_expert, axis_name) routed_output = _run_experts_grouped_mm( - # experts_w1, experts_w2, experts_w3, routed_input, num_tokens_per_expert experts_w1, experts_w2, experts_w3, @@ -783,14 +680,26 @@ def local_mapped_region( ) routed_output = _token_combine( - routed_output, input_splits, output_splits, axis_name + routed_output, + input_shape, + permuted_indices, + input_splits, + output_splits, + axis_name, ) - torch._check(routed_output.shape[0] == shape[0] * shape[1]) + torch._check(routed_output.shape[0] == shape[0]) - routed_output = routed_output.view(shape) + if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + ).to(routed_output.dtype) - out = out.scatter_add(dim=1, index=token_indices_experts_sorted, src=routed_output) + out = out.scatter_add( + dim=0, + index=token_indices_experts_sorted.reshape(-1, 1).expand(-1, dim), + src=routed_output, + ) return out, total_tokens_per_expert @@ -900,15 +809,12 @@ def _moe_forward( reorderer: TokenReorderer, mesh: Optional[DeviceMesh], axis_name: str, + score_before_experts: bool, + compute_dtype: torch.dtype | None = None, ): # x: 64, 2048, 256 bs, slen, dim = x.shape - - # local_batch_size = 4 - # num_gpus_participating = 32 * 2 - # num_experts_per_groups = local_batch_size * num_gpus_participating - # x = x.unflatten(0, (-1, num_experts_per_groups)) - # x = x.view(-1, dim) + x = x.view(-1, dim) # top_scores and selected_experts_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) @@ -943,7 +849,7 @@ def _moe_forward( # shape (bs*slen*top_k, dim) # routed_output = experts(routed_input, num_tokens_per_expert) - out = functional_feed_forward(shared_w1, shared_w2, shared_w3, x) + out = functional_feed_forward(shared_w1, shared_w2, shared_w3, x, compute_dtype) ###################################################### # This is in the local_map region @@ -969,6 +875,7 @@ def _moe_forward( None, None, None, + None, ) out, num_tokens_per_expert = local_map( @@ -991,6 +898,7 @@ def _moe_forward( out, router.top_k, router.num_experts, + score_before_experts, axis_name, ) # assert False, f"there: {out.shape}, {num_tokens_per_expert.shape}" @@ -1010,60 +918,47 @@ def _moe_forward( return out, num_tokens_per_expert -@dataclass -class MoEArgs: - num_experts: int = 8 - num_shared_experts: int = 1 - - # router - score_func: Literal["softmax", "sigmoid"] = "sigmoid" - route_norm: bool = False - route_scale: float = 1.0 - score_before_experts: bool = True - - # token-choice - top_k: int = 1 - use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation - load_balance_coeff: float | None = 1e-3 - - _debug_force_load_balance: bool = False - # if True, we force each experts get same amount of token via round-robin - mesh: Optional[DeviceMesh] = None - - class MoE(nn.Module): - def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + top_k: int, + shared_experts_hidden_dim: int, + score_func: Literal["softmax", "sigmoid"] = "sigmoid", + route_norm: bool = False, + route_scale: float = 1.0, + score_before_experts: bool = True, + use_grouped_mm: bool = True, + load_balance_coeff: float | None = 1e-3, + mesh: DeviceMesh | None = None, + compute_dtype: torch.dtype | None = None, + ): super().__init__() - num_experts = moe_args.num_experts - self.mesh = moe_args.mesh + self.mesh = mesh self.axis_name = "ep" + self.compute_dtype = compute_dtype self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, num_experts=num_experts, - use_grouped_mm=moe_args.use_grouped_mm, + use_grouped_mm=use_grouped_mm, ) self.router = TokenChoiceTopKRouter( dim=dim, num_experts=num_experts, - top_k=moe_args.top_k, - score_func=moe_args.score_func, - route_norm=moe_args.route_norm, - route_scale=moe_args.route_scale, - ) - self.reorderer = TokenReorderer(num_experts=num_experts, top_k=moe_args.top_k) - assert moe_args.num_shared_experts > 0 - self.shared_experts = FeedForward( - dim=dim, hidden_dim=hidden_dim * moe_args.num_shared_experts + top_k=top_k, + score_func=score_func, + route_norm=route_norm, + route_scale=route_scale, ) - self.score_before_experts = moe_args.score_before_experts + self.reorderer = TokenReorderer(num_experts=num_experts, top_k=top_k) + self.shared_experts = FeedForward(dim=dim, hidden_dim=shared_experts_hidden_dim) + self.score_before_experts = score_before_experts - # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) - # NOTE: tokens_per_expert is accumulated in the model forward pass. - # expert_bias is updated outside the model in an optimizer step pre hook - # to work with gradient accumulation. - self.load_balance_coeff = moe_args.load_balance_coeff + self.load_balance_coeff = load_balance_coeff if self.load_balance_coeff is not None: assert self.load_balance_coeff > 0.0 self.register_buffer( @@ -1104,6 +999,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.reorderer, self.mesh, self.axis_name, + self.score_before_experts, + self.compute_dtype, ) # HOPs don't support buffer mutations, keep this outside @@ -1185,92 +1082,202 @@ def build_attention( return ScaledDotProductAttention(attn_mask_type) -# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py @dataclass -class DeepSeekV3ModelArgs: - """ - Data class for defining model arguments and hyperparameters. +class RoPEConfig: + dim: int = 64 + max_seq_len: int = 4096 * 4 + theta: float = 10000.0 + rope_factor: float = 40.0 + beta_fast: float = 32.0 + beta_slow: float = 1.0 + original_seq_len: int = 4096 - Attributes: - max_batch_size (int): Maximum batch size. - max_seq_len (int): Maximum sequence length. - vocab_size (int): Vocabulary size. - dim (int): Model dimension. - inter_dim (int): Intermediate dimension for MLP layers. - moe_inter_dim (int): Intermediate dimension for MoE layers. - n_layers (int): Number of transformer layers. - n_dense_layers (int): Number of dense layers in the model. - n_heads (int): Number of attention heads. - norm_eps (float): Epsilon value used for RMSNorm. - moe_args (MoEArgs): MoE configuration. - n_expert_groups (int): Number of expert groups. - n_limited_groups (int): Number of limited groups for MoE routing. - q_lora_rank (int): LoRA rank for query projections. - kv_lora_rank (int): LoRA rank for key-value projections. - qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. - qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. - v_head_dim (int): Dimension for value projections. - use_flex_attn (bool): Whether to use FlexAttention. - attn_mask_type (str): Type of attention mask. - original_seq_len (int): Original sequence length. - rope_theta (float): Base for rotary positional encoding. - rope_factor (float): Scaling factor for extended sequence lengths. - beta_fast (int): Fast beta correction factor. - beta_slow (int): Slow beta correction factor. - """ - max_batch_size: int = 8 - max_seq_len: int = 4096 * 4 - vocab_size: int = 102400 - dim: int = 2048 - inter_dim: int = 10944 - moe_inter_dim: int = 1408 - n_layers: int = 27 - n_dense_layers: int = 1 - n_heads: int = 16 - norm_eps: float = 1e-5 # eps used for RMSNorm +@dataclass +class NormConfig: + eps: float = 1e-5 + - # MoE - moe_args: MoEArgs = field(default_factory=MoEArgs) - # TODO: node-limited routing is not supported yet - n_expert_groups: int = 1 - n_limited_groups: int = 1 +@dataclass +class LinearConfig: + in_features: int = 0 + out_features: int = 0 + + +@dataclass +class SDPAConfig: + pass - # Multi-Head Latent Attention (MLA) + +@dataclass +class AttentionConfig: + n_heads: int = 16 q_lora_rank: int = 0 kv_lora_rank: int = 512 qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 - use_flex_attn: bool = False - attn_mask_type: str = "causal" - - # yarn - original_seq_len: int = 4096 - rope_theta: float = 10000.0 - rope_factor: float = 40 - beta_fast: int = 32 - beta_slow: int = 1 mscale: float = 1.0 + mask_type: str = "causal" + inner_attention: SDPAConfig = field(default_factory=SDPAConfig) -# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 -def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor: - """ - Precomputes frequency-based complex exponential values for rotary positional embeddings. +@dataclass +class TokenDispatcherConfig: + score_before_experts: bool = True - Args: - args (DeepSeekV3ModelArgs): Model arguments containing positional embedding parameters. - Returns: - torch.Tensor: Precomputed complex exponential values for positional embeddings. +@dataclass +class ExpertsConfig: + hidden_dim: int = 1408 + use_grouped_mm: bool = True + token_dispatcher: TokenDispatcherConfig = field( + default_factory=TokenDispatcherConfig + ) + + +@dataclass +class RouterConfig: + top_k: int = 1 + score_func: str = "sigmoid" + route_norm: bool = False + route_scale: float = 1.0 + + +@dataclass +class FeedForwardConfig: + w1: LinearConfig = field(default_factory=LinearConfig) + + +@dataclass +class MoEConfig: + num_experts: int = 8 + experts: ExpertsConfig = field(default_factory=ExpertsConfig) + router: RouterConfig = field(default_factory=RouterConfig) + load_balance_coeff: float | None = 1e-3 + shared_experts: FeedForwardConfig = field(default_factory=FeedForwardConfig) + + +@dataclass +class LayerConfig: + attention: AttentionConfig = field(default_factory=AttentionConfig) + attention_norm: NormConfig = field(default_factory=NormConfig) + ffn_norm: NormConfig = field(default_factory=NormConfig) + feed_forward: FeedForwardConfig | None = None + moe: MoEConfig | None = None + + +@dataclass +class DeepSeekV3Config: + """Hierarchical config for DeepSeekV3Model. + + Attribute paths are compatible with torchtitan's DeepSeekV3Model.Config, + so either config type can be passed to DeepSeekV3Model. """ - dim = args.qk_rope_head_dim - seqlen = args.max_seq_len - beta_fast = args.beta_fast - beta_slow = args.beta_slow - base = args.rope_theta - factor = args.rope_factor + + dim: int = 2048 + vocab_size: int = 102400 + rope: RoPEConfig = field(default_factory=RoPEConfig) + norm: NormConfig = field(default_factory=NormConfig) + layers: list = field(default_factory=list) + + +def make_dsv3_config( + dim: int = 256, + vocab_size: int = 2048, + n_layers: int = 6, + n_dense_layers: int = 1, + n_heads: int = 16, + q_lora_rank: int = 0, + kv_lora_rank: int = 512, + qk_nope_head_dim: int = 128, + qk_rope_head_dim: int = 64, + v_head_dim: int = 128, + mscale: float = 0.70, + dense_hidden_dim: int = 1024, + moe_hidden_dim: int = 256, + num_experts: int = 8, + num_shared_experts: int = 2, + top_k: int = 3, + score_func: str = "softmax", + route_norm: bool = False, + score_before_experts: bool = False, + max_seq_len: int = 4096 * 4, + rope_theta: float = 10000.0, + rope_factor: float = 40.0, + beta_fast: float = 32.0, + beta_slow: float = 1.0, + original_seq_len: int = 4096, + load_balance_coeff: float | None = 1e-3, +) -> DeepSeekV3Config: + layers = [] + for layer_id in range(n_layers): + attn = AttentionConfig( + n_heads=n_heads, + q_lora_rank=q_lora_rank, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + mscale=mscale, + ) + if layer_id < n_dense_layers: + ff = FeedForwardConfig(w1=LinearConfig(out_features=dense_hidden_dim)) + moe = None + else: + ff = None + moe = MoEConfig( + num_experts=num_experts, + experts=ExpertsConfig( + hidden_dim=moe_hidden_dim, + token_dispatcher=TokenDispatcherConfig( + score_before_experts=score_before_experts, + ), + ), + router=RouterConfig( + top_k=top_k, + score_func=score_func, + route_norm=route_norm, + ), + load_balance_coeff=load_balance_coeff, + shared_experts=FeedForwardConfig( + w1=LinearConfig( + out_features=moe_hidden_dim * num_shared_experts, + ), + ), + ) + layers.append( + LayerConfig( + attention=attn, + feed_forward=ff, + moe=moe, + ) + ) + + return DeepSeekV3Config( + dim=dim, + vocab_size=vocab_size, + rope=RoPEConfig( + dim=qk_rope_head_dim, + max_seq_len=max_seq_len, + theta=rope_theta, + rope_factor=rope_factor, + beta_fast=beta_fast, + beta_slow=beta_slow, + original_seq_len=original_seq_len, + ), + layers=layers, + ) + + +def precompute_freqs_cis(config) -> torch.Tensor: + rope = config.rope + dim = rope.dim + seqlen = rope.max_seq_len + beta_fast = rope.beta_fast + beta_slow = rope.beta_slow + base = rope.theta + factor = rope.rope_factor def find_correction_dim( num_rotations: float, dim: int, base: float, max_seq_len: int @@ -1336,9 +1343,9 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. - if seqlen > args.original_seq_len: + if seqlen > rope.original_seq_len: low, high = find_correction_range( - beta_fast, beta_slow, dim, base, args.original_seq_len + beta_fast, beta_slow, dim, base, rope.original_seq_len ) smooth = 1 - linear_ramp_factor(low, high, dim // 2) freqs = freqs / factor * (1 - smooth) + freqs * smooth @@ -1377,29 +1384,37 @@ class Attention(nn.Module): Multi-head attention (MLA) module. """ - def __init__(self, model_args: DeepSeekV3ModelArgs): + def __init__( + self, + attn_config, + model_config, + compute_dtype: torch.dtype | None = None, + ): super().__init__() - self.dim = model_args.dim - self.n_heads = model_args.n_heads - self.q_lora_rank = model_args.q_lora_rank - self.kv_lora_rank = model_args.kv_lora_rank - self.qk_nope_head_dim = model_args.qk_nope_head_dim - self.qk_rope_head_dim = model_args.qk_rope_head_dim - self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim - self.v_head_dim = model_args.v_head_dim + self.dim = model_config.dim + self.n_heads = attn_config.n_heads + self.q_lora_rank = attn_config.q_lora_rank + self.kv_lora_rank = attn_config.kv_lora_rank + self.qk_nope_head_dim = attn_config.qk_nope_head_dim + self.qk_rope_head_dim = attn_config.qk_rope_head_dim + self.qk_head_dim = attn_config.qk_nope_head_dim + attn_config.qk_rope_head_dim + self.v_head_dim = attn_config.v_head_dim + self.compute_dtype = compute_dtype + + norm_eps = model_config.norm.eps if self.q_lora_rank == 0: self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim, bias=False) else: self.wq_a = nn.Linear(self.dim, self.q_lora_rank, bias=False) - self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps) + self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=norm_eps) self.wq_b = nn.Linear( self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False ) self.wkv_a = nn.Linear( self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False ) - self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps) + self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=norm_eps) self.wkv_b = nn.Linear( self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1408,11 +1423,13 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim, bias=False) self.softmax_scale = self.qk_head_dim**-0.5 - if model_args.max_seq_len > model_args.original_seq_len: - mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 + rope_cfg = model_config.rope + if rope_cfg.max_seq_len > rope_cfg.original_seq_len: + mscale = 0.1 * attn_config.mscale * math.log(rope_cfg.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale - self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) + use_flex_attn = "FlexAttention" in type(attn_config.inner_attention).__name__ + self.sdpa = build_attention(use_flex_attn, attn_config.mask_type) def forward( self, @@ -1433,10 +1450,25 @@ def forward( # Query projection if self.q_lora_rank == 0: - q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim) + q = _linear_compute( + x, + self.wq.weight, + self.wq.bias, + self.compute_dtype, + ) # (bsz, seqlen, n_heads * qk_head_dim) else: - q = self.wq_a(x) - q = self.wq_b(self.q_norm(q)) + q = _linear_compute( + x, + self.wq_a.weight, + self.wq_a.bias, + self.compute_dtype, + ) + q = _linear_compute( + _rms_norm_compute(q, self.q_norm, self.compute_dtype), + self.wq_b.weight, + self.wq_b.bias, + self.compute_dtype, + ) # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual # local heads from sizes of q and kv as TP may have sharded them after # the above linear ops. @@ -1448,15 +1480,23 @@ def forward( q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) # Key-value projection - kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) + kv = _linear_compute( + x, + self.wkv_a.weight, + self.wkv_a.bias, + self.compute_dtype, + ) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = apply_rotary_emb( k_pe.unsqueeze(2), freqs_cis ) # (bsz, seqlen, 1, qk_rope_head_dim) - kv = self.wkv_b( - self.kv_norm(kv) + kv = _linear_compute( + _rms_norm_compute(kv, self.kv_norm, self.compute_dtype), + self.wkv_b.weight, + self.wkv_b.bias, + self.compute_dtype, ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) @@ -1467,6 +1507,9 @@ def forward( q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) + q = _to_compute_dtype(q, self.compute_dtype) + k = _to_compute_dtype(k, self.compute_dtype) + v = _to_compute_dtype(v, self.compute_dtype) output = self.sdpa(q, k, v, scale=self.softmax_scale) @@ -1475,7 +1518,12 @@ def forward( 1, 2 ).contiguous() # (bsz, seqlen, n_heads, v_head_dim) output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) - return self.wo(output) # (bsz, seqlen, dim) + return _linear_compute( + output, + self.wo.weight, + self.wo.bias, + self.compute_dtype, + ) # (bsz, seqlen, dim) def init_weights(self, init_std: float): linear_list = [ @@ -1501,22 +1549,50 @@ class TransformerBlock(nn.Module): Transformer block with attention and feed-forward layers. """ - def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): - + def __init__( + self, + layer_id: int, + layer_config, + model_config, + mesh: DeviceMesh | None = None, + compute_dtype: torch.dtype | None = None, + ): super().__init__() - self.attention = Attention(model_args) - self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + dim = model_config.dim + self.compute_dtype = compute_dtype + self.attention = Attention( + layer_config.attention, + model_config, + compute_dtype=compute_dtype, + ) + self.attention_norm = nn.RMSNorm(dim, eps=layer_config.attention_norm.eps) + self.ffn_norm = nn.RMSNorm(dim, eps=layer_config.ffn_norm.eps) - self.moe_enabled = layer_id >= model_args.n_dense_layers + self.moe_enabled = layer_config.moe is not None if self.moe_enabled: + moe_cfg = layer_config.moe self.moe = MoE( - model_args.moe_args, - dim=model_args.dim, - hidden_dim=model_args.moe_inter_dim, + dim=dim, + hidden_dim=moe_cfg.experts.hidden_dim, + num_experts=moe_cfg.num_experts, + top_k=moe_cfg.router.top_k, + shared_experts_hidden_dim=moe_cfg.shared_experts.w1.out_features, + score_func=moe_cfg.router.score_func, + route_norm=moe_cfg.router.route_norm, + route_scale=moe_cfg.router.route_scale, + score_before_experts=moe_cfg.experts.token_dispatcher.score_before_experts, + use_grouped_mm=moe_cfg.experts.use_grouped_mm, + load_balance_coeff=moe_cfg.load_balance_coeff, + mesh=mesh, + compute_dtype=compute_dtype, ) else: - self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) + ff_cfg = layer_config.feed_forward + self.feed_forward = FeedForward( + dim, + ff_cfg.w1.out_features, + compute_dtype=compute_dtype, + ) self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 self.layer_id = layer_id @@ -1532,11 +1608,16 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.attention( + _rms_norm_compute(x, self.attention_norm, self.compute_dtype), + freqs_cis, + ) if self.moe_enabled: - x = x + self.moe(self.ffn_norm(x)) + x = x + self.moe(_rms_norm_compute(x, self.ffn_norm, self.compute_dtype)) else: - x = x + self.feed_forward(self.ffn_norm(x)) + x = x + self.feed_forward( + _rms_norm_compute(x, self.ffn_norm, self.compute_dtype) + ) return x def init_weights(self, buffer_device: torch.device): @@ -1554,28 +1635,40 @@ class DeepSeekV3Model(nn.Module): DeepSeek-V3 Transformer model with attention and feed-forward layers. """ - def __init__(self, model_args: DeepSeekV3ModelArgs): + def __init__( + self, + config, + mesh: DeviceMesh | None = None, + compute_dtype: torch.dtype | None = None, + ): # Explicitly call nn.Module.__init__ to avoid MRO issues when this class # is used with multiple inheritance (e.g., with ModelProtocol in torchtitan) nn.Module.__init__(self) - self.max_seq_len = model_args.max_seq_len - self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.compute_dtype = compute_dtype + self.max_seq_len = config.rope.max_seq_len + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) self.register_buffer( - "freqs_cis", precompute_freqs_cis(model_args), persistent=False + "freqs_cis", precompute_freqs_cis(config), persistent=False ) self.layers = torch.nn.ModuleDict() - for layer_id in range(model_args.n_layers): - self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + for layer_id, layer_config in enumerate(config.layers): + self.layers[str(layer_id)] = TransformerBlock( + layer_id, + layer_config, + config, + mesh, + compute_dtype=compute_dtype, + ) - self.norm = nn.RMSNorm(model_args.dim) - self.output = nn.Linear( - model_args.dim, - model_args.vocab_size, + self.norm = nn.RMSNorm(config.dim, eps=config.norm.eps) + self.lm_head = nn.Linear( + config.dim, + config.vocab_size, dtype=torch.get_default_dtype(), bias=False, ) - self.model_args = model_args + self.model_args = config def init_weights( self, buffer_device: torch.device | None = None, seed: int | None = None @@ -1607,20 +1700,28 @@ def forward( """ h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + h = _to_compute_dtype(h, self.compute_dtype) for layer in self.layers.values(): h = layer(h, self.freqs_cis) - h = self.norm(h) if self.norm is not None else h - output = self.output(h) if self.output is not None else h + h = ( + _rms_norm_compute(h, self.norm, self.compute_dtype) + if self.norm is not None + else h + ) + output = ( + _linear_compute( + h, + self.lm_head.weight, + self.lm_head.bias, + self.compute_dtype, + ) + if self.lm_head is not None + else h + ) return output -def dsv3_loss_fn(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.cross_entropy( - pred.flatten(0, 1).float(), labels.flatten(0, 1) - ) - - def _init_weights_tok_embeddings(self: DeepSeekV3Model, seed: int | None = None): if seed is not None: torch.manual_seed(seed) @@ -1648,13 +1749,45 @@ def _init_weights_layers( def _init_weights_norm_and_output(self: DeepSeekV3Model): if self.norm is not None: self.norm.reset_parameters() - if self.output is not None: + if self.lm_head is not None: final_out_std = self.model_args.dim**-0.5 cutoff_factor = 3 nn.init.trunc_normal_( - self.output.weight, + self.lm_head.weight, mean=0.0, std=final_out_std, a=-cutoff_factor * final_out_std, b=cutoff_factor * final_out_std, ) + + +def _annotate_once(fn: Callable, meta: dict): + if getattr(fn, "_graph_trainer_annotated", False): + return fn + wrapped = fx_traceback.annotate_fn(meta)(fn) + setattr(wrapped, "_graph_trainer_annotated", True) + return wrapped + + +def _annotate_module_fqns(model: nn.Module) -> None: + for fqn, submodule in model.named_modules(): + if fqn: + submodule.forward = _annotate_once( + submodule.forward, + {_MODULE_FQN: fqn}, + ) + + +def annotate_deepseekv3_for_graph_trainer(model: DeepSeekV3Model) -> None: + """Attach graph_trainer-compatible FX annotations to AP's DSv3 model.""" + global local_mapped_region + + local_mapped_region = _annotate_once( + local_mapped_region, + {"EP": "compute"}, + ) + MoE.forward = _annotate_once( # type: ignore[method-assign] + MoE.forward, + {"EP": "compute"}, + ) + _annotate_module_fqns(model) diff --git a/autoparallel/api.py b/autoparallel/api.py index d1b83dad..eca79884 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -32,6 +32,7 @@ assert_has_no_collectives, cleanup_graph, fix_scatter_on_aliased_inputs, + functionalize_fresh_index_put_mutations, update_joint_with_descriptors, ) from .input_validation import ( @@ -446,6 +447,7 @@ def _apply_placement_common(self, sharding_placement): from torch._inductor.fx_passes.post_grad import view_to_reshape view_to_reshape(parallel_gm) + functionalize_fresh_index_put_mutations(parallel_gm) mark_fsdp_all_gather_recomputation( parallel_gm.graph, self.reshard_after_forward diff --git a/autoparallel/graph_passes/graph_utils.py b/autoparallel/graph_passes/graph_utils.py index 87845b62..d14b1f12 100644 --- a/autoparallel/graph_passes/graph_utils.py +++ b/autoparallel/graph_passes/graph_utils.py @@ -199,6 +199,29 @@ def is_collective(node: torch.fx.Node) -> bool: } +def functionalize_fresh_index_put_mutations(gm: torch.fx.GraphModule) -> bool: + """Rewrite index_put_ on fresh tensors to the functional index_put form.""" + changed = False + for node in gm.graph.nodes: + if ( + node.op != "call_function" + or node.target != torch.ops.aten.index_put_.default + ): + continue + base = node.args[0] + if not isinstance(base, torch.fx.Node): + continue + if base.op == "placeholder" or len(base.users) != 1: + continue + node.target = torch.ops.aten.index_put.default + changed = True + + if changed: + gm.graph.lint() + gm.recompile() + return changed + + def fix_scatter_on_aliased_inputs(graph: torch.fx.Graph) -> None: """Insert clone before scatter ops whose input has zero strides (aliased from expand). diff --git a/autoparallel/module_construction.py b/autoparallel/module_construction.py index ad919971..d6645b08 100644 --- a/autoparallel/module_construction.py +++ b/autoparallel/module_construction.py @@ -115,7 +115,14 @@ def _assign_attr( curr_mod.register_parameter(field, attr) elif attr_kind == _AttrKind.BUFFER: assert isinstance(attr, torch.Tensor) - curr_mod.register_buffer(field, attr) + ref_curr_mod = ref_module + for attr_name in prefix: + ref_curr_mod = getattr(ref_curr_mod, attr_name) + curr_mod.register_buffer( + field, + attr, + persistent=field not in ref_curr_mod._non_persistent_buffers_set, + ) else: setattr(curr_mod, field, attr) diff --git a/examples/example_ds3_local_map.py b/examples/example_ds3_local_map.py index cd693839..106ce80b 100644 --- a/examples/example_ds3_local_map.py +++ b/examples/example_ds3_local_map.py @@ -8,25 +8,25 @@ import torch from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Shard from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing._internal.distributed.fake_pg import FakeStore -from autoparallel._testing.models.dsv3 import ( - DeepSeekV3Model, - DeepSeekV3ModelArgs, - MoEArgs, -) +from autoparallel._testing.models.dsv3 import DeepSeekV3Model, make_dsv3_config from autoparallel.api import AutoParallel from autoparallel.shardings.placement_options import NumericsLogger +_DEFAULT_DTENSOR_RNG_SEED = 0 + + +def _seed_dtensor_rng(rng_seed: Optional[int]) -> None: + torch.manual_seed(_DEFAULT_DTENSOR_RNG_SEED if rng_seed is None else rng_seed) + def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): seq_len = 1024 if fake_evaluate: - # must symbolically evaluate to run on 32 dp ranks - # world_size = 2048 - world_size = 256 fake_store = FakeStore() @@ -34,42 +34,16 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): "fake", store=fake_store, rank=0, world_size=world_size ) local_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + _seed_dtensor_rng(rng_seed) mesh = torch.distributed.device_mesh.init_device_mesh( "cuda", (world_size // 64, 64), - mesh_dim_names=( - "dp", - "ep", - ), + mesh_dim_names=("dp", "ep"), ) - config = DeepSeekV3ModelArgs( - vocab_size=102400, - max_seq_len=seq_len, - dim=2048, - inter_dim=10944, - moe_inter_dim=1408, - n_layers=1, # 27, - n_dense_layers=0, # 1, - n_heads=16, - moe_args=MoEArgs( - num_experts=64, - num_shared_experts=2, - top_k=6, - score_func="softmax", - route_norm=False, - score_before_experts=False, - mesh=mesh, - ), - q_lora_rank=0, - kv_lora_rank=512, - qk_nope_head_dim=128, - qk_rope_head_dim=64, - v_head_dim=128, - mscale=0.70, - use_flex_attn=False, - attn_mask_type="causal", - ) + config = make_dsv3_config(num_experts=64, max_seq_len=seq_len) else: dp_degree = 2 ep_degree = 2 @@ -82,49 +56,29 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): int(os.getenv("WORLD_SIZE")) == world_size ), f"Need at least {world_size} GPUs for real evaluation" local_rank = int(os.getenv("LOCAL_RANK")) - torch.distributed.init_process_group(backend="nccl") + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + _seed_dtensor_rng(rng_seed) + torch.distributed.init_process_group(backend="nccl", device_id=device) mesh = torch.distributed.device_mesh.init_device_mesh( "cuda", (dp_degree, ep_degree), - mesh_dim_names=( - "dp", - "ep", - ), + mesh_dim_names=("dp", "ep"), ) - config = DeepSeekV3ModelArgs( - vocab_size=2048, - max_seq_len=seq_len, - dim=256, - inter_dim=1024, - moe_inter_dim=256, - n_layers=4, - n_dense_layers=0, - n_heads=16, - moe_args=MoEArgs( - num_experts=4, - num_shared_experts=2, - top_k=2, - score_func="softmax", - route_norm=False, - score_before_experts=False, - mesh=mesh, - ), - q_lora_rank=0, - kv_lora_rank=512, - qk_nope_head_dim=128, - qk_rope_head_dim=64, - v_head_dim=128, - mscale=0.70, + config = make_dsv3_config( + num_experts=4, top_k=2, n_layers=4, n_dense_layers=0, max_seq_len=seq_len ) local_batch_size = 2 global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1] - device = torch.device(f"cuda:{local_rank}") - # parallelize the model with torch.device("meta"): - model = DeepSeekV3Model(config).bfloat16() + model = DeepSeekV3Model( + config, + mesh=mesh, + compute_dtype=torch.bfloat16, + ) def input_fn(): return torch.randint( @@ -137,10 +91,15 @@ def input_fn(): numerics_logger = None if rng_seed is not None: numerics_logger = NumericsLogger(logs_dir) - with AutoParallel(model, input_fn, mesh, dynamic=True) as autop: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + ) + with AutoParallel( + model, input_fn, mesh, mp_policy=mp_policy, dynamic=True + ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) - # x_sharding = (Shard(0), Replicate()) x_sharding = (Shard(0), Shard(0)) autop.add_input_constraints([x_sharding]) @@ -150,11 +109,6 @@ def input_fn(): parallel_mod = autop.apply_placement(sharding_placement) parallel_mod.to_empty(device=device) - # run weight init on our sharded DTensor params - # TODO: plumb init_std through - # parallel_mod.init_weights( - # init_std=0.02, buffer_device="cuda" - # ) # maybe not correct value parallel_mod.init_weights(buffer_device=device, seed=rng_seed) if rng_seed is not None: numerics_logger.log_model_weights(parallel_mod) @@ -174,45 +128,44 @@ def input_fn(): full_batch.to(torch.float32), prefix="full batch input" ) - # Symbolically evaluate in case you want to test running a graph bigger than your gpu - if fake_evaluate: - # all gather on the tokens takes 128 GiB (4GiB * 32 ranks) - shape_env = ShapeEnv() - with FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=shape_env, - ): - # now let's run it - for x in microbatches: + with torch.autograd.set_multithreading_enabled(False): + if fake_evaluate: + shape_env = ShapeEnv() + with FakeTensorMode( + allow_non_fake_inputs=True, + shape_env=shape_env, + ): + for x in microbatches: + out = parallel_mod(x) + out.backward(torch.ones_like(out)) + else: + for i, x in enumerate(microbatches): + assert x.shape[0] == 2 out = parallel_mod(x) + assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" out.backward(torch.ones_like(out)) - else: - for i, x in enumerate(microbatches): - assert x.shape[0] == 2 - out = parallel_mod(x) - assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" - out.backward(torch.ones_like(out)) - if rng_seed is not None: - numerics_logger.log_diff(out, prefix=f"mb{i} fwd out") + if rng_seed is not None: + numerics_logger.log_diff(out, prefix=f"mb{i} fwd out") - if rng_seed is not None: - for k, v in parallel_mod.named_parameters(): - numerics_logger.log_diff(v.grad, prefix=f"grad {k}") + if rng_seed is not None: + for k, v in parallel_mod.named_parameters(): + numerics_logger.log_diff(v.grad, prefix=f"grad {k}") print("All good!") if torch.distributed.is_initialized(): - torch.distributed.barrier() - torch.cuda.synchronize() + if torch.distributed.get_backend() == torch.distributed.Backend.NCCL: + torch.distributed.barrier(device_ids=[local_rank]) + else: + torch.distributed.barrier() + torch.cuda.synchronize(device) torch.distributed.destroy_process_group() if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser( - description="Run DeepSeek V3 pipeline parallel example" - ) + parser = argparse.ArgumentParser(description="Run DeepSeek V3 local_map example") parser.add_argument( "--fake-evaluate", action="store_true", @@ -235,7 +188,6 @@ def input_fn(): if args.rng_seed is not None: torch.use_deterministic_algorithms(True) - torch.manual_seed(args.rng_seed) run_test( fake_evaluate=args.fake_evaluate, rng_seed=args.rng_seed, logs_dir=args.logs_dir diff --git a/tests/test_graph_clustering.py b/tests/test_graph_clustering.py index b4522c59..3a4c9d28 100644 --- a/tests/test_graph_clustering.py +++ b/tests/test_graph_clustering.py @@ -11,11 +11,7 @@ from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard -from autoparallel._testing.models.dsv3 import ( - DeepSeekV3Model, - DeepSeekV3ModelArgs, - MoEArgs, -) +from autoparallel._testing.models.dsv3 import DeepSeekV3Model, make_dsv3_config from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel from autoparallel.graph_passes.graph_clustering import get_identical_regions @@ -215,14 +211,13 @@ def input_fn(): def _setup_ds3_local_map_autop(device_mesh_2d, n_layers=2): global_batch_size = 2 * device_mesh_2d.shape[0] * device_mesh_2d.shape[1] - moe_args = MoEArgs(mesh=device_mesh_2d) - config = DeepSeekV3ModelArgs( - n_layers=n_layers, - n_dense_layers=0, - moe_args=moe_args, - ) + config = make_dsv3_config(n_layers=n_layers, n_dense_layers=0) with torch.device("meta"): - model = DeepSeekV3Model(config).bfloat16() + model = DeepSeekV3Model( + config, + mesh=device_mesh_2d, + compute_dtype=torch.bfloat16, + ) for module in model.modules(): if hasattr(module, "axis_name"): module.axis_name = device_mesh_2d.mesh_dim_names[1] @@ -231,7 +226,7 @@ def input_fn(): return torch.randint( 0, config.vocab_size, - (global_batch_size, config.max_seq_len), + (global_batch_size, config.rope.max_seq_len), device="cuda", ) diff --git a/tests/test_graph_utils.py b/tests/test_graph_utils.py index c8740012..fb76eb4b 100644 --- a/tests/test_graph_utils.py +++ b/tests/test_graph_utils.py @@ -6,13 +6,48 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx -from autoparallel.graph_passes.graph_utils import _replace_view_mm_view_with_einsum +from autoparallel.graph_passes.graph_utils import ( + _replace_view_mm_view_with_einsum, + functionalize_fresh_index_put_mutations, +) def _count_ops(gm, target): return len(gm.graph.find_nodes(op="call_function", target=target)) +def test_functionalize_fresh_index_put_mutations(): + def f(x, idx, src): + out = torch.empty_like(x) + return torch.ops.aten.index_put_.default(out, [idx], src) + + x = torch.zeros(4, 3) + idx = torch.tensor([0, 1, 2, 3]) + src = torch.randn(4, 3) + gm = make_fx(f)(x, idx, src) + + assert _count_ops(gm, torch.ops.aten.index_put_.default) == 1 + + assert functionalize_fresh_index_put_mutations(gm) + + assert _count_ops(gm, torch.ops.aten.index_put_.default) == 0 + assert _count_ops(gm, torch.ops.aten.index_put.default) == 1 + torch.testing.assert_close(gm(x, idx, src), f(x, idx, src)) + + +def test_functionalize_fresh_index_put_mutations_skips_inputs(): + def f(x, idx, src): + return torch.ops.aten.index_put_.default(x, [idx], src) + + x = torch.zeros(4, 3) + idx = torch.tensor([0, 2]) + src = torch.randn(2, 3) + gm = make_fx(f)(x, idx, src) + + assert not functionalize_fresh_index_put_mutations(gm) + assert _count_ops(gm, torch.ops.aten.index_put_.default) == 1 + + def test_forward_pattern_3d(): """view(x, [-1,K]) -> mm(_, w) -> view(_, [B,S,N]) is replaced by einsum.""" B, S, K, N = 2, 8, 16, 32 diff --git a/tests/test_module_construction.py b/tests/test_module_construction.py index a524eeb4..aff8a2ba 100644 --- a/tests/test_module_construction.py +++ b/tests/test_module_construction.py @@ -254,6 +254,8 @@ def forward(self, x): if fqn not in seen.values() ) assert mod.get_buffer(alias_fqn) is mod.get_buffer(canonical_fqn) + assert "freqs_cis" not in mod.state_dict() + assert "rope.cache" not in mod.state_dict() def test_module_alias_reestablished():