forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
85 lines (66 loc) · 2.15 KB
/
setup.py
File metadata and controls
85 lines (66 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import distutils.command.clean
import os
import shutil
import subprocess
from pathlib import Path
from build_tools import setup_helpers
from setuptools import setup, find_packages
def _get_pytorch_version():
if "PYTORCH_VERSION" in os.environ:
return f"torch=={os.environ['PYTORCH_VERSION']}"
return "torch"
def _get_packages():
exclude = [
"build*",
"test*",
"torchrl.csrc*",
"third_party*",
"tools*",
]
return find_packages(exclude=exclude)
ROOT_DIR = Path(__file__).parent.resolve()
class clean(distutils.command.clean.clean):
def run(self):
# Run default behavior first
distutils.command.clean.clean.run(self)
# Remove torchrl extension
for path in (ROOT_DIR / "torchrl").glob("**/*.so"):
print(f"removing '{path}'")
path.unlink()
# Remove build directory
build_dirs = [
ROOT_DIR / "build",
]
for path in build_dirs:
if path.exists():
print(f"removing '{path}' (and everything under it)")
shutil.rmtree(str(path), ignore_errors=True)
def _run_cmd(cmd):
try:
return subprocess.check_output(cmd, cwd=ROOT_DIR).decode("ascii").strip()
except Exception:
return None
def _main():
pytorch_package_dep = _get_pytorch_version()
print("-- PyTorch dependency:", pytorch_package_dep)
# branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
# tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
setup(
name="torchrl",
version="0.1",
author="torchrl contributors",
author_email="vmoens@fb.com",
packages=_get_packages(),
ext_modules=setup_helpers.get_ext_modules(),
cmdclass={
"build_ext": setup_helpers.CMakeBuild,
"clean": clean,
},
install_requires=[pytorch_package_dep, "numpy"],
)
if __name__ == "__main__":
_main()