Skip to content

Commit

Permalink
Update _compat_numpy_misc.py
Browse files Browse the repository at this point in the history
Routhleck committed Jun 14, 2024
1 parent 0e54af2 commit cf60b41
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion brainunit/math/_compat_numpy_misc.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@

from __future__ import annotations

import collections
from collections.abc import Sequence
from typing import (Callable, Union, Tuple, Any, Optional)

@@ -120,6 +121,14 @@ def broadcast_shapes(*shapes):
"""
return jnp.broadcast_shapes(*shapes)

def _default_poly_einsum_handler(*operands, **kwargs):
dummy = collections.namedtuple('dummy', ['shape', 'dtype'])
dummies = [dummy(tuple(d if type(d) is int else 8 for d in x.shape), x.dtype)
if hasattr(x, 'dtype') else x for x in operands]
mapping = {id(d): i for i, d in enumerate(dummies)}
out_dummies, contractions = opt_einsum.contract_path(*dummies, **kwargs)
contract_operands = [operands[mapping[id(d)]] for d in out_dummies]
return contract_operands, contractions

def einsum(
subscripts: str,
@@ -176,7 +185,6 @@ def einsum(
if not non_constant_dim_types:
contract_path = opt_einsum.contract_path
else:
from jax._src.numpy.lax_numpy import _default_poly_einsum_handler
contract_path = _default_poly_einsum_handler

operands, contractions = contract_path(*operands, einsum_call=True, use_blas=True, optimize=optimize)

0 comments on commit cf60b41

Please sign in to comment.