Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ testpaths = ["tests"]
python_files = "test_*.py"
python_classes = "Test*"
python_functions = "test_*"
markers = ["slow: builds the full SAM3 model (heavy GPU + slow)"]
141 changes: 108 additions & 33 deletions sam3/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,46 @@ def forward(
tgt = tgt + self.catext_dropout(tgt2)
tgt = self.catext_norm(tgt)

if presence_token is not None:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
cross_attn_mask = torch.cat(
[presence_token_mask, cross_attn_mask], dim=1
) # (bs*nheads, 1+nq, hw)
if presence_token is not None and cross_attn_mask is not None:
# Support both 3D (bs*nheads, nq, hw) and 4D (bs, nheads, nq, hw)
# cross_attn_mask shapes. The 4D form is used by the export-friendly
# SDPA path below; nn.MultiheadAttention requires the 3D form.
if cross_attn_mask.dim() == 4:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :, :1, :])
cross_attn_mask = torch.cat(
[presence_token_mask, cross_attn_mask], dim=2
) # (bs, nheads, 1+nq, hw)
else:
presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
cross_attn_mask = torch.cat(
[presence_token_mask, cross_attn_mask], dim=1
) # (bs*nheads, 1+nq, hw)

# Cross attention to image
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=cross_attn_mask,
key_padding_mask=(
memory_key_padding_mask.transpose(0, 1)
if memory_key_padding_mask is not None
else None
),
)[0]
key_padding_mask = (
memory_key_padding_mask.transpose(0, 1)
if memory_key_padding_mask is not None
else None
)
if cross_attn_mask is not None and cross_attn_mask.dim() == 4:
# nn.MultiheadAttention does not accept (bs, nheads, q, k) attention
# bias tensors. Box-RPB cross-attn passes a 4D additive bias, so
# route through a manual SDPA path that preserves the per-head bias.
tgt2 = self._cross_attn_with_rpb(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_bias=cross_attn_mask,
key_padding_mask=key_padding_mask,
)
else:
tgt2 = self.cross_attn(
query=self.with_pos_embed(tgt, tgt_query_pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=cross_attn_mask,
key_padding_mask=key_padding_mask,
)[0]

tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
Expand All @@ -188,6 +210,67 @@ def forward(

return tgt, presence_token_out

def _cross_attn_with_rpb(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attn_bias: Tensor,
key_padding_mask: Optional[Tensor],
) -> Tensor:
"""
Manual SDPA cross-attention that preserves a per-head additive bias.
Used when ``attn_bias`` is 4D (bs, nheads, q, k) — the relative box
position bias path. ``nn.MultiheadAttention`` rejects 4D masks, so we
unpack the in/out projections and call ``scaled_dot_product_attention``.
"""
# Works on both nn.MultiheadAttention and sam3's custom
# MultiheadAttention (model_misc.MultiheadAttention) — both expose
# in_proj_weight/bias, out_proj, num_heads, head_dim, dropout.
# Cast projection weights to the activation dtype so the call works
# under autocast (image encoder may emit bf16 while weights stay fp32);
# the upstream nn.MultiheadAttention forward does this internally.
mha = self.cross_attn
in_proj_weight = mha.in_proj_weight.to(dtype=query.dtype)
in_proj_bias = (
mha.in_proj_bias.to(dtype=query.dtype)
if mha.in_proj_bias is not None
else None
)
q, k, v = torchF._in_projection_packed(
query, key, value, in_proj_weight, in_proj_bias
)
tgt_len, bsz, _ = q.shape
num_heads = mha.num_heads
head_dim = mha.head_dim
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
k = k.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
v = v.contiguous().view(-1, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
src_len = k.shape[2]
bias = attn_bias.to(dtype=q.dtype)
if bias.dim() == 3:
bias = bias.view(bsz, num_heads, tgt_len, src_len)
if key_padding_mask is not None:
pad = key_padding_mask[:, None, None, :].to(dtype=q.dtype)
pad = pad.masked_fill(pad > 0, float("-inf"))
bias = bias + pad
attn_output = torchF.scaled_dot_product_attention(
q,
k,
v,
attn_mask=bias,
dropout_p=mha.dropout if self.training else 0.0,
is_causal=False,
)
attn_output = attn_output.permute(2, 0, 1, 3).reshape(tgt_len, bsz, -1)
out_w = mha.out_proj.weight.to(dtype=attn_output.dtype)
out_b = (
mha.out_proj.bias.to(dtype=attn_output.dtype)
if mha.out_proj.bias is not None
else None
)
return torchF.linear(attn_output, out_w, out_b)


class TransformerDecoder(nn.Module):
def __init__(
Expand Down Expand Up @@ -338,11 +421,13 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
self.compilable_stored_size = (H, W)

if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
H,
W,
):
# good, hitting the cache, will be compilable
if torch.compiler.is_compiling():
# When tracing (torch.compile or torch.export, strict or non-strict),
# always reuse the compilable cache. The ``compilable_stored_size ==
# (H, W)`` check would compare a tuple of concrete ints against a
# tuple of SymInts and raise GuardOnDataDependentSymNode.
coords_h, coords_w = self.compilable_cord_cache
elif self.compilable_stored_size == (H, W):
coords_h, coords_w = self.compilable_cord_cache
else:
# cache miss, will create compilation issue
Expand All @@ -353,9 +438,6 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
)
coords_h, coords_w = self.coord_cache[feat_size]

assert coords_h.shape == (H,)
assert coords_w.shape == (W,)

deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
deltas_y = deltas_y.view(bs, num_queries, -1, 2)
deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
Expand Down Expand Up @@ -393,20 +475,12 @@ def _get_rpb_matrix(self, reference_boxes, feat_size):
act_ckpt_enable=self.training and self.use_act_checkpoint,
) # bs, num_queries, H, n_heads

if not torch.compiler.is_dynamo_compiling():
assert deltas_x.shape[:3] == (bs, num_queries, W)
assert deltas_y.shape[:3] == (bs, num_queries, H)

B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
2
) # bs, num_queries, H, W, n_heads
if not torch.compiler.is_dynamo_compiling():
assert B.shape[:4] == (bs, num_queries, H, W)
B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
B = B.contiguous() # memeff attn likes ordered strides
if not torch.compiler.is_dynamo_compiling():
assert B.shape[2:] == (num_queries, H * W)
return B

def forward(
Expand Down Expand Up @@ -519,11 +593,12 @@ def forward(
assert spatial_shapes.shape[0] == 1, (
"only single scale support implemented"
)
# Keep the 4D (bs, nheads, nq, H*W) shape so the export-friendly
# SDPA path in TransformerDecoderLayer can apply the bias per-head.
memory_mask = self._get_rpb_matrix(
reference_boxes,
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
)
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
if self.training:
assert self.use_act_checkpoint, (
"Activation checkpointing not enabled in the decoder"
Expand Down
2 changes: 1 addition & 1 deletion sam3/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def forward(
else None
)
else:
assert all(x.dim == 4 for x in src), (
assert all(x.dim() == 4 for x in src), (
"expected list of (bs, c, h, w) tensors"
)

Expand Down
19 changes: 17 additions & 2 deletions sam3/model/geometry_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,11 +645,26 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
# We need to denormalize, and convert to [x, y, x, y]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
# pin_memory + non_blocking copy is unsupported under torch.export tracing.
if not torch.compiler.is_compiling() and boxes_xyxy.device.type != "cpu":
scale = scale.pin_memory().to(
device=boxes_xyxy.device, non_blocking=True
)
else:
scale = scale.to(device=boxes_xyxy.device)
scale = scale.view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
# ROIAlign accepts a list of per-batch box tensors OR a single
# [N, 5] tensor with batch_idx prepended. The list form is hard to
# trace through torch.export, so use the batched format.
boxes_xyxy = boxes_xyxy.transpose(0, 1) # (bs, n_boxes, 4)
batch_idx = torch.arange(
bs, device=boxes_xyxy.device, dtype=boxes_xyxy.dtype
)
batch_idx = batch_idx.view(bs, 1, 1).expand(bs, n_boxes, 1)
boxes_for_roi = torch.cat([batch_idx, boxes_xyxy], dim=-1).reshape(-1, 5)
sampled = torchvision.ops.roi_align(
img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size
img_feats, boxes_for_roi.float(), self.roi_size
)
assert list(sampled.shape) == [
bs * n_boxes,
Expand Down
15 changes: 10 additions & 5 deletions sam3/model/position_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ def __init__(
self.cache[size] = self.cache[size].clone().detach()

def _encode_xy(self, x, y):
# The positions are expected to be normalized
assert len(x) == len(y) and x.ndim == y.ndim == 1
# The positions are expected to be normalized.
# Skip the size assert when tracing — symbolic shapes from dynamic
# prompts make `len(x) == len(y)` raise GuardOnDataDependentSymNode.
if not torch.compiler.is_compiling():
assert len(x) == len(y) and x.ndim == y.ndim == 1
x_embed = x * self.scale
y_embed = y * self.scale

Expand Down Expand Up @@ -95,9 +98,11 @@ def encode_points(self, x, y, labels):

@torch.no_grad()
def forward(self, x):
cache_key = None
cache_key = (x.shape[-2], x.shape[-1])
if cache_key in self.cache:
# Skip the cache when tracing with symbolic shapes — looking up a SymInt
# key in a dict with concrete-int keys raises GuardOnDataDependentSymNode.
use_cache = all(isinstance(dim, int) for dim in cache_key)
if use_cache and cache_key in self.cache:
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
y_embed = (
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
Expand Down Expand Up @@ -127,6 +132,6 @@ def forward(self, x):
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
if cache_key is not None:
if use_cache:
self.cache[cache_key] = pos[0]
return pos
21 changes: 15 additions & 6 deletions sam3/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def _create_vl_backbone(vit_neck, text_encoder):
return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1)


def _create_transformer_encoder(use_fa3=False) -> TransformerEncoderFusion:
def _create_transformer_encoder(
use_fa3: bool = False, num_feature_levels: int = 1
) -> TransformerEncoderFusion:
"""Create transformer encoder with its layer."""
encoder_layer = TransformerEncoderLayer(
activation="relu",
Expand Down Expand Up @@ -153,7 +155,7 @@ def _create_transformer_encoder(use_fa3=False) -> TransformerEncoderFusion:
layer=encoder_layer,
num_layers=6,
d_model=256,
num_feature_levels=1,
num_feature_levels=num_feature_levels,
frozen=False,
use_act_checkpoint=True,
add_pooled_text_to_img_feat=False,
Expand Down Expand Up @@ -307,14 +309,15 @@ def _create_sam3_model(
dot_prod_scoring,
inst_interactive_predictor,
eval_mode,
num_feature_levels: int = 1,
):
"""Create the SAM3 image model."""
common_params = {
"backbone": backbone,
"transformer": transformer,
"input_geometry_encoder": input_geometry_encoder,
"segmentation_head": segmentation_head,
"num_feature_levels": 1,
"num_feature_levels": num_feature_levels,
"o2m_mask_predict": True,
"dot_prod_scoring": dot_prod_scoring,
"use_instance_query": False,
Expand Down Expand Up @@ -527,10 +530,14 @@ def _create_vision_backbone(


def _create_sam3_transformer(
has_presence_token: bool = True, use_fa3: bool = False
has_presence_token: bool = True,
use_fa3: bool = False,
num_feature_levels: int = 1,
) -> TransformerWrapper:
"""Create SAM3 transformer encoder and decoder."""
encoder: TransformerEncoderFusion = _create_transformer_encoder(use_fa3=use_fa3)
encoder: TransformerEncoderFusion = _create_transformer_encoder(
use_fa3=use_fa3, num_feature_levels=num_feature_levels
)
decoder: TransformerDecoder = _create_transformer_decoder(use_fa3=use_fa3)

return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
Expand Down Expand Up @@ -579,6 +586,7 @@ def build_sam3_image_model(
enable_segmentation=True,
enable_inst_interactivity=False,
compile=False,
num_feature_levels: int = 1,
):
"""
Build SAM3 image model
Expand Down Expand Up @@ -613,7 +621,7 @@ def build_sam3_image_model(
backbone = _create_vl_backbone(vision_encoder, text_encoder)

# Create transformer components
transformer = _create_sam3_transformer()
transformer = _create_sam3_transformer(num_feature_levels=num_feature_levels)

# Create dot product scoring
dot_prod_scoring = _create_dot_product_scoring()
Expand Down Expand Up @@ -641,6 +649,7 @@ def build_sam3_image_model(
dot_prod_scoring,
inst_predictor,
eval_mode,
num_feature_levels=num_feature_levels,
)
if load_from_HF and checkpoint_path is None:
checkpoint_path = download_ckpt_from_hf(version="sam3")
Expand Down
Loading