diff --git a/Project.toml b/Project.toml index 88bee3794..c17dc5885 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" authors = ["mohamed82008 "] -version = "0.7.0" +version = "0.7.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -45,6 +45,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"] +test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"] diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8f914a2f8..5fe778988 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -302,7 +302,8 @@ function get_and_set_val!( else r = init(dist, spl, n) for i in 1:n - push!(vi, vns[i], r[:,i], dist, spl) + vn = vns[i] + push!(vi, vn, r[:,i], dist, spl) settrans!(vi, false, vn) end end diff --git a/src/model.jl b/src/model.jl index 343e3b923..2198e8163 100644 --- a/src/model.jl +++ b/src/model.jl @@ -154,7 +154,7 @@ function evaluate_multithreaded(model, varinfo, sampler, context) end wrapper = ThreadSafeVarInfo(varinfo) result = model.f(model, wrapper, sampler, context) - acclogp!(varinfo, sum(wrapper.logps)) + setlogp!(varinfo, getlogp(wrapper)) return result end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 2e71cd410..996934c53 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -9,29 +9,33 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo logps::L end function ThreadSafeVarInfo(vi::AbstractVarInfo) - return ThreadSafeVarInfo(vi, [zero(getlogp(vi)) for _ in 1:Threads.nthreads()]) + return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi # Instead of updating the log probability of the underlying variables we # just update the array of log probabilities. function acclogp!(vi::ThreadSafeVarInfo, logp) - vi.logps[Threads.threadid()] += logp + vi.logps[Threads.threadid()][] += logp return vi end # The current log probability of the variables has to be computed from # both the wrapped variables and the thread-specific log probabilities. -getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps) +getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps) # TODO: Make remaining methods thread-safe. function resetlogp!(vi::ThreadSafeVarInfo) - fill!(vi.logps, zero(getlogp(vi))) + for x in vi.logps + x[] = zero(x[]) + end return resetlogp!(vi.varinfo) end function setlogp!(vi::ThreadSafeVarInfo, logp) - fill!(vi.logps, zero(logp)) + for x in vi.logps + x[] = zero(x[]) + end return setlogp!(vi.varinfo, logp) end @@ -45,6 +49,7 @@ syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) setgid!(vi.varinfo, gid, vn) end +setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) diff --git a/test/compat/ad.jl b/test/compat/ad.jl new file mode 100644 index 000000000..814a3d245 --- /dev/null +++ b/test/compat/ad.jl @@ -0,0 +1,27 @@ +@testset "ad.jl" begin + @testset "logp" begin + # Hand-written log probabilities for vector `x = [s, m]`. + function logp_gdemo_default(x) + s = x[1] + m = x[2] + dist = Normal(m, sqrt(s)) + + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) + + logpdf(dist, 1.5) + logpdf(dist, 2.0) + end + + test_model_ad(gdemo_default, logp_gdemo_default) + + @model function wishart_ad() + v ~ Wishart(7, [1 0.5; 0.5 1]) + end + + # Hand-written log probabilities for `x = [v]`. + function logp_wishart_ad(x) + dist = Wishart(7, [1 0.5; 0.5 1]) + return logpdf(dist, reshape(x, 2, 2)) + end + + test_model_ad(wishart_ad(), logp_wishart_ad) + end +end diff --git a/test/compiler.jl b/test/compiler.jl index 4df301878..204ddeda3 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -580,42 +580,6 @@ end model = demo() @test all(iszero(model()) for _ in 1:1000) end - @testset "threading" begin - @info "Peforming threading tests with $(Threads.nthreads()) threads" - - x = rand(10_000) - - @model function wthreads(x) - x[1] ~ Normal(0, 1) - Threads.@threads for i in 2:length(x) - x[i] ~ Normal(x[i-1], 1) - end - end - - vi = VarInfo() - wthreads(x)(vi) - lp_w_threads = getlogp(vi) - - println("With threading:") - @time wthreads(x)(vi) - - @model function wothreads(x) - x[1] ~ Normal(0, 1) - for i in 2:length(x) - x[i] ~ Normal(x[i-1], 1) - end - end - - vi = VarInfo() - wothreads(x)(vi) - lp_wo_threads = getlogp(vi) - - println("Without threading:") - @time wothreads(x)(vi) - - @test lp_w_threads ≈ lp_wo_threads - end - @testset "docstring" begin "This is a test" @model function demo(x) diff --git a/test/runtests.jl b/test/runtests.jl index 154bf6dcf..d1c457737 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,20 @@ -using Test, DynamicPPL +using DynamicPPL +using Distributions +using ForwardDiff +using Tracker +using Zygote + +using Random +using Test + dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1] include(dir*"/test/Turing/Turing.jl") using .Turing turnprogress(false) +include("test_util.jl") + @testset "DynamicPPL.jl" begin include("utils.jl") include("compiler.jl") @@ -13,4 +23,10 @@ turnprogress(false) include("prob_macro.jl") include("independence.jl") include("distribution_wrappers.jl") + + include("threadsafe.jl") + + @testset "compat" begin + include(joinpath("compat", "ad.jl")) + end end diff --git a/test/test_util.jl b/test/test_util.jl new file mode 100644 index 000000000..d8a926656 --- /dev/null +++ b/test/test_util.jl @@ -0,0 +1,38 @@ +function test_model_ad(model, logp_manual) + vi = VarInfo(model) + model(vi, SampleFromPrior()) + x = DynamicPPL.getall(vi) + + # Log probabilities using the model. + function logp_model(x) + new_vi = VarInfo(vi, SampleFromPrior(), x) + model(new_vi, SampleFromPrior()) + return getlogp(new_vi) + end + + # Check that both functions return the same values. + lp = logp_manual(x) + @test logp_model(x) ≈ lp + + # Gradients based on the manual implementation. + grad = ForwardDiff.gradient(logp_manual, x) + + y, back = Tracker.forward(logp_manual, x) + @test Tracker.data(y) ≈ lp + @test Tracker.data(back(1)[1]) ≈ grad + + y, back = Zygote.pullback(logp_manual, x) + @test y ≈ lp + @test back(1)[1] ≈ grad + + # Gradients based on the model. + @test ForwardDiff.gradient(logp_model, x) ≈ grad + + y, back = Tracker.forward(logp_model, x) + @test Tracker.data(y) ≈ lp + @test Tracker.data(back(1)[1]) ≈ grad + + y, back = Zygote.pullback(logp_model, x) + @test y ≈ lp + @test back(1)[1] ≈ grad +end diff --git a/test/threadsafe.jl b/test/threadsafe.jl new file mode 100644 index 000000000..8fddb313f --- /dev/null +++ b/test/threadsafe.jl @@ -0,0 +1,91 @@ +@testset "threadsafe.jl" begin + @testset "constructor" begin + vi = VarInfo(gdemo_default) + threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) + + @test threadsafe_vi.varinfo === vi + @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} + @test length(threadsafe_vi.logps) == Threads.nthreads() + @test all(iszero(x[]) for x in threadsafe_vi.logps) + end + + # TODO: Add more tests of the public API + @testset "API" begin + vi = VarInfo(gdemo_default) + threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) + + lp = getlogp(vi) + @test getlogp(threadsafe_vi) == lp + + acclogp!(threadsafe_vi, 42) + @test threadsafe_vi.logps[Threads.threadid()][] == 42 + @test getlogp(vi) == lp + @test getlogp(threadsafe_vi) == lp + 42 + + resetlogp!(threadsafe_vi) + @test iszero(getlogp(vi)) + @test iszero(getlogp(threadsafe_vi)) + @test all(iszero(x[]) for x in threadsafe_vi.logps) + + setlogp!(threadsafe_vi, 42) + @test getlogp(vi) == 42 + @test getlogp(threadsafe_vi) == 42 + @test all(iszero(x[]) for x in threadsafe_vi.logps) + end + + @testset "model" begin + println("Peforming threading tests with $(Threads.nthreads()) threads") + + x = rand(10_000) + + @model function wthreads(x) + x[1] ~ Normal(0, 1) + Threads.@threads for i in 2:length(x) + x[i] ~ Normal(x[i-1], 1) + end + end + + vi = VarInfo() + wthreads(x)(vi) + lp_w_threads = getlogp(vi) + + println("With `@threads`:") + println(" default:") + @time wthreads(x)(vi) + + # Ensure that we use `ThreadSafeVarInfo`. + @test getlogp(vi) ≈ lp_w_threads + DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), + DefaultContext()) + + println(" evaluate_multithreaded:") + @time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(), + DefaultContext()) + + @model function wothreads(x) + x[1] ~ Normal(0, 1) + for i in 2:length(x) + x[i] ~ Normal(x[i-1], 1) + end + end + + vi = VarInfo() + wothreads(x)(vi) + lp_wo_threads = getlogp(vi) + + println("Without `@threads`:") + println(" default:") + @time wothreads(x)(vi) + + @test lp_w_threads ≈ lp_wo_threads + + # Ensure that we use `VarInfo`. + DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), + DefaultContext()) + @test getlogp(vi) ≈ lp_w_threads + + println(" evaluate_singlethreaded:") + @time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(), + DefaultContext()) + end +end