Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 21, 2023
1 parent 5aed3bd commit 8da1271
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,8 +915,8 @@ def _valid_jaxtype(arg):

def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.get_aval(x)
if jnp.issubdtype(aval.dtype, dtypes.extended):
raise TypeError(f"{name} with output element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"{name} with output element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
Expand All @@ -937,8 +937,8 @@ def _check_output_dtype_revderiv(name, holomorphic, x):
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
_check_arg(x)
aval = core.get_aval(x)
if jnp.issubdtype(aval.dtype, dtypes.extended):
raise TypeError(f"{name} with input element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"{name} with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
Expand Down Expand Up @@ -970,8 +970,8 @@ def _check_output_dtype_jacfwd(holomorphic, x):
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
_check_arg(x)
aval = core.get_aval(x)
if jnp.issubdtype(aval.dtype, dtypes.extended):
raise TypeError(f"jacfwd with input element type {aval.dtype.name}")
# if jnp.issubdtype(aval.dtype, dtypes.extended):
# raise TypeError(f"jacfwd with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
Expand Down

0 comments on commit 8da1271

Please sign in to comment.