Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support makie backends #181

Merged
merged 7 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -29,14 +28,15 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[weakdeps]
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[extensions]
GLMakieExt = "GLMakie"
MakieExt = "Makie"
JLD2Ext = "JLD2"
PlotsExt = "Plots"
PlotsExt = ["Plots", "RecipesBase"]

[compat]
Accessors = "0.1"
Expand All @@ -47,7 +47,7 @@ DataStructures = "0.17, 0.18"
DocStringExtensions = "^0.8, ^0.9"
FastGaussQuadrature = "^0.4, ^0.5, 1"
ForwardDiff = "^0.10"
GLMakie = "0.10"
Makie = "^0.21"
IterativeSolvers = "0.8.4, 0.8.5, ^0.9"
JLD2 = "0.4, 0.5"
KrylovKit = "^0.7, ^0.8"
Expand Down
2 changes: 1 addition & 1 deletion examples/SH3d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const BK = BifurcationKit

Makie.inline!(true)

contour3dMakie(x; k...) = GLMakie.contour(x; k...)
contour3dMakie(x; k...) = Makie.contour(x; k...)
contour3dMakie(x::AbstractVector; k...) = contour3dMakie(reshape(x,Nx,Ny,Nz); k...)
contour3dMakie(ax, x; k...) = (contour(ax, x; k...))
contour3dMakie(ax, x::AbstractVector; k...) = contour3dMakie(ax, reshape(x,Nx,Ny,Nz); k...)
Expand Down
9 changes: 5 additions & 4 deletions ext/GLMakieExt/GLMakieExt.jl → ext/MakieExt/MakieExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module GLMakieExt
using GLMakie, BifurcationKit
module MakieExt
using Makie, BifurcationKit
import BifurcationKit: _plot_backend,
plot,
plot,
plot!,
hasbranch,
plot_branch_cont,
Expand All @@ -22,13 +22,14 @@ module GLMakieExt
get_color,
colorbif,
get_plot_backend,
set_plot_backend!,
BK_Makie,
plotAllDCBranch,
plot_DCont_branch
include("plot.jl")

function __init__()
_plot_backend[] = BK_Makie()
set_plot_backend!(BK_Makie())
return nothing
end
end
174 changes: 84 additions & 90 deletions ext/GLMakieExt/plot.jl → ext/MakieExt/plot.jl
Original file line number Diff line number Diff line change
@@ -1,141 +1,135 @@
using GLMakie: Point2f0
using Makie: Point2f0

function GLMakie.convert_arguments(::PointBased, contres::AbstractBranchResult, vars = nothing, applytoY = identity, applytoX = identity)
function Makie.convert_arguments(::PointBased, contres::AbstractBranchResult, vars = nothing, applytoY = identity, applytoX = identity)
ind1, ind2 = get_plot_vars(contres, vars)
return ([Point2f0(i, j) for (i, j) in zip(map(applytoX, getproperty(contres.branch, ind1)), map(applytoY, getproperty(contres.branch, ind2)))],)
end

function plot!(ax1, contres::AbstractBranchResult;
plotfold = false,
plotstability = true,
plotspecialpoints = true,
putspecialptlegend = true,
filterspecialpoints = false,
vars = nothing,
linewidthunstable = 1.0,
linewidthstable = 3.0linewidthunstable,
plotcirclesbif = true,
branchlabel = "",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My coding style is to have spaces around =

applytoY = identity,
applytoX = identity)

function isplit(x::AbstractVector{T}, indices::AbstractVector{<:Integer}, splitval::Bool = true) where {T<:Real}
# Adapt behavior for CairoMakie only
if !isempty(indices) && isdefined(Main, :CairoMakie) && Makie.current_backend() == Main.CairoMakie
xx = similar(x, length(x) + 2 * (length(indices)))
for (i, ind) in enumerate(indices)
if ind == first(indices)
xx[1:ind] .= @views x[1:ind]
else
xx[(2*(i-1)).+(indices[i-1]+1:ind)] .= @views x[(indices[i-1]+1:ind)]
end
if !splitval
xx[2*(i-1)+ind] = x[ind-1]
end
# Add a NaN is necessary, otherwise continue with same value as before (useful for linewidth)
xx[2*(i-1)+ind+1] = splitval ? NaN : x[ind-1]
# Repeat last value before NaN, but adapt for linewidth
xx[2*(i-1)+ind+2] = splitval ? x[ind] : x[ind+1]
end
# Fill the rest of the extended array
xx[last(indices)+2*length(indices)+1:end] .= @views x[last(indices)+1:end]
return xx
else
return x
end
end

function plot!(ax1, contres::AbstractBranchResult; plotfold = false, plotstability = true, plotspecialpoints = true, putspecialptlegend = true, filterspecialpoints = false, vars = nothing, linewidthunstable = 1.0, linewidthstable = 3.0linewidthunstable, plotcirclesbif = true, branchlabel = nothing, applytoY = identity, applytoX = identity)

# names for axis labels
ind1, ind2 = get_plot_vars(contres, vars)
xlab, ylab = get_axis_labels(ind1, ind2, contres)

# stability linewidth
linewidth = linewidthunstable
indices = [sp.idx for sp in contres.specialpoint if sp.type !== :endpoint]
# isplit required to work with CairoMakie due to change of linewidth for stability
if _hasstability(contres) && plotstability
linewidth = map(x -> isodd(x) ? linewidthstable : linewidthunstable, contres.stable)
end
if branchlabel == ""
lines!(ax1, map(applytoX, getproperty(contres.branch, ind1)), map(applytoY, getproperty(contres.branch, ind2)); linewidth)
else
lines!(ax1, map(applytoX, getproperty(contres.branch, ind1)), map(applytoY, getproperty(contres.branch, ind2)), linewidth = linewidth, label = branchlabel)
linewidth = isplit(map(x -> x ? linewidthstable : linewidthunstable, contres.stable), indices, false)
end
xbranch = isplit(map(applytoX, getproperty(contres.branch, ind1)), indices)
ybranch = isplit(map(applytoY, getproperty(contres.branch, ind2)), indices)
lines!(ax1, xbranch, ybranch, linewidth = linewidth, label = branchlabel)
ax1.xlabel = xlab
ax1.ylabel = ylab

# display bifurcation points
bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres)-1), contres.specialpoint)
bifpt = filter(x -> (x.type != :none) && (x.type != :endpoint) && (plotfold || x.type != :fold) && (x.idx <= length(contres) - 1), contres.specialpoint)
if length(bifpt) >= 1 && plotspecialpoints #&& (ind1 == :param)
if filterspecialpoints == true
bifpt = filterBifurcations(bifpt)
end
scatter!(ax1,
[applytoX(getproperty(contres[pt.idx], ind1)) for pt in bifpt],
[applytoY(getproperty(contres[pt.idx], ind2)) for pt in bifpt];
marker = map(x -> (x.status == :guess) && (plotcirclesbif==false) ? :rect : :circle, bifpt),
markersize = 10,
color = map(x -> get_color(x.type), bifpt),
)
end

scatter!(ax1, [applytoX(getproperty(contres[pt.idx], ind1)) for pt in bifpt], [applytoY(getproperty(contres[pt.idx], ind2)) for pt in bifpt]; marker = map(x -> (x.status == :guess) && (plotcirclesbif == false) ? :rect : :circle, bifpt), markersize = 10, color = map(x -> get_color(x.type), bifpt))
end

# add legend for bifurcation points
if putspecialptlegend && length(bifpt) >= 1
bps = unique(x -> x.type, [pt for pt in bifpt if (pt.type != :none && (plotfold || pt.type != :fold))])
(length(bps) == 0) && return
for pt in bps
scatter!(ax1,
[applytoX(getproperty(contres[pt.idx], ind1))],
[applytoY(getproperty(contres[pt.idx], ind2))];
color = get_color(pt.type),
markersize = 10,
label = "$(pt.type)")
scatter!(ax1, [applytoX(getproperty(contres[pt.idx], ind1))], [applytoY(getproperty(contres[pt.idx], ind2))]; color = get_color(pt.type), markersize = 10, label = "$(pt.type)")
end
GLMakie.axislegend(ax1, merge = true, unique = true)
Makie.axislegend(ax1, merge = true, unique = true)
end
ax1
end

function plot_branch_cont(contres::ContResult,
state,
iter,
plotuserfunction;
plotfold = false,
plotstability = true,
plotspecialpoints = true,
putspecialptlegend = true,
filterspecialpoints = false,
linewidthunstable = 1.0,
linewidthstable = 3.0linewidthunstable,
plotcirclesbif = true,
applytoY = identity,
applytoX = identity)
function plot_branch_cont(contres::ContResult, state, iter, plotuserfunction; plotfold = false, plotstability = true, plotspecialpoints = true, putspecialptlegend = true, filterspecialpoints = false, linewidthunstable = 1.0, linewidthstable = 3.0linewidthunstable, plotcirclesbif = true, applytoY = identity, applytoX = identity)
sol = getsolution(state)
if length(contres) == 0; return ; end

if length(contres) == 0
return
end

# names for axis labels
ind1, ind2 = get_plot_vars(contres, nothing)
xlab, ylab = get_axis_labels(ind1, ind2, contres)

# stability linewidth
linewidth = linewidthunstable
if _hasstability(contres) && plotstability
linewidth = map(x -> isodd(x) ? linewidthstable : linewidthunstable, contres.stable)
end

fig = Figure(size = (1200, 700))
ax1 = fig[1:2, 1] = Axis(fig, xlabel = String(xlab), ylabel = String(ylab), tellheight = true)

ax2 = fig[1, 2] = Axis(fig, xlabel = "step [$(state.step)]", ylabel = String(xlab))
lines!(ax2, contres.step, contres.param, linewidth = linewidth)

if compute_eigenelements(iter)
eigvals = contres.eig[end].eigenvals
ax_ev = fig[3, 1:2] = Axis(fig, xlabel = "ℜ", ylabel = "ℑ")
scatter!(ax_ev, real.(eigvals), imag.(eigvals), strokewidth = 0, markersize = 10, color = :black)
# add stability boundary
maxIm = maximum(imag, eigvals)
minIm = minimum(imag, eigvals)
if maxIm-minIm < 1e-6
if maxIm - minIm < 1e-6
maxIm, minIm = 1, -1
end
lines!(ax_ev, [0, 0], [maxIm, minIm], color = :blue, linewidth = linewidthunstable)
end

# plot arrow to indicate the order of computation
if length(contres) > 1
x = contres.branch[end].param
y = getproperty(contres.branch,1)[end]
y = getproperty(contres.branch, 1)[end]
u = contres.branch[end].param - contres.branch[end-1].param
v = getproperty(contres.branch,1)[end] - getproperty(contres.branch,1)[end-1]
GLMakie.arrows!(ax1, [x], [y], [u], [v], color = :green, arrowsize = 20,)
v = getproperty(contres.branch, 1)[end] - getproperty(contres.branch, 1)[end-1]
Makie.arrows!(ax1, [x], [y], [u], [v], color = :green, arrowsize = 20)
end

plot!(ax1, contres; plotfold, plotstability, plotspecialpoints, putspecialptlegend, filterspecialpoints, linewidthunstable, linewidthstable, plotcirclesbif, applytoY, applytoX)

if isnothing(plotuserfunction) == false
ax_perso = fig[2, 2] = Axis(fig, tellheight = true)
plotuserfunction(ax_perso, sol.u, sol.p; ax1 = ax1)
end

display(fig)
fig
end

function plot(contres::AbstractBranchResult; kP...)
if length(contres) == 0; return ;end
if length(contres) == 0
return
end

ind1, ind2 = get_plot_vars(contres, nothing)
xlab, ylab = get_axis_labels(ind1, ind2, contres)
Expand All @@ -150,17 +144,17 @@ end

plot(brdc::DCResult; kP...) = plot(brdc.branches...; kP...)

function plot(brs::AbstractBranchResult...;
branchlabel = ["$i" for i=1:length(brs)],
kP...)
if length(brs) == 0; return ;end
function plot(brs::AbstractBranchResult...; branchlabel = ["$i" for i = 1:length(brs)], kP...)
if length(brs) == 0
return
end
fig = Figure()
ax1 = fig[1, 1] = Axis(fig)

for (id, contres) in pairs(brs)
plot!(ax1, contres; branchlabel = branchlabel[id], kP...)
end
GLMakie.axislegend(ax1, merge = true, unique = true)
Makie.axislegend(ax1, merge = true, unique = true)
display(fig)
fig, ax1
end
Expand All @@ -186,14 +180,14 @@ function plot_periodic_potrap(outpof, n, M; ratio = 2)
@assert ratio > 0 "You need at least one component"
outpo = reshape(outpof[1:end-1], ratio * n, M)
if ratio == 1
heatmap(outpo[1:n,:]', ylabel="Time", color=:viridis)
heatmap(outpo[1:n, :]', ylabel = "Time", color = :viridis)
else
fig = GLMakie.Figure()
ax1 = Axis(fig[1,1], ylabel="Time")
ax2 = Axis(fig[1,2], ylabel="Time")
# GLMakie.heatmap!(ax1, rand(2,2))
GLMakie.heatmap!(ax1, outpo[1:n,:]')
GLMakie.heatmap!(ax2, outpo[n+2:end,:]')
fig = Makie.Figure()
ax1 = Axis(fig[1, 1], ylabel = "Time")
ax2 = Axis(fig[1, 2], ylabel = "Time")
# Makie.heatmap!(ax1, rand(2,2))
Makie.heatmap!(ax1, outpo[1:n, :]')
Makie.heatmap!(ax2, outpo[n+2:end, :]')
fig
end
end
Expand All @@ -211,7 +205,9 @@ end
####################################################################################################
# plot recipes for the bifurcation diagram
function plot(bd::BifDiagNode; code = (), level = (-Inf, Inf), k...)
if ~hasbranch(bd); return; end
if ~hasbranch(bd)
return
end

fig = Figure()
ax = fig[1, 1] = Axis(fig)
Expand All @@ -223,7 +219,9 @@ function plot(bd::BifDiagNode; code = (), level = (-Inf, Inf), k...)
end

function _plot_bifdiag_makie!(ax, bd::BifDiagNode; code = (), level = (-Inf, Inf), k...)
if ~hasbranch(bd); return; end
if ~hasbranch(bd)
return
end

_bd = get_branch(bd, code)
_plot_bifdiag_makie!(ax, _bd.child; code = (), level = level, k...)
Expand All @@ -236,16 +234,12 @@ end

function _plot_bifdiag_makie!(ax, bd::Vector{BifDiagNode}; code = (), level = (-Inf, Inf), k...)
for b in bd
_plot_bifdiag_makie!(ax, b; code, level, k... )
_plot_bifdiag_makie!(ax, b; code, level, k...)
end
end
####################################################################################################
plotAllDCBranch(branches) = plot(branches...)

function plot_DCont_branch(::BK_Makie,
branches,
nbrs::Int,
nactive::Int,
nstep::Int)
function plot_DCont_branch(::BK_Makie, branches, nbrs::Int, nactive::Int, nstep::Int)
plot(branches...)
end
6 changes: 4 additions & 2 deletions ext/PlotsExt/PlotsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module PlotsExt
using Plots, BifurcationKit
import BifurcationKit: _plot_backend,
plot_branch_cont,
plot_branch_cont,
plot_periodic_potrap,
plot_periodic_shooting!,
plot_periodic_shooting,
Expand All @@ -18,6 +18,8 @@ module PlotsExt
filter_bifurcations,
get_color,
AbstractResult,
get_plot_backend,
set_plot_backend!,
BK_NoPlot, BK_Plots,
plotAllDCBranch,
plot_DCont_branch,
Expand All @@ -28,7 +30,7 @@ module PlotsExt
include("plot.jl")

function __init__()
_plot_backend[] = BK_Plots()
set_plot_backend!(BK_Plots())
return nothing
end
end
2 changes: 1 addition & 1 deletion src/BifurcationKit.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module BifurcationKit
using Printf, Dates, LinearMaps, BlockArrays, RecipesBase, StructArrays
using Printf, Dates, LinearMaps, BlockArrays, StructArrays
using Reexport
@reexport using Accessors: setproperties, @set, @reset, PropertyLens, getall, set, @optic, IndexLens, ComposedOptic
using Parameters: @with_kw, @unpack, @with_kw_noshow
Expand Down
Loading
Loading