Skip to content

Commit

Permalink
hacky fix for GRAPE result
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick committed Mar 24, 2024
1 parent 69ef01b commit 86204a1
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 32 deletions.
19 changes: 10 additions & 9 deletions src/qutip_qoc/analytical_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,9 @@ def global_local_optimization(
optimizer = sp.optimize.basinhopping

# if not specified through optimizer_kwargs "niter"
optimizer_kwargs.setdefault( # or "max_iter"
optimizer_kwargs.setdefault(
"niter",
optimizer_kwargs.get( # use algorithm_kwargs
"max_iter", algorithm_kwargs.get("max_iter", 1)
),
optimizer_kwargs.get("max_iter", algorithm_kwargs.get("glob_max_iter", 1)),
)

if len(bounds) != 0: # realizes boundaries through minimizer
Expand All @@ -343,11 +341,9 @@ def global_local_optimization(
optimizer = sp.optimize.dual_annealing

# if not specified through optimizer_kwargs "maxiter"
optimizer_kwargs.setdefault( # or "max_iter"
optimizer_kwargs.setdefault(
"maxiter",
optimizer_kwargs.get( # use algorithm_kwargs
"max_iter", algorithm_kwargs.get("max_iter", 1000)
),
optimizer_kwargs.get("max_iter", algorithm_kwargs.get("glob_max_iter", 1)),
)

if len(bounds) != 0: # realizes boundaries through optimizer
Expand Down Expand Up @@ -399,6 +395,11 @@ def global_local_optimization(
# save runtime information in result
result.n_iters = min_res.nit
if result.message is None:
result.message = min_res.message
result.message = (
"Local minimizer: "
+ min_res["lowest_optimization_result"].message
+ " Global optimizer: "
+ min_res.message[0]
)

return result
12 changes: 8 additions & 4 deletions src/qutip_qoc/crab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import qutip_qtrl.logging_utils as logging
import qutip_qtrl.optimizer as opt
from qutip_qtrl.errors import GoalAchievedTerminate, MaxFidFuncCallTerminate
import types
import copy

Expand Down Expand Up @@ -65,15 +66,13 @@ class Multi_CRAB:
to optimize multiple objectives simultaneously
"""

crabs: list = []
grad_fun = None

def __init__(
self,
qtrl_optimizers,
):
if not isinstance(qtrl_optimizers, list):
qtrl_optimizers = [qtrl_optimizers]
self.crabs = []
for optim in qtrl_optimizers:
crab = copy.deepcopy(optim)
crab.fid_err_func_wrapper = types.MethodType(fid_err_func_wrapper, crab)
Expand All @@ -89,7 +88,12 @@ def goal_fun(self, params):
infid_sum = 0

for crab in self.crabs: # TODO: parallelize
infid = crab.fid_err_func_wrapper(params)
try:
infid = crab.fid_err_func_wrapper(params)
except (GoalAchievedTerminate, MaxFidFuncCallTerminate):
pass
except Exception as ex:
raise ex
infid_sum += infid

self.mean_infid = np.mean(infid_sum)
Expand Down
24 changes: 18 additions & 6 deletions src/qutip_qoc/grape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import numpy as np
import qutip_qtrl.logging_utils as logging
import qutip_qtrl.optimizer as opt
from qutip_qtrl.errors import (
GoalAchievedTerminate,
GradMinReachedTerminate,
MaxFidFuncCallTerminate,
)
import types
import copy

Expand Down Expand Up @@ -115,14 +120,11 @@ class Multi_GRAPE:
to optimize multiple objectives simultaneously
"""

grapes: list = []

def __init__(
self,
qtrl_optimizers,
):
if not isinstance(qtrl_optimizers, list):
qtrl_optimizers = [qtrl_optimizers]
self.grapes = []
for optim in qtrl_optimizers:
grape = copy.deepcopy(optim)
grape.fid_err_func_wrapper = types.MethodType(fid_err_func_wrapper, grape)
Expand All @@ -138,7 +140,12 @@ def goal_fun(self, params):
infid_sum = 0

for grape in self.grapes: # TODO: parallelize
infid = grape.fid_err_func_wrapper(params)
try:
infid = grape.fid_err_func_wrapper(params)
except (GoalAchievedTerminate, MaxFidFuncCallTerminate):
pass
except Exception as ex:
raise ex
infid_sum += infid

self.mean_infid = np.mean(infid_sum)
Expand All @@ -151,7 +158,12 @@ def grad_fun(self, params):
grads = 0

for g in self.grapes:
grad = g.fid_err_grad_wrapper(params)
try:
grad = g.fid_err_grad_wrapper(params)
except GradMinReachedTerminate:
pass
except Exception as ex:
raise ex
grads += grad

return grads
81 changes: 68 additions & 13 deletions src/qutip_qoc/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
__all__ = ["Result"]


class Stats:
"""
Only for backward compatibility with qtrl.
"""

def __init__(self, result):
self.result = result

def report(self):
print(self.result)


class Result:
"""
Class for storing the results of a pulse control optimization run.
Expand Down Expand Up @@ -99,11 +111,14 @@ def __init__(
self.guess_params = guess_params
self.new_params = new_params
self._optimized_params = optimized_params
self.final_states = final_states
self._final_states = final_states
self.infidelity = infidelity
self.var_time = var_time
self.qtrl_optimizers = qtrl_optimizers

# qtrl result backward compatibility
self.stats = Stats(self)

def __str__(self):
return textwrap.dedent(
r"""
Expand Down Expand Up @@ -235,14 +250,14 @@ def optimized_objectives(self):
opt_obj = []

for obj in self.objectives:
if callable(obj.H[1][1]): # GOAT/JOPT
if callable(obj.H[1][1]): # GOAT, JOPT
optimized_H = obj.H
else:
optimized_H = [obj.H[0]] # drift
for Hc, cf in zip(obj.H[1:], self.optimized_controls):
if isinstance(Hc, qt.Qobj): # parameterized CRAB
optimized_H.append([Hc, cf])
else: # discrete control as in GRAPE/CRAB
else: # discrete control as in GRAPE, CRAB
optimized_H.append([Hc[0], cf])

opt_obj.append(Objective(obj.initial, optimized_H, obj.target))
Expand All @@ -261,15 +276,14 @@ def final_states(self):
evo_time = self.time_interval.evo_time

para_keys = []
if not self.qtrl_optimizers: # GOAT/JOPT
args_dict = {}
if not self.qtrl_optimizers: # GOAT, JOPT
# extract parameter names from control functions f(t, para_key)
c_sigs = [signature(Hc[1]) for Hc in self.objectives[0].H[1:]]
c_keys = [sig.parameters.keys() for sig in c_sigs]
para_keys = [list(keys)[1] for keys in c_keys]

args_dict = {}
for key, val in zip(para_keys, self.optimized_params):
args_dict[key] = val
for key, val in zip(para_keys, self.optimized_params):
args_dict[key] = val

# choose solver method based on type of control function
if isinstance(
Expand All @@ -285,7 +299,7 @@ def final_states(self):
if args_dict
else qt.QobjEvo(obj.H, tlist=self.time_interval.tslots)
)
solver = None
# solver = None

if obj.H[0].issuper: # choose solver
solver = qt.MESolver(
Expand All @@ -307,14 +321,25 @@ def final_states(self):
states.append( # compute evolution
solver.run(obj.initial, tlist=[0.0, evo_time]).final_state
)
if (
self.qtrl_optimizers
): # GRAPE HACK: this should be the same result through evolution
if not isinstance(
self.qtrl_optimizers[0].pulse_generator, list
): # only for GRAPE
dyn = self.qtrl_optimizers[0].dynamics
a = np.hstack([c for c in self.optimized_controls])
amps = self.qtrl_optimizers[0]._get_ctrl_amps(a)
dyn.update_ctrl_amps(amps)
# fid_err = dyn.fid_computer.get_fid_err()
# grad_norm_final = dyn.fid_computer.grad_norm
# final_amps = dyn.ctrl_amps
final_evo = dyn.full_evo
states = [final_evo]

self._final_states = states
return self._final_states

@final_states.setter
def final_states(self, states):
self._final_states = states

def update(self, infidelity, parameters):
self.infidelity = infidelity
self.new_params = parameters
Expand All @@ -330,3 +355,33 @@ def load(cls, filename, objectives=None):
result = pickle.load(dump_fh)
result.objectives = objectives
return result

@property
def evo_full_final(self):
# qtrl result backward compatibility # TODO: deprecated warning
return self.final_states[0]

@property
def fid_err(self):
# qtrl result backward compatibility # TODO: deprecated warning
return self.infidelity

@property
def grad_norm_final(self):
# qtrl result backward compatibility # TODO: deprecated warning
return None # not supported

@property
def termination_reason(self):
# qtrl result backward compatibility # TODO: deprecated warning
return self.message

@property
def num_iter(self):
# qtrl result backward compatibility # TODO: deprecated warning
return self.n_iters

@property
def wall_time(self):
# qtrl result backward compatibility # TODO: deprecated warning
return self.total_seconds

0 comments on commit 86204a1

Please sign in to comment.