diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index c29c512e8cf..f2745b820c5 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -77,6 +77,35 @@ if [[ $GPU_ARCH_TYPE == 'cuda' || $GPU_ARCH_TYPE == 'rocm' ]]; then fi echo '::endgroup::' +if [[ $GPU_ARCH_TYPE == 'rocm' ]]; then + echo '::group::Install rocJPEG SDK' + # rocJPEG is shipped as a separate SDK package and isn't in the base ROCm + # builder image. Without its header ($ROCM_HOME/include/rocjpeg/rocjpeg.h) + # setup.py silently builds the HIP jpeg ops as stubs ("not compiled with + # nvJPEG support"), so install it before building torchvision. + # + # rocjpeg-devel requires libva-devel >= 2.16.0 or libva-amdgpu-devel. The base + # image's libva-devel is too old and libva-amdgpu-devel lives in AMD's separate + # "graphics" repo (not the rocm repo), so add that repo first. Derive the ROCm + # version from the existing rocm repo config, falling back to 7.1.1. + rocm_ver=$(grep -rhoE 'repo\.radeon\.com/rocm/[^/]+/[0-9][0-9.]*' /etc/yum.repos.d/ \ + | grep -oE '[0-9][0-9.]*$' | head -1) + rocm_ver=${rocm_ver:-7.1.1} + cat > /etc/yum.repos.d/amdgpu-graphics.repo < CUDAJpegDecoder::decode_images( i < output_tensors.size(); ++i) { if (channels[i] == 1) { - output_tensors[i] = torch::stable::clone(torch::stable::unsqueeze( - torch::stable::select(output_tensors[i], 0, 0), 0)); + output_tensors[i] = torch::stable::clone( + torch::stable::unsqueeze( + torch::stable::select(output_tensors[i], 0, 0), 0)); } } } @@ -618,9 +619,13 @@ STABLE_TORCH_LIBRARY_FRAGMENT(image, m) { "decode_jpegs_cuda(Tensor[] encoded_images, int mode, Device device) -> Tensor[]"); } +// In ROCm builds, the hand-written rocJPEG implementation registers this op. +// Keep this registration for nvJPEG and the no-GPU-JPEG fallback only. +#if !ROCJPEG_FOUND STABLE_TORCH_LIBRARY_IMPL(image, CompositeExplicitAutograd, m) { m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); } +#endif } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/hip/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/hip/decode_jpegs_cuda.cpp new file mode 100644 index 00000000000..043560a36d8 --- /dev/null +++ b/torchvision/csrc/io/image/hip/decode_jpegs_cuda.cpp @@ -0,0 +1,210 @@ +#include "decode_jpegs_cuda.h" + +#include +#include + +#if ROCJPEG_FOUND +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace vision { +namespace image { + +namespace { +uint32_t align_up(uint32_t value) { + constexpr uint32_t kRocJpegPitchAlignment = 16; + return (value + kRocJpegPitchAlignment - 1) & ~(kRocJpegPitchAlignment - 1); +} + +std::mutex decoderMutex; +std::unique_ptr rocJpegDecoder; +} // namespace + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::stable::Device device) { + std::lock_guard lock(decoderMutex); + + STD_TORCH_CHECK( + device.is_cuda(), "Expected the device parameter to be a cuda device"); + + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + for (auto& encoded_image : encoded_images) { + STD_TORCH_CHECK( + encoded_image.scalar_type() == torch::headeronly::ScalarType::Byte, + "Expected a torch.uint8 tensor"); + STD_TORCH_CHECK( + !encoded_image.is_cuda(), "The input tensor must be on CPU"); + STD_TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + // rocJPEG requires contiguous input; contiguous() is a no-op when it + // already is. + contig_images.push_back(torch::stable::contiguous(encoded_image)); + } + + auto target_device = device.index() >= 0 + ? device + : torch::stable::Device( + torch::headeronly::DeviceType::CUDA, + torch::stable::accelerator::getCurrentDeviceIndex()); + torch::stable::accelerator::DeviceGuard device_guard(target_device.index()); + + if (rocJpegDecoder == nullptr || + target_device != rocJpegDecoder->target_device) { + if (rocJpegDecoder != nullptr) { + rocJpegDecoder.reset(new RocJpegDecoder(target_device)); + } else { + rocJpegDecoder = std::make_unique(target_device); + std::atexit([]() { rocJpegDecoder.reset(); }); + } + } + + try { + return rocJpegDecoder->decode_images(contig_images, mode); + } catch (const std::exception& e) { + STD_TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } +} + +RocJpegDecoder::RocJpegDecoder(const torch::stable::Device& target_device) + : target_device{target_device} { + torch::stable::accelerator::DeviceGuard device_guard(target_device.index()); + CHECK_HIP(hipSetDevice(target_device.index())); + CHECK_ROCJPEG(rocJpegCreate( + ROCJPEG_BACKEND_HARDWARE, target_device.index(), &rocjpeg_handle_)); +} + +RocJpegDecoder::~RocJpegDecoder() { + rocJpegDestroy(rocjpeg_handle_); + for (auto stream_handle : rocjpeg_stream_handles_) { + rocJpegStreamDestroy(stream_handle); + } +} + +std::vector RocJpegDecoder::decode_images( + const std::vector& encoded_images, + vision::image::ImageReadMode mode) { + const std::size_t num_images = encoded_images.size(); + // Reuse existing rocJPEG stream handles and create only the missing ones. + while (rocjpeg_stream_handles_.size() < num_images) { + RocJpegStreamHandle stream_handle; + CHECK_ROCJPEG(rocJpegStreamCreate(&stream_handle)); + rocjpeg_stream_handles_.push_back(stream_handle); + } + + std::vector decode_params(num_images); + std::vector output_images(num_images); + std::vector output_tensors; + output_tensors.reserve(num_images); + + for (std::size_t i = 0; i < num_images; ++i) { + CHECK_ROCJPEG(rocJpegStreamParse( + encoded_images[i].const_data_ptr(), + encoded_images[i].numel(), + rocjpeg_stream_handles_[i])); + + uint8_t num_components = 0; + RocJpegChromaSubsampling subsampling = ROCJPEG_CSS_UNKNOWN; + uint32_t widths[ROCJPEG_MAX_COMPONENT] = {}; + uint32_t heights[ROCJPEG_MAX_COMPONENT] = {}; + CHECK_ROCJPEG(rocJpegGetImageInfo( + rocjpeg_handle_, + rocjpeg_stream_handles_[i], + &num_components, + &subsampling, + widths, + heights)); + + const uint32_t width = widths[0]; + const uint32_t height = heights[0]; + STD_TORCH_CHECK( + width >= 64 && height >= 64, + "Image resolution ", + width, + "x", + height, + " is below the rocJPEG hardware JPEG decoder minimum of 64x64"); + STD_TORCH_CHECK( + subsampling != ROCJPEG_CSS_411 && subsampling != ROCJPEG_CSS_UNKNOWN, + "The image chroma subsampling is not supported by the rocJPEG hardware JPEG decoder"); + + RocJpegOutputFormat image_output_format; + uint32_t num_channels; + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + // torchvision's UNCHANGED mode is expected to match the CPU/nvJPEG + // behavior: grayscale JPEGs return one channel, while color JPEGs + // return RGB. + if (num_components == 1) { + image_output_format = ROCJPEG_OUTPUT_Y; + num_channels = 1; + } else { + image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + num_channels = 3; + } + break; + case vision::image::IMAGE_READ_MODE_GRAY: + image_output_format = ROCJPEG_OUTPUT_Y; + num_channels = 1; + break; + case vision::image::IMAGE_READ_MODE_RGB: + image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + num_channels = 3; + break; + default: + STD_TORCH_CHECK( + false, + "The provided mode is not supported for JPEG decoding on GPU"); + } + + // rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer + // padded to that alignment and return a view of the valid region. + uint32_t pitch = align_up(width); + auto buffer = torch::stable::empty( + {int64_t(num_channels), int64_t(align_up(height)), int64_t(pitch)}, + torch::headeronly::ScalarType::Byte, + std::nullopt, + target_device); + + decode_params[i].output_format = image_output_format; + for (uint32_t c = 0; c < num_channels; ++c) { + output_images[i].channel[c] = + torch::stable::select(buffer, 0, c).mutable_data_ptr(); + output_images[i].pitch[c] = pitch; + } + auto valid_height = torch::stable::narrow(buffer, 1, 0, height); + output_tensors.push_back(torch::stable::narrow(valid_height, 2, 0, width)); + } + + // Choosing a batch size that is a multiple of the available JPEG cores is + // recommended. + CHECK_ROCJPEG(rocJpegDecodeBatched( + rocjpeg_handle_, + rocjpeg_stream_handles_.data(), + static_cast(num_images), + decode_params.data(), + output_images.data())); + // rocJpegDecodeBatched synchronizes rocJPEG's internal HIP stream before + // returning, so the decoded output tensors are ready for PyTorch streams. + + return output_tensors; +} + +STABLE_TORCH_LIBRARY_IMPL(image, CompositeExplicitAutograd, m) { + m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); +} + +} // namespace image +} // namespace vision + +#endif diff --git a/torchvision/csrc/io/image/hip/decode_jpegs_cuda.h b/torchvision/csrc/io/image/hip/decode_jpegs_cuda.h new file mode 100644 index 00000000000..4623ac79d1b --- /dev/null +++ b/torchvision/csrc/io/image/hip/decode_jpegs_cuda.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include + +#include "../common_stable.h" + +#if ROCJPEG_FOUND + +#include +#include + +// rocJPEG decode API documentation: +// https://rocm.docs.amd.com/projects/rocJPEG/en/latest/how-to/rocjpeg-decoding-a-jpeg-stream.html + +namespace vision { +namespace image { +class RocJpegDecoder { + public: + RocJpegDecoder(const torch::stable::Device& target_device); + ~RocJpegDecoder(); + + std::vector decode_images( + const std::vector& encoded_images, + ImageReadMode mode); + + const torch::stable::Device target_device; + + private: + std::vector rocjpeg_stream_handles_; + RocJpegHandle rocjpeg_handle_; +}; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + ImageReadMode mode, + torch::stable::Device device); + +} // namespace image +} // namespace vision + +#define CHECK_ROCJPEG(call) \ + { \ + RocJpegStatus rocjpeg_status = (call); \ + STD_TORCH_CHECK( \ + rocjpeg_status == ROCJPEG_STATUS_SUCCESS, \ + #call, \ + " returned ", \ + rocJpegGetErrorName(rocjpeg_status)); \ + } + +#define CHECK_HIP(call) \ + { \ + hipError_t hip_status = (call); \ + STD_TORCH_CHECK( \ + hip_status == hipSuccess, \ + #call, \ + " failed with status: ", \ + hipGetErrorName(hip_status)); \ + } + +#endif