Skip to content

Commit

Permalink
Update _compat_numpy_misc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 20, 2024
1 parent cf60b41 commit 01c75cd
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions brainunit/math/_compat_numpy_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def broadcast_shapes(*shapes):
"""
return jnp.broadcast_shapes(*shapes)


def _default_poly_einsum_handler(*operands, **kwargs):
dummy = collections.namedtuple('dummy', ['shape', 'dtype'])
dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype)
Expand All @@ -130,6 +131,7 @@ def _default_poly_einsum_handler(*operands, **kwargs):
contract_operands = [operands[mapping[id(d)]] for d in out_dummies]
return contract_operands, contractions


def einsum(
subscripts: str,
/,
Expand Down Expand Up @@ -209,15 +211,22 @@ def einsum(
else:
if isinstance(operands[i + 1], Quantity):
unit = unit * operands[i + 1].dim
operands = [op.value if isinstance(op, Quantity) else op for op in operands]

contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
r = jnp.einsum(subscripts,
*operands,
precision=precision,
preferred_element_type=preferred_element_type,
_dot_general=_dot_general)

einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
if spec is not None:
einsum = jax.named_call(einsum, name=spec)
operands = [op.value if isinstance(op, Quantity) else op for op in operands]
r = einsum(operands, contractions, precision, # type: ignore[operator]
preferred_element_type, _dot_general)
# contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
#
# einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
# if spec is not None:
# einsum = jax.named_call(einsum, name=spec)

# r = einsum(operands, contractions, precision, # type: ignore[operator]
# preferred_element_type, _dot_general)
if unit is not None:
return Quantity(r, dim=unit)
else:
Expand Down

0 comments on commit 01c75cd

Please sign in to comment.