From 2d6ef3f37958ab24d60080d60f85ff09835b9233 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 3 Apr 2021 18:07:30 +0000 Subject: [PATCH] Resample variable if not given in `setval!` (#216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently if one calls `DynamicPPL._setval!(vi, vi.metadata, values, keys)` , then only those values present in `keys` will be set, as expected, but the variables which are _not_ present in `keys` will simply be left as-is. This means that we get the following behavior: ``` julia julia> using Turing julia> @model function demo(x) m ~ Normal(0, 1) for i in eachindex(x) x[i] ~ Normal(m, 1) end end demo (generic function with 1 method) julia> m_missing = demo(fill(missing, 2)); julia> var_info_missing = DynamicPPL.VarInfo(m_missing); julia> var_info_missing.metadata.m.vals 1-element Array{Float64,1}: 0.7251417347423874 julia> var_info_missing.metadata.x.vals 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 julia> var_info_missing.metadata.m.vals # ✓ new value 1-element Array{Float64,1}: 0.0 julia> var_info_missing.metadata.x.vals # ✓ still the same value 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 julia> m_missing(var_info_missing) # Re-run the model with new value for `m` julia> var_info_missing.metadata.x.vals # × still the same and thus not reflecting the change in `m`! 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 ``` _Personally_ I expected `x` to be resampled since now parts of the model has changed and thus the sample `x` is no longer representative of a sample from the model (under the sampler used). This PR "fixes" the above so that you get the following behavior: ``` julia julia> var_info_missing.metadata.x.vals 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 julia> DynamicPPL.setval!(var_info_missing, (m = 0.0, )); julia> var_info_missing.metadata.x.vals 2-element Array{Float64,1}: 1.2576791054418153 0.764913349211408 julia> m_missing(var_info_missing) julia> var_info_missing.metadata.x.vals 2-element Array{Float64,1}: -2.0493130638394947 0.3881955730968598 ``` This was discoverd when debugging https://github.com/TuringLang/Turing.jl/issues/1352 as I want to move `Turing.predict` over to using `DynamicPPL.setval!` and it also has consequences for `DynamicPPL.generated_quantities` which uses `DynamicPPL.setval!` under the hood and thus suffer from the same issue. There's an alternative: instead of making this the default-behavior, we could add `kwargs...` to `setval!` which includes `resample_missing::Bool` or something. I'm also completely fine with a solution like that :+1: --- Project.toml | 2 +- src/loglikelihoods.jl | 2 +- src/model.jl | 2 +- src/utils.jl | 7 ++ src/varinfo.jl | 232 ++++++++++++++++++++++++++++++++++++++---- src/varname.jl | 17 ++++ test/Project.toml | 2 + test/varinfo.jl | 112 ++++++++++++++------ 8 files changed, 320 insertions(+), 56 deletions(-) diff --git a/Project.toml b/Project.toml index cfd21170e..232c03392 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.10.8" +version = "0.10.9" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 2e7966279..e74aa31ca 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -173,7 +173,7 @@ function pointwise_loglikelihoods( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval!(vi, chain, sample_idx, chain_idx) + setval_and_resample!(vi, chain, sample_idx, chain_idx) # Execute model model(vi, spl, ctx) diff --git a/src/model.jl b/src/model.jl index fbb83c8a9..b0b78f71f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -277,7 +277,7 @@ function generated_quantities(model::Model, chain::AbstractChains) varinfo = VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - setval!(varinfo, chain, sample_idx, chain_idx) + setval_and_resample!(varinfo, chain, sample_idx, chain_idx) model(varinfo) end end diff --git a/src/utils.jl b/src/utils.jl index a2d209fa2..c01051b81 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -153,3 +153,10 @@ end function inittrans(rng, dist::MatrixDistribution, n::Int) return invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) end + + +####################### +# Convenience methods # +####################### +collectmaybe(x) = x +collectmaybe(x::Base.AbstractSet) = collect(x) diff --git a/src/varinfo.jl b/src/varinfo.jl index db56b3533..8d9ffecce 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -584,11 +584,22 @@ end # Functions defined only for UntypedVarInfo """ - keys(vi::UntypedVarInfo) + keys(vi::AbstractVarInfo) Return an iterator over all `vns` in `vi`. """ -keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) +Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) + +@generated function Base.keys(vi::TypedVarInfo{<:NamedTuple{names}}) where {names} + expr = Expr(:call) + push!(expr.args, :vcat) + + for n in names + push!(expr.args, :(vi.metadata.$n.vns)) + end + + return expr +end """ setgid!(vi::VarInfo, gid::Selector, vn::VarName) @@ -1165,19 +1176,39 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) end end -setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x)) -function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) - return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) -end +# TODO: Maybe rename or something? +""" + _apply!(kernel!, vi::AbstractVarInfo, values, keys) + +Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. +""" +function _apply!(kernel!, vi::AbstractVarInfo, values, keys) + keys_strings = map(string, collectmaybe(keys)) + num_indices_seen = 0 -function _setval!(vi::AbstractVarInfo, values, keys) for vn in Base.keys(vi) - _setval_kernel!(vi, vn, values, keys) + indices_found = kernel!(vi, vn, values, keys_strings) + if indices_found !== nothing + num_indices_seen += length(indices_found) + end end + + if length(keys) > num_indices_seen + # Some keys have not been seen, i.e. attempted to set variables which + # we were not able to locate in `vi`. + # Find the ones we missed so we can warn the user. + unused_keys = _find_missing_keys(vi, keys_strings) + @warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)" + end + return vi end -_setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys) -@generated function _typed_setval!( + +_apply!(kernel!, vi::TypedVarInfo, values, keys) = _typed_apply!( + kernel!, vi, vi.metadata, values, collectmaybe(keys)) + +@generated function _typed_apply!( + kernel!, vi::TypedVarInfo, metadata::NamedTuple{names}, values, @@ -1186,30 +1217,189 @@ _setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, value updates = map(names) do n quote for vn in metadata.$n.vns - _setval_kernel!(vi, vn, values, keys) + indices_found = kernel!(vi, vn, values, keys_strings) + if indices_found !== nothing + num_indices_seen += length(indices_found) + end end end end - + return quote + keys_strings = map(string, keys) + num_indices_seen = 0 + $(updates...) + + if length(keys) > num_indices_seen + # Some keys have not been seen, i.e. attempted to set variables which + # we were not able to locate in `vi`. + # Find the ones we missed so we can warn the user. + unused_keys = _find_missing_keys(vi, keys_strings) + @warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)" + end + return vi end end +function _find_missing_keys(vi::AbstractVarInfo, keys) + string_vns = map(string, collectmaybe(Base.keys(vi))) + # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. + missing_keys = filter(keys) do key + !any(Base.Fix2(subsumes_string, key), string_vns) + end + + return missing_keys +end + +""" + setval!(vi::AbstractVarInfo, x) + setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) + +Set the values in `vi` to the provided values and leave those which are not present in +`x` or `chains` unchanged. + +## Notes +This is rather limited for two reasons: +1. It uses `subsumes_string(string(vn), map(string, keys))` under the hood, + and therefore suffers from the same limitations as [`subsumes_string`](@ref). +2. It will set every `vn` present in `keys`. It will NOT however + set every `k` present in `keys`. This means that if `vn == [m[1], m[2]]`, + representing some variable `m`, calling `setval!(vi, (m = [1.0, 2.0]))` will + be a no-op since it will try to find `m[1]` and `m[2]` in `keys((m = [1.0, 2.0]))`. + +## Example +```jldoctest +julia> using DynamicPPL, Distributions, StableRNGs + +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1) + end + end; + +julia> rng = StableRNG(42); + +julia> m = demo([missing]); + +julia> var_info = DynamicPPL.VarInfo(rng, m); + +julia> var_info[@varname(m)] +-0.6702516921145671 + +julia> var_info[@varname(x[1])] +-0.22312984965118443 + +julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1]` + +julia> var_info[@varname(m)] # [✓] changed +100.0 + +julia> var_info[@varname(x[1])] # [✓] unchanged +-0.22312984965118443 + +julia> m(rng, var_info); # rerun model + +julia> var_info[@varname(m)] # [✓] unchanged +100.0 + +julia> var_info[@varname(x[1])] # [✓] unchanged +-0.22312984965118443 +``` +""" +setval!(vi::AbstractVarInfo, x) = _apply!(_setval_kernel!, vi, values(x), keys(x)) +function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) + return _apply!(_setval_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains)) +end + function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) - string_vn = string(vn) - string_vn_indexing = string_vn * "[" - indices = findall(keys) do x - string_x = string(x) - return string_x == string_vn || startswith(string_x, string_vn_indexing) + indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) + if !isempty(indices) + sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural) + val = reduce(vcat, values[sorted_indices]) + setval!(vi, val, vn) + settrans!(vi, false, vn) end + + return indices +end + +""" + setval_and_resample!(vi::AbstractVarInfo, x) + setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx, chain_idx) + +Set the values in `vi` to the provided values and those which are not present +in `x` or `chains` to *be* resampled. + +Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")` +for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these +variables will be resampled. + +## Note +- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. + +## Example +```jldoctest +julia> using DynamicPPL, Distributions, StableRNGs + +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1) + end + end; + +julia> rng = StableRNG(42); + +julia> m = demo([missing]); + +julia> var_info = DynamicPPL.VarInfo(rng, m); + +julia> var_info[@varname(m)] +-0.6702516921145671 + +julia> var_info[@varname(x[1])] +-0.22312984965118443 + +julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling + +julia> var_info[@varname(m)] # [✓] changed +100.0 + +julia> var_info[@varname(x[1])] # [✓] unchanged +-0.22312984965118443 + +julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` + +julia> var_info[@varname(m)] # [✓] unchanged +100.0 + +julia> var_info[@varname(x[1])] # [✓] changed +101.37363069798343 +``` + +## See also +- [`setval!`](@ref) +""" +setval_and_resample!(vi::AbstractVarInfo, x) = _apply!(_setval_and_resample_kernel!, vi, values(x), keys(x)) +function setval_and_resample!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int) + return _apply!(_setval_and_resample_kernel!, vi, chains.value[sample_idx, :, chain_idx], keys(chains)) +end + +function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) + indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) - sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural) - val = mapreduce(vcat, sorted_indices) do i - values[i] - end + sorted_indices = sort!(indices; by=i -> keys[i], lt=NaturalSort.natural) + val = reduce(vcat, values[sorted_indices]) setval!(vi, val, vn) settrans!(vi, false, vn) + else + # Ensures that we'll resample the variable corresponding to `vn` if we run + # the model on `vi` again. + set_flag!(vi, vn, "del") end + + return indices end diff --git a/src/varname.jl b/src/varname.jl index f45b0b430..ca5823a8a 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,3 +1,20 @@ +""" + subsumes_string(u::String, v::String[, u_indexing]) + +Check whether stringified variable name `v` describes a sub-range of stringified variable `u`. + +This is a very restricted version `subumes(u::VarName, v::VarName)` only really supporting: +- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. + +## Note +- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` + for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`, + and similarly to `v`. But this is slow. +""" +function subsumes_string(u::String, v::String, u_indexing=u * "[") + return u == v || startswith(v, u_indexing) +end + """ inargnames(varname::VarName, model::Model) diff --git a/test/Project.toml b/test/Project.toml index 0084a3668..a5276eec5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -26,6 +27,7 @@ Documenter = "0.26.1" ForwardDiff = "0.10.12" MCMCChains = "4.0.4" MacroTools = "0.5.5" +StableRNGs = "1" Tracker = "0.2.11" Zygote = "0.5.4, 0.6" julia = "1.3" diff --git a/test/varinfo.jl b/test/varinfo.jl index dc0fc3e59..bee6e781e 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -57,9 +57,11 @@ @test isempty(vi) @test ~haskey(vi, vn) + @test !(vn in keys(vi)) push!(vi, vn, r, dist, gid) @test ~isempty(vi) @test haskey(vi, vn) + @test vn in keys(vi) @test length(vi[vn]) == 1 @test length(vi[SampleFromPrior()]) == 1 @@ -143,7 +145,7 @@ setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end - @testset "setval!" begin + @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) s ~ truncated(Normal(), 0, Inf) @@ -151,40 +153,86 @@ x ~ MvNormal(m, s) end - x = randn(5) - model = testmodel(x) - - # UntypedVarInfo - vi = VarInfo() - model(vi, SampleFromPrior()) - - vicopy = deepcopy(vi) - DynamicPPL.setval!(vicopy, (m = zeros(5),)) - @test vicopy[@varname(m)] == zeros(5) - @test vicopy[@varname(s)] == vi[@varname(s)] - - DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)) - @test vicopy[@varname(m)] == 1:5 - @test vicopy[@varname(s)] == vi[@varname(s)] - - DynamicPPL.setval!(vicopy, (s = 42,)) - @test vicopy[@varname(m)] == 1:5 - @test vicopy[@varname(s)] == 42 + @model function testmodel_univariate(x, ::Type{TV} = Vector{Float64}) where {TV} + n = length(x) + s ~ truncated(Normal(), 0, Inf) - # TypedVarInfo - vi = VarInfo(model) + m = TV(undef, n) + for i in eachindex(m) + m[i] ~ Normal() + end - vicopy = deepcopy(vi) - DynamicPPL.setval!(vicopy, (m = zeros(5),)) - @test vicopy[@varname(m)] == zeros(5) - @test vicopy[@varname(s)] == vi[@varname(s)] + for i in eachindex(x) + x[i] ~ Normal(m[i], s) + end + end - DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)) - @test vicopy[@varname(m)] == 1:5 - @test vicopy[@varname(s)] == vi[@varname(s)] + x = randn(5) + model_mv = testmodel(x) + model_uv = testmodel_univariate(x) + + for model in [model_uv, model_mv] + m_vns = model == model_uv ? [@varname(m[i]) for i = 1:5] : @varname(m) + s_vns = @varname(s) + + vi_typed = VarInfo(model) + vi_untyped = VarInfo() + model(vi_untyped, SampleFromPrior()) + + for vi in [vi_untyped, vi_typed] + vicopy = deepcopy(vi) + + ### `setval` ### + DynamicPPL.setval!(vicopy, (m = zeros(5),)) + # Setting `m` fails for univariate due to limitations of `setval!` + # and `setval_and_resample!`. See docstring of `setval!` for more info. + if model == model_uv + @test_broken vicopy[m_vns] == zeros(5) + else + @test vicopy[m_vns] == zeros(5) + end + @test vicopy[s_vns] == vi[s_vns] + + DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...)) + @test vicopy[m_vns] == 1:5 + @test vicopy[s_vns] == vi[s_vns] + + DynamicPPL.setval!(vicopy, (s = 42,)) + @test vicopy[m_vns] == 1:5 + @test vicopy[s_vns] == 42 + + ### `setval_and_resample!` ### + if model == model_mv && vi == vi_untyped + # Trying to re-run model with `MvNormal` on `vi_untyped` will call + # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` + # so we skip this particular case. + continue + end - DynamicPPL.setval!(vicopy, (s = 42,)) - @test vicopy[@varname(m)] == 1:5 - @test vicopy[@varname(s)] == 42 + vicopy = deepcopy(vi) + DynamicPPL.setval_and_resample!(vicopy, (m = zeros(5),)) + model(vicopy) + # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` + if model == model_uv + @test_broken vicopy[m_vns] == zeros(5) + else + @test vicopy[m_vns] == zeros(5) + end + @test vicopy[s_vns] != vi[s_vns] + + DynamicPPL.setval_and_resample!( + vicopy, + (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) + ) + model(vicopy) + @test vicopy[m_vns] == 1:5 + @test vicopy[s_vns] != vi[s_vns] + + DynamicPPL.setval_and_resample!(vicopy, (s = 42,)) + model(vicopy) + @test vicopy[m_vns] != 1:5 + @test vicopy[s_vns] == 42 + end + end end end