Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def _forward_expert_blocked(
packed_chunk_size=self.expert_blocking_packed_chunk_size,
)

return expert_out.sum(dim=0)
return torch.einsum("nth->th", expert_out)

def forward(self, hidden_states):
residuals = hidden_states
Expand Down
5 changes: 3 additions & 2 deletions QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def forward(self, hidden: torch.Tensor):
packed_chunk_size=packed_chunk_size,
)

return expert_out.sum(dim=0).view(B, S, H), router_logits
expert_out_sum = torch.einsum("nth->th", expert_out)
return expert_out_sum.view(B, S, H), router_logits


class QEffPrefillOnlyGptOssMLP(GptOssMLP):
Expand Down Expand Up @@ -460,7 +461,7 @@ def forward_weights_as_activation(self, hidden_states):

# Apply routing weights AFTER expert computation (This is before on Llama4)
experts_out = experts_out * router_top_value.unsqueeze(-1)
experts_out = experts_out.sum(dim=1)
experts_out = torch.einsum("bnd->bd", experts_out)

return experts_out, router_logits

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
act_fn=act_fn,
packed_chunk_size=packed_chunk_size,
)
return expert_out.sum(dim=0).view(B, S, H), router_logits
expert_out_sum = torch.einsum("nth->th", expert_out)
return expert_out_sum.view(B, S, H), router_logits


class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
mxfp6_matmul=True,
mxint8_kv_cache=True,
num_devices=1,
split_retained_state_io=True,
# split_retained_state_io=True,
mos=1,
aic_enable_depth_first=False,
user_tiled=True,
Expand Down
6 changes: 3 additions & 3 deletions examples/disagg_serving/gpt_oss_disagg_mode_with_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
subfunc_npi_file_path = os.path.join(dir_path, "subfunction_120b_npi.yaml")
non_subfunc_npi_file_path = os.path.join(dir_path, "non_subfunction_120b_npi.yaml")

model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32
model_id = "tiny-random/gpt-oss-bf16" # weights are not required to convert to fp32

prompt = """
Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
Expand All @@ -32,8 +32,8 @@
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
PREFILL_SEQ_LEN = 512
CTX_LEN = 8192
NUM_CORES = 16
CTX_LEN = 1024
NUM_CORES = 4
MOE_PREFILL_PACKED_CHUNK_SIZE = 256

qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
num_speculative_tokens=None,
prefill_only=True,
enable_chunking=True,
use_onnx_subfunctions=False,
use_onnx_subfunctions=True,
)


Expand Down
105 changes: 105 additions & 0 deletions tests/unit_test/models/test_model_quickcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,51 @@
TINY_SEQ_CLASSIFICATION_MODEL_ID = "ydshieh/tiny-random-BertForSequenceClassification"
TINY_AWQ_MODEL_ID = "optimum-intel-internal-testing/tiny-mixtral-AWQ-4bit"

TINY_MOE_PREFILL_SUBFUNCTION_CONFIGS = {
"glm4_moe": dict(
max_position_embeddings=128,
num_hidden_layers=1,
num_attention_heads=4,
hidden_size=64,
intermediate_size=128,
moe_intermediate_size=32,
vocab_size=127,
num_key_value_heads=2,
n_routed_experts=4,
num_experts_per_tok=2,
first_k_dense_replace=0,
n_group=1,
topk_group=1,
head_dim=16,
),
"qwen3_moe": dict(
max_position_embeddings=128,
num_hidden_layers=1,
num_attention_heads=4,
hidden_size=64,
intermediate_size=128,
moe_intermediate_size=32,
vocab_size=127,
num_key_value_heads=2,
num_experts=4,
num_experts_per_tok=2,
),
"gpt_oss": dict(
max_position_embeddings=128,
num_hidden_layers=2,
num_attention_heads=2,
hidden_size=32,
intermediate_size=32,
vocab_size=127,
num_key_value_heads=2,
num_local_experts=4,
num_experts_per_tok=2,
head_dim=16,
sliding_window=128,
rope_scaling=None,
),
}

MODEL_KWARGS = {"attn_implementation": "eager"}
PREFIX_CACHING_MODEL_ID = "hf-internal-testing/tiny-random-GPT2LMHeadModel"

Expand Down Expand Up @@ -202,6 +247,30 @@ def _count_decoder_block_subfunctions(onnx_model, qeff_model) -> int:
return sum(any(block_name in func.name for block_name in block_names) for func in onnx_model.functions)


def _decoder_block_subfunction_names(onnx_model, qeff_model) -> Set[str]:
get_submodules = getattr(qeff_model.model, "get_submodules_for_export", None)
assert callable(get_submodules)

submodules = get_submodules()
assert submodules

if not isinstance(submodules, (set, list, tuple)):
submodules = [submodules]

block_names = {module.__name__ for module in submodules if hasattr(module, "__name__")}
assert block_names
return {func.name for func in onnx_model.functions if any(block_name in func.name for block_name in block_names)}


def _function_op_types(onnx_model, function_names: Set[str]) -> Set[str]:
return {
node.op_type
for function_proto in onnx_model.functions
if function_proto.name in function_names
for node in function_proto.node
}


def _assert_has_retained_state_outputs(onnx_path: Path) -> None:
onnx_model = onnx.load(onnx_path, load_external_data=False)
retained_outputs = [output.name for output in onnx_model.graph.output if output.name.endswith("_RetainedState")]
Expand Down Expand Up @@ -509,6 +578,42 @@ def test_causal_subfunction_export_smoke(tmp_path):
assert not any("QEffGPT2Block" in name for name in without_names)


@pytest.mark.llm_model
@pytest.mark.parametrize(
("model_type", "config_kwargs"),
sorted(TINY_MOE_PREFILL_SUBFUNCTION_CONFIGS.items()),
ids=sorted(TINY_MOE_PREFILL_SUBFUNCTION_CONFIGS),
)
def test_moe_prefill_subfunction_export_uses_einsum_reductions(model_type, config_kwargs, tmp_path):
config = AutoConfig.for_model(model_type, **config_kwargs)
model_hf = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS)
model_hf.eval()
qeff_model = QEFFAutoModelForCausalLM(model_hf, continuous_batching=False)

onnx_path = _exported_onnx_path(
qeff_model.export(
tmp_path / f"{model_type}-prefill-subfunctions",
prefill_only=True,
enable_chunking=True,
prefill_seq_len=64,
num_cores=2,
moe_prefill_packed_chunk_size=32,
use_onnx_subfunctions=True,
offload_pt_weights=False,
)
)

onnx_model = onnx.load(onnx_path, load_external_data=False)
decoder_function_names = _decoder_block_subfunction_names(onnx_model, qeff_model)
decoder_op_types = _function_op_types(onnx_model, decoder_function_names)

assert len(decoder_function_names) == config.num_hidden_layers
assert "Einsum" in decoder_op_types
assert "CtxGather3D" in decoder_op_types
assert "CtxScatter3D" in decoder_op_types
assert "CtxScatter3DInt" in decoder_op_types


@pytest.mark.llm_model
@pytest.mark.parametrize(
("model_type", "model_id"),
Expand Down
Loading