From 0e4051b351d60b0b5a8c6c63fc996057f4251639 Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Fri, 6 Feb 2026 01:03:07 -0800 Subject: [PATCH] Preserve the tensor shape and only remove the batch dim when it was added earlier --- src/openpi/shared/image_tools.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/openpi/shared/image_tools.py b/src/openpi/shared/image_tools.py index 8cde353520..b55bbdee6b 100644 --- a/src/openpi/shared/image_tools.py +++ b/src/openpi/shared/image_tools.py @@ -71,16 +71,19 @@ def resize_with_pad_torch( Resized and padded tensor with same shape format as input """ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + batch_dim_added = False if images.shape[-1] <= 4: # Assume channels-last format channels_last = True # Convert to channels-first for torch operations if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension + batch_dim_added = True images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] else: channels_last = False if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension + batch_dim_added = True batch_size, channels, cur_height, cur_width = images.shape @@ -120,7 +123,7 @@ def resize_with_pad_torch( # Convert back to original format if needed if channels_last: padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - if batch_size == 1 and images.shape[0] == 1: + if batch_dim_added: padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added return padded_images