diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7b950b0c45b..ad61dd234e4 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -63,7 +63,7 @@ def _interpolation_modes_from_int(i: int) -> InterpolationMode: _is_pil_image = F_pil._is_pil_image -def get_dimensions(img: Tensor) -> list[int]: +def get_dimensions(img: Union[Tensor, PILImage]) -> list[int]: """Returns the dimensions of an image as [channels, height, width]. Args: @@ -80,7 +80,7 @@ def get_dimensions(img: Tensor) -> list[int]: return F_pil.get_dimensions(img) -def get_image_size(img: Tensor) -> list[int]: +def get_image_size(img: Union[Tensor, PILImage]) -> list[int]: """Returns the size of an image as [width, height]. Args: @@ -97,7 +97,7 @@ def get_image_size(img: Tensor) -> list[int]: return F_pil.get_image_size(img) -def get_image_num_channels(img: Tensor) -> int: +def get_image_num_channels(img: Union[Tensor, PILImage]) -> int: """Returns the number of channels of an image. Args: @@ -178,7 +178,7 @@ def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor: return img -def pil_to_tensor(pic: Any) -> Tensor: +def pil_to_tensor(pic: PILImage) -> Tensor: """Convert a ``PIL Image`` to a tensor of the same type. This function does not support torchscript. @@ -385,12 +385,12 @@ def _compute_resized_output_size( def resize( - img: Tensor, + img: Union[Tensor, PILImage], size: list[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[bool] = True, -) -> Tensor: +) -> Union[Tensor, PILImage]: r"""Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -479,7 +479,7 @@ def resize( return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) -def pad(img: Tensor, padding: list[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Tensor: +def pad(img: Union[Tensor, PILImage], padding: list[int], fill: Union[int, float] = 0, padding_mode: str = "constant") -> Union[Tensor, PILImage]: r"""Pad the given image on all sides with the given "pad" value. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric, @@ -528,7 +528,7 @@ def pad(img: Tensor, padding: list[int], fill: Union[int, float] = 0, padding_mo return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) -def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: +def crop(img: Union[Tensor, PILImage], top: int, left: int, height: int, width: int) -> Union[Tensor, PILImage]: """Crop the given image at specified location and output size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -651,7 +651,7 @@ def resized_crop( return img -def hflip(img: Tensor) -> Tensor: +def hflip(img: Union[Tensor, PILImage]) -> Union[Tensor, PILImage]: """Horizontally flip the given image. Args: @@ -705,12 +705,12 @@ def _get_perspective_coeffs(startpoints: list[list[int]], endpoints: list[list[i def perspective( - img: Tensor, + img: Union[Tensor, PILImage], startpoints: list[list[int]], endpoints: list[list[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[list[float]] = None, -) -> Tensor: +) -> Union[Tensor, PILImage]: """Perform perspective transform of the given image. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -754,7 +754,7 @@ def perspective( return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill) -def vflip(img: Tensor) -> Tensor: +def vflip(img: Union[Tensor, PILImage]) -> Union[Tensor, PILImage]: """Vertically flip the given image. Args: @@ -865,7 +865,7 @@ def ten_crop( return first_five + second_five -def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: +def adjust_brightness(img: Union[Tensor, PILImage], brightness_factor: float) -> Union[Tensor, PILImage]: """Adjust brightness of an image. Args: @@ -887,7 +887,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: return F_t.adjust_brightness(img, brightness_factor) -def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: +def adjust_contrast(img: Union[Tensor, PILImage], contrast_factor: float) -> Union[Tensor, PILImage]: """Adjust contrast of an image. Args: @@ -909,7 +909,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: return F_t.adjust_contrast(img, contrast_factor) -def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: +def adjust_saturation(img: Union[Tensor, PILImage], saturation_factor: float) -> Union[Tensor, PILImage]: """Adjust color saturation of an image. Args: @@ -931,7 +931,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: return F_t.adjust_saturation(img, saturation_factor) -def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: +def adjust_hue(img: Union[Tensor, PILImage], hue_factor: float) -> Union[Tensor, PILImage]: """Adjust hue of an image. The image hue is adjusted by converting the image to HSV and @@ -970,7 +970,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: return F_t.adjust_hue(img, hue_factor) -def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: +def adjust_gamma(img: Union[Tensor, PILImage], gamma: float, gain: float = 1) -> Union[Tensor, PILImage]: r"""Perform gamma correction on an image. Also known as Power Law Transform. Intensities in RGB mode are adjusted @@ -1064,13 +1064,13 @@ def _get_inverse_affine_matrix( def rotate( - img: Tensor, + img: Union[Tensor, PILImage], angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, center: Optional[list[int]] = None, fill: Optional[list[float]] = None, -) -> Tensor: +) -> Union[Tensor, PILImage]: """Rotate the image by angle. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1133,7 +1133,7 @@ def rotate( def affine( - img: Tensor, + img: Union[Tensor, PILImage], angle: float, translate: list[int], scale: float, @@ -1141,7 +1141,7 @@ def affine( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[list[float]] = None, center: Optional[list[int]] = None, -) -> Tensor: +) -> Union[Tensor, PILImage]: """Apply affine transformation on the image keeping image center invariant. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. @@ -1264,7 +1264,7 @@ def to_grayscale(img, num_output_channels=1): raise TypeError("Input should be PIL Image") -def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: +def rgb_to_grayscale(img: Union[Tensor, PILImage], num_output_channels: int = 1) -> Union[Tensor, PILImage]: """Convert RGB image to grayscale version of image. If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions @@ -1315,7 +1315,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return F_t.erase(img, i, j, h, w, v, inplace=inplace) -def gaussian_blur(img: Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> Tensor: +def gaussian_blur(img: Union[Tensor, PILImage], kernel_size: list[int], sigma: Optional[list[float]] = None) -> Union[Tensor, PILImage]: """Performs Gaussian blurring on the image by given kernel The convolution will be using reflection padding corresponding to the kernel size, to maintain the input shape. @@ -1384,7 +1384,7 @@ def gaussian_blur(img: Tensor, kernel_size: list[int], sigma: Optional[list[floa return output -def invert(img: Tensor) -> Tensor: +def invert(img: Union[Tensor, PILImage]) -> Union[Tensor, PILImage]: """Invert the colors of an RGB/grayscale image. Args: @@ -1404,7 +1404,7 @@ def invert(img: Tensor) -> Tensor: return F_t.invert(img) -def posterize(img: Tensor, bits: int) -> Tensor: +def posterize(img: Union[Tensor, PILImage], bits: int) -> Union[Tensor, PILImage]: """Posterize an image by reducing the number of bits for each color channel. Args: @@ -1428,7 +1428,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: return F_t.posterize(img, bits) -def solarize(img: Tensor, threshold: float) -> Tensor: +def solarize(img: Union[Tensor, PILImage], threshold: float) -> Union[Tensor, PILImage]: """Solarize an RGB/grayscale image by inverting all pixel values above a threshold. Args: @@ -1448,7 +1448,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: return F_t.solarize(img, threshold) -def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: +def adjust_sharpness(img: Union[Tensor, PILImage], sharpness_factor: float) -> Union[Tensor, PILImage]: """Adjust the sharpness of an image. Args: @@ -1470,7 +1470,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: return F_t.adjust_sharpness(img, sharpness_factor) -def autocontrast(img: Tensor) -> Tensor: +def autocontrast(img: Union[Tensor, PILImage]) -> Union[Tensor, PILImage]: """Maximize contrast of an image by remapping its pixels per channel so that the lowest becomes black and the lightest becomes white. @@ -1492,7 +1492,7 @@ def autocontrast(img: Tensor) -> Tensor: return F_t.autocontrast(img) -def equalize(img: Tensor) -> Tensor: +def equalize(img: Union[Tensor, PILImage]) -> Union[Tensor, PILImage]: """Equalize the histogram of an image by applying a non-linear mapping to the input in order to create a uniform distribution of grayscale values in the output. @@ -1516,7 +1516,7 @@ def equalize(img: Tensor) -> Tensor: def elastic_transform( - img: Tensor, + img: Union[Tensor, PILImage], displacement: Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[list[float]] = None,