diff --git a/doctr/models/layout/lw_detr/base.py b/doctr/models/layout/lw_detr/base.py index 2aa59e7d18..a33a050397 100644 --- a/doctr/models/layout/lw_detr/base.py +++ b/doctr/models/layout/lw_detr/base.py @@ -144,16 +144,18 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int for b in range(boxes.shape[0]): # Convert logits to probabilities and get scores and labels - exp = np.exp(logits[b] - logits[b].max(axis=-1, keepdims=True)) - prob = exp / exp.sum(axis=-1, keepdims=True) + prob = 1.0 / (1.0 + np.exp(-logits[b])) + + # Remove background class + prob_fg = prob[:, :-1] - prob_fg = prob[:, :-1] # exclude background scores = prob_fg.max(axis=-1) labels = prob_fg.argmax(axis=-1) # Keep only topk predictions before NMS if self.topk is not None and len(scores) > self.topk: - idxs = np.argsort(scores)[::-1][: self.topk] + idxs = np.argpartition(-scores, self.topk)[: self.topk] + idxs = idxs[np.argsort(-scores[idxs])] else: idxs = np.arange(len(scores)) @@ -161,11 +163,11 @@ def __call__(self, logits: np.ndarray, boxes: np.ndarray) -> list[tuple[list[int labels_b = labels[idxs] bboxes = boxes[b][idxs] - mask = scores_b > self.score_thresh - - bboxes = bboxes[mask] - scores_b = scores_b[mask] - labels_b = labels_b[mask] + # Filter by score threshold + thresh_mask = scores_b >= self.score_thresh + scores_b = scores_b[thresh_mask] + labels_b = labels_b[thresh_mask] + bboxes = bboxes[thresh_mask] polys, _ = ( self._decode_boxes(bboxes) diff --git a/doctr/models/layout/lw_detr/layers/pytorch.py b/doctr/models/layout/lw_detr/layers/pytorch.py index 08f9ea9fb4..f0bc63c865 100644 --- a/doctr/models/layout/lw_detr/layers/pytorch.py +++ b/doctr/models/layout/lw_detr/layers/pytorch.py @@ -250,7 +250,7 @@ def forward( value = self.value_proj(encoder_hidden_states) if attention_mask is not None: # we invert the attention_mask - value = value.masked_fill(~attention_mask[..., None], float(0)) + value = value.masked_fill(attention_mask[..., None], float(0)) value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(hidden_states).view( batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2 @@ -293,6 +293,8 @@ def forward( else: raise ValueError(f"Last dim of reference_points must be 4 or 6, but got {reference_points.shape[-1]}") + # clamp sampling locations to keep them within the valid range [0, 1] for grid sampling + sampling_locations = sampling_locations.clamp(0.0, 1.0) output = self.attn( value, spatial_shapes_list, @@ -409,26 +411,28 @@ def gen_sine_position_embeddings(pos_tensor: torch.Tensor, hidden_size: int = 25 """ scale = 2 * math.pi dim = hidden_size // 2 + # Keep dim_t in float32 for numerical precision; cast output to match caller dtype dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim) - x_embed = pos_tensor[:, :, 0] * scale - y_embed = pos_tensor[:, :, 1] * scale + x_embed = pos_tensor[:, :, 0].float() * scale + y_embed = pos_tensor[:, :, 1].float() * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 4: - w_embed = pos_tensor[:, :, 2] * scale + w_embed = pos_tensor[:, :, 2].float() * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) - h_embed = pos_tensor[:, :, 3] * scale + h_embed = pos_tensor[:, :, 3].float() * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}") + # Cast back to the caller's dtype (supports bfloat16 / float16 AMP) return pos.to(pos_tensor.dtype) @@ -480,11 +484,7 @@ def __init__( self.bbox_embed = bbox_embed self.ref_point_head = LWDETRHead(2 * self.d_model, self.d_model, self.d_model, num_layers=2) - self.angle_proj = nn.Sequential( - nn.Linear(4, self.d_model), - nn.ReLU(), - nn.Linear(self.d_model, self.d_model), - ) + self.angle_proj = nn.Linear(2, self.d_model) def get_reference( self, reference_points: torch.Tensor, valid_ratios: torch.Tensor @@ -498,7 +498,7 @@ def get_reference( tensor containing the valid ratios for each level of the input feature maps Returns: - reference_points_inputs: (batch_size, num_queries, 1, num_levels, 4) + reference_points_inputs: (batch_size, num_queries, 1, num_levels, 6) tensor containing the reference point inputs for the decoder layers, which are the normalized center coordinates, width and height of the bounding boxes w.r.t. the valid ratios of the input feature maps @@ -515,45 +515,54 @@ def get_reference( # DETR positional encoding query_sine_embed = gen_sine_position_embeddings(spatial_inputs[:, :, 0, :], self.d_model) base_query_pos = self.ref_point_head(query_sine_embed) - # Angle embedding - sin_t = angle[..., 0:1] - cos_t = angle[..., 1:2] - - angle_feat = torch.cat( - [ - sin_t, - cos_t, - 2 * sin_t * cos_t, - cos_t**2 - sin_t**2, - ], - dim=-1, - ) - angle_emb = self.angle_proj(angle_feat) + angle_emb = self.angle_proj(angle) # Combine query_pos = base_query_pos + angle_emb return reference_points_inputs, query_pos - def refine_boxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - reference_points = reference_points.to(deltas.device) - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + """Refine bounding boxes by applying the predicted deltas to the reference points. + The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. + The refined boxes are computed as follows: - # Clamp deltas to prevent exp() from shooting to Infinity during early training - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] + cx' = cx + delta_cx * w + cy' = cy + delta_cy * h + w' = w * exp(delta_w) + h' = h * exp(delta_h) + sinθ' = sinθ * cosΔ + cosθ * sinΔ + cosθ' = cosθ * cosΔ - sinθ * sinΔ - # Add eps=1e-6 to avoid division-by-zero NaN creation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) + Args: + reference_points: (N, S, 6) tensor containing the reference points + deltas: (N, S, 6) tensor containing the predicted deltas + + Returns: + refined_boxes: (N, S, 6) tensor containing the refined bounding boxes + """ + reference_points = reference_points.to(deltas.device) + # size + wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=4.0).exp() * reference_points[..., 2:4] + wh = wh.clamp(min=1e-4, max=1.0) + # center + raw_cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] + half_wh = wh / 2 + cxcy = raw_cxcy.clamp( + min=half_wh, + max=1.0 - half_wh, + ) + # rotation + sin_d = deltas[..., 4:5] + cos_d = deltas[..., 5:6] + 1.0 + delta_rot = F.normalize(torch.cat([sin_d, cos_d], dim=-1), dim=-1, eps=1e-6) sin_delta = delta_rot[..., 0:1] cos_delta = delta_rot[..., 1:2] sin_ref = reference_points[..., 4:5] cos_ref = reference_points[..., 5:6] - + # compose rotations sin_new = sin_ref * cos_delta + cos_ref * sin_delta cos_new = cos_ref * cos_delta - sin_ref * sin_delta - - # Add eps=1e-6 here too rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - return torch.cat((cxcy, wh, rot), dim=-1) def forward( @@ -590,11 +599,7 @@ def forward( if self.bbox_embed is not None: delta = self.bbox_embed(hidden_states_norm) - reference_points = self.refine_boxes( - reference_points.squeeze(2), - delta, - ) - + reference_points = self.refine_bboxes(reference_points, delta) intermediate_reference_points.append(reference_points) reference_points_inputs, query_pos = self.get_reference( diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 848d44280e..fdb59fb2c6 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -10,6 +10,7 @@ import numpy as np import torch +from scipy.optimize import linear_sum_assignment from torch import nn from torch.nn import functional as F @@ -153,17 +154,17 @@ def __init__( self, feat_extractor: LWDETRBackbone, class_names: list[str], - score_thresh: float = 0.3, + score_thresh: float = 0.5, iou_thresh: float = 0.5, d_model: int = 256, - num_queries: int = 130, - group_detr: int = 1, + num_queries: int = 195, # This is different from the paper which uses 300 queries, but 195 queries is sufficient for document layout analysis) # noqa: E501 + group_detr: int = 13, dec_layers: int = 3, sa_num_heads: int = 8, ca_num_heads: int = 16, ff_dim: int = 2048, dec_n_points: int = 2, - dropout_prob: float = 0.0, + dropout_prob: float = 0.1, assume_straight_pages: bool = True, exportable: bool = False, cfg: dict[str, Any] | None = None, @@ -185,8 +186,11 @@ def __init__( self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 6) # Initialize angle to (sin=0, cos=1) with torch.no_grad(): - self.reference_point_embed.weight[:, 4] = 0.0 # sinθ - self.reference_point_embed.weight[:, 5] = 1.0 # cosθ + self.reference_point_embed.weight[:, 0:2].uniform_(0.05, 0.95) # cx, cy + self.reference_point_embed.weight[:, 2].uniform_(0.1, 0.6) # w + self.reference_point_embed.weight[:, 3].uniform_(0.02, 0.3) # h + self.reference_point_embed.weight[:, 4].zero_() # sinθ + self.reference_point_embed.weight[:, 5].fill_(1.0) # cosθ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, self.d_model) @@ -241,7 +245,8 @@ def __init__( if hasattr(m, "bias") and m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): - nn.init.normal_(m.weight, std=0.02) + if m is not self.reference_point_embed: + nn.init.normal_(m.weight, std=0.02) elif isinstance(m, LWDETRMultiscaleDeformableAttention): nn.init.constant_(m.sampling_offsets.weight, 0.0) @@ -252,34 +257,31 @@ def __init__( .view(m.n_heads, 1, 1, 2) .repeat(1, m.n_levels, m.n_points, 1) ) - for i in range(m.n_points): grid_init[:, :, i, :] *= i + 1 - with torch.no_grad(): m.sampling_offsets.bias.copy_(grid_init.view(-1)) nn.init.constant_(m.attention_weights.weight, 0.0) nn.init.constant_(m.attention_weights.bias, 0.0) - nn.init.xavier_uniform_(m.value_proj.weight) nn.init.zeros_(m.value_proj.bias) - nn.init.xavier_uniform_(m.output_proj.weight) nn.init.zeros_(m.output_proj.bias) - if isinstance(m, nn.Linear) and m.out_features == self.num_classes: - prior_prob = 0.01 - bias_value = -math.log((1 - prior_prob) / prior_prob) if m.bias is not None: - nn.init.constant_(m.bias, bias_value) + with torch.no_grad(): + # Focal-loss prior: foreground starts with low confidence (~0.01), + # preventing background from dominating gradients at the start of training. + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + nn.init.constant_(m.bias, 0.0) + m.bias[:-1].fill_(bias_value) if isinstance(m, LWDETRHead): last = m.layers[-1] if isinstance(last, nn.Linear): nn.init.zeros_(last.weight) nn.init.zeros_(last.bias) - if last.bias.shape[0] == 6: - nn.init.constant_(last.bias[5], 1.0) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -290,44 +292,6 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """ load_pretrained_params(self, path_or_url, **kwargs) - def refine_bboxes(self, reference_points: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: - """Refine bounding boxes by applying the predicted deltas to the reference points. - The reference points are in the format (cx, cy, w, h, sinθ, cosθ), and the deltas are in the same format. - The refined boxes are computed as follows: - - cx' = cx + delta_cx * w - cy' = cy + delta_cy * h - w' = w * exp(delta_w) - h' = h * exp(delta_h) - sinθ' = sinθ * cosΔ + cosθ * sinΔ - cosθ' = cosθ * cosΔ - sinθ * sinΔ - - Args: - reference_points: (N, S, 6) tensor containing the reference points - deltas: (N, S, 6) tensor containing the predicted deltas - - Returns: - refined_boxes: (N, S, 6) tensor containing the refined bounding boxes - """ - reference_points = reference_points.to(deltas.device) - # center - cxcy = deltas[..., :2] * reference_points[..., 2:4] + reference_points[..., :2] - # size - wh = torch.clamp(deltas[..., 2:4], min=-4.0, max=2.0).exp() * reference_points[..., 2:4] - # rotation - delta_rot = F.normalize(deltas[..., 4:6], dim=-1, eps=1e-6) - sin_delta = delta_rot[..., 0:1] - cos_delta = delta_rot[..., 1:2] - sin_ref = reference_points[..., 4:5] - cos_ref = reference_points[..., 5:6] - - # compose rotations - sin_new = sin_ref * cos_delta + cos_ref * sin_delta - cos_new = cos_ref * cos_delta - sin_ref * sin_delta - rot = F.normalize(torch.cat([sin_new, cos_new], dim=-1), dim=-1, eps=1e-6) - - return torch.cat((cxcy, wh, rot), dim=-1) - def get_valid_ratio(self, mask: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Get the valid ratio of all feature maps. @@ -402,15 +366,15 @@ def gen_encoder_output_proposals( _cur += height * width output_proposals = torch.cat(proposals, 1) - spatial_valid = ((output_proposals[..., :4] > 0.01) & (output_proposals[..., :4] < 0.99)).all(-1, keepdim=True) + spatial_valid = ((output_proposals[..., :2] > 0.01) & (output_proposals[..., :2] < 0.99)).all(-1, keepdim=True) output_proposals_valid = spatial_valid - invalid_mask = padding_mask | ~output_proposals_valid.squeeze(-1) invalid_mask = padding_mask.unsqueeze(-1) | ~output_proposals_valid output_proposals = output_proposals.masked_fill(invalid_mask, float(0)) # assign each pixel as an object query object_query = enc_output - object_query = object_query.masked_fill(invalid_mask, float(0)) + object_query = object_query.masked_fill(invalid_mask, 0.0) + return object_query, output_proposals, invalid_mask def forward( @@ -468,6 +432,7 @@ def forward( topk = self.num_queries topk_coords_logits_list: list[torch.Tensor] = [] + topk_content_list: list[torch.Tensor] = [] # encoder predictions for auxiliary losses all_group_enc_logits: list[torch.Tensor] = [] @@ -483,22 +448,30 @@ def forward( group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) - group_enc_outputs_coord = self.refine_bboxes(output_proposals, group_delta_bbox) + group_enc_outputs_coord = self.decoder.refine_bboxes(output_proposals, group_delta_bbox) all_group_enc_coords.append(group_enc_outputs_coord) - group_topk_proposals = torch.topk(group_enc_outputs_class_masked.max(-1)[0], topk, dim=1)[1] + scores = group_enc_outputs_class_masked[..., :-1].max(-1).values + group_topk_proposals = torch.topk(scores, topk, dim=1)[1] group_topk_coords_logits_undetach = torch.gather( group_enc_outputs_coord, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) - group_topk_coords_logits = group_topk_coords_logits_undetach + group_topk_coords_logits = group_topk_coords_logits_undetach.detach() topk_coords_logits_list.append(group_topk_coords_logits) + group_topk_content = torch.gather( + group_object_query, + 1, + group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model), + ) + topk_content_list.append(group_topk_content) topk_coords_logits = torch.cat(topk_coords_logits_list, 1) - reference_points = self.refine_bboxes(topk_coords_logits, reference_points) + + reference_points = self.decoder.refine_bboxes(topk_coords_logits, reference_points) last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, @@ -506,11 +479,11 @@ def forward( spatial_shapes_list=spatial_shapes_list, valid_ratios=valid_ratios, encoder_hidden_states=source_flatten, + encoder_attention_mask=mask_flatten, ) logits = self.class_embed(last_hidden_states) - pred_boxes_delta = self.bbox_embed(last_hidden_states) - pred_boxes = self.refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta) + pred_boxes = intermediate_reference_points[-1] out: dict[str, Any] = {} @@ -534,75 +507,69 @@ def _postprocess(logits, boxes): # Build target processed_targets = self.build_target(target, self.class_names) - # Main loss from final decoder layer (group DETR) - split_logits = logits.chunk(group_detr, dim=1) - split_boxes = pred_boxes.chunk(group_detr, dim=1) - - main_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_logits, split_boxes): - main_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss = main_loss / group_detr - - # Auxiliary losses from intermediate decoder layers - for i in range(intermediate.shape[0] - 1): - aux_logits = self.class_embed(intermediate[i]) - aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = self.refine_bboxes(intermediate_reference_points[i], aux_boxes_delta) + # Disable mixed precision for loss computation to ensure numerical stability, + # especially for the Bhattacharyya distance which involves + # logarithms and determinants of covariance matrices. + with torch.autocast(device_type=logits.device.type, enabled=False): + # Main loss from final decoder layer + loss = self.compute_loss(logits.float(), pred_boxes.float(), processed_targets) - split_aux_logits = aux_logits.chunk(group_detr, dim=1) - split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) + # Auxiliary losses from intermediate decoder layers + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]).float() + aux_boxes = intermediate_reference_points[i + 1].float() + loss += self.compute_loss(aux_logits, aux_boxes, processed_targets) - aux_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): - aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets) - loss += 0.5 * (aux_loss / group_detr) + # Auxiliary losses for encoder proposals + enc_logits = torch.cat(all_group_enc_logits, dim=1).float() + enc_coords = torch.cat(all_group_enc_coords, dim=1).float() + loss += 0.2 * self.compute_loss(enc_logits, enc_coords, processed_targets) - # Auxiliary losses for encoder proposals - enc_loss: float | torch.Tensor = 0.0 - for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): - enc_loss += self.compute_loss(group_logits, group_coords, processed_targets) - loss += 0.1 * (enc_loss / group_detr) - - out["loss"] = loss + out["loss"] = loss return out def compute_loss( - self, logits: torch.Tensor, pred_boxes: torch.Tensor, targets: list[dict[str, np.ndarray]] + self, + logits: torch.Tensor, + pred_boxes: torch.Tensor, + targets: list[dict[str, np.ndarray]], ) -> torch.Tensor: - """ - Compute the loss for LW-DETR. The loss consists of three components: - classification loss, box regression loss, and rotation loss. - The classification loss is a cross-entropy loss between the predicted class logits and the target classes. - The box regression loss is a Smooth L1 loss between the predicted boxes and the target boxes, - computed only on the positive samples. - The rotation loss is computed as 1 - cosine similarity between the predicted rotation and the target rotation, - averaged over the positive samples. - The positive samples are determined using a SimOTA-like assignment strategy, where for each ground truth box, - we select the top-k queries with the lowest cost - (combination of classification cost, box regression cost, and rotation cost). + """Compute the loss using Grouped Hungarian Matching + and consistent ProbIoU semantics for rotated bounding boxes. Args: logits: (B, Q, C) tensor containing the predicted class logits for each query - pred_boxes: (B, Q, 6) tensor containing the predicted boxes for each query - targets: list of dictionaries where each dictionary corresponds to a sample and has keys corresponding - to class names and values corresponding to lists of boxes in either polygon format (4, 2) - or bounding box format (4,) (xmin, ymin, xmax, ymax) + pred_boxes: (B, Q, 6) tensor containing the predicted boxes in (cx, cy, w, h, sinθ, cosθ) format + targets: list of length B, where each element is a dict with keys "labels" and "boxes", + containing the ground truth labels and boxes for each image in the batch. Returns: - loss: the computed loss value + A scalar tensor containing the computed loss. """ + device = logits.device + dtype = logits.dtype + B, Q, C = logits.shape - def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert rotated boxes in (cx, cy, w, h, sinθ, cosθ) format to Gaussian distribution parameters - (mean and covariance). - """ - cxcy = boxes[..., :2] + # Consistent coefficients across matcher and loss components + class_weight = 2.0 + bbox_weight = 5.0 + probiou_weight = 2.0 + rot_weight = 0.5 + # Focal Loss Params + alpha = 0.25 + gamma = 2.0 + eps = 1e-7 + + group_detr = getattr(self, "group_detr", 1) + + def _rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert rotated boxes to Gaussian distributions using the true + variance of a uniform continuous rectangle (w^2 / 12).""" + cxcy = boxes[..., :2] w = boxes[..., 2].clamp(min=1e-6) h = boxes[..., 3].clamp(min=1e-6) - sin = boxes[..., 4] cos = boxes[..., 5] @@ -614,143 +581,188 @@ def rotated_boxes_to_gaussian(boxes: torch.Tensor) -> tuple[torch.Tensor, torch. dim=-2, ) - sx = (w / 2) ** 2 - sy = (h / 2) ** 2 - - S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device) + sx = (w**2) / 12.0 + sy = (h**2) / 12.0 + S = torch.zeros((*boxes.shape[:-1], 2, 2), device=boxes.device, dtype=boxes.dtype) S[..., 0, 0] = sx S[..., 1, 1] = sy covariance = R @ S @ R.transpose(-1, -2) return cxcy, covariance - def _probiou_loss(pred_boxes: torch.Tensor, tgt_boxes: torch.Tensor) -> torch.Tensor: - """Compute the ProbIoU loss between predicted boxes and target boxes.""" - mu1, sigma1 = rotated_boxes_to_gaussian(pred_boxes) - mu2, sigma2 = rotated_boxes_to_gaussian(tgt_boxes) - + def _bhattacharyya_distance( + mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, sigma2: torch.Tensor + ) -> torch.Tensor: + """Compute Bhattacharyya distance with broadcast support.""" delta = (mu1 - mu2).unsqueeze(-1) sigma = (sigma1 + sigma2) * 0.5 - eps = 1e-6 - eye = torch.eye(2, device=sigma.device) * eps + eye = torch.eye(2, device=sigma.device, dtype=sigma.dtype) * 1e-6 sigma_safe = sigma + eye sigma1_safe = sigma1 + eye sigma2_safe = sigma2 + eye - sigma_inv = torch.linalg.inv(sigma_safe) + L = torch.linalg.cholesky(sigma_safe) + sigma_inv = torch.cholesky_inverse(L) mahalanobis = (delta.transpose(-1, -2) @ sigma_inv @ delta).squeeze(-1).squeeze(-1) - det_sigma = torch.linalg.det(sigma_safe).clamp(min=eps) - det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=eps) - det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=eps) + det_sigma = torch.linalg.det(sigma_safe).clamp(min=1e-6) + det_sigma1 = torch.linalg.det(sigma1_safe).clamp(min=1e-6) + det_sigma2 = torch.linalg.det(sigma2_safe).clamp(min=1e-6) bhattacharyya = 0.125 * mahalanobis + 0.5 * torch.log(det_sigma / torch.sqrt(det_sigma1 * det_sigma2)) + return bhattacharyya.clamp(min=0.0) + + # Prepare targets for matching + target_labels = [] + target_boxes = [] + sizes = [] + for t in targets: + lbls = torch.as_tensor(t["labels"], device=device, dtype=torch.long) + bxs = torch.as_tensor(t["boxes"], device=device, dtype=pred_boxes.dtype) + if bxs.ndim == 1 and bxs.numel() > 0: + bxs = bxs.unsqueeze(0) + target_labels.append(lbls) + target_boxes.append(bxs) + sizes.append(len(lbls)) + + # Unified formulation for empty batches + if sum(sizes) == 0: + prob = logits.sigmoid() + prob_safe = prob.clamp(min=eps, max=1.0 - eps) + neg_weights = prob.pow(gamma) + loss_ce = -neg_weights * (1.0 - prob_safe).log() + return class_weight * (loss_ce.sum() / (B * Q)) + + tgt_ids = torch.cat(target_labels) + tgt_bbox = torch.cat(target_boxes) + + # Matcher: Grouped Hungarian Assignment with a balanced cost matrix + with torch.no_grad(): + out_prob = logits.flatten(0, 1).sigmoid() + out_bbox = pred_boxes.flatten(0, 1) - probiou = torch.exp(-bhattacharyya) - return 1 - probiou - - device = logits.device - B, Q, C = logits.shape + # Classification Cost (Focal Loss based) + neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + eps).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + eps).log()) + class_cost = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] - total_cls = torch.tensor(0.0, device=device) - total_box = torch.tensor(0.0, device=device) - total_rot = torch.tensor(0.0, device=device) + # Box L1 Cost + out_bbox_f = out_bbox.to(torch.float32) + tgt_bbox_f = tgt_bbox.to(torch.float32) + bbox_cost = torch.cdist(out_bbox_f[:, :4], tgt_bbox_f[:, :4], p=1).to(dtype) - for b in range(B): - pred_logits = logits[b] - pred_boxes_b = pred_boxes[b] + # ProbIoU Cost + mu_pred, sig_pred = _rotated_boxes_to_gaussian(out_bbox_f) + mu_tgt, sig_tgt = _rotated_boxes_to_gaussian(tgt_bbox_f) - tgt_boxes = torch.as_tensor( - targets[b]["boxes"], - device=device, - dtype=pred_boxes.dtype, + bhat_dist = _bhattacharyya_distance( + mu_pred.unsqueeze(1), sig_pred.unsqueeze(1), mu_tgt.unsqueeze(0), sig_tgt.unsqueeze(0) ) - tgt_cls = torch.as_tensor( - targets[b]["labels"], - device=device, - dtype=torch.long, + probiou_cost = (1.0 - torch.exp(-bhat_dist)).to(dtype) + + # Rotation Cost + pred_rot = F.normalize(out_bbox_f[:, 4:6], dim=-1) + tgt_rot = F.normalize(tgt_bbox_f[:, 4:6], dim=-1) + rot_cost = (1.0 - torch.abs(pred_rot @ tgt_rot.T)).to(dtype) + + # Total balanced Cost Matrix + cost_matrix = ( + class_weight * class_cost + + bbox_weight * bbox_cost + + probiou_weight * probiou_cost + + rot_weight * rot_cost ) + cost_matrix = cost_matrix.view(B, Q, -1).cpu() + + # Grouped Hungarian Assignment + indices = [] + group_num_queries = Q // group_detr + cost_matrix_groups = cost_matrix.split(group_num_queries, dim=1) + + for group_id in range(group_detr): + group_cost_matrix = cost_matrix_groups[group_id] + + # Split targets per batch element + group_indices = [] + for i, c in enumerate(group_cost_matrix.split(sizes, -1)): + if sizes[i] == 0: + group_indices.append((np.array([], dtype=np.int64), np.array([], dtype=np.int64))) + else: + row_ind, col_ind = linear_sum_assignment(c[i].numpy()) + group_indices.append((row_ind, col_ind)) + + if group_id == 0: + indices = group_indices + else: + indices = [ + ( + np.concatenate([idx1[0], idx2[0] + group_num_queries * group_id]), + np.concatenate([idx1[1], idx2[1]]), + ) + for idx1, idx2 in zip(indices, group_indices) + ] + + # Image lovel loss normalization: scale by the number of matched boxes, + # and the number of active groups in group DETR + # Scale denominator by the number of active assignment groups + num_boxes = max(sum(sizes) * group_detr, 1) + + batch_idx = torch.cat([torch.full((len(src),), i, dtype=torch.long) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([torch.as_tensor(src, dtype=torch.long) for (src, _) in indices]) + + flat_tgt_idx_list = [] + offset = 0 + for i, (_, tgt) in enumerate(indices): + flat_tgt_idx_list.append(torch.as_tensor(tgt, dtype=torch.long) + offset) + offset += sizes[i] + flat_tgt_idx = torch.cat(flat_tgt_idx_list) + + target_classes_o = tgt_ids[flat_tgt_idx] + src_boxes = pred_boxes[batch_idx, src_idx] + target_boxes_matched = tgt_bbox[flat_tgt_idx] + + # Label Loss with Quality Mapping + prob = logits.sigmoid() + + mu1, sig1 = _rotated_boxes_to_gaussian(src_boxes.detach().to(torch.float32)) + mu2, sig2 = _rotated_boxes_to_gaussian(target_boxes_matched.detach().to(torch.float32)) + bhat_matched = _bhattacharyya_distance(mu1, sig1, mu2, sig2) + pos_ious = torch.exp(-bhat_matched).clamp(min=0.0, max=1.0).to(dtype) + + pos_weights = torch.zeros_like(logits) + neg_weights = prob.pow(gamma) + pos_ind = (batch_idx, src_idx, target_classes_o) + + pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha) + pos_quality = torch.clamp(pos_quality, 0.01).detach() + + pos_weights[pos_ind] = pos_quality + neg_weights[pos_ind] = 1 - pos_quality + + prob_safe = prob.clamp(min=eps, max=1.0 - eps) + loss_ce = -pos_weights * prob_safe.log() - neg_weights * (1.0 - prob_safe).log() + loss_ce = loss_ce.sum() / num_boxes + + # Bounding Box Loss + loss_bbox = ( + F.smooth_l1_loss(src_boxes[:, :4], target_boxes_matched[:, :4], reduction="sum", beta=0.1) / num_boxes + ) - num_gt = len(tgt_cls) - - pred_rot = F.normalize(pred_boxes_b[:, 4:6], dim=-1) - tgt_rot = F.normalize(tgt_boxes[:, 4:6], dim=-1) - - with torch.no_grad(): - cls_prob = pred_logits.sigmoid() - alpha = 0.25 - gamma = 2.0 - - neg_cost = (1 - alpha) * (cls_prob**gamma) * (-(1 - cls_prob + 1e-8).log()) - - pos_cost = alpha * ((1 - cls_prob) ** gamma) * (-(cls_prob + 1e-8).log()) - - cost_cls = pos_cost[:, tgt_cls] - neg_cost[:, tgt_cls] - cost_l1 = torch.cdist( - pred_boxes_b[:, :4], - tgt_boxes[:, :4], - p=1, - ) - cost_rot = 1 - (pred_rot @ tgt_rot.T).abs() - total_cost = 5.0 * cost_cls + 2.0 * cost_l1 + 1.0 * cost_rot - matching_matrix = torch.zeros( - (Q, num_gt), - dtype=torch.bool, - device=device, - ) - - center_dist = torch.cdist( - pred_boxes_b[:, :2], - tgt_boxes[:, :2], - p=2, - ) - - iou_like = torch.exp(-center_dist) - dynamic_k = iou_like.sum(0).int().clamp(min=1, max=10) - - for gt_idx in range(num_gt): - _, candidate_idx = torch.topk(-total_cost[:, gt_idx], k=int(dynamic_k[gt_idx].item())) - matching_matrix[candidate_idx, gt_idx] = True - - # resolve duplicate matches - multiple_match_mask = matching_matrix.sum(1) > 1 - - if multiple_match_mask.any(): - duplicate_idx = multiple_match_mask.nonzero(as_tuple=False).squeeze(1) - min_cost_idx = total_cost[duplicate_idx].argmin(dim=1) - # Set all matches to False for the duplicate indices, - # then set the match with the lowest cost to True - matching_matrix[duplicate_idx] = False - matching_matrix[duplicate_idx, min_cost_idx] = True - - pos_idx, gt_indices = matching_matrix.nonzero(as_tuple=True) - - target_classes = torch.zeros((Q,), dtype=torch.long, device=device) - - # background = 0 - target_classes[pos_idx] = tgt_cls[gt_indices] - - total_cls += F.cross_entropy(pred_logits, target_classes) + # ProbIoU Loss + mu1_l, sig1_l = _rotated_boxes_to_gaussian(src_boxes.to(torch.float32)) + mu2_l, sig2_l = _rotated_boxes_to_gaussian(target_boxes_matched.to(torch.float32)) + bhat_loss = _bhattacharyya_distance(mu1_l, sig1_l, mu2_l, sig2_l) + loss_probiou = (1.0 - torch.exp(-bhat_loss)).to(dtype).sum() / num_boxes - if len(pos_idx) == 0: - continue + # Rotation Loss + pred_rot = F.normalize(src_boxes[:, 4:6], dim=-1, eps=1e-6) + tgt_rot = F.normalize(target_boxes_matched[:, 4:6], dim=-1, eps=1e-6) + loss_rot = (1.0 - torch.abs((pred_rot * tgt_rot).sum(dim=-1))).sum() / num_boxes - pred_sel = pred_boxes_b[pos_idx] - tgt_sel = tgt_boxes[gt_indices] - # L1 loss on (cx, cy, w, h) - l1_loss = F.smooth_l1_loss(pred_sel[:, :4], tgt_sel[:, :4]) - # ProbIoU loss on the whole box (including rotation) - probiou_loss = _probiou_loss(pred_sel, tgt_sel).mean() - total_box += 2.0 * l1_loss + 0.5 * probiou_loss - # Rotation loss - cos_sim = (pred_rot[pos_idx] * tgt_rot[gt_indices]).sum(-1).abs() - rot_loss = (1 - cos_sim).mean() - total_rot += 0.5 * rot_loss - # Average the loss over the batch - return (total_cls + total_box + total_rot) / B + return class_weight * loss_ce + bbox_weight * loss_bbox + probiou_weight * loss_probiou + rot_weight * loss_rot def _lw_detr( diff --git a/references/classification/train_character.py b/references/classification/train_character.py index c266bb2f5c..22a9298ecb 100644 --- a/references/classification/train_character.py +++ b/references/classification/train_character.py @@ -66,7 +66,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): targets = torch.tensor(targets) @@ -79,7 +79,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -110,7 +110,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -126,7 +126,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -168,7 +168,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): targets = targets.cuda() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) loss = cross_entropy(out, targets) else: diff --git a/references/classification/train_orientation.py b/references/classification/train_orientation.py index 86dc4c5931..0cbbf05616 100644 --- a/references/classification/train_orientation.py +++ b/references/classification/train_orientation.py @@ -78,7 +78,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): targets = torch.tensor(targets) @@ -91,7 +91,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -122,7 +122,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -138,7 +138,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) train_loss = cross_entropy(out, targets) scaler.scale(train_loss).backward() @@ -180,7 +180,7 @@ def evaluate(model, val_loader, batch_transforms, amp=False, log=None): targets = targets.cuda() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images) loss = cross_entropy(out, targets) else: diff --git a/references/detection/evaluate.py b/references/detection/evaluate.py index 7c5fb597aa..a8674884f4 100644 --- a/references/detection/evaluate.py +++ b/references/detection/evaluate.py @@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): images = batch_transforms(images) targets = [{CLASS_NAME: t} for t in targets] if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/detection/train.py b/references/detection/train.py index 43e4ec3a56..37b439a6e5 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -63,7 +63,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): @@ -74,7 +74,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -107,7 +107,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -120,7 +120,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -163,7 +163,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False, l images = images.cuda() images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/layout/evaluate.py b/references/layout/evaluate.py index 36fb78cf80..aa8ea788aa 100644 --- a/references/layout/evaluate.py +++ b/references/layout/evaluate.py @@ -39,7 +39,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): padding_masks = padding_masks.cuda() imgs = batch_transforms(imgs) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(imgs, padding_masks, targets, return_preds=True) else: out = model(imgs, padding_masks, targets, return_preds=True) diff --git a/references/layout/train.py b/references/layout/train.py index f3f0d2c117..bf8a3a600d 100644 --- a/references/layout/train.py +++ b/references/layout/train.py @@ -17,7 +17,14 @@ # The following import is required for DDP import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + LinearLR, + MultiplicativeLR, + OneCycleLR, + PolynomialLR, + SequentialLR, +) from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort @@ -63,7 +70,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): imgs, padding_masks = images @@ -77,19 +84,19 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(imgs, padding_masks, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) optimizer.step() # Update LR scheduler.step() @@ -110,7 +117,7 @@ def record_lr( def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -125,19 +132,19 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, a optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(imgs, padding_masks, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) # Update the params scaler.step(optimizer) scaler.update() else: train_loss = model(imgs, padding_masks, targets)["loss"] train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1) optimizer.step() scheduler.step() @@ -170,7 +177,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non padding_masks = padding_masks.cuda() imgs = batch_transforms(imgs) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(imgs, padding_masks, targets, return_preds=True) else: out = model(imgs, padding_masks, targets, return_preds=True) @@ -472,20 +479,29 @@ def main(args): # construct DDP model model = DDP(model, device_ids=[rank]) + backbone_params = [p for n, p in model.named_parameters() if n.startswith("feat_extractor.") and p.requires_grad] + decoder_params = [p for n, p in model.named_parameters() if not n.startswith("feat_extractor.") and p.requires_grad] + # Optimizer if args.optim == "adam": optimizer = torch.optim.Adam( - [p for p in model.parameters() if p.requires_grad], - args.lr, - betas=(0.95, 0.999), + [ + {"params": backbone_params, "lr": 1e-5, "weight_decay": args.weight_decay or 1e-4}, + {"params": decoder_params, "lr": args.lr, "weight_decay": args.weight_decay or 1e-4}, + ], + lr=args.lr, + betas=(0.9, 0.999), eps=1e-6, - weight_decay=args.weight_decay, + weight_decay=args.weight_decay or 1e-4, ) elif args.optim == "adamw": optimizer = torch.optim.AdamW( - [p for p in model.parameters() if p.requires_grad], - args.lr, + [ + {"params": backbone_params, "lr": 1e-5, "weight_decay": args.weight_decay or 1e-4}, + {"params": decoder_params, "lr": args.lr, "weight_decay": args.weight_decay or 1e-4}, + ], + lr=args.lr, betas=(0.9, 0.999), eps=1e-6, weight_decay=args.weight_decay or 1e-4, @@ -498,12 +514,55 @@ def main(args): return # Scheduler + total_steps = args.epochs * len(train_loader) + warmup_steps = min(1000, max(200, total_steps // 20)) + if args.sched == "cosine": - scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=warmup_steps, + ) + cosine = CosineAnnealingLR( + optimizer, + T_max=total_steps - warmup_steps, + eta_min=args.lr * 0.01, + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup, cosine], + milestones=[warmup_steps], + ) + elif args.sched == "onecycle": - scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + scheduler = OneCycleLR( + optimizer, + max_lr=[g["lr"] for g in optimizer.param_groups], + total_steps=total_steps, + pct_start=warmup_steps / total_steps, + div_factor=100, + final_div_factor=100, + anneal_strategy="cos", + ) + elif args.sched == "poly": - scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=warmup_steps, + ) + poly = PolynomialLR( + optimizer, + total_iters=total_steps - warmup_steps, + power=1.0, + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup, poly], + milestones=[warmup_steps], + ) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -690,8 +749,8 @@ def parse_args(): "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" ) parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") - parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam or AdamW)") - parser.add_argument("--wd", "--weight-decay", default=0, type=float, help="weight decay", dest="weight_decay") + parser.add_argument("--lr", type=float, default=4e-4, help="learning rate for the optimizer (Adam or AdamW)") + parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, help="weight decay", dest="weight_decay") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") diff --git a/references/recognition/evaluate.py b/references/recognition/evaluate.py index 45a6b38306..22f69fa2cd 100644 --- a/references/recognition/evaluate.py +++ b/references/recognition/evaluate.py @@ -38,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): images = images.cuda() images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/references/recognition/train.py b/references/recognition/train.py index dc6b7b1b24..fd0cd1826d 100644 --- a/references/recognition/train.py +++ b/references/recognition/train.py @@ -68,7 +68,7 @@ def record_lr( loss_recorder = [] if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): @@ -79,7 +79,7 @@ def record_lr( # Forward, Backward & update optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -112,7 +112,7 @@ def record_lr( def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler("cuda") model.train() # Iterate over the batches of the dataset @@ -125,7 +125,7 @@ def fit_one_epoch(model, device, train_loader, batch_transforms, optimizer, sche optimizer.zero_grad() if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): train_loss = model(images, targets)["loss"] scaler.scale(train_loss).backward() # Gradient clipping @@ -167,7 +167,7 @@ def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False, images = images.to(device) images = batch_transforms(images) if amp: - with torch.cuda.amp.autocast(): + with torch.amp.autocast("cuda"): out = model(images, targets, return_preds=True) else: out = model(images, targets, return_preds=True) diff --git a/tests/pytorch/test_models_layout.py b/tests/pytorch/test_models_layout.py index d0c2865411..1094eee76f 100644 --- a/tests/pytorch/test_models_layout.py +++ b/tests/pytorch/test_models_layout.py @@ -76,7 +76,7 @@ def test_layout_models(arch_name, input_shape, train_mode, use_polygons): assert isinstance(results[1], np.ndarray) and results[1].shape == (len(results[0]), 4) assert isinstance(results[2], list) and all(isinstance(scores, float) for scores in results[2]) # Check class idxs are in the model's num_classes - assert all(0 <= idx < model.num_classes for idx in results[0]) + assert all(0 <= idx < len(model.class_names) for idx in results[0]) # Check scores are between 0 and 1 assert all(0 <= score <= 1 for score in results[2]) # Check that the number of boxes, labels and scores are the same