Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""PyTorch related extensions."""

import importlib.util
import os
from pathlib import Path
from importlib import metadata
Expand Down Expand Up @@ -58,6 +59,25 @@ def setup_pytorch_extension(
]
)

# apache-tvm-ffi: headers for the C++ API (Module / Function / TensorView)
# and libtvm_ffi.so for symbol resolution. Used by tvm_ffi_bridge.h /
# applyTVMFunction. Python registers AOT-compiled CuTeDSL kernels into
# the global registry; TE C++ looks them up via Function::GetGlobalRequired.
tvm_ffi_spec = importlib.util.find_spec("tvm_ffi")
if tvm_ffi_spec is None or not tvm_ffi_spec.submodule_search_locations:
raise RuntimeError(
"apache-tvm-ffi package not found; install it (e.g. "
"`pip install apache-tvm-ffi`) — required for the TVM FFI bridge."
)
tvm_ffi_root = Path(tvm_ffi_spec.submodule_search_locations[0])
tvm_ffi_include = tvm_ffi_root / "include"
tvm_ffi_lib_dir = tvm_ffi_root / "lib"
if not tvm_ffi_include.is_dir() or not (tvm_ffi_lib_dir / "libtvm_ffi.so").exists():
raise RuntimeError(
f"apache-tvm-ffi assets missing at {tvm_ffi_root} (need include/ and lib/libtvm_ffi.so)"
)
include_dirs.append(tvm_ffi_include)

# Compiler flags
cxx_flags = ["-O3", "-fvisibility=hidden"]
if debug_build_enabled():
Expand All @@ -77,8 +97,11 @@ def setup_pytorch_extension(

setup_mpi_flags(include_dirs, cxx_flags)

library_dirs = []
libraries = []
library_dirs = [tvm_ffi_lib_dir]
libraries = ["tvm_ffi"]
# rpath pinned to the pip install dir so the loader finds libtvm_ffi.so
# without LD_LIBRARY_PATH at runtime.
extra_link_args = [f"-Wl,-rpath,{tvm_ffi_lib_dir}"]
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
Expand All @@ -102,6 +125,7 @@ def setup_pytorch_extension(
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={"cxx": cxx_flags},
extra_link_args=extra_link_args,
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
# apache-tvm-ffi is required at configure/compile/link time: the common C++
# library finds it via find_package(tvm_ffi) and links libtvm_ffi.so (the
# CuTeDSL quant backend bridge). It is also a runtime dependency (see setup.py).
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1", "apache-tvm-ffi>=0.1.12"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
"importlib-metadata>=1.0",
"packaging",
cusolvermp_pypi_package_name(),
# The core C++ library links libtvm_ffi.so (CuTeDSL quant backend bridge),
# so apache-tvm-ffi is required at runtime by every TE install.
"apache-tvm-ffi>=0.1.12",
]
test_reqs: List[str] = ["pytest>=8.2.1"]

Expand Down Expand Up @@ -363,13 +366,19 @@ def git_check_submodules() -> None:
"core_cu13": [f"transformer_engine_cu13=={__version__}"],
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"cutedsl": [
"nvidia-cutlass-dsl>=4.2.0"
], # TODO: explain this in the docs when shipping this: `pip3 install --no-build-isolation '.[cutedsl]' `
}
else:
install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
package_data = {"": ["VERSION.txt"]}
include_package_data = True
extras_require = {"test": test_requires}
extras_require = {
"test": test_requires,
"cutedsl": ["nvidia-cutlass-dsl>=4.2.0"],
}

if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
Expand Down
Loading
Loading