diff --git a/Project.toml b/Project.toml index 1144486..7f6f480 100644 --- a/Project.toml +++ b/Project.toml @@ -4,14 +4,20 @@ authors = ["lassepe and contributors"] version = "0.1.15" [deps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +ChainRulesCoreExt = "ChainRulesCore" +ForwardDiffExt = "ForwardDiff" + [compat] ChainRulesCore = "1" FastDifferentiation = "0.4" diff --git a/ext/ChainRulesCoreExt.jl b/ext/ChainRulesCoreExt.jl new file mode 100644 index 0000000..30fc6e2 --- /dev/null +++ b/ext/ChainRulesCoreExt.jl @@ -0,0 +1,25 @@ +module ChainRulesCoreExt + +using ParametricMCPs: ParametricMCPs +using ChainRulesCore: ChainRulesCore + +function ChainRulesCore.rrule(::typeof(ParametricMCPs.solve), problem, θ; kwargs...) + solution = ParametricMCPs.solve(problem, θ; kwargs...) + project_to_θ = ChainRulesCore.ProjectTo(θ) + + function solve_pullback(∂solution) + no_grad_args = (; ∂self = ChainRulesCore.NoTangent(), ∂problem = ChainRulesCore.NoTangent()) + + ∂θ = ChainRulesCore.@thunk let + ∂z∂θ = ParametricMCPs.InternalAutoDiffUtils.solve_jacobian_θ(problem, solution, θ) + ∂l∂z = ∂solution.z + project_to_θ(∂z∂θ' * ∂l∂z) + end + + no_grad_args..., ∂θ + end + + solution, solve_pullback +end + +end diff --git a/ext/ForwardDiffExt.jl b/ext/ForwardDiffExt.jl new file mode 100644 index 0000000..018d415 --- /dev/null +++ b/ext/ForwardDiffExt.jl @@ -0,0 +1,26 @@ +module ForwardDiffExt + +using ParametricMCPs: ParametricMCPs +using ForwardDiff: ForwardDiff + +function ParametricMCPs.solve( + problem::ParametricMCPs.ParametricMCP, + θ::AbstractVector{<:ForwardDiff.Dual{T}}; + kwargs..., +) where {T} + # strip off the duals: + θ_v = ForwardDiff.value.(θ) + θ_p = ForwardDiff.partials.(θ) + # forward pass + solution = ParametricMCPs.solve(problem, θ_v; kwargs...) + # backward pass + ∂z∂θ = ParametricMCPs.InternalAutoDiffUtils.solve_jacobian_θ(problem, solution, θ_v) + # downstream gradient + z_p = ∂z∂θ * θ_p + # glue forward and backward pass together into dual number types + z_d = ForwardDiff.Dual{T}.(solution.z, z_p) + + (; z = z_d, solution.status, solution.info) +end + +end diff --git a/src/AutoDiff.jl b/src/AutoDiff.jl deleted file mode 100644 index b024c44..0000000 --- a/src/AutoDiff.jl +++ /dev/null @@ -1,87 +0,0 @@ -module AutoDiff - -using ..ParametricMCPs: ParametricMCPs, get_problem_size, get_result_buffer, get_parameter_dimension -using ChainRulesCore: ChainRulesCore -using ForwardDiff: ForwardDiff -using SparseArrays: SparseArrays -using LinearAlgebra: LinearAlgebra - -function _solve_jacobian_θ(problem, solution, θ; active_tolerance = 1e-3) - (; jacobian_z!, jacobian_θ!, lower_bounds, upper_bounds) = problem - z_star = solution.z - - !isnothing(jacobian_θ!) || throw( - ArgumentError( - "Missing sensitivities. Set `compute_sensitivities = true` when constructing the ParametricMCP.", - ), - ) - - inactive_indices = let - lower_inactive = z_star .>= (lower_bounds .+ active_tolerance) - upper_inactive = z_star .<= (upper_bounds .- active_tolerance) - findall(lower_inactive .& upper_inactive) - end - - ∂z∂θ = SparseArrays.spzeros(get_problem_size(problem), get_parameter_dimension(problem)) - if isempty(inactive_indices) - return ∂z∂θ - end - - ∂f_reduce∂θ = let - ∂f∂θ = get_result_buffer(jacobian_θ!) - jacobian_θ!(∂f∂θ, z_star, θ) - ∂f∂θ[inactive_indices, :] - end - - ∂f_reduced∂z_reduced = let - ∂f∂z = get_result_buffer(jacobian_z!) - jacobian_z!(∂f∂z, z_star, θ) - ∂f∂z[inactive_indices, inactive_indices] - end - - ∂z∂θ[inactive_indices, :] = - LinearAlgebra.qr(-collect(∂f_reduced∂z_reduced), LinearAlgebra.ColumnNorm()) \ - collect(∂f_reduce∂θ) - ∂z∂θ -end - -function ChainRulesCore.rrule(::typeof(ParametricMCPs.solve), problem, θ; kwargs...) - solution = ParametricMCPs.solve(problem, θ; kwargs...) - project_to_θ = ChainRulesCore.ProjectTo(θ) - - function solve_pullback(∂solution) - no_grad_args = (; ∂self = ChainRulesCore.NoTangent(), ∂problem = ChainRulesCore.NoTangent()) - - ∂θ = ChainRulesCore.@thunk let - ∂z∂θ = _solve_jacobian_θ(problem, solution, θ) - ∂l∂z = ∂solution.z - project_to_θ(∂z∂θ' * ∂l∂z) - end - - no_grad_args..., ∂θ - end - - solution, solve_pullback -end - -function ParametricMCPs.solve( - problem::ParametricMCPs.ParametricMCP, - θ::AbstractVector{<:ForwardDiff.Dual{T}}; - kwargs..., -) where {T} - # strip off the duals: - θ_v = ForwardDiff.value.(θ) - θ_p = ForwardDiff.partials.(θ) - # forward pass - solution = ParametricMCPs.solve(problem, θ_v; kwargs...) - # backward pass - ∂z∂θ = _solve_jacobian_θ(problem, solution, θ_v) - # downstream gradient - z_p = ∂z∂θ * θ_p - # glue forward and backward pass together into dual number types - z_d = ForwardDiff.Dual{T}.(solution.z, z_p) - - (; z = z_d, solution.status, solution.info) -end - -end diff --git a/src/InternalAutoDiffUtils.jl b/src/InternalAutoDiffUtils.jl new file mode 100644 index 0000000..2d4f9fd --- /dev/null +++ b/src/InternalAutoDiffUtils.jl @@ -0,0 +1,46 @@ +module InternalAutoDiffUtils + +using ..ParametricMCPs: get_problem_size, get_result_buffer, get_parameter_dimension +using SparseArrays: SparseArrays +using LinearAlgebra: LinearAlgebra + +function solve_jacobian_θ(problem, solution, θ; active_tolerance = 1e-3) + (; jacobian_z!, jacobian_θ!, lower_bounds, upper_bounds) = problem + z_star = solution.z + + !isnothing(jacobian_θ!) || throw( + ArgumentError( + "Missing sensitivities. Set `compute_sensitivities = true` when constructing the ParametricMCP.", + ), + ) + + inactive_indices = let + lower_inactive = z_star .>= (lower_bounds .+ active_tolerance) + upper_inactive = z_star .<= (upper_bounds .- active_tolerance) + findall(lower_inactive .& upper_inactive) + end + + ∂z∂θ = SparseArrays.spzeros(get_problem_size(problem), get_parameter_dimension(problem)) + if isempty(inactive_indices) + return ∂z∂θ + end + + ∂f_reduce∂θ = let + ∂f∂θ = get_result_buffer(jacobian_θ!) + jacobian_θ!(∂f∂θ, z_star, θ) + ∂f∂θ[inactive_indices, :] + end + + ∂f_reduced∂z_reduced = let + ∂f∂z = get_result_buffer(jacobian_z!) + jacobian_z!(∂f∂z, z_star, θ) + ∂f∂z[inactive_indices, inactive_indices] + end + + ∂z∂θ[inactive_indices, :] = + LinearAlgebra.qr(-collect(∂f_reduced∂z_reduced), LinearAlgebra.ColumnNorm()) \ + collect(∂f_reduce∂θ) + ∂z∂θ +end + +end diff --git a/src/ParametricMCPs.jl b/src/ParametricMCPs.jl index ef35c72..9450ee5 100644 --- a/src/ParametricMCPs.jl +++ b/src/ParametricMCPs.jl @@ -16,5 +16,5 @@ export ParametricMCP, get_parameter_dimension, get_problem_size include("solver.jl") export solve -include("AutoDiff.jl") +include("InternalAutoDiffUtils.jl") end