Skip to content

Commit

Permalink
try to remove hard dependency with taichi and numba
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 17, 2024
1 parent df04a81 commit 55c69bb
Show file tree
Hide file tree
Showing 25 changed files with 2,318 additions and 5,098 deletions.
17 changes: 11 additions & 6 deletions brainpy/_src/math/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,20 @@
# '''Default integer data type.'''
int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32

# '''Default integer data type in Taichi.'''
ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32

# '''Default float data type.'''
float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32

# '''Default float data type in Taichi.'''
ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32

# '''Default complex data type.'''
complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64


if ti is not None:
# '''Default integer data type in Taichi.'''
ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32

# '''Default float data type in Taichi.'''
ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32

else:
ti_int = None
ti_float = None
21 changes: 14 additions & 7 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,16 @@ def set_float(dtype: type):
"""
if dtype in [jnp.float16, 'float16', 'f16']:
defaults.__dict__['float_'] = jnp.float16
defaults.__dict__['ti_float'] = ti.float16
if ti is not None:
defaults.__dict__['ti_float'] = ti.float16
elif dtype in [jnp.float32, 'float32', 'f32']:
defaults.__dict__['float_'] = jnp.float32
defaults.__dict__['ti_float'] = ti.float32
if ti is not None:
defaults.__dict__['ti_float'] = ti.float32
elif dtype in [jnp.float64, 'float64', 'f64']:
defaults.__dict__['float_'] = jnp.float64
defaults.__dict__['ti_float'] = ti.float64
if ti is not None:
defaults.__dict__['ti_float'] = ti.float64
else:
raise NotImplementedError

Expand All @@ -448,16 +451,20 @@ def set_int(dtype: type):
"""
if dtype in [jnp.int8, 'int8', 'i8']:
defaults.__dict__['int_'] = jnp.int8
defaults.__dict__['ti_int'] = ti.int8
if ti is not None:
defaults.__dict__['ti_int'] = ti.int8
elif dtype in [jnp.int16, 'int16', 'i16']:
defaults.__dict__['int_'] = jnp.int16
defaults.__dict__['ti_int'] = ti.int16
if ti is not None:
defaults.__dict__['ti_int'] = ti.int16
elif dtype in [jnp.int32, 'int32', 'i32']:
defaults.__dict__['int_'] = jnp.int32
defaults.__dict__['ti_int'] = ti.int32
if ti is not None:
defaults.__dict__['ti_int'] = ti.int32
elif dtype in [jnp.int64, 'int64', 'i64']:
defaults.__dict__['int_'] = jnp.int64
defaults.__dict__['ti_int'] = ti.int64
if ti is not None:
defaults.__dict__['ti_int'] = ti.int64
else:
raise NotImplementedError

Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/math/event/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from ._info_collection import *
from ._csr_matvec import *

Loading

0 comments on commit 55c69bb

Please sign in to comment.