diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index 80113b8..96ec68c 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -121,6 +121,7 @@ def broadcast_shapes(*shapes): """ return jnp.broadcast_shapes(*shapes) + def _default_poly_einsum_handler(*operands, **kwargs): dummy = collections.namedtuple('dummy', ['shape', 'dtype']) dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype) @@ -130,6 +131,7 @@ def _default_poly_einsum_handler(*operands, **kwargs): contract_operands = [operands[mapping[id(d)]] for d in out_dummies] return contract_operands, contractions + def einsum( subscripts: str, /, @@ -209,15 +211,22 @@ def einsum( else: if isinstance(operands[i + 1], Quantity): unit = unit * operands[i + 1].dim + operands = [op.value if isinstance(op, Quantity) else op for op in operands] - contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) + r = jnp.einsum(subscripts, + *operands, + precision=precision, + preferred_element_type=preferred_element_type, + _dot_general=_dot_general) - einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) - if spec is not None: - einsum = jax.named_call(einsum, name=spec) - operands = [op.value if isinstance(op, Quantity) else op for op in operands] - r = einsum(operands, contractions, precision, # type: ignore[operator] - preferred_element_type, _dot_general) + # contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) + # + # einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) + # if spec is not None: + # einsum = jax.named_call(einsum, name=spec) + + # r = einsum(operands, contractions, precision, # type: ignore[operator] + # preferred_element_type, _dot_general) if unit is not None: return Quantity(r, dim=unit) else: