Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create tests for functions in qutip #64

Merged
merged 11 commits into from
Aug 26, 2024
44 changes: 44 additions & 0 deletions tests/test_qutip/test_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import jax.numpy as jnp
from jax import jit, grad
from qutip import bell_state
from qutip.entropy import (entropy_vn, entropy_linear, entropy_mutual, concurrence,
entropy_conditional, participation_ratio)
import qutip.settings
import qutip_jax

qutip.settings.core["auto_real_casting"] = False
qutip_jax.set_as_default()
tol = 1e-6 # Tolerance for assertion

with qutip.CoreOptions(default_dtype="jax"):
bell_state = bell_state("10")
bell_dm = bell_state * bell_state.dag()
dm = qutip.rand_dm([5, 5], distribution="pure")

@pytest.mark.parametrize("func, name, args", [
(entropy_vn, "entropy_vn", (bell_dm,)),
(entropy_linear, "entropy_linear", (bell_dm,)),
(concurrence, "concurrence", (bell_dm,)),
(participation_ratio, "participation_ratio", (bell_dm,))
])

def test_jit(func, name, args):
func_jit = jit(func)
result = func(*args)
result_jit = func_jit(*args)
assert jnp.abs(result - result_jit) < tol

@pytest.mark.parametrize("func, name, args", [
(entropy_vn, "entropy_vn", (bell_dm,)),
(entropy_linear, "entropy_linear", (bell_dm,)),
(entropy_mutual, "entropy_mutual", (dm, [0], [1])),
(concurrence, "concurrence", (bell_dm,)),
(entropy_conditional, "entropy_conditional", (bell_dm, 0)),
])
def test_grad(func, name, args):
func_grad = grad(func)
result_grad = func_grad(*args)
assert result_grad is not None


61 changes: 61 additions & 0 deletions tests/test_qutip/test_mcsolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import jax
import jax.numpy as jnp
import qutip as qt
import qutip_jax as qjax
from qutip import mcsolve
from functools import partial

# Use JAX backend for QuTiP
qjax.set_as_default()

# Define time-dependent functions
@partial(jax.jit, static_argnames=("omega",))
def H_1_coeff(t, omega):
return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t)

# Test setup for gradient calculation
def setup_system(size=2):
a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia')
sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia')

# Define the Hamiltonian
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
H_1_op = sm * a.dag() + sm.dag() * a

H = [H_0, [H_1_op, qt.coefficient(H_1_coeff, args={"omega": 1.0})]]

state = qt.basis(size, size - 1).to('jax') & qt.basis(2, 1).to('jax')

# Define collapse operators and observables
c_ops = [jnp.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]

# Time list
tlist = jnp.linspace(0.0, 1.0, 101)

return H, state, tlist, c_ops, e_ops

# Function for which we want to compute the gradient
def f(omega, H, state, tlist, c_ops, e_ops):
result = mcsolve(
H, state, tlist, c_ops, e_ops, ntraj=10,
args={"omega": omega},
options={"method": "diffrax"}
)

return result.expect[0][-1].real

# Pytest test case for gradient computation
@pytest.mark.parametrize("omega_val", [2.0])
def test_gradient_mcsolve(omega_val):
H, state, tlist, c_ops, e_ops = setup_system(size=10)

# Compute the gradient with respect to omega
grad_func = jax.grad(lambda omega: f(omega, H, state, tlist, c_ops, e_ops))
gradient = grad_func(omega_val)

# Check if the gradient is not None and has the correct shape
assert gradient is not None
assert gradient.shape == ()
assert jnp.isfinite(gradient)
43 changes: 43 additions & 0 deletions tests/test_qutip/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import jax.numpy as jnp
from jax import jit, grad
from qutip import basis
from qutip.core.metrics import (fidelity, tracedist, bures_dist, bures_angle,
hellinger_dist, hilbert_dist)
import qutip.settings
import qutip_jax

qutip.settings.core["auto_real_casting"] = False
qutip_jax.set_as_default()
tol = 1e-6 # Tolerance for assertion

with qutip.CoreOptions(default_dtype="jax"):
rho1 = qutip.rand_dm(dimensions=5)
rho2 = qutip.rand_dm(dimensions=5)
ket_state = basis(2, 0)
oper_state = qutip.rand_dm(2)

@pytest.mark.parametrize("func, name, args", [
(fidelity, "fidelity", (rho1, rho2)),
(tracedist, "tracedist", (rho1, rho2)),
(bures_dist, "bures_dist", (rho1, rho2)),
(bures_angle, "bures_angle", (rho1, rho2)),
(hellinger_dist, "hellinger_dist", (rho1, rho2)),
(hilbert_dist, "hilbert_dist", (rho1, rho2)),
])
def test_jit(func, name, args):
func_jit = jit(func)
result = func(*args)
result_jit = func_jit(*args)
assert jnp.abs(result - result_jit) < tol

@pytest.mark.parametrize("func, name, args", [
(fidelity, "fidelity", (ket_state, oper_state)),
(tracedist, "tracedist", (rho1, rho2)),
(hellinger_dist, "hellinger_dist", (ket_state, oper_state)),
])
def test_grad(func, name, args):
func_grad = grad(func)
result = func(*args)
result_grad = func_grad(*args)
assert result_grad is not None
96 changes: 96 additions & 0 deletions tests/test_qutip/test_qobj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest
import jax.numpy as jnp
from jax import jit, grad
from qutip import Qobj, basis, rand_dm, sigmax, identity, tensor, expect
import qutip.settings
import qutip_jax

# Set JAX backend for QuTiP
qutip.settings.core["auto_real_casting"] = False
qutip_jax.set_as_default()
tol = 1e-6 # Tolerance for assertion

# Initialize quantum objects for testing
with qutip.CoreOptions(default_dtype="jax"):
ket = basis(2, 0)
bra = ket.dag()
op1 = rand_dm(2)
identity_op = identity(2)
composite_op = tensor(op1, identity_op)


# Test case for Qobj functions with jax.jit
@pytest.mark.parametrize("func_name, func", [
("copy", lambda x: x.copy()),
("conj", lambda x: x.conj()),
("contract", lambda x: x.contract()),
("cosm", lambda x: x.cosm()),
("dag", lambda x: x.dag()),
("eigenenergies", lambda x: x.eigenenergies()),
("expm", lambda x: x.expm()),
("inv", lambda x: x.inv()),
("matrix_element", lambda x: x.matrix_element(ket, ket)),
("norm", lambda x: x.norm()),
("overlap", lambda x: x.overlap(op1)),
("ptrace", lambda x: x.ptrace([0])),
("purity", lambda x: x.purity()),
("sinm", lambda x: x.sinm()),
("sqrtm", lambda x: x.sqrtm()),
("tr", lambda x: x.tr()),
("trans", lambda x: x.trans()),
("transform", lambda x: x.transform(identity_op)),
("unit", lambda x: x.unit())
])
def test_qobj_jit(func_name, func):
# Create a jitted function using the given Qobj function
def jit_func(op):
return func(op)

# Apply jit to the function
func_jit = jit(jit_func)
result_jit = func_jit(op1)

# Check if jit result is not None
assert result_jit is not None

@pytest.mark.parametrize("func_name, func", [
("eigenenergies", lambda x: jnp.sum(x.eigenenergies())),
("overlap", lambda x: x.overlap(Qobj(jnp.eye(x.shape[0])))),
("purity", lambda x: x.purity()),
("tr", lambda x: x.tr()),
])
def test_qobj_grad_complex(func_name, func):
def grad_func(op1):
result = func(op1)
return jnp.real(result)

# Apply grad to the function
grad_func = grad(grad_func)
grad_result = grad_func(op1)

assert grad_result is not None


@pytest.mark.parametrize("func_name, func", [
("copy", lambda x: x.copy()),
("conj", lambda x: x.conj()),
("contract", lambda x: x.contract()),
("expm", lambda x: x.expm()),
("cosm", lambda x: x.cosm()),
("dag", lambda x: x.dag()),
("inv", lambda x: x.inv()),
("sinm", lambda x: x.sinm()),
("trans", lambda x: x.trans()),
])
def test_qobj_grad_differentiable(func_name, func):
def grad_func(op1):
result = func(op1)
return jnp.real(result.tr())

# Apply grad to the function
grad_func = grad(grad_func)
grad_result = grad_func(op1)

assert grad_result is not None


Loading