From d016feae8edb94a9a41b532d1674d99ebe09c0b0 Mon Sep 17 00:00:00 2001 From: Patrick <> Date: Thu, 18 Apr 2024 19:39:37 +0200 Subject: [PATCH] fix dual_annealing bug and add tests --- src/qutip_qoc/optimizer.py | 20 +++++++++++++++----- src/qutip_qoc/pulse_optim.py | 9 +++++---- tests/test_fidelity.py | 6 ++++++ tests/test_result.py | 15 +++++++-------- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/qutip_qoc/optimizer.py b/src/qutip_qoc/optimizer.py index 990e84e..413b330 100644 --- a/src/qutip_qoc/optimizer.py +++ b/src/qutip_qoc/optimizer.py @@ -320,10 +320,17 @@ def global_local_optimization( optimizer = sp.optimize.dual_annealing # if not specified through optimizer_kwargs "maxiter" - optimizer_kwargs.setdefault( - "maxiter", - optimizer_kwargs.get("max_iter", algorithm_kwargs.get("glob_max_iter", 0)), + keys = ["maxiter", "max_iter", "glob_max_iter", "niter"] + value = next( + (optimizer_kwargs.pop(key, None) or algorithm_kwargs.get(key)) + for key in keys + if optimizer_kwargs.get(key) or algorithm_kwargs.get(key) ) + optimizer_kwargs.setdefault("maxiter", value) + + # remove remaining keys + for key in keys[1:]: + optimizer_kwargs.pop(key, None) if len(bounds) != 0: # realizes boundaries through optimizer optimizer_kwargs.setdefault("bounds", np.concatenate(bounds)) @@ -375,8 +382,11 @@ def global_local_optimization( result.n_iters = min_res.nit if result.message is None: result.message = ( - "Local minimizer: " - + min_res["lowest_optimization_result"].message + ( + "Local minimizer: " + min_res["lowest_optimization_result"].message + if opt_method == "basinhopping" + else "" # dual_annealing does not return a local minimizer message + ) + " Global optimizer: " + min_res.message[0] ) diff --git a/src/qutip_qoc/pulse_optim.py b/src/qutip_qoc/pulse_optim.py index 210becf..83ba312 100644 --- a/src/qutip_qoc/pulse_optim.py +++ b/src/qutip_qoc/pulse_optim.py @@ -122,6 +122,10 @@ def optimize_pulses( integrator_kwargs = {} time_options = control_parameters.pop("__time__", {}) + if time_options: # convert to list of bounds if not already + if not isinstance(time_options["bounds"][0], (list, tuple)): + time_options["bounds"] = [time_options["bounds"]] + alg = algorithm_kwargs.get("alg", "GRAPE") # works with most input types Hd_lst, Hc_lst = [], [] @@ -313,13 +317,10 @@ def optimize_pulses( else: # Set the initial parameters - pgen.set_optim_var_vals(np.array(x0[j])) + pgen.init_pulse(init_param_vals=np.array(x0[j])) init_amps[:, j] = pgen.gen_pulse() # Initialise the starting amplitudes - # NOTE: any initial CRAB guess pulse is only used - # to modulate or offset the initial amplitudes - # depending on init_pulse_params: pulse_action dyn.initialize_controls(init_amps) # And store the (random) initial parameters init_params = qtrl_optimizer._get_optim_var_vals() diff --git a/tests/test_fidelity.py b/tests/test_fidelity.py index 116e4b8..66a8aee 100644 --- a/tests/test_fidelity.py +++ b/tests/test_fidelity.py @@ -74,6 +74,10 @@ def grad_sin(t, p, idx): }, ) +PSU_state2state_dual_annealing = PSU_state2state._replace( + optimizer_kwargs={"method": "dual_annealing", "niter": 5}, +) + # SU (must depend on global phase) state to state transfer SU_state2state = PSU_state2state._replace( objectives=[Objective(initial, H, initial)], @@ -174,6 +178,8 @@ def sin_jax(t, p): pytest.param(PSU_unitary_jax, id="PSU unitary gate (JAX)"), pytest.param(SU_unitary_jax, id="SU unitary gate (JAX)"), pytest.param(TRCDIFF_map_jax, id="TRACEDIFF map synthesis (JAX)"), + # Options + pytest.param(PSU_state2state_dual_annealing, id="Dual annealing (GOAT)"), ] ) def tst(request): diff --git a/tests/test_result.py b/tests/test_result.py index d23b059..37ba16e 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -77,7 +77,13 @@ def grad_sin(t, p, idx): "seed": 0, }, ) +# ----------------------- CRAB -------------------- +# state to state transfer with initial parameters (not amplitudes) +state2state_param_crab = state2state_goat._replace( + objectives=[Objective(initial, H, target)], + algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01}, +) # ----------------------- JAX --------------------- @@ -110,12 +116,6 @@ def sin_z_jax(t, r, **kwargs): algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01}, ) -# state to state transfer with initial parameters -# instead of initial amplitudes -state2state_param_crab = state2state_goat._replace( - objectives=[Objective(initial, H, target)], - algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01}, -) # ------------------- discrete CRAB / GRAPE control ------------------------ @@ -157,8 +157,7 @@ def sin_z_jax(t, r, **kwargs): params=[ pytest.param(state2state_grape, id="State to state (GRAPE)"), pytest.param(state2state_crab, id="State to state (CRAB)"), - # TODO: reactivate test after qutip_qtrl PR was merged - # pytest.param(state2state_param_crab, id="State to state (param. CRAB)"), + pytest.param(state2state_param_crab, id="State to state (param. CRAB)"), pytest.param(state2state_goat, id="State to state (GOAT)"), pytest.param(state2state_jax, id="State to state (JAX)"), ]