Skip to content

Commit

Permalink
Fix surrogate gradient function and numpy 2.0 compatibility (#679)
Browse files Browse the repository at this point in the history
* fix surrogate batching

* fix numpy 2.0 compatible issue

* fix numpy 2.0 compatible issue

* updates

* fix numpy2.0 compatible issue

* Skip the operators tests for GitHub action server

* Update

* Update

* Update

* Update test_taichi_based.py

* Update test_get_weight_matrix.py

---------

Co-authored-by: He Sichao <[email protected]>
  • Loading branch information
chaoming0625 and Routhleck authored Jun 18, 2024
1 parent b9461eb commit 5f33a66
Show file tree
Hide file tree
Showing 19 changed files with 124 additions and 88 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
pytest _src/
export IS_GITHUB_ACTIONS=1 && pytest _src/
test_macos:
Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
pytest _src/
export IS_GITHUB_ACTIONS=1 && pytest _src/
test_windows:
Expand Down Expand Up @@ -113,4 +113,4 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
pytest _src/ -p no:faulthandler
set IS_GITHUB_ACTIONS=1 && pytest _src/
3 changes: 2 additions & 1 deletion brainpy/_src/losses/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ def update(self, input, target):


def nll_loss(input, target, reduction: str = 'mean'):
r"""The negative log likelihood loss.
r"""
The negative log likelihood loss.
The negative log likelihood loss. It is useful to train a classification
problem with `C` classes.
Expand Down
35 changes: 6 additions & 29 deletions brainpy/_src/math/compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .interoperability import *
from .ndarray import Array


__all__ = [
'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu',
'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like',
Expand Down Expand Up @@ -92,9 +91,8 @@
'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete',

# unique
'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'array2string', 'asanyarray',
'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'disp', 'genfromtxt',
'loadtxt', 'info', 'issubclass_', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
'asanyarray', 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'genfromtxt',
'loadtxt', 'info', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat',

]
Expand Down Expand Up @@ -204,11 +202,12 @@ def ascontiguousarray(a, dtype=None, order=None):
return asarray(a, dtype=dtype, order=order)


def asfarray(a, dtype=np.float_):
def asfarray(a, dtype=None):
if not np.issubdtype(dtype, np.inexact):
dtype = np.float_
dtype = np.float64
return asarray(a, dtype=dtype)


def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
del assume_unique
ar1_flat = ravel(ar1)
Expand All @@ -227,6 +226,7 @@ def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
else:
return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1))


# Others
# ------
meshgrid = _compatible_with_brainpy_array(jnp.meshgrid)
Expand Down Expand Up @@ -454,7 +454,6 @@ def msort(a):
sometrue = any



def shape(a):
"""
Return the shape of an array.
Expand Down Expand Up @@ -648,7 +647,6 @@ def size(a, axis=None):
finfo = jnp.finfo
iinfo = jnp.iinfo


can_cast = _compatible_with_brainpy_array(jnp.can_cast)
choose = _compatible_with_brainpy_array(jnp.choose)
copy = _compatible_with_brainpy_array(jnp.copy)
Expand Down Expand Up @@ -678,23 +676,6 @@ def size(a, axis=None):
# Unique APIs
# -----------

add_docstring = np.add_docstring
add_newdoc = np.add_newdoc
add_newdoc_ufunc = np.add_newdoc_ufunc


def array2string(a, max_line_width=None, precision=None,
suppress_small=None, separator=' ', prefix="",
style=np._NoValue, formatter=None, threshold=None,
edgeitems=None, sign=None, floatmode=None, suffix="",
legacy=None):
a = as_numpy(a)
return array2string(a, max_line_width=max_line_width, precision=precision,
suppress_small=suppress_small, separator=separator, prefix=prefix,
style=style, formatter=formatter, threshold=threshold,
edgeitems=edgeitems, sign=sign, floatmode=floatmode, suffix=suffix,
legacy=legacy)


def asscalar(a):
return a.item()
Expand Down Expand Up @@ -731,13 +712,9 @@ def common_type(*arrays):
return array_type[0][precision]


disp = np.disp

genfromtxt = lambda *args, **kwargs: asarray(np.genfromtxt(*args, **kwargs))
loadtxt = lambda *args, **kwargs: asarray(np.loadtxt(*args, **kwargs))

info = np.info
issubclass_ = np.issubclass_


def place(arr, mask, vals):
Expand Down
14 changes: 13 additions & 1 deletion brainpy/_src/math/event/tests/test_event_csrmm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
# -*- coding: utf-8 -*-

import os
from functools import partial

import jax
import pytest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('gpu')

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)


# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

seed = 1234


Expand Down
7 changes: 5 additions & 2 deletions brainpy/_src/math/event/tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-


import os
from functools import partial

import jax
Expand All @@ -19,6 +18,10 @@
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

seed = 1234

Expand Down
5 changes: 5 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_event_matvec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import os

import jax
import jax.numpy as jnp
Expand All @@ -16,6 +17,10 @@
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

shapes = [(100, 200), (1000, 10)]

Expand Down
12 changes: 10 additions & 2 deletions brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
import os

import jax.numpy as jnp
import pytest
from absl.testing import parameterized
Expand All @@ -12,8 +14,14 @@
import platform

force_test = False # turn on to force test on windows locally
# if platform.system() == 'Windows' and not force_test:
# pytest.skip('skip windows', allow_module_level=True)
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)


shapes = [
(2, 2),
Expand Down
5 changes: 5 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_matvec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import os

import jax
import jax.numpy as jnp
Expand All @@ -16,6 +17,10 @@
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

shapes = [(100, 200), (1000, 10)]

Expand Down
5 changes: 5 additions & 0 deletions brainpy/_src/math/op_register/tests/test_taichi_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

ti = import_taichi(error_if_not_found=False)
if ti is None:
pytest.skip('no taichi', allow_module_level=True)
Expand Down
16 changes: 15 additions & 1 deletion brainpy/_src/math/sparse/tests/test_csrmm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
# -*- coding: utf-8 -*-


import os
from functools import partial

import jax
import pytest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('gpu')

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

seed = 1234


Expand Down Expand Up @@ -133,7 +146,8 @@ def test_homo_grad(self, transpose, shape, homo_data):
argnums=0)
r1 = dense_f1(homo_data)
r2 = jax.grad(sum_op(bm.sparse.csrmm))(
bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]),
bm.asarray([homo_data]), indices, indptr, matrix,
shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]),
transpose=transpose)

self.assertTrue(bm.allclose(r1, r2))
Expand Down
6 changes: 5 additions & 1 deletion brainpy/_src/math/sparse/tests/test_csrmv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-

import os
from functools import partial

import jax
Expand All @@ -17,6 +17,10 @@
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

seed = 1234

Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/math/surrogate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-


from .base import *
from ._one_input_new import *
from ._two_inputs import *
11 changes: 10 additions & 1 deletion brainpy/_src/math/surrogate/_one_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from .base import Surrogate

__all__ = [
'sigmoid',
Expand All @@ -32,6 +31,16 @@
]


class Surrogate(object):
"""The base surrograte gradient function."""

def __call__(self, *args, **kwargs):
raise NotImplementedError

def __repr__(self):
return f'{self.__class__.__name__}()'


class _OneInpSurrogate(Surrogate):
def __init__(self, forward_use_surrogate=False):
self.forward_use_surrogate = forward_use_surrogate
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/math/surrogate/_one_input_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brainpy._src.math.ndarray import Array

__all__ = [
'Surrogate',
'Sigmoid',
'sigmoid',
'PiecewiseQuadratic',
Expand Down Expand Up @@ -61,7 +62,7 @@ def _heaviside_imp(x, dx):


def _heaviside_batching(args, axes):
return heaviside_p.bind(*args), axes
return heaviside_p.bind(*args), [axes[0]]


def _heaviside_jvp(primals, tangents):
Expand Down
19 changes: 0 additions & 19 deletions brainpy/_src/math/surrogate/base.py

This file was deleted.

6 changes: 0 additions & 6 deletions brainpy/math/compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,20 +327,14 @@
sort_complex as sort_complex,
unpackbits as unpackbits,
delete as delete,
add_docstring as add_docstring,
add_newdoc as add_newdoc,
add_newdoc_ufunc as add_newdoc_ufunc,
array2string as array2string,
asanyarray as asanyarray,
ascontiguousarray as ascontiguousarray,
asfarray as asfarray,
asscalar as asscalar,
common_type as common_type,
disp as disp,
genfromtxt as genfromtxt,
loadtxt as loadtxt,
info as info,
issubclass_ as issubclass_,
place as place,
polydiv as polydiv,
put as put,
Expand Down
Loading

0 comments on commit 5f33a66

Please sign in to comment.