diff --git a/pyproject.toml b/pyproject.toml index 9df1b67..4f89532 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,3 +133,4 @@ testpaths = ["tests"] python_files = "test_*.py" python_classes = "Test*" python_functions = "test_*" +markers = ["slow: builds the full SAM3 model (heavy GPU + slow)"] diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index b90e55f..838b0c1 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -156,24 +156,46 @@ def forward( tgt = tgt + self.catext_dropout(tgt2) tgt = self.catext_norm(tgt) - if presence_token is not None: - presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) - cross_attn_mask = torch.cat( - [presence_token_mask, cross_attn_mask], dim=1 - ) # (bs*nheads, 1+nq, hw) + if presence_token is not None and cross_attn_mask is not None: + # Support both 3D (bs*nheads, nq, hw) and 4D (bs, nheads, nq, hw) + # cross_attn_mask shapes. The 4D form is used by the export-friendly + # SDPA path below; nn.MultiheadAttention requires the 3D form. + if cross_attn_mask.dim() == 4: + presence_token_mask = torch.zeros_like(cross_attn_mask[:, :, :1, :]) + cross_attn_mask = torch.cat( + [presence_token_mask, cross_attn_mask], dim=2 + ) # (bs, nheads, 1+nq, hw) + else: + presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :]) + cross_attn_mask = torch.cat( + [presence_token_mask, cross_attn_mask], dim=1 + ) # (bs*nheads, 1+nq, hw) # Cross attention to image - tgt2 = self.cross_attn( - query=self.with_pos_embed(tgt, tgt_query_pos), - key=self.with_pos_embed(memory, memory_pos), - value=memory, - attn_mask=cross_attn_mask, - key_padding_mask=( - memory_key_padding_mask.transpose(0, 1) - if memory_key_padding_mask is not None - else None - ), - )[0] + key_padding_mask = ( + memory_key_padding_mask.transpose(0, 1) + if memory_key_padding_mask is not None + else None + ) + if cross_attn_mask is not None and cross_attn_mask.dim() == 4: + # nn.MultiheadAttention does not accept (bs, nheads, q, k) attention + # bias tensors. Box-RPB cross-attn passes a 4D additive bias, so + # route through a manual SDPA path that preserves the per-head bias. + tgt2 = self._cross_attn_with_rpb( + query=self.with_pos_embed(tgt, tgt_query_pos), + key=self.with_pos_embed(memory, memory_pos), + value=memory, + attn_bias=cross_attn_mask, + key_padding_mask=key_padding_mask, + ) + else: + tgt2 = self.cross_attn( + query=self.with_pos_embed(tgt, tgt_query_pos), + key=self.with_pos_embed(memory, memory_pos), + value=memory, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + )[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) @@ -188,6 +210,67 @@ def forward( return tgt, presence_token_out + def _cross_attn_with_rpb( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor, + key_padding_mask: Optional[Tensor], + ) -> Tensor: + """ + Manual SDPA cross-attention that preserves a per-head additive bias. + Used when ``attn_bias`` is 4D (bs, nheads, q, k) — the relative box + position bias path. ``nn.MultiheadAttention`` rejects 4D masks, so we + unpack the in/out projections and call ``scaled_dot_product_attention``. + """ + # Works on both nn.MultiheadAttention and sam3's custom + # MultiheadAttention (model_misc.MultiheadAttention) — both expose + # in_proj_weight/bias, out_proj, num_heads, head_dim, dropout. + # Cast projection weights to the activation dtype so the call works + # under autocast (image encoder may emit bf16 while weights stay fp32); + # the upstream nn.MultiheadAttention forward does this internally. + mha = self.cross_attn + in_proj_weight = mha.in_proj_weight.to(dtype=query.dtype) + in_proj_bias = ( + mha.in_proj_bias.to(dtype=query.dtype) + if mha.in_proj_bias is not None + else None + ) + q, k, v = torchF._in_projection_packed( + query, key, value, in_proj_weight, in_proj_bias + ) + tgt_len, bsz, _ = q.shape + num_heads = mha.num_heads + head_dim = mha.head_dim + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3) + k = k.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3) + v = v.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3) + src_len = k.shape[2] + bias = attn_bias.to(dtype=q.dtype) + if bias.dim() == 3: + bias = bias.view(bsz, num_heads, tgt_len, src_len) + if key_padding_mask is not None: + pad = key_padding_mask[:, None, None, :].to(dtype=q.dtype) + pad = pad.masked_fill(pad > 0, float("-inf")) + bias = bias + pad + attn_output = torchF.scaled_dot_product_attention( + q, + k, + v, + attn_mask=bias, + dropout_p=mha.dropout if self.training else 0.0, + is_causal=False, + ) + attn_output = attn_output.permute(2, 0, 1, 3).reshape(tgt_len, bsz, -1) + out_w = mha.out_proj.weight.to(dtype=attn_output.dtype) + out_b = ( + mha.out_proj.bias.to(dtype=attn_output.dtype) + if mha.out_proj.bias is not None + else None + ) + return torchF.linear(attn_output, out_w, out_b) + class TransformerDecoder(nn.Module): def __init__( @@ -338,11 +421,13 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device) self.compilable_stored_size = (H, W) - if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == ( - H, - W, - ): - # good, hitting the cache, will be compilable + if torch.compiler.is_compiling(): + # When tracing (torch.compile or torch.export, strict or non-strict), + # always reuse the compilable cache. The ``compilable_stored_size == + # (H, W)`` check would compare a tuple of concrete ints against a + # tuple of SymInts and raise GuardOnDataDependentSymNode. + coords_h, coords_w = self.compilable_cord_cache + elif self.compilable_stored_size == (H, W): coords_h, coords_w = self.compilable_cord_cache else: # cache miss, will create compilation issue @@ -353,9 +438,6 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): ) coords_h, coords_w = self.coord_cache[feat_size] - assert coords_h.shape == (H,) - assert coords_w.shape == (W,) - deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2] deltas_y = deltas_y.view(bs, num_queries, -1, 2) deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2] @@ -393,20 +475,12 @@ def _get_rpb_matrix(self, reference_boxes, feat_size): act_ckpt_enable=self.training and self.use_act_checkpoint, ) # bs, num_queries, H, n_heads - if not torch.compiler.is_dynamo_compiling(): - assert deltas_x.shape[:3] == (bs, num_queries, W) - assert deltas_y.shape[:3] == (bs, num_queries, H) - B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze( 2 ) # bs, num_queries, H, W, n_heads - if not torch.compiler.is_dynamo_compiling(): - assert B.shape[:4] == (bs, num_queries, H, W) B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W B = B.contiguous() # memeff attn likes ordered strides - if not torch.compiler.is_dynamo_compiling(): - assert B.shape[2:] == (num_queries, H * W) return B def forward( @@ -519,11 +593,12 @@ def forward( assert spatial_shapes.shape[0] == 1, ( "only single scale support implemented" ) + # Keep the 4D (bs, nheads, nq, H*W) shape so the export-friendly + # SDPA path in TransformerDecoderLayer can apply the bias per-head. memory_mask = self._get_rpb_matrix( reference_boxes, (spatial_shapes[0, 0], spatial_shapes[0, 1]), ) - memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W) if self.training: assert self.use_act_checkpoint, ( "Activation checkpointing not enabled in the decoder" diff --git a/sam3/model/encoder.py b/sam3/model/encoder.py index 3fc9406..49b0219 100644 --- a/sam3/model/encoder.py +++ b/sam3/model/encoder.py @@ -538,7 +538,7 @@ def forward( else None ) else: - assert all(x.dim == 4 for x in src), ( + assert all(x.dim() == 4 for x in src), ( "expected list of (bs, c, h, w) tensors" ) diff --git a/sam3/model/geometry_encoders.py b/sam3/model/geometry_encoders.py index d60ee54..1c84112 100644 --- a/sam3/model/geometry_encoders.py +++ b/sam3/model/geometry_encoders.py @@ -645,11 +645,26 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats): # We need to denormalize, and convert to [x, y, x, y] boxes_xyxy = box_cxcywh_to_xyxy(boxes) scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) - scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) + # pin_memory + non_blocking copy is unsupported under torch.export tracing. + if not torch.compiler.is_compiling() and boxes_xyxy.device.type != "cpu": + scale = scale.pin_memory().to( + device=boxes_xyxy.device, non_blocking=True + ) + else: + scale = scale.to(device=boxes_xyxy.device) scale = scale.view(1, 1, 4) boxes_xyxy = boxes_xyxy * scale + # ROIAlign accepts a list of per-batch box tensors OR a single + # [N, 5] tensor with batch_idx prepended. The list form is hard to + # trace through torch.export, so use the batched format. + boxes_xyxy = boxes_xyxy.transpose(0, 1) # (bs, n_boxes, 4) + batch_idx = torch.arange( + bs, device=boxes_xyxy.device, dtype=boxes_xyxy.dtype + ) + batch_idx = batch_idx.view(bs, 1, 1).expand(bs, n_boxes, 1) + boxes_for_roi = torch.cat([batch_idx, boxes_xyxy], dim=-1).reshape(-1, 5) sampled = torchvision.ops.roi_align( - img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size + img_feats, boxes_for_roi.float(), self.roi_size ) assert list(sampled.shape) == [ bs * n_boxes, diff --git a/sam3/model/position_encoding.py b/sam3/model/position_encoding.py index f38c62b..d9d0450 100644 --- a/sam3/model/position_encoding.py +++ b/sam3/model/position_encoding.py @@ -58,8 +58,11 @@ def __init__( self.cache[size] = self.cache[size].clone().detach() def _encode_xy(self, x, y): - # The positions are expected to be normalized - assert len(x) == len(y) and x.ndim == y.ndim == 1 + # The positions are expected to be normalized. + # Skip the size assert when tracing — symbolic shapes from dynamic + # prompts make `len(x) == len(y)` raise GuardOnDataDependentSymNode. + if not torch.compiler.is_compiling(): + assert len(x) == len(y) and x.ndim == y.ndim == 1 x_embed = x * self.scale y_embed = y * self.scale @@ -95,9 +98,11 @@ def encode_points(self, x, y, labels): @torch.no_grad() def forward(self, x): - cache_key = None cache_key = (x.shape[-2], x.shape[-1]) - if cache_key in self.cache: + # Skip the cache when tracing with symbolic shapes — looking up a SymInt + # key in a dict with concrete-int keys raises GuardOnDataDependentSymNode. + use_cache = all(isinstance(dim, int) for dim in cache_key) + if use_cache and cache_key in self.cache: return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) y_embed = ( torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) @@ -127,6 +132,6 @@ def forward(self, x): (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 ).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - if cache_key is not None: + if use_cache: self.cache[cache_key] = pos[0] return pos diff --git a/sam3/model_builder.py b/sam3/model_builder.py index a5dea87..c6b34a1 100644 --- a/sam3/model_builder.py +++ b/sam3/model_builder.py @@ -122,7 +122,9 @@ def _create_vl_backbone(vit_neck, text_encoder): return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1) -def _create_transformer_encoder(use_fa3=False) -> TransformerEncoderFusion: +def _create_transformer_encoder( + use_fa3: bool = False, num_feature_levels: int = 1 +) -> TransformerEncoderFusion: """Create transformer encoder with its layer.""" encoder_layer = TransformerEncoderLayer( activation="relu", @@ -153,7 +155,7 @@ def _create_transformer_encoder(use_fa3=False) -> TransformerEncoderFusion: layer=encoder_layer, num_layers=6, d_model=256, - num_feature_levels=1, + num_feature_levels=num_feature_levels, frozen=False, use_act_checkpoint=True, add_pooled_text_to_img_feat=False, @@ -307,6 +309,7 @@ def _create_sam3_model( dot_prod_scoring, inst_interactive_predictor, eval_mode, + num_feature_levels: int = 1, ): """Create the SAM3 image model.""" common_params = { @@ -314,7 +317,7 @@ def _create_sam3_model( "transformer": transformer, "input_geometry_encoder": input_geometry_encoder, "segmentation_head": segmentation_head, - "num_feature_levels": 1, + "num_feature_levels": num_feature_levels, "o2m_mask_predict": True, "dot_prod_scoring": dot_prod_scoring, "use_instance_query": False, @@ -527,10 +530,14 @@ def _create_vision_backbone( def _create_sam3_transformer( - has_presence_token: bool = True, use_fa3: bool = False + has_presence_token: bool = True, + use_fa3: bool = False, + num_feature_levels: int = 1, ) -> TransformerWrapper: """Create SAM3 transformer encoder and decoder.""" - encoder: TransformerEncoderFusion = _create_transformer_encoder(use_fa3=use_fa3) + encoder: TransformerEncoderFusion = _create_transformer_encoder( + use_fa3=use_fa3, num_feature_levels=num_feature_levels + ) decoder: TransformerDecoder = _create_transformer_decoder(use_fa3=use_fa3) return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256) @@ -579,6 +586,7 @@ def build_sam3_image_model( enable_segmentation=True, enable_inst_interactivity=False, compile=False, + num_feature_levels: int = 1, ): """ Build SAM3 image model @@ -613,7 +621,7 @@ def build_sam3_image_model( backbone = _create_vl_backbone(vision_encoder, text_encoder) # Create transformer components - transformer = _create_sam3_transformer() + transformer = _create_sam3_transformer(num_feature_levels=num_feature_levels) # Create dot product scoring dot_prod_scoring = _create_dot_product_scoring() @@ -641,6 +649,7 @@ def build_sam3_image_model( dot_prod_scoring, inst_predictor, eval_mode, + num_feature_levels=num_feature_levels, ) if load_from_HF and checkpoint_path is None: checkpoint_path = download_ckpt_from_hf(version="sam3") diff --git a/scripts/export_sam3_full_pipeline.py b/scripts/export_sam3_full_pipeline.py new file mode 100644 index 0000000..4920a04 --- /dev/null +++ b/scripts/export_sam3_full_pipeline.py @@ -0,0 +1,185 @@ +""" +Export the full SAM3 image pipeline as a single ``torch.export`` program +suitable for ``torch.export.save`` / ``torch.export.load``. + +Inputs to the exported program: + images: (B, 3, 1008, 1008) float32 -- B is dynamic, H/W are fixed + token_ids: (P, 32) int64 -- P is dynamic (>= 1), L=32 fixed + +Outputs (4-tuple): + pred_logits, pred_boxes, pred_masks, presence_logit_dec + +This wrapper unifies the image encoder, text encoder, encoder fusion, and +decoder into a single graph so consumers can ship one ``.pt2`` artifact. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any, cast + +import torch +from torch.export.dynamic_shapes import Dim + +from sam3.model.data_misc import FindStage +from sam3.model.geometry_encoders import Prompt +from sam3.model_builder import build_sam3_image_model + +INPUT_SIZE = 1008 +CONTEXT_LENGTH = 32 + + +class FullSam3PipelineWrapper(torch.nn.Module): + """Single-graph wrapper over the SAM3 grounding pipeline.""" + + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + self.model = model + + def forward( + self, + images: torch.Tensor, + token_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + model = cast(Any, self.model) + device = images.device + bs = images.shape[0] * token_ids.shape[0] + + img_ids = torch.arange(images.shape[0], device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(token_ids.shape[0]) + text_ids = torch.arange(token_ids.shape[0], device=device, dtype=torch.long) + text_ids = text_ids.repeat(images.shape[0]) + + # Text-only grounding: empty geometric prompts. + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.ones(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) + + # Run the actual model under bf16 autocast on CUDA. The ViT MLP uses + # sam3.perflib.fused.addmm_act which forces bf16 internally; without + # autocast around the forward, the bf16 output collides with fp32 + # weights downstream. Sam3TrackingPredictor enters this same autocast + # in __init__, so eager production runs already happen in bf16. + autocast_ctx = ( + torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + if device.type == "cuda" + else torch.amp.autocast(device_type="cpu", enabled=False) + ) + with autocast_ctx: + backbone_out = model.backbone.forward_image(images) + text_encoder = model.backbone.language_backbone + _, text_tokens = text_encoder.encoder(token_ids) + text_tokens = text_tokens.transpose(0, 1) + text_memory = text_encoder.resizer(text_tokens) + text_attention_mask = token_ids.ne(0).ne(1) + backbone_out["language_features"] = text_memory + backbone_out["language_mask"] = text_attention_mask + + find_input = FindStage( + img_ids=img_ids, + text_ids=text_ids, + input_boxes=box_embeddings, + input_boxes_mask=box_mask, + input_boxes_label=box_labels, + input_points=torch.zeros(0, bs, 2, device=device), + input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), + ) + geometric_prompt = Prompt( + box_embeddings=box_embeddings, + box_mask=box_mask, + box_labels=box_labels, + ) + out = model.forward_grounding( + backbone_out=backbone_out, + find_input=find_input, + find_target=None, + geometric_prompt=geometric_prompt, + ) + # Cast outputs back to fp32 so downstream consumers don't have to. + return ( + out["pred_logits"].float(), + out["pred_boxes"].float(), + out["pred_masks"].float(), + out["presence_logit_dec"].float() if out.get("presence_logit_dec") is not None else None, + ) + + +def export_full_pipeline( + model: torch.nn.Module, + *, + device: torch.device, + num_export_prompts: int = 3, +) -> torch.export.ExportedProgram: + """Trace ``FullSam3PipelineWrapper`` with dynamic batch and prompt dims. + + ``num_export_prompts`` must be >= 3 so the prompt dim is treated as + dynamic during tracing (a length-2 example would let the tracer specialise + the dim away). + """ + if num_export_prompts < 3: + raise ValueError("Use >= 3 prompts so the prompt dim stays dynamic") + + wrapper = FullSam3PipelineWrapper(model).to(device).eval() + + # Trace with batch=2 so Dim.AUTO doesn't specialize the batch dim away. + images = torch.randn(2, 3, INPUT_SIZE, INPUT_SIZE, device=device) + token_ids = torch.zeros(num_export_prompts, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids[:, 0] = 49406 # <|startoftext|> so attention mask is non-empty + + # Named Dim with min=1 so consumers can call with batch=1; Dim.AUTO would + # take its min from the example shape (2) and refuse batch=1 at runtime. + batch = Dim("batch", min=1) + num_prompts = Dim("num_prompts", min=1) + with torch.no_grad(): + return torch.export.export( + wrapper, + (images, token_ids), + dynamic_shapes={ + "images": {0: batch, 2: INPUT_SIZE, 3: INPUT_SIZE}, + "token_ids": {0: num_prompts, 1: CONTEXT_LENGTH}, + }, + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--out", + type=Path, + default=Path("artifacts/export/full_sam3_pipeline.pt2"), + help="Destination .pt2 path", + ) + parser.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument( + "--num-export-prompts", + type=int, + default=3, + help="Prompts used during tracing — must be >= 3 to keep the dim dynamic", + ) + args = parser.parse_args() + + device = torch.device(args.device) + model = build_sam3_image_model( + device=str(device), + eval_mode=True, + enable_segmentation=True, + num_feature_levels=1, + ) + model.eval() + + exported = export_full_pipeline( + model, device=device, num_export_prompts=args.num_export_prompts + ) + args.out.parent.mkdir(parents=True, exist_ok=True) + torch.export.save(exported, str(args.out)) + print(f"Saved export to {args.out}") + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/export/__init__.py b/tests/export/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/export/test_full_pipeline_export.py b/tests/export/test_full_pipeline_export.py new file mode 100644 index 0000000..e054d6f --- /dev/null +++ b/tests/export/test_full_pipeline_export.py @@ -0,0 +1,111 @@ +"""End-to-end export test for the SAM3 full pipeline. + +Marked ``slow`` because it builds the real model (heavy memory + first call +allocates a CUDA model). Run explicitly with:: + + pytest tests/export/test_full_pipeline_export.py -m slow +""" + +from __future__ import annotations + +import os + +import pytest +import torch + +from sam3.model_builder import build_sam3_image_model +from scripts.export_sam3_full_pipeline import ( + CONTEXT_LENGTH, + INPUT_SIZE, + FullSam3PipelineWrapper, + export_full_pipeline, +) + + +def _device() -> torch.device: + if os.getenv("SAM3_EXPORT_FORCE_CPU") == "1": + return torch.device("cpu") + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.fixture(scope="module") +def sam3_model() -> torch.nn.Module: + device = _device() + model = build_sam3_image_model( + device=str(device), + eval_mode=True, + enable_segmentation=True, + num_feature_levels=1, + ) + model.eval() + return model + + +@pytest.mark.slow +def test_full_pipeline_export_matches_eager(sam3_model: torch.nn.Module) -> None: + """Exported and eager outputs must agree on the same input.""" + device = _device() + torch.manual_seed(0) + images = torch.randn(2, 3, INPUT_SIZE, INPUT_SIZE, device=device) + token_ids = torch.zeros(3, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids[:, 0] = 49406 + + wrapper = FullSam3PipelineWrapper(sam3_model).to(device).eval() + with torch.no_grad(): + eager_out = wrapper(images, token_ids) + + ep = export_full_pipeline(sam3_model, device=device, num_export_prompts=3) + with torch.no_grad(): + exported_out = ep.module()(images, token_ids) + + assert len(eager_out) == len(exported_out) == 4 + for idx, (e, x) in enumerate(zip(eager_out, exported_out)): + if e is None and x is None: + continue + assert e is not None and x is not None, f"output {idx} disagrees on None-ness" + torch.testing.assert_close(e, x, rtol=0, atol=0, msg=f"output {idx} differs") + + +@pytest.mark.slow +def test_full_pipeline_export_save_load_roundtrip( + sam3_model: torch.nn.Module, tmp_path +) -> None: + """Save the export to a .pt2, reload, and confirm it still runs.""" + import torchvision.ops # noqa: F401 -- registers roi_align before load + + device = _device() + ep = export_full_pipeline(sam3_model, device=device, num_export_prompts=3) + out_path = tmp_path / "full_sam3_pipeline.pt2" + torch.export.save(ep, str(out_path)) + + loaded = torch.export.load(str(out_path)) + images = torch.randn(2, 3, INPUT_SIZE, INPUT_SIZE, device=device) + token_ids = torch.zeros(3, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids[:, 0] = 49406 + with torch.no_grad(): + out = loaded.module()(images, token_ids) + assert out[0].shape[0] == images.shape[0] * token_ids.shape[0] + + +@pytest.mark.slow +@pytest.mark.parametrize(("batch", "num_prompts"), [(1, 1), (1, 2), (3, 2), (2, 4)]) +def test_full_pipeline_export_supports_dynamic_shapes( + sam3_model: torch.nn.Module, batch: int, num_prompts: int +) -> None: + """The exported program must accept the dynamic batch / num_prompts grid.""" + device = _device() + ep = export_full_pipeline(sam3_model, device=device, num_export_prompts=3) + module = ep.module() + + images = torch.randn(batch, 3, INPUT_SIZE, INPUT_SIZE, device=device) + token_ids = torch.zeros(num_prompts, CONTEXT_LENGTH, dtype=torch.long, device=device) + token_ids[:, 0] = 49406 + + with torch.no_grad(): + pred_logits, pred_boxes, pred_masks, presence = module(images, token_ids) + + bs_total = batch * num_prompts + assert pred_logits.shape[0] == bs_total + assert pred_boxes.shape[0] == bs_total + assert pred_masks.shape[0] == bs_total + assert presence is not None and presence.shape[0] == bs_total