Skip to content

Commit

Permalink
Merge pull request #535 from chaoming0625/master
Browse files Browse the repository at this point in the history
[brainpy.share] add category shared info
  • Loading branch information
chaoming0625 authored Nov 5, 2023
2 parents c014976 + a97771a commit 1d35e2e
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 245 deletions.
15 changes: 15 additions & 0 deletions brainpy/_src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
# -------------

self._arguments = DotDict()
self._category = dict()

@property
def dt(self):
Expand Down Expand Up @@ -95,5 +96,19 @@ def clear(self) -> None:
"""Clear all shared data in this computation context."""
self._arguments.clear()

def save_category(self, category, **kwargs):
if category not in self._category:
self._category[category] = dict()
self._category[category].update(**kwargs)

def clear_category(self, category=None):
if category is None:
self._category.clear()
else:
self._category.pop(category)

def get_category(self, category):
return self._category[category]


share = _ShareContext()
7 changes: 5 additions & 2 deletions brainpy/_src/math/compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round',
'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod',
'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum',
'cumprod', 'cumsum', 'ediff1d', 'cross', 'trapz', 'isfinite', 'isinf',
'cumprod', 'cumsum', 'ediff1d', 'cross', 'isfinite', 'isinf',
'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve',
'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside',
'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle',
Expand Down Expand Up @@ -381,7 +381,10 @@ def msort(a):
nansum = _compatible_with_brainpy_array(jnp.nansum)
ediff1d = _compatible_with_brainpy_array(jnp.ediff1d)
cross = _compatible_with_brainpy_array(jnp.cross)
trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid)
if jax.__version__ >= '0.4.18':
trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid)
else:
trapz = _compatible_with_brainpy_array(jnp.trapz)
isfinite = _compatible_with_brainpy_array(jnp.isfinite)
isinf = _compatible_with_brainpy_array(jnp.isinf)
isnan = _compatible_with_brainpy_array(jnp.isnan)
Expand Down
12 changes: 0 additions & 12 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import inspect
import os
from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional

import jax
import numpy as np
import taichi as ti
from jax.interpreters import xla, batching, ad, mlir
from numba.core.dispatcher import Dispatcher

Expand Down Expand Up @@ -231,12 +228,3 @@ def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)


def _set_taichi_envir():
# find the path of taichi in python site_packages
taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
taichi_lib_dir = os.path.join(taichi_path, '_lib', 'runtime')
os.environ.update({
'TAICHI_C_API_INSTALL_DIR': taichi_c_api_install_dir,
'TI_LIB_DIR': taichi_lib_dir
})
2 changes: 1 addition & 1 deletion brainpy/math/compat_numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

from brainpy._src.math.compat_numpy import (
trapz as trapz,
fill_diagonal as fill_diagonal,
empty as empty,
empty_like as empty_like,
Expand Down Expand Up @@ -95,7 +96,6 @@
cumsum as cumsum,
ediff1d as ediff1d,
cross as cross,
trapz as trapz,
isfinite as isfinite,
isinf as isinf,
isnan as isnan,
Expand Down
Loading

0 comments on commit 1d35e2e

Please sign in to comment.