diff --git a/src/qutip_jax/__init__.py b/src/qutip_jax/__init__.py index 0857251..d92cd5e 100644 --- a/src/qutip_jax/__init__.py +++ b/src/qutip_jax/__init__.py @@ -1,27 +1,33 @@ import qutip -from qutip_jax.jaxarray import JaxArray +from .jaxarray import JaxArray +from .jaxdia import JaxDia -from .convert import is_jax_array, jax_from_dense, dense_from_jax +from .convert import * from .version import version as __version__ # Register the data layer for JAX qutip.data.to.add_conversions( [ - (JaxArray, qutip.data.Dense, jax_from_dense), - (qutip.data.Dense, JaxArray, dense_from_jax, 2), + (JaxArray, qutip.data.Dense, jaxarray_from_dense), + (qutip.data.Dense, JaxArray, dense_from_jaxarray, 2), + (JaxArray, JaxDia, jaxarray_from_jaxdia), + (JaxDia, JaxArray, jaxdia_from_jaxarray), + (qutip.data.Dia, JaxDia, dia_from_jaxdia), + (JaxDia, qutip.data.Dia, jaxdia_from_dia), ] ) # User friendly name for conversion with `to` or Qobj creation functions: qutip.data.to.register_aliases(["jax", "JaxArray"], JaxArray) +qutip.data.to.register_aliases(["jaxdia", "JaxDia"], JaxDia) qutip.data.create.add_creators( [ (is_jax_array, JaxArray, 85), ] ) - +del is_jax_array from .binops import * from .unary import * diff --git a/src/qutip_jax/binops.py b/src/qutip_jax/binops.py index f9d762d..9bef827 100644 --- a/src/qutip_jax/binops.py +++ b/src/qutip_jax/binops.py @@ -1,14 +1,26 @@ import qutip from .jaxarray import JaxArray +from .jaxdia import JaxDia, clean_dia import jax.numpy as jnp +import jax +from jax import vmap, jit +from functools import partial __all__ = [ "add_jaxarray", + "add_jaxdia", "sub_jaxarray", + "sub_jaxdia", "mul_jaxarray", + "mul_jaxdia", "matmul_jaxarray", + "matmul_jaxdia", + "matmul_jaxdia_jaxarray_jaxarray", + "matmul_jaxarray_jaxdia_jaxarray", "multiply_jaxarray", + "multiply_jaxdia", "kron_jaxarray", + "kron_jaxdia", "pow_jaxarray", ] @@ -16,16 +28,16 @@ def _check_same_shape(left, right): if left.shape != right.shape: raise ValueError( - f"""Incompatible shapes for addition of two matrices: - left={left.shape} and right={right.shape}""" + "Incompatible shapes for addition of two matrices: " + f"left={left.shape} and right={right.shape}" + "" ) def _check_matmul_shape(left, right, out): if left.shape[1] != right.shape[0]: raise ValueError( - "incompatible matrix shapes " + str(left.shape) - + " and " + str(right.shape) + f"incompatible matrix shapes {left.shape} and {right.shape}" ) if ( out is not None @@ -33,14 +45,13 @@ def _check_matmul_shape(left, right, out): and out.shape[1] != right.shape[1] ): raise ValueError( - "incompatible output shape, got " - + str(out.shape) - + " but needed " + f"incompatible output shape, got {out.shape}, but needed " + str((left.shape[0], right.shape[1])) ) -def add_jaxarray(left, right, scale=1): +@jit +def add_jaxarray(left, right, scale=None): """ Perform the operation left + scale*right @@ -49,7 +60,7 @@ def add_jaxarray(left, right, scale=1): """ _check_same_shape(left, right) - if scale == 1 and isinstance(scale, int): + if scale == None: out = JaxArray._fast_constructor( left._jxa + right._jxa, shape=left.shape ) @@ -60,6 +71,51 @@ def add_jaxarray(left, right, scale=1): return out +@jit +def add_jaxdia(left, right, scale=None): + """ + Perform the operation + left + scale*right + where `left` and `right` are matrices, and `scale` is an optional complex + scalar. + """ + _check_same_shape(left, right) + diag_left = 0 + diag_right = 0 + data = [] + offsets = [] + + all_diag = set(left.offsets) | set(right.offsets) + + for diag in all_diag: + if diag in left.offsets and diag in right.offsets: + diag_left = left.offsets.index(diag) + diag_right = right.offsets.index(diag) + offsets.append(diag) + if scale is None: + data.append(left.data[diag_left, :] + right.data[diag_right, :]) + else: + data.append( + left.data[diag_left, :] + right.data[diag_right, :] * scale + ) + + elif diag in left.offsets: + diag_left = left.offsets.index(diag) + offsets.append(diag) + data.append(left.data[diag_left, :]) + + elif diag in right.offsets: + diag_right = right.offsets.index(diag) + offsets.append(diag) + if scale is None: + data.append(right.data[diag_right, :]) + else: + data.append(right.data[diag_right, :] * scale) + + return JaxDia((jnp.array(data), tuple(offsets),), left.shape, False) + + +@jit def sub_jaxarray(left, right): """ Perform the operation @@ -69,6 +125,17 @@ def sub_jaxarray(left, right): return add_jaxarray(left, right, -1) +@jit +def sub_jaxdia(left, right): + """ + Perform the operation + left - right + where `left` and `right` are matrices. + """ + return add_jaxdia(left, right, -1) + + +@jit def mul_jaxarray(matrix, value): """Multiply a matrix element-wise by a scalar.""" # We don't want to check values type in case jax pass a tracer etc. @@ -77,6 +144,24 @@ def mul_jaxarray(matrix, value): return JaxArray(matrix._jxa * value) +@partial(jit, donate_argnums=[0]) +def imul_jaxarray(matrix, value): + """Multiply a matrix element-wise by a scalar.""" + # We don't want to check values type in case jax pass a tracer etc. + # But we want to ensure the output is a matrix, thus don't use the + # fast constructor. + return JaxArray(matrix._jxa * value) + + +@jit +def mul_jaxdia(matrix, value): + """Multiply a matrix element-wise by a scalar.""" + return JaxDia._fast_constructor( + matrix.offsets, matrix.data * value, matrix.shape + ) + + +@partial(jit, donate_argnums=(3,)) def matmul_jaxarray(left, right, scale=1, out=None): """ Compute the matrix multiplication of two matrices, with the operation @@ -108,12 +193,135 @@ def matmul_jaxarray(left, right, scale=1, out=None): out._jxa = result + out._jxa +@partial(jit, donate_argnums=(3,)) +def matmul_jaxdia(left, right, scale=1.0, out=None): + _check_matmul_shape(left, right, out) + out_dict = {} + + for diag_left in range(left.num_diags): + for diag_right in range(right.num_diags): + off_out = left.offsets[diag_left] + right.offsets[diag_right] + if off_out <= -left.shape[0] or off_out >= right.shape[1]: + continue + + start_left = ( + max(0, left.offsets[diag_left]) + right.offsets[diag_right] + ) + start_right = max(0, right.offsets[diag_right]) + start_out = max(0, off_out) + start = max(start_left, start_right, start_out) + + end_left = ( + min(left.shape[1], left.shape[0] + left.offsets[diag_left]) + + right.offsets[diag_right] + ) + end_right = min( + right.shape[1], right.shape[0] + right.offsets[diag_right] + ) + end_out = min(right.shape[1], left.shape[0] + off_out) + end = min(end_left, end_right, end_out) + + left_shift = -right.offsets[diag_right] + data = jnp.zeros(right.shape[1], dtype=jnp.complex128) + data = data.at[start:end].set( + scale + * left.data[diag_left, left_shift + start : left_shift + end] + * right.data[diag_right, start:end] + ) + + if off_out in out_dict: + out_dict[off_out] = out_dict[off_out] + data + else: + out_dict[off_out] = data + + out_dia = JaxDia._fast_constructor( + tuple(out_dict.keys()), + jnp.array(list(out_dict.values())), + (left.shape[0], right.shape[1]), + ) + if out is not None: + out_dia = add_jaxdia(out, out_dia) + return out_dia + + +@partial(jit, donate_argnums=(3,)) +def matmul_jaxdia_jaxarray_jaxarray(left, right, scale=None, out=None): + _check_matmul_shape(left, right, out) + mul = vmap(jnp.multiply, (0, 0)) + if out is None: + out = jnp.zeros((left.shape[0], right.shape[1]), dtype=jnp.complex128) + else: + out = out._jxa + + for offset, data in zip(left.offsets, left.data): + start = max(0, offset) + end = min(left.shape[1], left.shape[0] + offset) + top = max(0, -offset) + bottom = top + end - start + + if scale is not None: + out = out.at[top:bottom, :].add( + mul(data[start:end], right._jxa[start:end, :]) * scale + ) + else: + out = out.at[top:bottom, :].add( + mul(data[start:end], right._jxa[start:end, :]) + ) + + return JaxArray(out, shape=(left.shape[0], right.shape[1]), copy=False) + + +@partial(jit, donate_argnums=(3,)) +def matmul_jaxarray_jaxdia_jaxarray(left, right, scale=1.0, out=None): + _check_matmul_shape(left, right, out) + mul = vmap(jnp.multiply, (1, 0)) + if out is None: + out = jnp.zeros((left.shape[0], right.shape[1]), dtype=jnp.complex128) + else: + out = out._jxa + + for offset, data in zip(right.offsets, right.data): + start = max(0, offset) + end = min(right.shape[1], right.shape[0] + offset) + top = max(0, -offset) + bottom = top + end - start + + out = out.at[:, start:end].add( + mul(left._jxa[:, top:bottom], data[start:end]).T * scale + ) + + return JaxArray(out, shape=(left.shape[0], right.shape[1]), copy=False) + + +@jit def multiply_jaxarray(left, right): """Element-wise multiplication of matrices.""" _check_same_shape(left, right) return JaxArray._fast_constructor(left._jxa * right._jxa, shape=left.shape) +@jit +def multiply_jaxdia(left, right): + """Element-wise multiplication of matrices.""" + _check_same_shape(left, right) + diag_left = 0 + diag_right = 0 + data = [] + offsets = [] + + for i, diag in enumerate(left.offsets): + if diag not in right.offsets: + continue + j = right.offsets.index(diag) + offsets.append(diag) + data.append(left.data[i, :] * right.data[j, :]) + + out = JaxDia._fast_constructor(tuple(offsets), jnp.array(data), left.shape) + + return out + + +@jit def kron_jaxarray(left, right): """ Compute the Kronecker product of two matrices. This is used to represent @@ -122,6 +330,78 @@ def kron_jaxarray(left, right): return JaxArray(jnp.kron(left._jxa, right._jxa)) +@jit +def _multiply_outer(left, right): + return vmap(vmap(jnp.multiply, (None, 0)), (0, None))(left, right).ravel() + + +@jit +def kron_jaxdia(left, right): + """ + Compute the Kronecker product of two matrices. This is used to represent + quantum tensor products of vector spaces. + """ + nrows = left.shape[0] * right.shape[0] + ncols = left.shape[1] * right.shape[1] + left = clean_dia(left) + right = clean_dia(right) + out = {} + + if right.shape[0] == right.shape[1]: + for diag_left in range(left.num_diags): + for diag_right in range(right.num_diags): + out_diag = ( + left.offsets[diag_left] * right.shape[0] + + right.offsets[diag_right] + ) + out_data = _multiply_outer( + left.data[diag_left], right.data[diag_right] + ) + if out_diag in out: + out[out_diag] = out[out_diag] + out_data + else: + out[out_diag] = out_data + + else: + delta = right.shape[0] - right.shape[1] + for diag_left in range(left.num_diags): + start_left = max(0, left.offsets[diag_left]) + end_left = min( + left.shape[1], left.shape[0] + left.offsets[diag_left] + ) + for diag_right in range(right.num_diags): + start_right = max(0, right.offsets[diag_right]) + end_right = min( + right.shape[1], right.shape[0] + right.offsets[diag_right] + ) + + for col_left in range(start_left, end_left): + out_diag = ( + left.offsets[diag_left] * right.shape[0] + + right.offsets[diag_right] + - col_left * delta + ) + data = jnp.zeros(ncols, dtype=jnp.complex128) + data = data.at[ + col_left * right.shape[1] : col_left * right.shape[1] + + right.shape[1] + ].set( + left.data[diag_left, col_left] * right.data[diag_right] + ) + + if out_diag in out: + out[out_diag] = out[out_diag] + data + else: + out[out_diag] = data + + out = JaxDia( + (jnp.array(list(out.values())), tuple(out.keys())), + shape=(nrows, ncols) + ) + out = clean_dia(out) + return out + + def pow_jaxarray(matrix, n): """ Compute the integer matrix power of the square input matrix. The power @@ -142,29 +422,57 @@ def pow_jaxarray(matrix, n): qutip.data.add.add_specialisations( - [(JaxArray, JaxArray, JaxArray, add_jaxarray),] + [ + (JaxArray, JaxArray, JaxArray, add_jaxarray), + (JaxDia, JaxDia, JaxDia, add_jaxdia), + ] ) qutip.data.sub.add_specialisations( - [(JaxArray, JaxArray, JaxArray, sub_jaxarray),] + [ + (JaxArray, JaxArray, JaxArray, sub_jaxarray), + (JaxDia, JaxDia, JaxDia, sub_jaxdia), + ] +) + +qutip.data.imul.add_specialisations( + [ + (JaxArray, JaxArray, imul_jaxarray), + ] ) qutip.data.mul.add_specialisations( - [(JaxArray, JaxArray, mul_jaxarray),] + [ + (JaxArray, JaxArray, mul_jaxarray), + (JaxDia, JaxDia, mul_jaxdia), + ] ) qutip.data.matmul.add_specialisations( - [(JaxArray, JaxArray, JaxArray, matmul_jaxarray),] + [ + (JaxArray, JaxArray, JaxArray, matmul_jaxarray), + (JaxDia, JaxDia, JaxDia, matmul_jaxdia), + (JaxDia, JaxArray, JaxArray, matmul_jaxdia_jaxarray_jaxarray), + (JaxArray, JaxDia, JaxArray, matmul_jaxarray_jaxdia_jaxarray), + ] ) qutip.data.multiply.add_specialisations( - [(JaxArray, JaxArray, JaxArray, multiply_jaxarray),] + [ + (JaxArray, JaxArray, JaxArray, multiply_jaxarray), + (JaxDia, JaxDia, JaxDia, multiply_jaxdia), + ] ) qutip.data.kron.add_specialisations( - [(JaxArray, JaxArray, JaxArray, kron_jaxarray),] + [ + (JaxArray, JaxArray, JaxArray, kron_jaxarray), + (JaxDia, JaxDia, JaxDia, kron_jaxdia), + ] ) qutip.data.pow.add_specialisations( - [(JaxArray, JaxArray, pow_jaxarray),] + [ + (JaxArray, JaxArray, pow_jaxarray), + ] ) diff --git a/src/qutip_jax/convert.py b/src/qutip_jax/convert.py index d5b77ce..3229564 100644 --- a/src/qutip_jax/convert.py +++ b/src/qutip_jax/convert.py @@ -1,19 +1,71 @@ import qutip from .jaxarray import JaxArray +from .jaxdia import JaxDia import jax import jax.numpy as jnp import numpy as np +from qutip import settings + + +__all__ = [ + "is_jax_array", + "jaxarray_from_dense", + "dense_from_jaxarray", + "jaxdia_from_jaxarray", + "jaxarray_from_jaxdia", + "jaxdia_from_dia", + "dia_from_jaxdia", +] -__all__ = ["is_jax_array", "jax_from_dense", "dense_from_jax"] # Conversion function -def jax_from_dense(dense): +def jaxarray_from_dense(dense): return JaxArray(dense.to_array(), copy=False) -def dense_from_jax(jax_array): +def dense_from_jaxarray(jax_array): return qutip.data.Dense(jax_array.to_array(), copy=False) +def jaxdia_from_dia(dia_mat): + as_scipy = dia_mat.as_scipy() + return JaxDia((as_scipy.data, as_scipy.offsets), shape=dia_mat.shape) + + +def dia_from_jaxdia(jaxdiag): + return qutip.data.Dia((jaxdiag.data, jaxdiag.offsets), shape=jaxdiag.shape) + + +def jaxdia_from_jaxarray(jax_array): + tol = settings.core["auto_tidyup_atol"] + data = {} + + for row in range(jax_array.shape[0]): + for col in range(jax_array.shape[1]): + if jnp.abs(jax_array._jxa[row, col]) <= tol: + continue + diag = col - row + if diag not in data: + data[diag] = jnp.zeros(jax_array.shape[1], dtype=np.complex128) + data[diag] = data[diag].at[col].set(jax_array._jxa[row, col]) + + offsets = tuple(data.keys()) + data = jnp.array(list(data.values())) + return JaxDia((data, offsets), shape=jax_array.shape, copy=False) + + +@jax.jit +def jaxarray_from_jaxdia(matrix): + out = jnp.zeros(matrix.shape, dtype=np.complex128) + + for diag, data in zip(matrix.offsets, matrix.data): + start = max(diag, 0) + end = min(matrix.shape[1], diag + matrix.shape[0]) + for col in range(start, end): + out = out.at[(col - diag), col].set(data[col]) + + return JaxArray(out, copy=False) + + def is_jax_array(data): return isinstance(data, jax.Array) diff --git a/src/qutip_jax/create.py b/src/qutip_jax/create.py index 29e8c60..3c2f7c0 100644 --- a/src/qutip_jax/create.py +++ b/src/qutip_jax/create.py @@ -1,21 +1,28 @@ import jax.numpy as jnp - +from jax import jit from .jaxarray import JaxArray -from .convert import jax_from_dense +from .jaxdia import JaxDia +from .convert import jaxarray_from_dense import numpy as np +from functools import partial import qutip __all__ = [ "zeros_jaxarray", + "zeros_jaxdia", "identity_jaxarray", + "identity_jaxdia", "diag_jaxarray", + "diag_jaxdia", "one_element_jaxarray", + "one_element_jaxdia", ] +@partial(jit, static_argnames=["rows", "cols"]) def zeros_jaxarray(rows, cols): """ Creates a matrix representation of zeros with the given dimensions. @@ -25,7 +32,24 @@ def zeros_jaxarray(rows, cols): rows, cols : int The number of rows and columns in the output matrix. """ - return JaxArray(jnp.zeros((rows, cols), dtype=jnp.complex128)) + return JaxArray._fast_constructor( + jnp.zeros((rows, cols), dtype=jnp.complex128), (rows, cols) + ) + + +@partial(jit, static_argnames=["rows", "cols"]) +def zeros_jaxdia(rows, cols): + """ + Creates a matrix representation of zeros with the given dimensions. + + Parameters + ---------- + rows, cols : int + The number of rows and columns in the output matrix. + """ + return JaxDia._fast_constructor( + (), jnp.zeros((0, cols), dtype=jnp.complex128), (rows, cols) + ) def identity_jaxarray(dimensions, scale=None): @@ -47,6 +71,30 @@ def identity_jaxarray(dimensions, scale=None): return JaxArray(jnp.eye(dimensions, dtype=jnp.complex128) * scale) +@partial(jit, static_argnums=(0,)) +def identity_jaxdia(dimensions, scale=1.0): + """ + Creates a square identity matrix of the given dimension. + + Optionally, the `scale` can be given, where all the diagonal elements will + be that instead of 1. + + Parameters + ---------- + dimension : int + The dimension of the square output identity matrix. + scale : complex, optional + The element which should be placed on the diagonal. + """ + if scale is None: + scale = 1.0 + return JaxDia._fast_constructor( + (0,), + jnp.ones((1, dimensions), dtype=jnp.complex128) * scale, + (dimensions, dimensions), + ) + + def diag_jaxarray(diagonals, offsets=None, shape=None): """ Constructs a matrix from diagonals and their offsets. @@ -113,13 +161,99 @@ def diag_jaxarray(diagonals, offsets=None, shape=None): out += jnp.diag(jnp.array(diag), offset) out = JaxArray(out) else: - out = jax_from_dense( + out = jaxarray_from_dense( qutip.core.data.dense.diags(diagonals, offsets, shape) ) return out +def diag_jaxdia(diagonals, offsets=None, shape=None): + """ + Constructs a matrix from diagonals and their offsets. + + Using this function in single-argument form produces a square matrix with + the given values on the main diagonal. With lists of diagonals and offsets, + the matrix will be the smallest possible square matrix if shape is not + given, but in all cases the diagonals must fit exactly with no extra or + missing elements. Duplicated diagonals will be summed together in the + output. + + Parameters + ---------- + diagonals : sequence of array_like of complex or array_like of complex + The entries (including zeros) that should be placed on the diagonals in + the output matrix. Each entry must have enough entries in it to fill + the relevant diagonal and no more. + offsets : sequence of integer or integer, optional + The indices of the diagonals. `offsets[i]` is the location of the + values `diagonals[i]`. An offset of 0 is the main diagonal, positive + values are above the main diagonal and negative ones are below the main + diagonal. + shape : tuple, optional + The shape of the output as (``rows``, ``columns``). The result does + not need to be square, but the diagonals must be of the correct length + to fit in exactly. + """ + try: + diagonals = list(diagonals) + # Can this be replaced with pure jnp and lax conditionals? + if diagonals and np.isscalar(diagonals[0]): + # Catch the case where we're being called as (for example) + # diags([1, 2, 3], 0) + # with a single diagonal and offset. + diagonals = [diagonals] + except TypeError: + raise TypeError("diagonals must be a list of arrays of complex") + + if offsets is None: + if len(diagonals) == 0: + offsets = [] + elif len(diagonals) == 1: + offsets = [0] + else: + raise TypeError( + "offsets must be supplied if passing more than one diagonal" + ) + + offsets = np.atleast_1d(offsets) + if offsets.ndim > 1: + raise ValueError("offsets must be a 1D array of integers") + if len(diagonals) != len(offsets): + raise ValueError("number of diagonals does not match number of offsets") + if len(diagonals) == 0: + if shape is None: + raise ValueError( + "cannot construct matrix with no diagonals without a shape" + ) + else: + n_rows, n_cols = shape + return zeros(n_rows, n_cols) + + if shape: + n_rows, n_cols = shape + else: + n_rows = n_cols = abs(offsets[0]) + len(diagonals[0]) + + out = {} + for offset, data in zip(offsets, diagonals): + start = max(0, offset) + end = min(n_cols, n_rows + offset) + out_data = jnp.zeros(n_cols, dtype=jnp.complex128) + out_data = out_data.at[start:end].set(data) + if offset in out: + out[offset] = out[offset] + out_data + else: + out[offset] = out_data + + out = JaxDia( + (jnp.array(list(out.values())), tuple(out.keys()),), + shape=(n_rows, n_cols), + copy=False, + ) + return out + + def one_element_jaxarray(shape, position, value=None): """ Creates a matrix with only one nonzero element. @@ -145,22 +279,62 @@ def one_element_jaxarray(shape, position, value=None): return JaxArray(out.at[position].set(value)) +def one_element_jaxdia(shape, position, value=None): + """ + Creates a matrix with only one nonzero element. + + Parameters + ---------- + shape : tuple + The shape of the output as (``rows``, ``columns``). + + position : tuple + The position of the non zero in the matrix as (``rows``, ``columns``). + + value : complex, optional + The value of the non-null element. + """ + if not (0 <= position[0] < shape[0] and 0 <= position[1] < shape[1]): + raise ValueError( + "Position of the elements out of bound: " + + str(position) + + " in " + + str(shape) + ) + if value is None: + value = 1.0 + row, col = position + return JaxDia._fast_constructor( + (col - row,), + jnp.zeros((1, shape[1]), dtype=jnp.complex128).at[0, col].set(value), + shape, + ) + + qutip.data.zeros.add_specialisations( [ (JaxArray, zeros_jaxarray), + (JaxDia, zeros_jaxdia), ] ) -qutip.data.identity.add_specialisations([(JaxArray, identity_jaxarray)]) +qutip.data.identity.add_specialisations( + [ + (JaxArray, identity_jaxarray), + (JaxDia, identity_jaxdia), + ] +) qutip.data.diag.add_specialisations( [ (JaxArray, diag_jaxarray), + (JaxDia, diag_jaxdia), ] ) qutip.data.one_element.add_specialisations( [ (JaxArray, one_element_jaxarray), + (JaxDia, one_element_jaxdia), ] ) diff --git a/src/qutip_jax/jaxarray.py b/src/qutip_jax/jaxarray.py index 5845bdc..33c5dca 100644 --- a/src/qutip_jax/jaxarray.py +++ b/src/qutip_jax/jaxarray.py @@ -1,7 +1,8 @@ import jax.numpy as jnp -from jax import tree_util, config +from jax import tree_util, config, jit import numbers import numpy as np +from functools import partial config.update("jax_enable_x64", True) @@ -80,6 +81,7 @@ def __matmul__(self, other): return NotImplemented @classmethod + @partial(jit, static_argnames=["cls", "shape"]) def _fast_constructor(cls, array, shape): out = cls.__new__(cls) Data.__init__(out, shape) diff --git a/src/qutip_jax/jaxdia.py b/src/qutip_jax/jaxdia.py new file mode 100644 index 0000000..196b959 --- /dev/null +++ b/src/qutip_jax/jaxdia.py @@ -0,0 +1,175 @@ +import jax.numpy as jnp +import numpy as np +from jax import tree_util, jit, config +from qutip.core.data.extract import extract +import qutip.core.data as _data +import numpy as np +from qutip.core.data.base import Data +import numbers + + +config.update("jax_enable_x64", True) + +__all__ = ["JaxDia"] + + +class JaxDia(Data): + data: jnp.ndarray + offsets: tuple + shape: tuple + + def __init__(self, arg, shape=None, copy=None): + data, offsets = arg + offsets = tuple(np.atleast_1d(offsets).astype(jnp.int64)) + data = jnp.atleast_2d(data).astype(jnp.complex128) + + if not ( + isinstance(shape, tuple) + and len(shape) == 2 + and isinstance(shape[0], numbers.Integral) + and isinstance(shape[1], numbers.Integral) + and shape[0] > 0 + and shape[1] > 0 + ): + raise ValueError( + """Shape must be a 2-tuple of positive ints, but is """ + + repr(shape) + ) + + self.data = data + self.offsets = offsets + self.num_diags = len(offsets) + super().__init__(shape) + + def copy(self): + return self.__class__((self.data, self.offsets), self.shape, copy=True) + + def to_array(self): + from .convert import jaxarray_from_jaxdia + + return jaxarray_from_jaxdia(self).to_array() + + def trace(self): + from .measurements import trace_jaxdia + + return trace_jaxdia(self) + + def conj(self): + from .unary import conj_jaxdia + + return conj_jaxdia(self) + + def transpose(self): + from .unary import transpose_jaxdia + + return transpose_jaxdia(self) + + def adjoint(self): + from .unary import adjoint_jaxdia + + return adjoint_jaxdia(self) + + @classmethod + def _fast_constructor(cls, offsets, data, shape): + out = cls.__new__(cls) + Data.__init__(out, shape) + out.data = data + out.offsets = offsets + out.num_diags = len(out.offsets) + return out + + def _tree_flatten(self): + children = (self.data,) # arrays / dynamic values + aux_data = { + "shape": self.shape, + "offsets": self.offsets, + } # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + # unflatten should not check data validity + # jax can pass tracer, object, etc. + out = cls.__new__(cls) + out.data = children[0] + out.offsets = aux_data["offsets"] + out.num_diags = len(out.offsets) + shape = aux_data["shape"] + Data.__init__(out, shape) + return out + + +tree_util.register_pytree_node( + JaxDia, JaxDia._tree_flatten, JaxDia._tree_unflatten +) + + +@jit +def clean_dia(matrix): + idx = np.argsort(matrix.offsets) + new_offset = tuple(matrix.offsets[i] for i in idx) + new_data = matrix.data[idx, :] + + for i in range(len(new_offset)): + start = max(0, new_offset[i]) + end = min(matrix.shape[1], matrix.shape[0] + new_offset[i]) + new_data = new_data.at[i, :start].set(0) + new_data = new_data.at[i, end:].set(0) + + return JaxDia._fast_constructor(new_offset, new_data, matrix.shape) + + +def tidyup_jaxdia(matrix, tol, _=None): + matrix = clean_dia(matrix) + new_offset = [] + new_data = [] + for offset, data in zip(matrix.offsets, matrix.data): + real = data.real + mask_r = real < tol + imag = data.imag + mask_i = imag < tol + if jnp.all(mask_r) and jnp.all(mask_i): + continue + data = real.at[mask_r].set(0) + 1j * imag.at[mask_i].set(0) + new_offset.append(offset) + new_data.append(data) + new_offset = tuple(new_offset) + new_data = jnp.array(new_data) + return JaxDia._fast_constructor(new_offset, new_data, matrix.shape) + + +_data.tidyup.add_specialisations([(JaxDia, tidyup_jaxdia)], _defer=True) + + +def extract_jaxdia(matrix, format=None, _=None): + """ + Return ``jaxdia_matrix`` as a pair of offsets and diagonals. + + It can be extracted as either a dict of the offset to the diagonal or a + tuple of ``(offsets, diagonals)``. + The diagonal are the lenght of the number of columns. + Each entry is at the position of the column. + + The element ``A[3, 5]`` is at ``extract_jaxdia(A, "dict")[5-3][5]``. + + Parameters + ---------- + matrix : Data + The matrix to convert to common type. + + format : str, {"dict"} + Type of the output. + """ + if format in ["dict", None]: + out = {} + for offset, data in zip(matrix.offsets, matrix.data): + out[offset] = data + + elif format in ["tuple"]: + out = (matrix.offsets, matrix.data) + else: + raise ValueError("Dia can only be extracted to 'dict' or 'tuple'") + return out + + +extract.add_specialisations([(JaxDia, extract_jaxdia)], _defer=True) diff --git a/src/qutip_jax/measurements.py b/src/qutip_jax/measurements.py index a7f75f1..53a94a9 100644 --- a/src/qutip_jax/measurements.py +++ b/src/qutip_jax/measurements.py @@ -1,5 +1,7 @@ import jax.numpy as jnp from .jaxarray import JaxArray +from .jaxdia import JaxDia +from .binops import matmul_jaxdia_jaxarray_jaxarray import qutip from jax import jit from functools import partial @@ -7,10 +9,13 @@ __all__ = [ "expect_jaxarray", + "expect_jaxdia_jaxarray", "expect_super_jaxarray", + "expect_super_jaxdia_jaxarray", "inner_jaxarray", "inner_op_jaxarray", "trace_jaxarray", + "trace_jaxdia", "trace_oper_ket_jaxarray", ] @@ -122,7 +127,10 @@ def expect_jaxarray(op, state): if ( op._jxa.shape[0] != op._jxa.shape[1] or op._jxa.shape[1] != state._jxa.shape[0] - or not (state._jxa.shape[1] == 1 or state._jxa.shape[0] == state._jxa.shape[1]) + or not ( + state._jxa.shape[1] == 1 + or state._jxa.shape[0] == state._jxa.shape[1] + ) ): raise ValueError( f"incompatible matrix shapes {op.shape} and {state.shape}" @@ -134,6 +142,80 @@ def expect_jaxarray(op, state): return out +@jit +def expect_jaxdia_jaxarray(op, state): + """Computes the expectation value between op and state assuming they are + operators and state representations (density matrix/ket). + + Parameters + ---------- + op, state : :class:`qutip.Qobj` + Quantum objects from which the underlying JAX array can be accessed. + + Returns + ------- + out : jax.interpreters.xla.DeviceArray + The complex valued output. + """ + if ( + op.shape[0] != op.shape[1] + or op.shape[1] != state.shape[0] + or not (state.shape[1] == 1 or state.shape[0] == state.shape[1]) + ): + raise ValueError( + f"incompatible matrix shapes {op.shape} and {state.shape}" + ) + out = 0 + if state.shape[0] == state.shape[1]: + for offset, data in zip(op.offsets, op.data): + if offset >= 0: + out += jnp.sum( + data[offset:] + * state._jxa.ravel()[ + offset * op.shape[0] :: (op.shape[0] + 1) + ] + ) + else: + out += jnp.sum( + data[:offset] + * state._jxa.ravel()[ + -offset : (offset * op.shape[0]) : (op.shape[0] + 1) + ] + ) + else: + out = ( + state._jxa.T.conj() + @ matmul_jaxdia_jaxarray_jaxarray(op, state)._jxa + )[0, 0] + return out + + +@jit +def expect_super_jaxdia_jaxarray(op, state): + """Computes the expectation value between op and state assuming they + represent a superoperator and a state (vectorized). + + Parameters + ---------- + op, state : :class:`qutip.Qobj` + Quantum objects from which the underlying JAX array can be accessed. + + Returns + ------- + out : jax.interpreters.xla.DeviceArray + The complex valued output. + """ + if state.shape[1] != 1: + raise ValueError("expected a column-stacked matrix") + if not (op.shape[0] == op.shape[1] and op.shape[1] == state.shape[0]): + raise ValueError( + f"incompatible matrix shapes {op.shape} and {state.shape}" + ) + + N = int(state._jxa.shape[0] ** 0.5) + return jnp.sum(matmul_jaxdia_jaxarray_jaxarray(op, state)._jxa[:: N + 1]) + + @jit def expect_super_jaxarray(op, state): """Computes the expectation value between op and state assuming they @@ -173,6 +255,19 @@ def trace_jaxarray(matrix): return jnp.trace(matrix._jxa) +@jit +def trace_jaxdia(matrix): + """Compute the trace (sum of digaonal elements) of a square matrix.""" + if matrix.shape[0] != matrix.shape[1]: + raise ValueError( + f"matrix {matrix.shape} is not a square matrix." + ) + if 0 not in matrix.offsets: + return 0.0 + idx = matrix.offsets.index(0) + return jnp.sum(matrix.data[idx]) + + @jit def trace_oper_ket_jaxarray(matrix): """ @@ -186,7 +281,6 @@ def trace_oper_ket_jaxarray(matrix): return jnp.sum(matrix._jxa[:: N + 1]) - qutip.data.inner.add_specialisations( [ (JaxArray, JaxArray, inner_jaxarray), @@ -204,6 +298,7 @@ def trace_oper_ket_jaxarray(matrix): qutip.data.expect.add_specialisations( [ (JaxArray, JaxArray, expect_jaxarray), + (JaxDia, JaxArray, expect_jaxdia_jaxarray), ] ) @@ -211,6 +306,7 @@ def trace_oper_ket_jaxarray(matrix): qutip.data.expect_super.add_specialisations( [ (JaxArray, JaxArray, expect_super_jaxarray), + (JaxDia, JaxArray, expect_super_jaxdia_jaxarray), ] ) @@ -218,6 +314,7 @@ def trace_oper_ket_jaxarray(matrix): qutip.data.trace.add_specialisations( [ (JaxArray, trace_jaxarray), + (JaxDia, trace_jaxdia), ] ) diff --git a/src/qutip_jax/ode.py b/src/qutip_jax/ode.py index 7047c2c..f4d6e8f 100644 --- a/src/qutip_jax/ode.py +++ b/src/qutip_jax/ode.py @@ -22,6 +22,14 @@ def _float2cplx(arr): return arr[0] + 1j * arr[1] +@jax.jit +def dstate(t, y, args): + state = _float2cplx(y) + H, = args + d_state = H.matmul_data(t, JaxArray(state)) + return _cplx2float(d_state._jxa) + + class DiffraxIntegrator(Integrator): method: str = "diffrax" supports_blackbox: bool = False # No feedback support @@ -38,20 +46,13 @@ def __init__(self, system, options): self._is_set = False # get_state can be used and return a valid state. self._options = self.integrator_options.copy() self.options = options - self.ODEsystem = diffrax.ODETerm(self.dstate) + self.ODEsystem = diffrax.ODETerm(dstate) self.solver_state = None self.name = f"{self.method}: {self.options['solver']}" def _prepare(self): pass - @staticmethod - def dstate(t, y, args): - state = _float2cplx(y) - H, kwargs = args - d_state = H.matmul_data(t, JaxArray(state), **kwargs) - return _cplx2float(d_state._jxa) - def set_state(self, t, state0): self.solver_state = None self.t = t @@ -64,6 +65,8 @@ def get_state(self, copy=False): return self.t, JaxArray(_float2cplx(self.state)) def integrate(self, t, copy=False, **kwargs): + if kwargs: + self.arguments(kwargs) sol = diffrax.diffeqsolve( self.ODEsystem, t0=self.t, @@ -71,7 +74,7 @@ def integrate(self, t, copy=False, **kwargs): y0=self.state, saveat=diffrax.SaveAt(t1=True, solver_state=True), solver_state=self.solver_state, - args=(self.system, kwargs), + args=(self.system,), **self._options, ) self.t = t diff --git a/src/qutip_jax/properties.py b/src/qutip_jax/properties.py index a50ed0d..33dac84 100644 --- a/src/qutip_jax/properties.py +++ b/src/qutip_jax/properties.py @@ -1,11 +1,20 @@ import jax.numpy as jnp from .jaxarray import JaxArray +from .jaxdia import JaxDia, clean_dia import qutip from jax import jit from functools import partial +import numpy as np -__all__ = ["isherm_jaxarray", "isdiag_jaxarray", "iszero_jaxarray"] +__all__ = [ + "isherm_jaxarray", + "isdiag_jaxarray", + "iszero_jaxarray", + "isherm_jaxdia", + "isdiag_jaxdia", + "iszero_jaxdia", +] @partial(jit, static_argnames=["tol"]) @@ -23,21 +32,77 @@ def isherm_jaxarray(matrix, tol=None): return _isherm(matrix._jxa, tol) +def _is_zero(vec, tol): + return jnp.allclose(vec, 0.0, atol=tol, rtol=0) + + +def _is_conj(vec1, vec2, tol): + return jnp.allclose(vec1, vec2.conj(), atol=tol, rtol=0) + + +def isherm_jaxdia(matrix, tol=None): + if matrix.shape[0] != matrix.shape[1]: + return False + tol = tol or qutip.settings.core["atol"] + done = [] + for offset, data in zip(matrix.offsets, matrix.data): + if offset in done: + continue + start = max(0, offset) + end = min(matrix.shape[1], matrix.shape[0] + offset) + if -offset not in matrix.offsets: + if not _is_zero(data[start:end], tol): + return False + else: + idx = matrix.offsets.index(-offset) + done.append(-offset) + st = max(0, -offset) + et = min(matrix.shape[1], matrix.shape[0] - offset) + if not _is_conj(data[start:end], matrix.data[idx, st:et], tol): + return False + return True + + @jit def isdiag_jaxarray(matrix): mat_abs = jnp.abs(matrix._jxa) return jnp.trace(mat_abs) == jnp.sum(mat_abs) +def isdiag_jaxdia(matrix): + if matrix.num_diags == 0 or ( + matrix.num_diags == 1 and matrix.offsets[0] == 0 + ): + return True + for offset, data in zip(matrix.offsets, matrix.data): + if offset == 0: + continue + start = max(0, offset) + end = min(matrix.shape[1], matrix.shape[0] + offset) + if not jnp.all(data[start:end] == 0): + return False + return True + + def iszero_jaxarray(matrix, tol=None): if tol is None: tol = qutip.settings.core["atol"] return jnp.allclose(matrix._jxa, 0.0, atol=tol) +def iszero_jaxdia(matrix, tol=None): + if tol is None: + tol = qutip.settings.core["atol"] + if matrix.num_diags == 0: + return True + # We must ensure the values outside the dims are not included + return jnp.allclose(clean_dia(matrix).data, 0.0, atol=tol) + + qutip.data.isherm.add_specialisations( [ (JaxArray, isherm_jaxarray), + (JaxDia, isherm_jaxdia), ] ) @@ -45,6 +110,7 @@ def iszero_jaxarray(matrix, tol=None): qutip.data.iszero.add_specialisations( [ (JaxArray, iszero_jaxarray), + (JaxDia, iszero_jaxdia), ] ) @@ -52,5 +118,6 @@ def iszero_jaxarray(matrix, tol=None): qutip.data.isdiag.add_specialisations( [ (JaxArray, isdiag_jaxarray), + (JaxDia, isdiag_jaxdia), ] ) diff --git a/src/qutip_jax/qobjevo.py b/src/qutip_jax/qobjevo.py index 11a6c68..65391d4 100644 --- a/src/qutip_jax/qobjevo.py +++ b/src/qutip_jax/qobjevo.py @@ -1,29 +1,39 @@ -import equinox as eqx import jaxlib import jax +from jax import jit import jax.numpy as jnp import numpy as np from .jaxarray import JaxArray +from .jaxdia import JaxDia +from .binops import matmul_jaxdia_jaxarray_jaxarray +from .create import zeros_jaxdia, zeros_jaxarray from qutip.core.coefficient import coefficient_builders from qutip.core.cy.coefficient import Coefficient from qutip import Qobj +from qutip.core.data.matmul import matmul +from functools import partial __all__ = [] class JaxJitCoeff(Coefficient): - func: callable = eqx.static_field() + func: callable + static_argnames: tuple args: dict - def __init__(self, func, args={}, **_): + def __init__(self, func, args={}, static_argnames=(), **_): self.func = func + self.static_argnames = static_argnames Coefficient.__init__(self, args) + self.jit_call = jit(self._caller, static_argnames=self.static_argnames) - @eqx.filter_jit def __call__(self, t, _args=None, **kwargs): if _args: kwargs.update(_args) + return self.jit_call(t, **kwargs) + + def _caller(self, t, **kwargs): args = self.args.copy() for key in kwargs: if key in args: @@ -33,38 +43,66 @@ def __call__(self, t, _args=None, **kwargs): def replace_arguments(self, _args=None, **kwargs): if _args: kwargs.update(_args) - return JaxJitCoeff(self.func, {**self.args, **kwargs}) + args = self.args.copy() + for key in kwargs: + if key in args: + args[key] = kwargs[key] + return JaxJitCoeff(self.func, args=args) def __add__(self, other): if isinstance(other, JaxJitCoeff): + merge_static = tuple( + set(self.static_argnames) | set(other.static_argnames) + ) def f(t, **kwargs): - return self(t, **kwargs) + other(t, **kwargs) - return JaxJitCoeff(eqx.filter_jit(f), {}) + return self._caller(t, **kwargs) + other._caller(t, **kwargs) + + return JaxJitCoeff( + jit(f, static_argnames=merge_static), + args={**self.args, **other.args}, + static_argnames=merge_static, + ) return NotImplemented def __mul__(self, other): if isinstance(other, JaxJitCoeff): + merge_static = tuple( + set(self.static_argnames) | set(other.static_argnames) + ) def f(t, **kwargs): - return self(t, **kwargs) * other(t, **kwargs) - return JaxJitCoeff(eqx.filter_jit(f), {}) + return self._caller(t, **kwargs) * self._caller(t, **kwargs) + + return JaxJitCoeff( + jit(f, static_argnames=merge_static), + args={**self.args, **other.args}, + static_argnames=merge_static, + ) return NotImplemented def conj(self): def f(t, **kwargs): - return jnp.conj(self(t, **kwargs)) + return jnp.conj(self._caller(t, **kwargs)) - return JaxJitCoeff(eqx.filter_jit(f), {}) + return JaxJitCoeff( + jit(f, static_argnames=self.static_argnames), + args=self.args, + static_argnames=self.static_argnames, + ) def _cdc(self): def f(t, **kwargs): val = self(t, **kwargs) return jnp.conj(val) * val - return JaxJitCoeff(eqx.filter_jit(f), {}) + return JaxJitCoeff( + jit(f, static_argnames=self.static_argnames), + args=self.args, + static_argnames=self.static_argnames, + ) def copy(self): return self @@ -73,28 +111,48 @@ def __reduce__(self): # Jitted function cannot be pickled. # Extract the original function and re-jit it. # This can fail depending on the wrapped object. - return (self.restore, (self.func.__wrapped__, self.args)) + return ( + self.restore, + (self.func.__wrapped__, self.args, self.static_argnames) + ) @classmethod - def restore(cls, func, args): - return cls(eqx.filter_jit(func), args) + def restore(cls, func, args, static_argnames): + return cls( + jit(func, static_argnames=static_argnames), + args, + static_argnames + ) def flatten(self): - return (self.args,), (self.func,) + static_args = { + key: val for key, val in self.args.items() + if key in self.static_argnames + } + jax_args = { + key: val for key, val in self.args.items() + if key not in self.static_argnames + } + return (jax_args,), (self.func, static_args, self.static_argnames) @classmethod def unflatten(cls, aux_data, children): - return JaxJitCoeff(*aux_data, *children) + func, static_args, static_argnames = aux_data + + return JaxJitCoeff( + func, + args={**children[0], **static_args}, + static_argnames=static_argnames + ) -coefficient_builders[eqx._jit._JitWrapper] = JaxJitCoeff coefficient_builders[jaxlib.xla_extension.PjitFunction] = JaxJitCoeff jax.tree_util.register_pytree_node( JaxJitCoeff, JaxJitCoeff.flatten, JaxJitCoeff.unflatten ) -class JaxQobjEvo(eqx.Module): +class JaxQobjEvo: """ Pytree friendly QobjEvo for the Diffrax integrator. @@ -102,66 +160,142 @@ class JaxQobjEvo(eqx.Module): """ batched_data: jnp.ndarray + sparse_part: list coeffs: list - dims: object = eqx.static_field() + sparse_part: list + shape: tuple + dims: object def __init__(self, qobjevo): as_list = qobjevo.to_list() - self.coeffs = [] + coeffs = [] qobjs = [] self.dims = qobjevo.dims + self.shape = qobjevo.shape + self.coeffs = [] + self.sparse_part = [] + self.batched_data = None - constant = JaxJitCoeff(eqx.filter_jit(lambda t, **_: 1.0)) + constant = JaxJitCoeff(jit(lambda t, **_: 1.0)) for part in as_list: if isinstance(part, Qobj): qobjs.append(part) - self.coeffs.append(constant) + coeffs.append(constant) elif ( isinstance(part, list) and isinstance(part[0], Qobj) ): qobjs.append(part[0]) - self.coeffs.append(part[1]) + coeffs.append(part[1]) else: - # TODO: raise NotImplementedError( "Function based QobjEvo are not supported" ) - if qobjs: - shape = qobjs[0].shape + dense_part = [] + for qobj, coeff in zip(qobjs, coeffs): + if type(qobj.data) in [JaxDia]: + # TODO: CSR also? + self.sparse_part.append((qobj.data, coeff)) + else: + dense_part.append((qobj, coeff)) + + if dense_part: self.batched_data = jnp.zeros( - shape + (len(qobjs),), dtype=np.complex128 + self.shape + (len(dense_part),), dtype=np.complex128 ) - for i, qobj in enumerate(qobjs): + for i, (qobj, coeff) in enumerate(dense_part): self.batched_data = self.batched_data.at[:, :, i].set( qobj.to("jax").data._jxa ) + self.coeffs.append(coeff) - @eqx.filter_jit - def _coeff(self, t, **args): - list_coeffs = [coeff(t, **args) for coeff in self.coeffs] + @jit + def _coeff(self, t): + list_coeffs = [coeff(t) for coeff in self.coeffs] return jnp.array(list_coeffs, dtype=np.complex128) - def __call__(self, t, **kwargs): - return Qobj(self.data(t, **kwargs), dims=self.dims) - - @eqx.filter_jit - def data(self, t, **kwargs): - coeff = self._coeff(t, **kwargs) - data = jnp.dot(self.batched_data, coeff) - return JaxArray(data) + def __call__(self, t, _args=None, **kwargs): + if args is not None: + kwargs.update(_args) + if kwargs: + caller = self.arguments(kwargs) + else: + caller = self + return Qobj(caller.data(t), dims=self.dims) + + @jit + def data(self, t): + if self.batched_data is not None: + coeff = self._coeff(t) + data = jnp.dot(self.batched_data, coeff) + out = JaxArray(data) + else: + out = zeros_jaxdia(*self.shape) + for data, coeff in self.sparse_part: + out = out + data * coeff(t) + return out - @eqx.filter_jit - def matmul_data(self, t, y, **kwargs): - coeffs = self._coeff(t, **kwargs) - out = JaxArray(jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa)) + @partial(jax.jit, donate_argnums=(3,)) + def matmul_data(self, t, y, out=None): + if out is None and self.batched_data is not None: + coeffs = self._coeff(t) + out = JaxArray._fast_constructor( + jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa), + y.shape + ) + elif type(out) is JaxArray and self.batched_data is not None: + coeffs = self._coeff(t) + out = JaxArray._fast_constructor( + jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa) + out._jxa, + y.shape + ) + elif self.batched_data is not None: + out = JaxArray._fast_constructor( + jnp.dot(jnp.dot(self.batched_data, coeffs), y._jxa), + y.shape + ) + out + elif out is None: + out = zeros_jaxarray(*y.shape) + + for data, coeff in self.sparse_part: + if isinstance(y, JaxArray): + out = matmul_jaxdia_jaxarray_jaxarray(data, y, coeff(t), out) + else: + out = out + matmul(data, y, coeff(t)) return out def arguments(self, args): out = JaxQobjEvo.__new__(JaxQobjEvo) coeffs = [coeff.replace_arguments(args) for coeff in self.coeffs] + sparse_part = [ + (data, coeff.replace_arguments(args)) + for data, coeff in self.sparse_part + ] object.__setattr__(out, "coeffs", coeffs) + object.__setattr__(out, "sparse_part", sparse_part) object.__setattr__(out, "batched_data", self.batched_data) object.__setattr__(out, "dims", self.dims) + object.__setattr__(out, "shape", self.shape) + return out + + def flatten(self): + return ( + (self.batched_data, self.coeffs, self.sparse_part), + {"dims": self.dims, "shape": self.shape} + ) + + @classmethod + def unflatten(cls, aux_data, children): + out = cls.__new__(cls) + out.batched_data = children[0] + out.coeffs = children[1] + out.sparse_part = children[2] + out.dims = aux_data["dims"] + out.shape = aux_data["shape"] return out + + +jax.tree_util.register_pytree_node( + JaxQobjEvo, JaxQobjEvo.flatten, JaxQobjEvo.unflatten +) diff --git a/src/qutip_jax/qutip_trees.py b/src/qutip_jax/qutip_trees.py index ae06a11..e9ce969 100644 --- a/src/qutip_jax/qutip_trees.py +++ b/src/qutip_jax/qutip_trees.py @@ -25,9 +25,7 @@ def qobj_tree_unflatten(aux_data, children): return out -tree_util.register_pytree_node( - Qobj, qobj_tree_flatten, qobj_tree_unflatten -) +tree_util.register_pytree_node(Qobj, qobj_tree_flatten, qobj_tree_unflatten) def _QobjEvo_flatten(qevo): @@ -35,7 +33,7 @@ def _QobjEvo_flatten(qevo): # But it is automatically generated by cython. # I am not sure if the order is constant across version/OS. state = qevo._getstate() - children = state.pop("elements"), + children = (state.pop("elements"),) return children, state @@ -43,9 +41,7 @@ def _QobjEvo_unflatten(aux_data, children): return QobjEvo._restore(children[0], **aux_data) -tree_util.register_pytree_node( - QobjEvo, _QobjEvo_flatten, _QobjEvo_unflatten -) +tree_util.register_pytree_node(QobjEvo, _QobjEvo_flatten, _QobjEvo_unflatten) ############################################# @@ -106,7 +102,7 @@ def _MapElement_flatten(element): def _MapElement_unflatten(aux_data, children): base, coeff = children - transform, = aux_data + (transform,) = aux_data return _MapElement(base, transform, coeff) diff --git a/src/qutip_jax/unary.py b/src/qutip_jax/unary.py index d3b601b..ea11d20 100644 --- a/src/qutip_jax/unary.py +++ b/src/qutip_jax/unary.py @@ -1,14 +1,21 @@ import qutip from .jaxarray import JaxArray -from .binops import mul_jaxarray +from .jaxdia import JaxDia +from .binops import mul_jaxarray, mul_jaxdia import jax.scipy.linalg as linalg from jax import jit +import numpy as np +import jax.numpy as jnp __all__ = [ "neg_jaxarray", + "neg_jaxdia", "adjoint_jaxarray", + "adjoint_jaxdia", "transpose_jaxarray", + "transpose_jaxdia", "conj_jaxarray", + "conj_jaxdia", "inv_jaxarray", "expm_jaxarray", "project_jaxarray", @@ -28,6 +35,12 @@ def neg_jaxarray(matrix): return mul_jaxarray(matrix, -1) +@jit +def neg_jaxdia(matrix): + """Unary element-wise negation of a matrix.""" + return mul_jaxdia(matrix, -1) + + @jit def adjoint_jaxarray(matrix): """Hermitian adjoint (matrix conjugate transpose).""" @@ -44,6 +57,50 @@ def conj_jaxarray(matrix): return JaxArray._fast_constructor(matrix._jxa.conj(), matrix.shape) +@jit +def conj_jaxdia(matrix): + """Element-wise conjugation of a matrix.""" + return JaxDia._fast_constructor( + matrix.offsets, matrix.data.conj(), matrix.shape + ) + + +@jit +def transpose_jaxdia(matrix): + """Transpose of a matrix.""" + new_offset = tuple(-diag for diag in matrix.offsets[::-1]) + new_data = jnp.zeros( + (matrix.data.shape[0], matrix.shape[0]), dtype=jnp.complex128 + ) + for i, diag in enumerate(matrix.offsets): + old_start = max(0, diag) + old_end = min(matrix.shape[1], matrix.shape[0] + diag) + new_start = max(0, -diag) + new_end = min(matrix.shape[0], matrix.shape[1] - diag) + new_data = new_data.at[-i - 1, new_start:new_end].set( + matrix.data[i, old_start:old_end] + ) + return JaxDia._fast_constructor(new_offset, new_data, matrix.shape[::-1]) + + +@jit +def adjoint_jaxdia(matrix): + """Hermitian adjoint (matrix conjugate transpose).""" + new_offset = tuple(-diag for diag in matrix.offsets[::-1]) + new_data = jnp.zeros( + (matrix.data.shape[0], matrix.shape[0]), dtype=jnp.complex128 + ) + for i, diag in enumerate(matrix.offsets): + old_start = max(0, diag) + old_end = min(matrix.shape[1], matrix.shape[0] + diag) + new_start = max(0, -diag) + new_end = min(matrix.shape[0], matrix.shape[1] - diag) + new_data = new_data.at[-i - 1, new_start:new_end].set( + matrix.data[i, old_start:old_end].conj() + ) + return JaxDia._fast_constructor(new_offset, new_data, matrix.shape[::-1]) + + def expm_jaxarray(matrix): """Matrix exponential `e**A` for a matrix `A`.""" _check_square_shape(matrix) @@ -85,6 +142,7 @@ def project_jaxarray(state): qutip.data.neg.add_specialisations( [ (JaxArray, JaxArray, neg_jaxarray), + (JaxDia, JaxDia, neg_jaxdia), ] ) @@ -92,6 +150,7 @@ def project_jaxarray(state): qutip.data.adjoint.add_specialisations( [ (JaxArray, JaxArray, adjoint_jaxarray), + (JaxDia, JaxDia, adjoint_jaxdia), ] ) @@ -99,6 +158,7 @@ def project_jaxarray(state): qutip.data.transpose.add_specialisations( [ (JaxArray, JaxArray, transpose_jaxarray), + (JaxDia, JaxDia, transpose_jaxdia), ] ) @@ -106,6 +166,7 @@ def project_jaxarray(state): qutip.data.conj.add_specialisations( [ (JaxArray, JaxArray, conj_jaxarray), + (JaxDia, JaxDia, conj_jaxdia), ] ) diff --git a/tests/conftest.py b/tests/conftest.py index 66ca23e..2fd8ab2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,20 @@ from jax import random import qutip_jax +import numpy as np key = random.PRNGKey(1234) + def _random_cplx(shape): return qutip_jax.JaxArray( - random.normal(key, shape) + 1j*random.normal(key, shape) + random.normal(key, shape) + 1j * random.normal(key, shape) ) + + +def _random_dia(shape): + offsets = np.arange(-shape[0] + 1, shape[1]) + np.random.shuffle(offsets) + offsets = tuple(offsets[: min(3, shape[0] + shape[1] - 1)]) + data_shape = len(offsets), shape[1] + data = np.random.rand(*data_shape) + 1j * np.random.rand(*data_shape) + return qutip_jax.JaxDia((data, offsets), shape=shape) diff --git a/tests/test_binops.py b/tests/test_binops.py index 4449a97..3e7c845 100644 --- a/tests/test_binops.py +++ b/tests/test_binops.py @@ -1,15 +1,18 @@ import qutip.tests.core.data.test_mathematics as testing import qutip_jax +from qutip_jax import JaxArray, JaxDia import pytest from . import conftest testing._ALL_CASES = { - qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)] + JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)], + JaxDia: lambda shape: [lambda: conftest._random_dia(shape)], } testing._RANDOM = { - qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)] + JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)], + JaxDia: lambda shape: [lambda: conftest._random_dia(shape)], } @@ -17,10 +20,16 @@ class TestAdd(testing.TestAdd): specialisations = [ pytest.param( qutip_jax.add_jaxarray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - ) + JaxArray, + JaxArray, + JaxArray, + ), + pytest.param( + qutip_jax.add_jaxdia, + JaxDia, + JaxDia, + JaxDia, + ), ] @@ -28,16 +37,23 @@ class TestSub(testing.TestSub): specialisations = [ pytest.param( qutip_jax.sub_jaxarray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - ) + JaxArray, + JaxArray, + JaxArray, + ), + pytest.param( + qutip_jax.sub_jaxdia, + JaxDia, + JaxDia, + JaxDia, + ), ] class TestMul(testing.TestMul): specialisations = [ - pytest.param(qutip_jax.mul_jaxarray, qutip_jax.JaxArray, qutip_jax.JaxArray) + pytest.param(qutip_jax.mul_jaxarray, JaxArray, JaxArray), + pytest.param(qutip_jax.mul_jaxdia, JaxDia, JaxDia), ] @@ -45,10 +61,28 @@ class TestMatmul(testing.TestMatmul): specialisations = [ pytest.param( qutip_jax.matmul_jaxarray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - ) + JaxArray, + JaxArray, + JaxArray, + ), + pytest.param( + qutip_jax.matmul_jaxdia, + JaxDia, + JaxDia, + JaxDia, + ), + pytest.param( + qutip_jax.matmul_jaxdia_jaxarray_jaxarray, + JaxDia, + JaxArray, + JaxArray, + ), + pytest.param( + qutip_jax.matmul_jaxarray_jaxdia_jaxarray, + JaxArray, + JaxDia, + JaxArray, + ), ] @@ -56,10 +90,16 @@ class TestMultiply(testing.TestMultiply): specialisations = [ pytest.param( qutip_jax.multiply_jaxarray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - ) + JaxArray, + JaxArray, + JaxArray, + ), + pytest.param( + qutip_jax.multiply_jaxdia, + JaxDia, + JaxDia, + JaxDia, + ), ] @@ -67,14 +107,18 @@ class TestKron(testing.TestKron): specialisations = [ pytest.param( qutip_jax.kron_jaxarray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - qutip_jax.JaxArray, - ) + JaxArray, + JaxArray, + JaxArray, + ), + pytest.param( + qutip_jax.kron_jaxdia, + JaxDia, + JaxDia, + JaxDia, + ), ] class TestPow(testing.TestPow): - specialisations = [ - pytest.param(qutip_jax.pow_jaxarray, qutip_jax.JaxArray, qutip_jax.JaxArray) - ] + specialisations = [pytest.param(qutip_jax.pow_jaxarray, JaxArray, JaxArray)] diff --git a/tests/test_convert.py b/tests/test_convert.py new file mode 100644 index 0000000..431a114 --- /dev/null +++ b/tests/test_convert.py @@ -0,0 +1,81 @@ +import jax +import jax.numpy as jnp +from jax import jit + +import numpy as np +from numpy.testing import assert_array_almost_equal +import pytest + +import qutip_jax +from qutip_jax import JaxArray, JaxDia +import qutip + + +@pytest.mark.parametrize( + "to_", + [ + pytest.param(qutip.data.Dense, id="to Dense type"), + pytest.param(qutip.data.CSR, id="to CSR type"), + pytest.param(JaxDia, id="to JaxDia type"), + ], +) +@pytest.mark.parametrize( + "back_", + [ + pytest.param("jax", id="from str (1)"), + pytest.param("JaxArray", id="from str (2)"), + pytest.param(JaxArray, id="from type"), + ], +) +def test_convert_explicit_jaxarray(to_, back_): + """Test that it can convert to and from other types""" + arr = JaxArray(jnp.arange(0, 3, 11)) + converted = qutip.data.to(to_, arr) + assert isinstance(converted, to_) + back = qutip.data.to[back_](converted) + assert isinstance(back, JaxArray) + assert back == arr + + +@pytest.mark.parametrize( + "to_", + [ + pytest.param(qutip.data.Dense, id="to Dense type"), + pytest.param(qutip.data.CSR, id="to CSR type"), + pytest.param(qutip.data.Dia, id="to Dia type"), + pytest.param(JaxArray, id="to JaxArray type"), + ], +) +@pytest.mark.parametrize( + "back_", + [ + pytest.param("JaxDia", id="from str"), + pytest.param(JaxDia, id="from type"), + ], +) +def test_convert_explicit_jaxdia(to_, back_): + """Test that it can convert to and from other types""" + arr = JaxDia((jnp.arange(3), (0,)), shape=(3, 3)) + converted = qutip.data.to(to_, arr) + assert isinstance(converted, to_) + back = qutip.data.to[back_](converted) + assert isinstance(back, JaxDia) + assert back == arr + + +def test_convert(): + """Tests if the conversions from Qobj to JaxArray work""" + ones = jnp.ones((3, 3)) + qobj = qutip.Qobj(ones) + prod = qobj * jnp.array(0.5) + assert_array_almost_equal(prod.data.to_array(), ones * jnp.array([0.5])) + + sx = qutip.qeye(5, dtype="csr") + assert isinstance(sx.data, qutip.core.data.CSR) + assert isinstance(sx.to("jax").data, JaxArray) + + sx = qutip.qeye(5, dtype="JaxArray") + assert isinstance(sx.data, JaxArray) + + sx = qutip.qeye(5, dtype="JaxDia") + assert isinstance(sx.data, JaxDia) diff --git a/tests/test_create.py b/tests/test_create.py index fda015c..8d02bd9 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,4 +1,4 @@ -from qutip_jax import JaxArray +from qutip_jax import JaxArray, JaxDia from qutip_jax.create import * import pytest import numpy as np @@ -15,11 +15,12 @@ pytest.param((2, 10), id="wide"), ], ) -def test_zeros(shape): +@pytest.mark.parametrize("func", [zeros_jaxarray, zeros_jaxdia]) +def test_zeros(func, shape): """Tests the function that creates zero JAX arrays.""" - base = zeros_jaxarray(shape[0], shape[1]) + base = func(shape[0], shape[1]) nd = base.to_array() - assert isinstance(base, JaxArray) + # assert isinstance(base, JaxArray) assert base.shape == shape assert nd.shape == shape assert np.count_nonzero(nd) == 0 @@ -31,19 +32,16 @@ def test_zeros(shape): [None, 2, -0.1, 1.5, 1.5 + 1j], ids=["none", "int", "negative", "float", "complex"], ) -def test_identity(dimension, scale): +@pytest.mark.parametrize("func", [identity_jaxarray, identity_jaxdia]) +def test_identity(func, dimension, scale): """Tests the function that creates identity JAX arrays.""" # scale=None is testing that the default value returns the identity. - base = ( - identity_jaxarray(dimension) - if scale is None - else identity_jaxarray(dimension, scale) - ) + base = func(dimension) if scale is None else func(dimension, scale) nd = base.to_array() numpy_test = np.eye(dimension, dtype=np.complex128) if scale is not None: numpy_test *= scale - assert isinstance(base, JaxArray) + # assert isinstance(base, JaxArray) assert base.shape == (dimension, dimension) assert np.count_nonzero(nd - numpy_test) == 0 @@ -56,7 +54,9 @@ def test_identity(dimension, scale): pytest.param([[0.2j, 0.3]], None, None, id="main diagonal list"), pytest.param([0.2j, 0.3], 2, None, id="superdiagonal"), pytest.param([0.2j, 0.3], -2, None, id="subdiagonal"), - pytest.param([[0.2, 0.3, 0.4], [0.1, 0.9]], [-2, 3], None, id="two diagonals"), + pytest.param( + [[0.2, 0.3, 0.4], [0.1, 0.9]], [-2, 3], None, id="two diagonals" + ), pytest.param([1, 2, 3], 0, (3, 5), id="main wide"), pytest.param([1, 2, 3], 0, (5, 3), id="main tall"), pytest.param([[1, 2, 3], [4, 5]], [-1, -2], (4, 8), id="two wide sub"), @@ -68,16 +68,25 @@ def test_identity(dimension, scale): [[1, 2, 3, 4], [4, 5, 4j, 1j]], [-1, -2], (8, 4), id="two tall sub" ), pytest.param( - [[1, 2, 3], [4, 5, 6], [1, 2]], [1, -1, -2], (4, 4), id="out of order" + [[1, 2, 3], [4, 5, 6], [1, 2]], + [1, -1, -2], + (4, 4), + id="out of order", ), pytest.param( - [[1, 2, 3], [4, 5, 6], [1, 2]], [1, 1, -2], (4, 4), id="sum duplicates" + [[1, 2, 3], [4, 5, 6], [1, 2]], + [1, 1, -2], + (4, 4), + id="sum duplicates", ), ], ) -def test_diags(diagonals, offsets, shape): +@pytest.mark.parametrize( + ["func", "dtype"], [(diag_jaxarray, JaxArray), (diag_jaxdia, JaxDia)] +) +def test_diags(func, dtype, diagonals, offsets, shape): """Tests the function that creates diagonal JAX arrays.""" - base = diag_jaxarray(diagonals, offsets, shape) + base = func(diagonals, offsets, shape) # Build numpy version test. if not isinstance(diagonals[0], list): diagonals = [diagonals] @@ -88,8 +97,9 @@ def test_diags(diagonals, offsets, shape): test = np.zeros(shape, dtype=np.complex128) for diagonal, offset in zip(diagonals, offsets): test[np.where(np.eye(*shape, k=offset) == 1)] += diagonal - assert isinstance(base, JaxArray) + # assert isinstance(base, JaxArray) assert base.shape == shape + assert isinstance(base, dtype) np.testing.assert_allclose(base.to_array(), test, rtol=1e-10) @@ -106,16 +116,17 @@ def test_diags(diagonals, offsets, shape): pytest.param((2, 10), (1, 5), 10, id="wide"), ], ) -def test_one_element(shape, position, value): +@pytest.mark.parametrize("func", [one_element_jaxarray, one_element_jaxdia]) +def test_one_element(func, shape, position, value): """Tests the function that creates single element JAX arrays.""" test = np.zeros(shape, dtype=np.complex128) if value is None: - base = one_element_jaxarray(shape, position) + base = func(shape, position) test[position] = 1.0 + 0.0j else: - base = one_element_jaxarray(shape, position, value) + base = func(shape, position, value) test[position] = value - assert isinstance(base, JaxArray) + # assert isinstance(base, JaxArray) assert base.shape == shape assert np.allclose(base.to_array(), test, atol=1e-10) @@ -129,8 +140,11 @@ def test_one_element(shape, position, value): pytest.param((10, 10), (5, -1), 2.0, id="outside neg"), ], ) -def test_one_element_error(shape, position, value): +@pytest.mark.parametrize("func", [one_element_jaxarray, one_element_jaxdia]) +def test_one_element_error(func, shape, position, value): """Tests for wrong inputs to the one_element_jaxarray function.""" with pytest.raises(ValueError) as exc: - base = one_element_jaxarray(shape, position, value) - assert str(exc.value).startswith("Position of the elements" " out of bound: ") + base = func(shape, position, value) + assert str(exc.value).startswith( + "Position of the elements" " out of bound: " + ) diff --git a/tests/test_eigen.py b/tests/test_eigen.py index 1056b5a..c5a164a 100644 --- a/tests/test_eigen.py +++ b/tests/test_eigen.py @@ -19,18 +19,21 @@ def test_eigen_known_oper(): np.testing.assert_allclose(spvals, expected, atol=1e-13) -@pytest.mark.parametrize(["rand"], [ - pytest.param(qutip.rand_herm, id="hermitian"), - pytest.param(qutip.rand_unitary, id="non-hermitian"), -]) -@pytest.mark.parametrize("order", ['low', 'high']) +@pytest.mark.parametrize( + ["rand"], + [ + pytest.param(qutip.rand_herm, id="hermitian"), + pytest.param(qutip.rand_unitary, id="non-hermitian"), + ], +) +@pytest.mark.parametrize("order", ["low", "high"]) def test_eigen_rand_oper(rand, order): mat = rand(10, dtype="jax").data isherm = rand is qutip.rand_herm kw = {"isherm": isherm, "sort": order} spvals, spvecs = qutip_jax.eigs_jaxarray(mat, vecs=True, **kw) sp_energies = qutip_jax.eigs_jaxarray(mat, vecs=False, **kw) - if order == 'low': + if order == "low": assert np.all(np.diff(spvals).real >= 0) else: assert np.all(np.diff(spvals).real <= 0) @@ -38,11 +41,14 @@ def test_eigen_rand_oper(rand, order): np.testing.assert_allclose(spvals, sp_energies, atol=5e-15) -@pytest.mark.parametrize("rand", [ - pytest.param(qutip.rand_herm, id="hermitian"), - pytest.param(qutip.rand_unitary, id="non-hermitian"), -]) -@pytest.mark.parametrize("order", ['low', 'high']) +@pytest.mark.parametrize( + "rand", + [ + pytest.param(qutip.rand_herm, id="hermitian"), + pytest.param(qutip.rand_unitary, id="non-hermitian"), + ], +) +@pytest.mark.parametrize("order", ["low", "high"]) @pytest.mark.parametrize("N", [1, 5, 8, 9]) def test_eigvals_parameter(rand, order, N): mat = rand(10, dtype="jax").data @@ -54,7 +60,7 @@ def test_eigvals_parameter(rand, order, N): assert np.allclose(all_spvals[:N], spvals) assert np.allclose(all_spvals[:N], sp_energies) assert_eigen_set(mat._jxa, spvals, spvecs._jxa) - if order == 'low': + if order == "low": assert np.all(np.diff(spvals).real >= 0) else: assert np.all(np.diff(spvals).real <= 0) diff --git a/tests/test_jaxarray.py b/tests/test_jaxarray.py index 5c1b391..330e79e 100644 --- a/tests/test_jaxarray.py +++ b/tests/test_jaxarray.py @@ -1,12 +1,7 @@ -import jax import jax.numpy as jnp from jax import jit - import numpy as np -from numpy.testing import assert_array_almost_equal import pytest - -import qutip_jax from qutip_jax.jaxarray import JaxArray import qutip @@ -15,7 +10,7 @@ "backend", [pytest.param(jnp, id="jnp"), pytest.param(np, id="np")], ) -@pytest.mark.parametrize("shape", [(1,1), (10,), (3, 3), (1, 10)]) +@pytest.mark.parametrize("shape", [(1, 1), (10,), (3, 3), (1, 10)]) @pytest.mark.parametrize("dtype", [int, float, complex]) def test_init(backend, shape, dtype): """Tests creation of JaxArrays from NumPy and JAX-Numpy arrays""" @@ -23,16 +18,17 @@ def test_init(backend, shape, dtype): array = backend.array(array) jax_a = JaxArray(array) assert isinstance(jax_a, JaxArray) - assert jax_a._jxa.dtype == jax.numpy.complex128 + assert jax_a._jxa.dtype == jnp.complex128 if len(shape) == 1: shape = shape + (1,) assert jax_a.shape == shape -@pytest.mark.parametrize("build", +@pytest.mark.parametrize( + "build", [ pytest.param(qutip.Qobj, id="Qobj"), - pytest.param(qutip.data.create, id="create") + pytest.param(qutip.data.create, id="create"), ], ) def test_create(build): @@ -53,47 +49,4 @@ def func(arr): return arr.trace() tr = func(arr) - assert isinstance(tr, jax.Array) - - -@pytest.mark.parametrize("to_", - [ - pytest.param(qutip.data.Dense, id="to Dense type"), - pytest.param(qutip.data.CSR, id="to CSR type"), - ], -) -@pytest.mark.parametrize("back_", - [ - pytest.param("jax", id="from str (1)"), - pytest.param("JaxArray", id="from str (2)"), - pytest.param(JaxArray, id="from type"), - ], -) -def test_convert_explicit(to_, back_): - """ Test that it can convert to and from other types """ - arr = JaxArray(jnp.linspace(0, 3, 11)) - converted = qutip.data.to(to_, arr) - assert isinstance(converted, to_) - back = qutip.data.to[back_](converted) - assert isinstance(back, JaxArray) - - -def test_convert(): - """Tests if the conversions from Qobj to JaxArray work""" - ones = jnp.ones((3, 3)) - qobj = qutip.Qobj(ones) - prod = qobj * jnp.array(0.5) - assert_array_almost_equal(prod.data.to_array(), ones * jnp.array(0.5)) - - sx = qutip.qeye(5, dtype="csr") - assert isinstance(sx.data, qutip.core.data.CSR) - assert isinstance(sx.to('jax').data, JaxArray) - - sx = qutip.qeye(5, dtype="JaxArray") - assert isinstance(sx.data, JaxArray) - - -def test_extract(): - ones = jnp.ones((3, 3)) - qobj = qutip.Qobj(ones) - assert isinstance(qobj.data_as("JaxArray"), jax.Array) + assert isinstance(tr, jnp.ndarray) diff --git a/tests/test_jaxdia.py b/tests/test_jaxdia.py new file mode 100644 index 0000000..8283952 --- /dev/null +++ b/tests/test_jaxdia.py @@ -0,0 +1,49 @@ +import jax.numpy as jnp +from jax import jit +import numpy as np +import pytest +from qutip_jax.jaxdia import JaxDia, tidyup_jaxdia, clean_dia +import qutip + + +@pytest.mark.parametrize( + "backend", + [pytest.param(jnp, id="jnp"), pytest.param(np, id="np")], +) +@pytest.mark.parametrize("shape", [(1, 1), (10, 1), (3, 3), (1, 10)]) +@pytest.mark.parametrize("dtype", [int, float, complex]) +def test_init(backend, shape, dtype): + """Tests creation of JaxArrays from NumPy and JAX-Numpy arrays""" + array = np.array(np.random.rand(1, shape[1]), dtype=dtype) + array = backend.array(array) + jax_a = JaxDia((array, (0,)), shape=shape) + assert isinstance(jax_a, JaxDia) + assert jax_a.data.dtype == jnp.complex128 + assert jax_a.shape == shape + + +def test_jit(): + """Tests JIT of JaxArray methods""" + arr = JaxDia((jnp.arange(3), (0,)), shape=(3, 3)) + + # Some function of that we would like to JIT. + @jit + def func(arr): + return arr.trace() + + tr = func(arr) + assert isinstance(tr, jnp.ndarray) + + +def test_tidyup(): + big = JaxDia((jnp.arange(3), (0,)), shape=(3, 3)) + small = JaxDia((jnp.arange(3) * 1e-10, (1,)), shape=(3, 3)) + data = big + small + assert data.num_diags == 2 + assert tidyup_jaxdia(data, 1e-5).num_diags == 1 + + +def test_clean(): + data = clean_dia(JaxDia((jnp.ones((2, 3)), (0, -1)), shape=(3, 3))) + assert data.offsets == (-1, 0) + assert data.data[0, 2] == 0.0 diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 4e92462..b18ff70 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -29,8 +29,12 @@ def test_mathematically_correct_JaxArray(self, method): test1 = _data.solve(A, b, method) assert test.shape == expected.shape - np.testing.assert_allclose(test.to_array(), expected, atol=1e-7, rtol=1e-7) - np.testing.assert_allclose(test1.to_array(), expected, atol=1e-7, rtol=1e-7) + np.testing.assert_allclose( + test.to_array(), expected, atol=1e-7, rtol=1e-7 + ) + np.testing.assert_allclose( + test1.to_array(), expected, atol=1e-7, rtol=1e-7 + ) def test_incorrect_shape_non_square(self): key = jax.random.PRNGKey(1) diff --git a/tests/test_measurements.py b/tests/test_measurements.py index d7f4caf..139d232 100644 --- a/tests/test_measurements.py +++ b/tests/test_measurements.py @@ -8,21 +8,29 @@ testing._ALL_CASES = { - qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)] + qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)], + qutip_jax.JaxDia: lambda shape: [lambda: conftest._random_dia(shape)], } testing._RANDOM = { - qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)] + qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)], + qutip_jax.JaxDia: lambda shape: [lambda: conftest._random_dia(shape)], } class TestExpect(testing_expect.TestExpect): specialisations = [ + pytest.param( + qutip_jax.expect_jaxdia_jaxarray, + qutip_jax.JaxDia, + qutip_jax.JaxArray, + object, + ), pytest.param( qutip_jax.expect_jaxarray, qutip_jax.JaxArray, qutip_jax.JaxArray, object, - ) + ), ] @@ -33,7 +41,13 @@ class TestExpectSuper(testing_expect.TestExpectSuper): qutip_jax.JaxArray, qutip_jax.JaxArray, object, - ) + ), + pytest.param( + qutip_jax.expect_super_jaxdia_jaxarray, + qutip_jax.JaxDia, + qutip_jax.JaxArray, + object, + ), ] diff --git a/tests/test_ode.py b/tests/test_ode.py index a959b10..53b4946 100644 --- a/tests/test_ode.py +++ b/tests/test_ode.py @@ -1,13 +1,20 @@ from qutip import ( - coefficient, num, destroy, create, sesolve, MESolver, basis, settings, QobjEvo + coefficient, + num, + destroy, + create, + sesolve, + MESolver, + basis, + settings, + QobjEvo, + CoreOptions, ) import qutip_jax import pytest import jax import numpy as np -settings.core["default_dtype"] = "jax" - @jax.jit def fp(t, w): @@ -21,7 +28,7 @@ def fm(t, w): @jax.jit def pulse(t, A, u, sigma): - return A * jax.numpy.exp(-(t-u)**2 / sigma) / (sigma * np.pi)**0.5 + return A * jax.numpy.exp(-((t - u) ** 2) / sigma) / (sigma * np.pi) ** 0.5 @jax.jit @@ -29,14 +36,16 @@ def cte(t, A): return A -def test_ode_run(): - H = ( - num(3) - + create(3) * coefficient(fp, args={"w": 3.1415}) - + destroy(3) * coefficient(fm, args={"w": 3.1415}) - ) +@pytest.mark.parametrize("dtype", ("jax", "jaxdia")) +def test_ode_run(dtype): + with CoreOptions(default_dtype=dtype): + H = ( + num(3) + + create(3) * coefficient(fp, args={"w": 3.1415}) + + destroy(3) * coefficient(fm, args={"w": 3.1415}) + ) - ket = basis(3) + ket = basis(3, dtype="jax") result = sesolve( H, ket, [0, 1, 2], e_ops=[num(3)], options={"method": "diffrax"} @@ -48,16 +57,18 @@ def test_ode_run(): np.testing.assert_allclose(result.expect[0], expected.expect[0], atol=1e-6) -def test_ode_step(): - H = ( - num(3) - + create(3) * coefficient(fp, args={"w": 3.1415}) - + destroy(3) * coefficient(fm, args={"w": 3.1415}) - ) +@pytest.mark.parametrize("dtype", ("jax", "jaxdia")) +def test_ode_step(dtype): + with CoreOptions(default_dtype=dtype): + H = ( + num(3) + + create(3) * coefficient(fp, args={"w": 3.1415}) + + destroy(3) * coefficient(fm, args={"w": 3.1415}) + ) - c_ops = [destroy(3)] + c_ops = [destroy(3)] - ket = basis(3) + ket = basis(3, dtype="jax") solver = MESolver(H, c_ops, options={"method": "diffrax"}) ref_solver = MESolver(H, c_ops, options={"method": "adams"}) @@ -68,21 +79,28 @@ def test_ode_step(): assert (solver.step(1) - ref_solver.step(1)).norm() <= 1e-6 -def test_ode_grad(): - H = num(10) - c_ops = [QobjEvo([destroy(10), cte], args={"A": 1.0})] +@pytest.mark.parametrize("dtype", ("jax", "jaxdia")) +def test_ode_grad(dtype): + with CoreOptions(default_dtype=dtype): + H = num(10) + c_ops = [QobjEvo([destroy(10), cte], args={"A": 1.0})] options = {"method": "diffrax", "normalize_output": False} solver = MESolver(H, c_ops, options=options) def f(solver, t, A): - result = solver.run(basis(10, 9), [0, t], e_ops=num(10), args={"A": A}) + result = solver.run( + basis(10, 9, dtype="jax"), + [0, t], + e_ops=num(10, dtype="jaxdia"), + args={"A": A} + ) return result.e_data[0][-1].real df = jax.value_and_grad(f, argnums=[1, 2]) val, (dt, dA) = df(solver, 0.2, 0.5) - assert val == pytest.approx(9 * np.exp(- 0.2 * 0.5)) - assert dt == pytest.approx(9 * np.exp(- 0.2 * 0.5) * -0.5) - assert dA == pytest.approx(9 * np.exp(- 0.2 * 0.5) * -0.2) + assert val == pytest.approx(9 * np.exp(-0.2 * 0.5**2)) + assert dt == pytest.approx(9 * np.exp(-0.2 * 0.5**2) * -0.5**2) + assert dA == pytest.approx(9 * np.exp(-0.2 * 0.5**2) * -0.2) diff --git a/tests/test_permute.py b/tests/test_permute.py index 18751e2..9219c6b 100644 --- a/tests/test_permute.py +++ b/tests/test_permute.py @@ -1,7 +1,8 @@ import qutip from qutip_jax import JaxArray -class TestPermute(): + +class TestPermute: def test_psi(self): A = qutip.basis(3, 0, dtype="jax") B = qutip.basis(5, 4, dtype="jax") @@ -55,8 +56,10 @@ def test_oper(self): rho_vec_bra = qutip.operator_to_vector(rho).dag() rho2_vec_bra = rho_vec_bra.permute([[1, 0, 2], [4, 3, 5]]) - assert (rho2_vec_bra - == qutip.operator_to_vector(qutip.tensor(B, A, C)).dag()) + assert ( + rho2_vec_bra + == qutip.operator_to_vector(qutip.tensor(B, A, C)).dag() + ) assert isinstance(rho2_vec_bra.data, JaxArray) def test_super(self): diff --git a/tests/test_properties.py b/tests/test_properties.py index d2a8791..1296bf4 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,59 +1,92 @@ +import qutip.tests.core.data.test_mathematics as testing import qutip_jax import pytest import numbers import numpy as np +import qutip.core.data as _data from . import conftest +testing._ALL_CASES = { + qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)], + qutip_jax.JaxDia: lambda shape: [lambda: conftest._random_dia(shape)], +} +testing._RANDOM = { + qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)], + qutip_jax.JaxDia: lambda shape: [lambda: conftest._random_dia(shape)], +} + + @pytest.mark.parametrize("N", (1, 10)) -def test_isherm(N): - A = conftest._random_cplx((N, N)) +@pytest.mark.parametrize( + ["func", "maker"], + [ + (qutip_jax.isherm_jaxarray, conftest._random_cplx), + (qutip_jax.isherm_jaxdia, conftest._random_dia), + ], +) +def test_isherm(func, maker, N): + A = maker((N, N)) A = A + A.adjoint() - assert qutip_jax.isherm_jaxarray(A) + assert func(A) -@pytest.mark.parametrize("shape", - [(1, 10), (10, 1), (2, 5), (5, 2)] +@pytest.mark.parametrize("shape", [(1, 10), (10, 1), (2, 5), (5, 2)]) +@pytest.mark.parametrize( + ["func", "maker"], + [ + (qutip_jax.isherm_jaxarray, conftest._random_cplx), + (qutip_jax.isherm_jaxdia, conftest._random_dia), + ], ) -def test_isherm_non_square(shape): - A = conftest._random_cplx(shape) - assert not qutip_jax.isherm_jaxarray(A) +def test_isherm_non_square(func, maker, shape): + A = maker(shape) + assert not func(A) -def test_isherm_cplxdiag(): +def test_isherm_nonherm(): A = conftest._random_cplx((10, 10)) - A = A + qutip_jax.JaxArray(np.diag(np.arange(10)*1j)) + A = A + qutip_jax.JaxArray(np.diag(np.arange(10) * 1j)) assert not qutip_jax.isherm_jaxarray(A) -def test_isherm_nonherm(): +def test_isherm_nonherm_dia(): A = conftest._random_cplx((10, 10)) - A = A + A.adjoint() - A = A + qutip_jax.JaxArray(np.diag(np.arange(9), 1)) + A = A + qutip_jax.identity_jaxdia(10) * 1j assert not qutip_jax.isherm_jaxarray(A) -def test_isherm_tol(): - A = conftest._random_cplx((10, 10)) +@pytest.mark.parametrize( + ["func", "maker"], + [ + (qutip_jax.isherm_jaxarray, conftest._random_cplx), + (qutip_jax.isherm_jaxdia, conftest._random_dia), + ], +) +def test_isherm_tol(func, maker): + A = maker((10, 10)) A = A + A.adjoint() - A = A + conftest._random_cplx((10, 10)) * 1e-10 - assert qutip_jax.isherm_jaxarray(A, 1e-5) - assert not qutip_jax.isherm_jaxarray(A, 1e-15) - - -@pytest.mark.parametrize("shape", - [(1, 10), (10, 1), (2, 5), (5, 2), (5, 5)] + A = A + maker((10, 10)) * 1e-10 + assert func(A, 1e-5) + assert not func(A, 1e-15) + + +@pytest.mark.parametrize("shape", [(1, 10), (10, 1), (2, 5), (5, 2), (5, 5)]) +@pytest.mark.parametrize( + ["func", "maker"], + [ + (qutip_jax.iszero_jaxarray, conftest._random_cplx), + (qutip_jax.iszero_jaxdia, conftest._random_dia), + ], ) -def test_iszero(shape): - A = conftest._random_cplx(shape) * 1e-10 - assert qutip_jax.iszero_jaxarray(A, 1e-5) - assert not qutip_jax.isherm_jaxarray(A, 1e-15) +def test_iszero(func, maker, shape): + A = maker(shape) * 1e-10 + assert func(A, 1e-5) + assert not func(A, 1e-15) -@pytest.mark.parametrize("shape", - [(10, 1), (2, 5), (5, 2), (5, 5)] -) +@pytest.mark.parametrize("shape", [(10, 1), (2, 5), (5, 2), (5, 5)]) def test_isdiag(shape): mat = np.zeros(shape) # empty matrices are diagonal @@ -64,3 +97,46 @@ def test_isdiag(shape): mat[1, 0] = 1 assert not qutip_jax.isdiag_jaxarray(qutip_jax.JaxArray(mat)) + + +@pytest.mark.parametrize("shape", [(10, 1), (2, 5), (5, 2), (5, 5)]) +def test_isdiag_dia(shape): + mat = np.zeros(shape) + # empty matrices are diagonal + assert qutip_jax.isdiag_jaxdia(_data.to("jaxdia", qutip_jax.JaxArray(mat))) + + mat[0, 0] = 1 + assert qutip_jax.isdiag_jaxdia(_data.to("jaxdia", qutip_jax.JaxArray(mat))) + + mat[1, 0] = 1 + assert not qutip_jax.isdiag_jaxdia( + _data.to("jaxdia", qutip_jax.JaxArray(mat)) + ) + + +class TestTrace(testing.TestTrace): + specialisations = [ + pytest.param( + qutip_jax.trace_jaxarray, + qutip_jax.JaxArray, + qutip_jax.JaxArray, + object, + ), + pytest.param( + qutip_jax.trace_jaxdia, + qutip_jax.JaxDia, + qutip_jax.JaxDia, + object, + ), + ] + + +class TestTrace_oper_ket(testing.TestTrace_oper_ket): + specialisations = [ + pytest.param( + qutip_jax.trace_oper_ket_jaxarray, + qutip_jax.JaxArray, + qutip_jax.JaxArray, + object, + ) + ] diff --git a/tests/test_unary.py b/tests/test_unary.py index 88294f0..cccf7f3 100644 --- a/tests/test_unary.py +++ b/tests/test_unary.py @@ -1,6 +1,6 @@ import qutip.tests.core.data.test_mathematics as testing import qutip_jax -from qutip_jax import JaxArray +from qutip_jax import JaxArray, JaxDia import pytest from qutip.core import data @@ -8,44 +8,49 @@ testing._ALL_CASES = { - qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)] + JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)], + JaxDia: lambda shape: [lambda: conftest._random_dia(shape)], } testing._RANDOM = { - qutip_jax.JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)] + JaxArray: lambda shape: [lambda: conftest._random_cplx(shape)], + JaxDia: lambda shape: [lambda: conftest._random_dia(shape)], } class TestNeg(testing.TestNeg): specialisations = [ pytest.param(qutip_jax.neg_jaxarray, JaxArray, JaxArray), + pytest.param(qutip_jax.neg_jaxdia, JaxDia, JaxDia), ] class TestAdjoint(testing.TestAdjoint): specialisations = [ pytest.param(qutip_jax.adjoint_jaxarray, JaxArray, JaxArray), - pytest.param(lambda mat: mat.adjoint(), JaxArray, JaxArray) + pytest.param(lambda mat: mat.adjoint(), JaxArray, JaxArray), + pytest.param(qutip_jax.adjoint_jaxdia, JaxDia, JaxDia), ] class TestConj(testing.TestConj): specialisations = [ pytest.param(qutip_jax.conj_jaxarray, JaxArray, JaxArray), - pytest.param(lambda mat: mat.conj(), JaxArray, JaxArray) + pytest.param(lambda mat: mat.conj(), JaxArray, JaxArray), + pytest.param(qutip_jax.conj_jaxdia, JaxDia, JaxDia), ] class TestTranspose(testing.TestTranspose): specialisations = [ pytest.param(qutip_jax.transpose_jaxarray, JaxArray, JaxArray), - pytest.param(lambda mat: mat.transpose(), JaxArray, JaxArray) + pytest.param(lambda mat: mat.transpose(), JaxArray, JaxArray), + pytest.param(qutip_jax.transpose_jaxdia, JaxDia, JaxDia), ] class TestExpm(testing.TestExpm): specialisations = [ - pytest.param( - qutip_jax.expm_jaxarray, JaxArray, JaxArray) + pytest.param(qutip_jax.expm_jaxarray, JaxArray, JaxArray) ] @@ -54,19 +59,18 @@ def _inv_jax(matrix): return qutip_jax.inv_jaxarray( data.add( matrix, - data.diag([1.1]*matrix.shape[0], shape=matrix.shape, dtype='JaxArray') + data.diag( + [1.1] * matrix.shape[0], shape=matrix.shape, dtype="JaxArray" + ), ) ) + class TestInv(testing.TestInv): - specialisations = [ - pytest.param( - _inv_jax, JaxArray, JaxArray) - ] + specialisations = [pytest.param(_inv_jax, JaxArray, JaxArray)] class TestProject(testing.TestProject): specialisations = [ - pytest.param( - qutip_jax.project_jaxarray, JaxArray, JaxArray) + pytest.param(qutip_jax.project_jaxarray, JaxArray, JaxArray) ]