Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions doctr/models/layout/lw_detr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,28 +144,30 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int

for b in range(boxes.shape[0]):
# Convert logits to probabilities and get scores and labels
exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True))
prob = exp / exp.sum(axis=-1, keepdims=True)
prob = 1.0 / (1.0 + np.exp(-logits[b]))

# Remove background class
prob_fg = prob[:, :-1]

prob_fg = prob[:, :-1] # exclude background
scores = prob_fg.max(axis=-1)
labels = prob_fg.argmax(axis=-1)

# Keep only topk predictions before NMS
if self.topk is not None and len(scores) > self.topk:
idxs = np.argsort(scores)[::-1][: self.topk]
idxs = np.argpartition(-scores, self.topk)[: self.topk]
idxs = idxs[np.argsort(-scores[idxs])]
else:
idxs = np.arange(len(scores))

scores_b = scores[idxs]
labels_b = labels[idxs]
bboxes = boxes[b][idxs]

mask = scores_b > self.score_thresh

bboxes = bboxes[mask]
scores_b = scores_b[mask]
labels_b = labels_b[mask]
# Filter by score threshold
thresh_mask = scores_b >= self.score_thresh
scores_b = scores_b[thresh_mask]
labels_b = labels_b[thresh_mask]
bboxes = bboxes[thresh_mask]

polys, _ = (
self._decode_boxes(bboxes)
Expand Down
87 changes: 46 additions & 41 deletions doctr/models/layout/lw_detr/layers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
# we invert the attention_mask
value = value.masked_fill(~attention_mask[..., None], float(0))
value = value.masked_fill(attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
Expand Down Expand Up @@ -293,6 +293,8 @@
else:
raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}")

# clamp sampling locations to keep them within the valid range [0, 1] for grid sampling
sampling_locations = sampling_locations.clamp(0.0, 1.0)
output = self.attn(
value,
spatial_shapes_list,
Expand Down Expand Up @@ -409,26 +411,28 @@
"""
scale = 2 * math.pi
dim = hidden_size // 2
# Keep dim_t in float32 for numerical precision; cast output to match caller dtype
dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
x_embed = pos_tensor[:, :, 0].float() * scale
y_embed = pos_tensor[:, :, 1].float() * scale
pos_x = x_embed[:, :, None] / dim_t
pos_y = y_embed[:, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
if pos_tensor.size(-1) == 4:
w_embed = pos_tensor[:, :, 2] * scale
w_embed = pos_tensor[:, :, 2].float() * scale
pos_w = w_embed[:, :, None] / dim_t
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)

h_embed = pos_tensor[:, :, 3] * scale
h_embed = pos_tensor[:, :, 3].float() * scale
pos_h = h_embed[:, :, None] / dim_t
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)

pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
else:
raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
# Cast back to the caller's dtype (supports bfloat16 / float16 AMP)
return pos.to(pos_tensor.dtype)


Expand Down Expand Up @@ -480,11 +484,7 @@
self.bbox_embed = bbox_embed

self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2)
self.angle_proj = nn.Sequential(
nn.Linear(4, self.d_model),
nn.ReLU(),
nn.Linear(self.d_model, self.d_model),
)
self.angle_proj = nn.Linear(2, self.d_model)

def get_reference(
self, reference_points: torch.Tensor, valid_ratios: torch.Tensor
Expand All @@ -498,7 +498,7 @@
tensor containing the valid ratios for each level of the input feature maps

Returns:
reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4)
reference_points_inputs: (batch_size, num_queries, 1, num_levels, 6)
tensor containing the reference point inputs for the decoder layers,
which are the normalized center coordinates,
width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps
Expand All @@ -515,45 +515,54 @@
# DETR positional encoding
query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model)
base_query_pos = self.ref_point_head(query_sine_embed)
# Angle embedding
sin_t = angle[..., 0:1]
cos_t = angle[..., 1:2]

angle_feat = torch.cat(
[
sin_t,
cos_t,
2 * sin_t * cos_t,
cos_t**2 - sin_t**2,
],
dim=-1,
)

angle_emb = self.angle_proj(angle_feat)
angle_emb = self.angle_proj(angle)
# Combine
query_pos = base_query_pos + angle_emb
return reference_points_inputs, query_pos

def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
reference_points = reference_points.to(deltas.device)
cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
"""Refine bounding boxes by applying the predicted deltas to the reference points.

Check notice on line 525 in doctr/models/layout/lw_detr/layers/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/layers/pytorch.py#L525

Missing dashed underline after section ('Returns') (D407)

Check notice on line 525 in doctr/models/layout/lw_detr/layers/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/layers/pytorch.py#L525

Section name should end with a newline ('Returns', not 'Returns:') (D406)
The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format.
The refined boxes are computed as follows:

# Clamp deltas to prevent exp() from shooting to Infinity during early training
wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4]
cx' = cx + delta_cx * w
cy' = cy + delta_cy * h
w' = w * exp(delta_w)
h' = h * exp(delta_h)
sinθ' = sinθ * cosΔ + cosθ * sinΔ
cosθ' = cosθ * cosΔ - sinθ * sinΔ

# Add eps=1e-6 to avoid division-by-zero NaN creation
delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6)
Args:
reference_points: (N, S, 6) tensor containing the reference points
deltas: (N, S, 6) tensor containing the predicted deltas

Returns:
refined_boxes: (N, S, 6) tensor containing the refined bounding boxes
"""
reference_points = reference_points.to(deltas.device)
# size
wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=4.0).exp() * reference_points[..., 2:4]
wh = wh.clamp(min=1e-4, max=1.0)
# center
raw_cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2]
half_wh = wh / 2
cxcy = raw_cxcy.clamp(
min=half_wh,
max=1.0 - half_wh,
)
# rotation
sin_d = deltas[..., 4:5]
cos_d = deltas[..., 5:6] + 1.0
delta_rot = F.normalize(torch.cat([sin_d, cos_d], dim=-1), dim=-1, eps=1e-6)
sin_delta = delta_rot[..., 0:1]
cos_delta = delta_rot[..., 1:2]
sin_ref = reference_points[..., 4:5]
cos_ref = reference_points[..., 5:6]

# compose rotations
sin_new = sin_ref * cos_delta + cos_ref * sin_delta
cos_new = cos_ref * cos_delta - sin_ref * sin_delta

# Add eps=1e-6 here too
rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6)

return torch.cat((cxcy, wh, rot), dim=-1)

def forward(
Expand Down Expand Up @@ -590,11 +599,7 @@
if self.bbox_embed is not None:
delta = self.bbox_embed(hidden_states_norm)

reference_points = self.refine_boxes(
reference_points.squeeze(2),
delta,
)

reference_points = self.refine_bboxes(reference_points, delta)
intermediate_reference_points.append(reference_points)

reference_points_inputs, query_pos = self.get_reference(
Expand Down
Loading
Loading