diff --git a/Project.toml b/Project.toml index b43b31e9d..f8542433f 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -64,7 +63,6 @@ ExplicitImports = "1.10.1" Flux = "0.14.22" ForwardDiff = "0.10.36" Functors = "0.4.12" -GPUArraysCore = "0.1.6" Integrals = "4.5" IntervalSets = "0.7.10" LineSearches = "7.3" diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index 306bf4c7d..f65f1d659 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -130,10 +130,10 @@ Contains `ahmc_bayesian_pinn_ode()` function output: - step_size - nom_step_size """ -struct BPINNstats{MC, S, ST} - mcmc_chain::MC - samples::S - statistics::ST +@concrete struct BPINNstats + mcmc_chain + samples + statistics end """ @@ -146,19 +146,12 @@ contains fields related to that). 3. `estimated_de_params` - Probabilistic Estimate of DE params from sampled unknown DE parameters. """ -struct BPINNsolution{O <: BPINNstats, E, NP, OP, P} - original::O - ensemblesol::E - estimated_nn_params::NP - estimated_de_params::OP - timepoints::P - - function BPINNsolution( - original, ensemblesol, estimated_nn_params, estimated_de_params, timepoints) - new{typeof(original), typeof(ensemblesol), typeof(estimated_nn_params), - typeof(estimated_de_params), typeof(timepoints)}( - original, ensemblesol, estimated_nn_params, estimated_de_params, timepoints) - end +@concrete struct BPINNsolution + original <: BPINNstats + ensemblesol + estimated_nn_params + estimated_de_params + timepoints end function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt = nothing, diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index b9e1d6202..9cd40b239 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -11,14 +11,13 @@ using DocStringExtensions: FIELDS using DomainSets: DomainSets, AbstractInterval, leftendpoint, rightendpoint, ProductDomain using ForwardDiff: ForwardDiff using Functors: Functors, fmap -using GPUArraysCore: @allowscalar using Integrals: Integrals, CubatureJLh, QuadGKJL using IntervalSets: infimum, supremum using LinearAlgebra: Diagonal using Lux: Lux, Chain, Dense, SkipConnection, StatefulLuxLayer using Lux: FromFluxAdaptor, recursive_eltype using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer -using MLDataDevices: CPUDevice, cpu_device, get_device +using MLDataDevices: CPUDevice, get_device using Optimisers: Optimisers, Adam using Optimization: Optimization using OptimizationOptimisers: OptimizationOptimisers @@ -61,8 +60,10 @@ abstract type AbstractPINN end abstract type AbstractTrainingStrategy end +const cdev = CPUDevice() + @inline safe_get_device(x) = safe_get_device(get_device(x), x) -@inline safe_get_device(::Nothing, x) = cpu_device() +@inline safe_get_device(::Nothing, x) = cdev @inline safe_get_device(dev, _) = dev @inline safe_expand(dev, x) = dev(x) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index a728e4fea..fe6a770cd 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -133,8 +133,7 @@ function (f::ODEPhi)(t, θ) end function (f::ODEPhi{<:Number})(dev, t::Number, θ) - res_vec = f.smodel(dev([t]), θ.depvar) - res = @allowscalar only(res_vec) + res = only(cdev(f.smodel(dev([t]), θ.depvar))) return f.u0 + (t - f.t0) * res end diff --git a/src/pinn_types.jl b/src/pinn_types.jl index bf1d84522..1fdd4cb43 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -38,10 +38,7 @@ function Phi(layer::AbstractLuxLayer) layer, nothing, initialstates(Random.default_rng(), layer))) end -function (f::Phi)(x::Number, θ) - res_vec = f([x], θ) - return @allowscalar only(res_vec) -end +(f::Phi)(x::Number, θ) = only(cdev(f([x], θ))) (f::Phi)(x::AbstractArray, θ) = f.smodel(safe_get_device(θ)(x), θ) diff --git a/src/rode_solve.jl b/src/rode_solve.jl index 9b823794b..33495cdf5 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -29,8 +29,7 @@ function (f::RODEPhi)(t, W, θ) end function (f::RODEPhi{<:Number})(dev, t::Number, W, θ) - res_vec = f.smodel(dev([t, W]), θ.depvar) - res = @allowscalar only(res_vec) + res = only(cdev(f.smodel(dev([t, W]), θ.depvar))) return f.u0 + (t - f.t0) * res end diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index e100b2e06..96fc17a19 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -90,28 +90,12 @@ end ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0)) luxchain = Chain(Dense(1, 5, σ), Dense(5, 1)) - opt = OptimizationOptimisers.Adam(0.1) - sol = solve(prob, NNODE(luxchain, opt), verbose = false, maxiters = 400, - abstol = 1.0f-8) - @test sol.errors[:l2] < 0.5 - - sol = solve(prob, - NNODE(luxchain, opt; batch = false, strategy = StochasticTraining(100)), - verbose = false, maxiters = 400, abstol = 1.0f-8) - @test sol.errors[:l2] < 0.5 - - sol = solve(prob, - NNODE(luxchain, opt; batch = true, strategy = StochasticTraining(100)), - verbose = false, maxiters = 400, abstol = 1.0f-8) - @test sol.errors[:l2] < 0.5 - - sol = solve(prob, NNODE(luxchain, opt; batch = false), verbose = false, - maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) - @test sol.errors[:l2] < 0.5 - - sol = solve(prob, NNODE(luxchain, opt; batch = true), verbose = false, - maxiters = 400, abstol = 1.0f-8, dt = 1 / 5.0f0) - @test sol.errors[:l2] < 0.5 + @testset for batch in (true, false), strategy in (StochasticTraining(100), nothing) + opt = OptimizationOptimisers.Adam(0.1) + sol = solve(prob, NNODE(luxchain, opt; batch, strategy), + verbose = false, maxiters = 400, abstol = 1.0f-8) + @test sol.errors[:l2] < 0.5 + end end @testset "Example 3" begin