Skip to content

Commit 9022e03

Browse files
committed
ci: fix nccl builds in CI
1 parent 0d2d61c commit 9022e03

7 files changed

Lines changed: 22 additions & 8 deletions

File tree

docker/MODULE.bazel.docker

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.16.0")
88
bazel_dep(name = "platforms", version = "0.0.11")
99
bazel_dep(name = "rules_cc", version = "0.1.1")
1010
bazel_dep(name = "rules_python", version = "1.3.0")
11+
bazel_dep(name = "bazel_skylib", version = "1.7.1")
1112

1213
python = use_extension("@rules_python//python/extensions:python.bzl", "python")
1314
python.toolchain(
@@ -24,9 +25,12 @@ git_override(
2425

2526
local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.bzl", "local_repository")
2627

27-
2828
new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.bzl", "new_local_repository")
2929

30+
torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect")
31+
32+
torch_nccl_detect(name = "torch_nccl")
33+
3034
# CUDA should be installed on the system locally
3135
new_local_repository(
3236
name = "cuda",

docker/MODULE.bazel.ngc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.14.0")
88
bazel_dep(name = "platforms", version = "0.0.10")
99
bazel_dep(name = "rules_cc", version = "0.0.9")
1010
bazel_dep(name = "rules_python", version = "0.34.0")
11+
bazel_dep(name = "bazel_skylib", version = "1.7.1")
1112

1213
python = use_extension("@rules_python//python/extensions:python.bzl", "python")
1314
python.toolchain(
@@ -24,9 +25,12 @@ git_override(
2425

2526
local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.bzl", "local_repository")
2627

27-
2828
new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.bzl", "new_local_repository")
2929

30+
torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect")
31+
32+
torch_nccl_detect(name = "torch_nccl")
33+
3034

3135
# External dependency for torch_tensorrt if you already have precompiled binaries.
3236
new_local_repository(

py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ def _register_with_torch() -> None:
9999
from torch_tensorrt.dynamo import backend # noqa: F401
100100
from torch_tensorrt import dynamo # noqa: F401
101101

102-
from torch_tensorrt import distributed # noqa: F401
103102
from torch_tensorrt._compile import * # noqa: F403
104103
from torch_tensorrt.distributed._distributed import ( # noqa: F401
105104
distributed_group,
106105
set_distributed_group,
107106
)
107+
from torch_tensorrt import distributed # noqa: F401
108108
from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
109109
MutableTorchTensorRTModule,
110110
)

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def run(self):
472472

473473
dynamo_packages = [
474474
"torch_tensorrt",
475+
"torch_tensorrt.distributed",
476+
"torch_tensorrt.distributed.run",
475477
"torch_tensorrt.dynamo",
476478
"torch_tensorrt.dynamo.backend",
477479
"torch_tensorrt.dynamo.conversion",
@@ -506,6 +508,8 @@ def run(self):
506508

507509
dynamo_package_dir = {
508510
"torch_tensorrt": "py/torch_tensorrt",
511+
"torch_tensorrt.distributed": "py/torch_tensorrt/distributed",
512+
"torch_tensorrt.distributed.run": "py/torch_tensorrt/distributed/run",
509513
"torch_tensorrt.dynamo": "py/torch_tensorrt/dynamo",
510514
"torch_tensorrt.dynamo.backend": "py/torch_tensorrt/dynamo/backend",
511515
"torch_tensorrt.dynamo.conversion": "py/torch_tensorrt/dynamo/conversion",

tests/py/dynamo/backend/test_backend_compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def forward(self, x, y):
7878
unexpected_ops = {torch.ops.aten.add.Tensor}
7979

8080
inputs = [
81-
torch.randint(-40, 40, (16, 7, 5), dtype=torch.int).cuda(),
82-
torch.randint(1, 40, (16, 7, 5), dtype=torch.int).cuda(),
81+
torch.randn(16, 7, 5, dtype=torch.float).cuda(),
82+
torch.randn(16, 7, 5, dtype=torch.float).cuda(),
8383
]
8484

8585
(

toolchains/ci_workspaces/MODULE.bazel.tmpl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ bazel_dep(name = "googletest", version = "1.16.0")
88
bazel_dep(name = "platforms", version = "0.0.11")
99
bazel_dep(name = "rules_cc", version = "0.1.1")
1010
bazel_dep(name = "rules_python", version = "1.3.0")
11+
bazel_dep(name = "bazel_skylib", version = "1.7.1")
1112

1213
python = use_extension("@rules_python//python/extensions:python.bzl", "python")
1314
python.toolchain(
@@ -24,6 +25,10 @@ git_override(
2425

2526
local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.bzl", "local_repository")
2627

28+
torch_nccl_detect = use_repo_rule("//toolchains/torch_nccl:defs.bzl", "torch_nccl_detect")
29+
30+
torch_nccl_detect(name = "torch_nccl")
31+
2732
# External dependency for torch_tensorrt if you already have precompiled binaries.
2833
local_repository(
2934
name = "torch_tensorrt",

uv.lock

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)