From 05394a27a893210b236d18cc548cc1124c3a06ff Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 23 Mar 2024 13:24:02 +0800 Subject: [PATCH] dtype checking during exponential euler method --- brainpy/_src/integrators/ode/exponential.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index e44e324e7..ec0e10701 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -106,6 +106,9 @@ """ from functools import wraps + +import jax.numpy as jnp + from brainpy import errors from brainpy._src import math as bm from brainpy._src.integrators import constants as C, utils, joint_eq @@ -356,6 +359,9 @@ def _build_integrator(self, eq): # integration function def integral(*args, **kwargs): assert len(args) > 0 + if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]: + raise ValueError('The input data type should be float32, float64, float16, or bfloat16 when using Exponential Euler method.' + f'But we got {args[0].dtype}.') dt = kwargs.pop(C.DT, self.dt) linear, derivative = value_and_grad(*args, **kwargs) phi = bm.exprel(dt * linear)