Skip to content

Commit

Permalink
Re-introduce extensions for AD backends
Browse files Browse the repository at this point in the history
  • Loading branch information
lassepe committed Dec 5, 2024
1 parent 57ce8d5 commit d5d01da
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 90 deletions.
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@ authors = ["lassepe <[email protected]> and contributors"]
version = "0.1.15"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
ChainRulesCoreExt = "ChainRulesCore"
ForwardDiffExt = "ForwardDiff"

[compat]
ChainRulesCore = "1"
FastDifferentiation = "0.4"
Expand Down
25 changes: 25 additions & 0 deletions ext/ChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module ChainRulesCoreExt

using ParametricMCPs: ParametricMCPs
using ChainRulesCore: ChainRulesCore

function ChainRulesCore.rrule(::typeof(ParametricMCPs.solve), problem, θ; kwargs...)
solution = ParametricMCPs.solve(problem, θ; kwargs...)
project_to_θ = ChainRulesCore.ProjectTo(θ)

function solve_pullback(∂solution)
no_grad_args = (; ∂self = ChainRulesCore.NoTangent(), ∂problem = ChainRulesCore.NoTangent())

∂θ = ChainRulesCore.@thunk let
∂z∂θ = ParametricMCPs.InternalAutoDiffUtils.solve_jacobian_θ(problem, solution, θ)
∂l∂z = ∂solution.z
project_to_θ(∂z∂θ' * ∂l∂z)
end

no_grad_args..., ∂θ
end

solution, solve_pullback
end

end
26 changes: 26 additions & 0 deletions ext/ForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module ForwardDiffExt

using ParametricMCPs: ParametricMCPs
using ForwardDiff: ForwardDiff

function ParametricMCPs.solve(
problem::ParametricMCPs.ParametricMCP,
θ::AbstractVector{<:ForwardDiff.Dual{T}};
kwargs...,
) where {T}
# strip off the duals:
θ_v = ForwardDiff.value.(θ)
θ_p = ForwardDiff.partials.(θ)
# forward pass
solution = ParametricMCPs.solve(problem, θ_v; kwargs...)
# backward pass
∂z∂θ = ParametricMCPs.InternalAutoDiffUtils.solve_jacobian_θ(problem, solution, θ_v)
# downstream gradient
z_p = ∂z∂θ * θ_p
# glue forward and backward pass together into dual number types
z_d = ForwardDiff.Dual{T}.(solution.z, z_p)

(; z = z_d, solution.status, solution.info)
end

end
87 changes: 0 additions & 87 deletions src/AutoDiff.jl

This file was deleted.

46 changes: 46 additions & 0 deletions src/InternalAutoDiffUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module InternalAutoDiffUtils

using ..ParametricMCPs: get_problem_size, get_result_buffer, get_parameter_dimension
using SparseArrays: SparseArrays
using LinearAlgebra: LinearAlgebra

function solve_jacobian_θ(problem, solution, θ; active_tolerance = 1e-3)
(; jacobian_z!, jacobian_θ!, lower_bounds, upper_bounds) = problem
z_star = solution.z

!isnothing(jacobian_θ!) || throw(
ArgumentError(
"Missing sensitivities. Set `compute_sensitivities = true` when constructing the ParametricMCP.",
),
)

inactive_indices = let
lower_inactive = z_star .>= (lower_bounds .+ active_tolerance)
upper_inactive = z_star .<= (upper_bounds .- active_tolerance)
findall(lower_inactive .& upper_inactive)
end

∂z∂θ = SparseArrays.spzeros(get_problem_size(problem), get_parameter_dimension(problem))
if isempty(inactive_indices)
return ∂z∂θ
end

∂f_reduce∂θ = let
∂f∂θ = get_result_buffer(jacobian_θ!)
jacobian_θ!(∂f∂θ, z_star, θ)
∂f∂θ[inactive_indices, :]
end

∂f_reduced∂z_reduced = let
∂f∂z = get_result_buffer(jacobian_z!)
jacobian_z!(∂f∂z, z_star, θ)
∂f∂z[inactive_indices, inactive_indices]
end

∂z∂θ[inactive_indices, :] =
LinearAlgebra.qr(-collect(∂f_reduced∂z_reduced), LinearAlgebra.ColumnNorm()) \
collect(∂f_reduce∂θ)
∂z∂θ
end

end
2 changes: 1 addition & 1 deletion src/ParametricMCPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ export ParametricMCP, get_parameter_dimension, get_problem_size
include("solver.jl")
export solve

include("AutoDiff.jl")
include("InternalAutoDiffUtils.jl")
end

2 comments on commit d5d01da

@lassepe
Copy link
Member Author

@lassepe lassepe commented on d5d01da Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.1.15 already exists

Please sign in to comment.