diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f8164e615..5f06b4e67 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -915,8 +915,8 @@ def _valid_jaxtype(arg): def _check_output_dtype_revderiv(name, holomorphic, x): aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"{name} with output element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"{name} with output element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, " @@ -937,8 +937,8 @@ def _check_output_dtype_revderiv(name, holomorphic, x): def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): _check_arg(x) aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"{name} with input element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"{name} with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " @@ -970,8 +970,8 @@ def _check_output_dtype_jacfwd(holomorphic, x): def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None: _check_arg(x) aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"jacfwd with input element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"jacfwd with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError("jacfwd with holomorphic=True requires inputs with complex "