Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RL Ideas #16

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions src/qutip_qoc/_rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
This module contains ...
"""
import qutip as qt
from qutip import Qobj, QobjEvo

import numpy as np

import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env


class _RL(gym.Env): # TODO: this should be similar to your GymQubitEnv(gym.Env) implementation
"""
Class for storing a control problem and ...
"""

def __init__(
self,
objective,
time_interval,
time_options,
control_parameters,
alg_kwargs,
guess_params,
**integrator_kwargs,
):
super().__init__() # TODO: super init your gym environment here

# ------------------------------- copied from _GOAT class -------------------------------

# TODO: you dont have to use (or keep them) if you don't need the following attributes
# this is just an inspiration how to extract information from the input

self._Hd = objective.H[0]
self._Hc_lst = objective.H[1:]

self._control_parameters = control_parameters
self._guess_params = guess_params
self._H = self._prepare_generator()

self._initial = objective.initial
self._target = objective.target

self._evo_time = time_interval.evo_time

# inferred attributes
self._norm_fac = 1 / self._target.norm()

# integrator options
self._integrator_kwargs = integrator_kwargs

self._rtol = self._integrator_kwargs.get("rtol", 1e-5)
self._atol = self._integrator_kwargs.get("atol", 1e-5)

# choose solver and fidelity type according to problem
if self._Hd.issuper:
self._fid_type = alg_kwargs.get("fid_type", "TRACEDIFF")
self._solver = qt.MESolver(H=self._H, options=self._integrator_kwargs)

else:
self._fid_type = alg_kwargs.get("fid_type", "PSU")
self._solver = qt.SESolver(H=self._H, options=self._integrator_kwargs)

self.infidelity = self._infid # TODO: should be used to calculate the reward

# ----------------------------------------------------------------------------------------
# TODO: set up your gym environment as you did correctly in post10
self.max_episode_time = time_interval.evo_time # maximum time for an episode
self.max_steps = time_interval.n_tslots # maximum number of steps in an episode
self.step_duration = time_interval.tslots[-1] / time_interval.n_tslots # step duration for mesvole()
...


# ----------------------------------------------------------------------------------------

def _infid(self, params):
"""
Calculate infidelity to be minimized
"""
X = self._solver.run(
self._initial, [0.0, self._evo_time], args={"p": params}
).final_state

if self._fid_type == "TRACEDIFF":
diff = X - self._target
# to prevent if/else in qobj.dag() and qobj.tr()
diff_dag = Qobj(diff.data.adjoint(), dims=diff.dims)
g = 1 / 2 * (diff_dag * diff).data.trace()
infid = np.real(self._norm_fac * g)
else:
g = self._norm_fac * self._target.overlap(X)
if self._fid_type == "PSU": # f_PSU (drop global phase)
infid = 1 - np.abs(g)
elif self._fid_type == "SU": # f_SU (incl global phase)
infid = 1 - np.real(g)

return infid

# TODO: don't hesitate to add the required methods for your rl environment

def step(self, action):
...

def train(self):
...

def result(self):
# TODO: return qoc.Result object with the optimized pulse amplitudes
...
17 changes: 17 additions & 0 deletions src/qutip_qoc/pulse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from qutip_qoc._optimizer import _global_local_optimization
from qutip_qoc._time import _TimeInterval
from qutip_qoc._rl import _RL

__all__ = ["optimize_pulses"]

Expand Down Expand Up @@ -348,6 +349,22 @@ def optimize_pulses(

qtrl_optimizers.append(qtrl_optimizer)

# TODO: we can deal with proper handling later
if alg == "RL":
rl_env = _RL(
objectives,
control_parameters,
time_interval,
time_options,
algorithm_kwargs,
optimizer_kwargs,
minimizer_kwargs,
integrator_kwargs,
qtrl_optimizers,
)
rl_env.train()
return rl_env.result()

return _global_local_optimization(
objectives,
control_parameters,
Expand Down
37 changes: 37 additions & 0 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,41 @@ def sin_z_jax(t, r, **kwargs):
algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01, "fix_frequency": False},
)

# ----------------------- RL --------------------
# TODO: this is the input for optimiz_pulses() function
# you can use this routine to test your implementation

# state to state transfer
init = qt.basis(2, 0)
target = qt.basis(2, 1)

H_c = [qt.sigmax(), qt.sigmay(), qt.sigmaz()] # control Hamiltonians

w, d, y = 0.1, 1.0, 0.1
H_d = 1 / 2 * (w * qt.sigmaz() + d * qt.sigmax()) # drift Hamiltonian

H = [H_d] + H_c # total Hamiltonian

state2state_rl = Case(
objectives=[Objective(initial, H, target)],
control_parameters={"bounds": [-13, 13]}, # TODO: for now only consider bounds
tlist=np.linspace(0, 10, 100), # TODO: derive single step duration and max evo time / max num steps from this
algorithm_kwargs={
"fid_err_targ": 0.01,
"alg": "RL",
"max_iter": 100,
}
)

# TODO: no big difference for unitary evolution

initial = qt.qeye(2) # Identity
target = qt.gates.hadamard_transform()

unitary_rl = state2state_rl._replace(
objectives=[Objective(initial, H, target)],
)


@pytest.fixture(
params=[
Expand All @@ -160,6 +195,8 @@ def sin_z_jax(t, r, **kwargs):
pytest.param(state2state_param_crab, id="State to state (param. CRAB)"),
pytest.param(state2state_goat, id="State to state (GOAT)"),
pytest.param(state2state_jax, id="State to state (JAX)"),
pytest.param(state2state_rl, id="State to state (RL)"),
pytest.param(unitary_rl, id="Unitary (RL)"),
]
)
def tst(request):
Expand Down
Loading