-
Notifications
You must be signed in to change notification settings - Fork 281
/
Copy pathsetup.py
87 lines (72 loc) · 3.08 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import os
import re
import setuptools
this_dir = os.path.dirname(os.path.abspath(__file__))
def fetch_requirements():
with open("requirements.txt") as f:
reqs = f.read().strip().split("\n")
return reqs
# https://packaging.python.org/guides/single-sourcing-package-version/
def find_version(version_file_path) -> str:
with open(version_file_path) as version_file:
version_match = re.search(r"^__version_tuple__ = (.*)", version_file.read(), re.M)
if version_match:
ver_tup = eval(version_match.group(1))
ver_str = ".".join([str(x) for x in ver_tup])
return ver_str
raise RuntimeError("Unable to find version tuple.")
extensions = []
cmdclass = {}
setup_requires = []
if os.getenv("BUILD_CUDA_EXTENSIONS", "0") == "1":
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup_requires = ["ninja"]
extensions.extend(
[
CUDAExtension(
name="fairscale.fused_adam_cuda",
include_dirs=[os.path.join(this_dir, "fairscale/clib/fused_adam_cuda")],
sources=[
"fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp",
"fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu",
],
extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"]},
)
]
)
cmdclass["build_ext"] = BuildExtension
if __name__ == "__main__":
setuptools.setup(
name="fairscale",
description="FairScale: A PyTorch library for large-scale and high-performance training.",
version=find_version("fairscale/version.py"),
setup_requires=setup_requires,
install_requires=fetch_requirements(),
include_package_data=True,
packages=setuptools.find_packages(include=["fairscale*"]), # Only include code within fairscale.
ext_modules=extensions,
cmdclass=cmdclass,
python_requires=">=3.8",
author="Foundational AI Research @ Meta AI",
author_email="[email protected]",
long_description=(
"FairScale is a PyTorch extension library for high performance and "
"large scale training on one or multiple machines/nodes. This library "
"extends basic PyTorch capabilities while adding new experimental ones."
),
long_description_content_type="text/markdown",
entry_points={"console_scripts": ["wgit = fairscale.experimental.wgit.__main__:main"]},
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"License :: OSI Approved :: BSD License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Operating System :: OS Independent",
],
)