diff --git a/tests/models/automodel_test.py b/tests/models/automodel_test.py index fe3bcddd6..1bce19bc7 100644 --- a/tests/models/automodel_test.py +++ b/tests/models/automodel_test.py @@ -101,7 +101,7 @@ def _get_all_models_test_parameters(): dict(testcase_name="qwen3-14b", model_name="qwen3-14b"), dict(testcase_name="qwen3-30b-a3b", model_name="qwen3-30b-a3b"), dict(testcase_name="qwen3-32b", model_name="qwen3-32b"), - dict(testcase_name="Qwen3-32B", model_name="Qwen3-32B"), + dict(testcase_name="qwen3-235b-a22b", model_name="qwen3-235b-a22b"), ) diff --git a/tunix/models/qwen3/model.py b/tunix/models/qwen3/model.py index 6d5b1d979..e8c070bdc 100644 --- a/tunix/models/qwen3/model.py +++ b/tunix/models/qwen3/model.py @@ -289,6 +289,22 @@ def qwen3_32b(cls): # qwen3-32B rope_theta=1_000_000, ) + @classmethod + def qwen3_235b_a22b(cls): # qwen3-235B-A22B + return cls( + num_layers=94, + vocab_size=151936, + embed_dim=4096, + hidden_dim=1536, + num_heads=64, + head_dim=128, + num_kv_heads=4, + norm_eps=1e-06, + rope_theta=1_000_000, + num_experts=128, + num_experts_per_tok=8, + ) + def shard(x: jnp.ndarray, s: Tuple[str, ...]): mesh = pxla.thread_resources.env.physical_mesh @@ -1073,7 +1089,9 @@ def __call__( self.config.remat_config == RematConfig.DECODER or self.config.remat_config == RematConfig.DECODER.value ): - return nnx.remat(self.block.__func__)(self, x, segment_pos, cache, attn_mask) + return nnx.remat(self.block.__func__)( + self, x, segment_pos, cache, attn_mask + ) else: return self.block(x, segment_pos, cache, attn_mask)