Skip to content

Commit

Permalink
refactor callbacks and problem templates to use rollout fidelity and …
Browse files Browse the repository at this point in the history
…system return values
  • Loading branch information
aarontrowbridge committed Jan 14, 2025
1 parent 45b6d36 commit f72f5fd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
22 changes: 11 additions & 11 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ function best_rollout_callback(
end

function best_rollout_fidelity_callback(prob::QuantumControlProblem)
return best_rollout_callback(prob, fidelity)
return best_rollout_callback(prob, rollout_fidelity)
end

function best_unitary_rollout_fidelity_callback(prob::QuantumControlProblem)
return best_rollout_callback(prob, unitary_fidelity)
return best_rollout_callback(prob, unitary_rollout_fidelity)
end

function trajectory_history_callback(prob::QuantumControlProblem)
Expand All @@ -58,13 +58,13 @@ end
const MOI = MathOptInterface
include("../test/test_utils.jl")

prob = smooth_quantum_state_problem()
prob, sys = smooth_quantum_state_problem(return_system=true)

my_callback = (kwargs...) -> false

initial = fidelity(prob)
initial = rollout_fidelity(prob, sys)
solve!(prob, max_iter=20, callback=my_callback)
final = fidelity(prob)
final = rollout_fidelity(prob, sys)

# callback forces problem to exit early as per Ipopt documentation
@test MOI.get(prob.optimizer, MOI.TerminationStatus()) == MOI.INTERRUPTED
Expand All @@ -78,7 +78,7 @@ end
const MOI = MathOptInterface
include("../test/test_utils.jl")

prob = smooth_quantum_state_problem()
prob, sys = smooth_quantum_state_problem(return_system=true)

callback, trajectory_history = trajectory_history_callback(prob)

Expand All @@ -98,15 +98,15 @@ end
@test length(best_trajs) == 0

# measure fidelity
before = fidelity(prob)
before = rollout_fidelity(prob, sys)
solve!(prob, max_iter=20, callback=callback)

# length must increase if iterations are made
@test length(best_trajs) > 0
@test best_trajs[end] isa NamedTrajectory

# fidelity ranking
after = fidelity(prob)
after = rollout_fidelity(prob, sys)
best = fidelity(best_trajs[end], system)

@test before < after
Expand All @@ -126,15 +126,15 @@ end
@test length(best_trajs) == 0

# measure fidelity
before = unitary_fidelity(prob)
before = unitary_rollout_fidelity(prob, sys)
solve!(prob, max_iter=20, callback=callback)

# length must increase if iterations are made
@test length(best_trajs) > 0
@test best_trajs[end] isa NamedTrajectory

# fidelity ranking
after = unitary_fidelity(prob)
after = unitary_rollout_fidelity(prob, sys)
best = unitary_fidelity(best_trajs[end], system)

@test before < after
Expand All @@ -148,7 +148,7 @@ end
const MOI = MathOptInterface
include("../test/test_utils.jl")

prob = smooth_quantum_state_problem()
prob, sys = smooth_quantum_state_problem(return_system=true)

obj_vals = []
function get_history_callback(
Expand Down
6 changes: 3 additions & 3 deletions src/problem_templates/_problem_templates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ function apply_piccolo_options!(
constraints::AbstractVector{<:AbstractConstraint},
piccolo_options::PiccoloOptions,
traj::NamedTrajectory,
operator::Union{Nothing, AbstractPiccoloOperator},
state_name::Symbol,
operators::Union{Nothing, AbstractVector{<:AbstractPiccoloOperator}},
state_names::AbstractVector{Symbol},
timestep_name::Symbol
)
# TODO: This should be changed to leakage indices (more general, for states)
Expand Down Expand Up @@ -85,7 +85,7 @@ function apply_piccolo_options!(
piccolo_options::PiccoloOptions,
traj::NamedTrajectory,
operator::Union{Nothing, AbstractPiccoloOperator},
state_names::AbstractVector{<:Symbol},
state_name::Symbol,
timestep_name::Symbol
)
state_names = [
Expand Down

0 comments on commit f72f5fd

Please sign in to comment.