Skip to content

Commit

Permalink
Docstring fixes, use of __time__ parameter instead of shorter_pulses,…
Browse files Browse the repository at this point in the history
… added underscore for internal variables, args is now passed as a parameter to _infid(), changes in callback function
  • Loading branch information
LegionAtol committed Sep 14, 2024
1 parent 318081f commit 4485f9d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 76 deletions.
126 changes: 60 additions & 66 deletions src/qutip_qoc/_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,34 @@ def __init__(
self._Hd_lst.append(objective.H[0])
self._Hc_lst.append([H[0] if isinstance(H, list) else H for H in objective.H[1:]])

# create the QobjEvo with Hd, Hc and controls(args)
self.args = {f"alpha{i+1}": (1) for i in range(len(self._Hc_lst[0]))} # set the control parameters to 1 for all the Hc
def create_pulse_func(idx):
"""
Create a control pulse lambda function for a given index.
"""
return lambda t, args: self._pulse(t, args, idx+1)

# create the QobjEvo with Hd, Hc and controls(args)
self._H_lst = [self._Hd_lst[0]]
dummy_args = {f'alpha{i+1}': 1.0 for i in range(len(self._Hc_lst[0]))}
for i, Hc in enumerate(self._Hc_lst[0]):
self._H_lst.append([Hc, lambda t, args: self.pulse(t, self.args, i+1)])
self._H = qt.QobjEvo(self._H_lst, self.args)
self._H_lst.append([Hc, create_pulse_func(i)])
self._H = qt.QobjEvo(self._H_lst, args=dummy_args)

self.shorter_pulses = False if time_options == {} else True # lengthen the training to look for pulses of shorter duration, therefore episodes with fewer steps

# extract bounds for control_parameters
bounds = []
for key in control_parameters.keys():
bounds.append(control_parameters[key].get("bounds"))
self.lbound = [b[0][0] for b in bounds]
self.ubound = [b[0][1] for b in bounds]
self._lbound = [b[0][0] for b in bounds]
self._ubound = [b[0][1] for b in bounds]

self._alg_kwargs = alg_kwargs
self.shorter_pulses = self._alg_kwargs.get("shorter_pulses", False) # lengthen the training to look for pulses of shorter duration, therefore episodes with fewer steps

self._initial = objectives[0].initial
self._target = objectives[0].target
self.state = None
self.dim = self._initial.shape[0]
self._state = None
self._dim = self._initial.shape[0]

self._result = Result(
objectives = objectives,
Expand All @@ -87,19 +94,19 @@ def __init__(
self._step_penalty = 1

# To check if it exceeds the maximum number of steps in an episode
self.current_step = 0
self._current_step = 0

self.terminated = False
self.truncated = False
self.episode_info = [] # to contain some information from the latest episode
self._episode_info = [] # to contain some information from the latest episode

self._fid_err_targ = alg_kwargs["fid_err_targ"]

# inferred attributes
self._norm_fac = 1 / self._target.norm()

self.temp_actions = [] # temporary list to save episode actions
self.actions = [] # list of actions(lists) of the last episode
self._temp_actions = [] # temporary list to save episode actions
self._actions = [] # list of actions(lists) of the last episode

# integrator options
self._integrator_kwargs = integrator_kwargs
Expand All @@ -108,16 +115,16 @@ def __init__(

self.max_episode_time = time_interval.evo_time # maximum time for an episode
self.max_steps = time_interval.n_tslots # maximum number of steps in an episode
self.step_duration = time_interval.tslots[-1] / time_interval.n_tslots # step duration for mesvole
self._step_duration = time_interval.tslots[-1] / time_interval.n_tslots # step duration for mesvole
self.max_episodes = alg_kwargs["max_iter"] # maximum number of episodes for training
self.total_timesteps = self.max_episodes * self.max_steps # for learn() of gym
self._total_timesteps = self.max_episodes * self.max_steps # for learn() of gym
self.current_episode = 0 # To keep track of the current episode

# Define action and observation spaces (Gym)
if self._initial.isket:
obs_shape = (2 * self.dim,)
obs_shape = (2 * self._dim,)
else: # for unitary operators
obs_shape = (2 * self.dim * self.dim,)
obs_shape = (2 * self._dim * self._dim,)
self.action_space = spaces.Box(low=-1, high=1, shape=(len(self._Hc_lst[0]),), dtype=np.float32) # Continuous action space from -1 to +1, as suggested from gym
self.observation_space = spaces.Box(low=-1, high=1, shape=obs_shape, dtype=np.float32) # Observation space

Expand All @@ -129,13 +136,14 @@ def __init__(
self._fid_type = self._alg_kwargs.get("fid_type", "PSU")
self._solver = qt.SESolver(H=self._H, options=self._integrator_kwargs)

def pulse(self, t, args, idx):
def _pulse(self, t, args, idx):
"""
Returns the control pulse value at time t for a given index.
"""
return 1*args[f"alpha{idx}"]
alpha = args[f"alpha{idx}"]
return alpha

def save_episode_info(self):
def _save_episode_info(self):
"""
Save the information of the last episode before resetting the environment.
"""
Expand All @@ -144,19 +152,19 @@ def save_episode_info(self):
"final_infidelity": self._result.infidelity,
"terminated": self.terminated,
"truncated": self.truncated,
"steps_used": self.current_step,
"steps_used": self._current_step,
"elapsed_time": time.mktime(time.localtime())
}
self.episode_info.append(episode_data)
self._episode_info.append(episode_data)

def _infid(self, params=None):
def _infid(self, args):
"""
The agent performs a step, then calculate infidelity to be minimized of the current state against the target state.
"""
X = self._solver.run(
self.state, [0.0, self.step_duration], args={"p": params}
self._state, [0.0, self._step_duration], args=args
).final_state
self.state = X
self._state = X

if self._fid_type == "TRACEDIFF":
diff = X - self._target
Expand All @@ -177,20 +185,18 @@ def step(self, action):
Perform a single time step in the environment, applying the scaled action (control pulse)
chosen by the RL agent. Updates the system's state and computes the reward.
"""
alphas = [((action[i] + 1) / 2 * (self.ubound[0] - self.lbound[0])) + self.lbound[0] for i in range(len(action))]

for i, value in enumerate(alphas):
self.args[f"alpha{i+1}"] = value
alphas = [((action[i] + 1) / 2 * (self._ubound[0] - self._lbound[0])) + self._lbound[0] for i in range(len(action))]

infidelity = self._infid()
args = {f"alpha{i+1}": value for i, value in enumerate(alphas)}
_infidelity = self._infid(args)

self.current_step += 1
self.temp_actions.append(alphas)
self._result.infidelity = infidelity
reward = (1 - infidelity) - self._step_penalty
self._current_step += 1
self._temp_actions.append(alphas)
self._result.infidelity = _infidelity
reward = (1 - _infidelity) - self._step_penalty

self.terminated = infidelity <= self._fid_err_targ # the episode ended reaching the goal
self.truncated = self.current_step >= self.max_steps # if the episode ended without reaching the goal
self.terminated = _infidelity <= self._fid_err_targ # the episode ended reaching the goal
self.truncated = self._current_step >= self.max_steps # if the episode ended without reaching the goal

observation = self._get_obs()
return observation, reward, bool(self.terminated), bool(self.truncated), {}
Expand All @@ -200,26 +206,26 @@ def _get_obs(self):
Get the current state observation for the RL agent. Converts the system's
quantum state or matrix into a real-valued NumPy array suitable for RL algorithms.
"""
rho = self.state.full().flatten()
rho = self._state.full().flatten()
obs = np.concatenate((np.real(rho), np.imag(rho)))
return obs.astype(np.float32) # Gymnasium expects the observation to be of type float32

def reset(self, seed=None):
"""
Reset the environment to the initial state, preparing for a new episode.
"""
self.save_episode_info()
self._save_episode_info()

time_diff = self.episode_info[-1]["elapsed_time"] - (self.episode_info[-2]["elapsed_time"] if len(self.episode_info) > 1 else time.mktime(self._result.start_local_time))
time_diff = self._episode_info[-1]["elapsed_time"] - (self._episode_info[-2]["elapsed_time"] if len(self._episode_info) > 1 else time.mktime(self._result.start_local_time))
self._result.iter_seconds.append(time_diff)
self.current_step = 0 # Reset the step counter
self._current_step = 0 # Reset the step counter
self.current_episode += 1 # Increment episode counter
self.actions = self.temp_actions.copy()
self._actions = self._temp_actions.copy()
self.terminated = False
self.truncated = False
self.temp_actions = []
self._result._final_states = [self.state]
self.state = self._initial
self._temp_actions = []
self._result._final_states = [self._state]
self._state = self._initial
return self._get_obs(), {}

def result(self):
Expand All @@ -229,8 +235,8 @@ def result(self):
"""
self._result.end_local_time = time.localtime()
self._result.n_iters = len(self._result.iter_seconds)
self._result.optimized_params = self.actions.copy() + [self._result.total_seconds] # If var_time is True, the last parameter is the evolution time
self._result._optimized_controls = self.actions.copy()
self._result.optimized_params = self._actions.copy() + [self._result.total_seconds] # If var_time is True, the last parameter is the evolution time
self._result._optimized_controls = self._actions.copy()
self._result.start_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.start_local_time) # Convert to a string
self._result.end_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.end_local_time) # Convert to a string
self._result._guess_controls = []
Expand All @@ -252,7 +258,7 @@ def train(self):
stop_callback = EarlyStopTraining(verbose=1)

# Train the model
model.learn(total_timesteps = self.total_timesteps, callback=stop_callback)
model.learn(total_timesteps = self._total_timesteps, callback=stop_callback)

class EarlyStopTraining(BaseCallback):
"""
Expand All @@ -267,38 +273,26 @@ def _on_step(self) -> bool:
This method is required by the BaseCallback class. We use it to stop the training.
- Stop training if the maximum number of episodes is reached.
- Stop training if it finds an episode with infidelity <= than target infidelity
- If all of the last 100 episodes have infidelity below the target and use the same number of steps, stop training.
"""
env = self.training_env.envs[0].unwrapped

# Check if we need to stop training
if self.stop_train:
return False # Stop training
elif env.current_episode >= env.max_episodes:
if env.current_episode >= env.max_episodes:
env._result.message = f"Reached {env.max_episodes} episodes, stopping training."
return False # Stop training
elif (env._result.infidelity <= env._fid_err_targ) and not(env.shorter_pulses):
env._result.message = f"Stop training because an episode with infidelity <= target infidelity was found"
return False # Stop training
return True # Continue training

def _on_rollout_start(self) -> None:
"""
This method is called before the rollout starts (before collecting new samples).
Checks:
- If all of the last 100 episodes have infidelity below the target and use the same number of steps, stop training.
"""
#could be moved to on_step

env = self.training_env.envs[0].unwrapped
#Only if specified in alg_kwargs, the algorithm will search for shorter pulses, resulting in episodes with fewer steps.
if env.shorter_pulses:
if len(env.episode_info) >= 100:
last_100_episodes = env.episode_info[-100:]
elif env.shorter_pulses:
if len(env._episode_info) >= 100:
last_100_episodes = env._episode_info[-100:]

min_steps = min(info['steps_used'] for info in last_100_episodes)
steps_condition = all(ep['steps_used'] == min_steps for ep in last_100_episodes)
infid_condition = all(ep['final_infidelity'] <= env._fid_err_targ for ep in last_100_episodes)

if steps_condition and infid_condition:
env._result.message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid."
self.stop_train = True # Stop training
return False # Stop training
return True # Continue training
10 changes: 3 additions & 7 deletions src/qutip_qoc/pulse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,17 @@ def optimize_pulses(
control_id : dict
- guess: ndarray, shape (n,)
For RL you don't need to specify the guess.
Initial guess. Array of real elements of size (n,),
where ``n`` is the number of independent variables.
- bounds : sequence, optional
For RL you don't need to specify the guess.
Sequence of ``(min, max)`` pairs for each element in
`guess`. None is used to specify no bound.
__time__ : dict, optional
Only supported by GOAT and JOPT.
Only supported by GOAT, JOPT and RL.
For RL the values of guess and bounds are not relevant.
If given the pulse duration is treated as optimization parameter.
It must specify both:
Expand Down Expand Up @@ -84,11 +85,6 @@ def optimize_pulses(
Global steps default to 0 (no global optimization).
Can be overridden by specifying in minimizer_kwargs.
- shorter_pulses : bool, optional
If set to True, allows the algorithm to search for shorter control
pulses that can achieve the desired fidelity target using fewer steps.
By default, it is set to False, only attempting to reach the target infidelity.
Algorithm specific keywords for GRAPE,CRAB can be found in
:func:`qutip_qtrl.pulseoptim.optimize_pulse`.
Expand Down
9 changes: 6 additions & 3 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,17 @@ def sin_z_jax(t, r, **kwargs):
state2state_rl = Case(
objectives=[Objective(initial, H, target)],
control_parameters = {
"p": {"bounds": [(-13, 13)],}
"p": {"bounds": [(-13, 13)]},
"__time__": {
"guess": np.array([0.0]), #dummy value
"bounds": [(0.0, 0.0)] #dummy value
}
},
tlist=np.linspace(0, 10, 100),
algorithm_kwargs={
"fid_err_targ": 0.01,
"alg": "RL",
"max_iter": 300,
"shorter_pulses": True,
"max_iter": 20000,
},
optimizer_kwargs={}
)
Expand Down

0 comments on commit 4485f9d

Please sign in to comment.