Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] add ein_rearrange, ein_reduce, and ein_repeat functions #590

Merged
merged 4 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.6.post5"

__version__ = "2.5.0"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
728 changes: 728 additions & 0 deletions brainpy/_src/math/einops.py

Large diffs are not rendered by default.

153 changes: 153 additions & 0 deletions brainpy/_src/math/einops_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import keyword
import warnings
from typing import List, Optional, Set, Tuple, Union

_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated


class EinopsError(Exception):
pass


class AnonymousAxis(object):
"""Important thing: all instances of this class are not equal to each other """

def __init__(self, value: str):
self.value = int(value)
if self.value <= 1:
if self.value == 1:
raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
else:
raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))

def __repr__(self):
return "{}-axis".format(str(self.value))


class ParsedExpression:
"""
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
and keeps some information important for downstream
"""

def __init__(self, expression: str, *, allow_underscore: bool = False,
allow_duplicates: bool = False):
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[str] = set()
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes: bool = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition: List[Union[List[str], str]] = []
if '.' in expression:
if '...' not in expression:
raise EinopsError('Expression may contain dots only inside ellipsis (...)')
if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
raise EinopsError(
'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
expression = expression.replace('...', _ellipsis)
self.has_ellipsis = True

bracket_group: Optional[List[str]] = None

def add_axis_name(x):
if x in self.identifiers:
if not (allow_underscore and x == "_") and not allow_duplicates:
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
if x == _ellipsis:
self.identifiers.add(_ellipsis)
if bracket_group is None:
self.composition.append(_ellipsis)
self.has_ellipsis_parenthesized = False
else:
bracket_group.append(_ellipsis)
self.has_ellipsis_parenthesized = True
else:
is_number = str.isdecimal(x)
if is_number and int(x) == 1:
# handling the case of anonymous axis of length 1
if bracket_group is None:
self.composition.append([])
else:
pass # no need to think about 1s inside parenthesis
return
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
if not (is_number or is_axis_name):
raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
if is_number:
x = AnonymousAxis(x)
self.identifiers.add(x)
if is_number:
self.has_non_unitary_anonymous_axes = True
if bracket_group is None:
self.composition.append([x])
else:
bracket_group.append(x)

current_identifier = None
for char in expression:
if char in '() ':
if current_identifier is not None:
add_axis_name(current_identifier)
current_identifier = None
if char == '(':
if bracket_group is not None:
raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
bracket_group = []
elif char == ')':
if bracket_group is None:
raise EinopsError('Brackets are not balanced')
self.composition.append(bracket_group)
bracket_group = None
elif str.isalnum(char) or char in ['_', _ellipsis]:
if current_identifier is None:
current_identifier = char
else:
current_identifier += char
else:
raise EinopsError("Unknown character '{}'".format(char))

if bracket_group is not None:
raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
if current_identifier is not None:
add_axis_name(current_identifier)

def flat_axes_order(self) -> List:
result = []
for composed_axis in self.composition:
assert isinstance(composed_axis, list), 'does not work with ellipsis'
for axis in composed_axis:
result.append(axis)
return result

def has_composed_axes(self) -> bool:
# this will ignore 1 inside brackets
for axes in self.composition:
if isinstance(axes, list) and len(axes) > 1:
return True
return False

@staticmethod
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
if not str.isidentifier(name):
return False, 'not a valid python identifier'
elif name[0] == '_' or name[-1] == '_':
if name == '_' and allow_underscore:
return True, ''
return False, 'axis name should should not start or end with underscore'
else:
if keyword.iskeyword(name):
warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
if name in ['axis']:
warnings.warn("It is discouraged to use 'axis' as an axis name "
"and will raise an error in future", FutureWarning)
return True, ''

@staticmethod
def check_axis_name(name: str) -> bool:
"""
Valid axes names are python identifiers except keywords,
and additionally should not start or end with underscore
"""
is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
return is_valid
10 changes: 9 additions & 1 deletion brainpy/_src/math/interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@


__all__ = [
'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable', 'is_bp_array'
'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable',
'from_numpy',

'is_bp_array'
]


Expand Down Expand Up @@ -99,3 +102,8 @@ def as_variable(tensor, dtype=None):
"""
from .object_transform.variables import Variable
return Variable(tensor, dtype=dtype)


def from_numpy(arr, dtype=None):
return as_ndarray(arr, dtype=dtype)

27 changes: 25 additions & 2 deletions brainpy/_src/math/others.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
# -*- coding: utf-8 -*-


from typing import Optional
from typing import Optional, Union

import jax
import jax.numpy as jnp
from jax.tree_util import tree_map

from brainpy import check, tools
from .compat_numpy import fill_diagonal
from .environment import get_dt, get_int
from .ndarray import Array
from .interoperability import as_jax
from .ndarray import Array

__all__ = [
'shared_args_over_time',
'remove_diag',
'clip_by_norm',
'exprel',
'is_float_type',
# 'reduce',
'add_axis',
'add_axes',
]


Expand Down Expand Up @@ -119,3 +124,21 @@ def exprel(x, threshold: float = None):
else:
threshold = 1e-5
return _exprel(x, threshold)


def is_float_type(x: Union[Array, jax.Array]):
return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")


def add_axis(x: Union[Array, jax.Array], new_position: int):
x = as_jax(x)
return jnp.expand_dims(x, new_position)


def add_axes(x: Union[Array, jax.Array], n_axes, pos2len):
x = as_jax(x)
repeats = [1] * n_axes
for axis_position, axis_length in pos2len.items():
x = add_axis(x, axis_position)
repeats[axis_position] = axis_length
return jnp.tile(x, repeats)
Loading