Skip to content

Commit

Permalink
filter in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric Giguere committed Jan 11, 2024
1 parent 6af5a26 commit efbf4a7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
# to see if something hung.
timeout-minutes: 60
run: |
pytest --durations=0 --durations-min=1.0 --verbosity=1 --cov=qutip_jax --cov-report= --color=yes -W ignore::UserWarning:qutip
pytest --durations=0 --durations-min=1.0 --verbosity=1 --cov=qutip_jax --cov-report= --color=yes -W ignore::UserWarning:qutip -W "ignore:Complex dtype:UserWarning"
# Above flags are:
# --durations=0 --durations-min=1.0
# at the end, show a list of all the tests that took longer than a
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/weekly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
# to see if something hung.
timeout-minutes: 60
run: |
pytest --durations=0 --durations-min=1.0 --verbosity=1 --color=yes -W ignore::UserWarning:qutip
pytest --durations=0 --durations-min=1.0 --verbosity=1 --color=yes -W ignore::UserWarning:qutip -W "ignore:Complex dtype:UserWarning"
# Above flags are:
# --durations=0 --durations-min=1.0
# at the end, show a list of all the tests that took longer than a
Expand Down
27 changes: 10 additions & 17 deletions src/qutip_jax/ode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import diffrax
import warnings
from qutip.solver.integrator import Integrator
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -65,22 +64,16 @@ def get_state(self, copy=False):
return self.t, JaxArray(_float2cplx(self.state))

def integrate(self, t, copy=False, **kwargs):
with warnings.catch_warnings():
# Diffrax added partial support for complex number, but raise a
# warning when it find a complex anywhere in the tree.
warnings.filterwarnings("ignore",
message="Complex dtype support is work in progress,"
)
sol = diffrax.diffeqsolve(
self.ODEsystem,
t0=self.t,
t1=t,
y0=self.state,
saveat=diffrax.SaveAt(t1=True, solver_state=True),
solver_state=self.solver_state,
args=(self.system, kwargs),
**self._options,
)
sol = diffrax.diffeqsolve(
self.ODEsystem,
t0=self.t,
t1=t,
y0=self.state,
saveat=diffrax.SaveAt(t1=True, solver_state=True),
solver_state=self.solver_state,
args=(self.system, kwargs),
**self._options,
)
self.t = t
self.state = sol.ys[0, :]
self.solver_state = sol.solver_state
Expand Down

0 comments on commit efbf4a7

Please sign in to comment.