Skip to content

Commit

Permalink
fully wrap GRAPE in global optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick committed Mar 24, 2024
1 parent 6a4e655 commit 69ef01b
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 236 deletions.
38 changes: 9 additions & 29 deletions src/qutip_qoc/analytical_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from qutip_qoc.crab import Multi_CRAB
from qutip_qoc.grape import Multi_GRAPE

__all__ = ["optimize_pulses"]
__all__ = ["global_local_optimization"]


def get_init_and_bounds_from_options(lst, input):
Expand Down Expand Up @@ -173,7 +173,7 @@ def opt_callback(self, x, f, accept):
return terminate


def optimize_pulses(
def global_local_optimization(
objectives,
pulse_options,
time_interval,
Expand All @@ -182,7 +182,7 @@ def optimize_pulses(
optimizer_kwargs,
minimizer_kwargs,
integrator_kwargs,
qtrl_optimizer=None,
qtrl_optimizers=None,
):
"""
Optimize a pulse sequence to implement a given target unitary by optimizing
Expand Down Expand Up @@ -281,7 +281,7 @@ def optimize_pulses(
integrator_kwargs["normalize_output"] = False
integrator_kwargs.setdefault("progress_bar", False)

# extract initial and boundary values
# extract initial and boundary values for global and local optimizer
x0, bounds = [], []
for key in pulse_options.keys():
get_init_and_bounds_from_options(x0, pulse_options[key].get("guess"))
Expand All @@ -290,8 +290,7 @@ def optimize_pulses(
get_init_and_bounds_from_options(x0, time_options.get("guess", None))
get_init_and_bounds_from_options(bounds, time_options.get("bounds", None))

if len(x0) != 0:
optimizer_kwargs.setdefault("x0", np.concatenate(x0))
optimizer_kwargs["x0"] = np.concatenate(x0)

# algorithm specific settings
if algorithm_kwargs.get("alg") == "JOPT":
Expand All @@ -316,29 +315,10 @@ def optimize_pulses(
**integrator_kwargs,
)
elif algorithm_kwargs.get("alg") == "CRAB":
multi_objective = Multi_CRAB(
qtrl_optimizer,
objectives,
time_interval,
time_options,
pulse_options,
algorithm_kwargs,
guess_params=optimizer_kwargs["x0"],
**integrator_kwargs,
)
minimizer_kwargs.setdefault("method", "Nelder-Mead")
multi_objective = Multi_CRAB(qtrl_optimizers)

elif algorithm_kwargs.get("alg") == "GRAPE":
multi_objective = Multi_GRAPE(
qtrl_optimizer,
objectives,
time_interval,
time_options,
pulse_options,
algorithm_kwargs,
guess_params=optimizer_kwargs["x0"],
**integrator_kwargs,
)
multi_objective = Multi_GRAPE(qtrl_optimizers)

# optimizer specific settings
opt_method = optimizer_kwargs.get(
Expand All @@ -352,7 +332,7 @@ def optimize_pulses(
optimizer_kwargs.setdefault( # or "max_iter"
"niter",
optimizer_kwargs.get( # use algorithm_kwargs
"max_iter", algorithm_kwargs.get("max_iter", 1000)
"max_iter", algorithm_kwargs.get("max_iter", 1)
),
)

Expand Down Expand Up @@ -386,7 +366,7 @@ def optimize_pulses(
time_interval,
guess_params=x0,
var_time=var_t,
qtrl_optimizer=qtrl_optimizer,
qtrl_optimizers=qtrl_optimizers,
)

# Callback instance for termination and logging
Expand Down
13 changes: 4 additions & 9 deletions src/qutip_qoc/crab.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,11 @@ class Multi_CRAB:

def __init__(
self,
qtrl_optimizer,
objectives,
time_interval,
time_options,
pulse_options,
alg_kwargs,
guess_params,
**integrator_kwargs,
qtrl_optimizers,
):
for optim in list(qtrl_optimizer):
if not isinstance(qtrl_optimizers, list):
qtrl_optimizers = [qtrl_optimizers]
for optim in qtrl_optimizers:
crab = copy.deepcopy(optim)
crab.fid_err_func_wrapper = types.MethodType(fid_err_func_wrapper, crab)
# Stack for each objective
Expand Down
15 changes: 4 additions & 11 deletions src/qutip_qoc/grape.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,11 @@ class Multi_GRAPE:

def __init__(
self,
qtrl_optimizer,
objectives,
time_interval,
time_options,
pulse_options,
alg_kwargs,
guess_params,
**integrator_kwargs,
qtrl_optimizers,
):
if not isinstance(qtrl_optimizer, list):
qtrl_optimizer = [qtrl_optimizer]
for optim in qtrl_optimizer:
if not isinstance(qtrl_optimizers, list):
qtrl_optimizers = [qtrl_optimizers]
for optim in qtrl_optimizers:
grape = copy.deepcopy(optim)
grape.fid_err_func_wrapper = types.MethodType(fid_err_func_wrapper, grape)
# Stack for each objective
Expand Down
Loading

0 comments on commit 69ef01b

Please sign in to comment.