Skip to content

Commit

Permalink
explicit path of third_party nvfuser
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Apr 19, 2023
1 parent ecca2f7 commit 7642c1c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions csrc/instance_norm_nvfuser_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <aten/src/ATen/native/utils/ParamsHash.h>

#if NVFUSER_THIRDPARTY
#include <fusion.h>
#include <kernel_cache.h>
#include <ops/all_ops.h>
#else
Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,15 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
os.path.join(d) for d in os.listdir(os.path.join(PYTORCH_HOME, "third_party"))
if os.path.isdir(os.path.join(os.path.join(PYTORCH_HOME, "third_party"), d))
)
import nvfuser # NOQA
print(PYTORCH_HOME)
include_dirs = [PYTORCH_HOME]
library_dirs = []
extra_link_args = []
if nvfuser_is_refactored:
include_dirs.append(os.path.join(PYTORCH_HOME, "third_party/nvfuser/csrc"))
library_dirs = nvfuser.__path__
extra_link_args.append("-lnvfuser")
ext_modules.append(
CUDAExtension(
name='instance_norm_nvfuser_cuda',
Expand All @@ -377,6 +382,8 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
'csrc/instance_norm_nvfuser_kernel.cu',
],
include_dirs=include_dirs,
library_dirs=library_dirs,
extra_link_args=extra_link_args,
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
"nvcc": ["-O3"] + version_dependent_macros + [f"-DNVFUSER_THIRDPARTY={int(nvfuser_is_refactored)}"],
Expand Down

0 comments on commit 7642c1c

Please sign in to comment.