Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Nov 23, 2024
1 parent bbb922c commit ec39076
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 42 deletions.
9 changes: 5 additions & 4 deletions brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp

import brainpy as bp
import brainpy.math as bm
Expand Down Expand Up @@ -100,11 +101,11 @@ def test_CSRLinear(self, conn):
bm.random.seed()
f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
x = bm.random.random((16, 100))
y = f(x)
y = f(jnp.asarray(x))
self.assertTrue(y.shape == (16, 100))

x = bm.random.random((100,))
y = f(x)
y = f(jnp.asarray(x))
self.assertTrue(y.shape == (100,))
bm.clear_buffer_memory()

Expand All @@ -119,10 +120,10 @@ def test_EventCSRLinear(self, conn):
bm.random.seed()
f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal())
x = bm.random.random((16, 100))
y = f(x)
y = f(jnp.asarray(x))
self.assertTrue(y.shape == (16, 100))
x = bm.random.random((100,))
y = f(x)
y = f(jnp.asarray(x))
self.assertTrue(y.shape == (100,))
bm.clear_buffer_memory()

Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .compat_numpy import *
from .compat_tensorflow import *
from .others import *
from . import random, linalg, fft, tifunc
from . import random, linalg, fft

# operators
from .op_register import *
Expand Down
14 changes: 6 additions & 8 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject

is_version_right = False
if jax.__version__ >= '0.4.16':
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
from braintaichi._primitive._mlir_translation_rule import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (
register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (
register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
is_version_right = True

from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

Expand Down Expand Up @@ -73,6 +69,8 @@ def __init__(
outs: Optional[Callable] = None,
name: str = None,
):
if not is_version_right:
raise ImportError('XLA Custom Op is only supported in JAX>=0.4.16')
super().__init__(name)

# set cpu_kernel and gpu_kernel
Expand Down
2 changes: 0 additions & 2 deletions brainpy/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from . import linalg
from . import random

# taichi operations
from . import tifunc

# others
from . import sharding
Expand Down
2 changes: 0 additions & 2 deletions brainpy/math/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from brainpy._src.math.op_register import (
CustomOpByNumba,
compile_cpu_signature_with_numba,
clear_taichi_aot_caches,
count_taichi_aot_kernels,
)

from brainpy._src.math.op_register.base import XLACustomOp
Expand Down
25 changes: 0 additions & 25 deletions brainpy/math/tifunc.py

This file was deleted.

0 comments on commit ec39076

Please sign in to comment.