Skip to content

Commit

Permalink
ADTypes + ADgradient Performance (#727)
Browse files Browse the repository at this point in the history
* ADTypes interop

* Improve comment

* Bump patch version

* Formatting

* Formatting

* Improve documentation

* Testing infrastructure

* Remove extras from main Project toml

* Apply some basic tests

* Locate tests better

* Internal _make_ad_gradient

* Mark failing tests as broken

* Formatting

* Update Project.toml

* Updates

* Bump patch version

* Bump patch again

---------

Co-authored-by: Penelope Yong <[email protected]>
  • Loading branch information
willtebbutt and penelopeysm authored Dec 7, 2024
1 parent 5a58571 commit f0c31f0
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 44 deletions.
13 changes: 1 addition & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.31.2"
version = "0.31.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -30,15 +30,13 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
Expand All @@ -63,15 +61,6 @@ MacroTools = "0.5.6"
OrderedCollections = "1"
Random = "1.6"
Requires = "1"
ReverseDiff = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
26 changes: 0 additions & 26 deletions ext/DynamicPPLReverseDiffExt.jl

This file was deleted.

16 changes: 16 additions & 0 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,19 @@ function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
end
# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))

# This is important for performance -- one needs to provide `ADGradient` with a vector of
# parameters, or DifferentiationInterface will not have sufficient information to e.g.
# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate
# a tape when using ReverseDiff.jl.
function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
end

function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction)
return _make_ad_gradient(ad, f)
end
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
return _make_ad_gradient(ad, f)
end
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand All @@ -17,6 +18,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -43,6 +45,7 @@ LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6.0.4"
MacroTools = "0.5.6"
Mooncake = "0.4.50"
ReverseDiff = "1"
StableRNGs = "1"
Tracker = "0.2.23"
Expand Down
21 changes: 15 additions & 6 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "AD: ForwardDiff and ReverseDiff" begin
@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
Expand All @@ -17,11 +17,20 @@
θ = convert(Vector{Float64}, varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)

@testset "ReverseDiff with compile=$compile" for compile in (false, true)
adtype = ADTypes.AutoReverseDiff(; compile=compile)
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ref_grad
@testset "$adtype" for adtype in [
ADTypes.AutoReverseDiff(; compile=false),
ADTypes.AutoReverseDiff(; compile=true),
ADTypes.AutoMooncake(; config=nothing),
]
# Mooncake can't currently handle something that is going on in
# SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now.
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
@test_broken 1 == 0
else
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ref_grad
end
end
end
end
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ using DynamicPPL
using AbstractMCMC
using AbstractPPL
using Bijectors
using DifferentiationInterface
using Distributions
using DistributionsAD
using Documenter
using ForwardDiff
using LogDensityProblems, LogDensityProblemsAD
using MacroTools
using MCMCChains
using Mooncake: Mooncake
using Tracker
using ReverseDiff
using Zygote
Expand Down

2 comments on commit f0c31f0

@penelopeysm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120878

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.31.3 -m "<description of version>" f0c31f045ccd9b69374a68ea24d3823d632a5234
git push origin v0.31.3

Please sign in to comment.