Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better BPINN ode Solver #853

Merged
merged 20 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ struct BNNODE{C, K, IT <: NamedTuple,
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
numensemble::Int64
estim_collocate::Bool
autodiff::Bool
progress::Bool
verbose::Bool
Expand All @@ -112,6 +114,8 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
numensemble = floor(Int, draw_samples / 3),
estim_collocate = false,
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
Expand All @@ -120,6 +124,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
numensemble, estim_collocate,
autodiff, progress, verbose)
end

Expand Down Expand Up @@ -186,7 +191,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg
MCMCkwargs, numensemble, estim_collocate, autodiff, progress,
verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
Expand All @@ -211,7 +217,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)
verbose = verbose,
estim_collocate = estim_collocate)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
ninv = length(param)
Expand All @@ -220,7 +227,8 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
if chain isa Lux.AbstractExplicitLayer
θinit, st = Lux.setup(Random.default_rng(), chain)
θ = [vector_to_parameters(samples[i][1:(end - ninv)], θinit)
for i in (draw_samples - numensemble):draw_samples]
for i in 1:max(draw_samples - draw_samples ÷ 10, draw_samples - 1000)]

luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
Expand Down
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ include("advancedHMC_MCMC.jl")
include("BPINN_ode.jl")
include("PDE_BPINN.jl")
include("dgm.jl")
include("collocated_estim.jl")

export NNODE, NNDAE,
PhysicsInformedNN, discretize,
Expand Down
31 changes: 20 additions & 11 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
physdt::Float64
extraparams::Int
init_params::I
estim_collocate::Bool

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector)
init_params::AbstractVector, estim_collocate)
new{
typeof(chain),
Nothing,
Expand All @@ -39,12 +40,13 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, strategy,
dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple)
init_params::NamedTuple, estim_collocate)
new{
typeof(chain),
typeof(st),
Expand All @@ -60,7 +62,8 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I,
autodiff,
physdt,
extraparams,
init_params)
init_params,
estim_collocate)
end
end

Expand All @@ -83,7 +86,12 @@ end
vector_to_parameters(ps_new::AbstractVector, ps::AbstractVector) = ps_new

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
if Tar.estim_collocate
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) +
L2loss2(Tar, θ)
else
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
end
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim
Expand Down Expand Up @@ -247,7 +255,7 @@ function innerdiff(Tar::LogTargetDensity, f, autodiff::Bool, t::AbstractVector,

vals = nnsol .- physsol

# N dimensional vector if N outputs for NN(each row has logpdf of i[i] where u is vector of dependant variables)
# N dimensional vector if N outputs for NN(each row has logpdf of u[i] where u is vector of dependant variables)
return [logpdf(
MvNormal(vals[i, :],
LinearAlgebra.Diagonal(abs2.(Tar.phystd[i] .*
Expand Down Expand Up @@ -442,7 +450,8 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)
progress = false, verbose = false,
estim_collocate = false)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
# NN parameter prior mean and variance(PriorsNN must be a tuple)
Expand All @@ -467,7 +476,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
# Lux-Named Tuple
initial_nnθ, recon, st = generate_Tar(chain, init_params)
else
error("Only Lux.AbstractExplicitLayer neural networks are supported")
error("Only Lux.AbstractExplicitLayer Neural networks are supported")
end

if nchains > Threads.nthreads()
Expand Down Expand Up @@ -500,7 +509,7 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
t0 = prob.tspan[1]
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, strategy, dataset, priors,
phystd, l2std, autodiff, physdt, ninv, initial_nnθ)
phystd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
Expand Down Expand Up @@ -569,8 +578,8 @@ function ahmc_bayesian_pinn_ode(prob::SciMLBase.ODEProblem, chain;
L2LossData(ℓπ, samples[end]))

# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1))
mcmc_chain = MCMCChains.Chains(matrix_samples)
return mcmc_chain, samples, stats
end
end
46 changes: 46 additions & 0 deletions src/collocated_estim.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# suggested extra loss function for ODE solver case
function L2loss2(Tar::LogTargetDensity, θ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this a separate file/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

f = Tar.prob.f

# parameter estimation chosen or not
if Tar.extraparams > 0
autodiff = Tar.autodiff
# Timepoints to enforce Physics
t = Tar.dataset[end]
u1 = Tar.dataset[2]
û = Tar.dataset[1]

nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff)

ode_params = Tar.extraparams == 1 ?
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
θ[((length(θ) - Tar.extraparams) + 1):length(θ)]

if length(Tar.prob.u0) == 1
physsol = [f(û[i],
ode_params,
t[i])
for i in 1:length(û[:, 1])]
else
physsol = [f([û[i], u1[i]],
ode_params,
t[i])
for i in 1:length(û)]
end
#form of NN output matrix output dim x n
deri_physsol = reduce(hcat, physsol)

physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(deri_physsol[i, :],
LinearAlgebra.Diagonal(map(abs2,
(Tar.l2std[i] * 4.0) .*
ones(length(nnsol[i, :]))))),
nnsol[i, :])
end
return physlogprob
else
return 0
end
end
24 changes: 12 additions & 12 deletions test/BPINN_PDEinvsol_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ComponentArrays

Random.seed!(100)

@testset "Example 1: 2D Periodic System with parameter estimation" begin
@testset "Example 1: 1D Periodic System with parameter estimation" begin
# Cos(pi*t) periodic curve
@parameters t, p
@variables u(..)
Expand Down Expand Up @@ -59,17 +59,17 @@ Random.seed!(100)
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)])

discretization = BayesianPINN([chainl], QuadratureTraining(), param_estim = true,
dataset = [dataset, nothing])

ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 1500,
bcstd = [0.05],
phystd = [0.01], l2std = [0.01],
priorsNNw = (0.0, 1.0),
saveats = [1 / 50.0],
param = [LogNormal(6.0, 0.5)])
# discretization = BayesianPINN([chainl], QuadratureTraining(), param_estim = true,
# dataset = [dataset, nothing])

# ahmc_bayesian_pinn_pde(pde_system,
# discretization;
# draw_samples = 1500,
# bcstd = [0.05],
# phystd = [0.01], l2std = [0.01],
# priorsNNw = (0.0, 1.0),
# saveats = [1 / 50.0],
# param = [LogNormal(6.0, 0.5)])
AstitvaAggarwal marked this conversation as resolved.
Show resolved Hide resolved

discretization = BayesianPINN([chainl], GridTraining([0.02]), param_estim = true,
dataset = [dataset, nothing])
Expand Down
31 changes: 17 additions & 14 deletions test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ Random.seed!(100)
# testing points
t = time
# Mean of last 500 sampled parameter's curves[Ensemble predictions]
θ = [vector_to_parameters(fhsamples[i], θinit) for i in 2000:2500]
luxar = [chainlux(t', θ[i], st)[1] for i in 1:500]
θ = [vector_to_parameters(fhsamples[i], θinit) for i in 2000:length(fhsamples)]
luxar = [chainlux(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

Expand All @@ -54,8 +54,8 @@ Random.seed!(100)
@test mean(abs.(physsol1 .- meanscurve)) < 0.005

#--------------------- solve() call
@test mean(abs.(x̂1 .- sol1lux.ensemblesol[1])) < 0.05
@test mean(abs.(physsol0_1 .- sol1lux.ensemblesol[1])) < 0.05
@test mean(abs.(x̂1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025
@test mean(abs.(physsol0_1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025
end

@testset "Example 2 - with parameter estimation" begin
Expand Down Expand Up @@ -111,19 +111,20 @@ end
# testing points
t = time
# Mean of last 500 sampled parameter's curves(flux and lux chains)[Ensemble predictions]
θ = [vector_to_parameters(fhsamples[i][1:(end - 1)], θinit) for i in 2000:2500]
luxar = [chainlux1(t', θ[i], st)[1] for i in 1:500]
θ = [vector_to_parameters(fhsamples[i][1:(end - 1)], θinit)
for i in 2000:length(fhsamples)]
luxar = [chainlux1(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

# --------------------- ahmc_bayesian_pinn_ode() call
@test mean(abs.(physsol1 .- meanscurve)) < 0.15

# ESTIMATED ODE PARAMETERS (NN1 AND NN2)
@test abs(p - mean([fhsamples[i][23] for i in 2000:2500])) < abs(0.35 * p)
@test abs(p - mean([fhsamples[i][23] for i in 2000:length(fhsamples)])) < abs(0.35 * p)

#-------------------------- solve() call
@test mean(abs.(physsol1_1 .- sol2lux.ensemblesol[1])) < 8e-2
@test mean(abs.(physsol1_1 .- pmean(sol2lux.ensemblesol[1]))) < 8e-2

# ESTIMATED ODE PARAMETERS (NN1 AND NN2)
@test abs(p - sol2lux.estimated_de_params[1]) < abs(0.15 * p)
Expand Down Expand Up @@ -193,13 +194,15 @@ end
t = sol.t
#------------------------------ ahmc_bayesian_pinn_ode() call
# Mean of last 500 sampled parameter's curves(lux chains)[Ensemble predictions]
θ = [vector_to_parameters(fhsampleslux12[i], θinit) for i in 1000:1500]
luxar = [chainlux12(t', θ[i], st)[1] for i in 1:500]
θ = [vector_to_parameters(fhsampleslux12[i], θinit)
for i in 1000:length(fhsampleslux12)]
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve2_1 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit) for i in 1000:1500]
luxar = [chainlux12(t', θ[i], st)[1] for i in 1:500]
θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit)
for i in 1000:length(fhsampleslux22)]
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)]
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)]
meanscurve2_2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean

Expand All @@ -209,12 +212,12 @@ end
@test mean(abs.(physsol1 .- meanscurve2_2)) < 5e-2

# estimated parameters(lux chain)
param1 = mean(i[62] for i in fhsampleslux22[1000:1500])
param1 = mean(i[62] for i in fhsampleslux22[1000:length(fhsampleslux22)])
@test abs(param1 - p) < abs(0.3 * p)

#-------------------------- solve() call
# (lux chain)
@test mean(abs.(physsol2 .- sol3lux_pestim.ensemblesol[1])) < 0.15
@test mean(abs.(physsol2 .- pmean(sol3lux_pestim.ensemblesol[1]))) < 0.15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pmean typo?

Copy link
Contributor Author

@AstitvaAggarwal AstitvaAggarwal Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, the mean is required as the solution's standard deviation are different at domain points, sometimes these uncertainties can be large enough for the tests to fail. so i just take the means for testing.

# estimated parameters(lux chain)
param1 = sol3lux_pestim.estimated_de_params[1]
@test abs(param1 - p) < abs(0.45 * p)
Expand Down
Loading
Loading