Skip to content

Commit

Permalink
added _backup_result to keep track of the result even if the algorith…
Browse files Browse the repository at this point in the history
…m continues to search for solutions with shorter pulses
  • Loading branch information
LegionAtol committed Sep 30, 2024
1 parent 4485f9d commit 3889a63
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
64 changes: 50 additions & 14 deletions src/qutip_qoc/_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,19 @@ def create_pulse_func(idx):
n_iters = 0, # Number of iterations(episodes) until convergence
iter_seconds = [], # list containing the time taken for each iteration(episode) of the optimization
var_time = True, # Whether the optimization was performed with variable time
guess_params=[]
)

self._backup_result = Result( # used as a backup in case the algorithm with shorter_pulses does not find an episode with infid<target_infid
objectives = objectives,
time_interval = time_interval,
start_local_time = time.localtime(),
n_iters = 0,
iter_seconds = [],
var_time = True,
guess_params=[]
)
self._use_backup_result = False # if true, use self._backup_result as the final optimization result

#for the reward
self._step_penalty = 1
Expand Down Expand Up @@ -228,21 +240,39 @@ def reset(self, seed=None):
self._state = self._initial
return self._get_obs(), {}

def result(self):
def _save_result(self):
"""
Retrieve the results of the optimization process, including the optimized
Save the results of the optimization process, including the optimized
pulse sequences, final states, and performance metrics.
"""
self._result.end_local_time = time.localtime()
self._result.n_iters = len(self._result.iter_seconds)
self._result.optimized_params = self._actions.copy() + [self._result.total_seconds] # If var_time is True, the last parameter is the evolution time
self._result._optimized_controls = self._actions.copy()
self._result.start_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.start_local_time) # Convert to a string
self._result.end_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.end_local_time) # Convert to a string
self._result._guess_controls = []
self._result._optimized_H = [self._H]
self._result.guess_params = []
return self._result
result_obj = self._backup_result if self._use_backup_result else self._result

if(self._use_backup_result):
self._backup_result.iter_seconds = self._result.iter_seconds.copy()
self._backup_result._final_states = self._result._final_states.copy()
self._backup_result.infidelity = self._result.infidelity

result_obj.end_local_time = time.localtime()
result_obj.n_iters = len(self._result.iter_seconds)
result_obj.optimized_params = self._actions.copy() + [self._result.total_seconds] # If var_time is True, the last parameter is the evolution time
result_obj._optimized_controls = self._actions.copy()
result_obj._guess_controls = []
result_obj._optimized_H = [self._H]


def result(self):
"""
Final conversions and return of optimization results
"""
if self._use_backup_result:
self._backup_result.start_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._backup_result.start_local_time) # Convert to a string
self._backup_result.end_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._backup_result.end_local_time) # Convert to a string
return self._backup_result
else:
self._save_result()
self._result.start_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.start_local_time) # Convert to a string
self._result.end_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.end_local_time) # Convert to a string
return self._result

def train(self):
"""
Expand All @@ -266,7 +296,6 @@ class EarlyStopTraining(BaseCallback):
"""
def __init__(self, verbose: int = 0):
super(EarlyStopTraining, self).__init__(verbose)
self.stop_train = False

def _on_step(self) -> bool:
"""
Expand All @@ -279,12 +308,18 @@ def _on_step(self) -> bool:

# Check if we need to stop training
if env.current_episode >= env.max_episodes:
env._result.message = f"Reached {env.max_episodes} episodes, stopping training."
if env._use_backup_result == True:
env._backup_result.message = f"Reached {env.max_episodes} episodes, stopping training. Return the last founded episode with infid < target_infid"
else:
env._result.message = f"Reached {env.max_episodes} episodes, stopping training."
return False # Stop training
elif (env._result.infidelity <= env._fid_err_targ) and not(env.shorter_pulses):
env._result.message = f"Stop training because an episode with infidelity <= target infidelity was found"
return False # Stop training
elif env.shorter_pulses:
if(env._result.infidelity <= env._fid_err_targ): # if it finds an episode with infidelity lower than target infidelity, I'll save it in the meantime
env._use_backup_result = True
env._save_result()
if len(env._episode_info) >= 100:
last_100_episodes = env._episode_info[-100:]

Expand All @@ -293,6 +328,7 @@ def _on_step(self) -> bool:
infid_condition = all(ep['final_infidelity'] <= env._fid_err_targ for ep in last_100_episodes)

if steps_condition and infid_condition:
env._use_backup_result = False
env._result.message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid."
return False # Stop training
return True # Continue training
12 changes: 12 additions & 0 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ def sin_z_jax(t, r, **kwargs):

unitary_rl = state2state_rl._replace(
objectives=[Objective(initial, H, target)],
control_parameters = {
"p": {"bounds": [(-13, 13)]},
"__time__": {
"guess": np.array([0.0]), #dummy value
"bounds": [(0.0, 0.0)] #dummy value
}
},
algorithm_kwargs={
"fid_err_targ": 0.01,
"alg": "RL",
"max_iter": 300,
},
)


Expand Down

0 comments on commit 3889a63

Please sign in to comment.