From 0997ad0c48a9205d971cb3c32b5553bfcf8708e2 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:22:34 -0700 Subject: [PATCH 01/10] Enable torch.export of the full SAM3 grounding pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bundle the image encoder, text encoder, encoder fusion, and decoder into a single ExportedProgram so consumers can ship one .pt2. The graph accepts dynamic batch and prompt dims; image H/W and CLIP context length are fixed. Required model patches: - sam3/model/encoder.py: fix x.dim → x.dim() (called the method object). - sam3/model/geometry_encoders.py: * Skip pin_memory + non_blocking copy under torch._dynamo.is_compiling() and on CPU; both fail under torch.export. * Convert ROIAlign input from a list-of-tensors (one per batch) to the [N, 5] batched format with batch_idx prepended. Lists of tensors are not traceable through torch.export. - sam3/model/position_encoding.py: * Skip the ndim/length assert in _encode_xy when tracing — symbolic dims from dynamic prompts trip GuardOnDataDependentSymNode. * Skip the shape-keyed cache when shapes are SymInts (looking up a SymInt in a dict of int keys raises the same guard). - sam3/model/decoder.py: * Allow 4D (bs, nheads, nq, hw) cross_attn_mask through TransformerDecoderLayer. nn.MultiheadAttention rejects 4D additive bias, so add a manual SDPA path (_cross_attn_with_rpb) that unpacks in/out projections and calls F.scaled_dot_product_attention with the per-head bias. * Forward the box-RPB matrix in 4D (drop the flatten(0, 1) that produced the (bs*nheads, nq, hw) form for MHA). * In _make_box_rpb_relative, branch is_dynamo_compiling() before the tuple-equality cache check (SymInt tuple equality trips guards) and drop the eager-only shape asserts on deltas/B that fired with dynamic dims even with the is_dynamo_compiling guard. - sam3/model_builder.py: thread num_feature_levels through build_sam3_image_model → _create_sam3_transformer/_model so callers can pin it explicitly (the production export passes 1). New: - scripts/export_sam3_full_pipeline.py — FullSam3PipelineWrapper that unifies the four sub-modules into one forward, plus a CLI to save a .pt2 with dynamic batch/prompt dims and fixed 1008x1008 / 32-token spec. - tests/export/test_full_pipeline_export.py — slow pytest that builds the real model, exports it, and runs the exported module on inputs with batch=2 and num_prompts=4 to exercise both dynamic dims. --- pyproject.toml | 1 + sam3/model/decoder.py | 124 +++++++++++----- sam3/model/encoder.py | 2 +- sam3/model/geometry_encoders.py | 19 ++- sam3/model/position_encoding.py | 15 +- sam3/model_builder.py | 21 ++- scripts/export_sam3_full_pipeline.py | 169 ++++++++++++++++++++++ tests/__init__.py | 0 tests/export/__init__.py | 0 tests/export/test_full_pipeline_export.py | 58 ++++++++ 10 files changed, 362 insertions(+), 47 deletions(-) create mode 100644 scripts/export_sam3_full_pipeline.py create mode 100644 tests/__init__.py create mode 100644 tests/export/__init__.py create mode 100644 tests/export/test_full_pipeline_export.py diff --git a/pyproject.toml b/pyproject.toml index 9df1b67..4f89532 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,3 +133,4 @@ testpaths = ["tests"] python_files = "test_*.py" python_classes = "Test*" python_functions = "test_*" +markers = ["slow: builds the full SAM3 model (heavy GPU + slow)"] diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index b90e55f..8acea03 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -156,24 +156,46 @@ def forward( tgt = tgt + self.catext_dropout(tgt2) tgt = self.catext_norm(tgt) - if presence_token is not None: - presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) - cross_attn_mask = torch.cat( - [presence_token_mask, cross_attn_mask], dim=1 - ) # (bs*nheads, 1+nq, hw) + if presence_token is not None and cross_attn_mask is not None: + # Support both 3D (bs*nheads, nq, hw) and 4D (bs, nheads, nq, hw) + # cross_attn_mask shapes. The 4D form is used by the export-friendly + # SDPA path below; nn.MultiheadAttention requires the 3D form. + if cross_attn_mask.dim() == 4: + presence_token_mask = torch.zeros_like(cross_attn_mask[:, :, :1, :]) + cross_attn_mask = torch.cat( + [presence_token_mask, cross_attn_mask], dim=2 + ) # (bs, nheads, 1+nq, hw) + else: + presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) + cross_attn_mask = torch.cat( + [presence_token_mask, cross_attn_mask], dim=1 + ) # (bs*nheads, 1+nq, hw) # Cross attention to image - tgt2 = self.cross_attn( - query=self.with_pos_embed(tgt, tgt_query_pos), - key=self.with_pos_embed(memory, memory_pos), - value=memory, - attn_mask=cross_attn_mask, - key_padding_mask=( - memory_key_padding_mask.transpose(0, 1) - if memory_key_padding_mask is not None - else None - ), - )[0] + key_padding_mask = ( + memory_key_padding_mask.transpose(0, 1) + if memory_key_padding_mask is not None + else None + ) + if cross_attn_mask is not None and cross_attn_mask.dim() == 4: + # nn.MultiheadAttention does not accept (bs, nheads, q, k) attention + # bias tensors. Box-RPB cross-attn passes a 4D additive bias, so + # route through a manual SDPA path that preserves the per-head bias. + tgt2 = self._cross_attn_with_rpb( + query=self.with_pos_embed(tgt, tgt_query_pos), + key=self.with_pos_embed(memory, memory_pos), + value=memory, + attn_bias=cross_attn_mask, + key_padding_mask=key_padding_mask, + ) + else: + tgt2 = self.cross_attn( + query=self.with_pos_embed(tgt, tgt_query_pos), + key=self.with_pos_embed(memory, memory_pos), + value=memory, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) @@ -188,6 +210,50 @@ def forward( return tgt, presence_token_out + def _cross_attn_with_rpb( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor, + key_padding_mask: Optional[Tensor], + ) -> Tensor: + """ + Manual SDPA cross-attention that preserves a per-head additive bias. + Used when ``attn_bias`` is 4D (bs, nheads, q, k) — the relative box + position bias path. ``nn.MultiheadAttention`` rejects 4D masks, so we + unpack the in/out projections and call ``scaled_dot_product_attention``. + """ + mha = self.cross_attn + assert isinstance(mha, nn.MultiheadAttention) + q, k, v = torchF._in_projection_packed( + query, key, value, mha.in_proj_weight, mha.in_proj_bias + ) + tgt_len, bsz, _ = q.shape + num_heads = mha.num_heads + head_dim = mha.head_dim + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3) + k = k.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3) + v = v.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3) + src_len = k.shape[2] + bias = attn_bias + if bias.dim() == 3: + bias = bias.view(bsz, num_heads, tgt_len, src_len) + if key_padding_mask is not None: + pad = key_padding_mask[:, None, None, :].to(dtype=q.dtype) + pad = pad.masked_fill(pad > 0, float("-inf")) + bias = bias + pad + attn_output = torchF.scaled_dot_product_attention( + q, + k, + v, + attn_mask=bias, + dropout_p=mha.dropout if self.training else 0.0, + is_causal=False, + ) + attn_output = attn_output.permute(2, 0, 1, 3).reshape(tgt_len, bsz, -1) + return torchF.linear(attn_output, mha.out_proj.weight, mha.out_proj.bias) + class TransformerDecoder(nn.Module): def __init__( @@ -338,11 +404,13 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device) self.compilable_stored_size = (H, W) - if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == ( - H, - W, - ): - # good, hitting the cache, will be compilable + if torch.compiler.is_dynamo_compiling(): + # When tracing, always reuse the compilable cache. The + # ``compilable_stored_size == (H, W)`` check would compare a tuple + # of SymInts against concrete ints and raise + # GuardOnDataDependentSymNode. + coords_h, coords_w = self.compilable_cord_cache + elif self.compilable_stored_size == (H, W): coords_h, coords_w = self.compilable_cord_cache else: # cache miss, will create compilation issue @@ -353,9 +421,6 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): ) coords_h, coords_w = self.coord_cache[feat_size] - assert coords_h.shape == (H,) - assert coords_w.shape == (W,) - deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] deltas_y = deltas_y.view(bs, num_queries, -1, 2) deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] @@ -393,20 +458,12 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): act_ckpt_enable=self.training and self.use_act_checkpoint, ) # bs, num_queries, H, n_heads - if not torch.compiler.is_dynamo_compiling(): - assert deltas_x.shape[:3] == (bs, num_queries, W) - assert deltas_y.shape[:3] == (bs, num_queries, H) - B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( 2 ) # bs, num_queries, H, W, n_heads - if not torch.compiler.is_dynamo_compiling(): - assert B.shape[:4] == (bs, num_queries, H, W) B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W B = B.contiguous() # memeff attn likes ordered strides - if not torch.compiler.is_dynamo_compiling(): - assert B.shape[2:] == (num_queries, H * W) return B def forward( @@ -519,11 +576,12 @@ def forward( assert spatial_shapes.shape[0] == 1, ( "only single scale support implemented" ) + # Keep the 4D (bs, nheads, nq, H*W) shape so the export-friendly + # SDPA path in TransformerDecoderLayer can apply the bias per-head. memory_mask = self._get_rpb_matrix( reference_boxes, (spatial_shapes[0, 0], spatial_shapes[0, 1]), ) - memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W) if self.training: assert self.use_act_checkpoint, ( "Activation checkpointing not enabled in the decoder" diff --git a/sam3/model/encoder.py b/sam3/model/encoder.py index 3fc9406..49b0219 100644 --- a/sam3/model/encoder.py +++ b/sam3/model/encoder.py @@ -538,7 +538,7 @@ def forward( else None ) else: - assert all(x.dim == 4 for x in src), ( + assert all(x.dim() == 4 for x in src), ( "expected list of (bs, c, h, w) tensors" ) diff --git a/sam3/model/geometry_encoders.py b/sam3/model/geometry_encoders.py index d60ee54..7e2e525 100644 --- a/sam3/model/geometry_encoders.py +++ b/sam3/model/geometry_encoders.py @@ -645,11 +645,26 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats): # We need to denormalize, and convert to [x, y, x, y] boxes_xyxy = box_cxcywh_to_xyxy(boxes) scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) - scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) + # pin_memory + non_blocking copy is unsupported under torch.export tracing. + if not torch._dynamo.is_compiling() and boxes_xyxy.device.type != "cpu": + scale = scale.pin_memory().to( + device=boxes_xyxy.device, non_blocking=True + ) + else: + scale = scale.to(device=boxes_xyxy.device) scale = scale.view(1, 1, 4) boxes_xyxy = boxes_xyxy * scale + # ROIAlign accepts a list of per-batch box tensors OR a single + # [N, 5] tensor with batch_idx prepended. The list form is hard to + # trace through torch.export, so use the batched format. + boxes_xyxy = boxes_xyxy.transpose(0, 1) # (bs, n_boxes, 4) + batch_idx = torch.arange( + bs, device=boxes_xyxy.device, dtype=boxes_xyxy.dtype + ) + batch_idx = batch_idx.view(bs, 1, 1).expand(bs, n_boxes, 1) + boxes_for_roi = torch.cat([batch_idx, boxes_xyxy], dim=-1).reshape(-1, 5) sampled = torchvision.ops.roi_align( - img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size + img_feats, boxes_for_roi.float(), self.roi_size ) assert list(sampled.shape) == [ bs * n_boxes, diff --git a/sam3/model/position_encoding.py b/sam3/model/position_encoding.py index f38c62b..76002c0 100644 --- a/sam3/model/position_encoding.py +++ b/sam3/model/position_encoding.py @@ -58,8 +58,11 @@ def __init__( self.cache[size] = self.cache[size].clone().detach() def _encode_xy(self, x, y): - # The positions are expected to be normalized - assert len(x) == len(y) and x.ndim == y.ndim == 1 + # The positions are expected to be normalized. + # Skip the size assert when tracing — symbolic shapes from dynamic + # prompts make `len(x) == len(y)` raise GuardOnDataDependentSymNode. + if not torch._dynamo.is_compiling(): + assert len(x) == len(y) and x.ndim == y.ndim == 1 x_embed = x * self.scale y_embed = y * self.scale @@ -95,9 +98,11 @@ def encode_points(self, x, y, labels): @torch.no_grad() def forward(self, x): - cache_key = None cache_key = (x.shape[-2], x.shape[-1]) - if cache_key in self.cache: + # Skip the cache when tracing with symbolic shapes — looking up a SymInt + # key in a dict with concrete-int keys raises GuardOnDataDependentSymNode. + use_cache = all(isinstance(dim, int) for dim in cache_key) + if use_cache and cache_key in self.cache: return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) y_embed = ( torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) @@ -127,6 +132,6 @@ def forward(self, x): (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - if cache_key is not None: + if use_cache: self.cache[cache_key] = pos[0] return pos diff --git a/sam3/model_builder.py b/sam3/model_builder.py index a5dea87..c6b34a1 100644 --- a/sam3/model_builder.py +++ b/sam3/model_builder.py @@ -122,7 +122,9 @@ def _create_vl_backbone(vit_neck, text_encoder): return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1) -def _create_transformer_encoder(use_fa3=False) -> TransformerEncoderFusion: +def _create_transformer_encoder( + use_fa3: bool = False, num_feature_levels: int = 1 +) -> TransformerEncoderFusion: """Create transformer encoder with its layer.""" encoder_layer = TransformerEncoderLayer( activation="relu", @@ -153,7 +155,7 @@ def _create_transformer_encoder(use_fa3=False) -> TransformerEncoderFusion: layer=encoder_layer, num_layers=6, d_model=256, - num_feature_levels=1, + num_feature_levels=num_feature_levels, frozen=False, use_act_checkpoint=True, add_pooled_text_to_img_feat=False, @@ -307,6 +309,7 @@ def _create_sam3_model( dot_prod_scoring, inst_interactive_predictor, eval_mode, + num_feature_levels: int = 1, ): """Create the SAM3 image model.""" common_params = { @@ -314,7 +317,7 @@ def _create_sam3_model( "transformer": transformer, "input_geometry_encoder": input_geometry_encoder, "segmentation_head": segmentation_head, - "num_feature_levels": 1, + "num_feature_levels": num_feature_levels, "o2m_mask_predict": True, "dot_prod_scoring": dot_prod_scoring, "use_instance_query": False, @@ -527,10 +530,14 @@ def _create_vision_backbone( def _create_sam3_transformer( - has_presence_token: bool = True, use_fa3: bool = False + has_presence_token: bool = True, + use_fa3: bool = False, + num_feature_levels: int = 1, ) -> TransformerWrapper: """Create SAM3 transformer encoder and decoder.""" - encoder: TransformerEncoderFusion = _create_transformer_encoder(use_fa3=use_fa3) + encoder: TransformerEncoderFusion = _create_transformer_encoder( + use_fa3=use_fa3, num_feature_levels=num_feature_levels + ) decoder: TransformerDecoder = _create_transformer_decoder(use_fa3=use_fa3) return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256) @@ -579,6 +586,7 @@ def build_sam3_image_model( enable_segmentation=True, enable_inst_interactivity=False, compile=False, + num_feature_levels: int = 1, ): """ Build SAM3 image model @@ -613,7 +621,7 @@ def build_sam3_image_model( backbone = _create_vl_backbone(vision_encoder, text_encoder) # Create transformer components - transformer = _create_sam3_transformer() + transformer = _create_sam3_transformer(num_feature_levels=num_feature_levels) # Create dot product scoring dot_prod_scoring = _create_dot_product_scoring() @@ -641,6 +649,7 @@ def build_sam3_image_model( dot_prod_scoring, inst_predictor, eval_mode, + num_feature_levels=num_feature_levels, ) if load_from_HF and checkpoint_path is None: checkpoint_path = download_ckpt_from_hf(version="sam3") diff --git a/scripts/export_sam3_full_pipeline.py b/scripts/export_sam3_full_pipeline.py new file mode 100644 index 0000000..72f9ae6 --- /dev/null +++ b/scripts/export_sam3_full_pipeline.py @@ -0,0 +1,169 @@ +""" +Export the full SAM3 image pipeline as a single ``torch.export`` program +suitable for ``torch.export.save`` / ``torch.export.load``. + +Inputs to the exported program: + images: (B, 3, 1008, 1008) float32 -- B is dynamic, H/W are fixed + token_ids: (P, 32) int64 -- P is dynamic (>= 1), L=32 fixed + +Outputs (4-tuple): + pred_logits, pred_boxes, pred_masks, presence_logit_dec + +This wrapper unifies the image encoder, text encoder, encoder fusion, and +decoder into a single graph so consumers can ship one ``.pt2`` artifact. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, cast + +import torch +from torch.export.dynamic_shapes import Dim + +from sam3.model.data_misc import FindStage +from sam3.model.geometry_encoders import Prompt +from sam3.model_builder import build_sam3_image_model + +INPUT_SIZE = 1008 +CONTEXT_LENGTH = 32 + + +class FullSam3PipelineWrapper(torch.nn.Module): + """Single-graph wrapper over the SAM3 grounding pipeline.""" + + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + self.model = model + + def forward( + self, + images: torch.Tensor, + token_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + model = cast(Any, self.model) + device = images.device + bs = images.shape[0] * token_ids.shape[0] + + img_ids = torch.arange(images.shape[0], device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(token_ids.shape[0]) + text_ids = torch.arange(token_ids.shape[0], device=device, dtype=torch.long) + text_ids = text_ids.repeat(images.shape[0]) + + # Text-only grounding: empty geometric prompts. + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.ones(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) + + backbone_out = model.backbone.forward_image(images) + text_encoder = model.backbone.language_backbone + _, text_tokens = text_encoder.encoder(token_ids) + text_tokens = text_tokens.transpose(0, 1) + text_memory = text_encoder.resizer(text_tokens) + text_attention_mask = token_ids.ne(0).ne(1) + backbone_out["language_features"] = text_memory + backbone_out["language_mask"] = text_attention_mask + + find_input = FindStage( + img_ids=img_ids, + text_ids=text_ids, + input_boxes=box_embeddings, + input_boxes_mask=box_mask, + input_boxes_label=box_labels, + input_points=torch.zeros(0, bs, 2, device=device), + input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), + ) + geometric_prompt = Prompt( + box_embeddings=box_embeddings, + box_mask=box_mask, + box_labels=box_labels, + ) + out = model.forward_grounding( + backbone_out=backbone_out, + find_input=find_input, + find_target=None, + geometric_prompt=geometric_prompt, + ) + return ( + out["pred_logits"], + out["pred_boxes"], + out["pred_masks"], + out.get("presence_logit_dec"), + ) + + +def export_full_pipeline( + model: torch.nn.Module, + *, + device: torch.device, + num_export_prompts: int = 3, +) -> torch.export.ExportedProgram: + """Trace ``FullSam3PipelineWrapper`` with dynamic batch and prompt dims. + + ``num_export_prompts`` must be >= 3 so the prompt dim is treated as + dynamic during tracing (a length-2 example would let the tracer specialise + the dim away). + """ + if num_export_prompts < 3: + raise ValueError("Use >= 3 prompts so the prompt dim stays dynamic") + + wrapper = FullSam3PipelineWrapper(model).to(device).eval() + + images = torch.randn(1, 3, INPUT_SIZE, INPUT_SIZE, device=device) + token_ids = torch.zeros(num_export_prompts, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids[:, 0] = 49406 # <|startoftext|> so attention mask is non-empty + + num_prompts = Dim("num_prompts", min=1) + with torch.no_grad(): + return torch.export.export( + wrapper, + (images, token_ids), + dynamic_shapes={ + "images": {0: Dim.AUTO, 2: INPUT_SIZE, 3: INPUT_SIZE}, + "token_ids": {0: num_prompts, 1: CONTEXT_LENGTH}, + }, + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--out", + type=Path, + default=Path("artifacts/export/full_sam3_pipeline.pt2"), + help="Destination .pt2 path", + ) + parser.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument( + "--num-export-prompts", + type=int, + default=3, + help="Prompts used during tracing — must be >= 3 to keep the dim dynamic", + ) + args = parser.parse_args() + + device = torch.device(args.device) + model = build_sam3_image_model( + device=str(device), + eval_mode=True, + enable_segmentation=True, + num_feature_levels=1, + ) + model.eval() + + exported = export_full_pipeline( + model, device=device, num_export_prompts=args.num_export_prompts + ) + args.out.parent.mkdir(parents=True, exist_ok=True) + torch.export.save(exported, str(args.out)) + print(f"Saved export to {args.out}") + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/export/__init__.py b/tests/export/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/export/test_full_pipeline_export.py b/tests/export/test_full_pipeline_export.py new file mode 100644 index 0000000..e1b3284 --- /dev/null +++ b/tests/export/test_full_pipeline_export.py @@ -0,0 +1,58 @@ +"""End-to-end export test for the SAM3 full pipeline. + +Marked ``slow`` because it builds the real model (heavy memory + first call +allocates a CUDA model). Run explicitly with:: + + pytest tests/export/test_full_pipeline_export.py -m slow +""" + +from __future__ import annotations + +import os + +import pytest +import torch + +from sam3.model_builder import build_sam3_image_model +from scripts.export_sam3_full_pipeline import ( + CONTEXT_LENGTH, + INPUT_SIZE, + export_full_pipeline, +) + + +def _device() -> torch.device: + if os.getenv("SAM3_EXPORT_FORCE_CPU") == "1": + return torch.device("cpu") + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.mark.slow +def test_full_pipeline_export_traces_and_runs() -> None: + device = _device() + model = build_sam3_image_model( + device=str(device), + eval_mode=True, + enable_segmentation=True, + num_feature_levels=1, + ) + model.eval() + + exported = export_full_pipeline(model, device=device, num_export_prompts=3) + + # The graph must accept variable batch size and variable num_prompts. + eager_wrapper = exported.module() + + images = torch.randn(2, 3, INPUT_SIZE, INPUT_SIZE, device=device) + token_ids = torch.zeros(4, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids[:, 0] = 49406 + + with torch.no_grad(): + pred_logits, pred_boxes, pred_masks, presence = eager_wrapper(images, token_ids) + + bs_total = images.shape[0] * token_ids.shape[0] + assert pred_logits.shape[0] == bs_total + assert pred_boxes.shape[0] == bs_total + assert pred_masks.shape[0] == bs_total + if presence is not None: + assert presence.shape[0] == bs_total From d8252d0100733fad8281ba6da9ce56b5b1397757 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:24:38 -0700 Subject: [PATCH 02/10] =?UTF-8?q?Use=20torch.compiler.is=5Fcompiling()=20?= =?UTF-8?q?=E2=80=94=20covers=20both=20strict=20and=20non-strict=20export?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit is_dynamo_compiling() returns False during non-strict export, so the SymInt tuple-equality on compilable_stored_size still triggered GuardOnDataDependentSymNode. is_compiling() is the broader API that covers torch.compile, strict export, and non-strict export. --- sam3/model/decoder.py | 10 +++++----- sam3/model/geometry_encoders.py | 2 +- sam3/model/position_encoding.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index 8acea03..438e2e8 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -404,11 +404,11 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device) self.compilable_stored_size = (H, W) - if torch.compiler.is_dynamo_compiling(): - # When tracing, always reuse the compilable cache. The - # ``compilable_stored_size == (H, W)`` check would compare a tuple - # of SymInts against concrete ints and raise - # GuardOnDataDependentSymNode. + if torch.compiler.is_compiling(): + # When tracing (torch.compile or torch.export, strict or non-strict), + # always reuse the compilable cache. The ``compilable_stored_size == + # (H, W)`` check would compare a tuple of concrete ints against a + # tuple of SymInts and raise GuardOnDataDependentSymNode. coords_h, coords_w = self.compilable_cord_cache elif self.compilable_stored_size == (H, W): coords_h, coords_w = self.compilable_cord_cache diff --git a/sam3/model/geometry_encoders.py b/sam3/model/geometry_encoders.py index 7e2e525..1c84112 100644 --- a/sam3/model/geometry_encoders.py +++ b/sam3/model/geometry_encoders.py @@ -646,7 +646,7 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats): boxes_xyxy = box_cxcywh_to_xyxy(boxes) scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) # pin_memory + non_blocking copy is unsupported under torch.export tracing. - if not torch._dynamo.is_compiling() and boxes_xyxy.device.type != "cpu": + if not torch.compiler.is_compiling() and boxes_xyxy.device.type != "cpu": scale = scale.pin_memory().to( device=boxes_xyxy.device, non_blocking=True ) diff --git a/sam3/model/position_encoding.py b/sam3/model/position_encoding.py index 76002c0..d9d0450 100644 --- a/sam3/model/position_encoding.py +++ b/sam3/model/position_encoding.py @@ -61,7 +61,7 @@ def _encode_xy(self, x, y): # The positions are expected to be normalized. # Skip the size assert when tracing — symbolic shapes from dynamic # prompts make `len(x) == len(y)` raise GuardOnDataDependentSymNode. - if not torch._dynamo.is_compiling(): + if not torch.compiler.is_compiling(): assert len(x) == len(y) and x.ndim == y.ndim == 1 x_embed = x * self.scale y_embed = y * self.scale From 5a5694c06a20f2aa68e939a26b5b62b57080cb15 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:26:01 -0700 Subject: [PATCH 03/10] Drop isinstance assert in _cross_attn_with_rpb cross_attn is sam3's custom MultiheadAttention (model_misc.MultiheadAttention, aliased as MultiheadAttentionWrapper), not torch.nn.MultiheadAttention. The two classes share the relevant attributes (in_proj_weight/bias, out_proj, num_heads, head_dim, dropout), so duck-type instead. --- sam3/model/decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index 438e2e8..71962fd 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -224,8 +224,10 @@ def _cross_attn_with_rpb( position bias path. ``nn.MultiheadAttention`` rejects 4D masks, so we unpack the in/out projections and call ``scaled_dot_product_attention``. """ + # Works on both nn.MultiheadAttention and sam3's custom + # MultiheadAttention (model_misc.MultiheadAttention) — both expose + # in_proj_weight/bias, out_proj, num_heads, head_dim, dropout. mha = self.cross_attn - assert isinstance(mha, nn.MultiheadAttention) q, k, v = torchF._in_projection_packed( query, key, value, mha.in_proj_weight, mha.in_proj_bias ) From a2b071ea8bab9933bbe3082851ec94ad2bd11727 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:27:05 -0700 Subject: [PATCH 04/10] Trace export with batch=2 so Dim.AUTO doesn't specialize the batch dim --- scripts/export_sam3_full_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/export_sam3_full_pipeline.py b/scripts/export_sam3_full_pipeline.py index 72f9ae6..6f871ea 100644 --- a/scripts/export_sam3_full_pipeline.py +++ b/scripts/export_sam3_full_pipeline.py @@ -110,7 +110,8 @@ def export_full_pipeline( wrapper = FullSam3PipelineWrapper(model).to(device).eval() - images = torch.randn(1, 3, INPUT_SIZE, INPUT_SIZE, device=device) + # Trace with batch=2 so Dim.AUTO doesn't specialize the batch dim away. + images = torch.randn(2, 3, INPUT_SIZE, INPUT_SIZE, device=device) token_ids = torch.zeros(num_export_prompts, CONTEXT_LENGTH, dtype=torch.long, device=device) token_ids[:, 0] = 49406 # <|startoftext|> so attention mask is non-empty From ca064617fbd1a478f52fbe2155ddd340f7b33161 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:28:58 -0700 Subject: [PATCH 05/10] Use named Dim('batch', min=1) so exported model accepts batch=1 Dim.AUTO with batch=2 example specializes the batch dim to min=2, which rejects batch=1 at runtime. A named Dim with min=1 keeps the dim dynamic across the full [1, +inf) range. --- scripts/export_sam3_full_pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/export_sam3_full_pipeline.py b/scripts/export_sam3_full_pipeline.py index 6f871ea..cfa9801 100644 --- a/scripts/export_sam3_full_pipeline.py +++ b/scripts/export_sam3_full_pipeline.py @@ -115,13 +115,16 @@ def export_full_pipeline( token_ids = torch.zeros(num_export_prompts, CONTEXT_LENGTH, dtype=torch.long, device=device) token_ids[:, 0] = 49406 # <|startoftext|> so attention mask is non-empty + # Named Dim with min=1 so consumers can call with batch=1; Dim.AUTO would + # take its min from the example shape (2) and refuse batch=1 at runtime. + batch = Dim("batch", min=1) num_prompts = Dim("num_prompts", min=1) with torch.no_grad(): return torch.export.export( wrapper, (images, token_ids), dynamic_shapes={ - "images": {0: Dim.AUTO, 2: INPUT_SIZE, 3: INPUT_SIZE}, + "images": {0: batch, 2: INPUT_SIZE, 3: INPUT_SIZE}, "token_ids": {0: num_prompts, 1: CONTEXT_LENGTH}, }, strict=False, From 3736777224192a2754bf2b6ddaff4e310bdf9b5a Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:31:29 -0700 Subject: [PATCH 06/10] Cast projection weights to activation dtype in _cross_attn_with_rpb The image encoder runs under autocast and emits bf16, but the cross-attn weights stay fp32. nn.MultiheadAttention's forward casts internally; our manual SDPA path didn't, producing 'mat1 and mat2 must have the same dtype' at runtime. Cast in_proj/out_proj weights and the attn bias to match query dtype. --- sam3/model/decoder.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index 71962fd..838b0c1 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -227,9 +227,18 @@ def _cross_attn_with_rpb( # Works on both nn.MultiheadAttention and sam3's custom # MultiheadAttention (model_misc.MultiheadAttention) — both expose # in_proj_weight/bias, out_proj, num_heads, head_dim, dropout. + # Cast projection weights to the activation dtype so the call works + # under autocast (image encoder may emit bf16 while weights stay fp32); + # the upstream nn.MultiheadAttention forward does this internally. mha = self.cross_attn + in_proj_weight = mha.in_proj_weight.to(dtype=query.dtype) + in_proj_bias = ( + mha.in_proj_bias.to(dtype=query.dtype) + if mha.in_proj_bias is not None + else None + ) q, k, v = torchF._in_projection_packed( - query, key, value, mha.in_proj_weight, mha.in_proj_bias + query, key, value, in_proj_weight, in_proj_bias ) tgt_len, bsz, _ = q.shape num_heads = mha.num_heads @@ -238,7 +247,7 @@ def _cross_attn_with_rpb( k = k.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3) v = v.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3) src_len = k.shape[2] - bias = attn_bias + bias = attn_bias.to(dtype=q.dtype) if bias.dim() == 3: bias = bias.view(bsz, num_heads, tgt_len, src_len) if key_padding_mask is not None: @@ -254,7 +263,13 @@ def _cross_attn_with_rpb( is_causal=False, ) attn_output = attn_output.permute(2, 0, 1, 3).reshape(tgt_len, bsz, -1) - return torchF.linear(attn_output, mha.out_proj.weight, mha.out_proj.bias) + out_w = mha.out_proj.weight.to(dtype=attn_output.dtype) + out_b = ( + mha.out_proj.bias.to(dtype=attn_output.dtype) + if mha.out_proj.bias is not None + else None + ) + return torchF.linear(attn_output, out_w, out_b) class TransformerDecoder(nn.Module): From 6cc263f3140ea37fcbeacd74b516db7c59dac0f7 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:35:01 -0700 Subject: [PATCH 07/10] Trace under bf16 autocast on CUDA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Upstream vitdet MLP uses sam3.perflib.fused.addmm_act which forces bf16 internally (added in the SAM 3.1 release commit 9f22cb9). Production runs the model inside Sam3TrackingPredictor whose __init__ enters a persistent bf16 autocast — so eager production code works. Our export wrapper calls the model directly, so we have to re-enter autocast at trace time, otherwise fc2 downstream of addmm_act sees a bf16 input against fp32 weights and raises 'mat1 and mat2 must have the same dtype'. --- scripts/export_sam3_full_pipeline.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/scripts/export_sam3_full_pipeline.py b/scripts/export_sam3_full_pipeline.py index cfa9801..4df482d 100644 --- a/scripts/export_sam3_full_pipeline.py +++ b/scripts/export_sam3_full_pipeline.py @@ -119,7 +119,16 @@ def export_full_pipeline( # take its min from the example shape (2) and refuse batch=1 at runtime. batch = Dim("batch", min=1) num_prompts = Dim("num_prompts", min=1) - with torch.no_grad(): + # Trace under bf16 autocast on CUDA — the ViT MLP's addmm_act path emits + # bf16, and Sam3TrackingPredictor enters bf16 autocast in __init__, so + # eager production runs already happen under autocast. Without this, fc2 + # downstream of addmm_act fails at runtime with bf16/fp32 mismatch. + autocast_ctx = ( + torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + if device.type == "cuda" + else torch.amp.autocast(device_type="cpu", enabled=False) + ) + with torch.no_grad(), autocast_ctx: return torch.export.export( wrapper, (images, token_ids), From 4073437d5dbc2a2fa21b1b97be8d1a35b4c7912b Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:37:39 -0700 Subject: [PATCH 08/10] Move bf16 autocast inside the wrapper forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tracing under outer autocast baked the autocast effects into op dtypes but missed the fp32→bf16 input boundary, causing the deserialized graph to demand bf16 input at runtime. Putting autocast inside the wrapper keeps the input/output fp32 contract while letting addmm_act's bf16 path coexist with the surrounding fp32 weights. --- scripts/export_sam3_full_pipeline.py | 87 ++++++++++++++-------------- 1 file changed, 45 insertions(+), 42 deletions(-) diff --git a/scripts/export_sam3_full_pipeline.py b/scripts/export_sam3_full_pipeline.py index 4df482d..4920a04 100644 --- a/scripts/export_sam3_full_pipeline.py +++ b/scripts/export_sam3_full_pipeline.py @@ -56,40 +56,52 @@ def forward( box_mask = torch.ones(bs, 1, device=device, dtype=torch.bool) box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) - backbone_out = model.backbone.forward_image(images) - text_encoder = model.backbone.language_backbone - _, text_tokens = text_encoder.encoder(token_ids) - text_tokens = text_tokens.transpose(0, 1) - text_memory = text_encoder.resizer(text_tokens) - text_attention_mask = token_ids.ne(0).ne(1) - backbone_out["language_features"] = text_memory - backbone_out["language_mask"] = text_attention_mask - - find_input = FindStage( - img_ids=img_ids, - text_ids=text_ids, - input_boxes=box_embeddings, - input_boxes_mask=box_mask, - input_boxes_label=box_labels, - input_points=torch.zeros(0, bs, 2, device=device), - input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), - ) - geometric_prompt = Prompt( - box_embeddings=box_embeddings, - box_mask=box_mask, - box_labels=box_labels, - ) - out = model.forward_grounding( - backbone_out=backbone_out, - find_input=find_input, - find_target=None, - geometric_prompt=geometric_prompt, + # Run the actual model under bf16 autocast on CUDA. The ViT MLP uses + # sam3.perflib.fused.addmm_act which forces bf16 internally; without + # autocast around the forward, the bf16 output collides with fp32 + # weights downstream. Sam3TrackingPredictor enters this same autocast + # in __init__, so eager production runs already happen in bf16. + autocast_ctx = ( + torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + if device.type == "cuda" + else torch.amp.autocast(device_type="cpu", enabled=False) ) + with autocast_ctx: + backbone_out = model.backbone.forward_image(images) + text_encoder = model.backbone.language_backbone + _, text_tokens = text_encoder.encoder(token_ids) + text_tokens = text_tokens.transpose(0, 1) + text_memory = text_encoder.resizer(text_tokens) + text_attention_mask = token_ids.ne(0).ne(1) + backbone_out["language_features"] = text_memory + backbone_out["language_mask"] = text_attention_mask + + find_input = FindStage( + img_ids=img_ids, + text_ids=text_ids, + input_boxes=box_embeddings, + input_boxes_mask=box_mask, + input_boxes_label=box_labels, + input_points=torch.zeros(0, bs, 2, device=device), + input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), + ) + geometric_prompt = Prompt( + box_embeddings=box_embeddings, + box_mask=box_mask, + box_labels=box_labels, + ) + out = model.forward_grounding( + backbone_out=backbone_out, + find_input=find_input, + find_target=None, + geometric_prompt=geometric_prompt, + ) + # Cast outputs back to fp32 so downstream consumers don't have to. return ( - out["pred_logits"], - out["pred_boxes"], - out["pred_masks"], - out.get("presence_logit_dec"), + out["pred_logits"].float(), + out["pred_boxes"].float(), + out["pred_masks"].float(), + out["presence_logit_dec"].float() if out.get("presence_logit_dec") is not None else None, ) @@ -119,16 +131,7 @@ def export_full_pipeline( # take its min from the example shape (2) and refuse batch=1 at runtime. batch = Dim("batch", min=1) num_prompts = Dim("num_prompts", min=1) - # Trace under bf16 autocast on CUDA — the ViT MLP's addmm_act path emits - # bf16, and Sam3TrackingPredictor enters bf16 autocast in __init__, so - # eager production runs already happen under autocast. Without this, fc2 - # downstream of addmm_act fails at runtime with bf16/fp32 mismatch. - autocast_ctx = ( - torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) - if device.type == "cuda" - else torch.amp.autocast(device_type="cpu", enabled=False) - ) - with torch.no_grad(), autocast_ctx: + with torch.no_grad(): return torch.export.export( wrapper, (images, token_ids), From ca57718f8b171a2599c9c97c596f504a5e0e97b1 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:40:07 -0700 Subject: [PATCH 09/10] Strengthen pytest: assert eager-vs-exported equivalence + parametric dynamic shapes --- tests/export/test_full_pipeline_export.py | 54 ++++++++++++++++++----- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/tests/export/test_full_pipeline_export.py b/tests/export/test_full_pipeline_export.py index e1b3284..e5dcba5 100644 --- a/tests/export/test_full_pipeline_export.py +++ b/tests/export/test_full_pipeline_export.py @@ -17,6 +17,7 @@ from scripts.export_sam3_full_pipeline import ( CONTEXT_LENGTH, INPUT_SIZE, + FullSam3PipelineWrapper, export_full_pipeline, ) @@ -27,8 +28,8 @@ def _device() -> torch.device: return torch.device("cuda" if torch.cuda.is_available() else "cpu") -@pytest.mark.slow -def test_full_pipeline_export_traces_and_runs() -> None: +@pytest.fixture(scope="module") +def sam3_model() -> torch.nn.Module: device = _device() model = build_sam3_image_model( device=str(device), @@ -37,22 +38,53 @@ def test_full_pipeline_export_traces_and_runs() -> None: num_feature_levels=1, ) model.eval() + return model - exported = export_full_pipeline(model, device=device, num_export_prompts=3) - - # The graph must accept variable batch size and variable num_prompts. - eager_wrapper = exported.module() +@pytest.mark.slow +def test_full_pipeline_export_matches_eager(sam3_model: torch.nn.Module) -> None: + """Exported and eager outputs must agree on the same input.""" + device = _device() + torch.manual_seed(0) images = torch.randn(2, 3, INPUT_SIZE, INPUT_SIZE, device=device) - token_ids = torch.zeros(4, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids = torch.zeros(3, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids[:, 0] = 49406 + + wrapper = FullSam3PipelineWrapper(sam3_model).to(device).eval() + with torch.no_grad(): + eager_out = wrapper(images, token_ids) + + ep = export_full_pipeline(sam3_model, device=device, num_export_prompts=3) + with torch.no_grad(): + exported_out = ep.module()(images, token_ids) + + assert len(eager_out) == len(exported_out) == 4 + for idx, (e, x) in enumerate(zip(eager_out, exported_out)): + if e is None and x is None: + continue + assert e is not None and x is not None, f"output {idx} disagrees on None-ness" + torch.testing.assert_close(e, x, rtol=0, atol=0, msg=f"output {idx} differs") + + +@pytest.mark.slow +@pytest.mark.parametrize(("batch", "num_prompts"), [(1, 1), (1, 2), (3, 2), (2, 4)]) +def test_full_pipeline_export_supports_dynamic_shapes( + sam3_model: torch.nn.Module, batch: int, num_prompts: int +) -> None: + """The exported program must accept the dynamic batch / num_prompts grid.""" + device = _device() + ep = export_full_pipeline(sam3_model, device=device, num_export_prompts=3) + module = ep.module() + + images = torch.randn(batch, 3, INPUT_SIZE, INPUT_SIZE, device=device) + token_ids = torch.zeros(num_prompts, CONTEXT_LENGTH, dtype=torch.long, device=device) token_ids[:, 0] = 49406 with torch.no_grad(): - pred_logits, pred_boxes, pred_masks, presence = eager_wrapper(images, token_ids) + pred_logits, pred_boxes, pred_masks, presence = module(images, token_ids) - bs_total = images.shape[0] * token_ids.shape[0] + bs_total = batch * num_prompts assert pred_logits.shape[0] == bs_total assert pred_boxes.shape[0] == bs_total assert pred_masks.shape[0] == bs_total - if presence is not None: - assert presence.shape[0] == bs_total + assert presence is not None and presence.shape[0] == bs_total From 6da0aeb6696c8a2895c70ca2f958bf3f5da30d16 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 29 Apr 2026 16:42:25 -0700 Subject: [PATCH 10/10] Add save/load roundtrip test for the .pt2 archive --- tests/export/test_full_pipeline_export.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/export/test_full_pipeline_export.py b/tests/export/test_full_pipeline_export.py index e5dcba5..e054d6f 100644 --- a/tests/export/test_full_pipeline_export.py +++ b/tests/export/test_full_pipeline_export.py @@ -66,6 +66,27 @@ def test_full_pipeline_export_matches_eager(sam3_model: torch.nn.Module) -> None torch.testing.assert_close(e, x, rtol=0, atol=0, msg=f"output {idx} differs") +@pytest.mark.slow +def test_full_pipeline_export_save_load_roundtrip( + sam3_model: torch.nn.Module, tmp_path +) -> None: + """Save the export to a .pt2, reload, and confirm it still runs.""" + import torchvision.ops # noqa: F401 -- registers roi_align before load + + device = _device() + ep = export_full_pipeline(sam3_model, device=device, num_export_prompts=3) + out_path = tmp_path / "full_sam3_pipeline.pt2" + torch.export.save(ep, str(out_path)) + + loaded = torch.export.load(str(out_path)) + images = torch.randn(2, 3, INPUT_SIZE, INPUT_SIZE, device=device) + token_ids = torch.zeros(3, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids[:, 0] = 49406 + with torch.no_grad(): + out = loaded.module()(images, token_ids) + assert out[0].shape[0] == images.shape[0] * token_ids.shape[0] + + @pytest.mark.slow @pytest.mark.parametrize(("batch", "num_prompts"), [(1, 1), (1, 2), (3, 2), (2, 4)]) def test_full_pipeline_export_supports_dynamic_shapes(