-
-
Notifications
You must be signed in to change notification settings - Fork 209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better BPINN ode Solver #853
Changes from 10 commits
5648819
b8182c4
400cdb7
80679bb
c3f9366
6fb20d5
8538e59
3c43177
36905a8
c071f40
7a2cfa2
044bd83
70581be
4a00341
7111f48
1fef5b2
4de5691
b25be13
3102215
fab83e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# suggested extra loss function for ODE solver case | ||
function L2loss2(Tar::LogTargetDensity, θ) | ||
f = Tar.prob.f | ||
|
||
# parameter estimation chosen or not | ||
if Tar.extraparams > 0 | ||
autodiff = Tar.autodiff | ||
# Timepoints to enforce Physics | ||
t = Tar.dataset[end] | ||
u1 = Tar.dataset[2] | ||
û = Tar.dataset[1] | ||
|
||
nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff) | ||
|
||
ode_params = Tar.extraparams == 1 ? | ||
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : | ||
θ[((length(θ) - Tar.extraparams) + 1):length(θ)] | ||
|
||
if length(Tar.prob.u0) == 1 | ||
physsol = [f(û[i], | ||
ode_params, | ||
t[i]) | ||
for i in 1:length(û[:, 1])] | ||
else | ||
physsol = [f([û[i], u1[i]], | ||
ode_params, | ||
t[i]) | ||
for i in 1:length(û)] | ||
end | ||
#form of NN output matrix output dim x n | ||
deri_physsol = reduce(hcat, physsol) | ||
|
||
physlogprob = 0 | ||
for i in 1:length(Tar.prob.u0) | ||
# can add phystd[i] for u[i] | ||
physlogprob += logpdf(MvNormal(deri_physsol[i, :], | ||
LinearAlgebra.Diagonal(map(abs2, | ||
(Tar.l2std[i] * 4.0) .* | ||
ones(length(nnsol[i, :]))))), | ||
nnsol[i, :]) | ||
end | ||
return physlogprob | ||
else | ||
return 0 | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,8 +44,8 @@ Random.seed!(100) | |
# testing points | ||
t = time | ||
# Mean of last 500 sampled parameter's curves[Ensemble predictions] | ||
θ = [vector_to_parameters(fhsamples[i], θinit) for i in 2000:2500] | ||
luxar = [chainlux(t', θ[i], st)[1] for i in 1:500] | ||
θ = [vector_to_parameters(fhsamples[i], θinit) for i in 2000:length(fhsamples)] | ||
luxar = [chainlux(t', θ[i], st)[1] for i in eachindex(θ)] | ||
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] | ||
meanscurve = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean | ||
|
||
|
@@ -54,8 +54,8 @@ Random.seed!(100) | |
@test mean(abs.(physsol1 .- meanscurve)) < 0.005 | ||
|
||
#--------------------- solve() call | ||
@test mean(abs.(x̂1 .- sol1lux.ensemblesol[1])) < 0.05 | ||
@test mean(abs.(physsol0_1 .- sol1lux.ensemblesol[1])) < 0.05 | ||
@test mean(abs.(x̂1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025 | ||
@test mean(abs.(physsol0_1 .- pmean(sol1lux.ensemblesol[1]))) < 0.025 | ||
end | ||
|
||
@testset "Example 2 - with parameter estimation" begin | ||
|
@@ -111,19 +111,20 @@ end | |
# testing points | ||
t = time | ||
# Mean of last 500 sampled parameter's curves(flux and lux chains)[Ensemble predictions] | ||
θ = [vector_to_parameters(fhsamples[i][1:(end - 1)], θinit) for i in 2000:2500] | ||
luxar = [chainlux1(t', θ[i], st)[1] for i in 1:500] | ||
θ = [vector_to_parameters(fhsamples[i][1:(end - 1)], θinit) | ||
for i in 2000:length(fhsamples)] | ||
luxar = [chainlux1(t', θ[i], st)[1] for i in eachindex(θ)] | ||
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] | ||
meanscurve = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean | ||
|
||
# --------------------- ahmc_bayesian_pinn_ode() call | ||
@test mean(abs.(physsol1 .- meanscurve)) < 0.15 | ||
|
||
# ESTIMATED ODE PARAMETERS (NN1 AND NN2) | ||
@test abs(p - mean([fhsamples[i][23] for i in 2000:2500])) < abs(0.35 * p) | ||
@test abs(p - mean([fhsamples[i][23] for i in 2000:length(fhsamples)])) < abs(0.35 * p) | ||
|
||
#-------------------------- solve() call | ||
@test mean(abs.(physsol1_1 .- sol2lux.ensemblesol[1])) < 8e-2 | ||
@test mean(abs.(physsol1_1 .- pmean(sol2lux.ensemblesol[1]))) < 8e-2 | ||
|
||
# ESTIMATED ODE PARAMETERS (NN1 AND NN2) | ||
@test abs(p - sol2lux.estimated_de_params[1]) < abs(0.15 * p) | ||
|
@@ -193,13 +194,15 @@ end | |
t = sol.t | ||
#------------------------------ ahmc_bayesian_pinn_ode() call | ||
# Mean of last 500 sampled parameter's curves(lux chains)[Ensemble predictions] | ||
θ = [vector_to_parameters(fhsampleslux12[i], θinit) for i in 1000:1500] | ||
luxar = [chainlux12(t', θ[i], st)[1] for i in 1:500] | ||
θ = [vector_to_parameters(fhsampleslux12[i], θinit) | ||
for i in 1000:length(fhsampleslux12)] | ||
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)] | ||
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] | ||
meanscurve2_1 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean | ||
|
||
θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit) for i in 1000:1500] | ||
luxar = [chainlux12(t', θ[i], st)[1] for i in 1:500] | ||
θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 1)], θinit) | ||
for i in 1000:length(fhsampleslux22)] | ||
luxar = [chainlux12(t', θ[i], st)[1] for i in eachindex(θ)] | ||
luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] | ||
meanscurve2_2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean | ||
|
||
|
@@ -209,12 +212,12 @@ end | |
@test mean(abs.(physsol1 .- meanscurve2_2)) < 5e-2 | ||
|
||
# estimated parameters(lux chain) | ||
param1 = mean(i[62] for i in fhsampleslux22[1000:1500]) | ||
param1 = mean(i[62] for i in fhsampleslux22[1000:length(fhsampleslux22)]) | ||
@test abs(param1 - p) < abs(0.3 * p) | ||
|
||
#-------------------------- solve() call | ||
# (lux chain) | ||
@test mean(abs.(physsol2 .- sol3lux_pestim.ensemblesol[1])) < 0.15 | ||
@test mean(abs.(physsol2 .- pmean(sol3lux_pestim.ensemblesol[1]))) < 0.15 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nope, the mean is required as the solution's standard deviation are different at domain points, sometimes these uncertainties can be large enough for the tests to fail. so i just take the means for testing. |
||
# estimated parameters(lux chain) | ||
param1 = sol3lux_pestim.estimated_de_params[1] | ||
@test abs(param1 - p) < abs(0.45 * p) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this a separate file/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed