diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index 7c2598d..80113b8 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -15,6 +15,7 @@ from __future__ import annotations +import collections from collections.abc import Sequence from typing import (Callable, Union, Tuple, Any, Optional) @@ -120,6 +121,14 @@ def broadcast_shapes(*shapes): """ return jnp.broadcast_shapes(*shapes) +def _default_poly_einsum_handler(*operands, **kwargs): + dummy = collections.namedtuple('dummy', ['shape', 'dtype']) + dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype) + if hasattr(x, 'dtype') else x for x in operands] + mapping = {id(d): i for i, d in enumerate(dummies)} + out_dummies, contractions = opt_einsum.contract_path(*dummies, **kwargs) + contract_operands = [operands[mapping[id(d)]] for d in out_dummies] + return contract_operands, contractions def einsum( subscripts: str, @@ -176,7 +185,6 @@ def einsum( if not non_constant_dim_types: contract_path = opt_einsum.contract_path else: - from jax._src.numpy.lax_numpy import _default_poly_einsum_handler contract_path = _default_poly_einsum_handler operands, contractions = contract_path(*operands, einsum_call=True, use_blas=True, optimize=optimize)