Skip to content

Commit

Permalink
fix: final round of cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 16, 2024
1 parent 9efad5e commit c483d1c
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 52 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -64,7 +63,6 @@ ExplicitImports = "1.10.1"
Flux = "0.14.22"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
GPUArraysCore = "0.1.6"
Integrals = "4.5"
IntervalSets = "0.7.10"
LineSearches = "7.3"
Expand Down
27 changes: 10 additions & 17 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ Contains `ahmc_bayesian_pinn_ode()` function output:
- step_size
- nom_step_size
"""
struct BPINNstats{MC, S, ST}
mcmc_chain::MC
samples::S
statistics::ST
@concrete struct BPINNstats
mcmc_chain
samples
statistics
end

"""
Expand All @@ -146,19 +146,12 @@ contains fields related to that).
3. `estimated_de_params` - Probabilistic Estimate of DE params from sampled unknown DE
parameters.
"""
struct BPINNsolution{O <: BPINNstats, E, NP, OP, P}
original::O
ensemblesol::E
estimated_nn_params::NP
estimated_de_params::OP
timepoints::P

function BPINNsolution(
original, ensemblesol, estimated_nn_params, estimated_de_params, timepoints)
new{typeof(original), typeof(ensemblesol), typeof(estimated_nn_params),
typeof(estimated_de_params), typeof(timepoints)}(
original, ensemblesol, estimated_nn_params, estimated_de_params, timepoints)
end
@concrete struct BPINNsolution
original <: BPINNstats
ensemblesol
estimated_nn_params
estimated_de_params
timepoints
end

function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt = nothing,
Expand Down
7 changes: 4 additions & 3 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ using DocStringExtensions: FIELDS
using DomainSets: DomainSets, AbstractInterval, leftendpoint, rightendpoint, ProductDomain
using ForwardDiff: ForwardDiff
using Functors: Functors, fmap
using GPUArraysCore: @allowscalar
using Integrals: Integrals, CubatureJLh, QuadGKJL
using IntervalSets: infimum, supremum
using LinearAlgebra: Diagonal
using Lux: Lux, Chain, Dense, SkipConnection, StatefulLuxLayer
using Lux: FromFluxAdaptor, recursive_eltype
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
using MLDataDevices: CPUDevice, cpu_device, get_device
using MLDataDevices: CPUDevice, get_device
using Optimisers: Optimisers, Adam
using Optimization: Optimization
using OptimizationOptimisers: OptimizationOptimisers
Expand Down Expand Up @@ -61,8 +60,10 @@ abstract type AbstractPINN end

abstract type AbstractTrainingStrategy end

const cdev = CPUDevice()

@inline safe_get_device(x) = safe_get_device(get_device(x), x)
@inline safe_get_device(::Nothing, x) = cpu_device()
@inline safe_get_device(::Nothing, x) = cdev
@inline safe_get_device(dev, _) = dev

@inline safe_expand(dev, x) = dev(x)
Expand Down
3 changes: 1 addition & 2 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ function (f::ODEPhi)(t, θ)
end

function (f::ODEPhi{<:Number})(dev, t::Number, θ)
res_vec = f.smodel(dev([t]), θ.depvar)
res = @allowscalar only(res_vec)
res = only(cdev(f.smodel(dev([t]), θ.depvar)))
return f.u0 + (t - f.t0) * res
end

Expand Down
5 changes: 1 addition & 4 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ function Phi(layer::AbstractLuxLayer)
layer, nothing, initialstates(Random.default_rng(), layer)))
end

function (f::Phi)(x::Number, θ)
res_vec = f([x], θ)
return @allowscalar only(res_vec)
end
(f::Phi)(x::Number, θ) = only(cdev(f([x], θ)))

(f::Phi)(x::AbstractArray, θ) = f.smodel(safe_get_device(θ)(x), θ)

Expand Down
3 changes: 1 addition & 2 deletions src/rode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ function (f::RODEPhi)(t, W, θ)
end

function (f::RODEPhi{<:Number})(dev, t::Number, W, θ)
res_vec = f.smodel(dev([t, W]), θ.depvar)
res = @allowscalar only(res_vec)
res = only(cdev(f.smodel(dev([t, W]), θ.depvar)))
return f.u0 + (t - f.t0) * res
end

Expand Down
28 changes: 6 additions & 22 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,12 @@ end
ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0))
luxchain = Chain(Dense(1, 5, σ), Dense(5, 1))

opt = OptimizationOptimisers.Adam(0.1)
sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400,
abstol = 1.0f-8)
@test sol.errors[:l2] < 0.5

sol = solve(prob,
NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)),
verbose = false, maxiters = 400, abstol = 1.0f-8)
@test sol.errors[:l2] < 0.5

sol = solve(prob,
NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)),
verbose = false, maxiters = 400, abstol = 1.0f-8)
@test sol.errors[:l2] < 0.5

sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false,
maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0)
@test sol.errors[:l2] < 0.5

sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false,
maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0)
@test sol.errors[:l2] < 0.5
@testset for batch in (true, false), strategy in (StochasticTraining(100), nothing)
opt = OptimizationOptimisers.Adam(0.1)
sol = solve(prob, NNODE(luxchain, opt; batch, strategy),
verbose = false, maxiters = 400, abstol = 1.0f-8)
@test sol.errors[:l2] < 0.5
end
end

@testset "Example 3" begin
Expand Down

0 comments on commit c483d1c

Please sign in to comment.