From 03c0166ef81896331eb2d193d0b6b76c7d05e96c Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 27 May 2026 14:24:47 +0200 Subject: [PATCH 1/9] straight check --- doctr/models/layout/lw_detr/base.py | 164 +++--- doctr/models/layout/lw_detr/layers/pytorch.py | 219 +++----- doctr/models/layout/lw_detr/loss.py | 527 ++++++++++++++++++ doctr/models/layout/lw_detr/pytorch.py | 475 ++++------------ 4 files changed, 781 insertions(+), 604 deletions(-) create mode 100644 doctr/models/layout/lw_detr/loss.py diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 2aa59e7d18..f7c926b13b 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -9,7 +9,6 @@ import numpy as np from doctr.models.core import BaseModel -from doctr.utils import order_points __all__ = ["_LWDETR", "LWDETRPostProcessor"] @@ -39,29 +38,36 @@ def __init__( self.topk = topk self.assume_straight_pages = assume_straight_pages - def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """Decode the predicted boxes from OBB format to polygon format - - Args: - boxes: array of predicted boxes in OBB format (N, 6) (cx, cy, w, h, sin(theta), cos(theta)) - - Returns: - tuple of (polys, angles) where polys is an array of decoded polygons (N, 4, 2) - and angles is an array of angles in radians (N,) + def _decode_boxes(self, boxes: np.ndarray) -> np.ndarray: """ - cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] - sin, cos = boxes[:, 4], boxes[:, 5] - - angles = np.arctan2(sin, cos) + Decode cxcywh -> polygons (axis-aligned rectangles) + """ + cx = boxes[:, 0] + cy = boxes[:, 1] + w = boxes[:, 2] + h = boxes[:, 3] polys = [] + for i in range(len(boxes)): - rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i]))) + x1 = cx[i] - w[i] / 2 + y1 = cy[i] - h[i] / 2 + x2 = cx[i] + w[i] / 2 + y2 = cy[i] + h[i] / 2 + + poly = np.array( + [ + [x1, y1], + [x2, y1], + [x2, y2], + [x1, y2], + ], + dtype=np.float32, + ) - poly = order_points(cv2.boxPoints(rect)) polys.append(poly) - return np.asarray(polys, dtype=np.float32), angles + return np.asarray(polys, dtype=np.float32) def _iou(self, poly1: np.ndarray, poly2: np.ndarray) -> float: """Compute the IoU between two polygons @@ -136,22 +142,21 @@ def _nms(self, polys: np.ndarray, scores: np.ndarray, labels: np.ndarray) -> lis suppressed[j] = True return keep - def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]: + def __call__(self, logits: np.ndarray, boxes: np.ndarray): + logits = np.asarray(logits) boxes = np.asarray(boxes) - results: list[tuple[list[int], np.ndarray, list[float]]] = [] + results = [] for b in range(boxes.shape[0]): - # Convert logits to probabilities and get scores and labels exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) prob = exp / exp.sum(axis=-1, keepdims=True) - prob_fg = prob[:, :-1] # exclude background + prob_fg = prob[:, :-1] scores = prob_fg.max(axis=-1) labels = prob_fg.argmax(axis=-1) - # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: idxs = np.argsort(scores)[::-1][: self.topk] else: @@ -167,44 +172,36 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int scores_b = scores_b[mask] labels_b = labels_b[mask] - polys, _ = ( - self._decode_boxes(bboxes) - if len(bboxes) > 0 - else ( - np.zeros((0, 4, 2), dtype=np.float32), - np.zeros((0,), dtype=np.float32), - ) - ) + polys = self._decode_boxes(bboxes) if len(bboxes) > 0 else np.zeros((0, 4, 2), dtype=np.float32) keep = self._nms(polys, scores_b, labels_b) if len(polys) > 0 else [] - final_labels = [] final_boxes = [] + final_labels = [] final_scores = [] for idx in keep: - poly = polys[idx].reshape(-1).tolist() + poly = polys[idx] + if self.assume_straight_pages: - x_coords = poly[0::2] - y_coords = poly[1::2] - xmin, xmax = min(x_coords), max(x_coords) - ymin, ymax = min(y_coords), max(y_coords) + # 👉 COCO-style axis aligned box from polygon + xmin = float(np.min(poly[:, 0])) + xmax = float(np.max(poly[:, 0])) + ymin = float(np.min(poly[:, 1])) + ymax = float(np.max(poly[:, 1])) + final_boxes.append([xmin, ymin, xmax, ymax]) else: - final_boxes.append(poly) + final_boxes.append(poly.reshape(-1).tolist()) final_labels.append(int(labels_b[idx])) final_scores.append(float(scores_b[idx])) - final_boxes_arr = ( - np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4, 2) - if not self.assume_straight_pages - else np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4) - ) + final_boxes_arr = np.asarray(final_boxes, dtype=np.float32) results.append(( final_labels, - final_boxes_arr, + final_boxes_arr, # <- NOW ALWAYS CLEAN FORMAT final_scores, )) @@ -221,55 +218,12 @@ def build_target( target: list[dict[str, np.ndarray]], class_names: list[str], ) -> list[dict[str, Any]]: - """Build the target for LW-DETR training - - Args: - target: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding - to class names and values corresponding to lists of boxes in either polygon format (4, 2) - or bounding box format (4,) (xmin, ymin, xmax, ymax) - class_names: list of class names - - Returns: - list of dictionaries with keys "boxes" and "labels" where "boxes" is an array of shape (num_boxes, 6) - containing the box parameters in OBB format (cx, cy, w, h, sin(theta), cos(theta)) - and "labels" is an array of shape (num_boxes,) containing the class labels + """ + Build targets in COCO format: [xmin, ymin, w, h] """ targets = [] class_to_id = {name: i for i, name in enumerate(class_names)} - def _quad_to_obb(poly: np.ndarray): - poly = np.asarray(poly, dtype=np.float32) - - # Center point is simply the average of the relative vertices - cx, cy = np.mean(poly, axis=0) - - edges = np.stack([ - poly[1] - poly[0], - poly[2] - poly[1], - poly[3] - poly[2], - poly[0] - poly[3], - ]) - - lengths = np.linalg.norm(edges, axis=1) - i = np.argmax(lengths) - dx, dy = edges[i] - - theta = np.arctan2(dy, dx) - - # Width and height remain cleanly in relative coordinate space [0, 1] - w = np.mean([lengths[i], lengths[(i + 2) % 4]]) - h = np.mean([lengths[(i + 1) % 4], lengths[(i + 3) % 4]]) - - # Enforce strict unit-length normal vectors for rotation - sin_t = np.sin(theta) - cos_t = np.cos(theta) - norm = np.sqrt(sin_t**2 + cos_t**2) + 1e-8 - - return np.array( - [cx, cy, w, h, sin_t / norm, cos_t / norm], - dtype=np.float32, - ) - def to_quad(box: np.ndarray): box = np.asarray(box, dtype=np.float32) if box.shape == (4,): @@ -281,6 +235,19 @@ def to_quad(box: np.ndarray): return box.astype(np.float32) raise ValueError(f"Unsupported box shape: {box.shape}") + def quad_to_coco(poly: np.ndarray) -> np.ndarray: + xmin = float(np.min(poly[:, 0])) + xmax = float(np.max(poly[:, 0])) + ymin = float(np.min(poly[:, 1])) + ymax = float(np.max(poly[:, 1])) + + w = xmax - xmin + h = ymax - ymin + cx = xmin + w / 2.0 + cy = ymin + h / 2.0 + + return np.array([cx, cy, w, h], dtype=np.float32) + for sample in target: boxes_all = [] labels_all = [] @@ -295,20 +262,29 @@ def to_quad(box: np.ndarray): if boxes.ndim == 1: boxes = boxes[None, :] + # sanity check normalized coords + flat = boxes.ravel() + coord_vals = flat[flat > 0] + if len(coord_vals) > 0 and coord_vals.max() > 1.5: + raise ValueError("build_target expects normalized [0,1] coordinates.") + for box in boxes: poly = to_quad(box) - obb = _quad_to_obb(poly) + coco_box = quad_to_coco(poly) - # filter out degenerate boxes - if obb[2] <= 1e-5 or obb[3] <= 1e-5: + if coco_box[2] <= 1e-5 or coco_box[3] <= 1e-5: continue - boxes_all.append(obb) + boxes_all.append(coco_box) labels_all.append(cls_id) + if len(boxes_all) == 0: + boxes_all = np.zeros((0, 4), dtype=np.float32) + labels_all = np.zeros((0,), dtype=np.int64) + targets.append({ - "boxes": np.asarray(boxes_all, dtype=np.float32), - "labels": np.asarray(labels_all, dtype=np.int64), + "boxes": np.asarray(boxes_all, dtype=np.float32), # (N, 4) + "class_labels": np.asarray(labels_all, dtype=np.int64), }) return targets diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 08f9ea9fb4..9f6d75ab30 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -113,7 +113,7 @@ def forward( hidden_states_original = hidden_states if position_embeddings is not None: - hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states = hidden_states + position_embeddings if self.training: # at training, we use group detr technique to @@ -238,6 +238,7 @@ def forward( encoder_hidden_states=None, position_embeddings: torch.Tensor | None = None, reference_points=None, + spatial_shapes=None, spatial_shapes_list=None, ) -> tuple[torch.Tensor, torch.Tensor]: # add position embeddings to the hidden states before projecting to queries and keys @@ -263,35 +264,19 @@ def forward( ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 num_coordinates = reference_points.shape[-1] - - if num_coordinates == 4: + if num_coordinates == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif num_coordinates == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) - elif num_coordinates == 6: - ref = reference_points[:, :, None, :, None, :] # (..., 6) - - center = ref[..., :2] # (cx, cy) - wh = ref[..., 2:4] # (w, h) - sin = ref[..., 4:5] # sinθ - cos = ref[..., 5:6] # cosθ - - # normalize offsets - offsets = sampling_offsets / self.n_points * wh * 0.5 - - dx = offsets[..., 0:1] - dy = offsets[..., 1:2] - - # rotate offsets - dx_rot = dx * cos - dy * sin - dy_rot = dx * sin + dy * cos - - rotated_offsets = torch.cat([dx_rot, dy_rot], dim=-1) - - sampling_locations = center + rotated_offsets else: - raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") output = self.attn( value, @@ -361,6 +346,7 @@ def forward( hidden_states: torch.Tensor, position_embeddings: torch.Tensor | None = None, reference_points: torch.Tensor | None = None, + spatial_shapes: torch.Tensor | None = None, spatial_shapes_list: list[tuple] | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, @@ -379,6 +365,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, position_embeddings=position_embeddings, reference_points=reference_points, + spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, ) cross_attention_output = F.dropout(cross_attention_output, p=self.dropout, training=self.training) @@ -393,43 +380,40 @@ def forward( # function to generate sine positional embedding for 4d coordinates # Borrowed from: https://github.com/Atten4Vis/LW-DETR/blob/main/models/transformer.py -def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 256) -> torch.Tensor: - """ - This function computes position embeddings using sine and cosine functions from the input positional tensor, - which has a shape of (batch_size, num_queries, 4). - The last dimension of `pos_tensor` represents the following coordinates: - - 0: x-coord - - 1: y-coord - - 2: width - - 3: height - - The output shape is (batch_size, num_queries, 512), - where final dim (hidden_size*2 = 512) is the total embedding dimension - achieved by concatenating the sine and cosine values for each coordinate. +def encode_sinusoidal_position_embedding( + pos_tensor: torch.Tensor, + num_pos_feats: int = 128, + temperature: int = 10000, +) -> torch.Tensor: + """Sinusoidal position embeddings from normalized anchor coordinates. + + Each coordinate in `pos_tensor` is independently encoded with ``num_pos_feats`` + interleaved sin/cos components; per-coordinate embeddings are concatenated. + Handles 2-D ``(x, y)`` and N-D ``(x, y, w, h)`` inputs. For 2-D+ inputs the + x and y embeddings are swapped to follow the DETR ``[pos_y, pos_x, ...]`` convention. + + Args: + pos_tensor: Normalized coordinates in ``[0, 1]``, shape ``(..., n_coords)``. + num_pos_feats: Embedding dimension per coordinate. + temperature: Base for the frequency decay. + + Returns: + Tensor of shape ``(..., n_coords * num_pos_feats)``, same dtype as input. """ scale = 2 * math.pi - dim = hidden_size // 2 - dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) - dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) - x_embed = pos_tensor[:, :, 0] * scale - y_embed = pos_tensor[:, :, 1] * scale - pos_x = x_embed[:, :, None] / dim_t - pos_y = y_embed[:, :, None] / dim_t - pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) - pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) - if pos_tensor.size(-1) == 4: - w_embed = pos_tensor[:, :, 2] * scale - pos_w = w_embed[:, :, None] / dim_t - pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) - - h_embed = pos_tensor[:, :, 3] * scale - pos_h = h_embed[:, :, None] / dim_t - pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) - - pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) - else: - raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") - return pos.to(pos_tensor.dtype) + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) + + coords = pos_tensor.unbind(-1) # list of (...,) tensors + embeddings = [coord[..., None] * scale / dim_t for coord in coords] # each (..., num_pos_feats) + embeddings = [ + torch.stack((e[..., 0::2].sin(), e[..., 1::2].cos()), dim=-1).flatten(-2) for e in embeddings + ] # each (..., num_pos_feats) + + if len(embeddings) >= 2: + embeddings[0], embeddings[1] = embeddings[1], embeddings[0] + + return torch.cat(embeddings, dim=-1).to(pos_tensor.dtype) class LWDETRDecoder(nn.Module): @@ -458,7 +442,6 @@ def __init__( dec_n_points: int = 2, group_detr: int = 13, dropout_prob: float = 0.0, - bbox_embed: nn.Module | None = None, ): super().__init__() self.dropout_prob = dropout_prob @@ -477,89 +460,30 @@ def __init__( for i in range(num_layers) ]) self.layernorm = nn.LayerNorm(self.d_model) - self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) - self.angle_proj = nn.Sequential( - nn.Linear(4, self.d_model), - nn.ReLU(), - nn.Linear(self.d_model, self.d_model), - ) - def get_reference( - self, reference_points: torch.Tensor, valid_ratios: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """This function computes the reference point inputs and positional embeddings for the decoder layers. - - Args: - reference_points: (batch_size, num_queries, 6) - tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) - valid_ratios: (batch_size, num_levels, 2) - tensor containing the valid ratios for each level of the input feature maps - - Returns: - reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4) - tensor containing the reference point inputs for the decoder layers, - which are the normalized center coordinates, - width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps - query_pos: (batch_size, num_queries, d_model) - tensor containing the positional embeddings for the decoder layers, - which are computed from the reference points using sine and cosine functions and a linear projection - """ + def get_reference(self, reference_points, valid_ratios): + # batch_size, num_queries, batch_size, 4 obj_center = reference_points[..., :4] - spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - # Extract angles - angle = reference_points[..., 4:6] # (sin, cos) - angle_expanded = angle[:, :, None] - reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1) - # DETR positional encoding - query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) - base_query_pos = self.ref_point_head(query_sine_embed) - # Angle embedding - sin_t = angle[..., 0:1] - cos_t = angle[..., 1:2] - - angle_feat = torch.cat( - [ - sin_t, - cos_t, - 2 * sin_t * cos_t, - cos_t**2 - sin_t**2, - ], - dim=-1, - ) - - angle_emb = self.angle_proj(angle_feat) - # Combine - query_pos = base_query_pos + angle_emb - return reference_points_inputs, query_pos - - def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - - # Clamp deltas to prevent exp() from shooting to Infinity during early training - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] - - # Add eps=1e-6 to avoid division-by-zero NaN creation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta + # batch_size, num_queries, num_levels, 4 + reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - # Add eps=1e-6 here too - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) + # batch_size, num_queries, d_model * 2 + query_sine_embed = encode_sinusoidal_position_embedding( + reference_points_inputs[:, :, 0, :], num_pos_feats=self.d_model // 2 + ) - return torch.cat((cxcy, wh, rot), dim=-1) + # batch_size, num_queries, d_model + query_pos = self.ref_point_head(query_sine_embed) + return reference_points_inputs, query_pos def forward( self, inputs_embeds: torch.Tensor | None, reference_points: torch.Tensor, + spatial_shapes: torch.Tensor, spatial_shapes_list: torch.Tensor, valid_ratios: torch.Tensor, encoder_hidden_states: torch.Tensor, @@ -581,35 +505,18 @@ def forward( encoder_attention_mask=encoder_attention_mask, position_embeddings=query_pos, reference_points=reference_points_inputs, + spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, ) - hidden_states_norm = self.layernorm(hidden_states) - - # iterative refinement - if self.bbox_embed is not None: - delta = self.bbox_embed(hidden_states_norm) - - reference_points = self.refine_boxes( - reference_points.squeeze(2), - delta, - ) - - intermediate_reference_points.append(reference_points) - - reference_points_inputs, query_pos = self.get_reference( - reference_points, - valid_ratios, - ) - - intermediate.append(hidden_states_norm) - - intermediate_stack = torch.stack(intermediate) - last_hidden_state = intermediate_stack[-1] + intermediate_hidden_states = self.layernorm(hidden_states) + intermediate.append(intermediate_hidden_states) - intermediate_reference_points_stack = torch.stack(intermediate_reference_points) + intermediate = torch.stack(intermediate) + last_hidden_state = intermediate[-1] + intermediate_reference_points = torch.stack(intermediate_reference_points) - return last_hidden_state, intermediate_stack, intermediate_reference_points_stack + return last_hidden_state, intermediate, intermediate_reference_points class MultiScaleProjector(nn.Module): diff --git a/doctr/models/layout/lw_detr/loss.py b/doctr/models/layout/lw_detr/loss.py new file mode 100644 index 0000000000..ba4364764f --- /dev/null +++ b/doctr/models/layout/lw_detr/loss.py @@ -0,0 +1,527 @@ +import numpy as np +import torch +import torch.nn as nn +from scipy.optimize import linear_sum_assignment +from torch import Tensor + +__all__ = ["lw_detr_for_object_detection_loss"] + + +def center_to_corners_format(bboxes_center: "torch.Tensor") -> "torch.Tensor": + center_x, center_y, width, height = bboxes_center.unbind(-1) + bbox_corners = torch.stack( + # top left x, top left y, bottom right x, bottom right y + [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)], + dim=-1, + ) + return bbox_corners + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs (0 for the negative class and 1 for the positive + class). + num_boxes: Normalization factor, typically the number of target boxes in the batch. This is used to scale the + loss to an absolute value, and is used in the original implementation of DETR and LW-DETR. + It doesn't have to be + exactly the number of target boxes, but it should be correlated to it for the loss to be meaningful. + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://huggingface.co/papers/1708.02002. + + Args: + inputs (`torch.FloatTensor` of arbitrary shape): + The predictions for each example. + targets (`torch.FloatTensor` with the same shape as `inputs`): + A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class + and 1 for the positive class). + num_boxes (`int`): + Normalization factor, typically the number of target boxes in the batch. This is used to scale the loss + to an absolute value, and is used in the original implementation of DETR and LW-DETR. It doesn't have to be + exactly the number of target boxes, but it should be correlated to it for the loss to be meaningful. + alpha (`float`, *optional*, defaults to `0.25`): + Optional weighting factor in the range (0,1) to balance positive vs. negative examples. + gamma (`int`, *optional*, defaults to `2`): + Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. + + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + # add modulating factor + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +def box_area(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. + + Args: + boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): + Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 + < x2` and `0 <= y1 < y2`. + + Returns: + `torch.FloatTensor`: a tensor containing the area for each box. + """ + boxes = _upcast(boxes) + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] + inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. + + Returns: + `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): + raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") + iou, union = box_iou(boxes1, boxes2) + + top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] + area = width_height[:, :, 0] * width_height[:, :, 1] + + return iou - (area - union) / area + + +# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306 +def _max_by_axis(the_list): + # type: (list[list[int]]) -> list[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor: + def __init__(self, tensors, mask: Tensor | None): + self.tensors = tensors + self.mask = mask + + def to(self, device): + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: list[Tensor]): + if tensor_list[0].ndim == 3: + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + batch_shape = [len(tensor_list)] + max_size + batch_size, num_channels, height, width = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("Only 3-dimensional tensors are supported") + return NestedTensor(tensor, mask) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +def _set_aux_loss(outputs_class, outputs_coord): + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class LwDetrHungarianMatcher(nn.Module): + def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): + super().__init__() + + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: + raise ValueError("All costs of the Matcher can't be 0") + + @torch.no_grad() + def forward(self, outputs, targets, group_detr): + """ + Differences: + - out_prob = outputs["logits"].flatten(0, 1).sigmoid() instead of softmax + - class_cost uses alpha and gamma + """ + batch_size, num_queries = outputs["logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + target_ids = torch.cat([torch.as_tensor(v["class_labels"], dtype=torch.int64) for v in targets]).to( + out_prob.device + ) + target_bbox = torch.cat([torch.as_tensor(v["boxes"], dtype=torch.float32) for v in targets]).to(out_bbox.device) + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids] + + # Compute the L1 cost between boxes, cdist only supports float32 + dtype = out_bbox.dtype + out_bbox = out_bbox.to(torch.float32) + target_bbox = target_bbox.to(torch.float32) + bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) + bbox_cost = bbox_cost.to(dtype) + + # Compute the giou cost between boxes + giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) + + # Final cost matrix + cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost + cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [] + group_num_queries = num_queries // group_detr + cost_matrix_list = cost_matrix.split(group_num_queries, dim=1) + for group_id in range(group_detr): + group_cost_matrix = cost_matrix_list[group_id] + group_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(group_cost_matrix.split(sizes, -1))] + if group_id == 0: + indices = group_indices + else: + indices = [ + ( + np.concatenate([indice1[0], indice2[0] + group_num_queries * group_id]), + np.concatenate([indice1[1], indice2[1]]), + ) + for indice1, indice2 in zip(indices, group_indices) + ] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +class LwDetrImageLoss(nn.Module): + def __init__(self, matcher, num_classes, focal_alpha, losses, group_detr): + super().__init__() + self.matcher = matcher + self.num_classes = num_classes + self.focal_alpha = focal_alpha + self.losses = losses + self.group_detr = group_detr + + # removed logging parameter, which was part of the original implementation + def loss_labels(self, outputs, targets, indices, num_boxes): + if "logits" not in outputs: + raise KeyError("No logits were found in the outputs") + source_logits = outputs["logits"] + dtype = source_logits.dtype + + idx = self._get_source_permutation_idx(indices) + target_classes_o = torch.cat([ + torch.as_tensor(np.atleast_1d(t["class_labels"][J]), dtype=torch.int64) + for t, (_, J) in zip(targets, indices) + ]).to(source_logits.device) + alpha = self.focal_alpha + gamma = 2 + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat( + [torch.as_tensor(np.atleast_2d(t["boxes"][i]), dtype=torch.float32) for t, (_, i) in zip(targets, indices)], + dim=0, + ).to(src_boxes.device) + iou_targets = torch.diag( + box_iou(center_to_corners_format(src_boxes.detach()), center_to_corners_format(target_boxes))[0] + ) + # Convert to the same dtype as the source logits as box_iou upcasts to float32 + iou_targets = iou_targets.to(dtype) + pos_ious = iou_targets.clone().detach() + prob = source_logits.sigmoid() + # init positive weights and negative weights + pos_weights = torch.zeros_like(source_logits) + # pow promotes to float32 under float16 CUDA autocast; cast back to preserve original dtype + neg_weights = prob.pow(gamma).to(dtype) + pos_ind = idx + (target_classes_o,) + + pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) + pos_quality = torch.clamp(pos_quality, 0.01).detach().to(dtype) + + pos_weights[pos_ind] = pos_quality + neg_weights[pos_ind] = 1 - pos_quality + loss_ce = -pos_weights * prob.log() - neg_weights * (1 - prob).log() + loss_ce = loss_ce.sum() / num_boxes + losses = {"loss_ce": loss_ce} + + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ + Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. + + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. + """ + logits = outputs["logits"] + device = logits.device + target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (sigmoid > 0.5 threshold) + card_pred = (logits.sigmoid().max(-1).values > 0.5).sum(1) + card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) + losses = {"cardinality_error": card_err} + return losses + + # Copied from loss.loss_for_object_detection.ImageLoss.loss_boxes + def loss_boxes(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. + + Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes + are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + if "pred_boxes" not in outputs: + raise KeyError("No predicted boxes found in outputs") + idx = self._get_source_permutation_idx(indices) + source_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat( + [torch.as_tensor(np.atleast_2d(t["boxes"][i]), dtype=torch.float32) for t, (_, i) in zip(targets, indices)], + dim=0, + ).to(source_boxes.device) + + loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") + + losses = {} + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag( + generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)) + ) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + # Copied from loss.loss_for_object_detection.ImageLoss.loss_masks + def loss_masks(self, outputs, targets, indices, num_boxes): + """ + Compute the losses related to the masks: the focal loss and the dice loss. + + Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]. + """ + if "pred_masks" not in outputs: + raise KeyError("No predicted masks found in outputs") + + source_idx = self._get_source_permutation_idx(indices) + target_idx = self._get_target_permutation_idx(indices) + source_masks = outputs["pred_masks"] + source_masks = source_masks[source_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(source_masks) + target_masks = target_masks[target_idx] + + # upsample predictions to the target size + source_masks = nn.functional.interpolate( + source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + source_masks = source_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(source_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), + "loss_dice": dice_loss(source_masks, target_masks, num_boxes), + } + return losses + + # Copied from loss.loss_for_object_detection.ImageLoss._get_source_permutation_idx + def _get_source_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) + source_idx = torch.cat([source for (source, _) in indices]) + return batch_idx, source_idx + + # Copied from loss.loss_for_object_detection.ImageLoss._get_target_permutation_idx + def _get_target_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) + target_idx = torch.cat([target for (_, target) in indices]) + return batch_idx, target_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes): + loss_map = { + "labels": self.loss_labels, + "cardinality": self.loss_cardinality, + "boxes": self.loss_boxes, + "masks": self.loss_masks, + } + if loss not in loss_map: + raise ValueError(f"Loss {loss} not supported") + return loss_map[loss](outputs, targets, indices, num_boxes) + + def forward(self, outputs, targets): + """ + This performs the loss computation. + + Args: + outputs (`dict`, *optional*): + Dictionary of tensors, see the output specification of the model for the format. + targets (`list[dict]`, *optional*): + List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the + losses applied, see each loss' doc. + """ + group_detr = self.group_detr if self.training else 1 + outputs_without_aux_and_enc = { + k: v for k, v in outputs.items() if k != "enc_outputs" and k != "auxiliary_outputs" + } + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux_and_enc, targets, group_detr) + + # Compute the average number of target boxes across all nodes, for normalization purposes + num_boxes = sum(len(t["class_labels"]) for t in targets) + num_boxes = num_boxes * group_detr + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + world_size = 1 + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "auxiliary_outputs" in outputs: + for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): + indices = self.matcher(auxiliary_outputs, targets, group_detr) + for loss in self.losses: + if loss == "masks": + # Intermediate masks losses are too costly to compute, we ignore them. + continue + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + if "enc_outputs" in outputs: + enc_outputs = outputs["enc_outputs"] + indices = self.matcher(enc_outputs, targets, group_detr=group_detr) + for loss in self.losses: + l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes) + l_dict = {k + "_enc": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +def lw_detr_for_object_detection_loss( + logits, + labels, + device, + pred_boxes, + outputs_class=None, + outputs_coord=None, + enc_outputs_class=None, + enc_outputs_coord=None, + use_aux_loss=False, + group_detr=1, + num_labels=None, + num_decoder_layers=None, + **kwargs, +): + """Loss computation for LW-DETR for object detection.""" + # First: create the matcher + matcher = LwDetrHungarianMatcher(class_cost=2.0, bbox_cost=5, giou_cost=2) + # Second: create the criterion + losses = ["labels", "boxes", "cardinality"] + criterion = LwDetrImageLoss( + matcher=matcher, + num_classes=num_labels, + focal_alpha=0.1, + losses=losses, + group_detr=group_detr, + ) + criterion.to(device) + # Third: compute the losses, based on outputs and labels + outputs_loss = {} + auxiliary_outputs = None + outputs_loss["logits"] = logits + outputs_loss["pred_boxes"] = pred_boxes + outputs_loss["enc_outputs"] = { + "logits": enc_outputs_class, + "pred_boxes": enc_outputs_coord, + } + if use_aux_loss: + auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord) + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + loss_dict = criterion(outputs_loss, labels) + # Fourth: compute total loss, as a weighted sum of the various losses + weight_dict = {"loss_ce": 1, "loss_bbox": 5} + weight_dict["loss_giou"] = 2 + if use_aux_loss: + aux_weight_dict = {} + for i in range(num_decoder_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()} + weight_dict.update(enc_weight_dict) + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict) + return loss, loss_dict, auxiliary_outputs diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 848d44280e..2ca0ecee4a 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -18,6 +18,7 @@ from ...utils import load_pretrained_params from .base import _LWDETR, LWDETRPostProcessor from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector +from .loss import lw_detr_for_object_detection_loss __all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"] @@ -153,8 +154,8 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.3, - iou_thresh: float = 0.5, + score_thresh: float = 0.0, + iou_thresh: float = 0.1, d_model: int = 256, num_queries: int = 130, group_detr: int = 1, @@ -171,7 +172,7 @@ def __init__( super().__init__() self.class_names: list[str] = class_names - self.num_classes = len(self.class_names) + 1 # +1 for background class + self.num_classes = len(self.class_names) + 1 # +1 for background class (NO OBJECT) self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -181,20 +182,13 @@ def __init__( self.group_detr = group_detr self.num_queries = num_queries self.d_model = d_model + self.dec_layers = dec_layers - self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) - # Initialize angle to (sin=0, cos=1) - with torch.no_grad(): - self.reference_point_embed.weight[:, 4] = 0.0 # sinθ - self.reference_point_embed.weight[:, 5] = 1.0 # cosθ - + self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4) self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) - self.class_embed = nn.Linear(self.d_model, self.num_classes) - self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) - self.decoder = LWDETRDecoder( - num_layers=dec_layers, + num_layers=self.dec_layers, d_model=d_model, sa_num_heads=sa_num_heads, ca_num_heads=ca_num_heads, @@ -202,19 +196,21 @@ def __init__( dec_n_points=dec_n_points, group_detr=group_detr, dropout_prob=dropout_prob, - bbox_embed=self.bbox_embed, ) self.enc_output = nn.ModuleList([nn.Linear(self.d_model, self.d_model) for _ in range(self.group_detr)]) self.enc_output_norm = nn.ModuleList([nn.LayerNorm(self.d_model) for _ in range(self.group_detr)]) self.enc_out_bbox_embed = nn.ModuleList([ - LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) for _ in range(self.group_detr) + LWDETRHead(self.d_model, self.d_model, 4, num_layers=3) for _ in range(self.group_detr) ]) self.enc_out_class_embed = nn.ModuleList([ nn.Linear(self.d_model, self.num_classes) for _ in range(self.group_detr) ]) + self.class_embed = nn.Linear(self.d_model, self.num_classes) + self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 4, num_layers=3) + self.postprocessor = LWDETRPostProcessor( num_classes=self.num_classes, score_thresh=score_thresh, @@ -226,60 +222,34 @@ def __init__( # Don't override the initialization of the backbone if n.startswith("feat_extractor."): continue - - if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): - if hasattr(m, "weight") and m.weight is not None: - nn.init.ones_(m.weight) - if hasattr(m, "bias") and m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Embedding): - nn.init.normal_(m.weight, std=0.02) - elif isinstance(m, LWDETRMultiscaleDeformableAttention): + if isinstance(m, LWDETRMultiscaleDeformableAttention): nn.init.constant_(m.sampling_offsets.weight, 0.0) - - thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads) + thetas = torch.arange(m.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / m.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) .view(m.n_heads, 1, 1, 2) .repeat(1, m.n_levels, m.n_points, 1) ) - for i in range(m.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): m.sampling_offsets.bias.copy_(grid_init.view(-1)) - nn.init.constant_(m.attention_weights.weight, 0.0) nn.init.constant_(m.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(m.value_proj.weight) - nn.init.zeros_(m.value_proj.bias) - + nn.init.constant_(m.value_proj.bias, 0.0) nn.init.xavier_uniform_(m.output_proj.weight) - nn.init.zeros_(m.output_proj.bias) - - if isinstance(m, nn.Linear) and m.out_features == self.num_classes: + nn.init.constant_(m.output_proj.bias, 0.0) + if hasattr(m, "refpoint_embed") and m.refpoint_embed is not None: + nn.init.constant_(m.refpoint_embed.weight, 0) + if hasattr(m, "class_embed") and m.class_embed is not None: prior_prob = 0.01 bias_value = -math.log((1 - prior_prob) / prior_prob) - if m.bias is not None: - nn.init.constant_(m.bias, bias_value) - if isinstance(m, LWDETRHead): - last = m.layers[-1] - if isinstance(last, nn.Linear): - nn.init.zeros_(last.weight) - nn.init.zeros_(last.bias) - if last.bias.shape[0] == 6: - nn.init.constant_(last.bias[5], 1.0) + nn.init.constant_(m.class_embed.bias, bias_value) + if hasattr(m, "bbox_embed") and m.bbox_embed is not None: + nn.init.constant_(m.bbox_embed.layers[-1].weight, 0) + nn.init.constant_(m.bbox_embed.layers[-1].bias, 0) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -290,79 +260,39 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - """Refine bounding boxes by applying the predicted deltas to the reference points. - The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. - The refined boxes are computed as follows: - - cx' = cx + delta_cx * w - cy' = cy + delta_cy * h - w' = w * exp(delta_w) - h' = h * exp(delta_h) - sinθ' = sinθ * cosΔ + cosθ * sinΔ - cosθ' = cosθ * cosΔ - sinθ * sinΔ - - Args: - reference_points: (N, S, 6) tensor containing the reference points - deltas: (N, S, 6) tensor containing the predicted deltas - - Returns: - refined_boxes: (N, S, 6) tensor containing the refined bounding boxes - """ + def refine_bboxes(self, reference_points, deltas): reference_points = reference_points.to(deltas.device) - # center - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - # size - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] - # rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - - # compose rotations - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - - return torch.cat((cxcy, wh, rot), dim=-1) - - def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: - """Get the valid ratio of all feature maps. - - Args: - mask: (N, H, W) binary tensor containing 1 on padded pixels - dtype: the desired data type of the output tensor + new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2] + new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:] + new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1) + return new_reference_points - Returns: - valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch - """ + def get_valid_ratio(self, mask, dtype=torch.float32): + """Get the valid ratio of all feature maps.""" _, height, width = mask.shape - valid_height = torch.sum(~mask[:, :, 0], 1) - valid_width = torch.sum(~mask[:, 0, :], 1) + valid_height = torch.sum(mask[:, :, 0], 1) + valid_width = torch.sum(mask[:, 0, :], 1) valid_ratio_height = valid_height.to(dtype) / height valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) return valid_ratio - def gen_encoder_output_proposals( - self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]] - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes): """Generate the encoder output proposals from encoded enc_output. Args: - enc_output: Output of the encoder - padding_mask: Padding mask for `enc_output` - spatial_shapes: Spatial shapes of the feature maps + enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder. + padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`. + spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps. Returns: - A tuple of feature map and bbox prediction. - - object_query: Object query features. Later used to directly predict a bounding box. - - output_proposals: Normalized proposals in [0, 1] space. - Invalid positions (padding or out-of-bounds) are filled with 0. - - invalid_mask: Boolean mask that is True for invalid positions - (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). + `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction. + - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to + directly predict a bounding box. (without the need of a decoder) + - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals in [0, 1] space. + Invalid positions (padding or out-of-bounds) are filled with 0. + - invalid_mask (Tensor[batch_size, sequence_length, 1]): Boolean mask that is True for invalid positions + (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). """ batch_size = enc_output.shape[0] proposals = [] @@ -394,17 +324,11 @@ def gen_encoder_output_proposals( scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale width_height = torch.ones_like(grid) * 0.05 * (2.0**level) - # add default rotation (sin=0, cos=1) - sin = torch.zeros_like(grid[..., :1]) - cos = torch.ones_like(grid[..., :1]) - proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6) + proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4) proposals.append(proposal) _cur += height * width output_proposals = torch.cat(proposals, 1) - - spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) - output_proposals_valid = spatial_valid - invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) @@ -455,6 +379,7 @@ def forward( mask_flatten_list.append(mask) source_flatten = torch.cat(source_flatten_list, 1) mask_flatten = torch.cat(mask_flatten_list, 1) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1) tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) @@ -467,51 +392,73 @@ def forward( group_detr = self.group_detr if self.training else 1 topk = self.num_queries - topk_coords_logits_list: list[torch.Tensor] = [] - - # encoder predictions for auxiliary losses - all_group_enc_logits: list[torch.Tensor] = [] - all_group_enc_coords: list[torch.Tensor] = [] + topk_coords_logits = [] + topk_coords_logits_undetach = [] + object_query_undetach = [] for group_id in range(group_detr): group_object_query = self.enc_output[group_id](object_query_embedding) group_object_query = self.enc_output_norm[group_id](group_object_query) group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) - all_group_enc_logits.append(group_enc_outputs_class) - - group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) - + group_enc_outputs_class = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) - all_group_enc_coords.append(group_enc_outputs_coord) - - group_topk_proposals = torch.topk(group_enc_outputs_class_masked.max(-1)[0], topk, dim=1)[1] - + group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1] group_topk_coords_logits_undetach = torch.gather( group_enc_outputs_coord, 1, - group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), + group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4), ) - group_topk_coords_logits = group_topk_coords_logits_undetach - topk_coords_logits_list.append(group_topk_coords_logits) + group_topk_coords_logits = group_topk_coords_logits_undetach.detach() + group_object_query_undetach = torch.gather( + group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) + ) + + topk_coords_logits.append(group_topk_coords_logits) + topk_coords_logits_undetach.append(group_topk_coords_logits_undetach) + object_query_undetach.append(group_object_query_undetach) + + topk_coords_logits = torch.cat(topk_coords_logits, 1) + topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1) + object_query_undetach = torch.cat(object_query_undetach, 1) + + enc_outputs_class_logits = object_query_undetach + enc_outputs_boxes_logits = topk_coords_logits_undetach - topk_coords_logits = torch.cat(topk_coords_logits_list, 1) reference_points = self.refine_bboxes(topk_coords_logits, reference_points) - last_hidden_states, intermediate, intermediate_reference_points = self.decoder( + init_reference_points = reference_points + last_hidden_state, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, reference_points=reference_points, + spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, + encoder_attention_mask=mask_flatten, ) - logits = self.class_embed(last_hidden_states) - pred_boxes_delta = self.bbox_embed(last_hidden_states) + logits = self.class_embed(last_hidden_state) + pred_boxes_delta = self.bbox_embed(last_hidden_state) pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.num_queries, dim=1) + pred_class = [] + group_detr = self.group_detr if self.training else 1 + for group_index in range(group_detr): + group_pred_class = self.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index]) + pred_class.append(group_pred_class) + enc_outputs_class_logits = torch.cat(pred_class, dim=1) + + if target is not None: + outputs_class, outputs_coord = None, None + intermediate_hidden_states = intermediate + outputs_coord_delta = self.bbox_embed(intermediate_hidden_states) + outputs_coord = self.refine_bboxes(intermediate_reference_points, outputs_coord_delta) + outputs_class = self.class_embed(intermediate_hidden_states) + out: dict[str, Any] = {} if self.exportable: @@ -533,224 +480,44 @@ def _postprocess(logits, boxes): if target is not None: # Build target processed_targets = self.build_target(target, self.class_names) - - # Main loss from final decoder layer (group DETR) - split_logits = logits.chunk(group_detr, dim=1) - split_boxes = pred_boxes.chunk(group_detr, dim=1) - - main_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_logits, split_boxes): - main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss = main_loss / group_detr - - # Auxiliary losses from intermediate decoder layers - for i in range(intermediate.shape[0] - 1): - aux_logits = self.class_embed(intermediate[i]) - aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = self.refine_bboxes(intermediate_reference_points[i], aux_boxes_delta) - - split_aux_logits = aux_logits.chunk(group_detr, dim=1) - split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) - - aux_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): - aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += 0.5 * (aux_loss / group_detr) - - # Auxiliary losses for encoder proposals - enc_loss: float | torch.Tensor = 0.0 - for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): - enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += 0.1 * (enc_loss / group_detr) - - out["loss"] = loss + out["loss"] = self.compute_loss( + logits, + processed_targets, + pred_boxes, + outputs_class, + outputs_coord, + enc_outputs_class_logits, + enc_outputs_boxes_logits, + ) return out def compute_loss( - self, logits: torch.Tensor, pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]] - ) -> torch.Tensor: - """ - Compute the loss for LW-DETR. The loss consists of three components: - classification loss, box regression loss, and rotation loss. - The classification loss is a cross-entropy loss between the predicted class logits and the target classes. - The box regression loss is a Smooth L1 loss between the predicted boxes and the target boxes, - computed only on the positive samples. - The rotation loss is computed as 1 - cosine similarity between the predicted rotation and the target rotation, - averaged over the positive samples. - The positive samples are determined using a SimOTA-like assignment strategy, where for each ground truth box, - we select the top-k queries with the lowest cost - (combination of classification cost, box regression cost, and rotation cost). - - Args: - logits: (B, Q, C) tensor containing the predicted class logits for each query - pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query - targets: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding - to class names and values corresponding to lists of boxes in either polygon format (4, 2) - or bounding box format (4,) (xmin, ymin, xmax, ymax) - - Returns: - loss: the computed loss value - """ - - def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters - (mean and covariance). - """ - cxcy = boxes[..., :2] - - w = boxes[..., 2].clamp(min=1e-6) - h = boxes[..., 3].clamp(min=1e-6) - - sin = boxes[..., 4] - cos = boxes[..., 5] - - R = torch.stack( - [ - torch.stack([cos, -sin], dim=-1), - torch.stack([sin, cos], dim=-1), - ], - dim=-2, - ) - - sx = (w / 2) ** 2 - sy = (h / 2) ** 2 - - S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) - - S[..., 0, 0] = sx - S[..., 1, 1] = sy - - covariance = R @ S @ R.transpose(-1, -2) - return cxcy, covariance - - def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: - """Compute the ProbIoU loss between predicted boxes and target boxes.""" - mu1, sigma1 = rotated_boxes_to_gaussian(pred_boxes) - mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes) - - delta = (mu1 - mu2).unsqueeze(-1) - sigma = (sigma1 + sigma2) * 0.5 - - eps = 1e-6 - eye = torch.eye(2, device=sigma.device) * eps - sigma_safe = sigma + eye - sigma1_safe = sigma1 + eye - sigma2_safe = sigma2 + eye - - sigma_inv = torch.linalg.inv(sigma_safe) - - mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - - det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) - det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) - det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) - - bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) - - probiou = torch.exp(-bhattacharyya) - return 1 - probiou - - device = logits.device - B, Q, C = logits.shape - - total_cls = torch.tensor(0.0, device=device) - total_box = torch.tensor(0.0, device=device) - total_rot = torch.tensor(0.0, device=device) - - for b in range(B): - pred_logits = logits[b] - pred_boxes_b = pred_boxes[b] - - tgt_boxes = torch.as_tensor( - targets[b]["boxes"], - device=device, - dtype=pred_boxes.dtype, - ) - tgt_cls = torch.as_tensor( - targets[b]["labels"], - device=device, - dtype=torch.long, - ) - - num_gt = len(tgt_cls) - - pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) - tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) - - with torch.no_grad(): - cls_prob = pred_logits.sigmoid() - alpha = 0.25 - gamma = 2.0 - - neg_cost = (1 - alpha) * (cls_prob**gamma) * (-(1 - cls_prob + 1e-8).log()) - - pos_cost = alpha * ((1 - cls_prob) ** gamma) * (-(cls_prob + 1e-8).log()) - - cost_cls = pos_cost[:, tgt_cls] - neg_cost[:, tgt_cls] - cost_l1 = torch.cdist( - pred_boxes_b[:, :4], - tgt_boxes[:, :4], - p=1, - ) - cost_rot = 1 - (pred_rot @ tgt_rot.T).abs() - total_cost = 5.0 * cost_cls + 2.0 * cost_l1 + 1.0 * cost_rot - matching_matrix = torch.zeros( - (Q, num_gt), - dtype=torch.bool, - device=device, - ) - - center_dist = torch.cdist( - pred_boxes_b[:, :2], - tgt_boxes[:, :2], - p=2, - ) - - iou_like = torch.exp(-center_dist) - dynamic_k = iou_like.sum(0).int().clamp(min=1, max=10) - - for gt_idx in range(num_gt): - _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item())) - matching_matrix[candidate_idx, gt_idx] = True - - # resolve duplicate matches - multiple_match_mask = matching_matrix.sum(1) > 1 - - if multiple_match_mask.any(): - duplicate_idx = multiple_match_mask.nonzero(as_tuple=False).squeeze(1) - min_cost_idx = total_cost[duplicate_idx].argmin(dim=1) - # Set all matches to False for the duplicate indices, - # then set the match with the lowest cost to True - matching_matrix[duplicate_idx] = False - matching_matrix[duplicate_idx, min_cost_idx] = True - - pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True) - - target_classes = torch.zeros((Q,), dtype=torch.long, device=device) - - # background = 0 - target_classes[pos_idx] = tgt_cls[gt_indices] - - total_cls += F.cross_entropy(pred_logits, target_classes) - - if len(pos_idx) == 0: - continue - - pred_sel = pred_boxes_b[pos_idx] - tgt_sel = tgt_boxes[gt_indices] - # L1 loss on (cx, cy, w, h) - l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4]) - # ProbIoU loss on the whole box (including rotation) - probiou_loss = _probiou_loss(pred_sel, tgt_sel).mean() - total_box += 2.0 * l1_loss + 0.5 * probiou_loss - # Rotation loss - cos_sim = (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1).abs() - rot_loss = (1 - cos_sim).mean() - total_rot += 0.5 * rot_loss - # Average the loss over the batch - return (total_cls + total_box + total_rot) / B + self, + logits, + targets, + pred_boxes, + outputs_class, + outputs_coord, + enc_outputs_class_logits, + enc_outputs_boxes_logits, + ): + + loss_calc = lw_detr_for_object_detection_loss( + logits=logits, + device=logits.device, + labels=targets, + pred_boxes=pred_boxes, + outputs_class=outputs_class, + outputs_coord=outputs_coord, + enc_outputs_class=enc_outputs_class_logits, + enc_outputs_coord=enc_outputs_boxes_logits, + use_aux_loss=True, + group_detr=self.group_detr, + num_decoder_layers=self.dec_layers, + num_labels=self.num_classes, + ) + return loss_calc[0] def _lw_detr( From 4c5629c73d34843b096d885479b7656507a01405 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 28 May 2026 11:24:14 +0200 Subject: [PATCH 2/9] rot-check --- doctr/models/layout/lw_detr/base.py | 179 +++--- doctr/models/layout/lw_detr/layers/pytorch.py | 235 +++++--- doctr/models/layout/lw_detr/loss.py | 527 ------------------ doctr/models/layout/lw_detr/pytorch.py | 466 ++++++++++++---- 4 files changed, 626 insertions(+), 781 deletions(-) delete mode 100644 doctr/models/layout/lw_detr/loss.py diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index f7c926b13b..1a7cb452cf 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -9,6 +9,7 @@ import numpy as np from doctr.models.core import BaseModel +from doctr.utils import order_points __all__ = ["_LWDETR", "LWDETRPostProcessor"] @@ -38,36 +39,29 @@ def __init__( self.topk = topk self.assume_straight_pages = assume_straight_pages - def _decode_boxes(self, boxes: np.ndarray) -> np.ndarray: - """ - Decode cxcywh -> polygons (axis-aligned rectangles) + def _decode_boxes(self, boxes: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Decode the predicted boxes from OBB format to polygon format + + Args: + boxes: array of predicted boxes in OBB format (N, 6) (cx, cy, w, h, sin(theta), cos(theta)) + + Returns: + tuple of (polys, angles) where polys is an array of decoded polygons (N, 4, 2) + and angles is an array of angles in radians (N,) """ - cx = boxes[:, 0] - cy = boxes[:, 1] - w = boxes[:, 2] - h = boxes[:, 3] + cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + sin, cos = boxes[:, 4], boxes[:, 5] - polys = [] + angles = np.arctan2(sin, cos) + polys = [] for i in range(len(boxes)): - x1 = cx[i] - w[i] / 2 - y1 = cy[i] - h[i] / 2 - x2 = cx[i] + w[i] / 2 - y2 = cy[i] + h[i] / 2 - - poly = np.array( - [ - [x1, y1], - [x2, y1], - [x2, y2], - [x1, y2], - ], - dtype=np.float32, - ) + rect = ((float(cx[i]), float(cy[i])), (float(w[i]), float(h[i])), float(np.degrees(angles[i]))) + poly = order_points(cv2.boxPoints(rect)) polys.append(poly) - return np.asarray(polys, dtype=np.float32) + return np.asarray(polys, dtype=np.float32), angles def _iou(self, poly1: np.ndarray, poly2: np.ndarray) -> float: """Compute the IoU between two polygons @@ -142,23 +136,30 @@ def _nms(self, polys: np.ndarray, scores: np.ndarray, labels: np.ndarray) -> lis suppressed[j] = True return keep - def __call__(self, logits: np.ndarray, boxes: np.ndarray): - + def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int], np.ndarray, list[float]]]: logits = np.asarray(logits) boxes = np.asarray(boxes) - results = [] + results: list[tuple[list[int], np.ndarray, list[float]]] = [] for b in range(boxes.shape[0]): + # Convert logits to probabilities and get scores and labels exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) prob = exp / exp.sum(axis=-1, keepdims=True) - prob_fg = prob[:, :-1] - scores = prob_fg.max(axis=-1) - labels = prob_fg.argmax(axis=-1) + scores = prob.max(axis=-1) + labels = prob.argmax(axis=-1) + # treat background as invalid prediction + bg = self.num_classes - 1 + valid = labels != bg + + scores = scores * valid + + # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: - idxs = np.argsort(scores)[::-1][: self.topk] + idxs = np.argpartition(-scores, self.topk)[: self.topk] + idxs = idxs[np.argsort(-scores[idxs])] else: idxs = np.arange(len(scores)) @@ -172,36 +173,44 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray): scores_b = scores_b[mask] labels_b = labels_b[mask] - polys = self._decode_boxes(bboxes) if len(bboxes) > 0 else np.zeros((0, 4, 2), dtype=np.float32) + polys, _ = ( + self._decode_boxes(bboxes) + if len(bboxes) > 0 + else ( + np.zeros((0, 4, 2), dtype=np.float32), + np.zeros((0,), dtype=np.float32), + ) + ) keep = self._nms(polys, scores_b, labels_b) if len(polys) > 0 else [] - final_boxes = [] final_labels = [] + final_boxes = [] final_scores = [] for idx in keep: - poly = polys[idx] - + poly = polys[idx].reshape(-1).tolist() if self.assume_straight_pages: - # 👉 COCO-style axis aligned box from polygon - xmin = float(np.min(poly[:, 0])) - xmax = float(np.max(poly[:, 0])) - ymin = float(np.min(poly[:, 1])) - ymax = float(np.max(poly[:, 1])) - + x_coords = poly[0::2] + y_coords = poly[1::2] + xmin, xmax = min(x_coords), max(x_coords) + ymin, ymax = min(y_coords), max(y_coords) final_boxes.append([xmin, ymin, xmax, ymax]) else: - final_boxes.append(poly.reshape(-1).tolist()) + final_boxes.append(poly) final_labels.append(int(labels_b[idx])) final_scores.append(float(scores_b[idx])) - final_boxes_arr = np.asarray(final_boxes, dtype=np.float32) + final_boxes_arr = ( + np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4, 2) + if not self.assume_straight_pages + else np.asarray(final_boxes, dtype=np.float32).reshape(-1, 4) + ) results.append(( final_labels, - final_boxes_arr, # <- NOW ALWAYS CLEAN FORMAT + final_boxes_arr, final_scores, )) @@ -218,12 +227,55 @@ def build_target( target: list[dict[str, np.ndarray]], class_names: list[str], ) -> list[dict[str, Any]]: - """ - Build targets in COCO format: [xmin, ymin, w, h] + """Build the target for LW-DETR training + + Args: + target: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding + to class names and values corresponding to lists of boxes in either polygon format (4, 2) + or bounding box format (4,) (xmin, ymin, xmax, ymax) + class_names: list of class names + + Returns: + list of dictionaries with keys "boxes" and "labels" where "boxes" is an array of shape (num_boxes, 6) + containing the box parameters in OBB format (cx, cy, w, h, sin(theta), cos(theta)) + and "labels" is an array of shape (num_boxes,) containing the class labels """ targets = [] class_to_id = {name: i for i, name in enumerate(class_names)} + def _quad_to_obb(poly: np.ndarray): + poly = np.asarray(poly, dtype=np.float32) + + # Center point is simply the average of the relative vertices + cx, cy = np.mean(poly, axis=0) + + edges = np.stack([ + poly[1] - poly[0], + poly[2] - poly[1], + poly[3] - poly[2], + poly[0] - poly[3], + ]) + + lengths = np.linalg.norm(edges, axis=1) + i = np.argmax(lengths) + dx, dy = edges[i] + + theta = np.arctan2(dy, dx) + + # Width and height remain cleanly in relative coordinate space [0, 1] + w = np.mean([lengths[i], lengths[(i + 2) % 4]]) + h = np.mean([lengths[(i + 1) % 4], lengths[(i + 3) % 4]]) + + # Enforce strict unit-length normal vectors for rotation + sin_t = np.sin(theta) + cos_t = np.cos(theta) + norm = np.sqrt(sin_t**2 + cos_t**2) + 1e-8 + + return np.array( + [cx, cy, w, h, sin_t / norm, cos_t / norm], + dtype=np.float32, + ) + def to_quad(box: np.ndarray): box = np.asarray(box, dtype=np.float32) if box.shape == (4,): @@ -235,19 +287,6 @@ def to_quad(box: np.ndarray): return box.astype(np.float32) raise ValueError(f"Unsupported box shape: {box.shape}") - def quad_to_coco(poly: np.ndarray) -> np.ndarray: - xmin = float(np.min(poly[:, 0])) - xmax = float(np.max(poly[:, 0])) - ymin = float(np.min(poly[:, 1])) - ymax = float(np.max(poly[:, 1])) - - w = xmax - xmin - h = ymax - ymin - cx = xmin + w / 2.0 - cy = ymin + h / 2.0 - - return np.array([cx, cy, w, h], dtype=np.float32) - for sample in target: boxes_all = [] labels_all = [] @@ -262,29 +301,31 @@ def quad_to_coco(poly: np.ndarray) -> np.ndarray: if boxes.ndim == 1: boxes = boxes[None, :] - # sanity check normalized coords + # Sanity check: coordinates must be in [0, 1] normalized space. + # Values > 1.5 almost certainly indicate pixel coordinates were passed in. flat = boxes.ravel() coord_vals = flat[flat > 0] if len(coord_vals) > 0 and coord_vals.max() > 1.5: - raise ValueError("build_target expects normalized [0,1] coordinates.") + raise ValueError( + f"build_target expects normalized [0, 1] box coordinates, " + f"but found values up to {coord_vals.max():.1f} for class '{class_name}'. " + f"Divide your coordinates by image width/height before calling build_target." + ) for box in boxes: poly = to_quad(box) - coco_box = quad_to_coco(poly) + obb = _quad_to_obb(poly) - if coco_box[2] <= 1e-5 or coco_box[3] <= 1e-5: + # filter out degenerate boxes + if obb[2] <= 1e-5 or obb[3] <= 1e-5: continue - boxes_all.append(coco_box) + boxes_all.append(obb) labels_all.append(cls_id) - if len(boxes_all) == 0: - boxes_all = np.zeros((0, 4), dtype=np.float32) - labels_all = np.zeros((0,), dtype=np.int64) - targets.append({ - "boxes": np.asarray(boxes_all, dtype=np.float32), # (N, 4) - "class_labels": np.asarray(labels_all, dtype=np.int64), + "boxes": np.asarray(boxes_all, dtype=np.float32), + "labels": np.asarray(labels_all, dtype=np.int64), }) return targets diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 9f6d75ab30..bb6b503e97 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -113,7 +113,7 @@ def forward( hidden_states_original = hidden_states if position_embeddings is not None: - hidden_states = hidden_states + position_embeddings + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings if self.training: # at training, we use group detr technique to @@ -238,7 +238,6 @@ def forward( encoder_hidden_states=None, position_embeddings: torch.Tensor | None = None, reference_points=None, - spatial_shapes=None, spatial_shapes_list=None, ) -> tuple[torch.Tensor, torch.Tensor]: # add position embeddings to the hidden states before projecting to queries and keys @@ -251,7 +250,7 @@ def forward( value = self.value_proj(encoder_hidden_states) if attention_mask is not None: # we invert the attention_mask - value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.masked_fill(attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 @@ -264,19 +263,35 @@ def forward( ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 num_coordinates = reference_points.shape[-1] - if num_coordinates == 2: - offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) - sampling_locations = ( - reference_points[:, :, None, :, None, :] - + sampling_offsets / offset_normalizer[None, None, None, :, None, :] - ) - elif num_coordinates == 4: + + if num_coordinates == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) + elif num_coordinates == 6: + ref = reference_points[:, :, None, :, None, :] # (..., 6) + + center = ref[..., :2] # (cx, cy) + wh = ref[..., 2:4] # (w, h) + sin = ref[..., 4:5] # sinθ + cos = ref[..., 5:6] # cosθ + + # normalize offsets + offsets = sampling_offsets / self.n_points * wh * 0.5 + + dx = offsets[..., 0:1] + dy = offsets[..., 1:2] + + # rotate offsets + dx_rot = dx * cos - dy * sin + dy_rot = dx * sin + dy * cos + + rotated_offsets = torch.cat([dx_rot, dy_rot], dim=-1) + + sampling_locations = center + rotated_offsets else: - raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") + raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") output = self.attn( value, @@ -346,7 +361,6 @@ def forward( hidden_states: torch.Tensor, position_embeddings: torch.Tensor | None = None, reference_points: torch.Tensor | None = None, - spatial_shapes: torch.Tensor | None = None, spatial_shapes_list: list[tuple] | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, @@ -365,7 +379,6 @@ def forward( encoder_hidden_states=encoder_hidden_states, position_embeddings=position_embeddings, reference_points=reference_points, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, ) cross_attention_output = F.dropout(cross_attention_output, p=self.dropout, training=self.training) @@ -380,40 +393,45 @@ def forward( # function to generate sine positional embedding for 4d coordinates # Borrowed from: https://github.com/Atten4Vis/LW-DETR/blob/main/models/transformer.py -def encode_sinusoidal_position_embedding( - pos_tensor: torch.Tensor, - num_pos_feats: int = 128, - temperature: int = 10000, -) -> torch.Tensor: - """Sinusoidal position embeddings from normalized anchor coordinates. - - Each coordinate in `pos_tensor` is independently encoded with ``num_pos_feats`` - interleaved sin/cos components; per-coordinate embeddings are concatenated. - Handles 2-D ``(x, y)`` and N-D ``(x, y, w, h)`` inputs. For 2-D+ inputs the - x and y embeddings are swapped to follow the DETR ``[pos_y, pos_x, ...]`` convention. - - Args: - pos_tensor: Normalized coordinates in ``[0, 1]``, shape ``(..., n_coords)``. - num_pos_feats: Embedding dimension per coordinate. - temperature: Base for the frequency decay. - - Returns: - Tensor of shape ``(..., n_coords * num_pos_feats)``, same dtype as input. +def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 256) -> torch.Tensor: + """ + This function computes position embeddings using sine and cosine functions from the input positional tensor, + which has a shape of (batch_size, num_queries, 4). + The last dimension of `pos_tensor` represents the following coordinates: + - 0: x-coord + - 1: y-coord + - 2: width + - 3: height + + The output shape is (batch_size, num_queries, 512), + where final dim (hidden_size*2 = 512) is the total embedding dimension + achieved by concatenating the sine and cosine values for each coordinate. """ scale = 2 * math.pi - dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device) - dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) - - coords = pos_tensor.unbind(-1) # list of (...,) tensors - embeddings = [coord[..., None] * scale / dim_t for coord in coords] # each (..., num_pos_feats) - embeddings = [ - torch.stack((e[..., 0::2].sin(), e[..., 1::2].cos()), dim=-1).flatten(-2) for e in embeddings - ] # each (..., num_pos_feats) - - if len(embeddings) >= 2: - embeddings[0], embeddings[1] = embeddings[1], embeddings[0] - - return torch.cat(embeddings, dim=-1).to(pos_tensor.dtype) + dim = hidden_size // 2 + # Keep dim_t in float32 for numerical precision; cast output to match caller dtype + dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) + x_embed = pos_tensor[:, :, 0].float() * scale + y_embed = pos_tensor[:, :, 1].float() * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2].float() * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3].float() * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") + # Cast back to the caller's dtype (supports bfloat16 / float16 AMP) + return pos.to(pos_tensor.dtype) class LWDETRDecoder(nn.Module): @@ -442,6 +460,7 @@ def __init__( dec_n_points: int = 2, group_detr: int = 13, dropout_prob: float = 0.0, + bbox_embed: nn.Module | None = None, ): super().__init__() self.dropout_prob = dropout_prob @@ -460,30 +479,102 @@ def __init__( for i in range(num_layers) ]) self.layernorm = nn.LayerNorm(self.d_model) + self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) + self.angle_proj = nn.Sequential( + nn.Linear(4, self.d_model), + nn.ReLU(), + nn.Linear(self.d_model, self.d_model), + ) - def get_reference(self, reference_points, valid_ratios): - # batch_size, num_queries, batch_size, 4 + def get_reference( + self, reference_points: torch.Tensor, valid_ratios: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """This function computes the reference point inputs and positional embeddings for the decoder layers. + + Args: + reference_points: (batch_size, num_queries, 6) + tensor containing the current reference points in the format (cx, cy, w, h, sinθ, cosθ) + valid_ratios: (batch_size, num_levels, 2) + tensor containing the valid ratios for each level of the input feature maps + + Returns: + reference_points_inputs: (batch_size, num_queries, 1, num_levels, 6) + tensor containing the reference point inputs for the decoder layers, + which are the normalized center coordinates, + width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps + query_pos: (batch_size, num_queries, d_model) + tensor containing the positional embeddings for the decoder layers, + which are computed from the reference points using sine and cosine functions and a linear projection + """ obj_center = reference_points[..., :4] - - # batch_size, num_queries, num_levels, 4 - reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] - - # batch_size, num_queries, d_model * 2 - query_sine_embed = encode_sinusoidal_position_embedding( - reference_points_inputs[:, :, 0, :], num_pos_feats=self.d_model // 2 + spatial_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + # Extract angles + angle = reference_points[..., 4:6] # (sin, cos) + angle_expanded = angle[:, :, None] + reference_points_inputs = torch.cat([spatial_inputs, angle_expanded], dim=-1) + # DETR positional encoding + query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) + base_query_pos = self.ref_point_head(query_sine_embed) + # Angle embedding + sin_t = angle[..., 0:1] + cos_t = angle[..., 1:2] + + angle_feat = torch.cat( + [ + sin_t, + cos_t, + 2 * sin_t * cos_t, + cos_t**2 - sin_t**2, + ], + dim=-1, ) - # batch_size, num_queries, d_model - query_pos = self.ref_point_head(query_sine_embed) + angle_emb = self.angle_proj(angle_feat) + # Combine + query_pos = base_query_pos + angle_emb return reference_points_inputs, query_pos + def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine bounding boxes by applying the predicted deltas to the reference points. + The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. + The refined boxes are computed as follows: + + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + sinθ' = sinθ * cosΔ + cosθ * sinΔ + cosθ' = cosθ * cosΔ - sinθ * sinΔ + + Args: + reference_points: (N, S, 6) tensor containing the reference points + deltas: (N, S, 6) tensor containing the predicted deltas + + Returns: + refined_boxes: (N, S, 6) tensor containing the refined bounding boxes + """ + reference_points = reference_points.to(deltas.device) + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + # size + wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + # rotation + delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + sin_delta = delta_rot[..., 0:1] + cos_delta = delta_rot[..., 1:2] + sin_ref = reference_points[..., 4:5] + cos_ref = reference_points[..., 5:6] + # compose rotations + sin_new = sin_ref * cos_delta + cos_ref * sin_delta + cos_new = cos_ref * cos_delta - sin_ref * sin_delta + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) + return torch.cat((cxcy, wh, rot), dim=-1) + def forward( self, inputs_embeds: torch.Tensor | None, reference_points: torch.Tensor, - spatial_shapes: torch.Tensor, spatial_shapes_list: torch.Tensor, valid_ratios: torch.Tensor, encoder_hidden_states: torch.Tensor, @@ -505,18 +596,34 @@ def forward( encoder_attention_mask=encoder_attention_mask, position_embeddings=query_pos, reference_points=reference_points_inputs, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, ) - intermediate_hidden_states = self.layernorm(hidden_states) - intermediate.append(intermediate_hidden_states) + hidden_states_norm = self.layernorm(hidden_states) + + # iterative refinement + if self.bbox_embed is not None: + delta = self.bbox_embed(hidden_states_norm) + + reference_points = self.refine_boxes( + reference_points.squeeze(2), + delta, + ) + intermediate_reference_points.append(reference_points) + + reference_points_inputs, query_pos = self.get_reference( + reference_points, + valid_ratios, + ) + + intermediate.append(hidden_states_norm) + + intermediate_stack = torch.stack(intermediate) + last_hidden_state = intermediate_stack[-1] - intermediate = torch.stack(intermediate) - last_hidden_state = intermediate[-1] - intermediate_reference_points = torch.stack(intermediate_reference_points) + intermediate_reference_points_stack = torch.stack(intermediate_reference_points) - return last_hidden_state, intermediate, intermediate_reference_points + return last_hidden_state, intermediate_stack, intermediate_reference_points_stack class MultiScaleProjector(nn.Module): diff --git a/doctr/models/layout/lw_detr/loss.py b/doctr/models/layout/lw_detr/loss.py deleted file mode 100644 index ba4364764f..0000000000 --- a/doctr/models/layout/lw_detr/loss.py +++ /dev/null @@ -1,527 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from scipy.optimize import linear_sum_assignment -from torch import Tensor - -__all__ = ["lw_detr_for_object_detection_loss"] - - -def center_to_corners_format(bboxes_center: "torch.Tensor") -> "torch.Tensor": - center_x, center_y, width, height = bboxes_center.unbind(-1) - bbox_corners = torch.stack( - # top left x, top left y, bottom right x, bottom right y - [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)], - dim=-1, - ) - return bbox_corners - - -def dice_loss(inputs, targets, num_boxes): - """ - Compute the DICE loss, similar to generalized IOU for masks - - Args: - inputs: A float tensor of arbitrary shape. - The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary - classification label for each element in inputs (0 for the negative class and 1 for the positive - class). - num_boxes: Normalization factor, typically the number of target boxes in the batch. This is used to scale the - loss to an absolute value, and is used in the original implementation of DETR and LW-DETR. - It doesn't have to be - exactly the number of target boxes, but it should be correlated to it for the loss to be meaningful. - """ - inputs = inputs.sigmoid() - inputs = inputs.flatten(1) - numerator = 2 * (inputs * targets).sum(1) - denominator = inputs.sum(-1) + targets.sum(-1) - loss = 1 - (numerator + 1) / (denominator + 1) - return loss.sum() / num_boxes - - -def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): - """ - Loss used in RetinaNet for dense detection: https://huggingface.co/papers/1708.02002. - - Args: - inputs (`torch.FloatTensor` of arbitrary shape): - The predictions for each example. - targets (`torch.FloatTensor` with the same shape as `inputs`): - A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class - and 1 for the positive class). - num_boxes (`int`): - Normalization factor, typically the number of target boxes in the batch. This is used to scale the loss - to an absolute value, and is used in the original implementation of DETR and LW-DETR. It doesn't have to be - exactly the number of target boxes, but it should be correlated to it for the loss to be meaningful. - alpha (`float`, *optional*, defaults to `0.25`): - Optional weighting factor in the range (0,1) to balance positive vs. negative examples. - gamma (`int`, *optional*, defaults to `2`): - Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. - - Returns: - Loss tensor - """ - prob = inputs.sigmoid() - ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none") - # add modulating factor - p_t = prob * targets + (1 - prob) * (1 - targets) - loss = ce_loss * ((1 - p_t) ** gamma) - - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - loss = alpha_t * loss - - return loss.mean(1).sum() / num_boxes - - -def _upcast(t: Tensor) -> Tensor: - # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.is_floating_point(): - return t if t.dtype in (torch.float32, torch.float64) else t.float() - else: - return t if t.dtype in (torch.int32, torch.int64) else t.int() - - -def box_area(boxes: Tensor) -> Tensor: - """ - Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. - - Args: - boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): - Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 - < x2` and `0 <= y1 < y2`. - - Returns: - `torch.FloatTensor`: a tensor containing the area for each box. - """ - boxes = _upcast(boxes) - return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) - - -# modified from torchvision to also return the union -def box_iou(boxes1, boxes2): - area1 = box_area(boxes1) - area2 = box_area(boxes2) - - left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] - - width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] - inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] - - union = area1[:, None] + area2 - inter - - iou = inter / union - return iou, union - - -def generalized_box_iou(boxes1, boxes2): - """ - Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. - - Returns: - `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) - """ - # degenerate boxes gives inf / nan results - # so do an early check - if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): - raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") - if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): - raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") - iou, union = box_iou(boxes1, boxes2) - - top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - - width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2] - area = width_height[:, :, 0] * width_height[:, :, 1] - - return iou - (area - union) / area - - -# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306 -def _max_by_axis(the_list): - # type: (list[list[int]]) -> list[int] - maxes = the_list[0] - for sublist in the_list[1:]: - for index, item in enumerate(sublist): - maxes[index] = max(maxes[index], item) - return maxes - - -class NestedTensor: - def __init__(self, tensors, mask: Tensor | None): - self.tensors = tensors - self.mask = mask - - def to(self, device): - cast_tensor = self.tensors.to(device) - mask = self.mask - if mask is not None: - cast_mask = mask.to(device) - else: - cast_mask = None - return NestedTensor(cast_tensor, cast_mask) - - def decompose(self): - return self.tensors, self.mask - - def __repr__(self): - return str(self.tensors) - - -def nested_tensor_from_tensor_list(tensor_list: list[Tensor]): - if tensor_list[0].ndim == 3: - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - batch_shape = [len(tensor_list)] + max_size - batch_size, num_channels, height, width = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((batch_size, height, width), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], : img.shape[2]] = False - else: - raise ValueError("Only 3-dimensional tensors are supported") - return NestedTensor(tensor, mask) - - -# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py -def _set_aux_loss(outputs_class, outputs_coord): - return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] - - -class LwDetrHungarianMatcher(nn.Module): - def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): - super().__init__() - - self.class_cost = class_cost - self.bbox_cost = bbox_cost - self.giou_cost = giou_cost - if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: - raise ValueError("All costs of the Matcher can't be 0") - - @torch.no_grad() - def forward(self, outputs, targets, group_detr): - """ - Differences: - - out_prob = outputs["logits"].flatten(0, 1).sigmoid() instead of softmax - - class_cost uses alpha and gamma - """ - batch_size, num_queries = outputs["logits"].shape[:2] - - # We flatten to compute the cost matrices in a batch - out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] - out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] - - # Also concat the target labels and boxes - target_ids = torch.cat([torch.as_tensor(v["class_labels"], dtype=torch.int64) for v in targets]).to( - out_prob.device - ) - target_bbox = torch.cat([torch.as_tensor(v["boxes"], dtype=torch.float32) for v in targets]).to(out_bbox.device) - - # Compute the classification cost. - alpha = 0.25 - gamma = 2.0 - neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) - pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) - class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids] - - # Compute the L1 cost between boxes, cdist only supports float32 - dtype = out_bbox.dtype - out_bbox = out_bbox.to(torch.float32) - target_bbox = target_bbox.to(torch.float32) - bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) - bbox_cost = bbox_cost.to(dtype) - - # Compute the giou cost between boxes - giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) - - # Final cost matrix - cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost - cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() - - sizes = [len(v["boxes"]) for v in targets] - indices = [] - group_num_queries = num_queries // group_detr - cost_matrix_list = cost_matrix.split(group_num_queries, dim=1) - for group_id in range(group_detr): - group_cost_matrix = cost_matrix_list[group_id] - group_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(group_cost_matrix.split(sizes, -1))] - if group_id == 0: - indices = group_indices - else: - indices = [ - ( - np.concatenate([indice1[0], indice2[0] + group_num_queries * group_id]), - np.concatenate([indice1[1], indice2[1]]), - ) - for indice1, indice2 in zip(indices, group_indices) - ] - return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] - - -class LwDetrImageLoss(nn.Module): - def __init__(self, matcher, num_classes, focal_alpha, losses, group_detr): - super().__init__() - self.matcher = matcher - self.num_classes = num_classes - self.focal_alpha = focal_alpha - self.losses = losses - self.group_detr = group_detr - - # removed logging parameter, which was part of the original implementation - def loss_labels(self, outputs, targets, indices, num_boxes): - if "logits" not in outputs: - raise KeyError("No logits were found in the outputs") - source_logits = outputs["logits"] - dtype = source_logits.dtype - - idx = self._get_source_permutation_idx(indices) - target_classes_o = torch.cat([ - torch.as_tensor(np.atleast_1d(t["class_labels"][J]), dtype=torch.int64) - for t, (_, J) in zip(targets, indices) - ]).to(source_logits.device) - alpha = self.focal_alpha - gamma = 2 - src_boxes = outputs["pred_boxes"][idx] - target_boxes = torch.cat( - [torch.as_tensor(np.atleast_2d(t["boxes"][i]), dtype=torch.float32) for t, (_, i) in zip(targets, indices)], - dim=0, - ).to(src_boxes.device) - iou_targets = torch.diag( - box_iou(center_to_corners_format(src_boxes.detach()), center_to_corners_format(target_boxes))[0] - ) - # Convert to the same dtype as the source logits as box_iou upcasts to float32 - iou_targets = iou_targets.to(dtype) - pos_ious = iou_targets.clone().detach() - prob = source_logits.sigmoid() - # init positive weights and negative weights - pos_weights = torch.zeros_like(source_logits) - # pow promotes to float32 under float16 CUDA autocast; cast back to preserve original dtype - neg_weights = prob.pow(gamma).to(dtype) - pos_ind = idx + (target_classes_o,) - - pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) - pos_quality = torch.clamp(pos_quality, 0.01).detach().to(dtype) - - pos_weights[pos_ind] = pos_quality - neg_weights[pos_ind] = 1 - pos_quality - loss_ce = -pos_weights * prob.log() - neg_weights * (1 - prob).log() - loss_ce = loss_ce.sum() / num_boxes - losses = {"loss_ce": loss_ce} - - return losses - - @torch.no_grad() - def loss_cardinality(self, outputs, targets, indices, num_boxes): - """ - Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. - - This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients. - """ - logits = outputs["logits"] - device = logits.device - target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) - # Count the number of predictions that are NOT "no-object" (sigmoid > 0.5 threshold) - card_pred = (logits.sigmoid().max(-1).values > 0.5).sum(1) - card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float()) - losses = {"cardinality_error": card_err} - return losses - - # Copied from loss.loss_for_object_detection.ImageLoss.loss_boxes - def loss_boxes(self, outputs, targets, indices, num_boxes): - """ - Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. - - Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes - are expected in format (center_x, center_y, w, h), normalized by the image size. - """ - if "pred_boxes" not in outputs: - raise KeyError("No predicted boxes found in outputs") - idx = self._get_source_permutation_idx(indices) - source_boxes = outputs["pred_boxes"][idx] - target_boxes = torch.cat( - [torch.as_tensor(np.atleast_2d(t["boxes"][i]), dtype=torch.float32) for t, (_, i) in zip(targets, indices)], - dim=0, - ).to(source_boxes.device) - - loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") - - losses = {} - losses["loss_bbox"] = loss_bbox.sum() / num_boxes - - loss_giou = 1 - torch.diag( - generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)) - ) - losses["loss_giou"] = loss_giou.sum() / num_boxes - return losses - - # Copied from loss.loss_for_object_detection.ImageLoss.loss_masks - def loss_masks(self, outputs, targets, indices, num_boxes): - """ - Compute the losses related to the masks: the focal loss and the dice loss. - - Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]. - """ - if "pred_masks" not in outputs: - raise KeyError("No predicted masks found in outputs") - - source_idx = self._get_source_permutation_idx(indices) - target_idx = self._get_target_permutation_idx(indices) - source_masks = outputs["pred_masks"] - source_masks = source_masks[source_idx] - masks = [t["masks"] for t in targets] - # TODO use valid to mask invalid areas due to padding in loss - target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() - target_masks = target_masks.to(source_masks) - target_masks = target_masks[target_idx] - - # upsample predictions to the target size - source_masks = nn.functional.interpolate( - source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False - ) - source_masks = source_masks[:, 0].flatten(1) - - target_masks = target_masks.flatten(1) - target_masks = target_masks.view(source_masks.shape) - losses = { - "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes), - "loss_dice": dice_loss(source_masks, target_masks, num_boxes), - } - return losses - - # Copied from loss.loss_for_object_detection.ImageLoss._get_source_permutation_idx - def _get_source_permutation_idx(self, indices): - # permute predictions following indices - batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)]) - source_idx = torch.cat([source for (source, _) in indices]) - return batch_idx, source_idx - - # Copied from loss.loss_for_object_detection.ImageLoss._get_target_permutation_idx - def _get_target_permutation_idx(self, indices): - # permute targets following indices - batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)]) - target_idx = torch.cat([target for (_, target) in indices]) - return batch_idx, target_idx - - def get_loss(self, loss, outputs, targets, indices, num_boxes): - loss_map = { - "labels": self.loss_labels, - "cardinality": self.loss_cardinality, - "boxes": self.loss_boxes, - "masks": self.loss_masks, - } - if loss not in loss_map: - raise ValueError(f"Loss {loss} not supported") - return loss_map[loss](outputs, targets, indices, num_boxes) - - def forward(self, outputs, targets): - """ - This performs the loss computation. - - Args: - outputs (`dict`, *optional*): - Dictionary of tensors, see the output specification of the model for the format. - targets (`list[dict]`, *optional*): - List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the - losses applied, see each loss' doc. - """ - group_detr = self.group_detr if self.training else 1 - outputs_without_aux_and_enc = { - k: v for k, v in outputs.items() if k != "enc_outputs" and k != "auxiliary_outputs" - } - - # Retrieve the matching between the outputs of the last layer and the targets - indices = self.matcher(outputs_without_aux_and_enc, targets, group_detr) - - # Compute the average number of target boxes across all nodes, for normalization purposes - num_boxes = sum(len(t["class_labels"]) for t in targets) - num_boxes = num_boxes * group_detr - num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - world_size = 1 - num_boxes = torch.clamp(num_boxes / world_size, min=1).item() - - # Compute all the requested losses - losses = {} - for loss in self.losses: - losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) - - # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. - if "auxiliary_outputs" in outputs: - for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): - indices = self.matcher(auxiliary_outputs, targets, group_detr) - for loss in self.losses: - if loss == "masks": - # Intermediate masks losses are too costly to compute, we ignore them. - continue - l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes) - l_dict = {k + f"_{i}": v for k, v in l_dict.items()} - losses.update(l_dict) - - if "enc_outputs" in outputs: - enc_outputs = outputs["enc_outputs"] - indices = self.matcher(enc_outputs, targets, group_detr=group_detr) - for loss in self.losses: - l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes) - l_dict = {k + "_enc": v for k, v in l_dict.items()} - losses.update(l_dict) - - return losses - - -def lw_detr_for_object_detection_loss( - logits, - labels, - device, - pred_boxes, - outputs_class=None, - outputs_coord=None, - enc_outputs_class=None, - enc_outputs_coord=None, - use_aux_loss=False, - group_detr=1, - num_labels=None, - num_decoder_layers=None, - **kwargs, -): - """Loss computation for LW-DETR for object detection.""" - # First: create the matcher - matcher = LwDetrHungarianMatcher(class_cost=2.0, bbox_cost=5, giou_cost=2) - # Second: create the criterion - losses = ["labels", "boxes", "cardinality"] - criterion = LwDetrImageLoss( - matcher=matcher, - num_classes=num_labels, - focal_alpha=0.1, - losses=losses, - group_detr=group_detr, - ) - criterion.to(device) - # Third: compute the losses, based on outputs and labels - outputs_loss = {} - auxiliary_outputs = None - outputs_loss["logits"] = logits - outputs_loss["pred_boxes"] = pred_boxes - outputs_loss["enc_outputs"] = { - "logits": enc_outputs_class, - "pred_boxes": enc_outputs_coord, - } - if use_aux_loss: - auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord) - outputs_loss["auxiliary_outputs"] = auxiliary_outputs - loss_dict = criterion(outputs_loss, labels) - # Fourth: compute total loss, as a weighted sum of the various losses - weight_dict = {"loss_ce": 1, "loss_bbox": 5} - weight_dict["loss_giou"] = 2 - if use_aux_loss: - aux_weight_dict = {} - for i in range(num_decoder_layers - 1): - aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) - weight_dict.update(aux_weight_dict) - enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()} - weight_dict.update(enc_weight_dict) - loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict) - return loss, loss_dict, auxiliary_outputs diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 2ca0ecee4a..c8d378d27a 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -10,6 +10,7 @@ import numpy as np import torch +from scipy.optimize import linear_sum_assignment from torch import nn from torch.nn import functional as F @@ -18,7 +19,6 @@ from ...utils import load_pretrained_params from .base import _LWDETR, LWDETRPostProcessor from .layers import LWDETRDecoder, LWDETRHead, LWDETRMultiscaleDeformableAttention, MultiScaleProjector -from .loss import lw_detr_for_object_detection_loss __all__ = ["LWDETR", "lw_detr_s", "lw_detr_m"] @@ -154,10 +154,10 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.0, - iou_thresh: float = 0.1, + score_thresh: float = 0.05, + iou_thresh: float = 0.05, d_model: int = 256, - num_queries: int = 130, + num_queries: int = 50, group_detr: int = 1, dec_layers: int = 3, sa_num_heads: int = 8, @@ -172,7 +172,7 @@ def __init__( super().__init__() self.class_names: list[str] = class_names - self.num_classes = len(self.class_names) + 1 # +1 for background class (NO OBJECT) + self.num_classes = len(self.class_names) + 1 # +1 for background class self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages @@ -182,13 +182,22 @@ def __init__( self.group_detr = group_detr self.num_queries = num_queries self.d_model = d_model - self.dec_layers = dec_layers - self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4) + self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) + # Initialize angle to (sin=0, cos=1) + with torch.no_grad(): + self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) + self.reference_point_embed.weight[:, 2:4].fill_(0.1) + self.reference_point_embed.weight[:, 4].zero_() + self.reference_point_embed.weight[:, 5].fill_(1.0) + self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) + self.class_embed = nn.Linear(self.d_model, self.num_classes) + self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) + self.decoder = LWDETRDecoder( - num_layers=self.dec_layers, + num_layers=dec_layers, d_model=d_model, sa_num_heads=sa_num_heads, ca_num_heads=ca_num_heads, @@ -196,21 +205,19 @@ def __init__( dec_n_points=dec_n_points, group_detr=group_detr, dropout_prob=dropout_prob, + bbox_embed=self.bbox_embed, ) self.enc_output = nn.ModuleList([nn.Linear(self.d_model, self.d_model) for _ in range(self.group_detr)]) self.enc_output_norm = nn.ModuleList([nn.LayerNorm(self.d_model) for _ in range(self.group_detr)]) self.enc_out_bbox_embed = nn.ModuleList([ - LWDETRHead(self.d_model, self.d_model, 4, num_layers=3) for _ in range(self.group_detr) + LWDETRHead(self.d_model, self.d_model, 6, num_layers=3) for _ in range(self.group_detr) ]) self.enc_out_class_embed = nn.ModuleList([ nn.Linear(self.d_model, self.num_classes) for _ in range(self.group_detr) ]) - self.class_embed = nn.Linear(self.d_model, self.num_classes) - self.bbox_embed = LWDETRHead(self.d_model, self.d_model, 4, num_layers=3) - self.postprocessor = LWDETRPostProcessor( num_classes=self.num_classes, score_thresh=score_thresh, @@ -222,9 +229,28 @@ def __init__( # Don't override the initialization of the backbone if n.startswith("feat_extractor."): continue - if isinstance(m, LWDETRMultiscaleDeformableAttention): + + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): + if hasattr(m, "weight") and m.weight is not None: + nn.init.ones_(m.weight) + if hasattr(m, "bias") and m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + # Don't overwrite the carefully seeded reference point embedding + if m is not self.reference_point_embed: + nn.init.normal_(m.weight, std=0.02) + elif isinstance(m, LWDETRMultiscaleDeformableAttention): nn.init.constant_(m.sampling_offsets.weight, 0.0) - thetas = torch.arange(m.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / m.n_heads) + + thetas = torch.arange(m.n_heads, dtype=torch.float32) * (2.0 * math.pi / m.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) @@ -235,21 +261,29 @@ def __init__( grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): m.sampling_offsets.bias.copy_(grid_init.view(-1)) + nn.init.constant_(m.attention_weights.weight, 0.0) nn.init.constant_(m.attention_weights.bias, 0.0) nn.init.xavier_uniform_(m.value_proj.weight) - nn.init.constant_(m.value_proj.bias, 0.0) + nn.init.zeros_(m.value_proj.bias) nn.init.xavier_uniform_(m.output_proj.weight) - nn.init.constant_(m.output_proj.bias, 0.0) - if hasattr(m, "refpoint_embed") and m.refpoint_embed is not None: - nn.init.constant_(m.refpoint_embed.weight, 0) - if hasattr(m, "class_embed") and m.class_embed is not None: - prior_prob = 0.01 - bias_value = -math.log((1 - prior_prob) / prior_prob) - nn.init.constant_(m.class_embed.bias, bias_value) - if hasattr(m, "bbox_embed") and m.bbox_embed is not None: - nn.init.constant_(m.bbox_embed.layers[-1].weight, 0) - nn.init.constant_(m.bbox_embed.layers[-1].bias, 0) + nn.init.zeros_(m.output_proj.bias) + if isinstance(m, nn.Linear) and m.out_features == self.num_classes: + if m.bias is not None: + with torch.no_grad(): + # Focal-loss prior: foreground starts with low confidence (~0.01), + # preventing background from dominating gradients at the start of training. + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + nn.init.constant_(m.bias, 0.0) + m.bias[:-1].fill_(bias_value) + if isinstance(m, LWDETRHead): + last = m.layers[-1] + if isinstance(last, nn.Linear): + nn.init.zeros_(last.weight) + nn.init.zeros_(last.bias) + if last.bias.shape[0] == 6: + nn.init.constant_(last.bias[5], 1.0) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -260,39 +294,76 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def refine_bboxes(self, reference_points, deltas): + def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine bounding boxes by applying the predicted deltas to the reference points. + The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. + The refined boxes are computed as follows: + + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + sinθ' = sinθ * cosΔ + cosθ * sinΔ + cosθ' = cosθ * cosΔ - sinθ * sinΔ + + Args: + reference_points: (N, S, 6) tensor containing the reference points + deltas: (N, S, 6) tensor containing the predicted deltas + + Returns: + refined_boxes: (N, S, 6) tensor containing the refined bounding boxes + """ reference_points = reference_points.to(deltas.device) - new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2] - new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:] - new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1) - return new_reference_points + cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + # size + wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + # rotation + delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + sin_delta = delta_rot[..., 0:1] + cos_delta = delta_rot[..., 1:2] + sin_ref = reference_points[..., 4:5] + cos_ref = reference_points[..., 5:6] + # compose rotations + sin_new = sin_ref * cos_delta + cos_ref * sin_delta + cos_new = cos_ref * cos_delta - sin_ref * sin_delta + rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) + return torch.cat((cxcy, wh, rot), dim=-1) + + def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Get the valid ratio of all feature maps. + + Args: + mask: (N, H, W) binary tensor containing 1 on padded pixels + dtype: the desired data type of the output tensor - def get_valid_ratio(self, mask, dtype=torch.float32): - """Get the valid ratio of all feature maps.""" + Returns: + valid_ratio: (N, 2) tensor containing the valid ratio of width and height for each image in the batch + """ _, height, width = mask.shape - valid_height = torch.sum(mask[:, :, 0], 1) - valid_width = torch.sum(mask[:, 0, :], 1) + valid_height = torch.sum(~mask[:, :, 0], 1) + valid_width = torch.sum(~mask[:, 0, :], 1) valid_ratio_height = valid_height.to(dtype) / height valid_ratio_width = valid_width.to(dtype) / width valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1) return valid_ratio - def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes): + def gen_encoder_output_proposals( + self, enc_output: torch.Tensor, padding_mask: torch.Tensor, spatial_shapes: list[tuple[int, int]] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the encoder output proposals from encoded enc_output. Args: - enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder. - padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`. - spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps. + enc_output: Output of the encoder + padding_mask: Padding mask for `enc_output` + spatial_shapes: Spatial shapes of the feature maps Returns: - `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction. - - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to - directly predict a bounding box. (without the need of a decoder) - - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals in [0, 1] space. - Invalid positions (padding or out-of-bounds) are filled with 0. - - invalid_mask (Tensor[batch_size, sequence_length, 1]): Boolean mask that is True for invalid positions - (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). + A tuple of feature map and bbox prediction. + - object_query: Object query features. Later used to directly predict a bounding box. + - output_proposals: Normalized proposals in [0, 1] space. + Invalid positions (padding or out-of-bounds) are filled with 0. + - invalid_mask: Boolean mask that is True for invalid positions + (padded pixels or proposals whose coordinates fall outside (0.01, 0.99)). """ batch_size = enc_output.shape[0] proposals = [] @@ -324,17 +395,23 @@ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes) scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2) grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale width_height = torch.ones_like(grid) * 0.05 * (2.0**level) - proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4) + # add default rotation (sin=0, cos=1) + sin = torch.zeros_like(grid[..., :1]) + cos = torch.ones_like(grid[..., :1]) + proposal = torch.cat((grid, width_height, sin, cos), -1).view(batch_size, -1, 6) proposals.append(proposal) _cur += height * width output_proposals = torch.cat(proposals, 1) - output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + + spatial_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals_valid = spatial_valid invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) # assign each pixel as an object query object_query = enc_output - object_query = object_query.masked_fill(invalid_mask, float(0)) + object_query = object_query.masked_fill(invalid_mask, 0.0) + return object_query, output_proposals, invalid_mask def forward( @@ -379,7 +456,6 @@ def forward( mask_flatten_list.append(mask) source_flatten = torch.cat(source_flatten_list, 1) mask_flatten = torch.cat(mask_flatten_list, 1) - spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in feats_masks], 1) tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) @@ -392,73 +468,65 @@ def forward( group_detr = self.group_detr if self.training else 1 topk = self.num_queries - topk_coords_logits = [] - topk_coords_logits_undetach = [] - object_query_undetach = [] + topk_coords_logits_list: list[torch.Tensor] = [] + topk_content_list: list[torch.Tensor] = [] + + # encoder predictions for auxiliary losses + all_group_enc_logits: list[torch.Tensor] = [] + all_group_enc_coords: list[torch.Tensor] = [] for group_id in range(group_detr): group_object_query = self.enc_output[group_id](object_query_embedding) group_object_query = self.enc_output_norm[group_id](group_object_query) group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) - group_enc_outputs_class = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) + all_group_enc_logits.append(group_enc_outputs_class) + + group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) + group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) - group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1] + all_group_enc_coords.append(group_enc_outputs_coord) + + scores = group_enc_outputs_class_masked[..., :-1].max(-1).values + group_topk_proposals = torch.topk(scores, topk, dim=1)[1] + group_topk_coords_logits_undetach = torch.gather( group_enc_outputs_coord, 1, - group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4), + group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) group_topk_coords_logits = group_topk_coords_logits_undetach.detach() - group_object_query_undetach = torch.gather( - group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model) + topk_coords_logits_list.append(group_topk_coords_logits) + group_topk_content = torch.gather( + group_object_query, + 1, + group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model), ) + topk_content_list.append(group_topk_content) - topk_coords_logits.append(group_topk_coords_logits) - topk_coords_logits_undetach.append(group_topk_coords_logits_undetach) - object_query_undetach.append(group_object_query_undetach) - - topk_coords_logits = torch.cat(topk_coords_logits, 1) - topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1) - object_query_undetach = torch.cat(object_query_undetach, 1) + topk_coords_logits = torch.cat(topk_coords_logits_list, 1) + reference_points = topk_coords_logits - enc_outputs_class_logits = object_query_undetach - enc_outputs_boxes_logits = topk_coords_logits_undetach + topk_content = torch.cat(topk_content_list, 1).detach() + tgt = tgt + topk_content - reference_points = self.refine_bboxes(topk_coords_logits, reference_points) + encoder_attention_mask = mask_flatten - init_reference_points = reference_points - last_hidden_state, intermediate, intermediate_reference_points = self.decoder( + last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, reference_points=reference_points, - spatial_shapes=spatial_shapes, spatial_shapes_list=spatial_shapes_list, valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, - encoder_attention_mask=mask_flatten, + encoder_attention_mask=encoder_attention_mask, ) - logits = self.class_embed(last_hidden_state) - pred_boxes_delta = self.bbox_embed(last_hidden_state) + logits = self.class_embed(last_hidden_states) + pred_boxes_delta = self.bbox_embed(last_hidden_states) pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) - enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.num_queries, dim=1) - pred_class = [] - group_detr = self.group_detr if self.training else 1 - for group_index in range(group_detr): - group_pred_class = self.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index]) - pred_class.append(group_pred_class) - enc_outputs_class_logits = torch.cat(pred_class, dim=1) - - if target is not None: - outputs_class, outputs_coord = None, None - intermediate_hidden_states = intermediate - outputs_coord_delta = self.bbox_embed(intermediate_hidden_states) - outputs_coord = self.refine_bboxes(intermediate_reference_points, outputs_coord_delta) - outputs_class = self.class_embed(intermediate_hidden_states) - out: dict[str, Any] = {} if self.exportable: @@ -480,44 +548,200 @@ def _postprocess(logits, boxes): if target is not None: # Build target processed_targets = self.build_target(target, self.class_names) - out["loss"] = self.compute_loss( - logits, - processed_targets, - pred_boxes, - outputs_class, - outputs_coord, - enc_outputs_class_logits, - enc_outputs_boxes_logits, - ) + + # Main loss from final decoder layer (group DETR) + split_logits = logits.chunk(group_detr, dim=1) + split_boxes = pred_boxes.chunk(group_detr, dim=1) + + main_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_logits, split_boxes): + main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) + loss = main_loss / group_detr + + # Auxiliary losses from intermediate decoder layers (group DETR) + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]) + aux_boxes_delta = self.bbox_embed(intermediate[i]) + aux_boxes = self.refine_bboxes(intermediate_reference_points[i + 1], aux_boxes_delta) + + split_aux_logits = aux_logits.chunk(group_detr, dim=1) + split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) + + aux_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): + aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) + loss += aux_loss + + # Auxiliary losses for encoder proposals + enc_loss: float | torch.Tensor = 0.0 + for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): + enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) + loss += enc_loss + + out["loss"] = loss return out def compute_loss( self, - logits, - targets, - pred_boxes, - outputs_class, - outputs_coord, - enc_outputs_class_logits, - enc_outputs_boxes_logits, - ): - - loss_calc = lw_detr_for_object_detection_loss( - logits=logits, - device=logits.device, - labels=targets, - pred_boxes=pred_boxes, - outputs_class=outputs_class, - outputs_coord=outputs_coord, - enc_outputs_class=enc_outputs_class_logits, - enc_outputs_coord=enc_outputs_boxes_logits, - use_aux_loss=True, - group_detr=self.group_detr, - num_decoder_layers=self.dec_layers, - num_labels=self.num_classes, - ) - return loss_calc[0] + logits: torch.Tensor, + pred_boxes: torch.Tensor, + targets: list[dict[str, np.ndarray]], + ) -> torch.Tensor: + + def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format + to Gaussian distributions (mean and covariance). + The mean is simply (cx, cy), and the covariance is computed from the width, height, and rotation angle. + + Args: + boxes: (N, S, 6) tensor containing the rotated boxes in (cx, cy, w, h, sinθ, cosθ) format + Returns: + A tuple of (mean, covariance) where: + - mean is a (N, S, 2) tensor containing the mean (cx, cy) of the Gaussian distributions + - covariance is a (N, S, 2, 2) tensor containing the covariance matrices of the Gaussian distributions + """ + cxcy = boxes[..., :2] + + w = boxes[..., 2].clamp(min=1e-6) + h = boxes[..., 3].clamp(min=1e-6) + + sin = boxes[..., 4] + cos = boxes[..., 5] + + R = torch.stack( + [ + torch.stack([cos, -sin], dim=-1), + torch.stack([sin, cos], dim=-1), + ], + dim=-2, + ) + + # Variance for a box half-width/half-height: σ² = (w/2)² + # Using w²/12 (uniform distribution) produces ~8x smaller variance, + # which collapses Bhattacharyya distance to the clamp ceiling and kills gradients. + sx = (w / 2) ** 2 + sy = (h / 2) ** 2 + + S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) + S[..., 0, 0] = sx + S[..., 1, 1] = sy + + covariance = R @ S @ R.transpose(-1, -2) + return cxcy, covariance + + def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: + mu1, sigma1 = _rotated_boxes_to_gaussian(pred_boxes) + mu2, sigma2 = _rotated_boxes_to_gaussian(tgt_boxes) + + delta = (mu1 - mu2).unsqueeze(-1) + sigma = (sigma1 + sigma2) * 0.5 + + eps = 1e-6 + eye = torch.eye(2, device=sigma.device) * eps + + sigma_safe = sigma + eye + sigma1_safe = sigma1 + eye + sigma2_safe = sigma2 + eye + + sigma_inv = torch.linalg.inv(sigma_safe) + + mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) + + det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) + det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) + det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) + + bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) + + bhattacharyya = torch.clamp(bhattacharyya, min=0.0, max=10.0) + probiou = torch.exp(-bhattacharyya) + return 1 - probiou + + device = logits.device + B, Q, C = logits.shape + + total_cls = torch.tensor(0.0, device=device) + total_box = torch.tensor(0.0, device=device) + + # FIX (issue #7): track total matched boxes across the batch for proper normalisation. + # Classification loss is still normalised by B (it covers all Q queries per image), + # but box/rotation losses are normalised by the actual number of matched pairs. + num_matched_total = 0 + + for b in range(B): + pred_logits = logits[b] + pred_boxes_b = pred_boxes[b] + + boxes = targets[b]["boxes"] + + if len(boxes) == 0: + # Penalize the model for any foreground boxes it guessed on this empty image + background_idx = self.num_classes - 1 + target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) + total_cls += F.cross_entropy(pred_logits, target_classes) + continue + + tgt_boxes = torch.as_tensor(boxes, device=device, dtype=pred_boxes.dtype) + tgt_cls = torch.as_tensor(targets[b]["labels"], device=device, dtype=torch.long) + + if tgt_boxes.ndim == 1: + tgt_boxes = tgt_boxes.unsqueeze(0) + + pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) + tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) + + with torch.no_grad(): + out_logprob = pred_logits.log_softmax(-1) + + cost_cls = -out_logprob[:, tgt_cls] # stable + cost_l1 = torch.cdist(pred_boxes_b[:, :4], tgt_boxes[:, :4], p=1) + cost_rot = 1.0 - torch.abs(pred_rot @ tgt_rot.T) + + total_cost = 2.0 * cost_cls + 5.0 * cost_l1 + 2.0 * cost_rot + + cost_np = total_cost.detach().cpu().numpy() + row_ind, col_ind = linear_sum_assignment(cost_np) + + pos_idx = torch.as_tensor(row_ind, device=device) + gt_idx = torch.as_tensor(col_ind, device=device) + + background_idx = self.num_classes - 1 + + target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) + target_classes[pos_idx] = tgt_cls[gt_idx] + + cls_weights = torch.ones(self.num_classes, device=device) + cls_weights[background_idx] = 1.0 + + total_cls += F.cross_entropy(pred_logits, target_classes, weight=cls_weights) + + if pos_idx.numel() == 0: + continue + + num_matched_total += pos_idx.numel() + + pred_sel = pred_boxes_b[pos_idx] + tgt_sel = tgt_boxes[gt_idx] + + # Smooth L1 (Huber) loss with beta=0.1: behaves like L2 for large errors early in + # training (gentle, stable gradient) and like L1 for small errors later (sharp + # localisation). Raw L1 with a large weight causes loss explosion when predictions + # are far from targets (e.g. 5 * 2.2 = 11.0 per pair vs cls ~2.0). + l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4], reduction="sum", beta=0.1) + probiou_loss = _probiou_loss(pred_sel, tgt_sel).sum() + total_box += 5.0 * l1_loss + 2.0 * probiou_loss + + # FIX (issue #7): normalise box loss by total matched boxes (min 1), not by batch size. + # This prevents images with many GT boxes from dominating over images with few. + # Normalise box loss by total matched boxes, with a floor of 1 to keep + # the box/cls loss ratio stable even when images have very few GT boxes. + num_matched_total = max(num_matched_total, 1) + + loss_cls = total_cls / B + loss_box = total_box / num_matched_total + + return loss_cls + loss_box def _lw_detr( From e2491aa667a89b62161d96e4e089ab1355f22f2d Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 28 May 2026 15:24:41 +0200 Subject: [PATCH 3/9] update --- doctr/models/layout/lw_detr/layers/pytorch.py | 5 +- doctr/models/layout/lw_detr/pytorch.py | 67 ++++++++++--------- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index bb6b503e97..e9f319e664 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -605,10 +605,7 @@ def forward( if self.bbox_embed is not None: delta = self.bbox_embed(hidden_states_norm) - reference_points = self.refine_boxes( - reference_points.squeeze(2), - delta, - ) + reference_points = self.refine_boxes(reference_points, delta) intermediate_reference_points.append(reference_points) reference_points_inputs, query_pos = self.get_reference( diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index c8d378d27a..e813becd0e 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -154,11 +154,11 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.05, - iou_thresh: float = 0.05, + score_thresh: float = 0.3, + iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 50, - group_detr: int = 1, + num_queries: int = 300, + group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, @@ -403,7 +403,7 @@ def gen_encoder_output_proposals( _cur += height * width output_proposals = torch.cat(proposals, 1) - spatial_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) output_proposals_valid = spatial_valid invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) @@ -507,12 +507,8 @@ def forward( topk_content_list.append(group_topk_content) topk_coords_logits = torch.cat(topk_coords_logits_list, 1) - reference_points = topk_coords_logits - topk_content = torch.cat(topk_content_list, 1).detach() - tgt = tgt + topk_content - - encoder_attention_mask = mask_flatten + reference_points = self.refine_bboxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, @@ -520,12 +516,11 @@ def forward( spatial_shapes_list=spatial_shapes_list, valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, - encoder_attention_mask=encoder_attention_mask, + encoder_attention_mask=mask_flatten, ) logits = self.class_embed(last_hidden_states) - pred_boxes_delta = self.bbox_embed(last_hidden_states) - pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + pred_boxes = intermediate_reference_points[-1] out: dict[str, Any] = {} @@ -561,8 +556,7 @@ def _postprocess(logits, boxes): # Auxiliary losses from intermediate decoder layers (group DETR) for i in range(intermediate.shape[0] - 1): aux_logits = self.class_embed(intermediate[i]) - aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = self.refine_bboxes(intermediate_reference_points[i + 1], aux_boxes_delta) + aux_boxes = intermediate_reference_points[i + 1] split_aux_logits = aux_logits.chunk(group_detr, dim=1) split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) @@ -570,13 +564,13 @@ def _postprocess(logits, boxes): aux_loss: float | torch.Tensor = 0.0 for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += aux_loss + loss += aux_loss / group_detr # Auxiliary losses for encoder proposals enc_loss: float | torch.Tensor = 0.0 for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += enc_loss + loss += enc_loss / group_detr out["loss"] = loss @@ -588,6 +582,18 @@ def compute_loss( pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]], ) -> torch.Tensor: + """Compute the loss between predicted logits and boxes and target labels and boxes. + + Args: + logits: (N, S, C) tensor containing the predicted class logits for each query + pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + targets: list of length N, where each element is a dict with keys "labels" and "boxes", + containing the ground truth labels and boxes for each image in the batch. + The boxes are in (cx, cy, w, h, sinθ, cosθ) format. + + Returns: + A scalar tensor containing the computed loss. + """ def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format @@ -631,6 +637,17 @@ def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch return cxcy, covariance def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: + """Compute the ProbIoU loss between predicted and target boxes, + where boxes are represented as Gaussian distributions. + The ProbIoU loss is defined as 1 - exp(-Bhattacharyya distance), + where the Bhattacharyya distance is computed between the two Gaussian distributions. + + Args: + pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + tgt_boxes: (N, S, 6) tensor containing the target boxes in (cx, cy, w, h, sinθ, cosθ) format + Returns: + A (N, S) tensor containing the ProbIoU loss for each pair of predicted and target boxes + """ mu1, sigma1 = _rotated_boxes_to_gaussian(pred_boxes) mu2, sigma2 = _rotated_boxes_to_gaussian(tgt_boxes) @@ -664,9 +681,6 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te total_cls = torch.tensor(0.0, device=device) total_box = torch.tensor(0.0, device=device) - # FIX (issue #7): track total matched boxes across the batch for proper normalisation. - # Classification loss is still normalised by B (it covers all Q queries per image), - # but box/rotation losses are normalised by the actual number of matched pairs. num_matched_total = 0 for b in range(B): @@ -694,7 +708,7 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te with torch.no_grad(): out_logprob = pred_logits.log_softmax(-1) - cost_cls = -out_logprob[:, tgt_cls] # stable + cost_cls = -out_logprob[:, tgt_cls] cost_l1 = torch.cdist(pred_boxes_b[:, :4], tgt_boxes[:, :4], p=1) cost_rot = 1.0 - torch.abs(pred_rot @ tgt_rot.T) @@ -712,7 +726,7 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te target_classes[pos_idx] = tgt_cls[gt_idx] cls_weights = torch.ones(self.num_classes, device=device) - cls_weights[background_idx] = 1.0 + cls_weights[background_idx] = 0.1 total_cls += F.cross_entropy(pred_logits, target_classes, weight=cls_weights) @@ -724,20 +738,11 @@ def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Te pred_sel = pred_boxes_b[pos_idx] tgt_sel = tgt_boxes[gt_idx] - # Smooth L1 (Huber) loss with beta=0.1: behaves like L2 for large errors early in - # training (gentle, stable gradient) and like L1 for small errors later (sharp - # localisation). Raw L1 with a large weight causes loss explosion when predictions - # are far from targets (e.g. 5 * 2.2 = 11.0 per pair vs cls ~2.0). l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4], reduction="sum", beta=0.1) probiou_loss = _probiou_loss(pred_sel, tgt_sel).sum() total_box += 5.0 * l1_loss + 2.0 * probiou_loss - # FIX (issue #7): normalise box loss by total matched boxes (min 1), not by batch size. - # This prevents images with many GT boxes from dominating over images with few. - # Normalise box loss by total matched boxes, with a floor of 1 to keep - # the box/cls loss ratio stable even when images have very few GT boxes. num_matched_total = max(num_matched_total, 1) - loss_cls = total_cls / B loss_box = total_box / num_matched_total From 4774d9fede460bcbfce0f9307cfb3cfe3b2df756 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 28 May 2026 15:28:14 +0200 Subject: [PATCH 4/9] update --- doctr/models/layout/lw_detr/base.py | 11 ----------- doctr/models/layout/lw_detr/pytorch.py | 1 - 2 files changed, 12 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 1a7cb452cf..8c461c9fc2 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -301,17 +301,6 @@ def to_quad(box: np.ndarray): if boxes.ndim == 1: boxes = boxes[None, :] - # Sanity check: coordinates must be in [0, 1] normalized space. - # Values > 1.5 almost certainly indicate pixel coordinates were passed in. - flat = boxes.ravel() - coord_vals = flat[flat > 0] - if len(coord_vals) > 0 and coord_vals.max() > 1.5: - raise ValueError( - f"build_target expects normalized [0, 1] box coordinates, " - f"but found values up to {coord_vals.max():.1f} for class '{class_name}'. " - f"Divide your coordinates by image width/height before calling build_target." - ) - for box in boxes: poly = to_quad(box) obb = _quad_to_obb(poly) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index e813becd0e..d760504378 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -244,7 +244,6 @@ def __init__( if hasattr(m, "bias") and m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): - # Don't overwrite the carefully seeded reference point embedding if m is not self.reference_point_embed: nn.init.normal_(m.weight, std=0.02) elif isinstance(m, LWDETRMultiscaleDeformableAttention): From 1fa67955691b28f49793dde2a430a6b5a0ba356c Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 09:21:27 +0200 Subject: [PATCH 5/9] loss and model fixes --- doctr/models/layout/lw_detr/base.py | 19 +- doctr/models/layout/lw_detr/layers/pytorch.py | 42 +- doctr/models/layout/lw_detr/pytorch.py | 379 +++++++++--------- 3 files changed, 218 insertions(+), 222 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 8c461c9fc2..9c6890e6b0 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -144,18 +144,11 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int for b in range(boxes.shape[0]): # Convert logits to probabilities and get scores and labels - exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) - prob = exp / exp.sum(axis=-1, keepdims=True) + prob = 1.0 / (1.0 + np.exp(-logits[b])) scores = prob.max(axis=-1) labels = prob.argmax(axis=-1) - # treat background as invalid prediction - bg = self.num_classes - 1 - valid = labels != bg - - scores = scores * valid - # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: idxs = np.argpartition(-scores, self.topk)[: self.topk] @@ -167,11 +160,11 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int labels_b = labels[idxs] bboxes = boxes[b][idxs] - mask = scores_b > self.score_thresh - - bboxes = bboxes[mask] - scores_b = scores_b[mask] - labels_b = labels_b[mask] + # Filter by score threshold + thresh_mask = scores_b >= self.score_thresh + scores_b = scores_b[thresh_mask] + labels_b = labels_b[thresh_mask] + bboxes = bboxes[thresh_mask] polys, _ = ( self._decode_boxes(bboxes) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index e9f319e664..f0bc63c865 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -293,6 +293,8 @@ def forward( else: raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") + # clamp sampling locations to keep them within the valid range [0, 1] for grid sampling + sampling_locations = sampling_locations.clamp(0.0, 1.0) output = self.attn( value, spatial_shapes_list, @@ -482,11 +484,7 @@ def __init__( self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) - self.angle_proj = nn.Sequential( - nn.Linear(4, self.d_model), - nn.ReLU(), - nn.Linear(self.d_model, self.d_model), - ) + self.angle_proj = nn.Linear(2, self.d_model) def get_reference( self, reference_points: torch.Tensor, valid_ratios: torch.Tensor @@ -517,26 +515,13 @@ def get_reference( # DETR positional encoding query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) base_query_pos = self.ref_point_head(query_sine_embed) - # Angle embedding - sin_t = angle[..., 0:1] - cos_t = angle[..., 1:2] - - angle_feat = torch.cat( - [ - sin_t, - cos_t, - 2 * sin_t * cos_t, - cos_t**2 - sin_t**2, - ], - dim=-1, - ) - angle_emb = self.angle_proj(angle_feat) + angle_emb = self.angle_proj(angle) # Combine query_pos = base_query_pos + angle_emb return reference_points_inputs, query_pos - def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: """Refine bounding boxes by applying the predicted deltas to the reference points. The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. The refined boxes are computed as follows: @@ -556,11 +541,20 @@ def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> refined_boxes: (N, S, 6) tensor containing the refined bounding boxes """ reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] # size - wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] + wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=4.0).exp() * reference_points[..., 2:4] + wh = wh.clamp(min=1e-4, max=1.0) + # center + raw_cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + half_wh = wh / 2 + cxcy = raw_cxcy.clamp( + min=half_wh, + max=1.0 - half_wh, + ) # rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + sin_d = deltas[..., 4:5] + cos_d = deltas[..., 5:6] + 1.0 + delta_rot = F.normalize(torch.cat([sin_d, cos_d], dim=-1), dim=-1, eps=1e-6) sin_delta = delta_rot[..., 0:1] cos_delta = delta_rot[..., 1:2] sin_ref = reference_points[..., 4:5] @@ -605,7 +599,7 @@ def forward( if self.bbox_embed is not None: delta = self.bbox_embed(hidden_states_norm) - reference_points = self.refine_boxes(reference_points, delta) + reference_points = self.refine_bboxes(reference_points, delta) intermediate_reference_points.append(reference_points) reference_points_inputs, query_pos = self.get_reference( diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index d760504378..57331c8c85 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -154,17 +154,17 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.3, + score_thresh: float = 0.5, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 300, + num_queries: int = 100, group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, ff_dim: int = 2048, dec_n_points: int = 2, - dropout_prob: float = 0.0, + dropout_prob: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: dict[str, Any] | None = None, @@ -186,10 +186,11 @@ def __init__( self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) # Initialize angle to (sin=0, cos=1) with torch.no_grad(): - self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) - self.reference_point_embed.weight[:, 2:4].fill_(0.1) - self.reference_point_embed.weight[:, 4].zero_() - self.reference_point_embed.weight[:, 5].fill_(1.0) + self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) # cx, cy + self.reference_point_embed.weight[:, 2].uniform_(0.1, 0.6) # w + self.reference_point_embed.weight[:, 3].uniform_(0.02, 0.3) # h + self.reference_point_embed.weight[:, 4].zero_() # sinθ + self.reference_point_embed.weight[:, 5].fill_(1.0) # cosθ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) @@ -281,8 +282,6 @@ def __init__( if isinstance(last, nn.Linear): nn.init.zeros_(last.weight) nn.init.zeros_(last.bias) - if last.bias.shape[0] == 6: - nn.init.constant_(last.bias[5], 1.0) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -293,41 +292,6 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - """Refine bounding boxes by applying the predicted deltas to the reference points. - The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. - The refined boxes are computed as follows: - - cx' = cx + delta_cx * w - cy' = cy + delta_cy * h - w' = w * exp(delta_w) - h' = h * exp(delta_h) - sinθ' = sinθ * cosΔ + cosθ * sinΔ - cosθ' = cosθ * cosΔ - sinθ * sinΔ - - Args: - reference_points: (N, S, 6) tensor containing the reference points - deltas: (N, S, 6) tensor containing the predicted deltas - - Returns: - refined_boxes: (N, S, 6) tensor containing the refined bounding boxes - """ - reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - # size - wh = torch.clamp(deltas[..., 2:4], min=-10.0, max=10.0).exp() * reference_points[..., 2:4] - # rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - # compose rotations - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - return torch.cat((cxcy, wh, rot), dim=-1) - def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Get the valid ratio of all feature maps. @@ -402,7 +366,7 @@ def gen_encoder_output_proposals( _cur += height * width output_proposals = torch.cat(proposals, 1) - spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) + spatial_valid = ((output_proposals[..., :2] > 0.01) & (output_proposals[..., :2] < 0.99)).all(-1, keepdim=True) output_proposals_valid = spatial_valid invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) @@ -484,7 +448,7 @@ def forward( group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) - group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) + group_enc_outputs_coord = self.decoder.refine_bboxes(output_proposals, group_delta_bbox) all_group_enc_coords.append(group_enc_outputs_coord) @@ -507,7 +471,7 @@ def forward( topk_coords_logits = torch.cat(topk_coords_logits_list, 1) - reference_points = self.refine_bboxes(topk_coords_logits, reference_points) + reference_points = self.decoder.refine_bboxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, @@ -543,33 +507,19 @@ def _postprocess(logits, boxes): # Build target processed_targets = self.build_target(target, self.class_names) - # Main loss from final decoder layer (group DETR) - split_logits = logits.chunk(group_detr, dim=1) - split_boxes = pred_boxes.chunk(group_detr, dim=1) - - main_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_logits, split_boxes): - main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss = main_loss / group_detr + # Main loss from final decoder layer + loss = self.compute_loss(logits, pred_boxes, processed_targets) - # Auxiliary losses from intermediate decoder layers (group DETR) + # Auxiliary losses from intermediate decoder layers for i in range(intermediate.shape[0] - 1): aux_logits = self.class_embed(intermediate[i]) aux_boxes = intermediate_reference_points[i + 1] - - split_aux_logits = aux_logits.chunk(group_detr, dim=1) - split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) - - aux_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): - aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += aux_loss / group_detr + loss += self.compute_loss(aux_logits, aux_boxes, processed_targets) # Auxiliary losses for encoder proposals - enc_loss: float | torch.Tensor = 0.0 - for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): - enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += enc_loss / group_detr + enc_logits = torch.cat(all_group_enc_logits, dim=1) + enc_coords = torch.cat(all_group_enc_coords, dim=1) + loss += 0.2 * self.compute_loss(enc_logits, enc_coords, processed_targets) out["loss"] = loss @@ -581,36 +531,41 @@ def compute_loss( pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]], ) -> torch.Tensor: - """Compute the loss between predicted logits and boxes and target labels and boxes. + """Compute the loss using Grouped Hungarian Matching + and consistent ProbIoU semantics for rotated bounding boxes. Args: - logits: (N, S, C) tensor containing the predicted class logits for each query - pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format - targets: list of length N, where each element is a dict with keys "labels" and "boxes", + logits: (B, Q, C) tensor containing the predicted class logits for each query + pred_boxes: (B, Q, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + targets: list of length B, where each element is a dict with keys "labels" and "boxes", containing the ground truth labels and boxes for each image in the batch. - The boxes are in (cx, cy, w, h, sinθ, cosθ) format. Returns: A scalar tensor containing the computed loss. """ + device = logits.device + dtype = logits.dtype + B, Q, C = logits.shape + + # Consistent coefficients across matcher and loss components + class_weight = 2.0 + bbox_weight = 5.0 + probiou_weight = 2.0 + rot_weight = 0.5 + + # Focal Loss Params + alpha = 0.25 + gamma = 2.0 + eps = 1e-7 + + group_detr = getattr(self, "group_detr", 1) def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format - to Gaussian distributions (mean and covariance). - The mean is simply (cx, cy), and the covariance is computed from the width, height, and rotation angle. - - Args: - boxes: (N, S, 6) tensor containing the rotated boxes in (cx, cy, w, h, sinθ, cosθ) format - Returns: - A tuple of (mean, covariance) where: - - mean is a (N, S, 2) tensor containing the mean (cx, cy) of the Gaussian distributions - - covariance is a (N, S, 2, 2) tensor containing the covariance matrices of the Gaussian distributions - """ + """Convert rotated boxes to Gaussian distributions using the true + variance of a uniform continuous rectangle (w^2 / 12).""" cxcy = boxes[..., :2] - w = boxes[..., 2].clamp(min=1e-6) h = boxes[..., 3].clamp(min=1e-6) - sin = boxes[..., 4] cos = boxes[..., 5] @@ -622,130 +577,184 @@ def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch dim=-2, ) - # Variance for a box half-width/half-height: σ² = (w/2)² - # Using w²/12 (uniform distribution) produces ~8x smaller variance, - # which collapses Bhattacharyya distance to the clamp ceiling and kills gradients. - sx = (w / 2) ** 2 - sy = (h / 2) ** 2 + sx = (w**2) / 12.0 + sy = (h**2) / 12.0 - S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) + S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device, dtype=boxes.dtype) S[..., 0, 0] = sx S[..., 1, 1] = sy covariance = R @ S @ R.transpose(-1, -2) return cxcy, covariance - def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: - """Compute the ProbIoU loss between predicted and target boxes, - where boxes are represented as Gaussian distributions. - The ProbIoU loss is defined as 1 - exp(-Bhattacharyya distance), - where the Bhattacharyya distance is computed between the two Gaussian distributions. - - Args: - pred_boxes: (N, S, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format - tgt_boxes: (N, S, 6) tensor containing the target boxes in (cx, cy, w, h, sinθ, cosθ) format - Returns: - A (N, S) tensor containing the ProbIoU loss for each pair of predicted and target boxes - """ - mu1, sigma1 = _rotated_boxes_to_gaussian(pred_boxes) - mu2, sigma2 = _rotated_boxes_to_gaussian(tgt_boxes) - + def _bhattacharyya_distance( + mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, sigma2: torch.Tensor + ) -> torch.Tensor: + """Compute Bhattacharyya distance with broadcast support.""" delta = (mu1 - mu2).unsqueeze(-1) sigma = (sigma1 + sigma2) * 0.5 - eps = 1e-6 - eye = torch.eye(2, device=sigma.device) * eps - + eye = torch.eye(2, device=sigma.device, dtype=sigma.dtype) * 1e-6 sigma_safe = sigma + eye sigma1_safe = sigma1 + eye sigma2_safe = sigma2 + eye - sigma_inv = torch.linalg.inv(sigma_safe) + L = torch.linalg.cholesky(sigma_safe) + sigma_inv = torch.cholesky_inverse(L) mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) - det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) - det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) + det_sigma = torch.linalg.det(sigma_safe).clamp(min=1e-6) + det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=1e-6) + det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=1e-6) bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) + return bhattacharyya.clamp(min=0.0) + + # Prepare targets for matching + target_labels = [] + target_boxes = [] + sizes = [] + for t in targets: + lbls = torch.as_tensor(t["labels"], device=device, dtype=torch.long) + bxs = torch.as_tensor(t["boxes"], device=device, dtype=pred_boxes.dtype) + if bxs.ndim == 1 and bxs.numel() > 0: + bxs = bxs.unsqueeze(0) + target_labels.append(lbls) + target_boxes.append(bxs) + sizes.append(len(lbls)) + + # Unified formulation for empty batches + if sum(sizes) == 0: + prob = logits.sigmoid() + prob_safe = prob.clamp(min=eps, max=1.0 - eps) + neg_weights = prob.pow(gamma) + loss_ce = -neg_weights * (1.0 - prob_safe).log() + return class_weight * (loss_ce.sum() / (B * Q)) + + tgt_ids = torch.cat(target_labels) + tgt_bbox = torch.cat(target_boxes) + + # Matcher: Grouped Hungarian Assignment with a balanced cost matrix + with torch.no_grad(): + out_prob = logits.flatten(0, 1).sigmoid() + out_bbox = pred_boxes.flatten(0, 1) - bhattacharyya = torch.clamp(bhattacharyya, min=0.0, max=10.0) - probiou = torch.exp(-bhattacharyya) - return 1 - probiou - - device = logits.device - B, Q, C = logits.shape - - total_cls = torch.tensor(0.0, device=device) - total_box = torch.tensor(0.0, device=device) - - num_matched_total = 0 - - for b in range(B): - pred_logits = logits[b] - pred_boxes_b = pred_boxes[b] - - boxes = targets[b]["boxes"] - - if len(boxes) == 0: - # Penalize the model for any foreground boxes it guessed on this empty image - background_idx = self.num_classes - 1 - target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) - total_cls += F.cross_entropy(pred_logits, target_classes) - continue - - tgt_boxes = torch.as_tensor(boxes, device=device, dtype=pred_boxes.dtype) - tgt_cls = torch.as_tensor(targets[b]["labels"], device=device, dtype=torch.long) - - if tgt_boxes.ndim == 1: - tgt_boxes = tgt_boxes.unsqueeze(0) - - pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) - tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) - - with torch.no_grad(): - out_logprob = pred_logits.log_softmax(-1) - - cost_cls = -out_logprob[:, tgt_cls] - cost_l1 = torch.cdist(pred_boxes_b[:, :4], tgt_boxes[:, :4], p=1) - cost_rot = 1.0 - torch.abs(pred_rot @ tgt_rot.T) - - total_cost = 2.0 * cost_cls + 5.0 * cost_l1 + 2.0 * cost_rot - - cost_np = total_cost.detach().cpu().numpy() - row_ind, col_ind = linear_sum_assignment(cost_np) - - pos_idx = torch.as_tensor(row_ind, device=device) - gt_idx = torch.as_tensor(col_ind, device=device) - - background_idx = self.num_classes - 1 - - target_classes = torch.full((Q,), background_idx, device=device, dtype=torch.long) - target_classes[pos_idx] = tgt_cls[gt_idx] - - cls_weights = torch.ones(self.num_classes, device=device) - cls_weights[background_idx] = 0.1 - - total_cls += F.cross_entropy(pred_logits, target_classes, weight=cls_weights) - - if pos_idx.numel() == 0: - continue + # Classification Cost (Focal Loss based) + neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + eps).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + eps).log()) + class_cost = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] - num_matched_total += pos_idx.numel() + # Box L1 Cost + out_bbox_f = out_bbox.to(torch.float32) + tgt_bbox_f = tgt_bbox.to(torch.float32) + bbox_cost = torch.cdist(out_bbox_f[:, :4], tgt_bbox_f[:, :4], p=1).to(dtype) - pred_sel = pred_boxes_b[pos_idx] - tgt_sel = tgt_boxes[gt_idx] + # ProbIoU Cost + mu_pred, sig_pred = _rotated_boxes_to_gaussian(out_bbox_f) + mu_tgt, sig_tgt = _rotated_boxes_to_gaussian(tgt_bbox_f) - l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4], reduction="sum", beta=0.1) - probiou_loss = _probiou_loss(pred_sel, tgt_sel).sum() - total_box += 5.0 * l1_loss + 2.0 * probiou_loss + bhat_dist = _bhattacharyya_distance( + mu_pred.unsqueeze(1), sig_pred.unsqueeze(1), mu_tgt.unsqueeze(0), sig_tgt.unsqueeze(0) + ) + probiou_cost = (1.0 - torch.exp(-bhat_dist)).to(dtype) + + # Rotation Cost + pred_rot = F.normalize(out_bbox_f[:, 4:6], dim=-1) + tgt_rot = F.normalize(tgt_bbox_f[:, 4:6], dim=-1) + rot_cost = (1.0 - torch.abs(pred_rot @ tgt_rot.T)).to(dtype) + + # Total balanced Cost Matrix + cost_matrix = ( + class_weight * class_cost + + bbox_weight * bbox_cost + + probiou_weight * probiou_cost + + rot_weight * rot_cost + ) + cost_matrix = cost_matrix.view(B, Q, -1).cpu() + + # Grouped Hungarian Assignment + indices = [] + group_num_queries = Q // group_detr + cost_matrix_groups = cost_matrix.split(group_num_queries, dim=1) + + for group_id in range(group_detr): + group_cost_matrix = cost_matrix_groups[group_id] + + # Split targets per batch element + group_indices = [] + for i, c in enumerate(group_cost_matrix.split(sizes, -1)): + if sizes[i] == 0: + group_indices.append((np.array([], dtype=np.int64), np.array([], dtype=np.int64))) + else: + row_ind, col_ind = linear_sum_assignment(c[i].numpy()) + group_indices.append((row_ind, col_ind)) + + if group_id == 0: + indices = group_indices + else: + indices = [ + ( + np.concatenate([idx1[0], idx2[0] + group_num_queries * group_id]), + np.concatenate([idx1[1], idx2[1]]), + ) + for idx1, idx2 in zip(indices, group_indices) + ] + + # Image lovel loss normalization: scale by the number of matched boxes, + # and the number of active groups in group DETR + # Scale denominator by the number of active assignment groups + num_boxes = max(sum(sizes) * group_detr, 1) + + batch_idx = torch.cat([torch.full((len(src),), i, dtype=torch.long) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([torch.as_tensor(src, dtype=torch.long) for (src, _) in indices]) + + flat_tgt_idx_list = [] + offset = 0 + for i, (_, tgt) in enumerate(indices): + flat_tgt_idx_list.append(torch.as_tensor(tgt, dtype=torch.long) + offset) + offset += sizes[i] + flat_tgt_idx = torch.cat(flat_tgt_idx_list) + + target_classes_o = tgt_ids[flat_tgt_idx] + src_boxes = pred_boxes[batch_idx, src_idx] + target_boxes_matched = tgt_bbox[flat_tgt_idx] + + # Label Loss with Quality Mapping + prob = logits.sigmoid() + + mu1, sig1 = _rotated_boxes_to_gaussian(src_boxes.detach().to(torch.float32)) + mu2, sig2 = _rotated_boxes_to_gaussian(target_boxes_matched.detach().to(torch.float32)) + bhat_matched = _bhattacharyya_distance(mu1, sig1, mu2, sig2) + pos_ious = torch.exp(-bhat_matched).clamp(min=0.0, max=1.0).to(dtype) + + pos_weights = torch.zeros_like(logits) + neg_weights = prob.pow(gamma) + pos_ind = (batch_idx, src_idx, target_classes_o) + + pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) + pos_quality = torch.clamp(pos_quality, 0.01).detach() + + pos_weights[pos_ind] = pos_quality + neg_weights[pos_ind] = 1 - pos_quality + + # AMP safety for log computation + prob_safe = prob.clamp(min=eps, max=1.0 - eps) + loss_ce = -pos_weights * prob_safe.log() - neg_weights * (1.0 - prob_safe).log() + loss_ce = loss_ce.sum() / num_boxes + + # Bounding Box Loss + loss_bbox = ( + F.smooth_l1_loss(src_boxes[:, :4], target_boxes_matched[:, :4], reduction="sum", beta=0.1) / num_boxes + ) - num_matched_total = max(num_matched_total, 1) - loss_cls = total_cls / B - loss_box = total_box / num_matched_total + # ProbIoU Loss + mu1_l, sig1_l = _rotated_boxes_to_gaussian(src_boxes.to(torch.float32)) + mu2_l, sig2_l = _rotated_boxes_to_gaussian(target_boxes_matched.to(torch.float32)) + bhat_loss = _bhattacharyya_distance(mu1_l, sig1_l, mu2_l, sig2_l) + loss_probiou = (1.0 - torch.exp(-bhat_loss)).to(dtype).sum() / num_boxes - return loss_cls + loss_box + return class_weight * loss_ce + bbox_weight * loss_bbox + probiou_weight * loss_probiou def _lw_detr( From 408797b473fef795ab98c6cd9fc75285e9b2c0f0 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 09:22:25 +0200 Subject: [PATCH 6/9] Update train script --- references/layout/train.py | 91 +++++++++++++++++++++++++++++++------- 1 file changed, 75 insertions(+), 16 deletions(-) diff --git a/references/layout/train.py b/references/layout/train.py index f3f0d2c117..b9a3a29b03 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -17,7 +17,14 @@ # The following import is required for DDP import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + LinearLR, + MultiplicativeLR, + OneCycleLR, + PolynomialLR, + SequentialLR, +) from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort @@ -82,14 +89,14 @@ def record_lr( scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) optimizer.step() # Update LR scheduler.step() @@ -130,14 +137,14 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) optimizer.step() scheduler.step() @@ -472,20 +479,29 @@ def main(args): # construct DDP model model = DDP(model, device_ids=[rank]) + backbone_params = [p for n, p in model.named_parameters() if n.startswith("feat_extractor.") and p.requires_grad] + decoder_params = [p for n, p in model.named_parameters() if not n.startswith("feat_extractor.") and p.requires_grad] + # Optimizer if args.optim == "adam": optimizer = torch.optim.Adam( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.95, 0.999), + [ + {"params": backbone_params, "lr": 1e-5, "weight_decay": args.weight_decay or 1e-4}, + {"params": decoder_params, "lr": args.lr, "weight_decay": args.weight_decay or 1e-4}, + ], + lr=args.lr, + betas=(0.9, 0.999), eps=1e-6, - weight_decay=args.weight_decay, + weight_decay=args.weight_decay or 1e-4, ) elif args.optim == "adamw": optimizer = torch.optim.AdamW( - [p for p in model.parameters() if p.requires_grad], - args.lr, + [ + {"params": backbone_params, "lr": 1e-5, "weight_decay": args.weight_decay or 1e-4}, + {"params": decoder_params, "lr": args.lr, "weight_decay": args.weight_decay or 1e-4}, + ], + lr=args.lr, betas=(0.9, 0.999), eps=1e-6, weight_decay=args.weight_decay or 1e-4, @@ -498,12 +514,55 @@ def main(args): return # Scheduler + total_steps = args.epochs * len(train_loader) + warmup_steps = min(1000, max(200, total_steps // 20)) + if args.sched == "cosine": - scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=warmup_steps, + ) + cosine = CosineAnnealingLR( + optimizer, + T_max=total_steps - warmup_steps, + eta_min=args.lr * 0.01, + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup, cosine], + milestones=[warmup_steps], + ) + elif args.sched == "onecycle": - scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + scheduler = OneCycleLR( + optimizer, + max_lr=[g["lr"] for g in optimizer.param_groups], + total_steps=total_steps, + pct_start=warmup_steps / total_steps, + div_factor=100, + final_div_factor=100, + anneal_strategy="cos", + ) + elif args.sched == "poly": - scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=warmup_steps, + ) + poly = PolynomialLR( + optimizer, + total_iters=total_steps - warmup_steps, + power=1.0, + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup, poly], + milestones=[warmup_steps], + ) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -690,8 +749,8 @@ def parse_args(): "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" ) parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") - parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") - parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") + parser.add_argument("--lr", type=float, default=4e-4, help="learning rate for the optimizer (Adam or AdamW)") + parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, help="weight decay", dest="weight_decay") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") From d8f5d762fe76743976de46d2cc102c5ac0ac2b9f Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 11:02:03 +0200 Subject: [PATCH 7/9] Update layout post proc --- doctr/models/layout/lw_detr/base.py | 7 +++++-- tests/pytorch/test_models_layout.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 9c6890e6b0..a33a050397 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -146,8 +146,11 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int # Convert logits to probabilities and get scores and labels prob = 1.0 / (1.0 + np.exp(-logits[b])) - scores = prob.max(axis=-1) - labels = prob.argmax(axis=-1) + # Remove background class + prob_fg = prob[:, :-1] + + scores = prob_fg.max(axis=-1) + labels = prob_fg.argmax(axis=-1) # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py index d0c2865411..1094eee76f 100644 --- a/tests/pytorch/test_models_layout.py +++ b/tests/pytorch/test_models_layout.py @@ -76,7 +76,7 @@ def test_layout_models(arch_name, input_shape, train_mode, use_polygons): assert isinstance(results[1], np.ndarray) and results[1].shape == (len(results[0]), 4) assert isinstance(results[2], list) and all(isinstance(scores, float) for scores in results[2]) # Check class idxs are in the model's num_classes - assert all(0 <= idx < model.num_classes for idx in results[0]) + assert all(0 <= idx < len(model.class_names) for idx in results[0]) # Check scores are between 0 and 1 assert all(0 <= score <= 1 for score in results[2]) # Check that the number of boxes, labels and scores are the same From f2e239eb16b6e32a265091d70320d2e0dff01983 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 11:23:12 +0200 Subject: [PATCH 8/9] Update loss --- doctr/models/layout/lw_detr/pytorch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 57331c8c85..c329ddffb3 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -754,7 +754,12 @@ def _bhattacharyya_distance( bhat_loss = _bhattacharyya_distance(mu1_l, sig1_l, mu2_l, sig2_l) loss_probiou = (1.0 - torch.exp(-bhat_loss)).to(dtype).sum() / num_boxes - return class_weight * loss_ce + bbox_weight * loss_bbox + probiou_weight * loss_probiou + # Rotation Loss + pred_rot = F.normalize(src_boxes[:, 4:6], dim=-1, eps=1e-6) + tgt_rot = F.normalize(target_boxes_matched[:, 4:6], dim=-1, eps=1e-6) + loss_rot = (1.0 - torch.abs((pred_rot * tgt_rot).sum(dim=-1))).sum() / num_boxes + + return class_weight * loss_ce + bbox_weight * loss_bbox + probiou_weight * loss_probiou + rot_weight * loss_rot def _lw_detr( From ad1693c802d7dfc2d64a1f1d6589c3267b88b381 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 5 Jun 2026 12:23:29 +0200 Subject: [PATCH 9/9] amp --- doctr/models/layout/lw_detr/pytorch.py | 31 ++++++++++--------- references/classification/train_character.py | 10 +++--- .../classification/train_orientation.py | 10 +++--- references/detection/evaluate.py | 2 +- references/detection/train.py | 10 +++--- references/layout/evaluate.py | 2 +- references/layout/train.py | 10 +++--- references/recognition/evaluate.py | 2 +- references/recognition/train.py | 10 +++--- 9 files changed, 45 insertions(+), 42 deletions(-) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index c329ddffb3..fdb59fb2c6 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -157,7 +157,7 @@ def __init__( score_thresh: float = 0.5, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 100, + num_queries: int = 195, # This is different from the paper which uses 300 queries, but 195 queries is sufficient for document layout analysis) # noqa: E501 group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, @@ -507,21 +507,25 @@ def _postprocess(logits, boxes): # Build target processed_targets = self.build_target(target, self.class_names) - # Main loss from final decoder layer - loss = self.compute_loss(logits, pred_boxes, processed_targets) + # Disable mixed precision for loss computation to ensure numerical stability, + # especially for the Bhattacharyya distance which involves + # logarithms and determinants of covariance matrices. + with torch.autocast(device_type=logits.device.type, enabled=False): + # Main loss from final decoder layer + loss = self.compute_loss(logits.float(), pred_boxes.float(), processed_targets) - # Auxiliary losses from intermediate decoder layers - for i in range(intermediate.shape[0] - 1): - aux_logits = self.class_embed(intermediate[i]) - aux_boxes = intermediate_reference_points[i + 1] - loss += self.compute_loss(aux_logits, aux_boxes, processed_targets) + # Auxiliary losses from intermediate decoder layers + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]).float() + aux_boxes = intermediate_reference_points[i + 1].float() + loss += self.compute_loss(aux_logits, aux_boxes, processed_targets) - # Auxiliary losses for encoder proposals - enc_logits = torch.cat(all_group_enc_logits, dim=1) - enc_coords = torch.cat(all_group_enc_coords, dim=1) - loss += 0.2 * self.compute_loss(enc_logits, enc_coords, processed_targets) + # Auxiliary losses for encoder proposals + enc_logits = torch.cat(all_group_enc_logits, dim=1).float() + enc_coords = torch.cat(all_group_enc_coords, dim=1).float() + loss += 0.2 * self.compute_loss(enc_logits, enc_coords, processed_targets) - out["loss"] = loss + out["loss"] = loss return out @@ -738,7 +742,6 @@ def _bhattacharyya_distance( pos_weights[pos_ind] = pos_quality neg_weights[pos_ind] = 1 - pos_quality - # AMP safety for log computation prob_safe = prob.clamp(min=eps, max=1.0 - eps) loss_ce = -pos_weights * prob_safe.log() - neg_weights * (1.0 - prob_safe).log() loss_ce = loss_ce.sum() / num_boxes diff --git a/references/classification/train_character.py b/references/classification/train_character.py index c266bb2f5c..22a9298ecb 100644 --- a/references/classification/train_character.py +++ b/references/classification/train_character.py @@ -66,7 +66,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): targets = torch.tensor(targets) @@ -79,7 +79,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -110,7 +110,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -126,7 +126,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -168,7 +168,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): targets = targets.cuda() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) loss = cross_entropy(out, targets) else: diff --git a/references/classification/train_orientation.py b/references/classification/train_orientation.py index 86dc4c5931..0cbbf05616 100644 --- a/references/classification/train_orientation.py +++ b/references/classification/train_orientation.py @@ -78,7 +78,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): targets = torch.tensor(targets) @@ -91,7 +91,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -122,7 +122,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -138,7 +138,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -180,7 +180,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): targets = targets.cuda() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) loss = cross_entropy(out, targets) else: diff --git a/references/detection/evaluate.py b/references/detection/evaluate.py index 7c5fb597aa..a8674884f4 100644 --- a/references/detection/evaluate.py +++ b/references/detection/evaluate.py @@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): images = batch_transforms(images) targets = [{CLASS_NAME: t} for t in targets] if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/detection/train.py b/references/detection/train.py index 43e4ec3a56..37b439a6e5 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -63,7 +63,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): @@ -74,7 +74,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -107,7 +107,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -120,7 +120,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -163,7 +163,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False, l images = images.cuda() images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/layout/evaluate.py b/references/layout/evaluate.py index 36fb78cf80..aa8ea788aa 100644 --- a/references/layout/evaluate.py +++ b/references/layout/evaluate.py @@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): padding_masks = padding_masks.cuda() imgs = batch_transforms(imgs) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(imgs, padding_masks, targets, return_preds=True) else: out = model(imgs, padding_masks, targets, return_preds=True) diff --git a/references/layout/train.py b/references/layout/train.py index b9a3a29b03..bf8a3a600d 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -70,7 +70,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): imgs, padding_masks = images @@ -84,7 +84,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(imgs, padding_masks, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -117,7 +117,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -132,7 +132,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(imgs, padding_masks, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -177,7 +177,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non padding_masks = padding_masks.cuda() imgs = batch_transforms(imgs) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(imgs, padding_masks, targets, return_preds=True) else: out = model(imgs, padding_masks, targets, return_preds=True) diff --git a/references/recognition/evaluate.py b/references/recognition/evaluate.py index 45a6b38306..22f69fa2cd 100644 --- a/references/recognition/evaluate.py +++ b/references/recognition/evaluate.py @@ -38,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): images = images.cuda() images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/recognition/train.py b/references/recognition/train.py index dc6b7b1b24..fd0cd1826d 100644 --- a/references/recognition/train.py +++ b/references/recognition/train.py @@ -68,7 +68,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): @@ -79,7 +79,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -112,7 +112,7 @@ def record_lr( def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -125,7 +125,7 @@ def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, sche optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -167,7 +167,7 @@ def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False, images = images.to(device) images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True)