Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f4fc92e
add rocjpeg support
xytpai Jan 8, 2026
a371c3e
update rocjpeg utils
xytpai Jan 8, 2026
e4c4fd0
rm cout
xytpai Jan 8, 2026
3d9041c
refine code
xytpai Jan 16, 2026
1d29986
rm unused file
xytpai Jan 16, 2026
15d8f11
refine code 2
xytpai Jan 16, 2026
fb2f9fc
Merge branch 'main' into xyt/rocjpeg_upstream
xytpai Jan 27, 2026
bc8c702
Merge branch 'main' into xyt/rocjpeg_upstream
zy1git Jan 28, 2026
173c23d
Merge branch 'main' into xyt/rocjpeg_upstream
zy1git Jan 28, 2026
09c589d
Merge branch 'main' into xyt/rocjpeg_upstream
xytpai Feb 9, 2026
8d4f6ff
Merge branch 'main' into xyt/rocjpeg_upstream
xytpai Jun 12, 2026
b68f0ef
full format support
xytpai Jun 13, 2026
e113fcc
remove stream dependency
xytpai Jun 13, 2026
85b55f1
make batch-size dynamic
xytpai Jun 13, 2026
dd23f0e
resolve remaining comments
xytpai Jun 13, 2026
722a4af
[ROCm] Clean up rocJPEG decode and share GPU JPEG scaffolding (#2)
jeffdaily Jun 18, 2026
a319739
refine IMAGE_READ_MODE_UNCHANGED
xytpai Jun 18, 2026
4b71908
rm dead code & refine comment
xytpai Jun 18, 2026
7ce968f
recover nv path
xytpai Jun 18, 2026
248894c
resolve comments
xytpai Jun 18, 2026
802cac2
apply clang-format
xytpai Jun 18, 2026
d942228
Separate rocJPEG and nvJPEG setup blocks
xytpai Jun 19, 2026
7581393
add _ suffix for private class members
xytpai Jun 19, 2026
a4073b0
just return padded tensor in its original layout
xytpai Jun 19, 2026
a2572c8
rm unnecessary sync
xytpai Jun 22, 2026
b413e54
refine code
xytpai Jun 22, 2026
0fe060a
add rocjpeg doc link
xytpai Jun 22, 2026
2e4047d
refine code
xytpai Jun 23, 2026
156b4ec
Split rocjpeg code (#3)
xytpai Jun 23, 2026
be0b6e5
Merge branch 'main' into xyt/rocjpeg_upstream
xytpai Jun 23, 2026
6620f0f
Merge branch 'main' of github.com:pytorch/vision into xyt/rocjpeg_ups…
NicolasHug Jun 25, 2026
3f2f94f
Let ROCm CI job test the jpeg decoder
NicolasHug Jun 25, 2026
81233bf
Merge branch 'main' into xyt/rocjpeg_upstream
NicolasHug Jun 25, 2026
89f074b
Merge branch 'pytorch:main' into xyt/rocjpeg_upstream
xytpai Jun 26, 2026
6903c5b
add stable abi support
xytpai Jun 27, 2026
004ebb5
add support for make_image_extension
xytpai Jun 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
echo '::endgroup::'
echo '::group::Run image tests'
pytest --junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" -v --durations=25 test/test_image.py -k "not cuda"
pytest --junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" -v --durations=25 test/test_image.py
echo '::endgroup::'
unittests-macos:
Expand Down
61 changes: 43 additions & 18 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)

TORCHVISION_INCLUDE = os.environ.get("TORCHVISION_INCLUDE", "")
Expand All @@ -45,6 +46,7 @@
print(f"{USE_JPEG = }")
print(f"{USE_WEBP = }")
print(f"{USE_NVJPEG = }")
print(f"{USE_ROCJPEG = }")
print(f"{NVCC_FLAGS = }")
print(f"{TORCHVISION_INCLUDE = }")
print(f"{TORCHVISION_LIBRARY = }")
Expand Down Expand Up @@ -162,10 +164,12 @@ def get_macros_and_flags():
CSRS_DIR / "ops/cpu/nms_kernel.cpp",
CSRS_DIR / "ops/mps/nms_kernel.mm",
CSRS_DIR / "ops/quantized/cpu/qnms_kernel.cpp",
CSRS_DIR / "io/image/cuda/decode_jpegs_cuda.cpp",
CSRS_DIR / "io/image/common_stable.cpp",
}
STABLE_SOURCES.add(CSRS_DIR / ("ops/hip/nms_kernel.hip" if IS_ROCM else "ops/cuda/nms_kernel.cu"))
STABLE_SOURCES.add(
CSRS_DIR / ("io/image/hip/decode_jpegs_cuda.cpp" if IS_ROCM else "io/image/cuda/decode_jpegs_cuda.cpp")
)


def _not_stable(paths):
Expand Down Expand Up @@ -440,18 +444,31 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without WEBP support")

if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA):
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

if nvjpeg_found:
print("Building torchvision with NVJPEG image support")
libraries.append("nvjpeg")
define_macros += [("NVJPEG_FOUND", 1)]
Extension = CUDAExtension
else:
if IS_ROCM:
if USE_ROCJPEG and (torch.cuda.is_available() or FORCE_CUDA):
rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists()
if rocjpeg_found:
print("Building torchvision with ROCJPEG image support")
libraries.append("rocjpeg")
define_macros += [("ROCJPEG_FOUND", 1)]
Extension = CUDAExtension
else:
warnings.warn("Building torchvision without ROCJPEG support")
elif USE_ROCJPEG:
warnings.warn("Building torchvision without ROCJPEG support")
else:
if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just leave the previous if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): block exactly like it was, and just have a separate (indepentent) ROCm-specific block below it? They should be mutually exclusive?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can rely on USE_NVJPEG and USE_ROCJPEG being mutually exclusive because both default to true. On ROCm builds, the nvJPEG block would still run and warn unless it is guarded by not IS_ROCM. I think the mutual exclusion should be based on the backend (not IS_ROCM for nvJPEG, IS_ROCM for rocJPEG), not on the two USE flags.

nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

if nvjpeg_found:
print("Building torchvision with NVJPEG image support")
libraries.append("nvjpeg")
define_macros += [("NVJPEG_FOUND", 1)]
Extension = CUDAExtension
else:
warnings.warn("Building torchvision without NVJPEG support")
elif USE_NVJPEG:
warnings.warn("Building torchvision without NVJPEG support")
elif USE_NVJPEG:
warnings.warn("Building torchvision without NVJPEG support")

return Extension(
name="torchvision.image",
Expand Down Expand Up @@ -481,12 +498,20 @@ def make_image_stable_extension():
)

Extension = CppExtension
if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA):
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()
if nvjpeg_found:
libraries.append("nvjpeg")
define_macros += [("NVJPEG_FOUND", 1)]
Extension = CUDAExtension
if IS_ROCM:
if USE_ROCJPEG and (torch.cuda.is_available() or FORCE_CUDA):
rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists()
if rocjpeg_found:
libraries.append("rocjpeg")
define_macros += [("ROCJPEG_FOUND", 1)]
Extension = CUDAExtension
else:
if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA):
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()
if nvjpeg_found:
libraries.append("nvjpeg")
define_macros += [("NVJPEG_FOUND", 1)]
Extension = CUDAExtension

return Extension(
name="torchvision.image_stable",
Expand Down
5 changes: 4 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,15 @@ def test_decode_jpegs_cuda(mode, scripted):
futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)]
decoded_images_threaded = [future.result() for future in futures]
assert len(decoded_images_threaded) == num_workers
# rocJPEG's color conversion differs slightly from nvJPEG, so it needs a
# looser tolerance against the CPU reference.
tol = 2.5 if torch.version.hip is not None else 2
for decoded_images in decoded_images_threaded:
assert len(decoded_images) == len(encoded_images)
for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu):
assert decoded_image_cuda.shape == decoded_image_cpu.shape
assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8
assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2
assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < tol


@needs_cuda
Expand Down
9 changes: 7 additions & 2 deletions torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,9 @@ std::vector<torch::stable::Tensor> CUDAJpegDecoder::decode_images(
i < output_tensors.size();
++i) {
if (channels[i] == 1) {
output_tensors[i] = torch::stable::clone(torch::stable::unsqueeze(
torch::stable::select(output_tensors[i], 0, 0), 0));
output_tensors[i] = torch::stable::clone(
torch::stable::unsqueeze(
torch::stable::select(output_tensors[i], 0, 0), 0));
}
}
}
Expand All @@ -618,9 +619,13 @@ STABLE_TORCH_LIBRARY_FRAGMENT(image, m) {
"decode_jpegs_cuda(Tensor[] encoded_images, int mode, Device device) -> Tensor[]");
}

// In ROCm builds, the hand-written rocJPEG implementation registers this op.
// Keep this registration for nvJPEG and the no-GPU-JPEG fallback only.
#if !ROCJPEG_FOUND
STABLE_TORCH_LIBRARY_IMPL(image, CompositeExplicitAutograd, m) {
m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda));
}
#endif

} // namespace image
} // namespace vision
210 changes: 210 additions & 0 deletions torchvision/csrc/io/image/hip/decode_jpegs_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#include "decode_jpegs_cuda.h"

#include <torch/csrc/stable/library.h>
#include <torch/headeronly/util/Exception.h>

#if ROCJPEG_FOUND
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/DeviceType.h>
#include <torch/headeronly/core/ScalarType.h>
#include <cstdlib>
#include <exception>
#include <memory>
#include <mutex>
#include <optional>

namespace vision {
namespace image {

namespace {
uint32_t align_up(uint32_t value) {
constexpr uint32_t kRocJpegPitchAlignment = 16;
return (value + kRocJpegPitchAlignment - 1) & ~(kRocJpegPitchAlignment - 1);
}

std::mutex decoderMutex;
std::unique_ptr<RocJpegDecoder> rocJpegDecoder;
} // namespace

std::vector<torch::stable::Tensor> decode_jpegs_cuda(
const std::vector<torch::stable::Tensor>& encoded_images,
vision::image::ImageReadMode mode,
torch::stable::Device device) {
std::lock_guard<std::mutex> lock(decoderMutex);

STD_TORCH_CHECK(
device.is_cuda(), "Expected the device parameter to be a cuda device");

std::vector<torch::stable::Tensor> contig_images;
contig_images.reserve(encoded_images.size());
for (auto& encoded_image : encoded_images) {
STD_TORCH_CHECK(
encoded_image.scalar_type() == torch::headeronly::ScalarType::Byte,
"Expected a torch.uint8 tensor");
STD_TORCH_CHECK(
!encoded_image.is_cuda(), "The input tensor must be on CPU");
STD_TORCH_CHECK(
encoded_image.dim() == 1 && encoded_image.numel() > 0,
"Expected a non empty 1-dimensional tensor");
// rocJPEG requires contiguous input; contiguous() is a no-op when it
// already is.
contig_images.push_back(torch::stable::contiguous(encoded_image));
}

auto target_device = device.index() >= 0
? device
: torch::stable::Device(
torch::headeronly::DeviceType::CUDA,
torch::stable::accelerator::getCurrentDeviceIndex());
torch::stable::accelerator::DeviceGuard device_guard(target_device.index());

if (rocJpegDecoder == nullptr ||
target_device != rocJpegDecoder->target_device) {
if (rocJpegDecoder != nullptr) {
rocJpegDecoder.reset(new RocJpegDecoder(target_device));
} else {
rocJpegDecoder = std::make_unique<RocJpegDecoder>(target_device);
std::atexit([]() { rocJpegDecoder.reset(); });
}
}

try {
return rocJpegDecoder->decode_images(contig_images, mode);
} catch (const std::exception& e) {
STD_TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what());
}
}

RocJpegDecoder::RocJpegDecoder(const torch::stable::Device& target_device)
: target_device{target_device} {
torch::stable::accelerator::DeviceGuard device_guard(target_device.index());
CHECK_HIP(hipSetDevice(target_device.index()));
CHECK_ROCJPEG(rocJpegCreate(
ROCJPEG_BACKEND_HARDWARE, target_device.index(), &rocjpeg_handle_));
}

RocJpegDecoder::~RocJpegDecoder() {
rocJpegDestroy(rocjpeg_handle_);
for (auto stream_handle : rocjpeg_stream_handles_) {
rocJpegStreamDestroy(stream_handle);
}
}

std::vector<torch::stable::Tensor> RocJpegDecoder::decode_images(
const std::vector<torch::stable::Tensor>& encoded_images,
vision::image::ImageReadMode mode) {
const std::size_t num_images = encoded_images.size();
// Reuse existing rocJPEG stream handles and create only the missing ones.
while (rocjpeg_stream_handles_.size() < num_images) {
RocJpegStreamHandle stream_handle;
CHECK_ROCJPEG(rocJpegStreamCreate(&stream_handle));
rocjpeg_stream_handles_.push_back(stream_handle);
}

std::vector<RocJpegDecodeParams> decode_params(num_images);
std::vector<RocJpegImage> output_images(num_images);
std::vector<torch::stable::Tensor> output_tensors;
output_tensors.reserve(num_images);

for (std::size_t i = 0; i < num_images; ++i) {
CHECK_ROCJPEG(rocJpegStreamParse(
encoded_images[i].const_data_ptr<uint8_t>(),
encoded_images[i].numel(),
rocjpeg_stream_handles_[i]));

uint8_t num_components = 0;
RocJpegChromaSubsampling subsampling = ROCJPEG_CSS_UNKNOWN;
uint32_t widths[ROCJPEG_MAX_COMPONENT] = {};
uint32_t heights[ROCJPEG_MAX_COMPONENT] = {};
CHECK_ROCJPEG(rocJpegGetImageInfo(
rocjpeg_handle_,
rocjpeg_stream_handles_[i],
&num_components,
&subsampling,
widths,
heights));

const uint32_t width = widths[0];
const uint32_t height = heights[0];
STD_TORCH_CHECK(
width >= 64 && height >= 64,
"Image resolution ",
width,
"x",
height,
" is below the rocJPEG hardware JPEG decoder minimum of 64x64");
STD_TORCH_CHECK(
subsampling != ROCJPEG_CSS_411 && subsampling != ROCJPEG_CSS_UNKNOWN,
"The image chroma subsampling is not supported by the rocJPEG hardware JPEG decoder");

RocJpegOutputFormat image_output_format;
uint32_t num_channels;
switch (mode) {
case vision::image::IMAGE_READ_MODE_UNCHANGED:
// torchvision's UNCHANGED mode is expected to match the CPU/nvJPEG
// behavior: grayscale JPEGs return one channel, while color JPEGs
// return RGB.
if (num_components == 1) {
image_output_format = ROCJPEG_OUTPUT_Y;
num_channels = 1;
} else {
image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR;
num_channels = 3;
}
break;
case vision::image::IMAGE_READ_MODE_GRAY:
image_output_format = ROCJPEG_OUTPUT_Y;
num_channels = 1;
break;
case vision::image::IMAGE_READ_MODE_RGB:
image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR;
num_channels = 3;
break;
default:
STD_TORCH_CHECK(
false,
"The provided mode is not supported for JPEG decoding on GPU");
}

// rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer
// padded to that alignment and return a view of the valid region.
uint32_t pitch = align_up(width);
auto buffer = torch::stable::empty(
{int64_t(num_channels), int64_t(align_up(height)), int64_t(pitch)},
torch::headeronly::ScalarType::Byte,
std::nullopt,
target_device);

decode_params[i].output_format = image_output_format;
for (uint32_t c = 0; c < num_channels; ++c) {
output_images[i].channel[c] =
torch::stable::select(buffer, 0, c).mutable_data_ptr<uint8_t>();
output_images[i].pitch[c] = pitch;
}
auto valid_height = torch::stable::narrow(buffer, 1, 0, height);
output_tensors.push_back(torch::stable::narrow(valid_height, 2, 0, width));
}

// Choosing a batch size that is a multiple of the available JPEG cores is
// recommended.
CHECK_ROCJPEG(rocJpegDecodeBatched(
rocjpeg_handle_,
rocjpeg_stream_handles_.data(),
static_cast<int>(num_images),
decode_params.data(),
output_images.data()));
// rocJpegDecodeBatched synchronizes rocJPEG's internal HIP stream before
// returning, so the decoded output tensors are ready for PyTorch streams.

return output_tensors;
}

STABLE_TORCH_LIBRARY_IMPL(image, CompositeExplicitAutograd, m) {
m.impl("decode_jpegs_cuda", TORCH_BOX(&decode_jpegs_cuda));
}

} // namespace image
} // namespace vision

#endif
Loading