Skip to content

Commit

Permalink
turn numpy to jax numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
rochisha0 committed Jul 31, 2024
1 parent 684da39 commit 54976f4
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/qutip_jax/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from . import JaxArray
import qutip
from functools import partial
import numpy as np

__all__ = [
"reshape_jaxarray",
Expand Down Expand Up @@ -64,14 +63,14 @@ def split_columns_jaxarray(matrix, copy=None):


def _parse_ptrace_inputs(dims, sel, shape):
dims = np.atleast_1d(dims).ravel()
sel = np.atleast_1d(sel)
dims = jnp.atleast_1d(dims).ravel()
sel = jnp.atleast_1d(sel)
sel.sort()

if shape[0] != shape[1]:
raise ValueError("ptrace is only defined for square density matrices")

if shape[0] != np.prod(dims, dtype=int):
if shape[0] != jnp.prod(dims, dtype=int):
raise ValueError(
f"the input matrix shape, {shape} and the"
f" dimension argument, {dims}, are not compatible."
Expand Down Expand Up @@ -110,11 +109,11 @@ def ptrace_jaxarray(matrix, dims, sel):
nd = dims.shape[0]
dims2 = tuple(list(dims) * 2)
sel = list(sel)
qtrace = list(set(np.arange(nd)) - set(sel))
qtrace = list(set(jnp.arange(nd)) - set(sel))


dkeep = np.prod([dims[x] for x in sel], dtype=int)
dtrace = np.prod([dims[x] for x in qtrace], dtype=int)
dkeep = jnp.prod([dims[x] for x in sel], dtype=int)
dtrace = jnp.prod([dims[x] for x in qtrace], dtype=int)

transpose_idx = tuple(
qtrace + [nd + q for q in qtrace]
Expand Down

0 comments on commit 54976f4

Please sign in to comment.