Skip to content

Align DeepSeek V3 test config with TorchTitan shape#441

Merged
sanketpurandare merged 1 commit into
mainfrom
sanketpurandare/stack/3
May 8, 2026
Merged

Align DeepSeek V3 test config with TorchTitan shape#441
sanketpurandare merged 1 commit into
mainfrom
sanketpurandare/stack/3

Conversation

@sanketpurandare

@sanketpurandare sanketpurandare commented May 1, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


Align DeepSeek V3 test config with TorchTitan shape

This refactors AutoParallel's DeepSeek V3 test model around a hierarchical config shape that mirrors TorchTitan's DeepSeek V3 configuration while keeping make_dsv3_config() as the lightweight constructor for tests and examples. The model now reads layer, attention, RoPE, norm, FFN, and MoE settings through config objects, uses lm_head naming, accepts an optional mesh and compute_dtype, and preserves graph-trainer annotations for module FQNs and expert-parallel regions.

The DS3 implementation now performs linear, RMSNorm, attention, dense FFN, shared expert, and output computations through explicit compute-dtype helpers so the debug model can run in the same bfloat16 style expected by the TorchTitan comparison path. The old FORCE_BALANCED_ROUTING and CPU fill-index path are removed, expert execution unwraps DTensor weights locally where needed, and the local_map example uses MixedPrecisionPolicy plus make_dsv3_config() instead of maintaining a separate flat DeepSeekV3ModelArgs/MoEArgs construction path.

The local_map example also now binds each rank to its CUDA device before mesh/DTensor work, seeds DTensor RNG state explicitly, initializes NCCL with device_id, runs backward with autograd multithreading disabled, initializes weights on the rank device, and uses a device-specific final barrier. Those changes clean up the DTensor RNG sync, CUDA context/cuBLAS, and NCCL barrier warnings while keeping the example aligned with the real 2D dp/ep sharding constraints.

Two supporting graph/module fixes are included because they are required by the updated DS3 path. AutoParallel functionalizes index_put_ mutations when the mutation target is a fresh non-input tensor before AOT compilation, with tests that ensure input mutations are left alone. Parallel module construction now preserves non-persistent buffer registration when rebuilding sharded modules, so aliased RoPE buffers such as freqs_cis and rope.cache do not reappear in state_dict().

Authored with Claude.

sanketpurandare added a commit that referenced this pull request May 1, 2026
The flat DeepSeekV3ModelArgs and MoEArgs dataclasses are replaced by a
tree of small dataclasses (DeepSeekV3Config -> LayerConfig ->
AttentionConfig / MoEConfig / ...) whose attribute paths match
torchtitan's DeepSeekV3Model.Config. Because the model reads config
attributes via duck typing (no torchtitan import), either autoparallel's
own DeepSeekV3Config or torchtitan's Config can be passed in.

Concrete changes in dsv3.py:
- Deleted DeepSeekV3ModelArgs and MoEArgs.
- Added config dataclasses: DeepSeekV3Config, LayerConfig,
  AttentionConfig, MoEConfig, RoPEConfig, NormConfig, etc.
- Added make_dsv3_config() factory that builds the config tree from
  scalar hyperparameters (same role as torchtitan's _debugmodel()).
- MoE.__init__ now takes explicit keyword args instead of MoEArgs.
  The DeviceMesh (needed by local_map) is a constructor parameter
  threaded through DeepSeekV3Model -> TransformerBlock -> MoE.
- Attention.__init__ takes (attn_config, model_config) and derives
  use_flex_attn from inner_attention type name instead of a flag.
- precompute_freqs_cis reads from config.rope.*.

example_ds3_local_map.py is updated to use make_dsv3_config().

Validated: pytest tests/ passes (327 tests, 1 xfail). Model
construction verified with both autoparallel's DeepSeekV3Config
and torchtitan's Config via duck typing.

stack-info: PR: #441, branch: sanketpurandare/stack/3
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/2 branch from 1f1a12c to 0ef0c2d Compare May 1, 2026 02:52
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 96ea142 to 2c6c7e7 Compare May 1, 2026 02:52
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 1, 2026
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main May 4, 2026 03:14
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 2c6c7e7 to 2a2aec7 Compare May 4, 2026 03:14
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 May 4, 2026 03:14
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/2 branch from 0ef0c2d to a8d2d18 Compare May 4, 2026 03:18
sanketpurandare added a commit that referenced this pull request May 4, 2026
The flat DeepSeekV3ModelArgs and MoEArgs dataclasses are replaced by a
tree of small dataclasses (DeepSeekV3Config -> LayerConfig ->
AttentionConfig / MoEConfig / ...) whose attribute paths match
torchtitan's DeepSeekV3Model.Config. Because the model reads config
attributes via duck typing (no torchtitan import), either autoparallel's
own DeepSeekV3Config or torchtitan's Config can be passed in.

Concrete changes in dsv3.py:
- Deleted DeepSeekV3ModelArgs and MoEArgs.
- Added config dataclasses: DeepSeekV3Config, LayerConfig,
  AttentionConfig, MoEConfig, RoPEConfig, NormConfig, etc.
- Added make_dsv3_config() factory that builds the config tree from
  scalar hyperparameters (same role as torchtitan's _debugmodel()).
- MoE.__init__ now takes explicit keyword args instead of MoEArgs.
  The DeviceMesh (needed by local_map) is a constructor parameter
  threaded through DeepSeekV3Model -> TransformerBlock -> MoE.
- Attention.__init__ takes (attn_config, model_config) and derives
  use_flex_attn from inner_attention type name instead of a flag.
- precompute_freqs_cis reads from config.rope.*.

example_ds3_local_map.py is updated to use make_dsv3_config().

Validated: pytest tests/ passes (327 tests, 1 xfail). Model
construction verified with both autoparallel's DeepSeekV3Config
and torchtitan's Config via duck typing.

stack-info: PR: #441, branch: sanketpurandare/stack/3
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 2a2aec7 to 7425c44 Compare May 4, 2026 03:18
@sanketpurandare sanketpurandare marked this pull request as ready for review May 4, 2026 03:21
@sanketpurandare sanketpurandare requested review from aditvenk and xmfan May 4, 2026 03:22
@sanketpurandare sanketpurandare marked this pull request as draft May 4, 2026 03:32
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main May 4, 2026 03:32
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 7425c44 to 77d855b Compare May 4, 2026 03:32
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 May 4, 2026 03:32
@sanketpurandare sanketpurandare marked this pull request as ready for review May 4, 2026 03:32
@sanketpurandare sanketpurandare marked this pull request as draft May 4, 2026 04:02
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main May 4, 2026 04:02
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch 2 times, most recently from 87a7c6c to 9afe651 Compare May 4, 2026 04:07
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 May 4, 2026 04:07
@sanketpurandare sanketpurandare marked this pull request as ready for review May 4, 2026 04:08
@sanketpurandare sanketpurandare marked this pull request as ready for review May 4, 2026 20:00
@sanketpurandare

sanketpurandare commented May 4, 2026

Copy link
Copy Markdown
Contributor Author

does this now match titan's new dsv3 definition?

Yes it does and we can directly pass the config now from TorchTitan, no patching needed, same for annotations. Also by adding a compute type we don't have to force the entire model to be bfloat16, it interfaces nicely with optimizer in torchtitan as well.

@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/2 branch from 215f0c2 to 9308ab5 Compare May 4, 2026 20:28
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from d1ce828 to 6bb34f3 Compare May 4, 2026 20:28
sanketpurandare added a commit that referenced this pull request May 4, 2026
The flat DeepSeekV3ModelArgs and MoEArgs dataclasses are replaced by a
tree of small dataclasses (DeepSeekV3Config -> LayerConfig ->
AttentionConfig / MoEConfig / ...) whose attribute paths match
torchtitan's DeepSeekV3Model.Config. Because the model reads config
attributes via duck typing (no torchtitan import), either autoparallel's
own DeepSeekV3Config or torchtitan's Config can be passed in.

Concrete changes in dsv3.py:
- Deleted DeepSeekV3ModelArgs and MoEArgs.
- Added config dataclasses: DeepSeekV3Config, LayerConfig,
  AttentionConfig, MoEConfig, RoPEConfig, NormConfig, etc.
- Added make_dsv3_config() factory that builds the config tree from
  scalar hyperparameters (same role as torchtitan's _debugmodel()).
- MoE.__init__ now takes explicit keyword args instead of MoEArgs.
  The DeviceMesh (needed by local_map) is a constructor parameter
  threaded through DeepSeekV3Model -> TransformerBlock -> MoE.
- Attention.__init__ takes (attn_config, model_config) and derives
  use_flex_attn from inner_attention type name instead of a flag.
- precompute_freqs_cis reads from config.rope.*.

example_ds3_local_map.py is updated to use make_dsv3_config().

Validated: pytest tests/ passes (327 tests, 1 xfail). Model
construction verified with both autoparallel's DeepSeekV3Config
and torchtitan's Config via duck typing.

stack-info: PR: #441, branch: sanketpurandare/stack/3
@sanketpurandare sanketpurandare marked this pull request as draft May 4, 2026 23:50
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main May 4, 2026 23:50
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 6bb34f3 to 562992a Compare May 4, 2026 23:50
@sanketpurandare sanketpurandare changed the title Replace flat DeepSeekV3ModelArgs with hierarchical config Align DeepSeek V3 test config with torchtitan shape May 4, 2026
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 May 4, 2026 23:50
@sanketpurandare sanketpurandare marked this pull request as ready for review May 4, 2026 23:50
@sanketpurandare sanketpurandare marked this pull request as draft May 8, 2026 00:23
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main May 8, 2026 00:23
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 562992a to 13740d1 Compare May 8, 2026 00:23
@sanketpurandare sanketpurandare changed the title Align DeepSeek V3 test config with torchtitan shape Align DeepSeek V3 test config with TorchTitan shape May 8, 2026
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 May 8, 2026 00:23
@sanketpurandare sanketpurandare marked this pull request as ready for review May 8, 2026 00:23
@sanketpurandare sanketpurandare marked this pull request as draft May 8, 2026 00:30
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main May 8, 2026 00:30
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 13740d1 to 539c35b Compare May 8, 2026 00:30
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 May 8, 2026 00:30
@sanketpurandare sanketpurandare marked this pull request as ready for review May 8, 2026 00:31
sanketpurandare added a commit that referenced this pull request May 8, 2026
This refactors AutoParallel's DeepSeek V3 test model around a hierarchical config shape that mirrors TorchTitan's DeepSeek V3 configuration while keeping make_dsv3_config() as the lightweight constructor for tests and examples. The model now reads layer, attention, RoPE, norm, FFN, and MoE settings through config objects, uses lm_head naming, accepts an optional mesh and compute_dtype, and preserves graph-trainer annotations for module FQNs and expert-parallel regions.

The DS3 implementation now performs linear, RMSNorm, attention, dense FFN, shared expert, and output computations through explicit compute-dtype helpers so the debug model can run in the same bfloat16 style expected by the TorchTitan comparison path. The old FORCE_BALANCED_ROUTING and CPU fill-index path are removed, expert execution unwraps DTensor weights locally where needed, and the local_map example uses MixedPrecisionPolicy plus make_dsv3_config() instead of maintaining a separate flat DeepSeekV3ModelArgs/MoEArgs construction path.

The local_map example also now binds each rank to its CUDA device before mesh/DTensor work, seeds DTensor RNG state explicitly, initializes NCCL with device_id, runs backward with autograd multithreading disabled, initializes weights on the rank device, and uses a device-specific final barrier. Those changes clean up the DTensor RNG sync, CUDA context/cuBLAS, and NCCL barrier warnings while keeping the example aligned with the real 2D dp/ep sharding constraints.

Two supporting graph/module fixes are included because they are required by the updated DS3 path. AutoParallel functionalizes index_put_ mutations when the mutation target is a fresh non-input tensor before AOT compilation, with tests that ensure input mutations are left alone. Parallel module construction now preserves non-persistent buffer registration when rebuilding sharded modules, so aliased RoPE buffers such as freqs_cis and rope.cache do not reappear in state_dict().

Authored with Claude.

stack-info: PR: #441, branch: sanketpurandare/stack/3
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 539c35b to f6f6c35 Compare May 8, 2026 00:32
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main May 8, 2026 00:32
This refactors AutoParallel's DeepSeek V3 test model around a hierarchical config shape that mirrors TorchTitan's DeepSeek V3 configuration while keeping make_dsv3_config() as the lightweight constructor for tests and examples. The model now reads layer, attention, RoPE, norm, FFN, and MoE settings through config objects, uses lm_head naming, accepts an optional mesh and compute_dtype, and preserves graph-trainer annotations for module FQNs and expert-parallel regions.

The DS3 implementation now performs linear, RMSNorm, attention, dense FFN, shared expert, and output computations through explicit compute-dtype helpers so the debug model can run in the same bfloat16 style expected by the TorchTitan comparison path. The old FORCE_BALANCED_ROUTING and CPU fill-index path are removed, expert execution unwraps DTensor weights locally where needed, and the local_map example uses MixedPrecisionPolicy plus make_dsv3_config() instead of maintaining a separate flat DeepSeekV3ModelArgs/MoEArgs construction path.

The local_map example also now binds each rank to its CUDA device before mesh/DTensor work, seeds DTensor RNG state explicitly, initializes NCCL with device_id, runs backward with autograd multithreading disabled, initializes weights on the rank device, and uses a device-specific final barrier. Those changes clean up the DTensor RNG sync, CUDA context/cuBLAS, and NCCL barrier warnings while keeping the example aligned with the real 2D dp/ep sharding constraints.

Two supporting graph/module fixes are included because they are required by the updated DS3 path. AutoParallel functionalizes index_put_ mutations when the mutation target is a fresh non-input tensor before AOT compilation, with tests that ensure input mutations are left alone. Parallel module construction now preserves non-persistent buffer registration when rebuilding sharded modules, so aliased RoPE buffers such as freqs_cis and rope.cache do not reappear in state_dict().

Authored with Claude.

stack-info: PR: #441, branch: sanketpurandare/stack/3
@sanketpurandare sanketpurandare marked this pull request as draft May 8, 2026 00:33
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from f6f6c35 to d6e0f18 Compare May 8, 2026 00:33
@sanketpurandare sanketpurandare marked this pull request as ready for review May 8, 2026 00:34
@sanketpurandare sanketpurandare merged commit 8d3c8d9 into main May 8, 2026
6 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants