Skip to content

Commit

Permalink
[math] compatible brainpy.math.trapz
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Nov 5, 2023
1 parent d75af21 commit 4c4b9df
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 16 deletions.
7 changes: 5 additions & 2 deletions brainpy/_src/math/compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round',
'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod',
'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum',
'cumprod', 'cumsum', 'ediff1d', 'cross', 'trapz', 'isfinite', 'isinf',
'cumprod', 'cumsum', 'ediff1d', 'cross', 'isfinite', 'isinf',
'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve',
'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside',
'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle',
Expand Down Expand Up @@ -381,7 +381,10 @@ def msort(a):
nansum = _compatible_with_brainpy_array(jnp.nansum)
ediff1d = _compatible_with_brainpy_array(jnp.ediff1d)
cross = _compatible_with_brainpy_array(jnp.cross)
trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid)
if jax.__version__ >= '0.4.18':
trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid)
else:
trapz = _compatible_with_brainpy_array(jnp.trapz)
isfinite = _compatible_with_brainpy_array(jnp.isfinite)
isinf = _compatible_with_brainpy_array(jnp.isinf)
isnan = _compatible_with_brainpy_array(jnp.isnan)
Expand Down
13 changes: 0 additions & 13 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import inspect
import os
from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional

import jax
import numpy as np
import taichi as ti
from jax.interpreters import xla, batching, ad, mlir
from numba.core.dispatcher import Dispatcher

Expand Down Expand Up @@ -230,13 +227,3 @@ def _transform_to_array(a):
def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)


def _set_taichi_envir():
# find the path of taichi in python site_packages
taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
taichi_lib_dir = os.path.join(taichi_path, '_lib', 'runtime')
os.environ.update({
'TAICHI_C_API_INSTALL_DIR': taichi_c_api_install_dir,
'TI_LIB_DIR': taichi_lib_dir
})
2 changes: 1 addition & 1 deletion brainpy/math/compat_numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

from brainpy._src.math.compat_numpy import (
trapz as trapz,
fill_diagonal as fill_diagonal,
empty as empty,
empty_like as empty_like,
Expand Down Expand Up @@ -95,7 +96,6 @@
cumsum as cumsum,
ediff1d as ediff1d,
cross as cross,
trapz as trapz,
isfinite as isfinite,
isinf as isinf,
isnan as isnan,
Expand Down

0 comments on commit 4c4b9df

Please sign in to comment.