Skip to content

Commit

Permalink
Merge branch 'master' into updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Apr 3, 2022
2 parents 822d56f + 47b7539 commit e35b09d
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 56 deletions.
71 changes: 41 additions & 30 deletions brainpy/nn/nodes/ANN/batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
# -*- coding: utf-8 -*-

"""
adapted from jax.example_libraries.stax.BatchNorm
https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm
"""


from typing import Union

import jax.nn
Expand All @@ -29,14 +23,23 @@ class BatchNorm(Node):
Most commonly, the first axis of the data is the batch, and the last is
the channel. However, users can specify the axes to be normalized.
adapted from jax.example_libraries.stax.BatchNorm
https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm
Parameters
----------
axis: axes where the data will be normalized. The axis of channels should be excluded.
epsilon: a value added to the denominator for numerical stability. Default: 1e-5
translate: whether to translate data in refactoring
scale: whether to scale data in refactoring
beta_init: an initializer generating the original translation matrix
gamma_init: an initializer generating the original scaling matrix
axis: int, tuple, list
axes where the data will be normalized. The axis of channels should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
translate: bool
whether to translate data in refactoring
scale: bool
whether to scale data in refactoring
beta_init: brainpy.init.Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
"""
def __init__(self,
axis: Union[int, tuple, list],
Expand Down Expand Up @@ -86,10 +89,14 @@ class BatchNorm1d(BatchNorm):
axes where the data will be normalized. The axis of channels should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
translate: whether to translate data in refactoring
scale: whether to scale data in refactoring
beta_init: an initializer generating the original translation matrix
gamma_init: an initializer generating the original scaling matrix
translate: bool
whether to translate data in refactoring
scale: bool
whether to scale data in refactoring
beta_init: brainpy.init.Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
"""
def __init__(self, axis=(0, 1), **kwargs):
super(BatchNorm1d, self).__init__(axis=axis, **kwargs)
Expand Down Expand Up @@ -138,20 +145,24 @@ def _check_input_dim(self):

class BatchNorm3d(BatchNorm):
"""3-D batch normalization.
The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
`h` is the height dimension, `w` is the width dimension, `d` is the depth
dimension, and `c` is the channel dimension.
Parameters
----------
axis: int, tuple, list
axes where the data will be normalized. The axis of channels should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
translate: whether to translate data in refactoring
scale: whether to scale data in refactoring
beta_init: an initializer generating the original translation matrix
gamma_init: an initializer generating the original scaling matrix
The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
`h` is the height dimension, `w` is the width dimension, `d` is the depth
dimension, and `c` is the channel dimension.
Parameters
----------
axis: int, tuple, list
axes where the data will be normalized. The axis of channels should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
translate: bool
whether to translate data in refactoring
scale: bool
whether to scale data in refactoring
beta_init: brainpy.init.Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
"""
def __init__(self, axis=(0, 1, 2, 3), **kwargs):
super(BatchNorm3d, self).__init__(axis=axis, **kwargs)
Expand Down
42 changes: 21 additions & 21 deletions brainpy/nn/nodes/ANN/tests/test_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,26 @@ def test_batchnorm1(self):

print(model(inputs))

# def test_batchnorm2(self):
# i = bp.nn.Input((3, 4))
# b = bp.nn.BatchNorm(axis=(0, 2)) # channel axis: 1
# f = bp.nn.Reshape((-1, 12))
# o = bp.nn.GeneralDense(2)
# model = i >> b >>f >> o
# model.initialize(num_batch=2)
#
# inputs = bp.math.ones((2, 3, 4))
# inputs[0, 0, :] = 2.
# inputs[0, 1, 0] = 5.
# # print(inputs)
# print(model(inputs))
#
#
# X = bp.math.random.random((1000, 10, 3, 4))
# Y = bp.math.random.randint(0, 2, (1000, 10, 2))
# trainer = bp.nn.BPTT(model,
# loss=bp.losses.cross_entropy_loss,
# optimizer=bp.optim.Adam(lr=1e-3))
# trainer.fit([X, Y])
def test_batchnorm2(self):
i = bp.nn.Input((3, 4))
b = bp.nn.BatchNorm(axis=(0, 2)) # channel axis: 1
f = bp.nn.Reshape((-1, 12))
o = bp.nn.GeneralDense(2)
model = i >> b >>f >> o
model.initialize(num_batch=2)

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
# print(inputs)
print(model(inputs))


X = bp.math.random.random((1000, 10, 3, 4))
Y = bp.math.random.randint(0, 2, (1000, 10, 2))
trainer = bp.nn.BPTT(model,
loss=bp.losses.cross_entropy_loss,
optimizer=bp.optim.Adam(lr=1e-3))
trainer.fit([X, Y])


2 changes: 1 addition & 1 deletion extensions/brainpylib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "0.0.4"
__version__ = "0.0.5"

# IMPORTANT, must import first
from . import register_custom_calls
Expand Down
11 changes: 11 additions & 0 deletions extensions/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
Release notes (brainpylib)
##########################

Version 0.0.5
=============

- Support operator customization on GPU by ``numba``


Version 0.0.4
=============

- Support operator customization on CPU by ``numba``


Version 0.0.3
=============
Expand Down
2 changes: 1 addition & 1 deletion extensions/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
author_email='[email protected]',
packages=find_packages(exclude=['lib*']),
include_package_data=True,
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8"],
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"],
extras_require={"test": "pytest"},
python_requires='>=3.7',
url='https://github.com/PKU-NIP-Lab/BrainPy',
Expand Down
2 changes: 1 addition & 1 deletion extensions/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def build_extension(self, ext):
author_email='[email protected]',
packages=find_packages(exclude=['lib*']),
include_package_data=True,
install_requires=["jax", "jaxlib"],
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"],
extras_require={"test": "pytest"},
python_requires='>=3.7',
url='https://github.com/PKU-NIP-Lab/BrainPy',
Expand Down
4 changes: 2 additions & 2 deletions extensions/setup_mac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
sources=["lib/cpu_ops.cc"] + glob.glob("lib/*_cpu.cc"),
cxx_std=11,
# extra_link_args=["-rpath", "/Users/ztqakita/miniforge3/lib"], # m1
extra_link_args=["-rpath", "/Users/ztqakita/miniforge3/lib"], # intel
extra_link_args=["-rpath", "/Users/ztqakita/opt/miniconda3/lib"], # intel
define_macros=[('VERSION_INFO', __version__)]),
]

Expand All @@ -36,7 +36,7 @@
author_email='[email protected]',
packages=find_packages(exclude=['lib*']),
include_package_data=True,
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8"],
install_requires=["jax", "jaxlib", "pybind11>=2.6, <2.8", "cffi", "numba"],
extras_require={"test": "pytest"},
python_requires='>=3.7',
url='https://github.com/PKU-NIP-Lab/BrainPy',
Expand Down

0 comments on commit e35b09d

Please sign in to comment.