Skip to content

Commit

Permalink
Merge pull request #20 from Ericgig/feature.jaxdiag
Browse files Browse the repository at this point in the history
Add a jax backend based on dia sparse matrix, `JaxDia`.
  • Loading branch information
Ericgig authored Apr 1, 2024
2 parents de36e0d + 368d4a5 commit 5f1c2bf
Show file tree
Hide file tree
Showing 25 changed files with 1,642 additions and 290 deletions.
16 changes: 11 additions & 5 deletions src/qutip_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
import qutip
from qutip_jax.jaxarray import JaxArray
from .jaxarray import JaxArray
from .jaxdia import JaxDia

from .convert import is_jax_array, jax_from_dense, dense_from_jax
from .convert import *
from .version import version as __version__


# Register the data layer for JAX
qutip.data.to.add_conversions(
[
(JaxArray, qutip.data.Dense, jax_from_dense),
(qutip.data.Dense, JaxArray, dense_from_jax, 2),
(JaxArray, qutip.data.Dense, jaxarray_from_dense),
(qutip.data.Dense, JaxArray, dense_from_jaxarray, 2),
(JaxArray, JaxDia, jaxarray_from_jaxdia),
(JaxDia, JaxArray, jaxdia_from_jaxarray),
(qutip.data.Dia, JaxDia, dia_from_jaxdia),
(JaxDia, qutip.data.Dia, jaxdia_from_dia),
]
)

# User friendly name for conversion with `to` or Qobj creation functions:
qutip.data.to.register_aliases(["jax", "JaxArray"], JaxArray)
qutip.data.to.register_aliases(["jaxdia", "JaxDia"], JaxDia)

qutip.data.create.add_creators(
[
(is_jax_array, JaxArray, 85),
]
)

del is_jax_array

from .binops import *
from .unary import *
Expand Down
Loading

0 comments on commit 5f1c2bf

Please sign in to comment.