diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 53d7f3cf303..ddea0082be6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -90,7 +90,7 @@ jobs: echo '::endgroup::' echo '::group::Run image tests' - pytest --junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" -v --durations=25 test/test_image.py -k "not cuda" + pytest --junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" -v --durations=25 test/test_image.py echo '::endgroup::' unittests-macos: diff --git a/setup.py b/setup.py index f01bb062d24..2f56a29db74 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) TORCHVISION_INCLUDE = os.environ.get("TORCHVISION_INCLUDE", "") @@ -45,6 +46,7 @@ print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") print(f"{USE_NVJPEG = }") +print(f"{USE_ROCJPEG = }") print(f"{NVCC_FLAGS = }") print(f"{TORCHVISION_INCLUDE = }") print(f"{TORCHVISION_LIBRARY = }") @@ -162,10 +164,12 @@ def get_macros_and_flags(): CSRS_DIR / "ops/cpu/nms_kernel.cpp", CSRS_DIR / "ops/mps/nms_kernel.mm", CSRS_DIR / "ops/quantized/cpu/qnms_kernel.cpp", - CSRS_DIR / "io/image/cuda/decode_jpegs_cuda.cpp", CSRS_DIR / "io/image/common_stable.cpp", } STABLE_SOURCES.add(CSRS_DIR / ("ops/hip/nms_kernel.hip" if IS_ROCM else "ops/cuda/nms_kernel.cu")) +STABLE_SOURCES.add( + CSRS_DIR / ("io/image/hip/decode_jpegs_cuda.cpp" if IS_ROCM else "io/image/cuda/decode_jpegs_cuda.cpp") +) def _not_stable(paths): @@ -440,18 +444,31 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): - nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - - if nvjpeg_found: - print("Building torchvision with NVJPEG image support") - libraries.append("nvjpeg") - define_macros += [("NVJPEG_FOUND", 1)] - Extension = CUDAExtension - else: + if IS_ROCM: + if USE_ROCJPEG and (torch.cuda.is_available() or FORCE_CUDA): + rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() + if rocjpeg_found: + print("Building torchvision with ROCJPEG image support") + libraries.append("rocjpeg") + define_macros += [("ROCJPEG_FOUND", 1)] + Extension = CUDAExtension + else: + warnings.warn("Building torchvision without ROCJPEG support") + elif USE_ROCJPEG: + warnings.warn("Building torchvision without ROCJPEG support") + else: + if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() + + if nvjpeg_found: + print("Building torchvision with NVJPEG image support") + libraries.append("nvjpeg") + define_macros += [("NVJPEG_FOUND", 1)] + Extension = CUDAExtension + else: + warnings.warn("Building torchvision without NVJPEG support") + elif USE_NVJPEG: warnings.warn("Building torchvision without NVJPEG support") - elif USE_NVJPEG: - warnings.warn("Building torchvision without NVJPEG support") return Extension( name="torchvision.image", @@ -481,12 +498,20 @@ def make_image_stable_extension(): ) Extension = CppExtension - if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): - nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - if nvjpeg_found: - libraries.append("nvjpeg") - define_macros += [("NVJPEG_FOUND", 1)] - Extension = CUDAExtension + if IS_ROCM: + if USE_ROCJPEG and (torch.cuda.is_available() or FORCE_CUDA): + rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() + if rocjpeg_found: + libraries.append("rocjpeg") + define_macros += [("ROCJPEG_FOUND", 1)] + Extension = CUDAExtension + else: + if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() + if nvjpeg_found: + libraries.append("nvjpeg") + define_macros += [("NVJPEG_FOUND", 1)] + Extension = CUDAExtension return Extension( name="torchvision.image_stable", diff --git a/test/test_image.py b/test/test_image.py index 2d9880a9d64..a396ffcaf6a 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -426,12 +426,15 @@ def test_decode_jpegs_cuda(mode, scripted): futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers + # rocJPEG's color conversion differs slightly from nvJPEG, so it needs a + # looser tolerance against the CPU reference. + tol = 2.5 if torch.version.hip is not None else 2 for decoded_images in decoded_images_threaded: assert len(decoded_images) == len(encoded_images) for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): assert decoded_image_cuda.shape == decoded_image_cpu.shape assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 - assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2 + assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < tol @needs_cuda diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 8839ed832e7..938e0cc8b0d 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -596,8 +596,9 @@ std::vector CUDAJpegDecoder::decode_images( i < output_tensors.size(); ++i) { if (channels[i] == 1) { - output_tensors[i] = torch::stable::clone(torch::stable::unsqueeze( - torch::stable::select(output_tensors[i], 0, 0), 0)); + output_tensors[i] = torch::stable::clone( + torch::stable::unsqueeze( + torch::stable::select(output_tensors[i], 0, 0), 0)); } } } @@ -618,9 +619,13 @@ STABLE_TORCH_LIBRARY_FRAGMENT(image, m) { "decode_jpegs_cuda(Tensor[] encoded_images, int mode, Device device) -> Tensor[]"); } +// In ROCm builds, the hand-written rocJPEG implementation registers this op. +// Keep this registration for nvJPEG and the no-GPU-JPEG fallback only. +#if !ROCJPEG_FOUND STABLE_TORCH_LIBRARY_IMPL(image, CompositeExplicitAutograd, m) { m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); } +#endif } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/hip/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/hip/decode_jpegs_cuda.cpp new file mode 100644 index 00000000000..043560a36d8 --- /dev/null +++ b/torchvision/csrc/io/image/hip/decode_jpegs_cuda.cpp @@ -0,0 +1,210 @@ +#include "decode_jpegs_cuda.h" + +#include +#include + +#if ROCJPEG_FOUND +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace vision { +namespace image { + +namespace { +uint32_t align_up(uint32_t value) { + constexpr uint32_t kRocJpegPitchAlignment = 16; + return (value + kRocJpegPitchAlignment - 1) & ~(kRocJpegPitchAlignment - 1); +} + +std::mutex decoderMutex; +std::unique_ptr rocJpegDecoder; +} // namespace + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::stable::Device device) { + std::lock_guard lock(decoderMutex); + + STD_TORCH_CHECK( + device.is_cuda(), "Expected the device parameter to be a cuda device"); + + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + for (auto& encoded_image : encoded_images) { + STD_TORCH_CHECK( + encoded_image.scalar_type() == torch::headeronly::ScalarType::Byte, + "Expected a torch.uint8 tensor"); + STD_TORCH_CHECK( + !encoded_image.is_cuda(), "The input tensor must be on CPU"); + STD_TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + // rocJPEG requires contiguous input; contiguous() is a no-op when it + // already is. + contig_images.push_back(torch::stable::contiguous(encoded_image)); + } + + auto target_device = device.index() >= 0 + ? device + : torch::stable::Device( + torch::headeronly::DeviceType::CUDA, + torch::stable::accelerator::getCurrentDeviceIndex()); + torch::stable::accelerator::DeviceGuard device_guard(target_device.index()); + + if (rocJpegDecoder == nullptr || + target_device != rocJpegDecoder->target_device) { + if (rocJpegDecoder != nullptr) { + rocJpegDecoder.reset(new RocJpegDecoder(target_device)); + } else { + rocJpegDecoder = std::make_unique(target_device); + std::atexit([]() { rocJpegDecoder.reset(); }); + } + } + + try { + return rocJpegDecoder->decode_images(contig_images, mode); + } catch (const std::exception& e) { + STD_TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } +} + +RocJpegDecoder::RocJpegDecoder(const torch::stable::Device& target_device) + : target_device{target_device} { + torch::stable::accelerator::DeviceGuard device_guard(target_device.index()); + CHECK_HIP(hipSetDevice(target_device.index())); + CHECK_ROCJPEG(rocJpegCreate( + ROCJPEG_BACKEND_HARDWARE, target_device.index(), &rocjpeg_handle_)); +} + +RocJpegDecoder::~RocJpegDecoder() { + rocJpegDestroy(rocjpeg_handle_); + for (auto stream_handle : rocjpeg_stream_handles_) { + rocJpegStreamDestroy(stream_handle); + } +} + +std::vector RocJpegDecoder::decode_images( + const std::vector& encoded_images, + vision::image::ImageReadMode mode) { + const std::size_t num_images = encoded_images.size(); + // Reuse existing rocJPEG stream handles and create only the missing ones. + while (rocjpeg_stream_handles_.size() < num_images) { + RocJpegStreamHandle stream_handle; + CHECK_ROCJPEG(rocJpegStreamCreate(&stream_handle)); + rocjpeg_stream_handles_.push_back(stream_handle); + } + + std::vector decode_params(num_images); + std::vector output_images(num_images); + std::vector output_tensors; + output_tensors.reserve(num_images); + + for (std::size_t i = 0; i < num_images; ++i) { + CHECK_ROCJPEG(rocJpegStreamParse( + encoded_images[i].const_data_ptr(), + encoded_images[i].numel(), + rocjpeg_stream_handles_[i])); + + uint8_t num_components = 0; + RocJpegChromaSubsampling subsampling = ROCJPEG_CSS_UNKNOWN; + uint32_t widths[ROCJPEG_MAX_COMPONENT] = {}; + uint32_t heights[ROCJPEG_MAX_COMPONENT] = {}; + CHECK_ROCJPEG(rocJpegGetImageInfo( + rocjpeg_handle_, + rocjpeg_stream_handles_[i], + &num_components, + &subsampling, + widths, + heights)); + + const uint32_t width = widths[0]; + const uint32_t height = heights[0]; + STD_TORCH_CHECK( + width >= 64 && height >= 64, + "Image resolution ", + width, + "x", + height, + " is below the rocJPEG hardware JPEG decoder minimum of 64x64"); + STD_TORCH_CHECK( + subsampling != ROCJPEG_CSS_411 && subsampling != ROCJPEG_CSS_UNKNOWN, + "The image chroma subsampling is not supported by the rocJPEG hardware JPEG decoder"); + + RocJpegOutputFormat image_output_format; + uint32_t num_channels; + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + // torchvision's UNCHANGED mode is expected to match the CPU/nvJPEG + // behavior: grayscale JPEGs return one channel, while color JPEGs + // return RGB. + if (num_components == 1) { + image_output_format = ROCJPEG_OUTPUT_Y; + num_channels = 1; + } else { + image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + num_channels = 3; + } + break; + case vision::image::IMAGE_READ_MODE_GRAY: + image_output_format = ROCJPEG_OUTPUT_Y; + num_channels = 1; + break; + case vision::image::IMAGE_READ_MODE_RGB: + image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + num_channels = 3; + break; + default: + STD_TORCH_CHECK( + false, + "The provided mode is not supported for JPEG decoding on GPU"); + } + + // rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer + // padded to that alignment and return a view of the valid region. + uint32_t pitch = align_up(width); + auto buffer = torch::stable::empty( + {int64_t(num_channels), int64_t(align_up(height)), int64_t(pitch)}, + torch::headeronly::ScalarType::Byte, + std::nullopt, + target_device); + + decode_params[i].output_format = image_output_format; + for (uint32_t c = 0; c < num_channels; ++c) { + output_images[i].channel[c] = + torch::stable::select(buffer, 0, c).mutable_data_ptr(); + output_images[i].pitch[c] = pitch; + } + auto valid_height = torch::stable::narrow(buffer, 1, 0, height); + output_tensors.push_back(torch::stable::narrow(valid_height, 2, 0, width)); + } + + // Choosing a batch size that is a multiple of the available JPEG cores is + // recommended. + CHECK_ROCJPEG(rocJpegDecodeBatched( + rocjpeg_handle_, + rocjpeg_stream_handles_.data(), + static_cast(num_images), + decode_params.data(), + output_images.data())); + // rocJpegDecodeBatched synchronizes rocJPEG's internal HIP stream before + // returning, so the decoded output tensors are ready for PyTorch streams. + + return output_tensors; +} + +STABLE_TORCH_LIBRARY_IMPL(image, CompositeExplicitAutograd, m) { + m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda)); +} + +} // namespace image +} // namespace vision + +#endif diff --git a/torchvision/csrc/io/image/hip/decode_jpegs_cuda.h b/torchvision/csrc/io/image/hip/decode_jpegs_cuda.h new file mode 100644 index 00000000000..4623ac79d1b --- /dev/null +++ b/torchvision/csrc/io/image/hip/decode_jpegs_cuda.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include + +#include "../common_stable.h" + +#if ROCJPEG_FOUND + +#include +#include + +// rocJPEG decode API documentation: +// https://rocm.docs.amd.com/projects/rocJPEG/en/latest/how-to/rocjpeg-decoding-a-jpeg-stream.html + +namespace vision { +namespace image { +class RocJpegDecoder { + public: + RocJpegDecoder(const torch::stable::Device& target_device); + ~RocJpegDecoder(); + + std::vector decode_images( + const std::vector& encoded_images, + ImageReadMode mode); + + const torch::stable::Device target_device; + + private: + std::vector rocjpeg_stream_handles_; + RocJpegHandle rocjpeg_handle_; +}; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + ImageReadMode mode, + torch::stable::Device device); + +} // namespace image +} // namespace vision + +#define CHECK_ROCJPEG(call) \ + { \ + RocJpegStatus rocjpeg_status = (call); \ + STD_TORCH_CHECK( \ + rocjpeg_status == ROCJPEG_STATUS_SUCCESS, \ + #call, \ + " returned ", \ + rocJpegGetErrorName(rocjpeg_status)); \ + } + +#define CHECK_HIP(call) \ + { \ + hipError_t hip_status = (call); \ + STD_TORCH_CHECK( \ + hip_status == hipSuccess, \ + #call, \ + " failed with status: ", \ + hipGetErrorName(hip_status)); \ + } + +#endif