diff --git a/ext/MTKBifurcationKitExt.jl b/ext/MTKBifurcationKitExt.jl index 285d348bf5..f3c42d36ae 100644 --- a/ext/MTKBifurcationKitExt.jl +++ b/ext/MTKBifurcationKitExt.jl @@ -6,6 +6,76 @@ module MTKBifurcationKitExt using ModelingToolkit, Setfield import BifurcationKit +### Observable Plotting Handling ### + +# Functor used when the plotting variable is an observable. Keeps track of the required information for computing the observable's value at each point of the bifurcation diagram. +struct ObservableRecordFromSolution{S, T} + # The equations determining the observables values. + obs_eqs::S + # The index of the observable that we wish to plot. + target_obs_idx::Int64 + # The final index in subs_vals that contains a state. + state_end_idxs::Int64 + # The final index in subs_vals that contains a param. + param_end_idxs::Int64 + # The index (in subs_vals) that contain the bifurcation parameter. + bif_par_idx::Int64 + # A Vector of pairs (Symbolic => value) with teh default values of all system variables and parameters. + subs_vals::T + + function ObservableRecordFromSolution(nsys::NonlinearSystem, + plot_var, + bif_idx, + u0_vals, + p_vals) where {S, T} + obs_eqs = observed(nsys) + target_obs_idx = findfirst(isequal(plot_var, eq.lhs) for eq in observed(nsys)) + state_end_idxs = length(states(nsys)) + param_end_idxs = state_end_idxs + length(parameters(nsys)) + + bif_par_idx = state_end_idxs + bif_idx + # Gets the (base) substitution values for states. + subs_vals_states = Pair.(states(nsys), u0_vals) + # Gets the (base) substitution values for parameters. + subs_vals_params = Pair.(parameters(nsys), p_vals) + # Gets the (base) substitution values for observables. + subs_vals_obs = [obs.lhs => substitute(obs.rhs, + [subs_vals_states; subs_vals_params]) for obs in observed(nsys)] + # Sometimes observables depend on other observables, hence we make a second upate to this vector. + subs_vals_obs = [obs.lhs => substitute(obs.rhs, + [subs_vals_states; subs_vals_params; subs_vals_obs]) for obs in observed(nsys)] + # During the bifurcation process, teh value of some states, parameters, and observables may vary (and are calculated in each step). Those that are not are stored in this vector + subs_vals = [subs_vals_states; subs_vals_params; subs_vals_obs] + + param_end_idxs = state_end_idxs + length(parameters(nsys)) + new{typeof(obs_eqs), typeof(subs_vals)}(obs_eqs, + target_obs_idx, + state_end_idxs, + param_end_idxs, + bif_par_idx, + subs_vals) + end +end +# Functor function that computes the value. +function (orfs::ObservableRecordFromSolution)(x, p) + # Updates the state values (in subs_vals). + for state_idx in 1:(orfs.state_end_idxs) + orfs.subs_vals[state_idx] = orfs.subs_vals[state_idx][1] => x[state_idx] + end + + # Updates the bifurcation parameters value (in subs_vals). + orfs.subs_vals[orfs.bif_par_idx] = orfs.subs_vals[orfs.bif_par_idx][1] => p + + # Updates the observable values (in subs_vals). + for (obs_idx, obs_eq) in enumerate(orfs.obs_eqs) + orfs.subs_vals[orfs.param_end_idxs + obs_idx] = orfs.subs_vals[orfs.param_end_idxs + obs_idx][1] => substitute(obs_eq.rhs, + orfs.subs_vals) + end + + # Substitutes in the value for all states, parameters, and observables into the equation for the designated observable. + return substitute(orfs.obs_eqs[orfs.target_obs_idx].rhs, orfs.subs_vals) +end + ### Creates BifurcationProblem Overloads ### # When input is a NonlinearSystem. @@ -23,20 +93,37 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem, F = ofun.f J = jac ? ofun.jac : nothing - # Computes bifurcation parameter and plot var indexes. + # Converts the input state guess. + u0_bif_vals = ModelingToolkit.varmap_to_vars(u0_bif, + states(nsys); + defaults = nsys.defaults) + p_vals = ModelingToolkit.varmap_to_vars(ps, parameters(nsys); defaults = nsys.defaults) + + # Computes bifurcation parameter and the plotting function. bif_idx = findfirst(isequal(bif_par), parameters(nsys)) if !isnothing(plot_var) - plot_idx = findfirst(isequal(plot_var), states(nsys)) - record_from_solution = (x, p) -> x[plot_idx] - end + # If the plot var is a normal state. + if any(isequal(plot_var, var) for var in states(nsys)) + plot_idx = findfirst(isequal(plot_var), states(nsys)) + record_from_solution = (x, p) -> x[plot_idx] - # Converts the input state guess. - u0_bif = ModelingToolkit.varmap_to_vars(u0_bif, states(nsys)) - ps = ModelingToolkit.varmap_to_vars(ps, parameters(nsys)) + # If the plot var is an observed state. + elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys)) + record_from_solution = ObservableRecordFromSolution(nsys, + plot_var, + bif_idx, + u0_bif_vals, + p_vals) + + # If neither an variable nor observable, throw an error. + else + error("The plot variable ($plot_var) was neither recognised as a system state nor observable.") + end + end return BifurcationKit.BifurcationProblem(F, - u0_bif, - ps, + u0_bif_vals, + p_vals, (@lens _[bif_idx]), args...; record_from_solution = record_from_solution, diff --git a/test/extensions/bifurcationkit.jl b/test/extensions/bifurcationkit.jl index 45f0a89d57..53c7369485 100644 --- a/test/extensions/bifurcationkit.jl +++ b/test/extensions/bifurcationkit.jl @@ -1,19 +1,20 @@ using BifurcationKit, ModelingToolkit, Test -# Checks pitchfork diagram and that there are the correct number of branches (a main one and two children) +# Simple pitchfork diagram, compares solution to native BifurcationKit, checks they are identical. +# Checks using `jac=false` option. let + # Creates model. @variables t x(t) y(t) @parameters μ α eqs = [0 ~ μ * x - x^3 + α * y, 0 ~ -y] @named nsys = NonlinearSystem(eqs, [x, y], [μ, α]) + # Creates BifurcationProblem bif_par = μ p_start = [μ => -1.0, α => 1.0] u0_guess = [x => 1.0, y => 1.0] plot_var = x - - using BifurcationKit bprob = BifurcationProblem(nsys, u0_guess, p_start, @@ -21,15 +22,115 @@ let plot_var = plot_var, jac = false) + # Conputes bifurcation diagram. p_span = (-4.0, 6.0) + opts_br = ContinuationPar(max_steps = 500, p_min = p_span[1], p_max = p_span[2]) + bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true) + + # Computes bifurcation diagram using BifurcationKit directly (without going through MTK). + function f_BK(u, p) + x, y = u + μ, α = p + return [μ * x - x^3 + α * y, -y] + end + bprob_BK = BifurcationProblem(f_BK, + [1.0, 1.0], + [-1.0, 1.0], + (@lens _[1]); + record_from_solution = (x, p) -> x[1]) + bif_dia_BK = bifurcationdiagram(bprob_BK, + PALC(), + 2, + (args...) -> opts_br; + bothside = true) + + # Compares results. + @test getfield.(bif_dia.γ.branch, :x) ≈ getfield.(bif_dia_BK.γ.branch, :x) + @test getfield.(bif_dia.γ.branch, :param) ≈ getfield.(bif_dia_BK.γ.branch, :param) + @test bif_dia.γ.specialpoint[1].x == bif_dia_BK.γ.specialpoint[1].x + @test bif_dia.γ.specialpoint[1].param == bif_dia_BK.γ.specialpoint[1].param + @test bif_dia.γ.specialpoint[1].type == bif_dia_BK.γ.specialpoint[1].type +end + +# Lotka–Volterra model, checks exact position of bifurcation variable and bifurcation points. +# Checks using ODESystem input. +let + # Creates a Lotka–Volterra model. + @parameters α a b + @variables t x(t) y(t) z(t) + D = Differential(t) + eqs = [D(x) ~ -x + a * y + x^2 * y, + D(y) ~ b - a * y - x^2 * y] + @named sys = ODESystem(eqs) + + # Creates BifurcationProblem + bprob = BifurcationProblem(sys, + [x => 1.5, y => 1.0], + [a => 0.1, b => 0.5], + b; + plot_var = x) + + # Computes bifurcation diagram. + p_span = (0.0, 2.0) + opt_newton = NewtonPar(tol = 1e-9, max_iterations = 2000) + opts_br = ContinuationPar(dsmax = 0.05, + max_steps = 500, + newton_options = opt_newton, + p_min = p_span[1], + p_max = p_span[2], + n_inversion = 4) + bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true) + + # Tests that the diagram has the correct values (x = b) + all([b.x ≈ b.param for b in bif_dia.γ.branch]) + + # Tests that we get two Hopf bifurcations at the correct positions. + hopf_points = sort(getfield.(filter(sp -> sp.type == :hopf, bif_dia.γ.specialpoint), + :x); + by = x -> x[1]) + @test length(hopf_points) == 2 + @test hopf_points[1] ≈ [0.41998733080424205, 1.5195495712453098] + @test hopf_points[2] ≈ [0.7899715592573977, 1.0910379583813192] +end + +# Simple fold bifurcation model, checks exact position of bifurcation variable and bifurcation points. +# Checks that default parameter values are accounted for. +# Checks that observables (that depend on other observables, as in this case) are accounted for. +let + # Creates model, and uses `structural_simplify` to generate observables. + @parameters μ p=2 + @variables t x(t) y(t) z(t) + D = Differential(t) + eqs = [0 ~ μ - x^3 + 2x^2, + 0 ~ p * μ - y, + 0 ~ y - z] + @named nsys = NonlinearSystem(eqs, [x, y, z], [μ, p]) + nsys = structural_simplify(nsys) + + # Creates BifurcationProblem. + bif_par = μ + p_start = [μ => 1.0] + u0_guess = [x => 1.0, y => 0.1, z => 0.1] + plot_var = x + bprob = BifurcationProblem(nsys, u0_guess, p_start, bif_par; plot_var = plot_var) + + # Computes bifurcation diagram. + p_span = (-4.3, 12.0) opt_newton = NewtonPar(tol = 1e-9, max_iterations = 20) - opts_br = ContinuationPar(dsmin = 0.001, dsmax = 0.05, ds = 0.01, - max_steps = 100, nev = 2, newton_options = opt_newton, - p_min = p_span[1], p_max = p_span[2], - detect_bifurcation = 3, n_inversion = 4, tol_bisection_eigenvalue = 1e-8, - dsmin_bisection = 1e-9) + opts_br = ContinuationPar(dsmax = 0.05, + max_steps = 500, + newton_options = opt_newton, + p_min = p_span[1], + p_max = p_span[2], + n_inversion = 4) + bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true) - bf = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true) + # Tests that the diagram has the correct values (x = b) + all([b.x ≈ 2 * b.param for b in bif_dia.γ.branch]) - @test length(bf.child) == 2 + # Tests that we get two fold bifurcations at the correct positions. + fold_points = sort(getfield.(filter(sp -> sp.type == :bp, bif_dia.γ.specialpoint), + :param)) + @test length(fold_points) == 2 + @test fold_points ≈ [-1.1851851706940317, -5.6734983580551894e-6] # test that they occur at the correct parameter values). end