AOT-Inductor compile of the full SAM3 pipeline#10
Open
rbavery wants to merge 2 commits into
Open
Conversation
Loads the exported .pt2 produced by export_sam3_full_pipeline.py and runs torch._inductor.aoti_compile_and_package. No workarounds (e.g. split_reductions=False) — baseline run to see what 2.10 / current main does.
Plain torch._inductor.aoti_compile_and_package + aoti_load_package both need an explicit 'import torch._inductor.codecache' on torch 2.10 — the torch.export.pt2_archive._package._load_aoti hits AttributeError without it. Also document the nvidia-cuda-nvcc + nvidia-cuda-cccl install path since torch wheels don't bundle nvcc.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacks on #9. Adds a one-file script that loads the exported
.pt2and packages it withtorch._inductor.aoti_compile_and_package.Result
Plain AOTI compile worked on torch 2.10.0+cu128 — no
split_reductions=False, no math-SDPA forcing, no other workarounds. The crash from pytorch/pytorch#174608 didn't reproduce on this graph.Likely why this works where #5 didn't:
enable_math_sdp(True). PR Enable torch.export of the full SAM3 grounding pipeline #9 routes RPB cross-attn through_cross_attn_with_rpbwhich lets the SDPA dispatcher pick efficient-attention by default — sidesteps thetriton_red_fused__safe_softmax_*codegen path that crashed.Dim("batch", min=1)andDim("num_prompts", min=1)instead ofDim.AUTOfrom a length-1 example.Verification on RTX 3090
Dynamic shapes still work after AOTI compile
Performance (bs=2, np=3 on a 3090)
Numerical drift on real images
truck.jpg+ "truck"groceries.jpg+ "fruit"Top-1 boxes agree to 4 decimal places on both images. Random-Gaussian inputs produced much larger drift (max diff ~41 on masks), but real images stay in-distribution and the bf16 internals don't move scores past any reasonable confidence threshold.
Setup notes (CUDA toolkit not installed system-wide)
The torch wheel doesn't bundle
nvcc. From a fresh box::Plus a torch-2.10 quirk in the loader::
The script does this automatically; documented in the module docstring.
Test plan
python scripts/compile_sam3_aoti.py --in artifacts/export/full_sam3_pipeline.pt2 --out artifacts/aoti/full_sam3_pipeline_aoti.pt2PT2ModelLoader(it already handlesaoti_runnersfirst,exported_programssecond — see model_loaders.py)artifacts/aoti_compare/for sanity (from the export-pipeline-minimal branch validation)Follow-ups (out of scope here)
.pt2packages bothmodelandtransformskeys; this PR only AOTI-compilesmodel.transformsis just aF.interpolateso it's cheap, but a follow-up could AOTI-compile both and bundle them into one archive that drops in for the production.pt2.tests/export/that compiles a tiny graph (not the full 3.5 GB pipeline) to keep CI honest.