Skip to content

Commit

Permalink
Merge pull request #41 from TuringLang/update
Browse files Browse the repository at this point in the history
Add AbstractMCMC 0.5 and update Turing
  • Loading branch information
yebai authored Feb 28, 2020
2 parents add2f5e + 62351c1 commit 7b4e3cd
Show file tree
Hide file tree
Showing 21 changed files with 794 additions and 584 deletions.
17 changes: 9 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
authors = ["mohamed82008 <[email protected]>"]
version = "0.4.0"
version = "0.4.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -10,35 +10,36 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

[compat]
AbstractMCMC = "0.4"
AbstractMCMC = "0.4, 0.5"
Bijectors = "0.5.2"
Distributions = "0.22"
MacroTools = "0.5.1"
julia = "1"

[extras]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[targets]
test = ["AbstractMCMC", "AdvancedHMC", "Bijectors", "Distributions", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "MCMCChains", "MacroTools", "Markdown", "PDMats", "ProgressMeter", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsFuns", "Test", "Tracker"]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]
22 changes: 15 additions & 7 deletions test/Turing/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ module Turing

using Requires, Reexport, ForwardDiff
using Bijectors, StatsFuns, SpecialFunctions
using Statistics, LinearAlgebra, ProgressMeter
using Statistics, LinearAlgebra
using Markdown, Libtask, MacroTools
using AbstractMCMC: sample, psample
@reexport using Distributions, MCMCChains, Libtask
using Tracker: Tracker

Expand All @@ -21,7 +20,7 @@ import DynamicPPL: getspace, runmodel!

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
@info("[Turing]: global PROGRESS is set as $switch")
@info "[Turing]: progress logging is $(switch ? "enabled" : "disabled") globally"
PROGRESS[] = switch
end

Expand Down Expand Up @@ -50,10 +49,14 @@ using .Variational
# end

@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval Inference begin
using Pkg;
Pkg.installed()["DynamicHMC"] < v"2.0" && error("Please upgdate your DynamicHMC, v1.x is no longer supported")
using ..Turing.DynamicHMC: DynamicHMC, mcmc_with_warmup
include("contrib/inference/dynamichmc.jl")
import ..DynamicHMC

if isdefined(DynamicHMC, :mcmc_with_warmup)
using ..DynamicHMC: mcmc_with_warmup
include("contrib/inference/dynamichmc.jl")
else
error("Please update DynamicHMC, v1.x is no longer supported")
end
end

###########
Expand All @@ -69,6 +72,7 @@ export @model, # modelling
DynamicPPL,

MH, # classic sampling
RWMH,
ESS,
Gibbs,

Expand Down Expand Up @@ -109,4 +113,8 @@ export @model, # modelling
LogPoisson,
NamedDist

# Reexports
using AbstractMCMC: sample, psample
export sample, psample

end
2 changes: 1 addition & 1 deletion test/Turing/contrib/inference/AdvancedSMCExtensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first:

# Run SMC & CSMC nodes
for j in 1:spl.alg.n_nodes
reset_num_produce!(VarInfos[j])
VarInfos[j].num_produce = 0
VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j])[1]
log_zs[j] = spl.info[:samplers][j].info[:logevidence][end]
end
Expand Down
40 changes: 33 additions & 7 deletions test/Turing/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using AbstractMCMC: NoCallback

###
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
###
Expand Down Expand Up @@ -110,13 +108,41 @@ function Sampler(
return Sampler(alg, Dict{Symbol,Any}(), s, state)
end

# Disable the callback for DynamicHMC, since it has it's own progress meter.
function AbstractMCMC.init_callback(
# Disable the progress logging for DynamicHMC, since it has its own progress meter.
function AbstractMCMC.sample(
rng::AbstractRNG,
model::Model,
s::Sampler{<:DynamicNUTS},
model::AbstractModel,
alg::DynamicNUTS,
N::Integer;
chain_type=Chains,
resume_from=nothing,
progress=PROGRESS[],
kwargs...
)
return NoCallback()
if progress
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
if resume_from === nothing
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
chain_type=chain_type, progress=false, kwargs...)
else
return resume(resume_from, N; chain_type=chain_type, progress=false, kwargs...)
end
end

function AbstractMCMC.psample(
rng::AbstractRNG,
model::AbstractModel,
alg::DynamicNUTS,
N::Integer,
n_chains::Integer;
chain_type=Chains,
progress=PROGRESS[],
kwargs...
)
if progress
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains;
chain_type=chain_type, progress=false, kwargs...)
end
6 changes: 0 additions & 6 deletions test/Turing/core/Core.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module Core

using Bijectors
using MacroTools, Libtask, ForwardDiff, Random
using Distributions, LinearAlgebra
using ..Utilities, Reexport
Expand All @@ -14,14 +13,9 @@ import Bijectors: link, invlink
using DistributionsAD
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using Requires

include("container.jl")
include("ad.jl")
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("compat/zygote.jl")
export ZygoteAD
end

export @model,
@varname,
Expand Down
41 changes: 14 additions & 27 deletions test/Turing/core/ad.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
##############################
# Global variables/constants #
##############################
using Bijectors

const ADBACKEND = Ref(:forward_diff)
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
function setadbackend(::Val{:forward_diff})
CHUNKSIZE[] == 0 && setchunksize(40)
ADBACKEND[] = :forward_diff
end
function setadbackend(::Val{:reverse_diff})
ADBACKEND[] = :reverse_diff
function setadbackend(backend_sym)
@assert backend_sym == :forward_diff || backend_sym == :reverse_diff
backend_sym == :forward_diff && CHUNKSIZE[] == 0 && setchunksize(40)
ADBACKEND[] = backend_sym

Bijectors.setadbackend(backend_sym)
end

const ADSAFE = Ref(false)
Expand Down Expand Up @@ -39,8 +39,7 @@ ADBackend() = ADBackend(ADBACKEND[])
ADBackend(T::Symbol) = ADBackend(Val(T))

ADBackend(::Val{:forward_diff}) = ForwardDiffAD{CHUNKSIZE[]}
ADBackend(::Val{:reverse_diff}) = TrackerAD
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
ADBackend(::Val) = TrackerAD

"""
getADtype(alg)
Expand Down Expand Up @@ -70,8 +69,8 @@ function gradient_logp(
ad_type = getADtype(sampler)
if ad_type <: ForwardDiffAD
return gradient_logp_forward(θ, vi, model, sampler)
else
return gradient_logp_reverse(ad_type(), θ, vi, model, sampler)
else ad_type <: TrackerAD
return gradient_logp_reverse(θ, vi, model, sampler)
end
end

Expand Down Expand Up @@ -114,22 +113,20 @@ end

"""
gradient_logp_reverse(
backend::ADBackend,
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
sampler::AbstractSampler=SampleFromPrior(),
)
Computes the value of the log joint of `θ` and its gradient for the model
specified by `(vi, sampler, model)` using reverse-mode AD from the specified `backend`, e.g. `TrackerAD()` which uses `Tracker.jl` or `ZygoteAD()` which uses `Zygote.jl`.
specified by `(vi, sampler, model)` using reverse-mode AD from Tracker.jl.
"""
function gradient_logp_reverse(
backend::TrackerAD,
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
sampler::AbstractSampler=SampleFromPrior(),
)
T = typeof(getlogp(vi))

Expand All @@ -141,19 +138,10 @@ function gradient_logp_reverse(

# Compute forward and reverse passes.
l_tracked, ȳ = Tracker.forward(f, θ)
# Remove tracking info from variables in model (because mutable state).
l::T, ∂l∂θ::typeof(θ) = Tracker.data(l_tracked), Tracker.data((1)[1])

# Remove tracking info from variables in model (because mutable state).
return l, ∂l∂θ
end
function gradient_logp_reverse(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
)
return gradient_logp_reverse(TrackerAD(), θ, vi, model, sampler)
end

function verifygrad(grad::AbstractVector{<:Real})
if any(isnan, grad) || any(isinf, grad)
Expand All @@ -165,7 +153,6 @@ function verifygrad(grad::AbstractVector{<:Real})
end
end

# Replace the adjoints below with Zygote ones
for F in (:link, :invlink)
@eval begin
function $F(
Expand Down
26 changes: 0 additions & 26 deletions test/Turing/core/compat/zygote.jl

This file was deleted.

11 changes: 5 additions & 6 deletions test/Turing/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function AbstractMCMC.step!(
spl::Sampler{<:SMC},
::Integer,
transition;
iteration = -1,
iteration=-1,
kwargs...
)
# check that we received a real iteration number
Expand Down Expand Up @@ -238,23 +238,22 @@ function AbstractMCMC.sample_end!(
spl::Sampler{<:ParticleInference},
N::Integer,
ts::Vector{<:ParticleTransition};
resume_from = nothing,
kwargs...
)
# Set the default for resuming the sampler.
resume_from = get(kwargs, :resume_from, nothing)

# Exponentiate the average log evidence.
# loge = exp(mean([t.le for t in ts]))
loge = mean(t.le for t in ts)

# If we already had a chain, grab the logevidence.
if resume_from !== nothing # concat samples
@assert resume_from isa Chains "resume_from needs to be a Chains object."
if resume_from isa Chains
# pushfirst!(samples, resume_from.info[:samples]...)
pre_loge = resume_from.logevidence
# Calculate new log-evidence
pre_n = length(resume_from)
loge = (pre_loge * pre_n + loge * N) / (pre_n + N)
elseif resume_from !== nothing
error("keyword argument `resume_from` has to be `nothing` or a `Chains` object")
end

# Store the logevidence.
Expand Down
Loading

0 comments on commit 7b4e3cd

Please sign in to comment.