Skip to content

Commit

Permalink
drop gradient for CRAB and change opimize routine, started working on…
Browse files Browse the repository at this point in the history
… init params for CRAB
  • Loading branch information
Patrick committed Feb 28, 2024
1 parent 35ee13a commit 623a295
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 728 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ repos:
- id: ruff
args: ["--fix", "--show-fixes"]
types_or: [python, pyi, jupyter]
exclude: "src/qutip_qoc/examples.ipynb"
- id: ruff-format
types_or: [python, pyi, jupyter]
exclude: "src/qutip_qoc/examples.ipynb"

# Also run Black on examples in the documentation
- repo: https://github.com/adamchainz/blacken-docs
Expand Down
58 changes: 35 additions & 23 deletions src/qutip_qoc/analytical_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,16 @@ def opt_callback(self, x, f, accept):


def optimize_pulses(
objectives,
pulse_options,
time_interval,
time_options,
algorithm_kwargs,
optimizer_kwargs,
minimizer_kwargs,
integrator_kwargs,
crab_optimizer=None):
objectives,
pulse_options,
time_interval,
time_options,
algorithm_kwargs,
optimizer_kwargs,
minimizer_kwargs,
integrator_kwargs,
crab_optimizer=None,
):
"""
Optimize a pulse sequence to implement a given target unitary by optimizing
the parameters of the pulse functions. The algorithm is a two-layered
Expand Down Expand Up @@ -304,16 +305,27 @@ def optimize_pulses(
**integrator_kwargs,
)
elif algorithm_kwargs.get("alg") == "GOAT":
multi_objective = Multi_GOAT(objectives, time_interval, time_options,
pulse_options, algorithm_kwargs,
guess_params=optimizer_kwargs["x0"],
**integrator_kwargs)
multi_objective = Multi_GOAT(
objectives,
time_interval,
time_options,
pulse_options,
algorithm_kwargs,
guess_params=optimizer_kwargs["x0"],
**integrator_kwargs,
)
elif algorithm_kwargs.get("alg") == "CRAB":
multi_objective = Multi_CRAB(crab_optimizer,
objectives, time_interval, time_options,
pulse_options, algorithm_kwargs,
guess_params=optimizer_kwargs["x0"],
**integrator_kwargs)
multi_objective = Multi_CRAB(
crab_optimizer,
objectives,
time_interval,
time_options,
pulse_options,
algorithm_kwargs,
guess_params=optimizer_kwargs["x0"],
**integrator_kwargs,
)
minimizer_kwargs.setdefault("method", "Nelder-Mead")

# optimizer specific settings
opt_method = optimizer_kwargs.get(
Expand All @@ -331,7 +343,7 @@ def optimize_pulses(
),
)

if len(bounds) != 0:# realizes boundaries through minimizer
if len(bounds) != 0: # realizes boundaries through minimizer
minimizer_kwargs.setdefault("bounds", np.concatenate(bounds))

elif opt_method == "dual_annealing":
Expand All @@ -345,7 +357,7 @@ def optimize_pulses(
),
)

if len(bounds) != 0:# realizes boundaries through optimizer
if len(bounds) != 0: # realizes boundaries through optimizer
optimizer_kwargs.setdefault("bounds", np.concatenate(bounds))

# remove overload from optimizer_kwargs
Expand All @@ -369,9 +381,9 @@ def optimize_pulses(
min_res = optimizer(
func=multi_objective.goal_fun,
minimizer_kwargs={
'jac': None,#multi_objective.grad_fun,
'callback': cllbck.min_callback,
**minimizer_kwargs
"jac": multi_objective.grad_fun,
"callback": cllbck.min_callback,
**minimizer_kwargs,
},
callback=cllbck.opt_callback,
**optimizer_kwargs,
Expand Down
53 changes: 24 additions & 29 deletions src/qutip_qoc/crab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import qutip_qtrl.optimizer as opt
import types
import copy

logger = logging.get_logger()


class CRAB(opt.OptimizerCrab):
def __init__(self, cfg, dyn, params, termination_conditions):
super().__init__(cfg, dyn, params)
self.init_optim(termination_conditions)


# overwrite
def fid_err_func_wrapper(self, *args):
"""
Get the fidelity error achieved using the ctrl amplitudes passed
Expand All @@ -32,15 +33,12 @@ def fid_err_func_wrapper(self, *args):
self.stats.num_fidelity_func_calls = self.num_fid_func_calls
if self.log_level <= logging.DEBUG:
logger.debug(
"fidelity error call {}".format(
self.stats.num_fidelity_func_calls
)
"fidelity error call {}".format(self.stats.num_fidelity_func_calls)
)

amps = self._get_ctrl_amps(args[0].copy())
self.dynamics.update_ctrl_amps(amps)

tc = self.termination_conditions
err = self.dynamics.fid_computer.get_fid_err()

if self.iter_summary:
Expand All @@ -54,13 +52,12 @@ def fid_err_func_wrapper(self, *args):
"""
if err <= tc.fid_err_targ:
raise errors.GoalAchievedTerminate(err)
if self.num_fid_func_calls > tc.max_fid_func_calls:
raise errors.MaxFidFuncCallTerminate()
"""
return err

# overwrite

def fid_err_grad_wrapper(self, *args):
"""
Expand All @@ -84,9 +81,7 @@ def fid_err_grad_wrapper(self, *args):
if self.stats is not None:
self.stats.num_grad_func_calls = self.num_grad_func_calls
if self.log_level <= logging.DEBUG:
logger.debug(
"gradient call {}".format(self.stats.num_grad_func_calls)
)
logger.debug("gradient call {}".format(self.stats.num_grad_func_calls))
amps = self._get_ctrl_amps(args[0].copy())
self.dynamics.update_ctrl_amps(amps)
fid_comp = self.dynamics.fid_computer
Expand Down Expand Up @@ -114,19 +109,30 @@ def fid_err_grad_wrapper(self, *args):
return grad.flatten()


class Multi_CRAB():
class Multi_CRAB:
"""
Composite class for multiple GOAT instances
to optimize multiple objectives simultaneously
"""

def __init__(self, crab_optimizer, objectives, time_interval, time_options, pulse_options,
alg_kwargs, guess_params, **integrator_kwargs):

self.crabs = []

for obj in objectives:
crab = copy.deepcopy(crab_optimizer)
crabs: list = []
grad_fun = None

def __init__(
self,
crab_optimizer,
objectives,
time_interval,
time_options,
pulse_options,
alg_kwargs,
guess_params,
**integrator_kwargs,
):
if not isinstance(crab_optimizer, list):
crab_optimizer = [crab_optimizer]
for optim in crab_optimizer:
crab = copy.deepcopy(optim)
crab.fid_err_func_wrapper = types.MethodType(fid_err_func_wrapper, crab)
# Stack for each objective
self.crabs.append(crab)
Expand All @@ -145,14 +151,3 @@ def goal_fun(self, params):

self.mean_infid = np.mean(infid_sum)
return self.mean_infid

def grad_fun(self, params):
"""
Calculates the sum of gradients over all objectives
"""
grads = 0

for c in self.crabs:
grads += c.fid_err_grad_wrapper(params)

return grads
Loading

0 comments on commit 623a295

Please sign in to comment.