diff --git a/src/qutip_qoc/_rl.py b/src/qutip_qoc/_rl.py index 66b8474..3bdb1fc 100644 --- a/src/qutip_qoc/_rl.py +++ b/src/qutip_qoc/_rl.py @@ -88,7 +88,19 @@ def create_pulse_func(idx): n_iters = 0, # Number of iterations(episodes) until convergence iter_seconds = [], # list containing the time taken for each iteration(episode) of the optimization var_time = True, # Whether the optimization was performed with variable time + guess_params=[] ) + + self._backup_result = Result( # used as a backup in case the algorithm with shorter_pulses does not find an episode with infid bool: """ @@ -279,12 +308,18 @@ def _on_step(self) -> bool: # Check if we need to stop training if env.current_episode >= env.max_episodes: - env._result.message = f"Reached {env.max_episodes} episodes, stopping training." + if env._use_backup_result == True: + env._backup_result.message = f"Reached {env.max_episodes} episodes, stopping training. Return the last founded episode with infid < target_infid" + else: + env._result.message = f"Reached {env.max_episodes} episodes, stopping training." return False # Stop training elif (env._result.infidelity <= env._fid_err_targ) and not(env.shorter_pulses): env._result.message = f"Stop training because an episode with infidelity <= target infidelity was found" return False # Stop training elif env.shorter_pulses: + if(env._result.infidelity <= env._fid_err_targ): # if it finds an episode with infidelity lower than target infidelity, I'll save it in the meantime + env._use_backup_result = True + env._save_result() if len(env._episode_info) >= 100: last_100_episodes = env._episode_info[-100:] @@ -293,6 +328,7 @@ def _on_step(self) -> bool: infid_condition = all(ep['final_infidelity'] <= env._fid_err_targ for ep in last_100_episodes) if steps_condition and infid_condition: + env._use_backup_result = False env._result.message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid." return False # Stop training return True # Continue training \ No newline at end of file diff --git a/tests/test_result.py b/tests/test_result.py index b74a211..9bb5496 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -190,6 +190,18 @@ def sin_z_jax(t, r, **kwargs): unitary_rl = state2state_rl._replace( objectives=[Objective(initial, H, target)], + control_parameters = { + "p": {"bounds": [(-13, 13)]}, + "__time__": { + "guess": np.array([0.0]), #dummy value + "bounds": [(0.0, 0.0)] #dummy value + } + }, + algorithm_kwargs={ + "fid_err_targ": 0.01, + "alg": "RL", + "max_iter": 300, + }, )