diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index ef73fa9f..997da4a8 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -29,8 +29,8 @@ numba_install_info = ('We need numba. Please install numba by pip . \n' '> pip install numba') cupy_install_info = ('We need cupy. Please install cupy by pip . \n' - 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n' - 'For CUDA v12.x > pip install cupy-cuda12x\n') + 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n' + 'For CUDA v12.x > pip install cupy-cuda12x\n') os.environ["TI_LOG_LEVEL"] = "error" @@ -105,9 +105,11 @@ def import_cupy(error_if_not_found=True): return None return cupy + def raise_cupy_not_found(): raise ModuleNotFoundError(cupy_install_info) + def is_brainpylib_gpu_installed(): return False if brainpylib_gpu_ops is None else True diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index e13382cb..ca0624a4 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -127,7 +127,7 @@ def __init__( gpu_checked = False if gpu_kernel is None: gpu_checked = True - elif isinstance(gpu_kernel, str): # cupy + elif isinstance(gpu_kernel, str): # cupy register_cupy_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi diff --git a/brainpy/_src/math/op_register/tests/test_cupy_based.py b/brainpy/_src/math/op_register/tests/test_cupy_based.py index 8073c428..b7656c43 100644 --- a/brainpy/_src/math/op_register/tests/test_cupy_based.py +++ b/brainpy/_src/math/op_register/tests/test_cupy_based.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp - import pytest + import brainpy.math as bm from brainpy._src.dependency_check import import_cupy @@ -43,7 +43,7 @@ def test_cupy_based(): # n = jnp.asarray([N**2,], dtype=jnp.int32) - y = prim(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0] + y = prim(x1, x2, N ** 2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0] print(y) assert jnp.allclose(y, x1 + x2) @@ -57,5 +57,4 @@ def test_cupy_based(): # ker_times((N,), (N,), (x1, x2, y, N**2)) # y = x1 * x2 # assert cp.allclose(y, x1 * x2) - # test_cupy_based()