Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
808ee01
Multi-Device TensorRT Runtime with Native NCCL Collectives
apbose Apr 1, 2026
aaa6557
removing the try-except block in TRTengine.cpp and correcting the typis
apbose Apr 8, 2026
4c1e68d
Redesign distributed inference API: auto-detect rank, lazy NCCL setup…
apbose Apr 9, 2026
7cfa40b
remove nccl.h dependancy
apbose Apr 9, 2026
ac96255
clean up import and add comment
apbose Apr 9, 2026
fe1c6f4
moving setup_nccl_library call to example script
apbose Apr 9, 2026
b658c7a
work on the save/load export part-add is_md flag, guard export tracin…
apbose Apr 10, 2026
a35dfe6
refactor: Adjusting how we use NCCL
narendasan Apr 10, 2026
3def3f7
fix: enable torch.compile(backend='tensorrt') for LLMs with dynamic s…
narendasan Apr 10, 2026
2aa8f14
test: add torch.compile(backend='tensorrt') integration test for Llam…
narendasan Apr 10, 2026
6f81a66
feat: llama3.2 working with MD-TRT
narendasan Apr 10, 2026
0d2d61c
feat: Support exported and serialization workflows for MD-TRT
narendasan Apr 12, 2026
e08b0c5
ci: fix nccl builds in CI
narendasan Apr 12, 2026
754b62b
chore: Some reorg and cleaning the constructor
narendasan Apr 14, 2026
bf432ad
fix: thread the MD-TRT requirement through the conversion system
narendasan Apr 14, 2026
f4e77ad
fix: DeviceMesh FakeScriptObjects get passed in as arguments into tor…
narendasan Apr 16, 2026
6ba00cf
fix: Address segfaults when a distributed context is manually destroy…
narendasan Apr 16, 2026
edf6518
replacing torchrun with torchtrtrun for right .so
apbose Apr 16, 2026
9e390eb
chore: apply linting
narendasan Apr 16, 2026
1b4e559
use correct group for dummy all_reduce
apbose Apr 16, 2026
df51acf
Broaden NCCL skip guards to include native TRT collectives and fix di…
apbose Apr 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/build-test-linux-x86_64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ jobs:
test-infra-ref: main
build-matrix: ${{ needs.filter-matrix.outputs.matrix }}
pre-script: ${{ matrix.pre-script }}
runner: linux.g4dn.12xlarge.nvidia.gpu
script: |
set -euo pipefail
export USE_HOST_DEPS=1
Expand All @@ -526,7 +527,12 @@ jobs:
pushd .
cd tests/py
cd dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml distributed/test_nccl_ops.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/l2_dynamo_distributed_test_results.xml \
distributed/test_nccl_ops.py \
distributed/test_native_nccl.py \
distributed/test_export_save_load.py
python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_native_nccl.py --multirank
python -m torch_tensorrt.distributed.run --nproc_per_node=2 distributed/test_export_save_load.py --multirank
popd

concurrency:
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/linux-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ on:
default: false
type: boolean
required: false
runner:
description: "Override the runner label (e.g. linux.g4dn.12xlarge.nvidia.gpu for multi-GPU jobs). Defaults to matrix.validation_runner."
default: ""
type: string
required: false

jobs:
test:
Expand All @@ -76,7 +81,7 @@ jobs:
USE_TRT_RTX: ${{ inputs.use-rtx }}
DOWNLOAD_ARTIFACT_NAME: pytorch_tensorrt_${{ matrix.tensorrt.version }}_${{ matrix.python_version }}_${{ matrix.desired_cuda }}_${{ inputs.architecture }}
name: ${{ inputs.job-name }}-${{ matrix.tensorrt.version }}-${{ matrix.python_version }}-${{ matrix.desired_cuda }}
runs-on: ${{ matrix.validation_runner }}
runs-on: ${{ inputs.runner != '' && inputs.runner || matrix.validation_runner }}
container:
image: ${{ matrix.container_image }}
options: ${{ matrix.gpu_arch_type == 'cuda' && '--gpus all --shm-size=1g' || ' ' }}
Expand Down
5 changes: 5 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.16.0")
bazel_dep(name = "platforms", version = "0.0.11")
bazel_dep(name = "rules_cc", version = "0.1.1")
bazel_dep(name = "rules_python", version = "1.3.0")
bazel_dep(name = "bazel_skylib", version = "1.7.1")

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.toolchain(
Expand All @@ -26,6 +27,10 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.

local_torch = use_repo_rule("//toolchains:local_torch.bzl", "local_torch")

torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect")

torch_nccl_detect(name = "torch_nccl")

# External dependency for torch_tensorrt if you already have precompiled binaries.
new_local_repository(
name = "torch_tensorrt",
Expand Down
7 changes: 5 additions & 2 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")
load("//toolchains/torch_nccl:defs.bzl", "if_torch_nccl")

package(default_visibility = ["//visibility:public"])

config_setting(
Expand Down Expand Up @@ -77,13 +79,14 @@ cc_library(
"TRTEngineProfiler.h",
"runtime.h",
],
copts = if_torch_nccl(["-DUSE_C10D_NCCL"]),
linkopts = [
"-lstdc++fs",
],
deps = [
"//core/plugins:torch_tensorrt_plugins",
"//core/util:prelude",
] + select({
] + if_torch_nccl(["@torch_nccl//:nccl_headers"]) + select({
":jetpack": ["@tensorrt_l4t//:nvinfer"],
":rtx_win": ["@tensorrt_rtx_win//:nvinfer"],
":rtx_x86_64": ["@tensorrt_rtx//:nvinfer"],
Expand Down Expand Up @@ -121,6 +124,6 @@ pkg_tar(
pkg_files(
name = "include_pkg_files",
srcs = [":include_files"],
visibility = ["//visibility:public"],
prefix = "include/torch_tensorrt/core/runtime/",
visibility = ["//visibility:public"],
)
124 changes: 123 additions & 1 deletion core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include "core/util/prelude.h"
#include "torch/torch.h"

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
#include "torch/csrc/distributed/c10d/GroupRegistry.hpp"
#include "torch/csrc/distributed/c10d/NCCLUtils.hpp"
#include "torch/csrc/distributed/c10d/ProcessGroup.hpp"
#include "torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp"
#endif

namespace torch_tensorrt {
namespace core {
namespace runtime {
Expand Down Expand Up @@ -88,7 +95,12 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
serialized_info[SERIALIZED_METADATA_IDX],
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]))
? ResourceAllocationStrategy::kDynamic
: ResourceAllocationStrategy::kStatic)) {}
: ResourceAllocationStrategy::kStatic)) {
this->is_md = std::stoi(serialized_info[IS_MD_ENGINE_IDX]);
if (this->is_md) {
LOG_INFO("Loaded distributed TRT engine (contains NCCL collectives); NCCL comm will be bound on first execution");
}
}

TRTEngine::TRTEngine(
const std::string& mod_name,
Expand Down Expand Up @@ -261,6 +273,18 @@ TRTEngine::TRTEngine(
this->enable_profiling();
#endif
LOG_DEBUG(*this);

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
// Attempt to bind the NCCL communicator immediately after exec_ctx is ready.
// This handles the common case where dist.init_process_group() and an initial
// collective have already been called before the engine is constructed.
// If the communicator isn't available yet (e.g. engine constructed before the
// first collective), bind_nccl_comm returns false and execute_engine() will
// retry on its first invocation.
if (this->is_md) {
bind_nccl_comm();
}
#endif
}

TRTEngine::~TRTEngine() {
Expand Down Expand Up @@ -383,6 +407,13 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) {
if (profile_execution) {
enable_profiling();
}
#ifdef ENABLE_TRT_NCCL_COLLECTIVES
// exec_ctx was recreated — re-bind the NCCL communicator if this is a
// distributed engine that has already been set up.
if (nccl_initialized) {
bind_nccl_comm();
}
#endif
// Indicates to reevaluate the runtime settings
runtime_states.context_changed = true;

Expand Down Expand Up @@ -428,6 +459,7 @@ std::string TRTEngine::to_str() const {
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
ss << " Target Platform: " << target_platform << std::endl;
ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl;
ss << " Multi-Device Engine: " << (is_md) << std::endl;
// clang-format on
return ss.str();
}
Expand Down Expand Up @@ -497,6 +529,8 @@ std::vector<std::string> TRTEngine::serialize() {
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";
serialized_info[IS_MD_ENGINE_IDX] = this->is_md ? "1" : "0";
// rank/world_size are runtime facts (may differ at load time); not serialized.

return serialized_info;
}
Expand All @@ -519,6 +553,94 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
}
}

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
bool TRTEngine::bind_nccl_comm() {
// When group_name is empty (e.g. engine loaded from a serialized
// ExportedProgram where the Python TorchTensorRTModule wrapper was
// inlined and set_group_name() was never called), auto-resolve the
// process group from the c10d registry. PyTorch assigns sequential
// numeric names ("0", "1", ...) to process groups; probe until we
// find one with an NCCL backend.
if (this->group_name.empty() && this->is_md) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only do this if there is one available group. If there are multiple NCCL groups available we should tell the user to manually select

// PyTorch assigns sequential numeric names ("0", "1", ...) to process
// groups. In practice most jobs create fewer than 10 groups; we probe
// up to 20 to allow for destroyed-and-recreated groups.
for (int i = 0; i < 20; ++i) {
auto candidate = std::to_string(i);
auto probe = c10d::resolve_process_group(candidate);
if (probe != nullptr && probe->getBackendType() == c10d::ProcessGroup::BackendType::NCCL) {
this->group_name = candidate;
LOG_INFO("Auto-resolved distributed group name to '" << candidate << "'");
break;
}
}
if (this->group_name.empty()) {
LOG_WARNING(
"This TRT engine requires NCCL (is_md=true) but no NCCL process group "
"was found in the c10d registry. Ensure dist.init_process_group(backend='nccl') "
"has been called before loading the engine. You can also set the group name "
"manually via: engine.set_group_name(NCCL_GROUP_NAME)");
}
}

// Soft-return when the process group isn't available yet (e.g. at engine
// construction time when the caller hasn't called dist.init_process_group()).
auto pg = c10d::resolve_process_group(this->group_name);
if (pg == nullptr) {
LOG_DEBUG("ProcessGroup '" << this->group_name << "' not yet registered in c10d; NCCL bind deferred.");
return false;
}

this->rank = pg->getRank();
this->world_size = pg->getSize();

auto backend = pg->getBackend(c10d::ProcessGroup::BackendType::NCCL);
TORCHTRT_CHECK(backend != nullptr, "ProcessGroup '" << this->group_name << "' has no NCCL backend");

auto* nccl_pg = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get());
TORCHTRT_CHECK(nccl_pg != nullptr, "Backend is not ProcessGroupNCCL");

at::cuda::set_device(this->device_info.id);

int64_t comm_ptr = nccl_pg->getCommPtr();
// Soft-return when NCCL hasn't run a collective yet. The communicator is
// created lazily by PyTorch on the first collective — callers should ensure
// at least one collective (e.g. dist.barrier()) has been issued before the
// first TRT forward pass.
if (comm_ptr == 0) {
LOG_DEBUG(
"NCCL communicator not yet initialized for device " << this->device_info.id
<< "; NCCL bind deferred until first execute_engine call.");
return false;
}

TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Cannot bind NCCL communicator: execution context is null");
exec_ctx->setCommunicator(reinterpret_cast<void*>(comm_ptr));
this->nccl_initialized = true;
LOG_INFO("NCCL comm bound (rank=" << this->rank << ", device=" << this->device_info.id << ")");
return true;
}

void TRTEngine::release_nccl_comm() {
if (!this->nccl_initialized) {
return;
}
LOG_INFO("Releasing NCCL communicator from engine '" << this->name << "'");
torch::cuda::synchronize(device_info.id);
this->exec_ctx.reset();
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
this->exec_ctx =
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
} else {
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
}
TORCHTRT_CHECK(
(exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context after releasing NCCL comm");
this->nccl_initialized = false;
LOG_INFO("NCCL communicator released from engine '" << this->name << "'");
}
#endif // ENABLE_TRT_NCCL_COLLECTIVES

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
36 changes: 36 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,25 @@
#include "ATen/core/function_schema.h"
#include "ATen/cuda/CUDAGraph.h"
#include "NvInfer.h"
#include "NvInferVersion.h"
#include "c10/cuda/CUDAStream.h"
#include "torch/custom_class.h"

#include "core/runtime/TRTEngineProfiler.h"
#include "core/util/prelude.h"

// TensorRT 10.16+ has native NCCL collective support via IExecutionContext::setCommunicator()
#if NV_TENSORRT_MAJOR > 10 || (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR >= 16)
#define TRT_HAS_NATIVE_NCCL 1
#endif

// Full TRT NCCL collectives support requires both:
// 1. PyTorch built with NCCL (USE_C10D_NCCL defined via Bazel)
// 2. TensorRT 10.16+ (TRT_HAS_NATIVE_NCCL defined above)
#if defined(USE_C10D_NCCL) && defined(TRT_HAS_NATIVE_NCCL)
#define ENABLE_TRT_NCCL_COLLECTIVES 1
#endif

namespace torch_tensorrt {
namespace core {
namespace runtime {
Expand Down Expand Up @@ -196,6 +209,29 @@ struct TRTEngine : torch::CustomClassHolder {
bool use_output_allocator_outputs = false; // users specify to use output allocator
std::shared_ptr<DynamicOutputAllocator> output_allocator;

// Member variables for distributed inference
bool is_md = false; // compile-time flag: engine contains NCCL collectives
int64_t rank = -1; // populated at runtime by setup_nccl_comm()
int64_t world_size = -1; // populated at runtime by setup_nccl_comm()
std::string group_name = ""; // c10d registry name; "" = default world group

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
bool nccl_initialized = false; // guards lazy one-shot NCCL setup in execute_engine

// Resolve ProcessGroup via group_name, fetch the NCCL comm from PyTorch,
// and bind it to exec_ctx. Returns true on success. Returns false (without
// throwing) when the process group or NCCL communicator is not yet available
// so callers can retry later. Throws on hard misconfiguration (wrong backend).
bool bind_nccl_comm();

// Detach the NCCL communicator from the execution context by recreating it.
// After this call the process group can be safely destroyed without causing a
// use-after-free in the TRT engine destructor. If the engine is used again
// later (with a new PG), execute_engine() will see nccl_initialized=false
// and re-bind automatically.
void release_nccl_comm();
#endif

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

Expand Down
25 changes: 22 additions & 3 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ void setup_input_tensors(
// Get tensor address, using placeholder for empty tensors
// TensorRT requires non-null address even if numel() = 0
// empty_tensor_placeholder is pre-allocated in TRTEngine constructor
void* input_addr = (final_input.numel() == 0 || final_input.data_ptr() == nullptr)
? compiled_engine->empty_tensor_placeholder
: final_input.data_ptr();
void* input_addr = final_input.numel() == 0 ? compiled_engine->empty_tensor_placeholder : final_input.data_ptr();

TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), input_addr),
Expand Down Expand Up @@ -209,6 +207,27 @@ void create_output_allocator(c10::intrusive_ptr<TRTEngine> compiled_engine) {
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
// All inputs are expected to be on CUDA. Warn and move any that are not.
for (auto& inp : inputs) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to remove this but didnt have time to check if the device operations in python suppress this correctly

if (inp.defined() && !inp.is_cuda()) {
LOG_WARNING(
"Input tensor is not on a CUDA device. Moving it to CUDA automatically. "
"For best performance, ensure all inputs are on the correct CUDA device before "
"calling the TensorRT engine (e.g. tensor.cuda() or tensor.to(device)).");
inp = inp.cuda();
}
}

#ifdef ENABLE_TRT_NCCL_COLLECTIVES
// Lazy one-shot NCCL bind: fires on the first real execute_engine call when
// the constructor-time bind was deferred (e.g. no collective had been issued
// at construction time, or for serialized programs loaded inline where there
// is no Python _TorchTensorRTModule.forward wrapper).
if (compiled_engine->is_md && !compiled_engine->nccl_initialized) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure this is necessary

compiled_engine->bind_nccl_comm();
}
#endif

torch::Tensor dynamic_workspace;
if (compiled_engine->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
dynamic_workspace = torch::empty(compiled_engine->cuda_engine->getDeviceMemorySizeV2(), {torch::kCUDA});
Expand Down
Loading
Loading