From 0cbcb14e2360cbb56a7a804a77c8a6bc65503c18 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Wed, 11 Feb 2026 17:00:20 -0800 Subject: [PATCH 1/3] Relax position encoding check for dynamic prompts Use 2-input full pipeline export in tests and scripts, log output shapes, and add a standalone 2-input export helper. This keeps prompt count dynamic for export artifacts while preserving benchmarks and artifact validation. --- sam3/model/position_encoding.py | 10 +- scripts/benchmark_sam3_artifacts.py | 97 ++++-------- scripts/benchmark_sam3_export_times.py | 6 +- scripts/export_sam3_artifacts.py | 24 +-- scripts/export_sam3_full_pipeline.py | 67 +++----- scripts/test_sam3_artifacts.py | 124 +++++---------- tests/export/test_decoder_export.py | 120 ++++++--------- tests/export/utils.py | 24 +++ torch_full_pipeline_export_2_input.py | 203 +++++++++++++++++++++++++ 9 files changed, 373 insertions(+), 302 deletions(-) create mode 100644 torch_full_pipeline_export_2_input.py diff --git a/sam3/model/position_encoding.py b/sam3/model/position_encoding.py index 7419242..e112447 100644 --- a/sam3/model/position_encoding.py +++ b/sam3/model/position_encoding.py @@ -53,7 +53,7 @@ def __init__( def _encode_xy(self, x, y): # The positions are expected to be normalized - assert len(x) == len(y) and x.ndim == y.ndim == 1 + # torch._check(len(x) == len(y) and x.ndim == y.ndim == 1) x_embed = x * self.scale y_embed = y * self.scale @@ -62,12 +62,8 @@ def _encode_xy(self, x, y): pos_x = x_embed[:, None] / dim_t pos_y = y_embed[:, None] / dim_t - pos_x = torch.stack( - (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 - ).flatten(1) - pos_y = torch.stack( - (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 - ).flatten(1) + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) return pos_x, pos_y @torch.no_grad() diff --git a/scripts/benchmark_sam3_artifacts.py b/scripts/benchmark_sam3_artifacts.py index 7373cd0..f0ea143 100644 --- a/scripts/benchmark_sam3_artifacts.py +++ b/scripts/benchmark_sam3_artifacts.py @@ -34,42 +34,31 @@ def _prepare_image(image: torch.Tensor, size: int) -> torch.Tensor: def _make_inputs(model, image: torch.Tensor, prompts): device = image.device - num_prompts = len(prompts) - num_images = int(image.shape[0]) tokenizer = model.backbone.language_backbone.tokenizer token_ids = tokenizer(prompts, context_length=32).to(device) - img_ids = torch.arange(num_images, device=device, dtype=torch.long) - img_ids = img_ids.repeat_interleave(num_prompts) - text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) - text_ids = text_ids.repeat(num_images) - - box_embeddings = torch.zeros(1, num_prompts, 4, device=device) - box_mask = torch.zeros(num_prompts, 1, device=device, dtype=torch.bool) - box_labels = torch.zeros(1, num_prompts, device=device, dtype=torch.long) - return ( image, token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, ) def _run_full_model(model, inputs): - ( - images, - token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, - ) = inputs + images, token_ids = inputs + num_images = images.shape[0] + num_prompts = token_ids.shape[0] + device = images.device + bs = num_images * num_prompts + + img_ids = torch.arange(num_images, device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(num_prompts) + text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) + text_ids = text_ids.repeat(num_images) + + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) backbone_out = model.backbone.forward_image(images) text_encoder = model.backbone.language_backbone _, text_tokens = text_encoder.encoder(token_ids) @@ -113,15 +102,20 @@ def _make_decoder_only_inputs_from_model( text_attention_mask, inputs, ): - ( - images, - token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, - ) = inputs + images, token_ids = inputs + num_images = images.shape[0] + num_prompts = token_ids.shape[0] + device = images.device + bs = num_images * num_prompts + + img_ids = torch.arange(num_images, device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(num_prompts) + text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) + text_ids = text_ids.repeat(num_images) + + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) backbone_out = { "backbone_fpn": backbone_fpn, "vision_pos_enc": vision_pos_enc, @@ -147,9 +141,7 @@ def _make_decoder_only_inputs_from_model( prompt, prompt_mask, backbone_out = model._encode_prompt( backbone_out, find_input, geometric_prompt ) - backbone_out, encoder_out, _ = model._run_encoder( - backbone_out, find_input, prompt, prompt_mask - ) + backbone_out, encoder_out, _ = model._run_encoder(backbone_out, find_input, prompt, prompt_mask) return ( backbone_out["backbone_fpn"], img_ids, @@ -286,32 +278,7 @@ def encoder_from_outputs(image_out, text_out): def encoder_fn(): encoder_from_outputs(cached_image_out, cached_text_out) - ( - pipeline_images, - pipeline_token_ids, - pipeline_img_ids, - pipeline_text_ids, - pipeline_box_embeddings, - pipeline_box_mask, - pipeline_box_labels, - ) = inputs - if pipeline_token_ids.shape[0] < 2: - repeat = 2 // pipeline_token_ids.shape[0] - pipeline_token_ids = pipeline_token_ids.repeat(repeat, 1) - pipeline_img_ids = pipeline_img_ids.repeat(repeat) - pipeline_text_ids = pipeline_text_ids.repeat(repeat) - pipeline_box_embeddings = pipeline_box_embeddings.repeat(1, repeat, 1) - pipeline_box_mask = pipeline_box_mask.repeat(repeat, 1) - pipeline_box_labels = pipeline_box_labels.repeat(1, repeat) - pipeline_inputs = ( - pipeline_images, - pipeline_token_ids, - pipeline_img_ids, - pipeline_text_ids, - pipeline_box_embeddings, - pipeline_box_mask, - pipeline_box_labels, - ) + pipeline_inputs = inputs def pipeline_fn(): pipeline_module(*pipeline_inputs) @@ -346,9 +313,7 @@ def pipeline_fn(): decoder_prompt = decoder_prompt.repeat(1, repeat, 1) decoder_prompt_mask = decoder_prompt_mask.repeat(repeat, 1) decoder_valid_ratios = decoder_valid_ratios.repeat(repeat, 1, 1) - decoder_backbone_fpn = [ - feat.repeat(repeat, 1, 1, 1) for feat in decoder_backbone_fpn - ] + decoder_backbone_fpn = [feat.repeat(repeat, 1, 1, 1) for feat in decoder_backbone_fpn] decoder_only_inputs = ( decoder_backbone_fpn, decoder_img_ids, diff --git a/scripts/benchmark_sam3_export_times.py b/scripts/benchmark_sam3_export_times.py index 6ae5302..b45b6a8 100644 --- a/scripts/benchmark_sam3_export_times.py +++ b/scripts/benchmark_sam3_export_times.py @@ -77,7 +77,7 @@ def main() -> None: model.eval() image = _prepare_image(_load_image(args.image, device), size=1008) - inputs = _make_inputs(1, 1008, 1008, str(device), num_boxes=1) + inputs = _make_inputs(1, 1008, 1008, str(device)) decoder_inputs = None decoder_inputs_error = None @@ -120,9 +120,7 @@ def export_encoder_fusion(): EncoderFusionWrapper(model.transformer.encoder).to(img_feats.device).eval() ) if args.num_feature_levels != 1: - raise RuntimeError( - "encoder_fusion export currently expects num_feature_levels=1" - ) + raise RuntimeError("encoder_fusion export currently expects num_feature_levels=1") torch.export.export( encoder_wrapper, (img_feats, img_pos, img_mask, text_memory, text_attention_mask), diff --git a/scripts/export_sam3_artifacts.py b/scripts/export_sam3_artifacts.py index 8f25647..7cc0426 100644 --- a/scripts/export_sam3_artifacts.py +++ b/scripts/export_sam3_artifacts.py @@ -39,29 +39,13 @@ def _prepare_image(image: torch.Tensor, size: int) -> torch.Tensor: def _make_inputs(model, image: torch.Tensor, prompts): device = image.device - num_prompts = len(prompts) - num_images = int(image.shape[0]) tokenizer = model.backbone.language_backbone.tokenizer token_ids = tokenizer(prompts, context_length=32).to(device) - img_ids = torch.arange(num_images, device=device, dtype=torch.long) - img_ids = img_ids.repeat_interleave(num_prompts) - text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) - text_ids = text_ids.repeat(num_images) - - box_embeddings = torch.zeros(1, num_prompts, 4, device=device) - box_mask = torch.zeros(num_prompts, 1, device=device, dtype=torch.bool) - box_labels = torch.zeros(1, num_prompts, device=device, dtype=torch.long) - return ( image, token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, ) @@ -101,9 +85,7 @@ def main() -> None: if not prompts: raise ValueError("Provide at least one prompt") - model = build_sam3_image_model( - device=args.device, eval_mode=True, enable_segmentation=True - ) + model = build_sam3_image_model(device=args.device, eval_mode=True, enable_segmentation=True) model.eval() image = _load_image(args.image, torch.device(args.device)) @@ -139,9 +121,7 @@ def main() -> None: img_pos = img_pos.repeat(prompt_batch, 1, 1, 1) img_mask = img_mask.repeat(prompt_batch, 1, 1) - encoder_wrapper = ( - EncoderFusionWrapper(model.transformer.encoder).to(img_feats.device).eval() - ) + encoder_wrapper = EncoderFusionWrapper(model.transformer.encoder).to(img_feats.device).eval() encoder = torch.export.export( encoder_wrapper, (img_feats, img_pos, img_mask, prompt, prompt_mask), diff --git a/scripts/export_sam3_full_pipeline.py b/scripts/export_sam3_full_pipeline.py index 4c6d902..9f50f3e 100644 --- a/scripts/export_sam3_full_pipeline.py +++ b/scripts/export_sam3_full_pipeline.py @@ -25,13 +25,22 @@ def forward( self, images: torch.Tensor, token_ids: torch.Tensor, - img_ids: torch.Tensor, - text_ids: torch.Tensor, - box_embeddings: torch.Tensor, - box_mask: torch.Tensor, - box_labels: torch.Tensor, ): model = cast(Any, self.model) + num_images = images.shape[0] + num_prompts = token_ids.shape[0] + device = images.device + bs = num_images * num_prompts + + img_ids = torch.arange(num_images, device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(num_prompts) + text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) + text_ids = text_ids.repeat(num_images) + + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) + backbone_out = model.backbone.forward_image(images) text_encoder = model.backbone.language_backbone _, text_tokens = text_encoder.encoder(token_ids) @@ -48,12 +57,8 @@ def forward( input_boxes=box_embeddings, input_boxes_mask=box_mask, input_boxes_label=box_labels, - input_points=torch.zeros( - 0, int(token_ids.shape[0]), 2, device=images.device - ), - input_points_mask=torch.zeros( - int(token_ids.shape[0]), 0, device=images.device, dtype=torch.bool - ), + input_points=torch.zeros(0, bs, 2, device=device), + input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), ) geometric_prompt = Prompt( box_embeddings=box_embeddings, @@ -92,23 +97,10 @@ def _prepare_image(image: torch.Tensor, size: int) -> torch.Tensor: def _make_inputs(model, image: torch.Tensor, prompts): device = image.device - num_prompts = len(prompts) - num_images = int(image.shape[0]) - token_ids = model.backbone.language_backbone.tokenizer( - prompts, context_length=32 - ).to(device) - img_ids = torch.arange(num_images, device=device, dtype=torch.long) - img_ids = img_ids.repeat_interleave(num_prompts) - text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) - text_ids = text_ids.repeat(num_images) + token_ids = model.backbone.language_backbone.tokenizer(prompts, context_length=32).to(device) return ( image, token_ids, - img_ids, - text_ids, - torch.zeros(1, num_prompts, 4, device=device), - torch.zeros(num_prompts, 1, device=device, dtype=torch.bool), - torch.zeros(1, num_prompts, device=device, dtype=torch.long), ) @@ -157,9 +149,7 @@ def main() -> None: ) model.eval() - image = _prepare_image( - _load_image(args.image, torch.device(args.device)), size=1008 - ) + image = _prepare_image(_load_image(args.image, torch.device(args.device)), size=1008) inputs = _make_inputs(model, image, prompts) wrapper = FullSam3PipelineWrapper(model).to(image.device).eval() if image.shape[0] < 2: @@ -167,11 +157,6 @@ def main() -> None: export_inputs = ( image.repeat(repeat, 1, 1, 1), inputs[1].repeat(repeat, 1), - inputs[2].repeat(repeat), - inputs[3].repeat(repeat), - inputs[4].repeat(1, repeat, 1), - inputs[5].repeat(repeat, 1), - inputs[6].repeat(1, repeat), ) else: export_inputs = inputs @@ -186,23 +171,9 @@ def main() -> None: 3: 1008, }, "token_ids": { - 0: torch.export.Dim.AUTO, + 0: torch.export.Dim("num_prompts", min=1), 1: 32, }, - "img_ids": {0: torch.export.Dim.AUTO}, - "text_ids": {0: torch.export.Dim.AUTO}, - "box_embeddings": { - 0: 1, - 1: torch.export.Dim.AUTO, - }, - "box_mask": { - 0: torch.export.Dim.AUTO, - 1: 1, - }, - "box_labels": { - 0: 1, - 1: torch.export.Dim.AUTO, - }, }, strict=False, prefer_deferred_runtime_asserts_over_guards=True, diff --git a/scripts/test_sam3_artifacts.py b/scripts/test_sam3_artifacts.py index dd8348b..5d6c678 100644 --- a/scripts/test_sam3_artifacts.py +++ b/scripts/test_sam3_artifacts.py @@ -41,42 +41,31 @@ def _prepare_image(image: torch.Tensor, size: int) -> torch.Tensor: def _make_inputs(model, image: torch.Tensor, prompts): device = image.device - num_prompts = len(prompts) - num_images = int(image.shape[0]) tokenizer = model.backbone.language_backbone.tokenizer token_ids = tokenizer(prompts, context_length=32).to(device) - img_ids = torch.arange(num_images, device=device, dtype=torch.long) - img_ids = img_ids.repeat_interleave(num_prompts) - text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) - text_ids = text_ids.repeat(num_images) - - box_embeddings = torch.zeros(1, num_prompts, 4, device=device) - box_mask = torch.zeros(num_prompts, 1, device=device, dtype=torch.bool) - box_labels = torch.zeros(1, num_prompts, device=device, dtype=torch.long) - return ( image, token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, ) def _run_full_model(model, inputs): - ( - images, - token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, - ) = inputs + images, token_ids = inputs + num_images = images.shape[0] + num_prompts = token_ids.shape[0] + device = images.device + bs = num_images * num_prompts + + img_ids = torch.arange(num_images, device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(num_prompts) + text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) + text_ids = text_ids.repeat(num_images) + + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) backbone_out = model.backbone.forward_image(images) text_encoder = model.backbone.language_backbone _, text_tokens = text_encoder.encoder(token_ids) @@ -125,15 +114,20 @@ def _make_decoder_only_inputs_from_model( text_attention_mask, inputs, ): - ( - images, - token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, - ) = inputs + images, token_ids = inputs + num_images = images.shape[0] + num_prompts = token_ids.shape[0] + device = images.device + bs = num_images * num_prompts + + img_ids = torch.arange(num_images, device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(num_prompts) + text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) + text_ids = text_ids.repeat(num_images) + + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) backbone_out = { "backbone_fpn": backbone_fpn, "vision_pos_enc": vision_pos_enc, @@ -159,9 +153,7 @@ def _make_decoder_only_inputs_from_model( prompt, prompt_mask, backbone_out = model._encode_prompt( backbone_out, find_input, geometric_prompt ) - backbone_out, encoder_out, _ = model._run_encoder( - backbone_out, find_input, prompt, prompt_mask - ) + backbone_out, encoder_out, _ = model._run_encoder(backbone_out, find_input, prompt, prompt_mask) return ( backbone_out["backbone_fpn"], img_ids, @@ -198,9 +190,7 @@ def _color_palette(num_colors: int): return [base[i % len(base)] for i in range(num_colors)] -def _overlay_masks( - image: Image.Image, masks: torch.Tensor, scores: torch.Tensor, out_path: Path -): +def _overlay_masks(image: Image.Image, masks: torch.Tensor, scores: torch.Tensor, out_path: Path): num_prompts, num_queries = scores.shape[:2] best_idx = scores.squeeze(-1).argmax(dim=1) colors = _color_palette(num_prompts) @@ -219,9 +209,7 @@ def _overlay_masks( blended.convert("RGB").save(out_path) -def _draw_boxes( - image: Image.Image, boxes_xyxy: torch.Tensor, scores: torch.Tensor, out_path: Path -): +def _draw_boxes(image: Image.Image, boxes_xyxy: torch.Tensor, scores: torch.Tensor, out_path: Path): num_prompts, num_queries = scores.shape[:2] best_idx = scores.squeeze(-1).argmax(dim=1).clamp(max=boxes_xyxy.shape[1] - 1) colors = _color_palette(num_prompts) @@ -282,9 +270,7 @@ def main() -> None: raise ValueError("Provide at least one prompt") prompt_count = len(prompts) - model = build_sam3_image_model( - device=args.device, eval_mode=True, enable_segmentation=True - ) + model = build_sam3_image_model(device=args.device, eval_mode=True, enable_segmentation=True) model.eval() pil_image = _load_pil_image(args.image) @@ -293,9 +279,7 @@ def main() -> None: inputs = _make_inputs(model, image, prompts) with torch.no_grad(): - eager_masks, eager_boxes, eager_logits, eager_boxes_xyxy = _run_full_model( - model, inputs - ) + eager_masks, eager_boxes, eager_logits, eager_boxes_xyxy = _run_full_model(model, inputs) image_module = _load_export(args.artifact_dir / "image_encoder.pt2") text_module = _load_export(args.artifact_dir / "text_encoder.pt2") @@ -321,38 +305,10 @@ def main() -> None: device=img_feats.device, dtype=torch.bool, ) - enc_out = encoder_module( - img_feats, img_pos, img_mask, text_memory, text_attention_mask - ) + enc_out = encoder_module(img_feats, img_pos, img_mask, text_memory, text_attention_mask) assert isinstance(enc_out, tuple) - ( - images, - token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, - ) = inputs - if token_ids.shape[0] < 2: - repeat = 2 // token_ids.shape[0] - token_ids = token_ids.repeat(repeat, 1) - img_ids = img_ids.repeat(repeat) - text_ids = text_ids.repeat(repeat) - box_embeddings = box_embeddings.repeat(1, repeat, 1) - box_mask = box_mask.repeat(repeat, 1) - box_labels = box_labels.repeat(1, repeat) - decoder_inputs = ( - images, - token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, - ) - pipeline_logits, pipeline_boxes, pipeline_masks, pipeline_boxes_xyxy = ( - pipeline_module(*decoder_inputs) + pipeline_logits, pipeline_boxes, pipeline_masks, pipeline_boxes_xyxy = pipeline_module( + *inputs ) ( decoder_backbone_fpn, @@ -381,9 +337,7 @@ def main() -> None: decoder_prompt = decoder_prompt.repeat(1, repeat, 1) decoder_prompt_mask = decoder_prompt_mask.repeat(repeat, 1) decoder_valid_ratios = decoder_valid_ratios.repeat(repeat, 1, 1) - decoder_backbone_fpn = [ - feat.repeat(repeat, 1, 1, 1) for feat in decoder_backbone_fpn - ] + decoder_backbone_fpn = [feat.repeat(repeat, 1, 1, 1) for feat in decoder_backbone_fpn] pred_logits, pred_boxes, pred_masks, pred_boxes_xyxy = decoder_module( decoder_backbone_fpn, decoder_img_ids, @@ -395,8 +349,8 @@ def main() -> None: decoder_spatial_shapes, decoder_valid_ratios, ) - eager_ref_masks, eager_ref_boxes, eager_ref_logits, eager_ref_boxes_xyxy = ( - _run_full_model(model, decoder_inputs) + eager_ref_masks, eager_ref_boxes, eager_ref_logits, eager_ref_boxes_xyxy = _run_full_model( + model, inputs ) pred_logits = pred_logits[:prompt_count] diff --git a/tests/export/test_decoder_export.py b/tests/export/test_decoder_export.py index b1d3ed8..998cf6f 100644 --- a/tests/export/test_decoder_export.py +++ b/tests/export/test_decoder_export.py @@ -7,7 +7,7 @@ from sam3.model.data_misc import FindStage from sam3.model.geometry_encoders import Prompt -from tests.export.utils import capture_stderr_on_fail, get_device +from tests.export.utils import capture_stderr_on_fail, get_device, save_output_shapes class FullSam3PipelineWrapper(torch.nn.Module): @@ -19,13 +19,22 @@ def forward( self, images: torch.Tensor, token_ids: torch.Tensor, - img_ids: torch.Tensor, - text_ids: torch.Tensor, - box_embeddings: torch.Tensor, - box_mask: torch.Tensor, - box_labels: torch.Tensor, ): model = cast(Any, self.model) + num_images = images.shape[0] + num_prompts = token_ids.shape[0] + device = images.device + bs = num_images * num_prompts + + img_ids = torch.arange(num_images, device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(num_prompts) + text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) + text_ids = text_ids.repeat(num_images) + + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) + backbone_out = model.backbone.forward_image(images) text_encoder = model.backbone.language_backbone _, text_tokens = text_encoder.encoder(token_ids) @@ -42,12 +51,8 @@ def forward( input_boxes=box_embeddings, input_boxes_mask=box_mask, input_boxes_label=box_labels, - input_points=torch.zeros( - 0, int(token_ids.shape[0]), 2, device=images.device - ), - input_points_mask=torch.zeros( - int(token_ids.shape[0]), 0, device=images.device, dtype=torch.bool - ), + input_points=torch.zeros(0, bs, 2, device=device), + input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), ) geometric_prompt = Prompt( @@ -71,32 +76,31 @@ def forward( ) -def _make_inputs(batch: int, height: int, width: int, device: str, num_boxes: int = 0): +def _make_inputs(batch: int, height: int, width: int, device: str): images = torch.randn(batch, 3, height, width, device=device) token_ids = torch.ones(batch, 32, device=device, dtype=torch.long) token_ids[:, -1] = 0 - img_ids = torch.arange(batch, device=device, dtype=torch.long) - text_ids = torch.zeros(batch, device=device, dtype=torch.long) - box_embeddings = torch.rand(num_boxes, batch, 4, device=device) - if num_boxes > 0: - box_embeddings[..., 2:] = box_embeddings[..., 2:] * 0.5 - box_mask = torch.zeros(batch, num_boxes, device=device, dtype=torch.bool) - box_labels = torch.zeros(num_boxes, batch, device=device, dtype=torch.long) - return images, token_ids, img_ids, text_ids, box_embeddings, box_mask, box_labels + return images, token_ids def _make_decoder_only_inputs(model: Any, inputs): - ( - images, - token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, - ) = inputs + images, token_ids = inputs model_any = cast(Any, model) model_any = model_any.eval() + num_images = images.shape[0] + num_prompts = token_ids.shape[0] + device = images.device + bs = num_images * num_prompts + + img_ids = torch.arange(num_images, device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(num_prompts) + text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) + text_ids = text_ids.repeat(num_images) + + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) + backbone_out = model_any.backbone.forward_image(images) text_encoder = model_any.backbone.language_backbone _, text_tokens = text_encoder.encoder(token_ids) @@ -113,10 +117,8 @@ def _make_decoder_only_inputs(model: Any, inputs): input_boxes=box_embeddings, input_boxes_mask=box_mask, input_boxes_label=box_labels, - input_points=torch.zeros(0, int(token_ids.shape[0]), 2, device=images.device), - input_points_mask=torch.zeros( - int(token_ids.shape[0]), 0, device=images.device, dtype=torch.bool - ), + input_points=torch.zeros(0, bs, 2, device=device), + input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), ) geometric_prompt = Prompt( box_embeddings=box_embeddings, @@ -144,15 +146,7 @@ def _make_decoder_only_inputs(model: Any, inputs): def _export_full_sam3_pipeline(model: Any, inputs): - ( - images, - token_ids, - img_ids, - text_ids, - box_embeddings, - box_mask, - box_labels, - ) = inputs + images, token_ids = inputs device = images.device wrapper = FullSam3PipelineWrapper(model).to(device).eval() # type: ignore[arg-type] if images.shape[0] < 2: @@ -160,11 +154,6 @@ def _export_full_sam3_pipeline(model: Any, inputs): export_inputs = ( images.repeat(repeat, 1, 1, 1), token_ids.repeat(repeat, 1), - img_ids.repeat(repeat), - text_ids.repeat(repeat), - box_embeddings.repeat(1, repeat, 1), - box_mask.repeat(repeat, 1), - box_labels.repeat(1, repeat), ) else: export_inputs = inputs @@ -179,23 +168,9 @@ def _export_full_sam3_pipeline(model: Any, inputs): 3: 1008, }, "token_ids": { - 0: torch.export.Dim.AUTO, + 0: torch.export.Dim("num_prompts", min=1), 1: 32, }, - "img_ids": {0: torch.export.Dim.AUTO}, - "text_ids": {0: torch.export.Dim.AUTO}, - "box_embeddings": { - 0: 1, - 1: torch.export.Dim.AUTO, - }, - "box_mask": { - 0: torch.export.Dim.AUTO, - 1: 1, - }, - "box_labels": { - 0: 1, - 1: torch.export.Dim.AUTO, - }, }, strict=False, prefer_deferred_runtime_asserts_over_guards=True, @@ -307,7 +282,7 @@ def _export_decoder_only(model: Any, inputs): @pytest.mark.slow def test_decoder_export_static(sam3_model): device = get_device() - inputs = _make_inputs(1, 1008, 1008, device, num_boxes=1) + inputs = _make_inputs(1, 1008, 1008, device) with capture_stderr_on_fail("export_static"): exported = _export_full_sam3_pipeline(sam3_model, inputs) assert exported is not None @@ -316,7 +291,7 @@ def test_decoder_export_static(sam3_model): @pytest.mark.slow def test_decoder_export_loads(sam3_model): device = get_device() - inputs = _make_inputs(1, 1008, 1008, device, num_boxes=1) + inputs = _make_inputs(1, 1008, 1008, device) with capture_stderr_on_fail("export_loads"): exported = _export_full_sam3_pipeline(sam3_model, inputs) module = exported.module() @@ -329,7 +304,7 @@ def test_decoder_export_loads(sam3_model): @pytest.mark.slow def test_decoder_export_matches_eager(sam3_model): device = get_device() - inputs = _make_inputs(1, 1008, 1008, device, num_boxes=1) + inputs = _make_inputs(1, 1008, 1008, device) wrapper = FullSam3PipelineWrapper(sam3_model).to(device).eval() with torch.no_grad(): eager_out = wrapper(*inputs) @@ -338,6 +313,8 @@ def test_decoder_export_matches_eager(sam3_model): module = exported.module() with torch.no_grad(): export_out = module(*inputs) + save_output_shapes("full_pipeline_eager", inputs, eager_out) + save_output_shapes("full_pipeline_export", inputs, export_out) for eager, compiled in zip(eager_out, export_out): if eager is None: assert compiled is None @@ -349,20 +326,21 @@ def test_decoder_export_matches_eager(sam3_model): @pytest.mark.parametrize("batch", [1, 2]) def test_full_sam3_pipeline_export_inference_shapes(sam3_model, batch: int): device = get_device() - inputs = _make_inputs(1, 1008, 1008, device, num_boxes=1) + inputs = _make_inputs(1, 1008, 1008, device) with capture_stderr_on_fail("export_inference_shapes"): exported = _export_full_sam3_pipeline(sam3_model, inputs) module = exported.module() - new_inputs = _make_inputs(batch, 1008, 1008, device, num_boxes=1) + new_inputs = _make_inputs(batch, 1008, 1008, device) with torch.no_grad(): out = module(*new_inputs) + save_output_shapes(f"full_pipeline_export_batch_{batch}", new_inputs, out) assert isinstance(out, tuple) @pytest.mark.slow def test_decoder_only_export_loads(sam3_model): device = get_device() - inputs = _make_inputs(1, 1008, 1008, device, num_boxes=1) + inputs = _make_inputs(1, 1008, 1008, device) decoder_inputs = _make_decoder_only_inputs(sam3_model, inputs) with capture_stderr_on_fail("export_decoder_only_loads"): exported = _export_decoder_only(sam3_model, decoder_inputs) @@ -376,7 +354,7 @@ def test_decoder_only_export_loads(sam3_model): @pytest.mark.slow def test_decoder_only_export_matches_eager(sam3_model): device = get_device() - inputs = _make_inputs(1, 1008, 1008, device, num_boxes=1) + inputs = _make_inputs(1, 1008, 1008, device) decoder_inputs = _make_decoder_only_inputs(sam3_model, inputs) wrapper = DecoderOnlyWrapper(sam3_model).to(device).eval() with torch.no_grad(): @@ -386,5 +364,7 @@ def test_decoder_only_export_matches_eager(sam3_model): module = exported.module() with torch.no_grad(): export_out = module(*decoder_inputs) + save_output_shapes("decoder_only_eager", None, eager_out) + save_output_shapes("decoder_only_export", None, export_out) for eager, compiled in zip(eager_out, export_out): torch.testing.assert_close(eager, compiled, rtol=1e-3, atol=1e-3) diff --git a/tests/export/utils.py b/tests/export/utils.py index 8bff9a3..bddf172 100644 --- a/tests/export/utils.py +++ b/tests/export/utils.py @@ -41,3 +41,27 @@ def get_device() -> str: if device is None: device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu" return device + + +def save_output_shapes( + suffix: str, + inputs: tuple[torch.Tensor, ...] | None, + outputs: tuple[torch.Tensor | None, ...], +) -> None: + LOG_DIR.mkdir(parents=True, exist_ok=True) + lines: list[str] = [] + if inputs is not None: + for idx, value in enumerate(inputs): + lines.append( + f"input[{idx}] shape={tuple(value.shape)} dtype={value.dtype} device={value.device}" + ) + for idx, value in enumerate(outputs): + if value is None: + lines.append(f"output[{idx}] None") + else: + lines.append( + f"output[{idx}] shape={tuple(value.shape)} dtype={value.dtype} device={value.device}" + ) + log_path = LOG_DIR / f"{_current_test_name()}-{suffix}.shapes.txt" + with log_path.open("w", encoding="utf-8") as handle: + handle.write("\n".join(lines)) diff --git a/torch_full_pipeline_export_2_input.py b/torch_full_pipeline_export_2_input.py new file mode 100644 index 0000000..e24c220 --- /dev/null +++ b/torch_full_pipeline_export_2_input.py @@ -0,0 +1,203 @@ +import argparse +from typing import Iterable + +import torch + +from sam3.model.data_misc import FindStage +from sam3.model.geometry_encoders import Prompt +from sam3.model_builder import build_sam3_image_model + + +SAM3_INPUT_SIZE = 1008 +SAM3_CONTEXT_LENGTH = 32 + + +class TwoInputWrapper(torch.nn.Module): + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + self.model = model + + def forward( + self, + images: torch.Tensor, + token_ids: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: + num_images = images.shape[0] + num_prompts = token_ids.shape[0] + device = images.device + bs = num_images * num_prompts + + img_ids = torch.arange(num_images, device=device, dtype=torch.long) + img_ids = img_ids.repeat_interleave(num_prompts) + text_ids = torch.arange(num_prompts, device=device, dtype=torch.long) + text_ids = text_ids.repeat(num_images) + + box_embeddings = torch.zeros(1, bs, 4, device=device) + box_mask = torch.zeros(bs, 1, device=device, dtype=torch.bool) + box_labels = torch.zeros(1, bs, device=device, dtype=torch.long) + + backbone_out = self.model.backbone.forward_image(images) + text_encoder = self.model.backbone.language_backbone + _, text_tokens = text_encoder.encoder(token_ids) + text_tokens = text_tokens.transpose(0, 1) + text_memory = text_encoder.resizer(text_tokens) + text_attention_mask = token_ids.ne(0) + text_attention_mask = text_attention_mask.ne(1) + backbone_out["language_features"] = text_memory + backbone_out["language_mask"] = text_attention_mask + + find_input = FindStage( + img_ids=img_ids, + text_ids=text_ids, + input_boxes=box_embeddings, + input_boxes_mask=box_mask, + input_boxes_label=box_labels, + input_points=torch.zeros(0, bs, 2, device=device), + input_points_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool), + ) + geometric_prompt = Prompt( + box_embeddings=box_embeddings, + box_mask=box_mask, + box_labels=box_labels, + ) + out = self.model.forward_grounding( + backbone_out=backbone_out, + find_input=find_input, + find_target=None, + geometric_prompt=geometric_prompt, + ) + return ( + out["pred_logits"], + out["pred_boxes"], + out["pred_masks"], + out.get("presence_logit_dec"), + ) + + +def _parse_int_list(value: str) -> list[int]: + items = [item.strip() for item in value.split(",") if item.strip()] + return [int(item) for item in items] + + +def _make_token_ids( + model: torch.nn.Module, prompts: Iterable[str], device: torch.device +) -> torch.Tensor: + tokenizer = model.backbone.language_backbone.tokenizer + return tokenizer(list(prompts), context_length=SAM3_CONTEXT_LENGTH).to(device) + + +def _run_export( + model: torch.nn.Module, + device: torch.device, + export_prompts: list[str], + image_batch: int, + dynamic_images: bool, +) -> torch.export.ExportedProgram: + image = torch.randn(image_batch, 3, SAM3_INPUT_SIZE, SAM3_INPUT_SIZE, device=device) + token_ids = _make_token_ids(model, export_prompts, device) + wrapper = TwoInputWrapper(model).to(device).eval() + images_dim = torch.export.Dim.AUTO if dynamic_images else image_batch + num_prompts_dim = torch.export.Dim("num_prompts", min=1) + with torch.no_grad(): + return torch.export.export( + wrapper, + (image, token_ids), + dynamic_shapes={ + "images": {0: images_dim, 2: SAM3_INPUT_SIZE, 3: SAM3_INPUT_SIZE}, + "token_ids": {0: num_prompts_dim, 1: SAM3_CONTEXT_LENGTH}, + }, + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + + +def _run_inference( + model: torch.nn.Module, + module: torch.nn.Module, + device: torch.device, + image_batch: int, + prompt_count: int, +) -> None: + prompts = [f"prompt_{idx}" for idx in range(prompt_count)] + image = torch.randn(image_batch, 3, SAM3_INPUT_SIZE, SAM3_INPUT_SIZE, device=device) + token_ids = _make_token_ids(model, prompts, device) + with torch.no_grad(): + out = module(image, token_ids) + print(f"run prompts={prompt_count} outputs={len(out)}") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="Optional path to sam3 checkpoint.", + ) + parser.add_argument( + "--image-batch", + type=int, + default=1, + help="Number of images for export and inference.", + ) + parser.add_argument( + "--export-prompts", + type=str, + default="cat,dog,building", + help="Comma-separated prompts for export sample.", + ) + parser.add_argument( + "--test-prompts", + type=str, + default="1,2,5", + help="Comma-separated prompt counts to run after export.", + ) + parser.add_argument( + "--dynamic-images", + action="store_true", + help="Export with a dynamic image batch dimension.", + ) + args = parser.parse_args() + + device = torch.device(args.device) + print("torch", torch.__version__, "cuda", torch.version.cuda, "device", device) + model = build_sam3_image_model( + device=str(device), + eval_mode=True, + enable_segmentation=True, + checkpoint_path=args.checkpoint, + load_from_HF=args.checkpoint is None, + ) + model.eval() + + export_prompts = [p.strip() for p in args.export_prompts.split(",") if p.strip()] + test_prompt_counts = _parse_int_list(args.test_prompts) + + exported = _run_export( + model=model, + device=device, + export_prompts=export_prompts, + image_batch=args.image_batch, + dynamic_images=args.dynamic_images, + ) + module = exported.module() + for prompt_count in test_prompt_counts: + try: + _run_inference( + model=model, + module=module, + device=device, + image_batch=args.image_batch, + prompt_count=prompt_count, + ) + except Exception as exc: + print(f"run prompts={prompt_count} failed: {type(exc).__name__}: {exc}") + + +if __name__ == "__main__": + main() From 83a1c62adbf1d81f58f1331d129226d10beb37ea Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Mon, 9 Feb 2026 12:13:38 -0800 Subject: [PATCH 2/3] Add minimal AOTInductor compile script Replace the argparse-based helper with a fixed-path, math-SDPA compile entrypoint. --- scripts/compile_sam3_aoti.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 scripts/compile_sam3_aoti.py diff --git a/scripts/compile_sam3_aoti.py b/scripts/compile_sam3_aoti.py new file mode 100644 index 0000000..ae7f3e8 --- /dev/null +++ b/scripts/compile_sam3_aoti.py @@ -0,0 +1,14 @@ +import argparse +import torch +import torchvision.ops # noqa: F401 + + +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) + +exported = torch.export.load("artifacts/export/full_sam3_pipeline.pt2") +torch._inductor.aoti_compile_and_package( + exported, + package_path="artifacts/aoti/full_sam3_pipeline_aoti.pt2", +) From 4d17e2b4a4680db8eeb4619c97bab1161d6d6e00 Mon Sep 17 00:00:00 2001 From: Ryan Avery Date: Tue, 10 Feb 2026 12:06:31 -0800 Subject: [PATCH 3/3] Document AOTInductor workaround attempt Apply split_reductions=False during export attempts to match workaround guidance, without changing export output. --- scripts/compile_sam3_aoti.py | 3 +++ scripts/export_sam3_full_pipeline.py | 1 + 2 files changed, 4 insertions(+) diff --git a/scripts/compile_sam3_aoti.py b/scripts/compile_sam3_aoti.py index ae7f3e8..c74375a 100644 --- a/scripts/compile_sam3_aoti.py +++ b/scripts/compile_sam3_aoti.py @@ -1,12 +1,15 @@ import argparse import torch import torchvision.ops # noqa: F401 +from torch._inductor import config as inductor_config torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) +inductor_config.split_reductions = False + exported = torch.export.load("artifacts/export/full_sam3_pipeline.pt2") torch._inductor.aoti_compile_and_package( exported, diff --git a/scripts/export_sam3_full_pipeline.py b/scripts/export_sam3_full_pipeline.py index 9f50f3e..f7dd501 100644 --- a/scripts/export_sam3_full_pipeline.py +++ b/scripts/export_sam3_full_pipeline.py @@ -6,6 +6,7 @@ import torch from PIL import Image + REPO_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(REPO_ROOT))