From 4d9c30ecdf9e30e7acbdc030fdc59416a7ae36f4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 6 Oct 2024 21:04:50 -0400 Subject: [PATCH] fix: simplenonlinearsolve in cuda kernels --- .buildkite/pipeline.yml | 28 ++++++ lib/BracketingNonlinearSolve/src/bisection.jl | 2 +- lib/BracketingNonlinearSolve/src/brent.jl | 2 +- lib/BracketingNonlinearSolve/src/falsi.jl | 2 +- lib/BracketingNonlinearSolve/src/itp.jl | 2 +- lib/BracketingNonlinearSolve/src/ridder.jl | 2 +- .../src/termination_conditions.jl | 12 +-- lib/SimpleNonlinearSolve/Project.toml | 31 ++++++- lib/SimpleNonlinearSolve/src/lbroyden.jl | 7 +- lib/SimpleNonlinearSolve/src/utils.jl | 61 ++++++++++--- .../test/gpu/cuda_tests.jl | 90 +++++++++++++++++++ lib/SimpleNonlinearSolve/test/runtests.jl | 10 ++- 12 files changed, 219 insertions(+), 30 deletions(-) create mode 100644 lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 8ddf4f4ff..f6b4ed6eb 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -27,6 +27,34 @@ steps: # Don't run Buildkite if the commit message includes the text [skip tests] if: build.message !~ /\[skip tests\]/ + - label: "Julia 1 (SimpleNonlinearSolve)" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/SimpleNonlinearSolve -e ' + import Pkg; + Pkg.Registry.update(); + # Install packages present in subdirectories + dev_pks = Pkg.PackageSpec[]; + for path in ("lib/NonlinearSolveBase", "lib/BracketingNonlinearSolve") + push!(dev_pks, Pkg.PackageSpec(; path)) + end + Pkg.develop(dev_pks); + Pkg.instantiate(); + Pkg.test(; coverage=true)' + agents: + queue: "juliagpu" + cuda: "*" + timeout_in_minutes: 60 + # Don't run Buildkite if the commit message includes the text [skip tests] + if: build.message !~ /\[skip tests\]/ + env: GROUP: CUDA JULIA_PKG_SERVER: "" # it often struggles with our large artifacts diff --git a/lib/BracketingNonlinearSolve/src/bisection.jl b/lib/BracketingNonlinearSolve/src/bisection.jl index 1611ad34e..e51416145 100644 --- a/lib/BracketingNonlinearSolve/src/bisection.jl +++ b/lib/BracketingNonlinearSolve/src/bisection.jl @@ -28,7 +28,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Bisection, fl, fr = f(left), f(right) abstol = NonlinearSolveBase.get_tolerance( - abstol, promote_type(eltype(left), eltype(right))) + left, abstol, promote_type(eltype(left), eltype(right))) if iszero(fl) return SciMLBase.build_solution( diff --git a/lib/BracketingNonlinearSolve/src/brent.jl b/lib/BracketingNonlinearSolve/src/brent.jl index fea2ce3f4..fb3740e98 100644 --- a/lib/BracketingNonlinearSolve/src/brent.jl +++ b/lib/BracketingNonlinearSolve/src/brent.jl @@ -15,7 +15,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Brent, args...; ϵ = eps(convert(typeof(fl), 1)) abstol = NonlinearSolveBase.get_tolerance( - abstol, promote_type(eltype(left), eltype(right))) + left, abstol, promote_type(eltype(left), eltype(right))) if iszero(fl) return SciMLBase.build_solution( diff --git a/lib/BracketingNonlinearSolve/src/falsi.jl b/lib/BracketingNonlinearSolve/src/falsi.jl index 8c62b95c6..f56155ef7 100644 --- a/lib/BracketingNonlinearSolve/src/falsi.jl +++ b/lib/BracketingNonlinearSolve/src/falsi.jl @@ -15,7 +15,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...; fl, fr = f(left), f(right) abstol = NonlinearSolveBase.get_tolerance( - abstol, promote_type(eltype(left), eltype(right))) + left, abstol, promote_type(eltype(left), eltype(right))) if iszero(fl) return SciMLBase.build_solution( diff --git a/lib/BracketingNonlinearSolve/src/itp.jl b/lib/BracketingNonlinearSolve/src/itp.jl index 4798f9030..821047a5a 100644 --- a/lib/BracketingNonlinearSolve/src/itp.jl +++ b/lib/BracketingNonlinearSolve/src/itp.jl @@ -65,7 +65,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::ITP, args...; fl, fr = f(left), f(right) abstol = NonlinearSolveBase.get_tolerance( - abstol, promote_type(eltype(left), eltype(right))) + left, abstol, promote_type(eltype(left), eltype(right))) if iszero(fl) return SciMLBase.build_solution( diff --git a/lib/BracketingNonlinearSolve/src/ridder.jl b/lib/BracketingNonlinearSolve/src/ridder.jl index e4b67a7c7..d988c9dc5 100644 --- a/lib/BracketingNonlinearSolve/src/ridder.jl +++ b/lib/BracketingNonlinearSolve/src/ridder.jl @@ -14,7 +14,7 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...; fl, fr = f(left), f(right) abstol = NonlinearSolveBase.get_tolerance( - abstol, promote_type(eltype(left), eltype(right))) + left, abstol, promote_type(eltype(left), eltype(right))) if iszero(fl) return SciMLBase.build_solution( diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index 8bb58f8eb..7978c19b8 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -34,8 +34,8 @@ function SciMLBase.init( du, u, mode::AbstractNonlinearTerminationMode, saved_value_prototype...; abstol = nothing, reltol = nothing, kwargs...) T = promote_type(eltype(du), eltype(u)) - abstol = get_tolerance(abstol, T) - reltol = get_tolerance(reltol, T) + abstol = get_tolerance(u, abstol, T) + reltol = get_tolerance(u, reltol, T) TT = typeof(abstol) u_unaliased = mode isa AbstractSafeBestNonlinearTerminationMode ? @@ -90,8 +90,8 @@ function SciMLBase.reinit!( cache.u = u_unaliased cache.retcode = ReturnCode.Default - cache.abstol = get_tolerance(abstol, T) - cache.reltol = get_tolerance(reltol, T) + cache.abstol = get_tolerance(u, abstol, T) + cache.reltol = get_tolerance(u, reltol, T) cache.nsteps = 0 TT = typeof(cache.abstol) @@ -274,8 +274,8 @@ end function init_termination_cache(::AbstractNonlinearProblem, abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode, ::Val) T = promote_type(eltype(du), eltype(u)) - abstol = get_tolerance(abstol, T) - reltol = get_tolerance(reltol, T) + abstol = get_tolerance(u, abstol, T) + reltol = get_tolerance(u, reltol, T) cache = SciMLBase.init(du, u, tc; abstol, reltol) return abstol, reltol, cache end diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index 3a20a18b2..fae63544a 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -37,35 +37,62 @@ SimpleNonlinearSolveTrackerExt = "Tracker" [compat] ADTypes = "1.2" +Accessors = "0.1" +AllocCheck = "0.1.1" +Aqua = "0.8.7" ArrayInterface = "7.16" BracketingNonlinearSolve = "1" +CUDA = "5.3" ChainRulesCore = "1.24" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" DiffEqBase = "6.155" DifferentiationInterface = "0.6.1" +Enzyme = "0.13" +ExplicitImports = "1.9" FastClosures = "0.3.2" FiniteDiff = "2.24.0" ForwardDiff = "0.10.36" InteractiveUtils = "<0.0.1, 1" -LinearAlgebra = "1.10" LineSearch = "0.1.3" +LinearAlgebra = "1.10" MaybeInplace = "0.1.4" +NonlinearProblemLibrary = "0.1.2" NonlinearSolveBase = "1" +Pkg = "1.10" +PolyesterForwardDiff = "0.1" PrecompileTools = "1.2" +Random = "1.10" Reexport = "1.2" ReverseDiff = "1.15" SciMLBase = "2.50" +SciMLSensitivity = "7.68" +StaticArrays = "1.9" StaticArraysCore = "1.4.3" Test = "1.10" TestItemRunner = "1" Tracker = "0.2.35" +Zygote = "0.6.70" julia = "1.10" [extras] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["InteractiveUtils", "Test", "TestItemRunner"] +test = ["AllocCheck", "Aqua", "CUDA", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"] diff --git a/lib/SimpleNonlinearSolve/src/lbroyden.jl b/lib/SimpleNonlinearSolve/src/lbroyden.jl index 2a9ace6dc..1ab200f74 100644 --- a/lib/SimpleNonlinearSolve/src/lbroyden.jl +++ b/lib/SimpleNonlinearSolve/src/lbroyden.jl @@ -204,7 +204,12 @@ end for i in 1:threshold static_idx, static_idx_p1 = Val(i - 1), Val(i) push!(calls, quote - α = ls_cache === nothing ? true : ls_cache(xo, δx) + if ls_cache === nothing + α = true + else + ls_sol = solve!(ls_cache, xo, δx) + α = ls_sol.step_size # Ignores the return code for now + end x = xo .+ α .* δx fx = prob.f(x, prob.p) δf = fx - fo diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 19d44cf21..011788a1c 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -2,6 +2,7 @@ module Utils using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface +using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface, Constant using FastClosures: @closure using LinearAlgebra: LinearAlgebra, I, diagind @@ -116,25 +117,35 @@ restructure(::Number, x::Number) = x safe_vec(x::AbstractArray) = vec(x) safe_vec(x::Number) = x +abstract type AbstractJacobianMode end + +struct AnalyticJacobian <: AbstractJacobianMode end +@concrete struct DIExtras <: AbstractJacobianMode + prep +end +struct DINoPreparation <: AbstractJacobianMode end + +# While we could run prep in other cases, we don't since we need it completely +# non-allocating for running inside GPU kernels function prepare_jacobian(prob, autodiff, _, x::Number) if SciMLBase.has_jac(prob.f) || SciMLBase.has_vjp(prob.f) || SciMLBase.has_jvp(prob.f) - return nothing + return AnalyticJacobian() end - return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p)) + # return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p)) + return DINoPreparation() end function prepare_jacobian(prob, autodiff, fx, x) - if SciMLBase.has_jac(prob.f) - return nothing - end + SciMLBase.has_jac(prob.f) && return AnalyticJacobian() if SciMLBase.isinplace(prob.f) - return DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p)) + return DIExtras(DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p))) else + x isa SArray && return DINoPreparation() return DI.prepare_jacobian(prob.f, autodiff, x, Constant(prob.p)) end end function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras) - if extras === nothing + if extras isa AnalyticJacobian if SciMLBase.has_jac(prob.f) return prob.f.jac(x, prob.p) elseif SciMLBase.has_vjp(prob.f) @@ -143,11 +154,15 @@ function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras) return prob.f.jvp(one(x), x, prob.p) end end - return DI.derivative(prob.f, extras, autodiff, x, Constant(prob.p)) + if extras isa DIExtras + return DI.derivative(prob.f, extras.prep, autodiff, x, Constant(prob.p)) + else + return DI.derivative(prob.f, autodiff, x, Constant(prob.p)) + end end function compute_jacobian!!(J, prob, autodiff, fx, x, extras) if J === nothing - if extras === nothing + if extras isa AnalyticJacobian if SciMLBase.isinplace(prob.f) J = similar(fx, length(fx), length(x)) prob.f.jac(J, x, prob.p) @@ -157,12 +172,17 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras) end end if SciMLBase.isinplace(prob) - return DI.jacobian(prob.f, fx, extras, autodiff, x, Constant(prob.p)) + @assert extras isa DIExtras + return DI.jacobian(prob.f, fx, extras.prep, autodiff, x, Constant(prob.p)) else - return DI.jacobian(prob.f, extras, autodiff, x, Constant(prob.p)) + if extras isa DIExtras + return DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p)) + else + return DI.jacobian(prob.f, autodiff, x, Constant(prob.p)) + end end end - if extras === nothing + if extras isa AnalyticJacobian if SciMLBase.isinplace(prob) prob.jac(J, x, prob.p) return J @@ -171,9 +191,22 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras) end end if SciMLBase.isinplace(prob) - DI.jacobian!(prob.f, fx, J, extras, autodiff, x, Constant(prob.p)) + @assert extras isa DIExtras + DI.jacobian!(prob.f, fx, J, extras.prep, autodiff, x, Constant(prob.p)) else - DI.jacobian!(prob.f, J, extras, autodiff, x, Constant(prob.p)) + if ArrayInterface.can_setindex(J) + if extras isa DIExtras + DI.jacobian!(prob.f, J, extras.prep, autodiff, x, Constant(prob.p)) + else + DI.jacobian!(prob.f, J, autodiff, x, Constant(prob.p)) + end + else + if extras isa DIExtras + J = DI.jacobian(prob.f, extras.prep, autodiff, x, Constant(prob.p)) + else + J = DI.jacobian(prob.f, autodiff, x, Constant(prob.p)) + end + end end return J end diff --git a/lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl b/lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl new file mode 100644 index 000000000..8ecee2fed --- /dev/null +++ b/lib/SimpleNonlinearSolve/test/gpu/cuda_tests.jl @@ -0,0 +1,90 @@ +@testitem "Solving on CUDA" tags=[:cuda] begin + using StaticArrays, CUDA, SimpleNonlinearSolve + + if CUDA.functional() + CUDA.allowscalar(false) + + f(u, p) = u .* u .- 2 + f!(du, u, p) = (du .= u .* u .- 2) + + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleNewtonRaphson(), + SimpleDFSane(), + SimpleTrustRegion(), + SimpleTrustRegion(; nlsolve_update_rule = Val(true)), + SimpleBroyden(), + SimpleLimitedMemoryBroyden(), + SimpleKlement(), + SimpleHalley(), + SimpleBroyden(; linesearch = Val(true)), + SimpleLimitedMemoryBroyden(; linesearch = Val(true)) + ) + # Static Arrays + u0 = @SVector[1.0f0, 1.0f0] + probN = NonlinearProblem{false}(f, u0) + sol = solve(probN, alg; abstol = 1.0f-6) + @test SciMLBase.successful_retcode(sol) + @test maximum(abs, sol.resid) ≤ 1.0f-6 + + # Regular Arrays + u0 = [1.0, 1.0] + probN = NonlinearProblem{false}(f, u0) + sol = solve(probN, alg; abstol = 1.0f-6) + @test SciMLBase.successful_retcode(sol) + @test maximum(abs, sol.resid) ≤ 1.0f-6 + + # Regular Arrays Inplace + if !(alg isa SimpleHalley) + u0 = [1.0, 1.0] + probN = NonlinearProblem{true}(f!, u0) + sol = solve(probN, alg; abstol = 1.0f-6) + @test SciMLBase.successful_retcode(sol) + @test maximum(abs, sol.resid) ≤ 1.0f-6 + end + end + end +end + +@testitem "CUDA Kernel Launch Test" tags=[:cuda] begin + using StaticArrays, CUDA, SimpleNonlinearSolve + using NonlinearSolveBase: ImmutableNonlinearProblem + + if CUDA.functional() + CUDA.allowscalar(false) + + f(u, p) = u .* u .- p + + function kernel_function(prob, alg) + solve(prob, alg) + return nothing + end + + @testset for u0 in (1.0f0, @SVector[1.0f0, 1.0f0]) + prob = convert(ImmutableNonlinearProblem, NonlinearProblem{false}(f, u0, 2.0f0)) + + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleNewtonRaphson(), + SimpleDFSane(), + SimpleTrustRegion(), + SimpleTrustRegion(; nlsolve_update_rule = Val(true)), + SimpleBroyden(), + SimpleLimitedMemoryBroyden(), + SimpleKlement(), + SimpleHalley(), + SimpleBroyden(; linesearch = Val(true)), + SimpleLimitedMemoryBroyden(; linesearch = Val(true)) + ) + @test begin + try + @cuda kernel_function(prob, alg) + @info "Successfully launched kernel for $(alg)." + true + catch err + @error "Kernel Launch failed for $(alg)." + false + end + end broken=(alg isa SimpleHalley && u0 isa StaticArray) + end + end + end +end diff --git a/lib/SimpleNonlinearSolve/test/runtests.jl b/lib/SimpleNonlinearSolve/test/runtests.jl index 6ea6326b0..dde4bacf4 100644 --- a/lib/SimpleNonlinearSolve/test/runtests.jl +++ b/lib/SimpleNonlinearSolve/test/runtests.jl @@ -1,5 +1,11 @@ -using TestItemRunner, InteractiveUtils +using TestItemRunner, InteractiveUtils, Pkg @info sprint(InteractiveUtils.versioninfo) -@run_package_tests +const GROUP = lowercase(get(ENV, "GROUP", "All")) + +if GROUP == "all" + @run_package_tests +else + @run_package_tests filter=ti->(Symbol(GROUP) in ti.tags) +end