diff --git a/doc/source/autodiff.rst b/doc/source/autodiff.rst index d7c97b9..7d37117 100644 --- a/doc/source/autodiff.rst +++ b/doc/source/autodiff.rst @@ -41,4 +41,60 @@ should work: result = solver.run(ket, [0, 1], e_ops=qt.num(2).to("jax"), args={"w":w}) return result.e_data[0][1].real - jax.grad(f)(0.5, solver) \ No newline at end of file + jax.grad(f)(0.5, solver) + + +Auto differentiation in ``mcsolve`` +=================================== + +.. note:: + + The functionality demonstrated in this example is currently available only in + the development (`dev.major`) branch of QuTiP. Ensure you are using the appropriate + version if you wish to replicate these results. + + +.. code-block:: python + import qutip_jax as qjax + import qutip as qt + import jax + import jax.numpy as jnp + from functools import partial + from qutip import mcsolve + # Use JAX backend for QuTiP + qjax.use_jax_backend() + # Define time-dependent functions + @partial(jax.jit, static_argnames=("omega",)) + def H_1_coeff(t, omega): + return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t) + # Define operators and states + size = 2 + a = qt.destroy(size).to("jax") # Annihilation operator + sm = qt.sigmax().to("jax") # Example spin operator + # Define the Hamiltonian + H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm + H_1_op = sm * a.dag() + sm.dag() * a + # Initialize the Hamiltonian with time-dependent coefficients + H = [H_0, [H_1_op, qt.coefficient(H_1_coeff, args={"omega": 1.0})]] + # Define initial states + pure_state = qt.basis(size, size-1).to("jax") + mixed_state = qt.maximally_mixed_dm(size).to("jax") + state = pure_state + # Define collapse operators and observables + c_ops = [jnp.sqrt(0.1) * a] + e_ops = [a.dag() * a, sm.dag() * sm] + # Time list + tlist = jnp.linspace(0.0, 10.0, 101) + # Define the function for which we want to compute the gradient + def f(omega): + # Update the Hamiltonian with the new coefficient + H[1][1] = qt.coefficient(H_1_coeff, args={"omega": omega}) + + # Run the Monte Carlo solver + result = mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={"method": "diffrax"}) + + # Return the expectation value of the number operator at the final time + return result.expect[0][-1].real + # Compute the gradient + gradient = jax.grad(f)(1.0) + print("Gradient:", gradient) \ No newline at end of file