Skip to content

Commit

Permalink
dtype checking during exponential euler method
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 23, 2024
1 parent e0ee142 commit 05394a2
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions brainpy/_src/integrators/ode/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@
"""

from functools import wraps

import jax.numpy as jnp

from brainpy import errors
from brainpy._src import math as bm
from brainpy._src.integrators import constants as C, utils, joint_eq
Expand Down Expand Up @@ -356,6 +359,9 @@ def _build_integrator(self, eq):
# integration function
def integral(*args, **kwargs):
assert len(args) > 0
if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
raise ValueError('The input data type should be float32, float64, float16, or bfloat16 when using Exponential Euler method.'
f'But we got {args[0].dtype}.')
dt = kwargs.pop(C.DT, self.dt)
linear, derivative = value_and_grad(*args, **kwargs)
phi = bm.exprel(dt * linear)
Expand Down

0 comments on commit 05394a2

Please sign in to comment.