Skip to content

Commit

Permalink
Resample variable if not given in setval! (#216)
Browse files Browse the repository at this point in the history
Currently if one calls `DynamicPPL._setval!(vi, vi.metadata, values, keys)` , then only those values present in `keys` will be set, as expected, but the variables which are _not_ present in `keys` will simply be left as-is. This means that we get the following behavior:
``` julia
julia> using Turing

julia> @model function demo(x)
           m ~ Normal(0, 1)
           for i in eachindex(x)
               x[i] ~ Normal(m, 1)
           end
       end
demo (generic function with 1 method)

julia> m_missing = demo(fill(missing, 2));

julia> var_info_missing = DynamicPPL.VarInfo(m_missing);

julia> var_info_missing.metadata.m.vals
1-element Array{Float64,1}:
 0.7251417347423874

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408
julia> var_info_missing.metadata.m.vals # ✓ new value
1-element Array{Float64,1}:
 0.0
julia> var_info_missing.metadata.x.vals # ✓ still the same value
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> m_missing(var_info_missing) # Re-run the model with new value for `m`

julia> var_info_missing.metadata.x.vals # × still the same and thus not reflecting the change in `m`!
2-element Array{Float64,1}:
 1.2576791054418153
0.764913349211408
```

_Personally_ I expected `x` to be resampled since now parts of the model has changed and thus the sample `x` is no longer representative of a sample from the model (under the sampler used).

This PR "fixes" the above so that you get the following behavior:
``` julia
julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> DynamicPPL.setval!(var_info_missing, (m = 0.0, ));

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 1.2576791054418153
 0.764913349211408

julia> m_missing(var_info_missing)

julia> var_info_missing.metadata.x.vals
2-element Array{Float64,1}:
 -2.0493130638394947
  0.3881955730968598
```

This was discoverd when debugging TuringLang/Turing.jl#1352 as I want to move `Turing.predict` over to using `DynamicPPL.setval!` and it also has consequences for `DynamicPPL.generated_quantities` which uses `DynamicPPL.setval!` under the hood and thus suffer from the same issue.


There's an alternative: instead of making this the default-behavior, we could add `kwargs...` to `setval!` which includes `resample_missing::Bool` or something. I'm also completely fine with a solution like that 👍
  • Loading branch information
torfjelde committed Apr 3, 2021
1 parent 3602c56 commit 2d6ef3f
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.8"
version = "0.10.9"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 1 addition & 1 deletion src/loglikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ function pointwise_loglikelihoods(
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
for (sample_idx, chain_idx) in iters
# Update the values
setval!(vi, chain, sample_idx, chain_idx)
setval_and_resample!(vi, chain, sample_idx, chain_idx)

# Execute model
model(vi, spl, ctx)
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ function generated_quantities(model::Model, chain::AbstractChains)
varinfo = VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
setval!(varinfo, chain, sample_idx, chain_idx)
setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
model(varinfo)
end
end
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,10 @@ end
function inittrans(rng, dist::MatrixDistribution, n::Int)
return invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n])
end


#######################
# Convenience methods #
#######################
collectmaybe(x) = x
collectmaybe(x::Base.AbstractSet) = collect(x)
232 changes: 211 additions & 21 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,22 @@ end

# Functions defined only for UntypedVarInfo
"""
keys(vi::UntypedVarInfo)
keys(vi::AbstractVarInfo)
Return an iterator over all `vns` in `vi`.
"""
keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)
Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs)

@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names}
expr = Expr(:call)
push!(expr.args, :vcat)

for n in names
push!(expr.args, :(vi.metadata.$n.vns))
end

return expr
end

"""
setgid!(vi::VarInfo, gid::Selector, vn::VarName)
Expand Down Expand Up @@ -1165,19 +1176,39 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler)
end
end

setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x))
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end
# TODO: Maybe rename or something?
"""
_apply!(kernel!, vi::AbstractVarInfo, values, keys)
Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`.
"""
function _apply!(kernel!, vi::AbstractVarInfo, values, keys)
keys_strings = map(string, collectmaybe(keys))
num_indices_seen = 0

function _setval!(vi::AbstractVarInfo, values, keys)
for vn in Base.keys(vi)
_setval_kernel!(vi, vn, values, keys)
indices_found = kernel!(vi, vn, values, keys_strings)
if indices_found !== nothing
num_indices_seen += length(indices_found)
end
end

if length(keys) > num_indices_seen
# Some keys have not been seen, i.e. attempted to set variables which
# we were not able to locate in `vi`.
# Find the ones we missed so we can warn the user.
unused_keys = _find_missing_keys(vi, keys_strings)
@warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)"
end

return vi
end
_setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys)
@generated function _typed_setval!(

_apply!(kernel!, vi::TypedVarInfo, values, keys) = _typed_apply!(
kernel!, vi, vi.metadata, values, collectmaybe(keys))

@generated function _typed_apply!(
kernel!,
vi::TypedVarInfo,
metadata::NamedTuple{names},
values,
Expand All @@ -1186,30 +1217,189 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value
updates = map(names) do n
quote
for vn in metadata.$n.vns
_setval_kernel!(vi, vn, values, keys)
indices_found = kernel!(vi, vn, values, keys_strings)
if indices_found !== nothing
num_indices_seen += length(indices_found)
end
end
end
end

return quote
keys_strings = map(string, keys)
num_indices_seen = 0

$(updates...)

if length(keys) > num_indices_seen
# Some keys have not been seen, i.e. attempted to set variables which
# we were not able to locate in `vi`.
# Find the ones we missed so we can warn the user.
unused_keys = _find_missing_keys(vi, keys_strings)
@warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)"
end

return vi
end
end

function _find_missing_keys(vi::AbstractVarInfo, keys)
string_vns = map(string, collectmaybe(Base.keys(vi)))
# If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`.
missing_keys = filter(keys) do key
!any(Base.Fix2(subsumes_string, key), string_vns)
end

return missing_keys
end

"""
setval!(vi::AbstractVarInfo, x)
setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
Set the values in `vi` to the provided values and leave those which are not present in
`x` or `chains` unchanged.
## Notes
This is rather limited for two reasons:
1. It uses `subsumes_string(string(vn), map(string, keys))` under the hood,
and therefore suffers from the same limitations as [`subsumes_string`](@ref).
2. It will set every `vn` present in `keys`. It will NOT however
set every `k` present in `keys`. This means that if `vn == [m[1], m[2]]`,
representing some variable `m`, calling `setval!(vi, (m = [1.0, 2.0]))` will
be a no-op since it will try to find `m[1]` and `m[2]` in `keys((m = [1.0, 2.0]))`.
## Example
```jldoctest
julia> using DynamicPPL, Distributions, StableRNGs
julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1)
end
end;
julia> rng = StableRNG(42);
julia> m = demo([missing]);
julia> var_info = DynamicPPL.VarInfo(rng, m);
julia> var_info[@varname(m)]
-0.6702516921145671
julia> var_info[@varname(x[1])]
-0.22312984965118443
julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1]`
julia> var_info[@varname(m)] # [✓] changed
100.0
julia> var_info[@varname(x[1])] # [✓] unchanged
-0.22312984965118443
julia> m(rng, var_info); # rerun model
julia> var_info[@varname(m)] # [✓] unchanged
100.0
julia> var_info[@varname(x[1])] # [✓] unchanged
-0.22312984965118443
```
"""
setval!(vi::AbstractVarInfo, x) = _apply!(_setval_kernel!, vi, values(x), keys(x))
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _apply!(_setval_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end

function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
string_vn = string(vn)
string_vn_indexing = string_vn * "["
indices = findall(keys) do x
string_x = string(x)
return string_x == string_vn || startswith(string_x, string_vn_indexing)
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
val = reduce(vcat, values[sorted_indices])
setval!(vi, val, vn)
settrans!(vi, false, vn)
end

return indices
end

"""
setval_and_resample!(vi::AbstractVarInfo, x)
setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx)
Set the values in `vi` to the provided values and those which are not present
in `x` or `chains` to *be* resampled.
Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")`
for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these
variables will be resampled.
## Note
- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info.
## Example
```jldoctest
julia> using DynamicPPL, Distributions, StableRNGs
julia> @model function demo(x)
m ~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m, 1)
end
end;
julia> rng = StableRNG(42);
julia> m = demo([missing]);
julia> var_info = DynamicPPL.VarInfo(rng, m);
julia> var_info[@varname(m)]
-0.6702516921145671
julia> var_info[@varname(x[1])]
-0.22312984965118443
julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling
julia> var_info[@varname(m)] # [✓] changed
100.0
julia> var_info[@varname(x[1])] # [✓] unchanged
-0.22312984965118443
julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0`
julia> var_info[@varname(m)] # [✓] unchanged
100.0
julia> var_info[@varname(x[1])] # [✓] changed
101.37363069798343
```
## See also
- [`setval!`](@ref)
"""
setval_and_resample!(vi::AbstractVarInfo, x) = _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x))
function setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _apply!(_setval_and_resample_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end

function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
indices = findall(Base.Fix1(subsumes_string, string(vn)), keys)
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
val = mapreduce(vcat, sorted_indices) do i
values[i]
end
sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural)
val = reduce(vcat, values[sorted_indices])
setval!(vi, val, vn)
settrans!(vi, false, vn)
else
# Ensures that we'll resample the variable corresponding to `vn` if we run
# the model on `vi` again.
set_flag!(vi, vn, "del")
end

return indices
end
17 changes: 17 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
subsumes_string(u::String, v::String[, u_indexing])
Check whether stringified variable name `v` describes a sub-range of stringified variable `u`.
This is a very restricted version `subumes(u::VarName, v::VarName)` only really supporting:
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.
## Note
- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)`
for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`,
and similarly to `v`. But this is slow.
"""
function subsumes_string(u::String, v::String, u_indexing=u * "[")
return u == v || startswith(v, u_indexing)
end

"""
inargnames(varname::VarName, model::Model)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -26,6 +27,7 @@ Documenter = "0.26.1"
ForwardDiff = "0.10.12"
MCMCChains = "4.0.4"
MacroTools = "0.5.5"
StableRNGs = "1"
Tracker = "0.2.11"
Zygote = "0.5.4, 0.6"
julia = "1.3"
Loading

2 comments on commit 2d6ef3f

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33499

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.9 -m "<description of version>" 2d6ef3f37958ab24d60080d60f85ff09835b9233
git push origin v0.10.9

Please sign in to comment.