forked from LMCache/LMCache
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
169 lines (145 loc) · 5.13 KB
/
setup.py
File metadata and controls
169 lines (145 loc) · 5.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# SPDX-License-Identifier: Apache-2.0
# Standard
from pathlib import Path
import os
import sys
# Third Party
from setuptools import find_packages, setup
ROOT_DIR = Path(__file__).parent
HIPIFY_DIR = os.path.join(ROOT_DIR, "csrc/")
HIPIFY_OUT_DIR = os.path.join(ROOT_DIR, "csrc_hip/")
# python -m build --sdist
# will run python setup.py sdist --dist-dir dist
BUILDING_SDIST = "sdist" in sys.argv or os.environ.get("NO_CUDA_EXT", "0") == "1"
# New environment variable to choose between CUDA and HIP
BUILD_WITH_HIP = os.environ.get("BUILD_WITH_HIP", "0") == "1"
ENABLE_CXX11_ABI = os.environ.get("ENABLE_CXX11_ABI", "1") == "1"
def hipify_wrapper() -> None:
# Third Party
from torch.utils.hipify.hipify_python import hipify
print("Hipifying sources ")
# Get absolute path for all source files.
extra_files = [
os.path.abspath(os.path.join(HIPIFY_DIR, item))
for item in os.listdir(HIPIFY_DIR)
if os.path.isfile(os.path.join(HIPIFY_DIR, item))
]
hipify_result = hipify(
project_directory=HIPIFY_DIR,
output_directory=HIPIFY_OUT_DIR,
header_include_dirs=[],
includes=[],
extra_files=extra_files,
show_detailed=True,
is_pytorch_extension=True,
hipify_extra_files_only=True,
)
hipified_sources = []
for source in extra_files:
s_abs = os.path.abspath(source)
hipified_s_abs = (
hipify_result[s_abs].hipified_path
if (
s_abs in hipify_result
and hipify_result[s_abs].hipified_path is not None
)
else s_abs
)
hipified_sources.append(hipified_s_abs)
assert len(hipified_sources) == len(extra_files)
def cuda_extension() -> tuple[list, dict]:
# Third Party
from torch.utils import cpp_extension # Import here
print("Building CUDA extensions")
global ENABLE_CXX11_ABI
if ENABLE_CXX11_ABI:
flag_cxx_abi = "-D_GLIBCXX_USE_CXX11_ABI=1"
else:
flag_cxx_abi = "-D_GLIBCXX_USE_CXX11_ABI=0"
cuda_sources = [
"csrc/pybind.cpp",
"csrc/mem_kernels.cu",
"csrc/cal_cdf.cu",
"csrc/ac_enc.cu",
"csrc/ac_dec.cu",
"csrc/pos_kernels.cu",
"csrc/mem_alloc.cpp",
"csrc/utils.cpp",
]
ext_modules = [
cpp_extension.CUDAExtension(
"lmcache.c_ops",
sources=cuda_sources,
extra_compile_args={
"cxx": [flag_cxx_abi],
"nvcc": [flag_cxx_abi],
},
),
]
cmdclass = {"build_ext": cpp_extension.BuildExtension}
return ext_modules, cmdclass
def rocm_extension() -> tuple[list, dict]:
# Third Party
from torch.utils import cpp_extension # Import here
print("Building ROCM extensions")
hipify_wrapper()
hip_sources = [
"csrc/pybind_hip.cpp", # Use the hipified pybind
"csrc/mem_kernels.hip",
"csrc/cal_cdf.hip",
"csrc/ac_enc.hip",
"csrc/ac_dec.hip",
"csrc/pos_kernels.hip",
]
# For HIP, we generally use CppExtension and let hipcc handle things.
# Ensure CXX environment variable is set to hipcc when running this build.
# e.g., CXX=hipcc python setup.py install
define_macros = [("__HIP_PLATFORM_HCC__", "1"), ("USE_ROCM", "1")]
ext_modules = [
cpp_extension.CppExtension(
"lmcache.c_ops",
sources=hip_sources,
extra_compile_args={
"cxx": [ # hipcc is typically invoked as a C++ compiler
# '-D_GLIBCXX_USE_CXX11_ABI=0',
"-O3"
# Add any HIP specific flags if needed.
# For example, if you need to specify ROCm architecture:
# '--offload-arch=gfx942' # (replace with your target arch)
# '-x hip' # Sometimes needed to explicitly treat files as HIP
],
# No 'nvcc' key for hipcc with CppExtension
},
# You might need to specify include paths for ROCm if not found
# automatically
include_dirs=[
os.path.join(os.environ.get("ROCM_PATH", "/opt/rocm"), "include")
],
library_dirs=[
os.path.join(os.environ.get("ROCM_PATH", "/opt/rocm"), "lib")
],
# libraries=['amdhip64'] # Or other relevant HIP libs if needed
define_macros=define_macros,
)
]
cmdclass = {"build_ext": cpp_extension.BuildExtension}
return ext_modules, cmdclass
def source_dist_extension() -> tuple[list, dict]:
print("Not building CUDA/HIP extensions for sdist")
return [], {}
if __name__ == "__main__":
if BUILDING_SDIST:
get_extension = source_dist_extension
elif BUILD_WITH_HIP:
get_extension = rocm_extension
else:
get_extension = cuda_extension
ext_modules, cmdclass = get_extension()
setup(
packages=find_packages(
exclude=("csrc",)
), # Ensure csrc is excluded if it only contains sources
ext_modules=ext_modules,
cmdclass=cmdclass,
include_package_data=True,
)