From dd4f33705805b4f47da45d5fe290bcb4241bc323 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 May 2020 00:30:22 +0200 Subject: [PATCH 01/13] Fix typo --- src/context_implementations.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8f914a2f8..5fe778988 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -302,7 +302,8 @@ function get_and_set_val!( else r = init(dist, spl, n) for i in 1:n - push!(vi, vns[i], r[:,i], dist, spl) + vn = vns[i] + push!(vi, vn, r[:,i], dist, spl) settrans!(vi, false, vn) end end From 75eb4e981050b57323be5c72e71e202f315311c6 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 May 2020 00:30:39 +0200 Subject: [PATCH 02/13] Implement `set_order!` --- src/threadsafe.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 2e71cd410..ceab89e44 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -45,6 +45,7 @@ syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) setgid!(vi.varinfo, gid, vn) end +setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) From 9f0c204f13dd74b8e504709dbf8a04b5974eb272 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 May 2020 00:40:22 +0200 Subject: [PATCH 03/13] Initial try at fixing incompatibility of Zygote with multithreading --- src/compat/ad.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 88a28f840..75d5b42eb 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -9,3 +9,16 @@ ZygoteRules.@adjoint function push!( ) return push!(vi, vn, r, dist, gidset), _ -> nothing end + +# Multithreaded evaluation is not compatible with Zygote. +ZygoteRules.@adjoint function (model::Model)( + vi::AbstractVarInfo, + spl::AbstractSampler, + ctx::AbstractContext +) + function evaluate(vi, spl, ctx) + return evaluate_singlethreaded(model, vi, spl, ctx) + end + return ZygoteRules.pullback(evaluate, vi, spl, ctx) +end + From 462c32fe2759509994eb57243c08a7ed618f417a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 May 2020 01:49:57 +0200 Subject: [PATCH 04/13] Add basic AD test --- Project.toml | 3 ++- test/compat/ad.jl | 67 +++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 test/compat/ad.jl diff --git a/Project.toml b/Project.toml index 88bee3794..a54a1a905 100644 --- a/Project.toml +++ b/Project.toml @@ -45,6 +45,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"] +test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"] diff --git a/test/compat/ad.jl b/test/compat/ad.jl new file mode 100644 index 000000000..e801647dc --- /dev/null +++ b/test/compat/ad.jl @@ -0,0 +1,67 @@ +using DynamicPPL +using Distributions + +using ForwardDiff +using Zygote +using Tracker + +@testset "logp" begin + @model function admodel() + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + 1.5 ~ Normal(m, sqrt(s)) + 2.0 ~ Normal(m, sqrt(s)) + return s, m + end + + model = admodel() + vi = VarInfo(model) + model(vi, SampleFromPrior()) + x = [vi[@varname(s)], vi[@varname(m)]] + + dist_s = InverseGamma(2,3) + + # Hand-written log probabilities for vector `x = [s, m]`. + function logp_manual(x) + s = x[1] + m = x[2] + dist = Normal(m, sqrt(s)) + + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) + + logpdf(dist, 1.5) + logpdf(dist, 2.0) + end + + # Log probabilities for vector `x = [s, m]` using the model. + function logp_model(x) + new_vi = VarInfo(vi, SampleFromPrior(), x) + model(new_vi, SampleFromPrior()) + return getlogp(new_vi) + end + + # Check that both functions return the same values. + lp = logp_manual(x) + @test logp_model(x) ≈ lp + + # Gradients based on the manual implementation. + grad = ForwardDiff.gradient(logp_manual, x) + + y, back = Tracker.forward(logp_manual, x) + @test Tracker.data(y) ≈ lp + @test Tracker.data(back(1)[1]) ≈ grad + + y, back = Zygote.pullback(logp_manual, x) + @test y ≈ lp + @test back(1)[1] ≈ grad + + # Gradients based on the model. + @test ForwardDiff.gradient(logp_model, x) ≈ grad + + y, back = Tracker.forward(logp_model, x) + @test Tracker.data(y) ≈ lp + @test Tracker.data(back(1)[1]) ≈ grad + + y, back = Zygote.gradient(logp_model, x) + @test y ≈ lp + @test back(1) ≈ grad +end + diff --git a/test/runtests.jl b/test/runtests.jl index 154bf6dcf..84e6d38af 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ turnprogress(false) @testset "DynamicPPL.jl" begin include("utils.jl") include("compiler.jl") + include("compat/ad.jl") include("varinfo.jl") include("sampler.jl") include("prob_macro.jl") From 7c642976d15feabeda5bb32d455460c591992908 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 May 2020 02:05:00 +0200 Subject: [PATCH 05/13] Ensure that both `evaluate_singlethreaded` and `evaluate_multithreaded` are tested --- test/compiler.jl | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 440c40dcd..9b134ec87 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -581,7 +581,7 @@ end @test all(iszero(model()) for _ in 1:1000) end @testset "threading" begin - @info "Peforming threading tests with $(Threads.nthreads()) threads" + println("Peforming threading tests with $(Threads.nthreads()) threads") x = rand(10_000) @@ -596,9 +596,19 @@ end wthreads(x)(vi) lp_w_threads = getlogp(vi) - println("With threading:") + println("With `@threads`:") + println(" default:") @time wthreads(x)(vi) + # Ensure that we use `ThreadSafeVarInfo`. + @test getlogp(vi) ≈ lp_w_threads + DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), + DefaultContext()) + + println(" evaluate_multithreaded:") + @time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), + DefaultContext()) + @model function wothreads(x) x[1] ~ Normal(0, 1) for i in 2:length(x) @@ -610,9 +620,19 @@ end wothreads(x)(vi) lp_wo_threads = getlogp(vi) - println("Without threading:") + println("Without `@threads`:") + println(" default:") @time wothreads(x)(vi) @test lp_w_threads ≈ lp_wo_threads + + # Ensure that we use `VarInfo`. + DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), + DefaultContext()) + @test getlogp(vi) ≈ lp_w_threads + + println(" evaluate_singlethreaded:") + @time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), + DefaultContext()) end end From b81363aca343df434c0cce1ec8e56d60a5ce8a9d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 May 2020 02:21:31 +0200 Subject: [PATCH 06/13] Fix typo --- test/compat/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compat/ad.jl b/test/compat/ad.jl index e801647dc..6fb218016 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -60,7 +60,7 @@ using Tracker @test Tracker.data(y) ≈ lp @test Tracker.data(back(1)[1]) ≈ grad - y, back = Zygote.gradient(logp_model, x) + y, back = Zygote.pullback(logp_model, x) @test y ≈ lp @test back(1) ≈ grad end From 924f635634f4854f1692d9016bebcbcf75573548 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 3 May 2020 22:00:41 +1000 Subject: [PATCH 07/13] fix Zygote support --- src/compat/ad.jl | 18 ++++++------------ src/threadsafe.jl | 29 ++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 75d5b42eb..a1f716c89 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -7,18 +7,12 @@ ZygoteRules.@adjoint function push!( dist::Distribution, gidset::Set{Selector} ) - return push!(vi, vn, r, dist, gidset), _ -> nothing + return push!(vi, vn, r, dist, gidset), _ -> ntuple(_ -> nothing, 5) end -# Multithreaded evaluation is not compatible with Zygote. -ZygoteRules.@adjoint function (model::Model)( - vi::AbstractVarInfo, - spl::AbstractSampler, - ctx::AbstractContext -) - function evaluate(vi, spl, ctx) - return evaluate_singlethreaded(model, vi, spl, ctx) - end - return ZygoteRules.pullback(evaluate, vi, spl, ctx) +ZygoteRules.@adjoint function Threads.nthreads() + Threads.nthreads(), _ -> (nothing,) +end +ZygoteRules.@adjoint function Threads.threadid() + Threads.threadid(), _ -> (nothing,) end - diff --git a/src/threadsafe.jl b/src/threadsafe.jl index ceab89e44..32579f694 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -1,3 +1,30 @@ +################# +# VectorOfLogps # +################# + +struct VectorOfLogps{T1, T2 <: Vector{Base.RefValue{T1}}} + v::T2 +end +VectorOfLogps(::Type{T}, n::Int) where {T} = VectorOfLogps(zero(T), n) +function VectorOfLogps(val::T, n::Int) where {T} + v = [val for i in 1:Threads.nthreads()] + return VectorOfLogps(v) +end +VectorOfLogps(v::Vector) = VectorOfLogps(Ref.(v)) +Base.getindex(v::VectorOfLogps, i::Integer) = v.v[i][] +function Base.setindex!(v::VectorOfLogps, val, i::Integer) + v.v[i][] = val + return v +end +Base.sum(v::VectorOfLogps) = sum(v -> v[], v.v) +function Base.fill!(v::VectorOfLogps, val) + for i in 1:length(v.v) + v.v[i][] = val + end + return v +end + + """ ThreadSafeVarInfo @@ -9,7 +36,7 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo logps::L end function ThreadSafeVarInfo(vi::AbstractVarInfo) - return ThreadSafeVarInfo(vi, [zero(getlogp(vi)) for _ in 1:Threads.nthreads()]) + return ThreadSafeVarInfo(vi, VectorOfLogps(zero(getlogp(vi)), Threads.nthreads())) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi From 4bc1103195607241af896cb67234fd9c771490bc Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 3 May 2020 22:51:49 +1000 Subject: [PATCH 08/13] fix DPPL test --- test/compat/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compat/ad.jl b/test/compat/ad.jl index 6fb218016..d7afc867c 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -62,6 +62,6 @@ using Tracker y, back = Zygote.pullback(logp_model, x) @test y ≈ lp - @test back(1) ≈ grad + @test back(1)[1] ≈ grad end From 10039f020c2b959eaf8343746a4cb5aabd9cc226 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 3 May 2020 22:26:18 +1000 Subject: [PATCH 09/13] Update src/compat/ad.jl Co-authored-by: David Widmann --- src/compat/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compat/ad.jl b/src/compat/ad.jl index a1f716c89..e37052ce9 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -7,7 +7,7 @@ ZygoteRules.@adjoint function push!( dist::Distribution, gidset::Set{Selector} ) - return push!(vi, vn, r, dist, gidset), _ -> ntuple(_ -> nothing, 5) + return push!(vi, vn, r, dist, gidset), _ -> nothing end ZygoteRules.@adjoint function Threads.nthreads() From 08dac12e73de2db92fb86560071dc494979fef62 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 3 May 2020 22:26:29 +1000 Subject: [PATCH 10/13] Update src/compat/ad.jl Co-authored-by: David Widmann --- src/compat/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compat/ad.jl b/src/compat/ad.jl index e37052ce9..9d8cf751a 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -11,8 +11,8 @@ ZygoteRules.@adjoint function push!( end ZygoteRules.@adjoint function Threads.nthreads() - Threads.nthreads(), _ -> (nothing,) + return Threads.nthreads(), _ -> nothing end ZygoteRules.@adjoint function Threads.threadid() - Threads.threadid(), _ -> (nothing,) + return Threads.threadid(), _ -> nothing end From 008392aa0d6b9975a398a91c8088a3e414958016 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Mon, 4 May 2020 00:55:41 +1000 Subject: [PATCH 11/13] simplify VectorOfLogps constructor --- src/threadsafe.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 32579f694..1992742f6 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -5,12 +5,10 @@ struct VectorOfLogps{T1, T2 <: Vector{Base.RefValue{T1}}} v::T2 end -VectorOfLogps(::Type{T}, n::Int) where {T} = VectorOfLogps(zero(T), n) function VectorOfLogps(val::T, n::Int) where {T} - v = [val for i in 1:Threads.nthreads()] + v = [Ref(val) for i in 1:n] return VectorOfLogps(v) end -VectorOfLogps(v::Vector) = VectorOfLogps(Ref.(v)) Base.getindex(v::VectorOfLogps, i::Integer) = v.v[i][] function Base.setindex!(v::VectorOfLogps, val, i::Integer) v.v[i][] = val From 16fd886d28b856a868fadfdb51ee239b72a1557d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 May 2020 21:49:13 +0200 Subject: [PATCH 12/13] Add more tests of AD and threading, independently of Turing --- test/compat/ad.jl | 90 +++++++++++++-------------------------------- test/compiler.jl | 55 ---------------------------- test/runtests.jl | 19 +++++++++- test/test_util.jl | 38 +++++++++++++++++++ test/threadsafe.jl | 91 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 171 insertions(+), 122 deletions(-) create mode 100644 test/test_util.jl create mode 100644 test/threadsafe.jl diff --git a/test/compat/ad.jl b/test/compat/ad.jl index d7afc867c..814a3d245 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -1,67 +1,27 @@ -using DynamicPPL -using Distributions - -using ForwardDiff -using Zygote -using Tracker - -@testset "logp" begin - @model function admodel() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - 1.5 ~ Normal(m, sqrt(s)) - 2.0 ~ Normal(m, sqrt(s)) - return s, m - end - - model = admodel() - vi = VarInfo(model) - model(vi, SampleFromPrior()) - x = [vi[@varname(s)], vi[@varname(m)]] - - dist_s = InverseGamma(2,3) - - # Hand-written log probabilities for vector `x = [s, m]`. - function logp_manual(x) - s = x[1] - m = x[2] - dist = Normal(m, sqrt(s)) - - return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) + - logpdf(dist, 1.5) + logpdf(dist, 2.0) - end - - # Log probabilities for vector `x = [s, m]` using the model. - function logp_model(x) - new_vi = VarInfo(vi, SampleFromPrior(), x) - model(new_vi, SampleFromPrior()) - return getlogp(new_vi) +@testset "ad.jl" begin + @testset "logp" begin + # Hand-written log probabilities for vector `x = [s, m]`. + function logp_gdemo_default(x) + s = x[1] + m = x[2] + dist = Normal(m, sqrt(s)) + + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) + + logpdf(dist, 1.5) + logpdf(dist, 2.0) + end + + test_model_ad(gdemo_default, logp_gdemo_default) + + @model function wishart_ad() + v ~ Wishart(7, [1 0.5; 0.5 1]) + end + + # Hand-written log probabilities for `x = [v]`. + function logp_wishart_ad(x) + dist = Wishart(7, [1 0.5; 0.5 1]) + return logpdf(dist, reshape(x, 2, 2)) + end + + test_model_ad(wishart_ad(), logp_wishart_ad) end - - # Check that both functions return the same values. - lp = logp_manual(x) - @test logp_model(x) ≈ lp - - # Gradients based on the manual implementation. - grad = ForwardDiff.gradient(logp_manual, x) - - y, back = Tracker.forward(logp_manual, x) - @test Tracker.data(y) ≈ lp - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(logp_manual, x) - @test y ≈ lp - @test back(1)[1] ≈ grad - - # Gradients based on the model. - @test ForwardDiff.gradient(logp_model, x) ≈ grad - - y, back = Tracker.forward(logp_model, x) - @test Tracker.data(y) ≈ lp - @test Tracker.data(back(1)[1]) ≈ grad - - y, back = Zygote.pullback(logp_model, x) - @test y ≈ lp - @test back(1)[1] ≈ grad end - diff --git a/test/compiler.jl b/test/compiler.jl index 9b134ec87..83f47ce42 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -580,59 +580,4 @@ end model = demo() @test all(iszero(model()) for _ in 1:1000) end - @testset "threading" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - - x = rand(10_000) - - @model function wthreads(x) - x[1] ~ Normal(0, 1) - Threads.@threads for i in 2:length(x) - x[i] ~ Normal(x[i-1], 1) - end - end - - vi = VarInfo() - wthreads(x)(vi) - lp_w_threads = getlogp(vi) - - println("With `@threads`:") - println(" default:") - @time wthreads(x)(vi) - - # Ensure that we use `ThreadSafeVarInfo`. - @test getlogp(vi) ≈ lp_w_threads - DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), - DefaultContext()) - - println(" evaluate_multithreaded:") - @time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), - DefaultContext()) - - @model function wothreads(x) - x[1] ~ Normal(0, 1) - for i in 2:length(x) - x[i] ~ Normal(x[i-1], 1) - end - end - - vi = VarInfo() - wothreads(x)(vi) - lp_wo_threads = getlogp(vi) - - println("Without `@threads`:") - println(" default:") - @time wothreads(x)(vi) - - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), - DefaultContext()) - @test getlogp(vi) ≈ lp_w_threads - - println(" evaluate_singlethreaded:") - @time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), - DefaultContext()) - end end diff --git a/test/runtests.jl b/test/runtests.jl index 84e6d38af..d1c457737 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,17 +1,32 @@ -using Test, DynamicPPL +using DynamicPPL +using Distributions +using ForwardDiff +using Tracker +using Zygote + +using Random +using Test + dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1] include(dir*"/test/Turing/Turing.jl") using .Turing turnprogress(false) +include("test_util.jl") + @testset "DynamicPPL.jl" begin include("utils.jl") include("compiler.jl") - include("compat/ad.jl") include("varinfo.jl") include("sampler.jl") include("prob_macro.jl") include("independence.jl") include("distribution_wrappers.jl") + + include("threadsafe.jl") + + @testset "compat" begin + include(joinpath("compat", "ad.jl")) + end end diff --git a/test/test_util.jl b/test/test_util.jl new file mode 100644 index 000000000..d8a926656 --- /dev/null +++ b/test/test_util.jl @@ -0,0 +1,38 @@ +function test_model_ad(model, logp_manual) + vi = VarInfo(model) + model(vi, SampleFromPrior()) + x = DynamicPPL.getall(vi) + + # Log probabilities using the model. + function logp_model(x) + new_vi = VarInfo(vi, SampleFromPrior(), x) + model(new_vi, SampleFromPrior()) + return getlogp(new_vi) + end + + # Check that both functions return the same values. + lp = logp_manual(x) + @test logp_model(x) ≈ lp + + # Gradients based on the manual implementation. + grad = ForwardDiff.gradient(logp_manual, x) + + y, back = Tracker.forward(logp_manual, x) + @test Tracker.data(y) ≈ lp + @test Tracker.data(back(1)[1]) ≈ grad + + y, back = Zygote.pullback(logp_manual, x) + @test y ≈ lp + @test back(1)[1] ≈ grad + + # Gradients based on the model. + @test ForwardDiff.gradient(logp_model, x) ≈ grad + + y, back = Tracker.forward(logp_model, x) + @test Tracker.data(y) ≈ lp + @test Tracker.data(back(1)[1]) ≈ grad + + y, back = Zygote.pullback(logp_model, x) + @test y ≈ lp + @test back(1)[1] ≈ grad +end diff --git a/test/threadsafe.jl b/test/threadsafe.jl new file mode 100644 index 000000000..8fddb313f --- /dev/null +++ b/test/threadsafe.jl @@ -0,0 +1,91 @@ +@testset "threadsafe.jl" begin + @testset "constructor" begin + vi = VarInfo(gdemo_default) + threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) + + @test threadsafe_vi.varinfo === vi + @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} + @test length(threadsafe_vi.logps) == Threads.nthreads() + @test all(iszero(x[]) for x in threadsafe_vi.logps) + end + + # TODO: Add more tests of the public API + @testset "API" begin + vi = VarInfo(gdemo_default) + threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) + + lp = getlogp(vi) + @test getlogp(threadsafe_vi) == lp + + acclogp!(threadsafe_vi, 42) + @test threadsafe_vi.logps[Threads.threadid()][] == 42 + @test getlogp(vi) == lp + @test getlogp(threadsafe_vi) == lp + 42 + + resetlogp!(threadsafe_vi) + @test iszero(getlogp(vi)) + @test iszero(getlogp(threadsafe_vi)) + @test all(iszero(x[]) for x in threadsafe_vi.logps) + + setlogp!(threadsafe_vi, 42) + @test getlogp(vi) == 42 + @test getlogp(threadsafe_vi) == 42 + @test all(iszero(x[]) for x in threadsafe_vi.logps) + end + + @testset "model" begin + println("Peforming threading tests with $(Threads.nthreads()) threads") + + x = rand(10_000) + + @model function wthreads(x) + x[1] ~ Normal(0, 1) + Threads.@threads for i in 2:length(x) + x[i] ~ Normal(x[i-1], 1) + end + end + + vi = VarInfo() + wthreads(x)(vi) + lp_w_threads = getlogp(vi) + + println("With `@threads`:") + println(" default:") + @time wthreads(x)(vi) + + # Ensure that we use `ThreadSafeVarInfo`. + @test getlogp(vi) ≈ lp_w_threads + DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), + DefaultContext()) + + println(" evaluate_multithreaded:") + @time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), + DefaultContext()) + + @model function wothreads(x) + x[1] ~ Normal(0, 1) + for i in 2:length(x) + x[i] ~ Normal(x[i-1], 1) + end + end + + vi = VarInfo() + wothreads(x)(vi) + lp_wo_threads = getlogp(vi) + + println("Without `@threads`:") + println(" default:") + @time wothreads(x)(vi) + + @test lp_w_threads ≈ lp_wo_threads + + # Ensure that we use `VarInfo`. + DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), + DefaultContext()) + @test getlogp(vi) ≈ lp_w_threads + + println(" evaluate_singlethreaded:") + @time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), + DefaultContext()) + end +end From 6d8cacafe2eaaf1f5ac816d514ccdc83d681b814 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 3 May 2020 21:54:36 +0200 Subject: [PATCH 13/13] Increment version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a54a1a905..c17dc5885 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" authors = ["mohamed82008 "] -version = "0.7.0" +version = "0.7.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"