From 8da1271701d6af67553ad7b73cf630e646069d09 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 21 Sep 2023 17:34:50 +0800 Subject: [PATCH] updates --- brainpy/_src/math/object_transform/autograd.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f8164e615..5f06b4e67 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -915,8 +915,8 @@ def _valid_jaxtype(arg): def _check_output_dtype_revderiv(name, holomorphic, x): aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"{name} with output element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"{name} with output element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, " @@ -937,8 +937,8 @@ def _check_output_dtype_revderiv(name, holomorphic, x): def _check_input_dtype_revderiv(name, holomorphic, allow_int, x): _check_arg(x) aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"{name} with input element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"{name} with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, " @@ -970,8 +970,8 @@ def _check_output_dtype_jacfwd(holomorphic, x): def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None: _check_arg(x) aval = core.get_aval(x) - if jnp.issubdtype(aval.dtype, dtypes.extended): - raise TypeError(f"jacfwd with input element type {aval.dtype.name}") + # if jnp.issubdtype(aval.dtype, dtypes.extended): + # raise TypeError(f"jacfwd with input element type {aval.dtype.name}") if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError("jacfwd with holomorphic=True requires inputs with complex "