diff --git a/brainpy/__init__.py b/brainpy/__init__.py index c8f834c6d..a3a1de694 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- -__version__ = "2.4.6.post5" + +__version__ = "2.5.0" # fundamental supporting modules from brainpy import errors, check, tools diff --git a/brainpy/_src/math/einops.py b/brainpy/_src/math/einops.py new file mode 100644 index 000000000..d42026974 --- /dev/null +++ b/brainpy/_src/math/einops.py @@ -0,0 +1,728 @@ +import functools +import itertools +from collections import OrderedDict +from typing import Set, Tuple, List, Dict, Union, Callable, Optional, cast + +import jax +import numpy as np + +from . import compat_numpy as bnp +from . import others as bnp2 +from .einops_parsing import ParsedExpression, _ellipsis, AnonymousAxis, EinopsError +from .ndarray import Array + +__all__ = [ + 'ein_reduce', 'ein_rearrange', 'ein_repeat', 'ein_shape', +] + +Tensor = Union[Array, jax.Array] +ReductionCallable = Callable[[Tensor, Tuple[int, ...]], Tensor] +Reduction = Union[str, ReductionCallable] + +_reductions = ("min", "max", "sum", "mean", "prod", "any", "all") + +# magic integers are required to stay within +# traceable subset of language +_unknown_axis_length = -999999 +_expected_axis_length = -99999 + + +def _product(sequence: List[int]) -> int: + """minimalistic product that works both with numbers and symbols. Supports empty lists""" + result = 1 + for element in sequence: + result *= element + return result + + +def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int]): + if callable(reduction_type): + # custom callable + return reduction_type(tensor, tuple(reduced_axes)) + else: + # one of built-in operations + assert reduction_type in _reductions + if reduction_type == "mean": + if not bnp2.is_float_type(tensor): + raise NotImplementedError("reduce_mean is not available for non-floating tensors") + return __reduce(tensor, reduction_type, tuple(reduced_axes)) + + +def __reduce(x: Union[Array, jax.Array], operation: str, reduced_axes): + if operation == "min": + return x.min(axis=reduced_axes) + elif operation == "max": + return x.max(axis=reduced_axes) + elif operation == "sum": + return x.sum(axis=reduced_axes) + elif operation == "mean": + return x.mean(axis=reduced_axes) + elif operation == "prod": + return x.prod(axis=reduced_axes) + elif operation == "any": + return x.any(axis=reduced_axes) + elif operation == "all": + return x.all(axis=reduced_axes) + else: + raise NotImplementedError("Unknown reduction ", operation) + + +def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes): + # 'collapses' neighboring axes if those participate in the result pattern in the same order + # TODO add support for added_axes + assert len(axes_reordering) + len(reduced_axes) == len(init_shapes) + # joining consecutive axes that will be reduced + # possibly we can skip this if all backends can optimize this (not sure) + reduced_axes = tuple(sorted(reduced_axes)) + for i in range(len(reduced_axes) - 1)[::-1]: + if reduced_axes[i] + 1 == reduced_axes[i + 1]: + removed_axis = reduced_axes[i + 1] + removed_length = init_shapes[removed_axis] + init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] + init_shapes[removed_axis - 1] *= removed_length + reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:]) + + # removing axes that are moved together during reshape + def build_mapping(): + init_to_final = {} + for axis in range(len(init_shapes)): + if axis in reduced_axes: + init_to_final[axis] = None + else: + after_reduction = sum(x is not None for x in init_to_final.values()) + init_to_final[axis] = list(axes_reordering).index(after_reduction) + return init_to_final + + init_axis_to_final_axis = build_mapping() + + for init_axis in range(len(init_shapes) - 1)[::-1]: + if init_axis_to_final_axis[init_axis] is None: + continue + if init_axis_to_final_axis[init_axis + 1] is None: + continue + if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]: + removed_axis = init_axis + 1 + removed_length = init_shapes[removed_axis] + removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis)) + + reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes) + init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] + init_shapes[removed_axis - 1] *= removed_length + old_reordering = axes_reordering + axes_reordering = [] + for axis in old_reordering: + if axis == removed_axis_after_reduction: + pass + elif axis < removed_axis_after_reduction: + axes_reordering.append(axis) + else: + axes_reordering.append(axis - 1) + init_axis_to_final_axis = build_mapping() + + return init_shapes, reduced_axes, axes_reordering, final_shapes + + +CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int] + +# Actual type is tuple[tuple[str, int], ...] +# However torch.jit.script does not "understand" the correct type, +# and torch_specific will use list version. +HashableAxesLengths = Tuple[Tuple[str, int], ...] +FakeHashableAxesLengths = List[Tuple[str, int]] + + +class TransformRecipe: + """ + Recipe describes actual computation pathway. + Recipe can be applied to a tensor or variable. + """ + + # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+) + # update: pytorch 2.0 torch.jit.script seems to have problems with dataclasses unless they were explicitly provided + + def __init__( + self, + # list of sizes (or just sizes) for elementary axes as they appear in left expression. + # this is what (after computing unknown parts) will be a shape after first transposition. + # This does not include any ellipsis dimensions. + elementary_axes_lengths: List[int], + # if additional axes are provided, they should be set in prev array + # This shows mapping from name to position + axis_name2elementary_axis: Dict[str, int], + # each dimension in input can help to reconstruct length of one elementary axis + # or verify one of dimensions. Each element points to element of elementary_axes_lengths. + input_composition_known_unknown: List[Tuple[List[int], List[int]]], + # permutation applied to elementary axes, if ellipsis is absent + axes_permutation: List[int], + # permutation puts reduced axes in the end, we only need to know the first position. + first_reduced_axis: int, + # at which positions which of elementary axes should appear. Axis position -> axis index. + added_axes: Dict[int, int], + # ids of axes as they appear in result, again pointers to elementary_axes_lengths, + # only used to infer result dimensions + output_composite_axes: List[List[int]], + ): + self.elementary_axes_lengths: List[int] = elementary_axes_lengths + self.axis_name2elementary_axis: Dict[str, int] = axis_name2elementary_axis + self.input_composition_known_unknown: List[Tuple[List[int], List[int]]] = input_composition_known_unknown + self.axes_permutation: List[int] = axes_permutation + + self.first_reduced_axis: int = first_reduced_axis + self.added_axes: Dict[int, int] = added_axes + self.output_composite_axes: List[List[int]] = output_composite_axes + + +def _reconstruct_from_shape_uncached( + self: TransformRecipe, shape: List[int], axes_dims: FakeHashableAxesLengths +) -> CookedRecipe: + """ + Reconstruct all actual parameters using shape. + Shape is a tuple that may contain integers, shape symbols (tf, theano) and UnknownSize (tf, previously mxnet) + known axes can be integers or symbols, but not Nones. + """ + # magic number + need_init_reshape = False + + # last axis is allocated for collapsed ellipsis + axes_lengths: List[int] = list(self.elementary_axes_lengths) + for axis, dim in axes_dims: + axes_lengths[self.axis_name2elementary_axis[axis]] = dim + + for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composition_known_unknown): + length = shape[input_axis] + if len(known_axes) == 0 and len(unknown_axes) == 1: + # shortcut for the most common case + axes_lengths[unknown_axes[0]] = length + continue + + known_product = 1 + for axis in known_axes: + known_product *= axes_lengths[axis] + + if len(unknown_axes) == 0: + if isinstance(length, int) and isinstance(known_product, int) and length != known_product: + raise EinopsError(f"Shape mismatch, {length} != {known_product}") + else: + # assert len(unknown_axes) == 1, 'this is enforced when recipe is created, so commented out' + if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0: + raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}") + + unknown_axis = unknown_axes[0] + inferred_length: int = length // known_product + axes_lengths[unknown_axis] = inferred_length + + if len(known_axes) + len(unknown_axes) != 1: + need_init_reshape = True + + # at this point all axes_lengths are computed (either have values or variables, but not Nones) + + # elementary axes are ordered as they appear in input, then all added axes + init_shapes: Optional[List[int]] = axes_lengths[: len(self.axes_permutation)] if need_init_reshape else None + + need_final_reshape = False + final_shapes: List[int] = [] + for grouping in self.output_composite_axes: + lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] + final_shapes.append(_product(lengths)) + if len(lengths) != 1: + need_final_reshape = True + + added_axes: Dict[int, int] = { + pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items() + } + + # this list can be empty + reduced_axes = list(range(self.first_reduced_axis, len(self.axes_permutation))) + + n_axes_after_adding_axes = len(added_axes) + len(self.axes_permutation) + + axes_reordering: Optional[List[int]] = self.axes_permutation + if self.axes_permutation == list(range(len(self.axes_permutation))): + axes_reordering = None + + _final_shapes = final_shapes if need_final_reshape else None + return init_shapes, axes_reordering, reduced_axes, added_axes, _final_shapes, n_axes_after_adding_axes + + +_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached) + + +def _apply_recipe( + recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths +) -> Tensor: + # this method implements actual work for all backends for 3 operations + try: + init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = ( + _reconstruct_from_shape(recipe, bnp.shape(tensor), axes_lengths)) + except TypeError: + # shape or one of passed axes lengths is not hashable (i.e. they are symbols) + _result = _reconstruct_from_shape_uncached(recipe, bnp.shape(tensor), axes_lengths) + (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result + if init_shapes is not None: + tensor = bnp.reshape(bnp.as_jax(tensor), init_shapes) + if axes_reordering is not None: + tensor = bnp.transpose(bnp.as_jax(tensor), axes_reordering) + if len(reduced_axes) > 0: + tensor = _reduce_axes(bnp.as_jax(tensor), reduction_type=reduction_type, reduced_axes=reduced_axes) + if len(added_axes) > 0: + tensor = bnp2.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes) + if final_shapes is not None: + tensor = bnp.reshape(bnp.as_jax(tensor), final_shapes) + return tensor + + +def _apply_recipe_array_api( + xp, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths +) -> Tensor: + # completely-inline implementation + init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( + recipe, tensor.shape, axes_lengths + ) + if init_shapes is not None: + tensor = xp.reshape(tensor, init_shapes) + if axes_reordering is not None: + tensor = xp.permute_dims(tensor, axes_reordering) + if len(reduced_axes) > 0: + if callable(reduction_type): + # custom callable + tensor = reduction_type(tensor, tuple(reduced_axes)) + else: + # one of built-in operations + assert reduction_type in _reductions + tensor = getattr(xp, reduction_type)(tensor, axis=tuple(reduced_axes)) + if len(added_axes) > 0: + # we use broadcasting + for axis_position, axis_length in added_axes.items(): + tensor = xp.expand_dims(tensor, axis=axis_position) + + final_shape = list(tensor.shape) + for axis_position, axis_length in added_axes.items(): + final_shape[axis_position] = axis_length + + tensor = xp.broadcast_to(tensor, final_shape) + if final_shapes is not None: + tensor = xp.reshape(tensor, final_shapes) + return tensor + + +@functools.lru_cache(256) +def _prepare_transformation_recipe( + pattern: str, + operation: Reduction, + axes_names: Tuple[str, ...], + ndim: int, +) -> TransformRecipe: + """Perform initial parsing of pattern and provided supplementary info + axes_lengths is a tuple of tuples (axis_name, axis_length) + """ + left_str, rght_str = pattern.split("->") + left = ParsedExpression(left_str) + rght = ParsedExpression(rght_str) + + # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction + if not left.has_ellipsis and rght.has_ellipsis: + raise EinopsError("Ellipsis found in right side, but not left side of a pattern {}".format(pattern)) + if left.has_ellipsis and left.has_ellipsis_parenthesized: + raise EinopsError("Ellipsis inside parenthesis in the left side is not allowed: {}".format(pattern)) + if operation == "rearrange": + if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes: + raise EinopsError("Non-unitary anonymous axes are not supported in rearrange (exception is length 1)") + difference = set.symmetric_difference(left.identifiers, rght.identifiers) + if len(difference) > 0: + raise EinopsError("Identifiers only on one side of expression (should be on both): {}".format(difference)) + elif operation == "repeat": + difference = set.difference(left.identifiers, rght.identifiers) + if len(difference) > 0: + raise EinopsError("Unexpected identifiers on the left side of repeat: {}".format(difference)) + axes_without_size = set.difference( + {ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)}, + {*left.identifiers, *axes_names}, + ) + if len(axes_without_size) > 0: + raise EinopsError("Specify sizes for new axes in repeat: {}".format(axes_without_size)) + elif operation in _reductions or callable(operation): + difference = set.difference(rght.identifiers, left.identifiers) + if len(difference) > 0: + raise EinopsError("Unexpected identifiers on the right side of reduce {}: {}".format(operation, difference)) + else: + raise EinopsError("Unknown reduction {}. Expect one of {}.".format(operation, _reductions)) + + if left.has_ellipsis: + n_other_dims = len(left.composition) - 1 + if ndim < n_other_dims: + raise EinopsError(f"Wrong shape: expected >={n_other_dims} dims. Received {ndim}-dim tensor.") + ellipsis_ndim = ndim - n_other_dims + ell_axes = [_ellipsis + str(i) for i in range(ellipsis_ndim)] + left_composition = [] + for composite_axis in left.composition: + if composite_axis == _ellipsis: + for axis in ell_axes: + left_composition.append([axis]) + else: + left_composition.append(composite_axis) + + rght_composition = [] + for composite_axis in rght.composition: + if composite_axis == _ellipsis: + for axis in ell_axes: + rght_composition.append([axis]) + else: + group = [] + for axis in composite_axis: + if axis == _ellipsis: + group.extend(ell_axes) + else: + group.append(axis) + rght_composition.append(group) + + left.identifiers.update(ell_axes) + left.identifiers.remove(_ellipsis) + if rght.has_ellipsis: + rght.identifiers.update(ell_axes) + rght.identifiers.remove(_ellipsis) + else: + if ndim != len(left.composition): + raise EinopsError(f"Wrong shape: expected {len(left.composition)} dims. Received {ndim}-dim tensor.") + left_composition = left.composition + rght_composition = rght.composition + + # parsing all dimensions to find out lengths + axis_name2known_length: Dict[Union[str, AnonymousAxis], int] = OrderedDict() + for composite_axis in left_composition: + for axis_name in composite_axis: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = _unknown_axis_length + + # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point + + repeat_axes_names = [] + for axis_name in rght.identifiers: + if axis_name not in axis_name2known_length: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = _unknown_axis_length + repeat_axes_names.append(axis_name) + + axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} + + # axes provided as kwargs + for elementary_axis in axes_names: + if not ParsedExpression.check_axis_name(elementary_axis): + raise EinopsError("Invalid name for an axis", elementary_axis) + if elementary_axis not in axis_name2known_length: + raise EinopsError("Axis {} is not used in transform".format(elementary_axis)) + axis_name2known_length[elementary_axis] = _expected_axis_length + + input_axes_known_unknown = [] + # some shapes are inferred later - all information is prepared for faster inference + for i, composite_axis in enumerate(left_composition): + known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} + unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} + if len(unknown) > 1: + raise EinopsError("Could not infer sizes for {}".format(unknown)) + assert len(unknown) + len(known) == len(composite_axis) + input_axes_known_unknown.append( + ([axis_name2position[axis] for axis in known], [axis_name2position[axis] for axis in unknown]) + ) + + axis_position_after_reduction: Dict[str, int] = {} + for axis_name in itertools.chain(*left_composition): + if axis_name in rght.identifiers: + axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) + + result_axes_grouping: List[List[int]] = [ + [axis_name2position[axis] for axis in composite_axis] for i, composite_axis in enumerate(rght_composition) + ] + + ordered_axis_left = list(itertools.chain(*left_composition)) + ordered_axis_rght = list(itertools.chain(*rght_composition)) + reduced_axes = [axis for axis in ordered_axis_left if axis not in rght.identifiers] + order_after_transposition = [axis for axis in ordered_axis_rght if axis in left.identifiers] + reduced_axes + axes_permutation = [ordered_axis_left.index(axis) for axis in order_after_transposition] + added_axes = { + i: axis_name2position[axis_name] + for i, axis_name in enumerate(ordered_axis_rght) + if axis_name not in left.identifiers + } + + first_reduced_axis = len(order_after_transposition) - len(reduced_axes) + + return TransformRecipe( + elementary_axes_lengths=list(axis_name2known_length.values()), + axis_name2elementary_axis={axis: axis_name2position[axis] for axis in axes_names}, + input_composition_known_unknown=input_axes_known_unknown, + axes_permutation=axes_permutation, + first_reduced_axis=first_reduced_axis, + added_axes=added_axes, + output_composite_axes=result_axes_grouping, + ) + + +def _prepare_recipes_for_all_dims( + pattern: str, operation: Reduction, axes_names: Tuple[str, ...] +) -> Dict[int, TransformRecipe]: + """ + Internal function, used in layers. + Layer makes all recipe creation when it is initialized, thus to keep recipes simple we pre-compute for all dims + """ + left_str, rght_str = pattern.split("->") + left = ParsedExpression(left_str) + dims = [len(left.composition)] + if left.has_ellipsis: + dims = [len(left.composition) - 1 + ellipsis_dims for ellipsis_dims in range(8)] + return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims} + + +def ein_reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor: + """ + ``ein_reduce`` provides combination of reordering and reduction using reader-friendly notation. + + Examples for reduce operation: + + ```python + >>> x = np.random.randn(100, 32, 64) + + # perform max-reduction on the first axis + >>> y = ein_reduce(x, 't b c -> b c', 'max') + + # same as previous, but with clearer axes meaning + >>> y = ein_reduce(x, 'time batch channel -> batch channel', 'max') + + >>> x = np.random.randn(10, 20, 30, 40) + + # 2d max-pooling with kernel size = 2 * 2 for image processing + >>> y1 = ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) + + # if one wants to go back to the original height and width, depth-to-space trick can be applied + >>> y2 = ein_rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) + >>> assert ein_shape(x, 'b _ h w') == ein_shape(y2, 'b _ h w') + + # Adaptive 2d max-pooling to 3 * 4 grid + >>> ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape + (10, 20, 3, 4) + + # Global average pooling + >>> ein_reduce(x, 'b c h w -> b c', 'mean').shape + (10, 20) + + # Subtracting mean over batch for each channel + >>> y = x - ein_reduce(x, 'b c h w -> () c () ()', 'mean') + + # Subtracting per-image mean for each channel + >>> y = x - ein_reduce(x, 'b c h w -> b c () ()', 'mean') + + ``` + + Parameters: + tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, reduction pattern + reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive + alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. + This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc. + axes_lengths: any additional specifications for dimensions + + Returns: + tensor of the same type as input + """ + try: + hashable_axes_lengths = tuple(axes_lengths.items()) + shape = bnp.shape(tensor) + recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape)) + return _apply_recipe(recipe, + cast(Tensor, tensor), + reduction_type=reduction, + axes_lengths=hashable_axes_lengths) + except EinopsError as e: + message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) + if not isinstance(tensor, list): + message += "\n Input tensor shape: {}. ".format(shape) + else: + message += "\n Input is list. " + message += "Additional info: {}.".format(axes_lengths) + raise EinopsError(message + "\n {}".format(e)) + + +def ein_rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: + """ + ``ein_rearrange`` is a reader-friendly smart element reordering for multidimensional tensors. + This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, + stack, concatenate and other operations. + + Examples for rearrange operation: + + ```python + # suppose we have a set of 32 images in "h w c" format (height-width-channel) + >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] + + # stack along first (batch) axis, output is a single array + >>> ein_rearrange(images, 'b h w c -> b h w c').shape + (32, 30, 40, 3) + + # concatenate images along height (vertical axis), 960 = 32 * 30 + >>> ein_rearrange(images, 'b h w c -> (b h) w c').shape + (960, 40, 3) + + # concatenated images along horizontal axis, 1280 = 32 * 40 + >>> ein_rearrange(images, 'b h w c -> h (b w) c').shape + (30, 1280, 3) + + # reordered axes to "b c h w" format for deep learning + >>> ein_rearrange(images, 'b h w c -> b c h w').shape + (32, 3, 30, 40) + + # flattened each image into a vector, 3600 = 30 * 40 * 3 + >>> ein_rearrange(images, 'b h w c -> b (c h w)').shape + (32, 3600) + + # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 + >>> ein_rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape + (128, 15, 20, 3) + + # space-to-depth operation + >>> ein_rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape + (32, 15, 20, 12) + + ``` + + When composing axes, C-order enumeration used (consecutive elements have different last axis) + Find more examples in einops tutorial. + + Parameters: + tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, rearrangement pattern + axes_lengths: any additional specifications for dimensions + + Returns: + tensor of the same type as input. If possible, a view to the original tensor is returned. + + """ + return ein_reduce(tensor, pattern, reduction="rearrange", **axes_lengths) + + +def ein_repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: + """ + ``ein_repeat`` allows reordering elements and repeating them in arbitrary combinations. + This operation includes functionality of repeat, tile, broadcast functions. + + Examples for repeat operation: + + ```python + # a grayscale image (of shape height x width) + >>> image = np.random.randn(30, 40) + + # change it to RGB format by repeating in each channel + >>> ein_repeat(image, 'h w -> h w c', c=3).shape + (30, 40, 3) + + # repeat image 2 times along height (vertical axis) + >>> ein_repeat(image, 'h w -> (repeat h) w', repeat=2).shape + (60, 40) + + # repeat image 2 time along height and 3 times along width + >>> ein_repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape + (60, 120) + + # convert each pixel to a small square 2x2. Upsample image by 2x + >>> ein_repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape + (60, 80) + + # pixelate image first by downsampling by 2x, then upsampling + >>> downsampled = ein_reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) + >>> ein_repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape + (30, 40) + + ``` + + When composing axes, C-order enumeration used (consecutive elements have different last axis) + Find more examples in einops tutorial. + + Parameters: + tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, rearrangement pattern + axes_lengths: any additional specifications for dimensions + + Returns: + Tensor of the same type as input. If possible, a view to the original tensor is returned. + + """ + return ein_reduce(tensor, pattern, reduction="repeat", **axes_lengths) + + +def ein_shape(x, pattern: str) -> dict: + """ + Parse a tensor shape to dictionary mapping axes names to their lengths. + + ```python + # Use underscore to skip the dimension in parsing. + >>> x = np.zeros([2, 3, 5, 7]) + >>> ein_shape(x, 'batch _ h w') + {'batch': 2, 'h': 5, 'w': 7} + + # `parse_shape` output can be used to specify axes_lengths for other operations: + >>> y = np.zeros([700]) + >>> ein_rearrange(y, '(b c h w) -> b c h w', **ein_shape(x, 'b _ h w')).shape + (2, 10, 5, 7) + + ``` + + For symbolic frameworks may return symbols, not integers. + + Parameters: + x: tensor of any supported framework + pattern: str, space separated names for axes, underscore means skip axis + + Returns: + dict, maps axes names to their lengths + """ + exp = ParsedExpression(pattern, allow_underscore=True) + shape = bnp.shape(x) + if exp.has_composed_axes(): + raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}") + if len(shape) != len(exp.composition): + if exp.has_ellipsis: + if len(shape) < len(exp.composition) - 1: + raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}") + else: + raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}") + if exp.has_ellipsis: + ellipsis_idx = exp.composition.index(_ellipsis) + composition = ( + exp.composition[:ellipsis_idx] + + ["_"] * (len(shape) - len(exp.composition) + 1) + + exp.composition[ellipsis_idx + 1:] + ) + else: + composition = exp.composition + result = {} + for (axis_name,), axis_length in zip(composition, shape): # type: ignore + if axis_name != "_": + result[axis_name] = axis_length + return result + + +# _enumerate_directions is not exposed in the public API +def _enumerate_directions(x): + """ + For an n-dimensional tensor, returns tensors to enumerate each axis. + ```python + x = np.zeros([2, 3, 4]) # or any other tensor + i, j, k = _enumerate_directions(x) + result = i + 2*j + 3*k + ``` + + `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result + Works very similarly to numpy.ogrid (open indexing grid) + """ + shape = bnp.shape(x) + result = [] + for axis_id, axis_length in enumerate(shape): + shape = [1] * len(shape) + shape[axis_id] = axis_length + result.append(bnp.reshape(bnp.arange(0, axis_length), shape)) + return result diff --git a/brainpy/_src/math/einops_parsing.py b/brainpy/_src/math/einops_parsing.py new file mode 100644 index 000000000..6ce055bdb --- /dev/null +++ b/brainpy/_src/math/einops_parsing.py @@ -0,0 +1,153 @@ +import keyword +import warnings +from typing import List, Optional, Set, Tuple, Union + +_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated + + +class EinopsError(Exception): + pass + + +class AnonymousAxis(object): + """Important thing: all instances of this class are not equal to each other """ + + def __init__(self, value: str): + self.value = int(value) + if self.value <= 1: + if self.value == 1: + raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') + else: + raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value)) + + def __repr__(self): + return "{}-axis".format(str(self.value)) + + +class ParsedExpression: + """ + non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') + and keeps some information important for downstream + """ + + def __init__(self, expression: str, *, allow_underscore: bool = False, + allow_duplicates: bool = False): + self.has_ellipsis: bool = False + self.has_ellipsis_parenthesized: Optional[bool] = None + self.identifiers: Set[str] = set() + # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition + self.has_non_unitary_anonymous_axes: bool = False + # composition keeps structure of composite axes, see how different corner cases are handled in tests + self.composition: List[Union[List[str], str]] = [] + if '.' in expression: + if '...' not in expression: + raise EinopsError('Expression may contain dots only inside ellipsis (...)') + if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: + raise EinopsError( + 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') + expression = expression.replace('...', _ellipsis) + self.has_ellipsis = True + + bracket_group: Optional[List[str]] = None + + def add_axis_name(x): + if x in self.identifiers: + if not (allow_underscore and x == "_") and not allow_duplicates: + raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) + if x == _ellipsis: + self.identifiers.add(_ellipsis) + if bracket_group is None: + self.composition.append(_ellipsis) + self.has_ellipsis_parenthesized = False + else: + bracket_group.append(_ellipsis) + self.has_ellipsis_parenthesized = True + else: + is_number = str.isdecimal(x) + if is_number and int(x) == 1: + # handling the case of anonymous axis of length 1 + if bracket_group is None: + self.composition.append([]) + else: + pass # no need to think about 1s inside parenthesis + return + is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) + if not (is_number or is_axis_name): + raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) + if is_number: + x = AnonymousAxis(x) + self.identifiers.add(x) + if is_number: + self.has_non_unitary_anonymous_axes = True + if bracket_group is None: + self.composition.append([x]) + else: + bracket_group.append(x) + + current_identifier = None + for char in expression: + if char in '() ': + if current_identifier is not None: + add_axis_name(current_identifier) + current_identifier = None + if char == '(': + if bracket_group is not None: + raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)") + bracket_group = [] + elif char == ')': + if bracket_group is None: + raise EinopsError('Brackets are not balanced') + self.composition.append(bracket_group) + bracket_group = None + elif str.isalnum(char) or char in ['_', _ellipsis]: + if current_identifier is None: + current_identifier = char + else: + current_identifier += char + else: + raise EinopsError("Unknown character '{}'".format(char)) + + if bracket_group is not None: + raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) + if current_identifier is not None: + add_axis_name(current_identifier) + + def flat_axes_order(self) -> List: + result = [] + for composed_axis in self.composition: + assert isinstance(composed_axis, list), 'does not work with ellipsis' + for axis in composed_axis: + result.append(axis) + return result + + def has_composed_axes(self) -> bool: + # this will ignore 1 inside brackets + for axes in self.composition: + if isinstance(axes, list) and len(axes) > 1: + return True + return False + + @staticmethod + def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: + if not str.isidentifier(name): + return False, 'not a valid python identifier' + elif name[0] == '_' or name[-1] == '_': + if name == '_' and allow_underscore: + return True, '' + return False, 'axis name should should not start or end with underscore' + else: + if keyword.iskeyword(name): + warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning) + if name in ['axis']: + warnings.warn("It is discouraged to use 'axis' as an axis name " + "and will raise an error in future", FutureWarning) + return True, '' + + @staticmethod + def check_axis_name(name: str) -> bool: + """ + Valid axes names are python identifiers except keywords, + and additionally should not start or end with underscore + """ + is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name) + return is_valid diff --git a/brainpy/_src/math/interoperability.py b/brainpy/_src/math/interoperability.py index 22fe25caf..948538371 100644 --- a/brainpy/_src/math/interoperability.py +++ b/brainpy/_src/math/interoperability.py @@ -7,7 +7,10 @@ __all__ = [ - 'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable', 'is_bp_array' + 'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable', + 'from_numpy', + + 'is_bp_array' ] @@ -99,3 +102,8 @@ def as_variable(tensor, dtype=None): """ from .object_transform.variables import Variable return Variable(tensor, dtype=dtype) + + +def from_numpy(arr, dtype=None): + return as_ndarray(arr, dtype=dtype) + diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index f3cf4f516..94aeebb16 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -1,22 +1,27 @@ # -*- coding: utf-8 -*- -from typing import Optional +from typing import Optional, Union +import jax import jax.numpy as jnp from jax.tree_util import tree_map from brainpy import check, tools from .compat_numpy import fill_diagonal from .environment import get_dt, get_int -from .ndarray import Array from .interoperability import as_jax +from .ndarray import Array __all__ = [ 'shared_args_over_time', 'remove_diag', 'clip_by_norm', 'exprel', + 'is_float_type', + # 'reduce', + 'add_axis', + 'add_axes', ] @@ -119,3 +124,21 @@ def exprel(x, threshold: float = None): else: threshold = 1e-5 return _exprel(x, threshold) + + +def is_float_type(x: Union[Array, jax.Array]): + return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") + + +def add_axis(x: Union[Array, jax.Array], new_position: int): + x = as_jax(x) + return jnp.expand_dims(x, new_position) + + +def add_axes(x: Union[Array, jax.Array], n_axes, pos2len): + x = as_jax(x) + repeats = [1] * n_axes + for axis_position, axis_length in pos2len.items(): + x = add_axis(x, axis_position) + repeats[axis_position] = axis_length + return jnp.tile(x, repeats) diff --git a/brainpy/_src/math/tests/test_einops.py b/brainpy/_src/math/tests/test_einops.py new file mode 100644 index 000000000..2f018d973 --- /dev/null +++ b/brainpy/_src/math/tests/test_einops.py @@ -0,0 +1,331 @@ +import numpy +import pytest + +import brainpy.math as bm +from brainpy._src.math.einops import ein_rearrange, ein_reduce, ein_repeat, _enumerate_directions +from brainpy._src.math.einops_parsing import EinopsError + +REDUCTIONS = ("min", "max", "sum", "mean", "prod") + +identity_patterns = [ + "...->...", + "a b c d e-> a b c d e", + "a b c d e ...-> ... a b c d e", + "a b c d e ...-> a ... b c d e", + "... a b c d e -> ... a b c d e", + "a ... e-> a ... e", + "a ... -> a ... ", + "a ... c d e -> a (...) c d e", +] + +equivalent_rearrange_patterns = [ + ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "), + ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"), + ("a b c d e -> a b c d e", "... -> ... "), + ("a b c d e -> (a b c d e)", "... -> (...)"), + ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"), + ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"), +] + +equivalent_reduction_patterns = [ + ("a b c d e -> ", " ... -> "), + ("a b c d e -> (e a)", "a ... e -> (e a)"), + ("a b c d e -> d (a e)", " a b c d e ... -> d (a e) "), + ("a b c d e -> (a b)", " ... c d e -> (...) "), +] + + +def test_collapsed_ellipsis_errors_out(): + x = numpy.zeros([1, 1, 1, 1, 1]) + ein_rearrange(x, "a b c d ... -> a b c ... d") + with pytest.raises(EinopsError): + ein_rearrange(x, "a b c d (...) -> a b c ... d") + + ein_rearrange(x, "... -> (...)") + with pytest.raises(EinopsError): + ein_rearrange(x, "(...) -> (...)") + + +def test_ellipsis_ops_numpy(): + x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + for pattern in identity_patterns: + assert numpy.array_equal(x, ein_rearrange(x, pattern)), pattern + + for pattern1, pattern2 in equivalent_rearrange_patterns: + assert numpy.array_equal(ein_rearrange(x, pattern1), ein_rearrange(x, pattern2)) + + for reduction in ["min", "max", "sum"]: + for pattern1, pattern2 in equivalent_reduction_patterns: + assert numpy.array_equal(ein_reduce(x, pattern1, reduction=reduction), + ein_reduce(x, pattern2, reduction=reduction)) + + # now just check coincidence with numpy + all_rearrange_patterns = [*identity_patterns] + for pattern_pairs in equivalent_rearrange_patterns: + all_rearrange_patterns.extend(pattern_pairs) + + +def test_rearrange_consistency_numpy(): + shape = [1, 2, 3, 5, 7, 11] + x = numpy.arange(numpy.prod(shape)).reshape(shape) + for pattern in [ + "a b c d e f -> a b c d e f", + "b a c d e f -> a b d e f c", + "a b c d e f -> f e d c b a", + "a b c d e f -> (f e) d (c b a)", + "a b c d e f -> (f e d c b a)", + ]: + result = ein_rearrange(x, pattern) + assert len(numpy.setdiff1d(x, result)) == 0 + + result = ein_rearrange(x, "a b c d e f -> a (b) (c d e) f") + assert numpy.array_equal(x.flatten(), result.flatten()) + + result = ein_rearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11") + assert numpy.array_equal(x, result) + + result1 = ein_rearrange(x, "a b c d e f -> f e d c b a") + result2 = ein_rearrange(x, "f e d c b a -> a b c d e f") + assert numpy.array_equal(result1, result2) + + result = ein_rearrange(ein_rearrange(x, "a b c d e f -> (f d) c (e b) a"), "(f d) c (e b) a -> a b c d e f", b=2, d=5) + assert numpy.array_equal(x, result) + + sizes = dict(zip("abcdef", shape)) + temp = ein_rearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes) + result = ein_rearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes) + assert numpy.array_equal(x, result) + + x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4]) + result = ein_rearrange(x2, "a b c -> b c a") + assert x2[1, 2, 3] == result[2, 3, 1] + assert x2[0, 1, 2] == result[1, 2, 0] + + +def test_rearrange_permutations_numpy(): + # tests random permutation of axes against two independent numpy ways + for n_axes in range(1, 10): + input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) + permutation = numpy.random.permutation(n_axes) + left_expression = " ".join("i" + str(axis) for axis in range(n_axes)) + right_expression = " ".join("i" + str(axis) for axis in permutation) + expression = left_expression + " -> " + right_expression + result = ein_rearrange(input, expression) + + for pick in numpy.random.randint(0, 2, [10, n_axes]): + assert input[tuple(pick)] == result[tuple(pick[permutation])] + + for n_axes in range(1, 10): + input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) + permutation = numpy.random.permutation(n_axes) + left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1]) + right_expression = " ".join("i" + str(axis) for axis in permutation[::-1]) + expression = left_expression + " -> " + right_expression + result = ein_rearrange(input, expression) + assert result.shape == input.shape + expected_result = numpy.zeros_like(input) + for original_axis, result_axis in enumerate(permutation): + expected_result |= ((input >> original_axis) & 1) << result_axis + + assert numpy.array_equal(result, expected_result) + + +def test_reduction_imperatives(): + for reduction in REDUCTIONS: + # slight redundancy for simpler order - numpy version is evaluated multiple times + input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6]) + if reduction in ["mean", "prod"]: + input = input / input.astype("float64").mean() + test_cases = [ + ["a b c d e -> ", {}, getattr(input, reduction)()], + ["a ... -> ", {}, getattr(input, reduction)()], + ["(a1 a2) ... (e1 e2) -> ", dict(a1=1, e2=2), getattr(input, reduction)()], + [ + "a b c d e -> (e c) a", + {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), + ], + [ + "a ... c d e -> (e c) a", + {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), + ], + [ + "a b c d e ... -> (e c) a", + {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), + ], + ["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])], + ["(a a2) ... -> (a2 a) ...", dict(a2=1), input], + ] + for pattern, axes_lengths, expected_result in test_cases: + result = ein_reduce(bm.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths) + result = bm.as_numpy(result) + print(reduction, pattern, expected_result, result) + assert numpy.allclose(result, expected_result), f"Failed at {pattern}" + + +def test_enumerating_directions(): + for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]: + x = numpy.arange(numpy.prod(shape)).reshape(shape) + axes1 = _enumerate_directions(x) + axes2 = _enumerate_directions(bm.from_numpy(x)) + assert len(axes1) == len(axes2) == len(shape) + for ax1, ax2 in zip(axes1, axes2): + ax2 = bm.as_numpy(ax2) + assert ax1.shape == ax2.shape + assert numpy.allclose(ax1, ax2) + + +def test_concatenations_and_stacking(): + for n_arrays in [1, 2, 5]: + shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6] + for shape in shapes: + arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)] + arrays2 = [bm.from_numpy(array) for array in arrays1] + result0 = numpy.asarray(arrays1) + result1 = ein_rearrange(arrays1, "...->...") + result2 = ein_rearrange(arrays2, "...->...") + assert numpy.array_equal(result0, result1) + assert numpy.array_equal(result1, bm.as_numpy(result2)) + + result1 = ein_rearrange(arrays1, "b ... -> ... b") + result2 = ein_rearrange(arrays2, "b ... -> ... b") + assert numpy.array_equal(result1, bm.as_numpy(result2)) + + +def test_gradients_imperatives(): + # lazy - just checking reductions + for reduction in REDUCTIONS: + if reduction in ("any", "all"): + continue # non-differentiable ops + x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype("float32") + y0 = bm.from_numpy(x) + if not hasattr(y0, "grad"): + continue + + y1 = ein_reduce(y0, "a b c -> c a", reduction=reduction) + y2 = ein_reduce(y1, "c a -> a c", reduction=reduction) + y3 = ein_reduce(y2, "a (c1 c2) -> a", reduction=reduction, c1=2) + y4 = ein_reduce(y3, "... -> ", reduction=reduction) + + y4.backward() + grad = bm.as_numpy(y0.grad) + + +def test_tiling_imperatives(): + input = numpy.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5]) + test_cases = [ + (1, 1, 1, 1, 1), + (1, 2, 1, 3, 1), + (3, 1, 1, 4, 1), + ] + for repeats in test_cases: + expected = numpy.tile(input, repeats) + converted = bm.from_numpy(input) + repeated = bm.tile(converted, repeats) + result = bm.as_numpy(repeated) + assert numpy.array_equal(result, expected) + + +repeat_test_cases = [ + # all assume that input has shape [2, 3, 5] + ("a b c -> c a b", dict()), + ("a b c -> (c copy a b)", dict(copy=2, a=2, b=3, c=5)), + ("a b c -> (a copy) b c ", dict(copy=1)), + ("a b c -> (c a) (copy1 b copy2)", dict(a=2, copy1=1, copy2=2)), + ("a ... -> a ... copy", dict(copy=4)), + ("... c -> ... (copy1 c copy2)", dict(copy1=1, copy2=2)), + ("... -> ... ", dict()), + (" ... -> copy1 ... copy2 ", dict(copy1=2, copy2=3)), + ("a b c -> copy1 a copy2 b c () ", dict(copy1=2, copy2=1)), +] + + +def check_reversion(x, repeat_pattern, **sizes): + """Checks repeat pattern by running reduction""" + left, right = repeat_pattern.split("->") + reduce_pattern = right + "->" + left + repeated = ein_repeat(x, repeat_pattern, **sizes) + reduced_min = ein_reduce(repeated, reduce_pattern, reduction="min", **sizes) + reduced_max = ein_reduce(repeated, reduce_pattern, reduction="max", **sizes) + assert numpy.array_equal(x, reduced_min) + assert numpy.array_equal(x, reduced_max) + + +def test_repeat_numpy(): + # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well + x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5]) + x1 = ein_repeat(x, "a b c -> copy a b c ", copy=1) + assert numpy.array_equal(x[None], x1) + for pattern, axis_dimensions in repeat_test_cases: + check_reversion(x, pattern, **axis_dimensions) + + +test_cases_repeat_anonymous = [ + # all assume that input has shape [1, 2, 4, 6] + ("a b c d -> c a d b", dict()), + ("a b c d -> (c 2 d a b)", dict(a=1, c=4, d=6)), + ("1 b c d -> (d copy 1) 3 b c ", dict(copy=3)), + ("1 ... -> 3 ... ", dict()), + ("() ... d -> 1 (copy1 d copy2) ... ", dict(copy1=2, copy2=3)), + ("1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)", dict()), +] + + +def test_anonymous_axes(): + x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6]) + for pattern, axis_dimensions in test_cases_repeat_anonymous: + check_reversion(x, pattern, **axis_dimensions) + + +def test_list_inputs(): + x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + + assert numpy.array_equal( + ein_rearrange(list(x), "... -> (...)"), + ein_rearrange(x, "... -> (...)"), + ) + assert numpy.array_equal( + ein_reduce(list(x), "a ... e -> (...)", "min"), + ein_reduce(x, "a ... e -> (...)", "min"), + ) + assert numpy.array_equal( + ein_repeat(list(x), "... -> b (...)", b=3), + ein_repeat(x, "... -> b (...)", b=3), + ) + + +def bit_count(x): + return sum((x >> i) & 1 for i in range(20)) + + +def test_reduction_imperatives_booleans(): + """Checks that any/all reduction works in all frameworks""" + x_np = numpy.asarray([(bit_count(x) % 2) == 0 for x in range(2 ** 6)]).reshape([2] * 6) + + for axis in range(6): + expected_result_any = numpy.any(x_np, axis=axis, keepdims=True) + expected_result_all = numpy.all(x_np, axis=axis, keepdims=True) + assert not numpy.array_equal(expected_result_any, expected_result_all) + + axes = list("abcdef") + axes_in = list(axes) + axes_out = list(axes) + axes_out[axis] = "1" + pattern = (" ".join(axes_in)) + " -> " + (" ".join(axes_out)) + + res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any") + res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all") + + assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any)) + assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all)) + + # expected result: any/all + expected_result_any = numpy.any(x_np, axis=(0, 1), keepdims=True) + expected_result_all = numpy.all(x_np, axis=(0, 1), keepdims=True) + pattern = "a b ... -> 1 1 ..." + res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any") + res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all") + assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any)) + assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all)) diff --git a/brainpy/_src/math/tests/test_einops_parsing.py b/brainpy/_src/math/tests/test_einops_parsing.py new file mode 100644 index 000000000..069c7bbac --- /dev/null +++ b/brainpy/_src/math/tests/test_einops_parsing.py @@ -0,0 +1,111 @@ +import pytest + +from brainpy._src.math.einops_parsing import EinopsError, ParsedExpression, AnonymousAxis, _ellipsis + + +class AnonymousAxisPlaceholder: + def __init__(self, value: int): + self.value = value + assert isinstance(self.value, int) + + def __eq__(self, other): + return isinstance(other, AnonymousAxis) and self.value == other.value + + +def test_anonymous_axes(): + a, b = AnonymousAxis('2'), AnonymousAxis('2') + assert a != b + c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3) + assert a == c and b == c + assert a != d and b != d + assert [a, 2, b] == [c, 2, c] + + +def test_elementary_axis_name(): + for name in ['a', 'b', 'h', 'dx', 'h1', 'zz', 'i9123', 'somelongname', + 'Alex', 'camelCase', 'u_n_d_e_r_score', 'unreasonablyLongAxisName']: + assert ParsedExpression.check_axis_name(name) + + for name in ['', '2b', '12', '_startWithUnderscore', 'endWithUnderscore_', '_', '...', _ellipsis]: + assert not ParsedExpression.check_axis_name(name) + + +def test_invalid_expressions(): + # double ellipsis should raise an error + ParsedExpression('... a b c d') + with pytest.raises(EinopsError): + ParsedExpression('... a b c d ...') + with pytest.raises(EinopsError): + ParsedExpression('... a b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(... a) b c (d ...)') + + # double/missing/enclosed parenthesis + ParsedExpression('(a) b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(a)) b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(a b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(a) (()) b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(a) ((b c) (d ...))') + + # invalid identifiers + ParsedExpression('camelCase under_scored cApiTaLs ß ...') + with pytest.raises(EinopsError): + ParsedExpression('1a') + with pytest.raises(EinopsError): + ParsedExpression('_pre') + with pytest.raises(EinopsError): + ParsedExpression('...pre') + with pytest.raises(EinopsError): + ParsedExpression('pre...') + + +def test_parse_expression(): + parsed = ParsedExpression('a1 b1 c1 d1') + assert parsed.identifiers == {'a1', 'b1', 'c1', 'd1'} + assert parsed.composition == [['a1'], ['b1'], ['c1'], ['d1']] + assert not parsed.has_non_unitary_anonymous_axes + assert not parsed.has_ellipsis + + parsed = ParsedExpression('() () () ()') + assert parsed.identifiers == set() + assert parsed.composition == [[], [], [], []] + assert not parsed.has_non_unitary_anonymous_axes + assert not parsed.has_ellipsis + + parsed = ParsedExpression('1 1 1 ()') + assert parsed.identifiers == set() + assert parsed.composition == [[], [], [], []] + assert not parsed.has_non_unitary_anonymous_axes + assert not parsed.has_ellipsis + + aap = AnonymousAxisPlaceholder + + parsed = ParsedExpression('5 (3 4)') + assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5} + assert parsed.composition == [[aap(5)], [aap(3), aap(4)]] + assert parsed.has_non_unitary_anonymous_axes + assert not parsed.has_ellipsis + + parsed = ParsedExpression('5 1 (1 4) 1') + assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5} + assert parsed.composition == [[aap(5)], [], [aap(4)], []] + + parsed = ParsedExpression('name1 ... a1 12 (name2 14)') + assert len(parsed.identifiers) == 6 + assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2 + assert parsed.composition == [['name1'], _ellipsis, ['a1'], [aap(12)], ['name2', aap(14)]] + assert parsed.has_non_unitary_anonymous_axes + assert parsed.has_ellipsis + assert not parsed.has_ellipsis_parenthesized + + parsed = ParsedExpression('(name1 ... a1 12) name2 14') + assert len(parsed.identifiers) == 6 + assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2 + assert parsed.composition == [['name1', _ellipsis, 'a1', aap(12)], ['name2'], [aap(14)]] + assert parsed.has_non_unitary_anonymous_axes + assert parsed.has_ellipsis + assert parsed.has_ellipsis_parenthesized diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index cf7a766b4..02f671345 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -8,6 +8,7 @@ from .compat_numpy import * from .compat_tensorflow import * from .compat_pytorch import * +from .einops import * # functions from .activations import * diff --git a/brainpy/math/einops.py b/brainpy/math/einops.py new file mode 100644 index 000000000..5dcb4ce67 --- /dev/null +++ b/brainpy/math/einops.py @@ -0,0 +1,6 @@ +from brainpy._src.math.einops import ( + ein_repeat as ein_repeat, + ein_shape as ein_shape, + ein_reduce as ein_reduce, + ein_rearrange as ein_rearrange, +) diff --git a/brainpy/math/interoperability.py b/brainpy/math/interoperability.py index f6356bca7..6956f9ba2 100644 --- a/brainpy/math/interoperability.py +++ b/brainpy/math/interoperability.py @@ -6,6 +6,7 @@ as_ndarray as as_ndarray, as_numpy as as_numpy, as_variable as as_variable, + from_numpy as from_numpy, is_bp_array as is_bp_array, ) diff --git a/docs/tutorial_math/einops_in_brainpy.ipynb b/docs/tutorial_math/einops_in_brainpy.ipynb new file mode 100644 index 000000000..2489d6bae --- /dev/null +++ b/docs/tutorial_math/einops_in_brainpy.ipynb @@ -0,0 +1,1509 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Array operations with ``ein_rearrange``, ``ein_reduce``, and ``ein_repeat``\n", + "\n", + "We don't write \n", + "```python\n", + "y = x.transpose(0, 2, 3, 1)\n", + "```\n", + "We write comprehensible code\n", + "```python\n", + "y = bm.ein_rearrange(x, 'b c h w -> b h w c')\n", + "```\n", + "\n", + "\n", + "## What's in this tutorial?\n", + "\n", + "- fundamentals: reordering, composition and decomposition of axes\n", + "- operations: `ein_rearrange`, `ein_reduce`, `ein_repeat`\n", + "- how much you can do with a single operation!\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preparations" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:51.896023200Z", + "start_time": "2024-01-09T03:16:49.966551200Z" + } + }, + "outputs": [], + "source": [ + "# Examples are given for numpy. This code also setups ipython/jupyter\n", + "# so that numpy arrays in the output are displayed as images\n", + "import numpy\n", + "\n", + "import brainpy.math as bm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load a batch of images to play with" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Please download [the data](./test_images.npy)." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:51.903282300Z", + "start_time": "2024-01-09T03:16:51.898250400Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(6, 96, 96, 3) float64\n" + ] + } + ], + "source": [ + "ims = numpy.load('./test_images.npy', allow_pickle=False)\n", + "# There are 6 images of shape 96x96 with 3 color channels packed into tensor\n", + "print(ims.shape, ims.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:51.910514400Z", + "start_time": "2024-01-09T03:16:51.905419300Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 96, 3)" + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# display the first image (whole 4d tensor can't be rendered)\n", + "ims[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:51.916049400Z", + "start_time": "2024-01-09T03:16:51.912295Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 96, 3)" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# second image in a batch\n", + "ims[1].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:51.987415500Z", + "start_time": "2024-01-09T03:16:51.917288700Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 96, 3)" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# rearrange, as its name suggests, rearranges elements\n", + "# below we swapped height and width.\n", + "# In other words, transposed first two axes (dimensions)\n", + "bm.ein_rearrange(ims[0], 'h w c -> w h c').shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Composition of axes\n", + "transposition is very common and useful, but let's move to other capabilities provided by einops" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.001062900Z", + "start_time": "2024-01-09T03:16:51.984159900Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(576, 96, 3)" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# einops allows seamlessly composing batch and height to a new height dimension\n", + "# We just rendered all images by collapsing to 3d tensor!\n", + "bm.ein_rearrange(ims, 'b h w c -> (b h) w c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.043645400Z", + "start_time": "2024-01-09T03:16:52.002184500Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# or compose a new dimension of batch and width\n", + "bm.ein_rearrange(ims, 'b h w c -> h (b w) c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.044717500Z", + "start_time": "2024-01-09T03:16:52.032578100Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# resulting dimensions are computed very simply\n", + "# length of newly composed axis is a product of components\n", + "# [6, 96, 96, 3] -> [96, (6 * 96), 3]\n", + "bm.ein_rearrange(ims, 'b h w c -> h (b w) c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.059635400Z", + "start_time": "2024-01-09T03:16:52.039293900Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(165888,)" + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# we can compose more than two axes. \n", + "# let's flatten 4d array into 1d, resulting array has as many elements as the original\n", + "bm.ein_rearrange(ims, 'b h w c -> (b h w c)').shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Decomposition of axis" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.104413Z", + "start_time": "2024-01-09T03:16:52.056324200Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(2, 3, 96, 96, 3)" + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# decomposition is the inverse process - represent an axis as a combination of new axes\n", + "# several decompositions possible, so b1=2 is to decompose 6 to b1=2 and b2=3\n", + "bm.ein_rearrange(ims, '(b1 b2) h w c -> b1 b2 h w c ', b1=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.136340300Z", + "start_time": "2024-01-09T03:16:52.073847300Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(192, 288, 3)" + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# finally, combine composition and decomposition:\n", + "bm.ein_rearrange(ims, '(b1 b2) h w c -> (b1 h) (b2 w) c ', b1=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.165079200Z", + "start_time": "2024-01-09T03:16:52.106539200Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(288, 192, 3)" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# slightly different composition: b1 is merged with width, b2 with height\n", + "# ... so letters are ordered by w then by h\n", + "bm.ein_rearrange(ims, '(b1 b2) h w c -> (b2 h) (b1 w) c ', b1=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.199903Z", + "start_time": "2024-01-09T03:16:52.144629900Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(192, 288, 3)" + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# move part of width dimension to height. \n", + "# we should call this width-to-height as image width shrunk by 2 and height doubled. \n", + "# but all pixels are the same!\n", + "# Can you write reverse operation (height-to-width)?\n", + "bm.ein_rearrange(ims, 'b h (w w2) c -> (h w2) (b w) c', w2=2).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Order of axes matters" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.200972800Z", + "start_time": "2024-01-09T03:16:52.190142300Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# compare with the next example\n", + "bm.ein_rearrange(ims, 'b h w c -> h (b w) c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.250337300Z", + "start_time": "2024-01-09T03:16:52.196592800Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# order of axes in composition is different\n", + "# rule is just as for digits in the number: leftmost digit is the most significant, \n", + "# while neighboring numbers differ in the rightmost axis.\n", + "\n", + "# you can also think of this as lexicographic sort\n", + "bm.ein_rearrange(ims, 'b h w c -> h (w b) c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.277698500Z", + "start_time": "2024-01-09T03:16:52.228269800Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# what if b1 and b2 are reordered before composing to width?\n", + "bm.ein_rearrange(ims, '(b1 b2) h w c -> h (b1 b2 w) c ', b1=2).shape " + ] + }, + { + "cell_type": "code", + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bm.ein_rearrange(ims, '(b1 b2) h w c -> h (b2 b1 w) c ', b1=2).shape " + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.314368100Z", + "start_time": "2024-01-09T03:16:52.262594800Z" + } + }, + "execution_count": 17 + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Meet einops.reduce\n", + "\n", + "In einops-land you don't need to guess what happened\n", + "```python\n", + "x.mean(-1)\n", + "```\n", + "Because you write what the operation does\n", + "```python\n", + "bm.ein_reduce(x, 'b h w c -> b h w', 'mean')\n", + "```\n", + "\n", + "if axis is not present in the output — you guessed it — axis was reduced." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.354728900Z", + "start_time": "2024-01-09T03:16:52.298014600Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 96, 3)" + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# average over batch\n", + "bm.ein_reduce(ims, 'b h w c -> h w c', 'mean').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.355832600Z", + "start_time": "2024-01-09T03:16:52.340237700Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 96, 3)" + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# the previous is identical to familiar:\n", + "ims.mean(axis=0).shape\n", + "# but is so much more readable" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.408044400Z", + "start_time": "2024-01-09T03:16:52.345070800Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 96)" + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Example of reducing of several axes \n", + "# besides mean, there are also min, max, sum, prod\n", + "bm.ein_reduce(ims, 'b h w c -> h w', 'min').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.438192700Z", + "start_time": "2024-01-09T03:16:52.365121Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(48, 288, 3)" + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# this is mean-pooling with 2x2 kernel\n", + "# image is split into 2x2 patches, each patch is averaged\n", + "bm.ein_reduce(ims, 'b (h h2) (w w2) c -> h (b w) c', 'mean', h2=2, w2=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.466068200Z", + "start_time": "2024-01-09T03:16:52.429666600Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(48, 288, 3)" + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# max-pooling is similar\n", + "# result is not as smooth as for mean-pooling\n", + "bm.ein_reduce(ims, 'b (h h2) (w w2) c -> h (b w) c', 'max', h2=2, w2=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.508614800Z", + "start_time": "2024-01-09T03:16:52.453429200Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(288, 192)" + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# yet another example. Can you compute result shape?\n", + "bm.ein_reduce(ims, '(b1 b2) h w c -> (b2 h) (b1 w)', 'mean', b1=2).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "## Stack and concatenate" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.509704200Z", + "start_time": "2024-01-09T03:16:52.486964100Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " with 6 tensors of shape (96, 96, 3)\n" + ] + }, + { + "data": { + "text/plain": "[(96, 96, 3), (96, 96, 3), (96, 96, 3), (96, 96, 3), (96, 96, 3), (96, 96, 3)]" + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# rearrange can also take care of lists of arrays with the same shape\n", + "x = list(ims)\n", + "print(type(x), 'with', len(x), 'tensors of shape', x[0].shape)\n", + "# that's how we can stack inputs\n", + "# \"list axis\" becomes first (\"b\" in this case), and we left it there\n", + "res = bm.ein_rearrange(x, 'b h w c -> b h w c')\n", + "\n", + "[r.shape for r in res]" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.524732200Z", + "start_time": "2024-01-09T03:16:52.495686100Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 96, 3, 6)" + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# but new axis can appear in the other place:\n", + "bm.ein_rearrange(x, 'b h w c -> h w c b').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.528015200Z", + "start_time": "2024-01-09T03:16:52.511870500Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "False" + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# that's equivalent to numpy stacking, but written more explicitly\n", + "numpy.array_equal(bm.ein_rearrange(x, 'b h w c -> h w c b'), numpy.stack(x, axis=3))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.586497800Z", + "start_time": "2024-01-09T03:16:52.517938100Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ... or we can concatenate along axes\n", + "bm.ein_rearrange(x, 'b h w c -> h (b w) c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.589607600Z", + "start_time": "2024-01-09T03:16:52.524732200Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "False" + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# which is equivalent to concatenation\n", + "numpy.array_equal(bm.ein_rearrange(x, 'b h w c -> h (b w) c'), numpy.concatenate(x, axis=1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Addition or removal of axes\n", + "\n", + "You can write 1 to create a new axis of length 1. Similarly you can remove such axis.\n", + "\n", + "There is also a synonym `()` that you can use. That's a composition of zero axes and it also has a unit length." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.601830300Z", + "start_time": "2024-01-09T03:16:52.531696500Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(6, 1, 96, 96, 1, 3)\n", + "(6, 96, 96, 3)\n" + ] + } + ], + "source": [ + "x = bm.ein_rearrange(ims, 'b h w c -> b 1 h w 1 c') # functionality of numpy.expand_dims\n", + "print(x.shape)\n", + "print(bm.ein_rearrange(x, 'b 1 h w 1 c -> b h w c').shape) # functionality of numpy.squeeze" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.652283400Z", + "start_time": "2024-01-09T03:16:52.562649Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# compute max in each image individually, then show a difference \n", + "x = bm.ein_reduce(ims, 'b h w c -> b () () c', 'max') - ims\n", + "bm.ein_rearrange(x, 'b h w c -> h (b w) c').shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Repeating elements\n", + "\n", + "Third operation we introduce is `repeat`" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.708988500Z", + "start_time": "2024-01-09T03:16:52.634965400Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 5, 96, 3)" + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# repeat along a new axis. New axis can be placed anywhere\n", + "bm.ein_repeat(ims[0], 'h w c -> h new_axis w c', new_axis=5).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.714789300Z", + "start_time": "2024-01-09T03:16:52.710069Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 5, 96, 3)" + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# shortcut\n", + "bm.ein_repeat(ims[0], 'h w c -> h 5 w c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.757633Z", + "start_time": "2024-01-09T03:16:52.714789300Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 288, 3)" + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# repeat along w (existing axis)\n", + "bm.ein_repeat(ims[0], 'h w c -> h (repeat w) c', repeat=3).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.853440Z", + "start_time": "2024-01-09T03:16:52.757633Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(192, 192, 3)" + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# repeat along two existing axes\n", + "bm.ein_repeat(ims[0], 'h w c -> (2 h) (2 w) c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:52.935098900Z", + "start_time": "2024-01-09T03:16:52.853440Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 288, 3)" + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# order of axes matters as usual - you can repeat each element (pixel) 3 times \n", + "# by changing order in parenthesis\n", + "bm.ein_repeat(ims[0], 'h w c -> h (w repeat) c', repeat=3).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: `repeat` operation covers functionality identical to `numpy.repeat`, `numpy.tile` and actually more than that." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Reduce ⇆ repeat\n", + "\n", + "reduce and repeat are like opposite of each other: first one reduces amount of elements, second one increases.\n", + "\n", + "In the following example each image is repeated first, then we reduce over new axis to get back original tensor. Notice that operation patterns are \"reverse\" of each other" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.086847800Z", + "start_time": "2024-01-09T03:16:52.936595200Z" + } + }, + "outputs": [], + "source": [ + "repeated = bm.ein_repeat(ims, 'b h w c -> b h new_axis w c', new_axis=2)\n", + "reduced = bm.ein_reduce(repeated, 'b h new_axis w c -> b h w c', 'min')\n", + "\n", + "\n", + "assert bm.allclose(ims, reduced)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fancy examples in random order\n", + "\n", + "(a.k.a. mad designer gallery)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.124865300Z", + "start_time": "2024-01-09T03:16:53.089018Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(192, 288, 3)" + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# interweaving pixels of different pictures\n", + "# all letters are observable\n", + "bm.ein_rearrange(ims, '(b1 b2) h w c -> (h b1) (w b2) c ', b1=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.139588200Z", + "start_time": "2024-01-09T03:16:53.123858300Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(192, 288, 3)" + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# interweaving along vertical for couples of images\n", + "bm.ein_rearrange(ims, '(b1 b2) h w c -> (h b1) (b2 w) c', b1=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.186247700Z", + "start_time": "2024-01-09T03:16:53.140592800Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 288, 3)" + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# interweaving lines for couples of images\n", + "# exercise: achieve the same result without einops in your favourite framework\n", + "bm.ein_reduce(ims, '(b1 b2) h w c -> h (b2 w) c', 'max', b1=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.232730900Z", + "start_time": "2024-01-09T03:16:53.178674500Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(144, 288)" + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# color can be also composed into dimension\n", + "# ... while image is downsampled\n", + "bm.ein_reduce(ims, 'b (h 2) (w 2) c -> (c h) (b w)', 'mean').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.302503900Z", + "start_time": "2024-01-09T03:16:53.236495100Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(24, 192)" + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# disproportionate resize\n", + "bm.ein_reduce(ims, 'b (h 4) (w 3) c -> (h) (b w)', 'mean').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.365480400Z", + "start_time": "2024-01-09T03:16:53.303630100Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(48, 576)" + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# spilt each image in two halves, compute mean of the two\n", + "bm.ein_reduce(ims, 'b (h1 h2) w c -> h2 (b w)', 'mean', h1=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.413333100Z", + "start_time": "2024-01-09T03:16:53.364414400Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# split in small patches and transpose each patch\n", + "bm.ein_rearrange(ims, 'b (h1 h2) (w1 w2) c -> (h1 w2) (b w1 h2) c', h2=8, w2=8).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.499062100Z", + "start_time": "2024-01-09T03:16:53.407925200Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# stop me someone!\n", + "bm.ein_rearrange(ims, 'b (h1 h2 h3) (w1 w2 w3) c -> (h1 w2 h3) (b w1 h2 w3) c', h2=2, w2=2, w3=2, h3=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.546329400Z", + "start_time": "2024-01-09T03:16:53.459186600Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(192, 288, 3)" + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bm.ein_rearrange(ims, '(b1 b2) (h1 h2) (w1 w2) c -> (h1 b1 h2) (w1 b2 w2) c', h1=3, w1=3, b2=3).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.587041200Z", + "start_time": "2024-01-09T03:16:53.505732100Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# patterns can be arbitrarily complicated\n", + "bm.ein_reduce(ims, '(b1 b2) (h1 h2 h3) (w1 w2 w3) c -> (h1 w1 h3) (b1 w2 h2 w3 b2) c', 'mean', \n", + " h2=2, w1=2, w3=2, h3=2, b2=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.608899300Z", + "start_time": "2024-01-09T03:16:53.556416400Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# subtract background in each image individually and normalize\n", + "# pay attention to () - this is composition of 0 axis, a dummy axis with 1 element.\n", + "im2 = bm.ein_reduce(ims, 'b h w c -> b () () c', 'max') - ims\n", + "im2 /= bm.ein_reduce(im2, 'b h w c -> b () () c', 'max')\n", + "bm.ein_rearrange(im2, 'b h w c -> h (b w) c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.742684900Z", + "start_time": "2024-01-09T03:16:53.578494900Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# pixelate: first downscale by averaging, then upscale back using the same pattern\n", + "averaged = bm.ein_reduce(ims, 'b (h h2) (w w2) c -> b h w c', 'mean', h2=6, w2=8)\n", + "bm.ein_repeat(averaged, 'b h w c -> (h h2) (b w w2) c', h2=6, w2=8).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.783169200Z", + "start_time": "2024-01-09T03:16:53.742684900Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576, 3)" + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bm.ein_rearrange(ims, 'b h w c -> w (b h) c').shape" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T03:16:53.827528Z", + "start_time": "2024-01-09T03:16:53.765960100Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "(96, 576)" + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# let's bring color dimension as part of horizontal axis\n", + "# at the same time horizontal axis is downsampled by 2x\n", + "bm.ein_reduce(ims, 'b (h h2) (w w2) c -> (h w2) (b w c)', 'mean', h2=3, w2=3).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ok, numpy is fun, but how do I use einops with some other framework?\n", + "\n", + "If that's what you've done with `ims` being numpy array:\n", + "```python\n", + "bm.ein_rearrange(ims, 'b h w c -> w (b h) c')\n", + "```\n", + "That's how you adapt the code for other frameworks:\n", + "\n", + "```python\n", + "# pytorch:\n", + "bm.ein_rearrange(ims, 'b h w c -> w (b h) c')\n", + "# tensorflow:\n", + "bm.ein_rearrange(ims, 'b h w c -> w (b h) c')\n", + "# chainer:\n", + "bm.ein_rearrange(ims, 'b h w c -> w (b h) c')\n", + "# gluon:\n", + "bm.ein_rearrange(ims, 'b h w c -> w (b h) c')\n", + "# cupy:\n", + "bm.ein_rearrange(ims, 'b h w c -> w (b h) c')\n", + "# jax:\n", + "bm.ein_rearrange(ims, 'b h w c -> w (b h) c')\n", + "\n", + "...well, you got the idea.\n", + "```\n", + "\n", + "Einops allows backpropagation as if all operations were native to framework.\n", + "Operations do not change when moving to another framework - einops notation is universal" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Summary\n", + "\n", + "- `rearrange` doesn't change number of elements and covers different numpy functions (like `transpose`, `reshape`, `stack`, `concatenate`, `squeeze` and `expand_dims`)\n", + "- `reduce` combines same reordering syntax with reductions (`mean`, `min`, `max`, `sum`, `prod`, and any others)\n", + "- `repeat` additionally covers repeating and tiling\n", + "- composition and decomposition of axes are a corner stone, they can and should be used together\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorial_math/index.rst b/docs/tutorial_math/index.rst index 6ad69939d..d5b764761 100644 --- a/docs/tutorial_math/index.rst +++ b/docs/tutorial_math/index.rst @@ -8,3 +8,4 @@ Math Foundation control_flows Numpy_like_Operations.ipynb Dedicated_Operators.ipynb + einops_in_brainpy.ipynb diff --git a/docs/tutorial_math/test_images.npy b/docs/tutorial_math/test_images.npy new file mode 100644 index 000000000..bbff7bd9b Binary files /dev/null and b/docs/tutorial_math/test_images.npy differ diff --git a/setup.py b/setup.py index b9f51dd6b..d03fd91fd 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import io import os import re +import time from setuptools import find_packages from setuptools import setup @@ -26,6 +27,7 @@ except ModuleNotFoundError: pass + # version here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'brainpy', '__init__.py'), 'r') as f: @@ -42,7 +44,7 @@ # setup setup( name='brainpy', - version=version, + version=version + '.post{}'.format(time.strftime("%Y%m%d", time.localtime())), description='BrainPy: Brain Dynamics Programming in Python', long_description=README, long_description_content_type="text/markdown",