diff --git a/Project.toml b/Project.toml index 51a0c452c..5df51eea9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" authors = ["mohamed82008 "] -version = "0.4.0" +version = "0.4.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -10,35 +10,36 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" [compat] -AbstractMCMC = "0.4" +AbstractMCMC = "0.4, 0.5" Bijectors = "0.5.2" Distributions = "0.22" MacroTools = "0.5.1" julia = "1" [extras] -AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" +BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" -ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [targets] -test = ["AbstractMCMC", "AdvancedHMC", "Bijectors", "Distributions", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "MCMCChains", "MacroTools", "Markdown", "PDMats", "ProgressMeter", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsFuns", "Test", "Tracker"] +test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"] diff --git a/test/Turing/Turing.jl b/test/Turing/Turing.jl index 8927ec48f..78a885e88 100644 --- a/test/Turing/Turing.jl +++ b/test/Turing/Turing.jl @@ -10,9 +10,8 @@ module Turing using Requires, Reexport, ForwardDiff using Bijectors, StatsFuns, SpecialFunctions -using Statistics, LinearAlgebra, ProgressMeter +using Statistics, LinearAlgebra using Markdown, Libtask, MacroTools -using AbstractMCMC: sample, psample @reexport using Distributions, MCMCChains, Libtask using Tracker: Tracker @@ -21,7 +20,7 @@ import DynamicPPL: getspace, runmodel! const PROGRESS = Ref(true) function turnprogress(switch::Bool) - @info("[Turing]: global PROGRESS is set as $switch") + @info "[Turing]: progress logging is $(switch ? "enabled" : "disabled") globally" PROGRESS[] = switch end @@ -50,10 +49,14 @@ using .Variational # end @init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval Inference begin - using Pkg; - Pkg.installed()["DynamicHMC"] < v"2.0" && error("Please upgdate your DynamicHMC, v1.x is no longer supported") - using ..Turing.DynamicHMC: DynamicHMC, mcmc_with_warmup - include("contrib/inference/dynamichmc.jl") + import ..DynamicHMC + + if isdefined(DynamicHMC, :mcmc_with_warmup) + using ..DynamicHMC: mcmc_with_warmup + include("contrib/inference/dynamichmc.jl") + else + error("Please update DynamicHMC, v1.x is no longer supported") + end end ########### @@ -69,6 +72,7 @@ export @model, # modelling DynamicPPL, MH, # classic sampling + RWMH, ESS, Gibbs, @@ -109,4 +113,8 @@ export @model, # modelling LogPoisson, NamedDist +# Reexports +using AbstractMCMC: sample, psample +export sample, psample + end diff --git a/test/Turing/contrib/inference/AdvancedSMCExtensions.jl b/test/Turing/contrib/inference/AdvancedSMCExtensions.jl index 7adc20eac..27d2f0805 100644 --- a/test/Turing/contrib/inference/AdvancedSMCExtensions.jl +++ b/test/Turing/contrib/inference/AdvancedSMCExtensions.jl @@ -255,7 +255,7 @@ function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first: # Run SMC & CSMC nodes for j in 1:spl.alg.n_nodes - reset_num_produce!(VarInfos[j]) + VarInfos[j].num_produce = 0 VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j])[1] log_zs[j] = spl.info[:samplers][j].info[:logevidence][end] end diff --git a/test/Turing/contrib/inference/dynamichmc.jl b/test/Turing/contrib/inference/dynamichmc.jl index 39e9ea97c..b0d158eb5 100644 --- a/test/Turing/contrib/inference/dynamichmc.jl +++ b/test/Turing/contrib/inference/dynamichmc.jl @@ -1,5 +1,3 @@ -using AbstractMCMC: NoCallback - ### ### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl ### @@ -110,13 +108,41 @@ function Sampler( return Sampler(alg, Dict{Symbol,Any}(), s, state) end -# Disable the callback for DynamicHMC, since it has it's own progress meter. -function AbstractMCMC.init_callback( + # Disable the progress logging for DynamicHMC, since it has its own progress meter. + function AbstractMCMC.sample( rng::AbstractRNG, - model::Model, - s::Sampler{<:DynamicNUTS}, + model::AbstractModel, + alg::DynamicNUTS, N::Integer; + chain_type=Chains, + resume_from=nothing, + progress=PROGRESS[], kwargs... ) - return NoCallback() + if progress + @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + end + if resume_from === nothing + return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; + chain_type=chain_type, progress=false, kwargs...) + else + return resume(resume_from, N; chain_type=chain_type, progress=false, kwargs...) + end end + +function AbstractMCMC.psample( + rng::AbstractRNG, + model::AbstractModel, + alg::DynamicNUTS, + N::Integer, + n_chains::Integer; + chain_type=Chains, + progress=PROGRESS[], + kwargs... +) + if progress + @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + end + return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains; + chain_type=chain_type, progress=false, kwargs...) +end \ No newline at end of file diff --git a/test/Turing/core/Core.jl b/test/Turing/core/Core.jl index e66b721f4..fbaeee16a 100644 --- a/test/Turing/core/Core.jl +++ b/test/Turing/core/Core.jl @@ -1,6 +1,5 @@ module Core -using Bijectors using MacroTools, Libtask, ForwardDiff, Random using Distributions, LinearAlgebra using ..Utilities, Reexport @@ -14,14 +13,9 @@ import Bijectors: link, invlink using DistributionsAD using StatsFuns: logsumexp, softmax @reexport using DynamicPPL -using Requires include("container.jl") include("ad.jl") -@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD -end export @model, @varname, diff --git a/test/Turing/core/ad.jl b/test/Turing/core/ad.jl index ffca99f4c..da9ee34ba 100644 --- a/test/Turing/core/ad.jl +++ b/test/Turing/core/ad.jl @@ -1,15 +1,15 @@ ############################## # Global variables/constants # ############################## +using Bijectors const ADBACKEND = Ref(:forward_diff) -setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) -function setadbackend(::Val{:forward_diff}) - CHUNKSIZE[] == 0 && setchunksize(40) - ADBACKEND[] = :forward_diff -end -function setadbackend(::Val{:reverse_diff}) - ADBACKEND[] = :reverse_diff +function setadbackend(backend_sym) + @assert backend_sym == :forward_diff || backend_sym == :reverse_diff + backend_sym == :forward_diff && CHUNKSIZE[] == 0 && setchunksize(40) + ADBACKEND[] = backend_sym + + Bijectors.setadbackend(backend_sym) end const ADSAFE = Ref(false) @@ -39,8 +39,7 @@ ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) ADBackend(::Val{:forward_diff}) = ForwardDiffAD{CHUNKSIZE[]} -ADBackend(::Val{:reverse_diff}) = TrackerAD -ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") +ADBackend(::Val) = TrackerAD """ getADtype(alg) @@ -70,8 +69,8 @@ function gradient_logp( ad_type = getADtype(sampler) if ad_type <: ForwardDiffAD return gradient_logp_forward(θ, vi, model, sampler) - else - return gradient_logp_reverse(ad_type(), θ, vi, model, sampler) + else ad_type <: TrackerAD + return gradient_logp_reverse(θ, vi, model, sampler) end end @@ -114,22 +113,20 @@ end """ gradient_logp_reverse( - backend::ADBackend, θ::AbstractVector{<:Real}, vi::VarInfo, model::Model, - sampler::AbstractSampler = SampleFromPrior(), + sampler::AbstractSampler=SampleFromPrior(), ) Computes the value of the log joint of `θ` and its gradient for the model -specified by `(vi, sampler, model)` using reverse-mode AD from the specified `backend`, e.g. `TrackerAD()` which uses `Tracker.jl` or `ZygoteAD()` which uses `Zygote.jl`. +specified by `(vi, sampler, model)` using reverse-mode AD from Tracker.jl. """ function gradient_logp_reverse( - backend::TrackerAD, θ::AbstractVector{<:Real}, vi::VarInfo, model::Model, - sampler::AbstractSampler = SampleFromPrior(), + sampler::AbstractSampler=SampleFromPrior(), ) T = typeof(getlogp(vi)) @@ -141,19 +138,10 @@ function gradient_logp_reverse( # Compute forward and reverse passes. l_tracked, ȳ = Tracker.forward(f, θ) - # Remove tracking info from variables in model (because mutable state). l::T, ∂l∂θ::typeof(θ) = Tracker.data(l_tracked), Tracker.data(ȳ(1)[1]) - + # Remove tracking info from variables in model (because mutable state). return l, ∂l∂θ end -function gradient_logp_reverse( - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), -) - return gradient_logp_reverse(TrackerAD(), θ, vi, model, sampler) -end function verifygrad(grad::AbstractVector{<:Real}) if any(isnan, grad) || any(isinf, grad) @@ -165,7 +153,6 @@ function verifygrad(grad::AbstractVector{<:Real}) end end -# Replace the adjoints below with Zygote ones for F in (:link, :invlink) @eval begin function $F( diff --git a/test/Turing/core/compat/zygote.jl b/test/Turing/core/compat/zygote.jl deleted file mode 100644 index 054c93496..000000000 --- a/test/Turing/core/compat/zygote.jl +++ /dev/null @@ -1,26 +0,0 @@ -struct ZygoteAD <: ADBackend end -ADBackend(::Val{:zygote}) = ZygoteAD -function setadbackend(::Val{:zygote}) - ADBACKEND[] = :zygote -end - -function gradient_logp_reverse( - backend::ZygoteAD, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), -) - - # Specify objective function. - function f(θ) - new_vi = VarInfo(vi, sampler, θ) - return getlogp(runmodel!(model, new_vi, sampler)) - end - - # Compute forward and reverse passes. - l, ȳ = Zygote.pullback(f, θ) - ∂l∂θ = ȳ(1)[1] - - return l, ∂l∂θ -end diff --git a/test/Turing/inference/AdvancedSMC.jl b/test/Turing/inference/AdvancedSMC.jl index f7e4139ef..646ca58d4 100644 --- a/test/Turing/inference/AdvancedSMC.jl +++ b/test/Turing/inference/AdvancedSMC.jl @@ -109,7 +109,7 @@ function AbstractMCMC.step!( spl::Sampler{<:SMC}, ::Integer, transition; - iteration = -1, + iteration=-1, kwargs... ) # check that we received a real iteration number @@ -238,23 +238,22 @@ function AbstractMCMC.sample_end!( spl::Sampler{<:ParticleInference}, N::Integer, ts::Vector{<:ParticleTransition}; + resume_from = nothing, kwargs... ) - # Set the default for resuming the sampler. - resume_from = get(kwargs, :resume_from, nothing) - # Exponentiate the average log evidence. # loge = exp(mean([t.le for t in ts])) loge = mean(t.le for t in ts) # If we already had a chain, grab the logevidence. - if resume_from !== nothing # concat samples - @assert resume_from isa Chains "resume_from needs to be a Chains object." + if resume_from isa Chains # pushfirst!(samples, resume_from.info[:samples]...) pre_loge = resume_from.logevidence # Calculate new log-evidence pre_n = length(resume_from) loge = (pre_loge * pre_n + loge * N) / (pre_n + N) + elseif resume_from !== nothing + error("keyword argument `resume_from` has to be `nothing` or a `Chains` object") end # Store the logevidence. diff --git a/test/Turing/inference/Inference.jl b/test/Turing/inference/Inference.jl index 7803ca3fe..564538bf3 100644 --- a/test/Turing/inference/Inference.jl +++ b/test/Turing/inference/Inference.jl @@ -8,19 +8,21 @@ using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo, Selector, AbstractSamplerState, DefaultContext, PriorContext, LikelihoodContext, MiniBatchContext, set_flag!, unset_flag! using Distributions, Libtask, Bijectors -using ProgressMeter, LinearAlgebra +using LinearAlgebra using ..Turing: PROGRESS, NamedDist, NoDist, Turing using StatsFuns: logsumexp -using Random: GLOBAL_RNG, AbstractRNG, randexp +using Random: AbstractRNG, randexp using DynamicPPL -using Bijectors: _debug +using AbstractMCMC: AbstractModel, AbstractSampler +using MCMCChains: Chains -import MCMCChains: Chains +import AbstractMCMC import AdvancedHMC; const AHMC = AdvancedHMC +import AdvancedMH; const AMH = AdvancedMH import ..Core: getchunksize, getADtype -import AbstractMCMC -using AbstractMCMC: AbstractModel, AbstractCallback, AbstractSampler -import DynamicPPL: tilde, dot_tilde, getspace, get_matching_type +import DynamicPPL: tilde, dot_tilde, getspace, get_matching_type, + VarName, _getranges, _getindex, getval, _getvns +import Random export InferenceAlgorithm, Hamiltonian, @@ -126,41 +128,40 @@ const TURING_INTERNAL_VARS = (internals = [ ######################################### function AbstractMCMC.sample( - rng::AbstractRNG, model::AbstractModel, alg::InferenceAlgorithm, N::Integer; - chain_type=Chains, kwargs... ) - return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...) + return AbstractMCMC.sample(Random.GLOBAL_RNG, model, alg, N; kwargs...) end function AbstractMCMC.sample( - model::Model, + rng::AbstractRNG, + model::AbstractModel, alg::InferenceAlgorithm, N::Integer; - resume_from=nothing, chain_type=Chains, + resume_from=nothing, + progress=PROGRESS[], kwargs... ) if resume_from === nothing - return AbstractMCMC.sample(model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...) + return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; + chain_type=chain_type, progress=progress, kwargs...) else - return resume(resume_from, N) + return resume(resume_from, N; chain_type=chain_type, progress=progress, kwargs...) end end - function AbstractMCMC.psample( model::AbstractModel, alg::InferenceAlgorithm, N::Integer, n_chains::Integer; - chain_type=Chains, kwargs... ) - return AbstractMCMC.psample(GLOBAL_RNG, model, alg, N, n_chains; progress=false, chain_type=chain_type, kwargs...) + return AbstractMCMC.psample(Random.GLOBAL_RNG, model, alg, N, n_chains; kwargs...) end function AbstractMCMC.psample( @@ -170,9 +171,11 @@ function AbstractMCMC.psample( N::Integer, n_chains::Integer; chain_type=Chains, + progress=PROGRESS[], kwargs... ) - return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains; progress=false, chain_type=chain_type, kwargs...) + return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains; + chain_type=chain_type, progress=progress, kwargs...) end function AbstractMCMC.sample_init!( @@ -201,7 +204,7 @@ function AbstractMCMC.sample_end!( end function initialize_parameters!( - spl::AbstractSampler; + spl::Sampler; init_theta::Union{Nothing,Vector}=nothing, verbose::Bool=false, kwargs... @@ -305,13 +308,13 @@ end # Default Chains constructor. function AbstractMCMC.bundle_samples( rng::AbstractRNG, - model::AbstractModel, + model::Model, spl::Sampler, N::Integer, ts::Vector, - ::Type{Chains}; + chain_type::Type{Chains}; discard_adapt::Bool=true, - save_state=true, + save_state=false, kwargs... ) # Check if we have adaptation samples. @@ -361,7 +364,36 @@ function AbstractMCMC.bundle_samples( ) end -function save(c::Chains, spl::AbstractSampler, model, vi, samples) +function AbstractMCMC.bundle_samples( + rng::AbstractRNG, + model::Model, + spl::Sampler, + N::Integer, + ts::Vector, + chain_type::Type{Vector{NamedTuple}}; + discard_adapt::Bool=true, + save_state=false, + kwargs... +) + nts = Vector{NamedTuple}(undef, N) + + for (i,t) in enumerate(ts) + k = collect(keys(t.θ)) + vs = [] + for v in values(t.θ) + push!(vs, v[1]) + end + + push!(k, :lp) + + + nts[i] = NamedTuple{tuple(k...)}(tuple(vs..., t.lp)) + end + + return map(identity, nts) +end + +function save(c::Chains, spl::Sampler, model, vi, samples) nt = NamedTuple{(:spl, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples)) return setinfo(c, merge(nt, c.info)) end @@ -377,7 +409,7 @@ function resume(c::Chains, n_iter::Int; chain_type=Chains, kwargs...) n_iter; resume_from=c, reuse_spl_n=n_iter, - chain_type=chain_type, + chain_type=Chains, kwargs... ) @@ -565,7 +597,7 @@ function assume( vi[vn] = vectorize(dist, r) setorder!(vi, vn, get_num_produce(vi)) else - r = vi[vn] + r = vi[vn] end else r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist) @@ -687,7 +719,7 @@ function dot_assume( var::AbstractMatrix, vi::VarInfo, ) - @assert dim(dist) == size(var, 1) + @assert length(dist) == size(var, 1) r = get_and_set_val!(vi, vns, dist, spl) lp = sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) var .= r @@ -831,8 +863,8 @@ function dot_observe( vi::VarInfo, ) increment_num_produce!(vi) - Turing.DEBUG && _debug("dist = $dist") - Turing.DEBUG && _debug("value = $value") + Turing.DEBUG && @debug "dist = $dist" + Turing.DEBUG && @debug "value = $value" return sum(logpdf(dist, value)) end function dot_observe( @@ -842,8 +874,8 @@ function dot_observe( vi::VarInfo, ) increment_num_produce!(vi) - Turing.DEBUG && _debug("dists = $dists") - Turing.DEBUG && _debug("value = $value") + Turing.DEBUG && @debug "dists = $dists" + Turing.DEBUG && @debug "value = $value" return sum(logpdf.(dists, value)) end function dot_observe( diff --git a/test/Turing/inference/ess.jl b/test/Turing/inference/ess.jl index d6aad3ccb..5cffad6b0 100644 --- a/test/Turing/inference/ess.jl +++ b/test/Turing/inference/ess.jl @@ -56,7 +56,7 @@ isgaussian(::AbstractMvNormal) = true # always accept in the first step function AbstractMCMC.step!( ::AbstractRNG, - ::Model, + model::Model, spl::Sampler{<:ESS}, ::Integer, ::Nothing; diff --git a/test/Turing/inference/gibbs.jl b/test/Turing/inference/gibbs.jl index 93df86e19..61e7d7f6c 100644 --- a/test/Turing/inference/gibbs.jl +++ b/test/Turing/inference/gibbs.jl @@ -126,8 +126,7 @@ function AbstractMCMC.sample_end!( end end - -# Steps +# Steps 2 function AbstractMCMC.step!( rng::AbstractRNG, model::Model, diff --git a/test/Turing/inference/hmc.jl b/test/Turing/inference/hmc.jl index 0b77a2469..ec2e3ebdb 100644 --- a/test/Turing/inference/hmc.jl +++ b/test/Turing/inference/hmc.jl @@ -119,7 +119,10 @@ function AbstractMCMC.sample_init!( if spl.alg isa AdaptiveHamiltonian # If there's no chain passed in, verify the n_adapts. if resume_from === nothing - if spl.alg.n_adapts == 0 + # if n_adapts is -1, then the user called a convenience + # constructor like NUTS() or NUTS(0.65), and we should + # set a default for them. + if spl.alg.n_adapts == -1 spl.alg.n_adapts = min(1000, N ÷ 2) elseif spl.alg.n_adapts > N # Verify that n_adapts is not greater than the number of samples to draw. @@ -179,7 +182,7 @@ function HMCDA{AD}( init_ϵ::Float64=0.0, metricT=AHMC.UnitEuclideanMetric ) where AD - return HMCDA{AD}(0, δ, λ, init_ϵ, metricT, ()) + return HMCDA{AD}(-1, δ, λ, init_ϵ, metricT, ()) end function HMCDA{AD}( @@ -275,11 +278,11 @@ function NUTS{AD}( init_ϵ::Float64=0.0, metricT=AHMC.DiagEuclideanMetric ) where AD - NUTS{AD}(0, δ, max_depth, Δ_max, init_ϵ, metricT, ()) + NUTS{AD}(-1, δ, max_depth, Δ_max, init_ϵ, metricT, ()) end function NUTS{AD}(kwargs...) where AD - NUTS{AD}(0, 0.65; kwargs...) + NUTS{AD}(-1, 0.65; kwargs...) end for alg in (:HMC, :HMCDA, :NUTS) @@ -304,7 +307,7 @@ function Sampler( initial_spl = Sampler(alg, info, s, initial_state) # Create the actual state based on the alg type. - state = HMCState(model, initial_spl, GLOBAL_RNG) + state = HMCState(model, initial_spl, Random.GLOBAL_RNG) # Create a real sampler after getting all the types/running the init phase. return Sampler(alg, initial_spl.info, initial_spl.selector, state) @@ -436,14 +439,14 @@ function assume( vn::VarName, vi::VarInfo ) - Turing.DEBUG && _debug("assuming...") + Turing.DEBUG && @debug "assuming..." updategid!(vi, vn, spl) r = vi[vn] # acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn))) # r - Turing.DEBUG && _debug("dist = $dist") - Turing.DEBUG && _debug("vn = $vn") - Turing.DEBUG && _debug("r = $r, typeof(r)=$(typeof(r))") + Turing.DEBUG && @debug "dist = $dist" + Turing.DEBUG && @debug "vn = $vn" + Turing.DEBUG && @debug "r = $r" "typeof(r)=$(typeof(r))" return r, logpdf_with_trans(dist, r, istrans(vi, vn)) end @@ -454,7 +457,7 @@ function dot_assume( var::AbstractMatrix, vi::VarInfo, ) - @assert dim(dist) == size(var, 1) + @assert length(dist) == size(var, 1) updategid!.(Ref(vi), vns, Ref(spl)) r = vi[vns] var .= r @@ -498,6 +501,7 @@ end function AHMCAdaptor(alg::AdaptiveHamiltonian, metric::AHMC.AbstractMetric; ϵ=alg.ϵ) pc = AHMC.Preconditioner(metric) da = AHMC.NesterovDualAveraging(alg.δ, ϵ) + if iszero(alg.n_adapts) adaptor = AHMC.Adaptation.NoAdaptation() else @@ -508,6 +512,7 @@ function AHMCAdaptor(alg::AdaptiveHamiltonian, metric::AHMC.AbstractMetric; ϵ=a AHMC.initialize!(adaptor, alg.n_adapts) end end + return adaptor end @@ -560,39 +565,4 @@ function HMCState( invlink!(vi, spl) return HMCState(vi, 0, 0, traj, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) -end - -####################################################### -# Special callback functionality for the HMC samplers # -####################################################### - -mutable struct HMCCallback{ - ProgType<:ProgressMeter.AbstractProgress -} <: AbstractCallback - p :: ProgType -end - - -function AbstractMCMC.callback( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:Union{StaticHamiltonian, AdaptiveHamiltonian}}, - N::Integer, - iteration::Integer, - t::HamiltonianTransition, - cb::HMCCallback; - kwargs... -) - AHMC.pm_next!(cb.p, (iteration=iteration, t.stat..., mass_matrix=spl.state.h.metric)) -end - -function AbstractMCMC.init_callback( - rng::AbstractRNG, - model::Model, - s::Sampler{<:Union{StaticHamiltonian, AdaptiveHamiltonian}}, - N::Integer; - dt::Real=0.25, - kwargs... -) - return HMCCallback(ProgressMeter.Progress(N, dt=dt, desc="Sampling ", barlen=31)) -end +end \ No newline at end of file diff --git a/test/Turing/inference/mh.jl b/test/Turing/inference/mh.jl index 54ad81500..76e158f42 100644 --- a/test/Turing/inference/mh.jl +++ b/test/Turing/inference/mh.jl @@ -1,174 +1,301 @@ -mutable struct MHState{V<:VarInfo} <: AbstractSamplerState - proposal_ratio :: Float64 - prior_prob :: Float64 - violating_support :: Bool - vi :: V +### +### Sampler states +### + +struct MH{space, P} <: InferenceAlgorithm + proposals::P end -MHState(model::Model) = MHState(0.0, 0.0, false, VarInfo(model)) +function MH(space...) + syms = Symbol[] -""" - MH() + prop_syms = Symbol[] + props = AMH.Proposal[] + + check_support(dist) = insupport(dist, z) + + for s in space + if s isa Symbol + push!(syms, s) + elseif s isa Pair || s isa Tuple + push!(prop_syms, s[1]) + + if s[2] isa AMH.Proposal + push!(props, s[2]) + elseif s[2] isa Distribution + push!(props, AMH.Proposal(AMH.Static(), s[2])) + elseif s[2] isa Function + push!(props, AMH.Proposal(AMH.Static(), s[2])) + end + end + end -Metropolis-Hastings sampler. + proposals = NamedTuple{tuple(prop_syms...)}(tuple(props...)) + syms = vcat(syms, prop_syms) + return MH{tuple(syms...), typeof(proposals)}(proposals) +end -Usage: +alg_str(::Sampler{<:MH}) = "MH" -```julia -MH(:m) -MH((:m, x -> Normal(x, 0.1))) -``` +################# +# MH Transition # +################# -Example: +struct MHTransition{T, F<:AbstractFloat, M<:AMH.Transition} + θ :: T + lp :: F + mh_trans :: M +end -```julia -# Define a simple Normal model with unknown mean and variance. -@model gdemo(x) = begin - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - x[1] ~ Normal(m, sqrt(s)) - x[2] ~ Normal(m, sqrt(s)) +function MHTransition(spl::Sampler{<:MH}, mh_trans::AMH.Transition) + theta = tonamedtuple(spl.state.vi) + return MHTransition(theta, mh_trans.lp, mh_trans) end -chn = sample(gdemo([1.5, 2]), MH(), 1000) -``` +##################### +# Utility functions # +##################### + """ -mutable struct MH{space} <: InferenceAlgorithm - proposals :: Dict{Symbol,Any} # Proposals for paramters + set_namedtuple!(vi::VarInfo, nt::NamedTuple) + +Places the values of a `NamedTuple` into the relevant places of a `VarInfo`. +""" +function set_namedtuple!(vi::VarInfo, nt::NamedTuple) + for (n, vals) in pairs(nt) + vns = vi.metadata[n].vns + + n_vns = length(vns) + n_vals = length(vals) + v_isarr = vals isa AbstractArray + + if v_isarr && n_vals == 1 && n_vns > 1 + for (vn, val) in zip(vns, vals[1]) + vi[vn] = val isa AbstractArray ? val : [val] + end + elseif v_isarr && n_vals > 1 && n_vns == 1 + vi[vns[1]] = vals + elseif v_isarr && n_vals == 1 && n_vns == 1 + if vals[1] isa AbstractArray + vi[vns[1]] = vals[1] + else + vi[vns[1]] = [vals[1]] + end + elseif !(v_isarr) + vi[vns[1]] = [vals] + else + error("Cannot assign `NamedTuple` to `VarInfo`") + end + end end -function MH(proposals::Dict{Symbol, Any}, space::Tuple) - return MH{space}(proposals) +""" + gen_logπ_mh(vi::VarInfo, spl::Sampler, model) + +Generate a log density function -- this variant uses the +`set_namedtuple!` function to update the `VarInfo`. +""" +function gen_logπ_mh(spl::Sampler, model) + function logπ(x)::Float64 + vi = spl.state.vi + x_old, lj_old = vi[spl], getlogp(vi) + # vi[spl] = x + set_namedtuple!(vi, x) + runmodel!(model, vi) + lj = getlogp(vi) + vi[spl] = x_old + setlogp!(vi, lj_old) + return lj + end + return logπ end -function MH(space...) - new_space = () - proposals = Dict{Symbol,Any}() +""" + dist_val_tuple(spl::Sampler{<:MH}) + +Returns two `NamedTuples`. The first `NamedTuple` has symbols as keys and distributions as values. +The second `NamedTuple` has model symbols as keys and their stored values as values. +""" +function dist_val_tuple(spl::Sampler{<:MH}) + vns = _getvns(spl.state.vi, spl) + dt = _dist_tuple(spl.state.vi.metadata, spl.alg.proposals, spl.state.vi, vns) + vt = _val_tuple(spl.state.vi.metadata, spl.state.vi, vns) + return dt, vt +end - # parse random variables with their hypothetical proposal - for element in space - if isa(element, Symbol) - new_space = (new_space..., element) +@generated function _val_tuple(metadata::NamedTuple, vi::VarInfo, vns::NamedTuple{names}) where {names} + length(names) === 0 && return :(NamedTuple()) + expr = Expr(:tuple) + map(names) do f + push!(expr.args, Expr(:(=), f, :( + length(metadata.$f.vns) == 1 ? getindex(vi, metadata.$f.vns)[1] : getindex.(Ref(vi), metadata.$f.vns) + ))) + end + return expr +end + +@generated function _dist_tuple( + metadata::NamedTuple, + props::NamedTuple{propnames}, + vi::VarInfo, + vns::NamedTuple{names} +) where {names, propnames} + length(names) === 0 && return :(NamedTuple()) + expr = Expr(:tuple) + map(names) do f + if f in propnames + # We've been given a custom proposal, use that instead. + push!(expr.args, Expr(:(=), f, :(props.$f))) else - @assert isa(element[1], Symbol) "[MH] ($element[1]) should be a Symbol. For proposal, use the syntax MH((:m, x -> Normal(x, 0.1)))" - new_space = (new_space..., element[1]) - proposals[element[1]] = element[2] + # Otherwise, use the default proposal. + push!(expr.args, Expr(:(=), f, :(AMH.Proposal(AMH.Static(), metadata.$f.dists[1])))) end end - return MH(proposals, new_space) + return expr +end + +################# +# Sampler state # +################# + +mutable struct MHState{V<:VarInfo} <: AbstractSamplerState + vi :: V + density_model :: AMH.DensityModel end -function Sampler(alg::MH, model::Model, s::Selector) - alg_str = "MH" +############################### +# Static MH (from prior only) # +############################### +function Sampler( + alg::MH, + model::Model, + s::Selector=Selector() +) + # Set up info dict. info = Dict{Symbol, Any}() - state = MHState(model) - return Sampler(alg, info, s, state) -end + # Make a varinfo. + vi = VarInfo(model) + + # Make a density model. + dm = AMH.DensityModel(x -> 0.0) + + # Set up state struct. + state = MHState(vi, dm) + + # Generate a sampler. + spl = Sampler(alg, info, s, state) -function propose(model, spl::Sampler{<:MH}, vi::VarInfo) - spl.state.proposal_ratio = 0.0 - spl.state.prior_prob = 0.0 - spl.state.violating_support = false - return runmodel!(model, spl.state.vi, spl) + # Update the density model. + spl.state.density_model = AMH.DensityModel(gen_logπ_mh(spl, model)) + + return spl end -# First step always returns a value. -function AbstractMCMC.step!( - ::AbstractRNG, +function AbstractMCMC.sample_init!( + rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, - ::Integer, - ::Nothing; + N::Integer; + verbose::Bool=true, + resume_from=nothing, kwargs... ) - return Transition(spl) + # Resume the sampler. + set_resume!(spl; resume_from=resume_from, kwargs...) + + # Get `init_theta` + initialize_parameters!(spl; verbose=verbose, kwargs...) end -# Every step after the first. function AbstractMCMC.step!( - ::AbstractRNG, + rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, - ::Integer, + N::Integer, transition; kwargs... ) if spl.selector.rerun # Recompute joint in logp runmodel!(model, spl.state.vi) end - old_θ = copy(spl.state.vi[spl]) - old_logp = getlogp(spl.state.vi) - Turing.DEBUG && @debug "Propose new parameters from proposals..." - propose(model, spl, spl.state.vi) + # Retrieve distribution and value NamedTuples. + dt, vt = dist_val_tuple(spl) - Turing.DEBUG && @debug "Decide whether to accept..." - accepted = !spl.state.violating_support && mh_accept(old_logp, getlogp(spl.state.vi), spl.state.proposal_ratio) + # Create a sampler and the previous transition. + mh_sampler = AMH.MetropolisHastings(dt) + prev_trans = AMH.Transition(vt, getlogp(spl.state.vi)) - # reset Θ and logp if the proposal is rejected - if !accepted - spl.state.vi[spl] = old_θ - setlogp!(spl.state.vi, old_logp) - end + # Make a new transition. + trans = AbstractMCMC.step!(rng, spl.state.density_model, mh_sampler, 1, prev_trans) + + # Update the values in the VarInfo. + set_namedtuple!(spl.state.vi, trans.params) + setlogp!(spl.state.vi, trans.lp) return Transition(spl) end -function assume(spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi::VarInfo) - if vn in getspace(spl) - if ~haskey(vi, vn) error("[MH] does not handle stochastic existence yet") end - old_val = vi[vn] - sym = getsym(vn) - - if sym in keys(spl.alg.proposals) # Custom proposal for this parameter - proposal = spl.alg.proposals[sym](old_val) - if proposal isa Distributions.Normal # If Gaussian proposal - σ = std(proposal) - lb = support(dist).lb - ub = support(dist).ub - stdG = Normal() - r = rand(truncated(Normal(proposal.μ, proposal.σ), lb, ub)) - # cf http://fsaad.scripts.mit.edu/randomseed/metropolis-hastings-sampling-with-gaussian-drift-proposal-on-bounded-support/ - spl.state.proposal_ratio += log(cdf(stdG, (ub-old_val)/σ) - cdf(stdG,(lb-old_val)/σ)) - spl.state.proposal_ratio -= log(cdf(stdG, (ub-r)/σ) - cdf(stdG,(lb-r)/σ)) - else # Other than Gaussian proposal - r = rand(proposal) - if !(insupport(dist, r)) # check if value lies in support - spl.state.violating_support = true - r = old_val - end - spl.state.proposal_ratio -= logpdf(proposal, r) # accumulate pdf of proposal - reverse_proposal = spl.alg.proposals[sym](r) - spl.state.proposal_ratio += logpdf(reverse_proposal, old_val) - end - - else # Prior as proposal - r = rand(dist) - spl.state.proposal_ratio += (logpdf(dist, old_val) - logpdf(dist, r)) - end - - spl.state.prior_prob += logpdf(dist, r) # accumulate prior for PMMH - vi[vn] = vectorize(dist, r) - setgid!(vi, spl.selector, vn) - else - r = vi[vn] - end +#### +#### Compiler interface, i.e. tilde operators. +#### +function assume( + spl::Sampler{<:MH}, + dist::Distribution, + vn::VarName, + vi::VarInfo +) + updategid!(vi, vn, spl) + r = vi[vn] + return r, logpdf_with_trans(dist, r, istrans(vi, vn)) +end - # acclogp!(vi, logpdf(dist, r)) # accumulate pdf of prior - r, logpdf(dist, r) +function dot_assume( + spl::Sampler{<:MH}, + dist::MultivariateDistribution, + vn::VarName, + var::AbstractMatrix, + vi::VarInfo, +) + @assert dim(dist) == size(var, 1) + getvn = i -> VarName(vn, vn.indexing * "[:,$i]") + vns = getvn.(1:size(var, 2)) + updategid!.(Ref(vi), vns, Ref(spl)) + r = vi[vns] + var .= r + return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1]))) +end +function dot_assume( + spl::Sampler{<:MH}, + dists::Union{Distribution, AbstractArray{<:Distribution}}, + vn::VarName, + var::AbstractArray, + vi::VarInfo, +) + getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]") + vns = getvn.(CartesianIndices(var)) + updategid!.(Ref(vi), vns, Ref(spl)) + r = reshape(vi[vec(vns)], size(var)) + var .= r + return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) end -function observe(spl::Sampler{<:MH}, d::Distribution, value, vi::VarInfo) - return observe(SampleFromPrior(), d, value, vi) # accumulate pdf of likelihood +function observe( + spl::Sampler{<:MH}, + d::Distribution, + value, + vi::VarInfo, +) + return observe(SampleFromPrior(), d, value, vi) end function dot_observe( spl::Sampler{<:MH}, - ds, - value, + ds::Union{Distribution, AbstractArray{<:Distribution}}, + value::AbstractArray, vi::VarInfo, ) - return dot_observe(SampleFromPrior(), ds, value, vi) # accumulate pdf of likelihood + return dot_observe(SampleFromPrior(), ds, value, vi) end diff --git a/test/Turing/stdlib/RandomMeasures.jl b/test/Turing/stdlib/RandomMeasures.jl index dd5c5f0f6..c11e39680 100644 --- a/test/Turing/stdlib/RandomMeasures.jl +++ b/test/Turing/stdlib/RandomMeasures.jl @@ -82,7 +82,7 @@ function rand(rng::AbstractRNG, d::ChineseRestaurantProcess) end minimum(d::ChineseRestaurantProcess) = 1 -maximum(d::ChineseRestaurantProcess) = length(d.m) + 1 +maximum(d::ChineseRestaurantProcess) = any(iszero, d.m) ? length(d.m) : length(d.m)+1 ## ################# ## ## Random partitions ## @@ -128,41 +128,28 @@ function distribution(d::SizeBiasedSamplingProcess{<:DirichletProcess}) return LocationScale(zero(α), d.surplus, Beta(one(α), α)) end -function _logpdf_table(d::DirichletProcess, m::AbstractVector{Int}) - # compute the sum of all cluster counts - sum_m = sum(m) - - # shortcut if all cluster counts are zero - dα = d.α - T = typeof(dα) - iszero(sum_m) && return zeros(T, 1) - - # pre-calculations - z = log(sum_m - 1 + dα) +function _logpdf_table(d::DirichletProcess{T}, m::AbstractVector{Int}) where {T<:Real} # construct the table - K = length(m) - table = Vector{T}(undef, K) - contains_zero = false - @inbounds for i in 1:K - mi = m[i] - - if iszero(mi) - if contains_zero - table[i] = -Inf - else - table[i] = log(dα) - z - contains_zero = true - end - else - table[i] = log(mi) - z - end + first_zero = findfirst(iszero, m) + K = first_zero === nothing ? length(m)+1 : length(m) + table = fill(T(-Inf), K) + + # exit if m is empty or contains only zeros + if iszero(m) + table[1] = T(0) + return table end - if !contains_zero - push!(table, log(dα) - z) + # compute logpdf for each occupied table + @inbounds for i in 1:(K-1) + table[i] = T(log(m[i])) end + # logpdf for new table + k_new = first_zero === nothing ? K : first_zero + table[k_new] = log(d.α) + return table end @@ -211,42 +198,30 @@ function distribution(d::SizeBiasedSamplingProcess{<:PitmanYorProcess}) return LocationScale(zero(d_rpm_d), d.surplus, dist) end -function _logpdf_table(d::PitmanYorProcess, m::AbstractVector{Int}) - # compute the sum of all cluster counts - sum_m = sum(m) +function _logpdf_table(d::PitmanYorProcess{T}, m::AbstractVector{Int}) where {T<:Real} + # sanity check + @assert d.t == sum(!iszero, m) - # shortcut if all cluster counts are zero - dd = d.d - T = typeof(dd) - iszero(sum_m) && return zeros(T, 1) + # construct table + first_zero = findfirst(iszero, m) + K = first_zero === nothing ? length(m)+1 : length(m) + table = fill(T(-Inf), K) - # pre-calculations - dθ = d.θ - z = log(sum_m + dθ) - - # construct the table - K = length(m) - table = Vector{T}(undef, K) - contains_zero = false - @inbounds for i in 1:K - mi = m[i] - - if iszero(mi) - if contains_zero - table[i] = -Inf - else - table[i] = log(dθ + dd * d.t) - z - contains_zero = true - end - else - table[i] = log(mi - dd) - z - end + # exit if m is empty or contains only zeros + if iszero(m) + table[1] = T(0) + return table end - if !contains_zero - push!(table, log(dθ + dd * d.t) - z) + # compute logpdf for each occupied table + @inbounds for i in 1:(K-1) + !iszero(m[i]) && ( table[i] = T(log(m[i] - d.d)) ) end + # logpdf for new table + k_new = first_zero === nothing ? K : first_zero + table[k_new] = log(d.θ + d.d * d.t) + return table end diff --git a/test/Turing/stdlib/distributions.jl b/test/Turing/stdlib/distributions.jl index b292b87c0..8407de140 100644 --- a/test/Turing/stdlib/distributions.jl +++ b/test/Turing/stdlib/distributions.jl @@ -103,6 +103,8 @@ function Distributions.logpdf(d::OrderedLogistic, k::Int) return logp end +Distributions.pdf(d::OrderedLogistic, k::Int) = exp(logpdf(d,k)) + function Distributions.rand(rng::AbstractRNG, d::OrderedLogistic) cutpoints = d.cutpoints η = d.η diff --git a/test/Turing/variational/VariationalInference.jl b/test/Turing/variational/VariationalInference.jl index b9e0ff6eb..ef23df198 100644 --- a/test/Turing/variational/VariationalInference.jl +++ b/test/Turing/variational/VariationalInference.jl @@ -2,10 +2,9 @@ module Variational using ..Core, ..Utilities using Distributions, Bijectors, DynamicPPL -using ProgressMeter, LinearAlgebra -using ..Turing: PROGRESS +using LinearAlgebra +using ..Turing: PROGRESS, Turing using DynamicPPL: Model, SampleFromPrior, SampleFromUniform -using ..Turing: Turing using Random: AbstractRNG using ForwardDiff @@ -13,16 +12,27 @@ using Tracker import ..Core: getchunksize, getADtype +import ProgressLogging + +import Logging +import UUIDs + using Requires function __init__() - @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ) + @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin + apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ) + Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) + Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) + end end export vi, ADVI, ELBO, - TruncatedADAGrad + elbo, + TruncatedADAGrad, + DecayedADAGrad abstract type VariationalInference{AD} end @@ -33,29 +43,9 @@ abstract type VariationalObjective end const VariationalPosterior = Distribution{Multivariate, Continuous} -""" - rand(vi::VariationalInference, num_samples) - -Produces `num_samples` samples for the given VI method using number of samples equal -to `num_samples`. -""" -function rand(vi::VariationalPosterior, num_samples) end - -""" - objective(vi::VariationalInference, q::VariationalPosterior, model::Model, args...) - -Computes the variational objective to be optimized for a given VI method. -""" -function objective( - vi::VariationalInference, - q::VariationalPosterior, - model::Model, - num_samples) -end - """ - grad!(vo, vi::VariationalInference, q::VariationalPosterior, model::Model, θ, out, args...) + grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...) Computes the gradients used in `optimize!`. Default implementation is provided for `VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. @@ -63,47 +53,40 @@ This implicitly also gives a default implementation of `optimize!`. Variance reduction techniques, e.g. control variates, should be implemented in this function. """ -function grad!( - vo, vi::VariationalInference, - q::VariationalPosterior, - model::Model, - θ, - out, - args... -) - error("Turing.Variational.grad!: unmanaged variational inference algorithm: " - * "$(typeof(alg))") -end +function grad! end """ - vi(model::Model, alg::VariationalInference) - vi(model::Model, alg::VariationalInference, q::VariationalPosterior) + vi(model, alg::VariationalInference) + vi(model, alg::VariationalInference, q::VariationalPosterior) + vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray) Constructs the variational posterior from the `model` and performs the optimization following the configuration of the given `VariationalInference` instance. + +# Arguments +- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations +- `alg`: the VI algorithm used +- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. +- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` +- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior """ -function vi(model::Model, alg::VariationalInference) - error("Turing.Variational.vi: variational inference algorithm $(typeof(alg)) " - * "is not implemented") -end -function vi(model::Model, alg::VariationalInference, q::VariationalPosterior) - error("Turing.Variational.vi: variational inference algorithm $(typeof(alg)) " - * "is not implemented") -end +function vi end # default implementations function grad!( vo, alg::VariationalInference{<:ForwardDiffAD}, - q::VariationalPosterior, - model::Model, + q, + model, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, args... ) - # TODO: this probably slows down executation quite a bit; exists a better way - # of doing this? - f(θ_) = - vo(alg, q, model, θ_, args...) + f(θ_) = if (q isa VariationalPosterior) + - vo(alg, update(q, θ_), model, args...) + else + - vo(alg, q(θ_), model, args...) + end chunk_size = getchunksize(typeof(alg)) # Set chunk size and do ForwardMode. @@ -115,22 +98,26 @@ end function grad!( vo, alg::VariationalInference{<:TrackerAD}, - q::VariationalPosterior, - model::Model, + q, + model, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, args... ) - θ_tracked = [Tracker.param(θ[i]) for i ∈ eachindex(θ)] - y = - vo(alg, q, model, θ_tracked, args...) + θ_tracked = Tracker.param(θ) + y = if (q isa VariationalPosterior) + - vo(alg, update(q, θ_tracked), model, args...) + else + - vo(alg, q(θ_tracked), model, args...) + end Tracker.back!(y, 1.0) DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, [Tracker.grad(θ_tracked[i]) for i ∈ eachindex(θ_tracked)]) + DiffResults.gradient!(out, Tracker.grad(θ_tracked)) end """ - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) + optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model, θ; optimizer = TruncatedADAGrad()) Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute the steps. @@ -138,57 +125,97 @@ the steps. function optimize!( vo, alg::VariationalInference, - q::VariationalPosterior, - model::Model, - θ; - optimizer = TruncatedADAGrad() + q, + model, + θ::AbstractVector{<:Real}; + optimizer = TruncatedADAGrad(), + progress = Turing.PROGRESS[], + progressname = "[$(alg_str(alg))] Optimizing..." ) # TODO: should we always assume `samples_per_step` and `max_iters` for all algos? - alg_name = alg_str(alg) samples_per_step = alg.samples_per_step max_iters = alg.max_iters - # number of previous gradients to use to compute `s` in adaGrad - stepsize_num_prev = 10 - - num_params = length(q) - # TODO: really need a better way to warn the user about potentially # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc)) + if optimizer isa TruncatedADAGrad && θ ∉ keys(optimizer.acc) # this message should only occurr once in the optimization process - @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ) + @info "[$(alg_str(alg))] Should only be seen once: optimizer created for θ" objectid(θ) end - + diff_result = DiffResults.GradientResult(θ) - i = 0 - prog = if PROGRESS[] - ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0) - else - 0 + # Create the progress bar. + if progress + progressid = UUIDs.uuid4() + Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN, + _id=progressid) end - # add criterion? A running mean maybe? - time_elapsed = @elapsed while (i < max_iters) # & converged - grad!(vo, alg, q, model, θ, diff_result, samples_per_step) + try + # add criterion? A running mean maybe? + for i in 1:max_iters + grad!(vo, alg, q, model, θ, diff_result, samples_per_step) + + # apply update rule + Δ = DiffResults.gradient(diff_result) + Δ = apply!(optimizer, θ, Δ) + @. θ = θ - Δ + + Turing.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) + + # Update the progress bar. + if progress + Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, + progress=i/max_iters, _id=progressid) + end + end + finally + if progress + Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress="done", + _id=progressid) + end + end - # apply update rule - Δ = DiffResults.gradient(diff_result) - Δ = apply!(optimizer, θ, Δ) - @. θ = θ - Δ - - Turing.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) - PROGRESS[] && (ProgressMeter.next!(prog)) + return θ +end - i += 1 +""" + make_logjoint(model::Model; weight = 1.0) + +Constructs the logjoint as a function of latent variables, i.e. the map z → p(x ∣ z) p(z). + +The weight used to scale the likelihood, e.g. when doing stochastic gradient descent one needs to +use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_obs / batch_size`. + +## Notes +- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`. +""" +function make_logjoint(model::Model; weight = 1.0) + # setup + ctx = DynamicPPL.MiniBatchContext( + DynamicPPL.DefaultContext(), + weight + ) + varinfo_init = Turing.VarInfo(model, ctx) + + function logπ(z) + varinfo = VarInfo(varinfo_init, SampleFromUniform(), z) + model(varinfo) + + return getlogp(varinfo) end - return θ + return logπ +end + +function logjoint(model::Model, varinfo, z) + varinfo = VarInfo(varinfo, SampleFromUniform(), z) + model(varinfo) + + return getlogp(varinfo) end -# distributions -include("distributions.jl") # objectives include("objectives.jl") diff --git a/test/Turing/variational/advi.jl b/test/Turing/variational/advi.jl index 8ab364ee5..70e05afda 100644 --- a/test/Turing/variational/advi.jl +++ b/test/Turing/variational/advi.jl @@ -1,26 +1,78 @@ +using StatsFuns +using DistributionsAD +using Bijectors +using Bijectors: TransformedDistribution +using Random: AbstractRNG, GLOBAL_RNG +import Bijectors: bijector + +update(d::TuringDiagMvNormal, μ, σ) = TuringDiagMvNormal(μ, σ) +update(td::TransformedDistribution, θ...) = transformed(update(td.dist, θ...), td.transform) +function update(td::TransformedDistribution{<:TuringDiagMvNormal}, θ::AbstractArray) + μ, ω = θ[1:length(td)], θ[length(td) + 1:end] + return update(td, μ, softplus.(ω)) +end + +# TODO: add these to DistributionsAD.jl and remove from here +Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ) + """ - ADVI(samples_per_step = 10, max_iters = 5000) + bijector(model::Model; sym_to_ranges = Val(false)) -Automatic Differentiation Variational Inference (ADVI) for a given model. +Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` +denoting the dimensionality of the latent variables. """ -struct ADVI{AD} <: VariationalInference{AD} - samples_per_step # number of samples used to estimate the ELBO in each optimization step - max_iters # maximum number of gradient steps used in optimization -end +function bijector(model::Model; sym_to_ranges::Val{sym2ranges} = Val(false)) where {sym2ranges} + varinfo = Turing.VarInfo(model) + num_params = sum([size(varinfo.metadata[sym].vals, 1) + for sym ∈ keys(varinfo.metadata)]) -ADVI(args...) = ADVI{ADBackend()}(args...) -ADVI() = ADVI(10, 5000) + dists = vcat([varinfo.metadata[sym].dists for sym ∈ keys(varinfo.metadata)]...) -alg_str(::ADVI) = "ADVI" + num_ranges = sum([length(varinfo.metadata[sym].ranges) + for sym ∈ keys(varinfo.metadata)]) + ranges = Vector{UnitRange{Int}}(undef, num_ranges) + idx = 0 + range_idx = 1 -function vi(model::Model, alg::ADVI; optimizer = TruncatedADAGrad()) + # ranges might be discontinuous => values are vectors of ranges rather than just ranges + sym_lookup = Dict{Symbol, Vector{UnitRange{Int}}}() + for sym ∈ keys(varinfo.metadata) + sym_lookup[sym] = Vector{UnitRange{Int}}() + for r ∈ varinfo.metadata[sym].ranges + ranges[range_idx] = idx .+ r + push!(sym_lookup[sym], ranges[range_idx]) + range_idx += 1 + end + + idx += varinfo.metadata[sym].ranges[end][end] + end + + bs = inv.(bijector.(tuple(dists...))) + + if sym2ranges + return Stacked(bs, ranges), (; collect(zip(keys(sym_lookup), values(sym_lookup)))...) + else + return Stacked(bs, ranges) + end +end + +""" + meanfield(model::Model) + meanfield(rng::AbstractRNG, model::Model) + +Creates a mean-field approximation with multivariate normal as underlying distribution. +""" +meanfield(model::Model) = meanfield(GLOBAL_RNG, model) +function meanfield(rng::AbstractRNG, model::Model) # setup varinfo = Turing.VarInfo(model) - num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym ∈ keys(varinfo.metadata)]) + num_params = sum([size(varinfo.metadata[sym].vals, 1) + for sym ∈ keys(varinfo.metadata)]) dists = vcat([varinfo.metadata[sym].dists for sym ∈ keys(varinfo.metadata)]...) - num_ranges = sum([length(varinfo.metadata[sym].ranges) for sym ∈ keys(varinfo.metadata)]) + num_ranges = sum([length(varinfo.metadata[sym].ranges) + for sym ∈ keys(varinfo.metadata)]) ranges = Vector{UnitRange{Int}}(undef, num_ranges) idx = 0 range_idx = 1 @@ -34,101 +86,125 @@ function vi(model::Model, alg::ADVI; optimizer = TruncatedADAGrad()) idx += varinfo.metadata[sym].ranges[end][end] end - q = Variational.MeanField(zeros(num_params), zeros(num_params), dists, ranges) - - # construct objective - elbo = ELBO() + # initial params + μ = randn(rng, num_params) + σ = softplus.(randn(rng, num_params)) - Turing.DEBUG && @debug "Optimizing ADVI..." - θ = optimize(elbo, alg, q, model; optimizer = optimizer) - μ, ω = θ[1:length(q)], θ[length(q) + 1:end] + # construct variational posterior + d = TuringDiagMvNormal(μ, σ) + bs = inv.(bijector.(tuple(dists...))) + b = Stacked(bs, ranges) + + return transformed(d, b) +end + +""" + ADVI(samples_per_step = 1, max_iters = 1000) - return MeanField(μ, ω, dists, ranges) +Automatic Differentiation Variational Inference (ADVI) for a given model. +""" +struct ADVI{AD} <: VariationalInference{AD} + samples_per_step # number of samples used to estimate the ELBO in each optimization step + max_iters # maximum number of gradient steps used in optimization end -# TODO: implement optimize like this? -# (advi::ADVI)(elbo::EBLO, q::MeanField, model::Model) = begin -# end +ADVI(args...) = ADVI{ADBackend()}(args...) +ADVI() = ADVI(1, 1000) + +alg_str(::ADVI) = "ADVI" -function optimize(elbo::ELBO, alg::ADVI, q::MeanField, model::Model; optimizer = TruncatedADAGrad()) - θ = randn(2 * length(q)) + +function vi(model::Model, alg::ADVI; optimizer = TruncatedADAGrad()) + q = meanfield(model) + return vi(model, alg, q; optimizer = optimizer) +end + +function vi(model, alg::ADVI, q::TransformedDistribution{<:TuringDiagMvNormal}; optimizer = TruncatedADAGrad()) + Turing.DEBUG && @debug "Optimizing ADVI..." + # Initial parameters for mean-field approx + μ, σs = params(q) + θ = vcat(μ, invsoftplus.(σs)) + + # Optimize + optimize!(elbo, alg, q, model, θ; optimizer = optimizer) + + # Return updated `Distribution` + return update(q, θ) +end + +function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad()) + Turing.DEBUG && @debug "Optimizing ADVI..." + θ = copy(θ_init) optimize!(elbo, alg, q, model, θ; optimizer = optimizer) + # If `q` is a mean-field approx we use the specialized `update` function + if q isa Distribution + return update(q, θ) + else + # Otherwise we assume it's a mapping θ → q + return q(θ) + end +end + + +function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad()) + θ = copy(θ_init) + + if model isa Model + optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer = optimizer) + else + # `model` assumed to be callable z ↦ p(x, z) + optimize!(elbo, alg, q, model, θ; optimizer = optimizer) + end + return θ end +# WITHOUT updating parameters inside ELBO function (elbo::ELBO)( + rng::AbstractRNG, alg::ADVI, - q::MeanField, - model::Model, - θ::AbstractVector{<:Real}, + q::VariationalPosterior, + logπ::Function, num_samples ) - # setup - varinfo = Turing.VarInfo(model) - - T = eltype(θ) - num_params = length(q) - μ, ω = θ[1:num_params], θ[num_params + 1: end] + # 𝔼_q(z)[log p(xᵢ, z)] + # = ∫ log p(xᵢ, z) q(z) dz + # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ (since change of variables) + # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) + # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + + # 𝔼_q(z)[log q(z)] + # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) + # = 𝔼_q̃(ϕ) [log q(f(ϕ))] + # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|] + # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + + # Finally, the ELBO is given by + # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)] + # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ)) + + # If f: supp(p(z | x)) → ℝ then + # ELBO = 𝔼[log p(x, z) - log q(z)] + # = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃)) + # = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃)) + + # But our `forward(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` + _, z, logjac, _ = forward(rng, q) + res = (logπ(z) + logjac) / num_samples + + if q isa TransformedDistribution + res += entropy(q.dist) + else + res += entropy(q) + end - elbo_acc = 0.0 - - # TODO: instead use `rand(q, num_samples)` and iterate through? - # Requires new interface for Bijectors.jl - - for i = 1:num_samples - # iterate through priors, sample and update - idx = 0 - z = zeros(T, num_params) - - for sym ∈ keys(varinfo.metadata) - md = varinfo.metadata[sym] - - for i = 1:size(md.dists, 1) - prior = md.dists[i] - r = md.ranges[i] .+ idx - - # mean-field params for this set of model params - μ_i = μ[r] - ω_i = ω[r] - - # obtain samples from mean-field posterior approximation - η = randn(length(μ_i)) - ζ = center_diag_gaussian_inv(η, μ_i, exp.(ω_i)) - - # inverse-transform back to domain of original priro - z[r] .= invlink(prior, ζ) - - # update - # @info θ - # z[md.ranges[i]] .= θ - # @info md.vals - - # add the log-det-jacobian of inverse transform; - # `logabsdet` returns `(log(abs(det(M))), sign(det(M)))` so return first entry - # add `eps` to ensure SingularException does not occurr in `logabsdet` - elbo_acc += logabsdet(jac_inv_transform(prior, ζ) .+ eps(T))[1] / num_samples - end - - idx += md.ranges[end][end] - end - - # compute log density - varinfo = VarInfo(varinfo, SampleFromUniform(), z) - model(varinfo) - elbo_acc += getlogp(varinfo) / num_samples + for i = 2:num_samples + _, z, logjac, _ = forward(rng, q) + res += (logπ(z) + logjac) / num_samples end - # add the term for the entropy of the variational posterior - variational_posterior_entropy = sum(ω) - elbo_acc += variational_posterior_entropy - - return elbo_acc + return res end -function (elbo::ELBO)(alg::ADVI, q::MeanField, model::Model, num_samples) - # extract the mean-field Gaussian params - θ = vcat(q.μ, q.ω) - - return elbo(alg, q, model, θ, num_samples) -end diff --git a/test/Turing/variational/distributions.jl b/test/Turing/variational/distributions.jl deleted file mode 100644 index d146e8699..000000000 --- a/test/Turing/variational/distributions.jl +++ /dev/null @@ -1,57 +0,0 @@ -import Distributions: _rand! - - -function jac_inv_transform(dist::Distribution, x::Real) - ForwardDiff.derivative(x -> invlink(dist, x), x) -end - -function jac_inv_transform(dist::Distribution, x::AbstractArray{<:Real}) - ForwardDiff.jacobian(x -> invlink(dist, x), x) -end - -function jac_inv_transform(dist::Distribution, x::TrackedArray{<:Real}) - Tracker.jacobian(x -> invlink(dist, x), x) -end - -# instead of creating a diagonal matrix, we just do elementwise multiplication -center_diag_gaussian(x, μ, σ) = (x .- μ) ./ σ -center_diag_gaussian_inv(η, μ, σ) = (η .* σ) .+ μ - - -# Mean-field approximation used by ADVI -struct MeanField{TDists, V} <: VariationalPosterior where {V <: AbstractVector{<: Real}, TDists <: AbstractVector{<: Distribution}} - μ::V - ω::V - dists::TDists - ranges::Vector{UnitRange{Int}} -end - -Base.length(advi::MeanField) = length(advi.μ) - -function _rand!( - rng::AbstractRNG, - q::MeanField, - x::AbstractVector{<:Real} -) - # extract parameters for convenience - μ, ω = q.μ, q.ω - num_params = length(q) - - for i = 1:size(q.dists, 1) - prior = q.dists[i] - r = q.ranges[i] - - # initials - μ_i = μ[r] - ω_i = ω[r] - - # # sample from VI posterior - η = randn(rng, length(μ_i)) - ζ = center_diag_gaussian_inv(η, μ_i, exp.(ω_i)) - θ = invlink(prior, ζ) - - x[r] = θ - end - - return x -end diff --git a/test/Turing/variational/objectives.jl b/test/Turing/variational/objectives.jl index 58776e164..794d1c9ef 100644 --- a/test/Turing/variational/objectives.jl +++ b/test/Turing/variational/objectives.jl @@ -1 +1,21 @@ +using Random: GLOBAL_RNG + struct ELBO <: VariationalObjective end + +function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...) + return elbo(GLOBAL_RNG, alg, q, logπ, num_samples; kwargs...) +end + +function (elbo::ELBO)( + rng::AbstractRNG, + alg::VariationalInference, + q, + model::Model, + num_samples; + weight = 1.0, + kwargs... +) + return elbo(rng, alg, q, make_logjoint(model; weight = weight), num_samples; kwargs...) +end + +const elbo = ELBO() diff --git a/test/Turing/variational/optimisers.jl b/test/Turing/variational/optimisers.jl index 8785bfedb..8bbc5b4a4 100644 --- a/test/Turing/variational/optimisers.jl +++ b/test/Turing/variational/optimisers.jl @@ -1,5 +1,21 @@ const ϵ = 1e-8 +""" + TruncatedADAGrad(η=0.1, τ=1.0, n=100) + +Implements a truncated version of AdaGrad in the sense that only the `n` previous gradient norms are used to compute the scaling rather than *all* previous. It has parameter specific learning rates based on how frequently it is updated. + +## Parameters + - η: learning rate + - τ: constant scale factor + - n: number of previous gradient norms to use in the scaling. + +## References +[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. +Parameters don't need tuning. + +[TruncatedADAGrad](https://arxiv.org/abs/1506.03431v2) (Appendix E). +""" mutable struct TruncatedADAGrad eta::Float64 tau::Float64 @@ -14,14 +30,16 @@ function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100) end function apply!(o::TruncatedADAGrad, x, Δ) + T = eltype(Tracker.data(Δ)) + η = o.eta τ = o.tau g² = get!( o.acc, x, - [fill(0.0, size(x)) for j = 1:o.n] - )::Array{typeof(Tracker.data(x)), 1} + [zeros(T, size(x)) for j = 1:o.n] + )::Array{typeof(Tracker.data(Δ)), 1} i = get!(o.iters, x, 1)::Int # Example: suppose i = 12 and o.n = 10 @@ -42,3 +60,35 @@ function apply!(o::TruncatedADAGrad, x, Δ) @. Δ *= η / (τ + sqrt(s) + ϵ) end +""" + DecayedADAGrad(η=0.1, pre=1.0, post=0.9) + +Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated. + +## Parameters + - η: learning rate + - pre: weight of new gradient norm + - post: weight of histroy of gradient norms + +## References +[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. +Parameters don't need tuning. +""" +mutable struct DecayedADAGrad + eta::Float64 + pre::Float64 + post::Float64 + + acc::IdDict +end + +DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict()) + +function apply!(o::DecayedADAGrad, x, Δ) + T = eltype(Tracker.data(Δ)) + + η = o.eta + acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x)) + @. acc = o.post * acc + o.pre * Δ^2 + @. Δ *= η / (√acc + ϵ) +end diff --git a/test/prob_macro.jl b/test/prob_macro.jl index d1cbb8e6a..4c89e1dae 100644 --- a/test/prob_macro.jl +++ b/test/prob_macro.jl @@ -35,7 +35,7 @@ Random.seed!(129) varinfo = VarInfo(demo(missing)) @test logprob"x = xval, m = mval | model = demo, varinfo = varinfo" == logjoint - chain = sample(demo(xval), IS(), iters) + chain = sample(demo(xval), IS(), iters; save_state = true) chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) lps = logpdf.(Normal.(vec(chain["m"].value), 1), xval) @test logprob"x = xval | chain = chain" == lps @@ -70,7 +70,7 @@ Random.seed!(129) @test logprob"x = xval | m = mval, model = demo, varinfo = varinfo" == loglike # Currently, we cannot easily pre-allocate `VarInfo` for vector data - chain = sample(demo(xval), HMC(0.5, 1), iters) + chain = sample(demo(xval), HMC(0.5, 1), iters; save_state = true) chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) lps = like.([[chain["m[$i]"].value[j] for i in 1:n] for j in 1:iters], Ref(xval)) @test logprob"x = xval | chain = chain" == lps