-
Notifications
You must be signed in to change notification settings - Fork 88
Feat/moe nsp blocking all models #1016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
divytrip3005
wants to merge
16
commits into
quic:main
Choose a base branch
from
divytrip3005:feat/moe-nsp-blocking-all-models
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
80142ca
feat: NSP-blocked MoE prefill dispatch for Qwen3MOE and GPT-OSS
vbaddi a5bd93a
nit: weights re-route fixes
vbaddi c4ef4c8
nit: weights re-route fixes v1
vbaddi 290839e
nit(0423): gpt oss moe fixed and nit
vbaddi 2804851
nit(0424): ctx batch idx cast to int32
vbaddi 6b049bc
nit(0429): qwen3_moe, gpt_oss: port cumsum scatter-gather-update MoE …
vbaddi 1ae7b23
nit(0429): update modeling files
vbaddi 96df492
Fix CtxGather3D packed-chunk shape expansion
tv-karthikeya a619175
nit(0513): fix: register moe prefill 3d custom ops for subfunction ex…
vbaddi 27c0d28
fix(0415): fix: avoid unsupported prefill MoE reductions in subfuncti…
vbaddi e605480
fix(0415): align prefill MoE chunk export with packed dispatch
vbaddi c76082e
feat: NSP-blocked MoE prefill for GPT-OSS, Qwen3-MoE and GraniteMoE
3a1873b
fix: replace torch.clamp with torch.where for int32 chunk_valid_rows
0ebf018
feat: API-driven NSP blocking for Qwen3-VL-MoE, Qwen3-MoE, GPT-OSS, G…
8c54aa1
fix: add num_cores and moe_prefill_packed_chunk_size as explicit para…
1744f90
fix: moe_prefill_num_nsp API param + GraniteMoE NSP blocking fixes
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,6 +69,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates | |
|
|
||
| # Create indices | ||
| batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) | ||
|
|
||
| # keep index tensor types aligned for backend that require exact dtype match | ||
| batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) | ||
| ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) | ||
| indices = ops.Concat(batch_idx, ctx_idx, axis=2) | ||
|
|
||
|
|
@@ -78,8 +81,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates | |
| class CtxScatterFunc3D(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): | ||
| data = data.clone() | ||
| batch_idx = torch.arange(data.shape[0]).view(-1, 1) | ||
| ctx_idx = position_ids | ||
| ctx_idx = torch.where(position_ids == torch.iinfo(torch.int32).max, data.shape[1] - 1, position_ids) | ||
| data[batch_idx, ctx_idx] = updates | ||
| return data | ||
|
|
||
|
|
@@ -92,9 +96,80 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat | |
| return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) | ||
|
|
||
|
|
||
| class CtxScatterFunc3DGeneralized(torch.autograd.Function): | ||
| """Scatter variant that preserves ``data`` at invalid (INT32_MAX) positions. | ||
|
|
||
| Unlike :class:`CtxScatterFunc3D`, which writes updates for invalid rows to | ||
| ``data.shape[1]-1`` (potentially clobbering valid content), this version | ||
| masks out invalid rows before scattering so ``data`` is left untouched where | ||
| ``position_ids == INT32_MAX``. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): | ||
| data = data.clone() | ||
| valid = position_ids != torch.iinfo(torch.int32).max | ||
| batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) | ||
| data[batch_idx[valid], position_ids[valid].long()] = updates[valid] | ||
| return data | ||
|
|
||
| @staticmethod | ||
| def setup_context(ctx, inputs, outputs): | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: | ||
| return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) | ||
|
|
||
|
|
||
| @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) | ||
| def CtxScatter3DInt( | ||
| data: onnxscript.INT32, position_ids: onnxscript.INT32, updates: onnxscript.INT32 | ||
| ) -> onnxscript.INT32: | ||
| # Find dims | ||
| batch_size = ops.Gather(ops.Shape(data), [0]) | ||
| seq_len = ops.Gather(ops.Shape(position_ids), [1]) | ||
|
|
||
| # Expanded shape to create indices | ||
| zero = ops.Constant(value_ints=[0]) | ||
| one = ops.Constant(value_ints=[1]) | ||
| exp_shape = ops.Concat(batch_size, seq_len, one, axis=0) | ||
|
|
||
| # Create indices | ||
| batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) | ||
| batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) | ||
| ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) | ||
| indices = ops.Concat(batch_idx, ctx_idx, axis=2) | ||
|
|
||
| return ops.ScatterND(data, indices, updates) | ||
|
|
||
|
|
||
| class CtxScatterFunc3DInt(torch.autograd.Function): | ||
| """Int32-typed scatter used to build a packed->original index table.""" | ||
|
|
||
| @staticmethod | ||
| def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): | ||
| data = data.clone() | ||
| valid = position_ids != torch.iinfo(torch.int32).max | ||
| batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) | ||
| data[batch_idx[valid], position_ids[valid].long()] = updates[valid] | ||
| return data | ||
|
|
||
| @staticmethod | ||
| def setup_context(ctx, inputs, outputs): | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: | ||
| return g.onnxscript_op(CtxScatter3DInt, data, position_ids, updates).setTypeAs(data) | ||
|
|
||
|
|
||
| @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) | ||
| def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: | ||
| ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0])) | ||
| batch_size = ops.Slice(ops.Shape(data), starts=[0], ends=[1], axes=[0]) | ||
| idx_seq_len = ops.Slice(ops.Shape(ctx_indices), starts=[1], ends=[2], axes=[0]) | ||
| expand_shape = ops.Concat(batch_size, idx_seq_len, axis=0) | ||
| ctx_indices = ops.Expand(ctx_indices, expand_shape) | ||
| ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) | ||
| return ops.GatherND(data, ctx_indices, batch_dims=1) | ||
|
|
||
|
|
@@ -103,6 +178,7 @@ class CtxGatherFunc3D(torch.autograd.Function): | |
| @staticmethod | ||
| def forward(data: torch.Tensor, ctx_indices: torch.Tensor): | ||
| batch_indices = torch.arange(data.shape[0]).view(-1, 1) | ||
| ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) | ||
| return data[batch_indices, ctx_indices] | ||
|
|
||
| @staticmethod | ||
|
|
@@ -114,6 +190,31 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor | |
| return g.onnxscript_op(CtxGather3D, data, ctx_indices).setTypeAs(data) | ||
|
|
||
|
|
||
| class CtxGatherFunc3DGeneralized(torch.autograd.Function): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Let's rebase to 1.22_tmp, these changes should catch up in there. |
||
| """Gather variant that tolerates INT32_MAX indices (invalid rows read from 0). | ||
|
|
||
| Semantically equivalent to :class:`CtxGatherFunc3D` on the PyTorch side but | ||
| exposed as a separate autograd op so callers using the packed/cumsum scatter | ||
| pipeline can be easily recognized and so the ONNX symbolic omits | ||
| ``setTypeAs`` (needed when the caller already has a matching dtype on | ||
| ``data`` and wants the op signature to flow through without dtype pinning). | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(data: torch.Tensor, ctx_indices: torch.Tensor): | ||
| batch_indices = torch.arange(data.shape[0]).view(-1, 1) | ||
| ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) | ||
| return data[batch_indices, ctx_indices] | ||
|
|
||
| @staticmethod | ||
| def setup_context(ctx, inputs, outputs): | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: | ||
| return g.onnxscript_op(CtxGather3D, data, ctx_indices) | ||
|
|
||
|
|
||
| @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) | ||
| def CtxGather( | ||
| data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no need of this imo, this would be same as
num_coresalready.