Skip to content

Commit

Permalink
fix: simplenonlinearsolve in cuda kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 7, 2024
1 parent ecdbc78 commit 4d9c30e
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 30 deletions.
28 changes: 28 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,34 @@ steps:
# Don't run Buildkite if the commit message includes the text [skip tests]
if: build.message !~ /\[skip tests\]/

- label: "Julia 1 (SimpleNonlinearSolve)"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
command: |
julia --color=yes --code-coverage=user --depwarn=yes --project=lib/SimpleNonlinearSolve -e '
import Pkg;
Pkg.Registry.update();
# Install packages present in subdirectories
dev_pks = Pkg.PackageSpec[];
for path in ("lib/NonlinearSolveBase", "lib/BracketingNonlinearSolve")
push!(dev_pks, Pkg.PackageSpec(; path))
end
Pkg.develop(dev_pks);
Pkg.instantiate();
Pkg.test(; coverage=true)'
agents:
queue: "juliagpu"
cuda: "*"
timeout_in_minutes: 60
# Don't run Buildkite if the commit message includes the text [skip tests]
if: build.message !~ /\[skip tests\]/

env:
GROUP: CUDA
JULIA_PKG_SERVER: "" # it often struggles with our large artifacts
Expand Down
2 changes: 1 addition & 1 deletion lib/BracketingNonlinearSolve/src/bisection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Bisection,
fl, fr = f(left), f(right)

abstol = NonlinearSolveBase.get_tolerance(
abstol, promote_type(eltype(left), eltype(right)))
left, abstol, promote_type(eltype(left), eltype(right)))

if iszero(fl)
return SciMLBase.build_solution(
Expand Down
2 changes: 1 addition & 1 deletion lib/BracketingNonlinearSolve/src/brent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
ϵ = eps(convert(typeof(fl), 1))

abstol = NonlinearSolveBase.get_tolerance(
abstol, promote_type(eltype(left), eltype(right)))
left, abstol, promote_type(eltype(left), eltype(right)))

if iszero(fl)
return SciMLBase.build_solution(
Expand Down
2 changes: 1 addition & 1 deletion lib/BracketingNonlinearSolve/src/falsi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
fl, fr = f(left), f(right)

abstol = NonlinearSolveBase.get_tolerance(
abstol, promote_type(eltype(left), eltype(right)))
left, abstol, promote_type(eltype(left), eltype(right)))

if iszero(fl)
return SciMLBase.build_solution(
Expand Down
2 changes: 1 addition & 1 deletion lib/BracketingNonlinearSolve/src/itp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
fl, fr = f(left), f(right)

abstol = NonlinearSolveBase.get_tolerance(
abstol, promote_type(eltype(left), eltype(right)))
left, abstol, promote_type(eltype(left), eltype(right)))

if iszero(fl)
return SciMLBase.build_solution(
Expand Down
2 changes: 1 addition & 1 deletion lib/BracketingNonlinearSolve/src/ridder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
fl, fr = f(left), f(right)

abstol = NonlinearSolveBase.get_tolerance(
abstol, promote_type(eltype(left), eltype(right)))
left, abstol, promote_type(eltype(left), eltype(right)))

if iszero(fl)
return SciMLBase.build_solution(
Expand Down
12 changes: 6 additions & 6 deletions lib/NonlinearSolveBase/src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ function SciMLBase.init(
du, u, mode::AbstractNonlinearTerminationMode, saved_value_prototype...;
abstol = nothing, reltol = nothing, kwargs...)
T = promote_type(eltype(du), eltype(u))
abstol = get_tolerance(abstol, T)
reltol = get_tolerance(reltol, T)
abstol = get_tolerance(u, abstol, T)
reltol = get_tolerance(u, reltol, T)
TT = typeof(abstol)

u_unaliased = mode isa AbstractSafeBestNonlinearTerminationMode ?
Expand Down Expand Up @@ -90,8 +90,8 @@ function SciMLBase.reinit!(
cache.u = u_unaliased
cache.retcode = ReturnCode.Default

cache.abstol = get_tolerance(abstol, T)
cache.reltol = get_tolerance(reltol, T)
cache.abstol = get_tolerance(u, abstol, T)
cache.reltol = get_tolerance(u, reltol, T)
cache.nsteps = 0
TT = typeof(cache.abstol)

Expand Down Expand Up @@ -274,8 +274,8 @@ end
function init_termination_cache(::AbstractNonlinearProblem, abstol, reltol, du,
u, tc::AbstractNonlinearTerminationMode, ::Val)
T = promote_type(eltype(du), eltype(u))
abstol = get_tolerance(abstol, T)
reltol = get_tolerance(reltol, T)
abstol = get_tolerance(u, abstol, T)
reltol = get_tolerance(u, reltol, T)
cache = SciMLBase.init(du, u, tc; abstol, reltol)
return abstol, reltol, cache
end
31 changes: 29 additions & 2 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,35 +37,62 @@ SimpleNonlinearSolveTrackerExt = "Tracker"

[compat]
ADTypes = "1.2"
Accessors = "0.1"
AllocCheck = "0.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.16"
BracketingNonlinearSolve = "1"
CUDA = "5.3"
ChainRulesCore = "1.24"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.155"
DifferentiationInterface = "0.6.1"
Enzyme = "0.13"
ExplicitImports = "1.9"
FastClosures = "0.3.2"
FiniteDiff = "2.24.0"
ForwardDiff = "0.10.36"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
LineSearch = "0.1.3"
LinearAlgebra = "1.10"
MaybeInplace = "0.1.4"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1"
Pkg = "1.10"
PolyesterForwardDiff = "0.1"
PrecompileTools = "1.2"
Random = "1.10"
Reexport = "1.2"
ReverseDiff = "1.15"
SciMLBase = "2.50"
SciMLSensitivity = "7.68"
StaticArrays = "1.9"
StaticArraysCore = "1.4.3"
Test = "1.10"
TestItemRunner = "1"
Tracker = "0.2.35"
Zygote = "0.6.70"
julia = "1.10"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["InteractiveUtils", "Test", "TestItemRunner"]
test = ["AllocCheck", "Aqua", "CUDA", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]
7 changes: 6 additions & 1 deletion lib/SimpleNonlinearSolve/src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,12 @@ end
for i in 1:threshold
static_idx, static_idx_p1 = Val(i - 1), Val(i)
push!(calls, quote
α = ls_cache === nothing ? true : ls_cache(xo, δx)
if ls_cache === nothing
α = true
else
ls_sol = solve!(ls_cache, xo, δx)
α = ls_sol.step_size # Ignores the return code for now
end
x = xo .+ α .* δx
fx = prob.f(x, prob.p)
δf = fx - fo
Expand Down
61 changes: 47 additions & 14 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Utils

using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface, Constant
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra, I, diagind
Expand Down Expand Up @@ -116,25 +117,35 @@ restructure(::Number, x::Number) = x
safe_vec(x::AbstractArray) = vec(x)
safe_vec(x::Number) = x

abstract type AbstractJacobianMode end

struct AnalyticJacobian <: AbstractJacobianMode end
@concrete struct DIExtras <: AbstractJacobianMode
prep
end
struct DINoPreparation <: AbstractJacobianMode end

# While we could run prep in other cases, we don't since we need it completely
# non-allocating for running inside GPU kernels
function prepare_jacobian(prob, autodiff, _, x::Number)
if SciMLBase.has_jac(prob.f) || SciMLBase.has_vjp(prob.f) || SciMLBase.has_jvp(prob.f)
return nothing
return AnalyticJacobian()
end
return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p))
# return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p))
return DINoPreparation()
end
function prepare_jacobian(prob, autodiff, fx, x)
if SciMLBase.has_jac(prob.f)
return nothing
end
SciMLBase.has_jac(prob.f) && return AnalyticJacobian()
if SciMLBase.isinplace(prob.f)
return DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p))
return DIExtras(DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p)))
else
x isa SArray && return DINoPreparation()
return DI.prepare_jacobian(prob.f, autodiff, x, Constant(prob.p))
end
end

function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras)
if extras === nothing
if extras isa AnalyticJacobian
if SciMLBase.has_jac(prob.f)
return prob.f.jac(x, prob.p)
elseif SciMLBase.has_vjp(prob.f)
Expand All @@ -143,11 +154,15 @@ function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras)
return prob.f.jvp(one(x), x, prob.p)
end
end
return DI.derivative(prob.f, extras, autodiff, x, Constant(prob.p))
if extras isa DIExtras
return DI.derivative(prob.f, extras.prep, autodiff, x, Constant(prob.p))
else
return DI.derivative(prob.f, autodiff, x, Constant(prob.p))
end
end
function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
if J === nothing
if extras === nothing
if extras isa AnalyticJacobian
if SciMLBase.isinplace(prob.f)
J = similar(fx, length(fx), length(x))
prob.f.jac(J, x, prob.p)
Expand All @@ -157,12 +172,17 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
end
end
if SciMLBase.isinplace(prob)
return DI.jacobian(prob.f, fx, extras, autodiff, x, Constant(prob.p))
@assert extras isa DIExtras
return DI.jacobian(prob.f, fx, extras.prep, autodiff, x, Constant(prob.p))
else
return DI.jacobian(prob.f, extras, autodiff, x, Constant(prob.p))
if extras isa DIExtras
return DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p))
else
return DI.jacobian(prob.f, autodiff, x, Constant(prob.p))
end
end
end
if extras === nothing
if extras isa AnalyticJacobian
if SciMLBase.isinplace(prob)
prob.jac(J, x, prob.p)
return J
Expand All @@ -171,9 +191,22 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
end
end
if SciMLBase.isinplace(prob)
DI.jacobian!(prob.f, fx, J, extras, autodiff, x, Constant(prob.p))
@assert extras isa DIExtras
DI.jacobian!(prob.f, fx, J, extras.prep, autodiff, x, Constant(prob.p))
else
DI.jacobian!(prob.f, J, extras, autodiff, x, Constant(prob.p))
if ArrayInterface.can_setindex(J)
if extras isa DIExtras
DI.jacobian!(prob.f, J, extras.prep, autodiff, x, Constant(prob.p))
else
DI.jacobian!(prob.f, J, autodiff, x, Constant(prob.p))
end
else
if extras isa DIExtras
J = DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p))
else
J = DI.jacobian(prob.f, autodiff, x, Constant(prob.p))
end
end
end
return J
end
Expand Down
90 changes: 90 additions & 0 deletions lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
@testitem "Solving on CUDA" tags=[:cuda] begin
using StaticArrays, CUDA, SimpleNonlinearSolve

if CUDA.functional()
CUDA.allowscalar(false)

f(u, p) = u .* u .- 2
f!(du, u, p) = (du .= u .* u .- 2)

@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(),
SimpleDFSane(),
SimpleTrustRegion(),
SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
SimpleBroyden(),
SimpleLimitedMemoryBroyden(),
SimpleKlement(),
SimpleHalley(),
SimpleBroyden(; linesearch = Val(true)),
SimpleLimitedMemoryBroyden(; linesearch = Val(true))
)
# Static Arrays
u0 = @SVector[1.0f0, 1.0f0]
probN = NonlinearProblem{false}(f, u0)
sol = solve(probN, alg; abstol = 1.0f-6)
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, sol.resid) 1.0f-6

# Regular Arrays
u0 = [1.0, 1.0]
probN = NonlinearProblem{false}(f, u0)
sol = solve(probN, alg; abstol = 1.0f-6)
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, sol.resid) 1.0f-6

# Regular Arrays Inplace
if !(alg isa SimpleHalley)
u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(f!, u0)
sol = solve(probN, alg; abstol = 1.0f-6)
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, sol.resid) 1.0f-6
end
end
end
end

@testitem "CUDA Kernel Launch Test" tags=[:cuda] begin
using StaticArrays, CUDA, SimpleNonlinearSolve
using NonlinearSolveBase: ImmutableNonlinearProblem

if CUDA.functional()
CUDA.allowscalar(false)

f(u, p) = u .* u .- p

function kernel_function(prob, alg)
solve(prob, alg)
return nothing
end

@testset for u0 in (1.0f0, @SVector[1.0f0, 1.0f0])
prob = convert(ImmutableNonlinearProblem, NonlinearProblem{false}(f, u0, 2.0f0))

@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(),
SimpleDFSane(),
SimpleTrustRegion(),
SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
SimpleBroyden(),
SimpleLimitedMemoryBroyden(),
SimpleKlement(),
SimpleHalley(),
SimpleBroyden(; linesearch = Val(true)),
SimpleLimitedMemoryBroyden(; linesearch = Val(true))
)
@test begin
try
@cuda kernel_function(prob, alg)
@info "Successfully launched kernel for $(alg)."
true
catch err
@error "Kernel Launch failed for $(alg)."
false
end
end broken=(alg isa SimpleHalley && u0 isa StaticArray)
end
end
end
end
Loading

0 comments on commit 4d9c30e

Please sign in to comment.