Skip to content

Commit

Permalink
Merge pull request #2337 from TorkelE/bif_extension
Browse files Browse the repository at this point in the history
Improve BifurcationKit extension
  • Loading branch information
YingboMa authored Nov 7, 2023
2 parents 70c3252 + b7ab44f commit 397ab12
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 19 deletions.
105 changes: 96 additions & 9 deletions ext/MTKBifurcationKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,76 @@ module MTKBifurcationKitExt
using ModelingToolkit, Setfield
import BifurcationKit

### Observable Plotting Handling ###

# Functor used when the plotting variable is an observable. Keeps track of the required information for computing the observable's value at each point of the bifurcation diagram.
struct ObservableRecordFromSolution{S, T}
# The equations determining the observables values.
obs_eqs::S
# The index of the observable that we wish to plot.
target_obs_idx::Int64
# The final index in subs_vals that contains a state.
state_end_idxs::Int64
# The final index in subs_vals that contains a param.
param_end_idxs::Int64
# The index (in subs_vals) that contain the bifurcation parameter.
bif_par_idx::Int64
# A Vector of pairs (Symbolic => value) with teh default values of all system variables and parameters.
subs_vals::T

function ObservableRecordFromSolution(nsys::NonlinearSystem,
plot_var,
bif_idx,
u0_vals,
p_vals) where {S, T}
obs_eqs = observed(nsys)
target_obs_idx = findfirst(isequal(plot_var, eq.lhs) for eq in observed(nsys))
state_end_idxs = length(states(nsys))
param_end_idxs = state_end_idxs + length(parameters(nsys))

bif_par_idx = state_end_idxs + bif_idx
# Gets the (base) substitution values for states.
subs_vals_states = Pair.(states(nsys), u0_vals)
# Gets the (base) substitution values for parameters.
subs_vals_params = Pair.(parameters(nsys), p_vals)
# Gets the (base) substitution values for observables.
subs_vals_obs = [obs.lhs => substitute(obs.rhs,
[subs_vals_states; subs_vals_params]) for obs in observed(nsys)]
# Sometimes observables depend on other observables, hence we make a second upate to this vector.
subs_vals_obs = [obs.lhs => substitute(obs.rhs,
[subs_vals_states; subs_vals_params; subs_vals_obs]) for obs in observed(nsys)]
# During the bifurcation process, teh value of some states, parameters, and observables may vary (and are calculated in each step). Those that are not are stored in this vector
subs_vals = [subs_vals_states; subs_vals_params; subs_vals_obs]

param_end_idxs = state_end_idxs + length(parameters(nsys))
new{typeof(obs_eqs), typeof(subs_vals)}(obs_eqs,
target_obs_idx,
state_end_idxs,
param_end_idxs,
bif_par_idx,
subs_vals)
end
end
# Functor function that computes the value.
function (orfs::ObservableRecordFromSolution)(x, p)
# Updates the state values (in subs_vals).
for state_idx in 1:(orfs.state_end_idxs)
orfs.subs_vals[state_idx] = orfs.subs_vals[state_idx][1] => x[state_idx]
end

# Updates the bifurcation parameters value (in subs_vals).
orfs.subs_vals[orfs.bif_par_idx] = orfs.subs_vals[orfs.bif_par_idx][1] => p

# Updates the observable values (in subs_vals).
for (obs_idx, obs_eq) in enumerate(orfs.obs_eqs)
orfs.subs_vals[orfs.param_end_idxs + obs_idx] = orfs.subs_vals[orfs.param_end_idxs + obs_idx][1] => substitute(obs_eq.rhs,
orfs.subs_vals)
end

# Substitutes in the value for all states, parameters, and observables into the equation for the designated observable.
return substitute(orfs.obs_eqs[orfs.target_obs_idx].rhs, orfs.subs_vals)
end

### Creates BifurcationProblem Overloads ###

# When input is a NonlinearSystem.
Expand All @@ -23,20 +93,37 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
F = ofun.f
J = jac ? ofun.jac : nothing

# Computes bifurcation parameter and plot var indexes.
# Converts the input state guess.
u0_bif_vals = ModelingToolkit.varmap_to_vars(u0_bif,
states(nsys);
defaults = nsys.defaults)
p_vals = ModelingToolkit.varmap_to_vars(ps, parameters(nsys); defaults = nsys.defaults)

# Computes bifurcation parameter and the plotting function.
bif_idx = findfirst(isequal(bif_par), parameters(nsys))
if !isnothing(plot_var)
plot_idx = findfirst(isequal(plot_var), states(nsys))
record_from_solution = (x, p) -> x[plot_idx]
end
# If the plot var is a normal state.
if any(isequal(plot_var, var) for var in states(nsys))
plot_idx = findfirst(isequal(plot_var), states(nsys))
record_from_solution = (x, p) -> x[plot_idx]

# Converts the input state guess.
u0_bif = ModelingToolkit.varmap_to_vars(u0_bif, states(nsys))
ps = ModelingToolkit.varmap_to_vars(ps, parameters(nsys))
# If the plot var is an observed state.
elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys))
record_from_solution = ObservableRecordFromSolution(nsys,
plot_var,
bif_idx,
u0_bif_vals,
p_vals)

# If neither an variable nor observable, throw an error.
else
error("The plot variable ($plot_var) was neither recognised as a system state nor observable.")
end
end

return BifurcationKit.BifurcationProblem(F,
u0_bif,
ps,
u0_bif_vals,
p_vals,
(@lens _[bif_idx]),
args...;
record_from_solution = record_from_solution,
Expand Down
121 changes: 111 additions & 10 deletions test/extensions/bifurcationkit.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,136 @@
using BifurcationKit, ModelingToolkit, Test

# Checks pitchfork diagram and that there are the correct number of branches (a main one and two children)
# Simple pitchfork diagram, compares solution to native BifurcationKit, checks they are identical.
# Checks using `jac=false` option.
let
# Creates model.
@variables t x(t) y(t)
@parameters μ α
eqs = [0 ~ μ * x - x^3 + α * y,
0 ~ -y]
@named nsys = NonlinearSystem(eqs, [x, y], [μ, α])

# Creates BifurcationProblem
bif_par = μ
p_start ==> -1.0, α => 1.0]
u0_guess = [x => 1.0, y => 1.0]
plot_var = x

using BifurcationKit
bprob = BifurcationProblem(nsys,
u0_guess,
p_start,
bif_par;
plot_var = plot_var,
jac = false)

# Conputes bifurcation diagram.
p_span = (-4.0, 6.0)
opts_br = ContinuationPar(max_steps = 500, p_min = p_span[1], p_max = p_span[2])
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)

# Computes bifurcation diagram using BifurcationKit directly (without going through MTK).
function f_BK(u, p)
x, y = u
μ, α = p
return* x - x^3 + α * y, -y]
end
bprob_BK = BifurcationProblem(f_BK,
[1.0, 1.0],
[-1.0, 1.0],
(@lens _[1]);
record_from_solution = (x, p) -> x[1])
bif_dia_BK = bifurcationdiagram(bprob_BK,
PALC(),
2,
(args...) -> opts_br;
bothside = true)

# Compares results.
@test getfield.(bif_dia.γ.branch, :x) getfield.(bif_dia_BK.γ.branch, :x)
@test getfield.(bif_dia.γ.branch, :param) getfield.(bif_dia_BK.γ.branch, :param)
@test bif_dia.γ.specialpoint[1].x == bif_dia_BK.γ.specialpoint[1].x
@test bif_dia.γ.specialpoint[1].param == bif_dia_BK.γ.specialpoint[1].param
@test bif_dia.γ.specialpoint[1].type == bif_dia_BK.γ.specialpoint[1].type
end

# Lotka–Volterra model, checks exact position of bifurcation variable and bifurcation points.
# Checks using ODESystem input.
let
# Creates a Lotka–Volterra model.
@parameters α a b
@variables t x(t) y(t) z(t)
D = Differential(t)
eqs = [D(x) ~ -x + a * y + x^2 * y,
D(y) ~ b - a * y - x^2 * y]
@named sys = ODESystem(eqs)

# Creates BifurcationProblem
bprob = BifurcationProblem(sys,
[x => 1.5, y => 1.0],
[a => 0.1, b => 0.5],
b;
plot_var = x)

# Computes bifurcation diagram.
p_span = (0.0, 2.0)
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 2000)
opts_br = ContinuationPar(dsmax = 0.05,
max_steps = 500,
newton_options = opt_newton,
p_min = p_span[1],
p_max = p_span[2],
n_inversion = 4)
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)

# Tests that the diagram has the correct values (x = b)
all([b.x b.param for b in bif_dia.γ.branch])

# Tests that we get two Hopf bifurcations at the correct positions.
hopf_points = sort(getfield.(filter(sp -> sp.type == :hopf, bif_dia.γ.specialpoint),
:x);
by = x -> x[1])
@test length(hopf_points) == 2
@test hopf_points[1] [0.41998733080424205, 1.5195495712453098]
@test hopf_points[2] [0.7899715592573977, 1.0910379583813192]
end

# Simple fold bifurcation model, checks exact position of bifurcation variable and bifurcation points.
# Checks that default parameter values are accounted for.
# Checks that observables (that depend on other observables, as in this case) are accounted for.
let
# Creates model, and uses `structural_simplify` to generate observables.
@parameters μ p=2
@variables t x(t) y(t) z(t)
D = Differential(t)
eqs = [0 ~ μ - x^3 + 2x^2,
0 ~ p * μ - y,
0 ~ y - z]
@named nsys = NonlinearSystem(eqs, [x, y, z], [μ, p])
nsys = structural_simplify(nsys)

# Creates BifurcationProblem.
bif_par = μ
p_start ==> 1.0]
u0_guess = [x => 1.0, y => 0.1, z => 0.1]
plot_var = x
bprob = BifurcationProblem(nsys, u0_guess, p_start, bif_par; plot_var = plot_var)

# Computes bifurcation diagram.
p_span = (-4.3, 12.0)
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 20)
opts_br = ContinuationPar(dsmin = 0.001, dsmax = 0.05, ds = 0.01,
max_steps = 100, nev = 2, newton_options = opt_newton,
p_min = p_span[1], p_max = p_span[2],
detect_bifurcation = 3, n_inversion = 4, tol_bisection_eigenvalue = 1e-8,
dsmin_bisection = 1e-9)
opts_br = ContinuationPar(dsmax = 0.05,
max_steps = 500,
newton_options = opt_newton,
p_min = p_span[1],
p_max = p_span[2],
n_inversion = 4)
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)

bf = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
# Tests that the diagram has the correct values (x = b)
all([b.x 2 * b.param for b in bif_dia.γ.branch])

@test length(bf.child) == 2
# Tests that we get two fold bifurcations at the correct positions.
fold_points = sort(getfield.(filter(sp -> sp.type == :bp, bif_dia.γ.specialpoint),
:param))
@test length(fold_points) == 2
@test fold_points [-1.1851851706940317, -5.6734983580551894e-6] # test that they occur at the correct parameter values).
end

0 comments on commit 397ab12

Please sign in to comment.