-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create tests for functions in qutip #64
Conversation
tests/test_qutip/test_metrics.py
Outdated
|
||
with qutip.CoreOptions(default_dtype="jax"): | ||
X = qutip.sigmax() | ||
I = qutip.qeye(2) | ||
CNOT = qutip.tensor(qutip.basis(2, 0) * qutip.basis(2, 0).dag(), I) + qutip.tensor(qutip.basis(2, 1) * qutip.basis(2, 1).dag(), X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CNOT
is not used, let's remove it.
tests/test_qutip/test_entropy.py
Outdated
print(f"{name} (original):", result) | ||
print(f"{name} (JIT):", result_jit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No print in tests.
print(f"{name} (original):", result) | |
print(f"{name} (JIT):", result_jit) |
tests/test_qutip/test_entropy.py
Outdated
with qutip.CoreOptions(default_dtype="jax"): | ||
basis_0 = qutip.basis(2, 0) | ||
basis_1 = qutip.basis(2, 1) | ||
bell_state = (qutip.tensor(basis_0, basis_1) + qutip.tensor(basis_1, basis_0)).unit() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bell_state = (qutip.tensor(basis_0, basis_1) + qutip.tensor(basis_1, basis_0)).unit() | |
bell_state = qutip.bell_state("10") |
tests/test_qutip/test_entropy.py
Outdated
density_matrix = bell_state * bell_state.dag() | ||
dm = qutip.rand_dm([5, 5], distribution="pure") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bell_dm
and rand_dm
?
tests/test_qutip/test_entropy.py
Outdated
with qutip.CoreOptions(default_dtype="jax"): | ||
X = qutip.sigmax() | ||
I = qutip.qeye(2) | ||
CNOT = qutip.tensor(qutip.basis(2, 0) * qutip.basis(2, 0).dag(), I) + qutip.tensor(qutip.basis(2, 1) * qutip.basis(2, 1).dag(), X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used. We have it as qutip.gates.cnot
instead of building it yourself.
tests/test_qutip/test_mcsolve.py
Outdated
# Pytest test case for gradient computation | ||
@pytest.mark.parametrize("omega_val", [1.0, 2.0, 3.0]) | ||
def test_gradient_mcsolve(omega_val): | ||
H, state, tlist, c_ops, e_ops = setup_system(size=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the system has a continuous variable harmonic oscillator model (defined by the ladder operator a
inside setup_system
), would it be better to use a Hilbert space size (defined by the parameter size
here) greater than or equal to 10?
tests/test_qutip/test_mcsolve.py
Outdated
|
||
# Test setup for gradient calculation | ||
def setup_system(size=2): | ||
a = qt.destroy(size).to("jax") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the Hamiltonian contains two sub-systems, I think the operators would be tensor products with corresponding identities for the other operators:
a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia')
Or alternatively,
a = qt.destroy(size).to('jaxdia') & qt.qeye(2).to('jaxdia')
Same goes for sm
:
sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia')
Accordingly, the initial state would be:
state = qt.basis(size, size - 1).to('jax') & qutip.basis(2, 1).to('jax')
Or alternatively,
state = qt.basis([size, 2], [size - 1, 1]).to('jax')
Pull Request Test Coverage Report for Build 10554916907Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall look good.
Let's just make the mcsolve
test faster.
tests/test_qutip/test_mcsolve.py
Outdated
return result.expect[0][-1].real | ||
|
||
# Pytest test case for gradient computation | ||
@pytest.mark.parametrize("omega_val", [1.0, 2.0, 3.0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is the slowest and grad functionality is not affected by the value of omega, so let's run it only once.
We could loop over options (improved_sampling
, store_states
, keep_runs_results
) or mixed state input later.
@pytest.mark.parametrize("omega_val", [1.0, 2.0, 3.0]) | |
@pytest.mark.parametrize("omega_val", [2.0]) |
tests/test_qutip/test_mcsolve.py
Outdated
e_ops = [a.dag() * a, sm.dag() * sm] | ||
|
||
# Time list | ||
tlist = jnp.linspace(0.0, 10.0, 101) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's speed up the tests, the actual range does not affect the validity.
tlist = jnp.linspace(0.0, 10.0, 101) | |
tlist = jnp.linspace(0.0, 1.0, 101) |
tests/test_qutip/test_mcsolve.py
Outdated
H[1][1] = qt.coefficient(H_1_coeff, args={"omega": omega}) | ||
|
||
result = mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={"method": "diffrax"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing args
to a solver should overwrite the existing values.
H[1][1] = qt.coefficient(H_1_coeff, args={"omega": omega}) | |
result = mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={"method": "diffrax"}) | |
result = mcsolve( | |
H, state, tlist, c_ops, e_ops, ntraj=10, | |
args={"omega": omega}, | |
options={"method": "diffrax"} | |
) |
In this task, we aim to create comprehensive tests for the functions within the QuTiP library to check their compatibility with
jax.grad
andjax.jit
. These tests will ensure the correctness and robustness of the implemented functions and will cover a wide range of scenarios.