diff --git a/src/openpi/shared/image_tools.py b/src/openpi/shared/image_tools.py index 8cde353520..b55bbdee6b 100644 --- a/src/openpi/shared/image_tools.py +++ b/src/openpi/shared/image_tools.py @@ -71,16 +71,19 @@ def resize_with_pad_torch( Resized and padded tensor with same shape format as input """ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + batch_dim_added = False if images.shape[-1] <= 4: # Assume channels-last format channels_last = True # Convert to channels-first for torch operations if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension + batch_dim_added = True images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] else: channels_last = False if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension + batch_dim_added = True batch_size, channels, cur_height, cur_width = images.shape @@ -120,7 +123,7 @@ def resize_with_pad_torch( # Convert back to original format if needed if channels_last: padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - if batch_size == 1 and images.shape[0] == 1: + if batch_dim_added: padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added return padded_images