Skip to content


Resample variable if not given in setval! (#216)
Browse files Browse the repository at this point in the history
Currently if one calls `DynamicPPL._setval!(vi, vi.metadata, values, keys)` , then only those values present in `keys` will be set, as expected, but the variables which are _not_ present in `keys` will simply be left as-is. This means that we get the following behavior:
``` julia
julia> using Turing

julia> @model function demo(x)
           m ~ Normal(0, 1)
           for i in eachindex(x)
               x[i] ~ Normal(m, 1)
demo (generic function with 1 method)

julia> m_missing = demo(fill(missing, 2));

julia> var_info_missing = DynamicPPL.VarInfo(m_missing);

julia> var_info_missing.metadata.m.vals
1-element Array{Float64,1}:

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
julia> var_info_missing.metadata.m.vals # ✓ new value
1-element Array{Float64,1}:
julia> var_info_missing.metadata.x.vals # ✓ still the same value
2-element Array{Float64,1}:

julia> m_missing(var_info_missing) # Re-run the model with new value for `m`

julia> var_info_missing.metadata.x.vals # × still the same and thus not reflecting the change in `m`!
2-element Array{Float64,1}:

_Personally_ I expected `x` to be resampled since now parts of the model has changed and thus the sample `x` is no longer representative of a sample from the model (under the sampler used).

This PR "fixes" the above so that you get the following behavior:
``` julia
julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:

julia> DynamicPPL.setval!(var_info_missing, (m = 0.0, ));

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:

julia> m_missing(var_info_missing)

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:

This was discoverd when debugging TuringLang/Turing.jl#1352 as I want to move `Turing.predict` over to using `DynamicPPL.setval!` and it also has consequences for `DynamicPPL.generated_quantities` which uses `DynamicPPL.setval!` under the hood and thus suffer from the same issue.

There's an alternative: instead of making this the default-behavior, we could add `kwargs...` to `setval!` which includes `resample_missing::Bool` or something. I'm also completely fine with a solution like that 👍
  • Loading branch information
torfjelde committed Apr 3, 2021
1 parent 3602c56 commit 2d6ef3f
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.8"
version = "0.10.9"

AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 1 addition & 1 deletion src/loglikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function pointwise_loglikelihoods(
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
for (sample_idx, chain_idx) in iters
# Update the values
setval!(vi, chain, sample_idx, chain_idx)
setval_and_resample!(vi, chain, sample_idx, chain_idx)

# Execute model
model(vi, spl, ctx)
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ function generated_quantities(model::Model, chain::AbstractChains)
varinfo = VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
setval!(varinfo, chain, sample_idx, chain_idx)
setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,10 @@ end
function inittrans(rng, dist::MatrixDistribution, n::Int)
return invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n])

# Convenience methods #
collectmaybe(x) = x
collectmaybe(x::Base.AbstractSet) = collect(x)
232 changes: 211 additions & 21 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,22 @@ end

# Functions defined only for UntypedVarInfo
Return an iterator over all `vns` in `vi`.
keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)
Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)

@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names}
expr = Expr(:call)
push!(expr.args, :vcat)

for n in names
push!(expr.args, :(vi.metadata.$n.vns))

return expr

setgid!(vi::VarInfo, gid::Selector, vn::VarName)
Expand Down Expand Up @@ -1165,19 +1176,39 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler)

setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x))
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
# TODO: Maybe rename or something?
_apply!(kernel!, vi::AbstractVarInfo, values, keys)
Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`.
function _apply!(kernel!, vi::AbstractVarInfo, values, keys)
keys_strings = map(string, collectmaybe(keys))
num_indices_seen = 0

function _setval!(vi::AbstractVarInfo, values, keys)
for vn in Base.keys(vi)
_setval_kernel!(vi, vn, values, keys)
indices_found = kernel!(vi, vn, values, keys_strings)
if indices_found !== nothing
num_indices_seen += length(indices_found)

if length(keys) > num_indices_seen
# Some keys have not been seen, i.e. attempted to set variables which
# we were not able to locate in `vi`.
# Find the ones we missed so we can warn the user.
unused_keys = _find_missing_keys(vi, keys_strings)
@warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)"

return vi
_setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys)
@generated function _typed_setval!(

_apply!(kernel!, vi::TypedVarInfo, values, keys) = _typed_apply!(
kernel!, vi, vi.metadata, values, collectmaybe(keys))

@generated function _typed_apply!(
Expand All @@ -1186,30 +1217,189 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value
updates = map(names) do n
for vn in metadata.$n.vns
_setval_kernel!(vi, vn, values, keys)
indices_found = kernel!(vi, vn, values, keys_strings)
if indices_found !== nothing
num_indices_seen += length(indices_found)

return quote
keys_strings = map(string, keys)
num_indices_seen = 0


if length(keys) > num_indices_seen
# Some keys have not been seen, i.e. attempted to set variables which
# we were not able to locate in `vi`.
# Find the ones we missed so we can warn the user.
unused_keys = _find_missing_keys(vi, keys_strings)
@warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)"

return vi

function _find_missing_keys(vi::AbstractVarInfo, keys)
string_vns = map(string, collectmaybe(Base.keys(vi)))
# If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`.
missing_keys = filter(keys) do key
!any(Base.Fix2(subsumes_string, key), string_vns)

return missing_keys

setval!(vi::AbstractVarInfo, x)
setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
Set the values in `vi` to the provided values and leave those which are not present in
`x` or `chains` unchanged.
## Notes
This is rather limited for two reasons:
1. It uses `subsumes_string(string(vn), map(string, keys))` under the hood,
and therefore suffers from the same limitations as [`subsumes_string`](@ref).
2. It will set every `vn` present in `keys`. It will NOT however
set every `k` present in `keys`. This means that if `vn == [m[1], m[2]]`,
representing some variable `m`, calling `setval!(vi, (m = [1.0, 2.0]))` will
be a no-op since it will try to find `m[1]` and `m[2]` in `keys((m = [1.0, 2.0]))`.
## Example
julia> using DynamicPPL, Distributions, StableRNGs
julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1)
julia> rng = StableRNG(42);
julia> m = demo([missing]);
julia> var_info = DynamicPPL.VarInfo(rng, m);
julia> var_info[@varname(m)]
julia> var_info[@varname(x[1])]
julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1]`
julia> var_info[@varname(m)] # [✓] changed
julia> var_info[@varname(x[1])] # [✓] unchanged
julia> m(rng, var_info); # rerun model
julia> var_info[@varname(m)] # [✓] unchanged
julia> var_info[@varname(x[1])] # [✓] unchanged
setval!(vi::AbstractVarInfo, x) = _apply!(_setval_kernel!, vi, values(x), keys(x))
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _apply!(_setval_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains))

function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
string_vn = string(vn)
string_vn_indexing = string_vn * "["
indices = findall(keys) do x
string_x = string(x)
return string_x == string_vn || startswith(string_x, string_vn_indexing)
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
val = reduce(vcat, values[sorted_indices])
setval!(vi, val, vn)
settrans!(vi, false, vn)

return indices

setval_and_resample!(vi::AbstractVarInfo, x)
setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx)
Set the values in `vi` to the provided values and those which are not present
in `x` or `chains` to *be* resampled.
Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")`
for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these
variables will be resampled.
## Note
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
## Example
julia> using DynamicPPL, Distributions, StableRNGs
julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1)
julia> rng = StableRNG(42);
julia> m = demo([missing]);
julia> var_info = DynamicPPL.VarInfo(rng, m);
julia> var_info[@varname(m)]
julia> var_info[@varname(x[1])]
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
julia> var_info[@varname(m)] # [✓] changed
julia> var_info[@varname(x[1])] # [✓] unchanged
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
julia> var_info[@varname(m)] # [✓] unchanged
julia> var_info[@varname(x[1])] # [✓] changed
## See also
- [`setval!`](@ref)
setval_and_resample!(vi::AbstractVarInfo, x) = _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x))
function setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _apply!(_setval_and_resample_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains))

function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
val = mapreduce(vcat, sorted_indices) do i
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
val = reduce(vcat, values[sorted_indices])
setval!(vi, val, vn)
settrans!(vi, false, vn)
# Ensures that we'll resample the variable corresponding to `vn` if we run
# the model on `vi` again.
set_flag!(vi, vn, "del")

return indices
17 changes: 17 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
subsumes_string(u::String, v::String[, u_indexing])
Check whether stringified variable name `v` describes a sub-range of stringified variable `u`.
This is a very restricted version `subumes(u::VarName, v::VarName)` only really supporting:
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.
## Note
- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)`
for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`,
and similarly to `v`. But this is slow.
function subsumes_string(u::String, v::String, u_indexing=u * "[")
return u == v || startswith(v, u_indexing)

inargnames(varname::VarName, model::Model)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -26,6 +27,7 @@ Documenter = "0.26.1"
ForwardDiff = "0.10.12"
MCMCChains = "4.0.4"
MacroTools = "0.5.5"
StableRNGs = "1"
Tracker = "0.2.11"
Zygote = "0.5.4, 0.6"
julia = "1.3"

2 comments on commit 2d6ef3f

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33499

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.9 -m "<description of version>" 2d6ef3f37958ab24d60080d60f85ff09835b9233
git push origin v0.10.9

Please sign in to comment.