Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] remove the hard requirement of taichi #531

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions brainpy/_src/math/brainpylib_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,49 @@
import platform
import ctypes

import taichi as ti
from jax.lib import xla_client

taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Does not found {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Does not found {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

try:
import taichi as ti
except (ImportError, ModuleNotFoundError):
ti = None


def import_taichi():
if ti is None:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
if ti.__version__ < (1, 7, 0):
raise RuntimeError(
'We need taichi>=1.7.0. Currently you can install taichi>=1.7.0 through taichi-nightly:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
return ti


if ti is None:
is_taichi_installed = False
else:
is_taichi_installed = True
taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

# Register the CPU XLA custom calls
try:
Expand Down
51 changes: 21 additions & 30 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from functools import partial, reduce
from typing import Any

import jax.numpy as jnp
import numpy as np
import taichi as ti
from jax.interpreters import xla
from jax.lib import xla_client

import brainpy.math as bm
from .utils import _shape_to_layout
from ..brainpylib_check import import_taichi


### UTILS ###

Expand Down Expand Up @@ -122,25 +122,25 @@ def check_kernel_exist(source_md5_encode: str) -> bool:

### KERNEL AOT BUILD ###

# jnp dtype to taichi type
type_map4template = {
jnp.dtype("bool"): bool,
jnp.dtype("int8"): ti.int8,
jnp.dtype("int16"): ti.int16,
jnp.dtype("int32"): ti.int32,
jnp.dtype("int64"): ti.int64,
jnp.dtype("uint8"): ti.uint8,
jnp.dtype("uint16"): ti.uint16,
jnp.dtype("uint32"): ti.uint32,
jnp.dtype("uint64"): ti.uint64,
jnp.dtype("float16"): ti.float16,
jnp.dtype("float32"): ti.float32,
jnp.dtype("float64"): ti.float64,
}


def _array_to_field(dtype, shape) -> Any:
return ti.field(dtype=type_map4template[dtype], shape=shape)
ti = import_taichi()
if dtype == np.bool_:
dtype = bool
elif dtype == np.int8: dtype= ti.int8
elif dtype == np.int16: dtype= ti.int16
elif dtype == np.int32: dtype= ti.int32
elif dtype == np.int64: dtype= ti.int64
elif dtype == np.uint8: dtype= ti.uint8
elif dtype == np.uint16: dtype= ti.uint16
elif dtype == np.uint32: dtype= ti.uint32
elif dtype == np.uint64: dtype= ti.uint64
elif dtype == np.float16: dtype= ti.float16
elif dtype == np.float32: dtype= ti.float32
elif dtype == np.float64: dtype= ti.float64
else:
raise TypeError
return ti.field(dtype=dtype, shape=shape)


# build aot kernel
Expand All @@ -151,6 +151,8 @@ def build_kernel(
outs: dict,
device: str
):
ti = import_taichi()

# init arch
arch = None
if device == 'cpu':
Expand Down Expand Up @@ -191,17 +193,6 @@ def build_kernel(
int: 0,
float: 1,
bool: 2,
ti.int32: 0,
ti.float32: 1,
ti.u8: 3,
ti.u16: 4,
ti.u32: 5,
ti.u64: 6,
ti.i8: 7,
ti.i16: 8,
ti.i64: 9,
ti.f16: 10,
ti.f64: 11,
np.dtype('int32'): 0,
np.dtype('float32'): 1,
np.dtype('bool'): 2,
Expand Down
18 changes: 0 additions & 18 deletions brainpy/_src/tools/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
except (ImportError, ModuleNotFoundError):
brainpylib = None

try:
import taichi as ti
except (ImportError, ModuleNotFoundError):
ti = None

__all__ = [
'import_numba',
Expand All @@ -27,20 +23,6 @@
]


def import_taichi():
if ti is None:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
if ti.__version__ < (1, 7, 0):
raise RuntimeError(
'We need taichi>=1.7.0. Currently you can install taichi>=1.7.0 through taichi-nightly:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
return ti


def import_numba():
if numba is None:
raise ModuleNotFoundError('Numba is needed. Please install numba through:\n\n'
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ jax
tqdm
msgpack
numba
taichi