Skip to content

Commit

Permalink
Merge pull request #97 from TuringLang/fixes_threaded
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored May 3, 2020
2 parents 0758d01 + 6d8caca commit 07c0cee
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 46 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
authors = ["mohamed82008 <[email protected]>"]
version = "0.7.0"
version = "0.7.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -45,6 +45,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
3 changes: 2 additions & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ function get_and_set_val!(
else
r = init(dist, spl, n)
for i in 1:n
push!(vi, vns[i], r[:,i], dist, spl)
vn = vns[i]
push!(vi, vn, r[:,i], dist, spl)
settrans!(vi, false, vn)
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ function evaluate_multithreaded(model, varinfo, sampler, context)
end
wrapper = ThreadSafeVarInfo(varinfo)
result = model.f(model, wrapper, sampler, context)
acclogp!(varinfo, sum(wrapper.logps))
setlogp!(varinfo, getlogp(wrapper))
return result
end

Expand Down
15 changes: 10 additions & 5 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,33 @@ struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo
logps::L
end
function ThreadSafeVarInfo(vi::AbstractVarInfo)
return ThreadSafeVarInfo(vi, [zero(getlogp(vi)) for _ in 1:Threads.nthreads()])
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
end
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi

# Instead of updating the log probability of the underlying variables we
# just update the array of log probabilities.
function acclogp!(vi::ThreadSafeVarInfo, logp)
vi.logps[Threads.threadid()] += logp
vi.logps[Threads.threadid()][] += logp
return vi
end

# The current log probability of the variables has to be computed from
# both the wrapped variables and the thread-specific log probabilities.
getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps)
getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps)

# TODO: Make remaining methods thread-safe.

function resetlogp!(vi::ThreadSafeVarInfo)
fill!(vi.logps, zero(getlogp(vi)))
for x in vi.logps
x[] = zero(x[])
end
return resetlogp!(vi.varinfo)
end
function setlogp!(vi::ThreadSafeVarInfo, logp)
fill!(vi.logps, zero(logp))
for x in vi.logps
x[] = zero(x[])
end
return setlogp!(vi.varinfo, logp)
end

Expand All @@ -45,6 +49,7 @@ syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)
function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName)
setgid!(vi.varinfo, gid, vn)
end
setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index)
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)

keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
Expand Down
27 changes: 27 additions & 0 deletions test/compat/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
@testset "ad.jl" begin
@testset "logp" begin
# Hand-written log probabilities for vector `x = [s, m]`.
function logp_gdemo_default(x)
s = x[1]
m = x[2]
dist = Normal(m, sqrt(s))

return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +
logpdf(dist, 1.5) + logpdf(dist, 2.0)
end

test_model_ad(gdemo_default, logp_gdemo_default)

@model function wishart_ad()
v ~ Wishart(7, [1 0.5; 0.5 1])
end

# Hand-written log probabilities for `x = [v]`.
function logp_wishart_ad(x)
dist = Wishart(7, [1 0.5; 0.5 1])
return logpdf(dist, reshape(x, 2, 2))
end

test_model_ad(wishart_ad(), logp_wishart_ad)
end
end
36 changes: 0 additions & 36 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -580,42 +580,6 @@ end
model = demo()
@test all(iszero(model()) for _ in 1:1000)
end
@testset "threading" begin
@info "Peforming threading tests with $(Threads.nthreads()) threads"

x = rand(10_000)

@model function wthreads(x)
x[1] ~ Normal(0, 1)
Threads.@threads for i in 2:length(x)
x[i] ~ Normal(x[i-1], 1)
end
end

vi = VarInfo()
wthreads(x)(vi)
lp_w_threads = getlogp(vi)

println("With threading:")
@time wthreads(x)(vi)

@model function wothreads(x)
x[1] ~ Normal(0, 1)
for i in 2:length(x)
x[i] ~ Normal(x[i-1], 1)
end
end

vi = VarInfo()
wothreads(x)(vi)
lp_wo_threads = getlogp(vi)

println("Without threading:")
@time wothreads(x)(vi)

@test lp_w_threads lp_wo_threads
end

@testset "docstring" begin
"This is a test"
@model function demo(x)
Expand Down
18 changes: 17 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
using Test, DynamicPPL
using DynamicPPL
using Distributions
using ForwardDiff
using Tracker
using Zygote

using Random
using Test

dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1]
include(dir*"/test/Turing/Turing.jl")
using .Turing

turnprogress(false)

include("test_util.jl")

@testset "DynamicPPL.jl" begin
include("utils.jl")
include("compiler.jl")
Expand All @@ -13,4 +23,10 @@ turnprogress(false)
include("prob_macro.jl")
include("independence.jl")
include("distribution_wrappers.jl")

include("threadsafe.jl")

@testset "compat" begin
include(joinpath("compat", "ad.jl"))
end
end
38 changes: 38 additions & 0 deletions test/test_util.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
function test_model_ad(model, logp_manual)
vi = VarInfo(model)
model(vi, SampleFromPrior())
x = DynamicPPL.getall(vi)

# Log probabilities using the model.
function logp_model(x)
new_vi = VarInfo(vi, SampleFromPrior(), x)
model(new_vi, SampleFromPrior())
return getlogp(new_vi)
end

# Check that both functions return the same values.
lp = logp_manual(x)
@test logp_model(x) lp

# Gradients based on the manual implementation.
grad = ForwardDiff.gradient(logp_manual, x)

y, back = Tracker.forward(logp_manual, x)
@test Tracker.data(y) lp
@test Tracker.data(back(1)[1]) grad

y, back = Zygote.pullback(logp_manual, x)
@test y lp
@test back(1)[1] grad

# Gradients based on the model.
@test ForwardDiff.gradient(logp_model, x) grad

y, back = Tracker.forward(logp_model, x)
@test Tracker.data(y) lp
@test Tracker.data(back(1)[1]) grad

y, back = Zygote.pullback(logp_model, x)
@test y lp
@test back(1)[1] grad
end
91 changes: 91 additions & 0 deletions test/threadsafe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
@testset "threadsafe.jl" begin
@testset "constructor" begin
vi = VarInfo(gdemo_default)
threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi)

@test threadsafe_vi.varinfo === vi
@test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))}
@test length(threadsafe_vi.logps) == Threads.nthreads()
@test all(iszero(x[]) for x in threadsafe_vi.logps)
end

# TODO: Add more tests of the public API
@testset "API" begin
vi = VarInfo(gdemo_default)
threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi)

lp = getlogp(vi)
@test getlogp(threadsafe_vi) == lp

acclogp!(threadsafe_vi, 42)
@test threadsafe_vi.logps[Threads.threadid()][] == 42
@test getlogp(vi) == lp
@test getlogp(threadsafe_vi) == lp + 42

resetlogp!(threadsafe_vi)
@test iszero(getlogp(vi))
@test iszero(getlogp(threadsafe_vi))
@test all(iszero(x[]) for x in threadsafe_vi.logps)

setlogp!(threadsafe_vi, 42)
@test getlogp(vi) == 42
@test getlogp(threadsafe_vi) == 42
@test all(iszero(x[]) for x in threadsafe_vi.logps)
end

@testset "model" begin
println("Peforming threading tests with $(Threads.nthreads()) threads")

x = rand(10_000)

@model function wthreads(x)
x[1] ~ Normal(0, 1)
Threads.@threads for i in 2:length(x)
x[i] ~ Normal(x[i-1], 1)
end
end

vi = VarInfo()
wthreads(x)(vi)
lp_w_threads = getlogp(vi)

println("With `@threads`:")
println(" default:")
@time wthreads(x)(vi)

# Ensure that we use `ThreadSafeVarInfo`.
@test getlogp(vi) lp_w_threads
DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
DefaultContext())

println(" evaluate_multithreaded:")
@time DynamicPPL.evaluate_multithreaded(wthreads(x), vi, SampleFromPrior(),
DefaultContext())

@model function wothreads(x)
x[1] ~ Normal(0, 1)
for i in 2:length(x)
x[i] ~ Normal(x[i-1], 1)
end
end

vi = VarInfo()
wothreads(x)(vi)
lp_wo_threads = getlogp(vi)

println("Without `@threads`:")
println(" default:")
@time wothreads(x)(vi)

@test lp_w_threads lp_wo_threads

# Ensure that we use `VarInfo`.
DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
DefaultContext())
@test getlogp(vi) lp_w_threads

println(" evaluate_singlethreaded:")
@time DynamicPPL.evaluate_singlethreaded(wothreads(x), vi, SampleFromPrior(),
DefaultContext())
end
end

2 comments on commit 07c0cee

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/14102

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.1 -m "<description of version>" 07c0ceec1fe4f8a2ba05e5b5a85e4f4222afde35
git push origin v0.7.1

Please sign in to comment.