Skip to content

Commit

Permalink
create mcsolve grad doc
Browse files Browse the repository at this point in the history
  • Loading branch information
rochisha0 committed Aug 19, 2024
1 parent 2b0c79c commit f4946ce
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion doc/source/autodiff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,60 @@ should work:
result = solver.run(ket, [0, 1], e_ops=qt.num(2).to("jax"), args={"w":w})
return result.e_data[0][1].real
jax.grad(f)(0.5, solver)
jax.grad(f)(0.5, solver)
Auto differentiation in ``mcsolve``
===================================

.. note::

The functionality demonstrated in this example is currently available only in
the development (`dev.major`) branch of QuTiP. Ensure you are using the appropriate
version if you wish to replicate these results.


.. code-block:: python
import qutip_jax as qjax
import qutip as qt
import jax
import jax.numpy as jnp
from functools import partial
from qutip import mcsolve
# Use JAX backend for QuTiP
qjax.use_jax_backend()
# Define time-dependent functions
@partial(jax.jit, static_argnames=("omega",))
def H_1_coeff(t, omega):
return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t)
# Define operators and states
size = 2
a = qt.destroy(size).to("jax") # Annihilation operator
sm = qt.sigmax().to("jax") # Example spin operator
# Define the Hamiltonian
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
H_1_op = sm * a.dag() + sm.dag() * a
# Initialize the Hamiltonian with time-dependent coefficients
H = [H_0, [H_1_op, qt.coefficient(H_1_coeff, args={"omega": 1.0})]]
# Define initial states
pure_state = qt.basis(size, size-1).to("jax")
mixed_state = qt.maximally_mixed_dm(size).to("jax")
state = pure_state
# Define collapse operators and observables
c_ops = [jnp.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]
# Time list
tlist = jnp.linspace(0.0, 10.0, 101)
# Define the function for which we want to compute the gradient
def f(omega):
# Update the Hamiltonian with the new coefficient
H[1][1] = qt.coefficient(H_1_coeff, args={"omega": omega})
# Run the Monte Carlo solver
result = mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={"method": "diffrax"})
# Return the expectation value of the number operator at the final time
return result.expect[0][-1].real
# Compute the gradient
gradient = jax.grad(f)(1.0)
print("Gradient:", gradient)

0 comments on commit f4946ce

Please sign in to comment.