forked from rusty1s/pytorch_spline_conv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build.py
40 lines (33 loc) · 1.07 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os.path as osp
import subprocess
import torch
from torch.utils.ffi import create_extension
files = ['Basis', 'Weighting']
headers = ['aten/TH/TH{}.h'.format(f) for f in files]
sources = ['aten/TH/TH{}.c'.format(f) for f in files]
include_dirs = ['aten/TH']
define_macros = []
extra_objects = []
extra_compile_args = ['-std=c99']
with_cuda = False
if torch.cuda.is_available():
subprocess.call(['./build.sh', osp.dirname(torch.__file__)])
headers += ['aten/THCC/THCC{}.h'.format(f) for f in files]
sources += ['aten/THCC/THCC{}.c'.format(f) for f in files]
include_dirs += ['aten/THCC']
define_macros += [('WITH_CUDA', None)]
extra_objects += ['torch_spline_conv/_ext/THC.so']
with_cuda = True
ffi = create_extension(
name='torch_spline_conv._ext.ffi',
package=True,
headers=headers,
sources=sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_objects=extra_objects,
extra_compile_args=extra_compile_args,
with_cuda=with_cuda,
relative_to=__file__)
if __name__ == '__main__':
ffi.build()