Skip to content

Commit

Permalink
fix dual_annealing bug and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick committed Apr 18, 2024
1 parent 1aee32a commit d016fea
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 17 deletions.
20 changes: 15 additions & 5 deletions src/qutip_qoc/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,17 @@ def global_local_optimization(
optimizer = sp.optimize.dual_annealing

# if not specified through optimizer_kwargs "maxiter"
optimizer_kwargs.setdefault(
"maxiter",
optimizer_kwargs.get("max_iter", algorithm_kwargs.get("glob_max_iter", 0)),
keys = ["maxiter", "max_iter", "glob_max_iter", "niter"]
value = next(
(optimizer_kwargs.pop(key, None) or algorithm_kwargs.get(key))
for key in keys
if optimizer_kwargs.get(key) or algorithm_kwargs.get(key)
)
optimizer_kwargs.setdefault("maxiter", value)

# remove remaining keys
for key in keys[1:]:
optimizer_kwargs.pop(key, None)

if len(bounds) != 0: # realizes boundaries through optimizer
optimizer_kwargs.setdefault("bounds", np.concatenate(bounds))
Expand Down Expand Up @@ -375,8 +382,11 @@ def global_local_optimization(
result.n_iters = min_res.nit
if result.message is None:
result.message = (
"Local minimizer: "
+ min_res["lowest_optimization_result"].message
(
"Local minimizer: " + min_res["lowest_optimization_result"].message
if opt_method == "basinhopping"
else "" # dual_annealing does not return a local minimizer message
)
+ " Global optimizer: "
+ min_res.message[0]
)
Expand Down
9 changes: 5 additions & 4 deletions src/qutip_qoc/pulse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def optimize_pulses(
integrator_kwargs = {}

time_options = control_parameters.pop("__time__", {})
if time_options: # convert to list of bounds if not already
if not isinstance(time_options["bounds"][0], (list, tuple)):
time_options["bounds"] = [time_options["bounds"]]

alg = algorithm_kwargs.get("alg", "GRAPE") # works with most input types

Hd_lst, Hc_lst = [], []
Expand Down Expand Up @@ -313,13 +317,10 @@ def optimize_pulses(

else:
# Set the initial parameters
pgen.set_optim_var_vals(np.array(x0[j]))
pgen.init_pulse(init_param_vals=np.array(x0[j]))
init_amps[:, j] = pgen.gen_pulse()

# Initialise the starting amplitudes
# NOTE: any initial CRAB guess pulse is only used
# to modulate or offset the initial amplitudes
# depending on init_pulse_params: pulse_action
dyn.initialize_controls(init_amps)
# And store the (random) initial parameters
init_params = qtrl_optimizer._get_optim_var_vals()
Expand Down
6 changes: 6 additions & 0 deletions tests/test_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def grad_sin(t, p, idx):
},
)

PSU_state2state_dual_annealing = PSU_state2state._replace(
optimizer_kwargs={"method": "dual_annealing", "niter": 5},
)

# SU (must depend on global phase) state to state transfer
SU_state2state = PSU_state2state._replace(
objectives=[Objective(initial, H, initial)],
Expand Down Expand Up @@ -174,6 +178,8 @@ def sin_jax(t, p):
pytest.param(PSU_unitary_jax, id="PSU unitary gate (JAX)"),
pytest.param(SU_unitary_jax, id="SU unitary gate (JAX)"),
pytest.param(TRCDIFF_map_jax, id="TRACEDIFF map synthesis (JAX)"),
# Options
pytest.param(PSU_state2state_dual_annealing, id="Dual annealing (GOAT)"),
]
)
def tst(request):
Expand Down
15 changes: 7 additions & 8 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def grad_sin(t, p, idx):
"seed": 0,
},
)
# ----------------------- CRAB --------------------

# state to state transfer with initial parameters (not amplitudes)
state2state_param_crab = state2state_goat._replace(
objectives=[Objective(initial, H, target)],
algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01},
)
# ----------------------- JAX ---------------------


Expand Down Expand Up @@ -110,12 +116,6 @@ def sin_z_jax(t, r, **kwargs):
algorithm_kwargs={"alg": "JOPT", "fid_err_targ": 0.01},
)

# state to state transfer with initial parameters
# instead of initial amplitudes
state2state_param_crab = state2state_goat._replace(
objectives=[Objective(initial, H, target)],
algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01},
)

# ------------------- discrete CRAB / GRAPE control ------------------------

Expand Down Expand Up @@ -157,8 +157,7 @@ def sin_z_jax(t, r, **kwargs):
params=[
pytest.param(state2state_grape, id="State to state (GRAPE)"),
pytest.param(state2state_crab, id="State to state (CRAB)"),
# TODO: reactivate test after qutip_qtrl PR was merged
# pytest.param(state2state_param_crab, id="State to state (param. CRAB)"),
pytest.param(state2state_param_crab, id="State to state (param. CRAB)"),
pytest.param(state2state_goat, id="State to state (GOAT)"),
pytest.param(state2state_jax, id="State to state (JAX)"),
]
Expand Down

0 comments on commit d016fea

Please sign in to comment.