Skip to content

Commit

Permalink
Update to AbstractMCMC 2 (#150)
Browse files Browse the repository at this point in the history
Co-authored-by: Tor Erlend Fjelde <[email protected]>
  • Loading branch information
devmotion and torfjelde authored Nov 26, 2020
1 parent cf95183 commit 6ac3922
Show file tree
Hide file tree
Showing 23 changed files with 856 additions and 1,093 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.9.8"
version = "0.10.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -12,7 +12,7 @@ NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
AbstractMCMC = "1"
AbstractMCMC = "2"
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
ChainRulesCore = "0.9.7"
Distributions = "0.23.8"
Expand Down
3 changes: 0 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ export AbstractVarInfo,
Sample,
init,
vectorize,
set_resume!,
# Model
Model,
getmissings,
Expand Down Expand Up @@ -122,6 +121,4 @@ include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")

include("deprecations.jl")

end # module
22 changes: 0 additions & 22 deletions src/deprecations.jl

This file was deleted.

6 changes: 0 additions & 6 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ See also: [`evaluate_threadsafe`](@ref)
"""
function evaluate_threadunsafe(rng, model, varinfo, sampler, context)
resetlogp!(varinfo)
if has_eval_num(sampler)
sampler.state.eval_num += 1
end
return _evaluate(rng, model, varinfo, sampler, context)
end

Expand All @@ -143,9 +140,6 @@ See also: [`evaluate_threadunsafe`](@ref)
"""
function evaluate_threadsafe(rng, model, varinfo, sampler, context)
resetlogp!(varinfo)
if has_eval_num(sampler)
sampler.state.eval_num += 1
end
wrapper = ThreadSafeVarInfo(varinfo)
result = _evaluate(rng, model, wrapper, sampler, context)
setlogp!(varinfo, getlogp(wrapper))
Expand Down
140 changes: 105 additions & 35 deletions src/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler`
# That would let us use all defaults for Sampler, combine it with other samplers etc.
"""
Robust initialization method for model parameters in Hamiltonian samplers.
"""
Expand All @@ -17,55 +19,123 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n)
end

"""
has_eval_num(spl::AbstractSampler)
Check whether `spl` has a field called `eval_num` in its state variables or not.
"""
has_eval_num(spl::SampleFromUniform) = false
has_eval_num(spl::SampleFromPrior) = false
has_eval_num(spl::AbstractSampler) = :eval_num in fieldnames(typeof(spl.state))

"""
An abstract type that mutable sampler state structs inherit from.
"""
abstract type AbstractSamplerState end

"""
Sampler{T}
Generic interface for implementing inference algorithms.
An implementation of an algorithm should include the following:
1. A type specifying the algorithm and its parameters, derived from InferenceAlgorithm
2. A method of `sample` function that produces results of inference, which is where actual inference happens.
Generic sampler type for inference algorithms of type `T` in DynamicPPL.
DynamicPPL translates models to chunks that call the modelling functions at specified points.
The dispatch is based on the value of a `sampler` variable.
To include a new inference algorithm implements the requirements mentioned above in a separate file,
then include that file at the end of this one.
`Sampler` should implement the AbstractMCMC interface, and in particular
[`AbstractMCMC.step`](@ref). A default implementation of the initial sampling step is
provided that supports resuming sampling from a previous state and setting initial
parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref)
for loading previous states and actually performing the initial sampling step,
respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref)
that specifies how the initial parameter values are sampled if they are not provided.
By default, values are sampled from the prior.
"""
mutable struct Sampler{T, S<:AbstractSamplerState} <: AbstractSampler
alg :: T
info :: Dict{Symbol, Any} # sampler infomation
selector :: Selector
state :: S
struct Sampler{T} <: AbstractSampler
alg::T
selector::Selector # Can we remove it?
# TODO: add space such that we can integrate existing external samplers in DynamicPPL
end
Sampler(alg) = Sampler(alg, Selector())
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
Sampler(alg, model::Model, s::Selector) = Sampler(alg, model, s)
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)

# AbstractMCMC interface for SampleFromUniform and SampleFromPrior

function AbstractMCMC.step!(
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
sampler::Union{SampleFromUniform,SampleFromPrior},
::Integer,
transition;
state = nothing;
kwargs...
)
vi = VarInfo()
model(vi, sampler)
return vi
model(rng, vi, sampler)
return vi, nothing
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
spl::Sampler;
resume_from = nothing,
kwargs...
)
if resume_from !== nothing
state = loadstate(resume_from)
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end

# Sample initial values.
_spl = initialsampler(spl)
vi = VarInfo(rng, model, _spl)

# Update the parameters if provided.
if haskey(kwargs, :init_params)
initialize_parameters!(vi, kwargs[:init_params], spl)

# Update joint log probability.
model(rng, vi, _spl)
end

return initialstep(rng, model, spl, vi; kwargs...)
end

"""
loadstate(data)
Load sampler state from `data`.
"""
function loadstate end

"""
initialsampler(sampler::Sampler)
Return the sampler that is used for generating the initial parameters when sampling with
`sampler`.
By default, it returns an instance of [`SampleFromPrior`](@ref).
"""
initialsampler(spl::Sampler) = SampleFromPrior()

function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler)
@debug "Using passed-in initial variable values" init_params

# Flatten parameters.
init_theta = mapreduce(vcat, init_params) do x
vec([x;])
end

# Get all values.
linked = islinked(vi, spl)
linked && invlink!(vi, spl)
theta = vi[spl]
length(theta) == length(init_theta_flat) ||
error("Provided initial value doesn't match the dimension of the model")

# Update values that are provided.
for i in 1:length(init_theta)
x = init_theta[i]
if x !== missing
theta[i] = x
end
end

# Update in `vi`.
vi[spl] = theta
linked && link!(vi, spl)

return
end

"""
initialstep(rng, model, sampler, varinfo; kwargs...)
Perform the initial sampling step of the `sampler` for the `model`.
The `varinfo` contains the initial samples, which can be provided by the user or
sampled randomly.
"""
function initialstep end
30 changes: 23 additions & 7 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,38 @@ end
const UntypedVarInfo = VarInfo{<:Metadata}
const TypedVarInfo = VarInfo{<:NamedTuple}

function VarInfo(model::Model, ctx = DefaultContext())
vi = VarInfo()
model(vi, SampleFromPrior(), ctx)
return TypedVarInfo(vi)
end

function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector)
new_vi = deepcopy(old_vi)
new_vi[spl] = x
return new_vi
end

function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
VarInfo(md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)))
end

function VarInfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
context::AbstractContext = DefaultContext(),
)
varinfo = VarInfo()
model(rng, varinfo, sampler, context)
return TypedVarInfo(varinfo)
end
VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...)

# without AbstractSampler
function VarInfo(
rng::Random.AbstractRNG,
model::Model,
context::AbstractContext,
)
return VarInfo(rng, model, SampleFromPrior(), context)
end

@generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space}
exprs = []
offset = :(0)
Expand Down Expand Up @@ -1000,7 +1017,6 @@ from a distribution `dist` to `VarInfo` `vi`.
The sampler is passed here to invalidate its cache where defined.
"""
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler)
spl.info[:cache_updated] = CACHERESET
return push!(vi, vn, r, dist, spl.selector)
end
function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler)
Expand Down
8 changes: 5 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down Expand Up @@ -31,15 +32,16 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "1.0.1"
AbstractMCMC = "2.1"
AdvancedHMC = "0.2.25"
AdvancedMH = "0.5.1"
AdvancedMH = "0.5.2"
AdvancedVI = "0.1"
BangBang = "0.3"
Bijectors = "0.8.2"
Distributions = "0.23.8"
DistributionsAD = "0.6.3"
DocStringExtensions = "0.8.2"
EllipticalSliceSampling = "0.2.2"
EllipticalSliceSampling = "0.3"
ForwardDiff = "0.10.12"
Libtask = "0.4.1, 0.5"
LogDensityProblems = "0.10.3"
Expand Down
Loading

2 comments on commit 6ac3922

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/25340

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.0 -m "<description of version>" 6ac3922bf09cf8e0d124b448353dd18b7afee00d
git push origin v0.10.0

Please sign in to comment.