Skip to content

Commit

Permalink
Use OrderedDict in SimpleVarInfo + improvements and fixes for `va…
Browse files Browse the repository at this point in the history
…lues_as` (#420)

We are currently using `Dict` together with `SimpleVarInfo` which leads to inconsistent ordering of the variables vs. `OrderdDict` which, if generated from a `Model`, will preserve the execution order of the model.

In addition, I've fixed some impls for `values_as` + added more better support, in addition to proper testing. Given how it's now better tested + is a nice-to-have feature + will likely see extensive use after #417, it also seems reasonable to export `values_as` from DPPL.

EDIT: This should be merged before #417
  • Loading branch information
torfjelde committed Aug 29, 2022
1 parent 08ef935 commit e31a790
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 31 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -28,6 +29,7 @@ ConstructionBase = "1"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
MacroTools = "0.5.6"
OrderedCollections = "1"
Setfield = "0.7.1, 0.8"
ZygoteRules = "0.2"
julia = "1.6"
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ push!!
empty!!
```

```@docs
values_as
```

#### `SimpleVarInfo`

```@docs
Expand Down
4 changes: 4 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ using AbstractMCMC: AbstractSampler, AbstractChains
using AbstractPPL
using Bijectors
using Distributions
using OrderedCollections: OrderedDict

using AbstractMCMC: AbstractMCMC
using BangBang: BangBang, push!!, empty!!, setindex!!
using ChainRulesCore: ChainRulesCore
using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Setfield: Setfield
using ZygoteRules: ZygoteRules

Expand Down Expand Up @@ -59,6 +61,7 @@ export AbstractVarInfo,
link!,
invlink!,
tonamedtuple,
values_as,
# VarName (reexport from AbstractPPL)
VarName,
inspace,
Expand All @@ -73,6 +76,7 @@ export AbstractVarInfo,
Sample,
init,
vectorize,
OrderedDict,
# Model
Model,
getmissings,
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
x = last(
evaluate!!(
model,
SimpleVarInfo{Float64}(),
SimpleVarInfo{Float64}(OrderedDict()),
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
),
)
Expand Down
40 changes: 21 additions & 19 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ struct DefaultTransformation <: AbstractTransformation end
A simple wrapper of the parameters with a `logp` field for
accumulation of the logdensity.
Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`.
Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`.
# Fields
$(FIELDS)
Expand Down Expand Up @@ -69,8 +69,8 @@ julia> # (×) If we don't provide the container...
ERROR: type NamedTuple has no field x
[...]
julia> # If one does not know the varnames, we can use a `Dict` instead.
_, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict()), ctx);
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
_, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx);
julia> # (✓) Sort of fast, but only possible at runtime.
vi[@varname(x[1])]
Expand All @@ -86,6 +86,11 @@ ERROR: KeyError: key x[1:2] not found
[...]
```
_Technically_, it's possible to use any implementation of `AbstractDict` in place of
`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening
of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is
the preferred implementation of `AbstractDict` to use here.
You can also sample in _transformed_ space:
```jldoctest simplevarinfo-general
Expand All @@ -109,8 +114,8 @@ julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo()
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
true
julia> # And with `Dict` of course!
_, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true), ctx);
julia> # And with `OrderedDict` of course!
_, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx);
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
0.6225185067787314
Expand Down Expand Up @@ -165,9 +170,9 @@ ERROR: type NamedTuple has no field b
[...]
```
Using `Dict` as underlying storage.
Using `OrderedDict` as underlying storage.
```jldoctest
julia> svi_dict = SimpleVarInfo(Dict(@varname(m) => (a = [1.0], )));
julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], )));
julia> svi_dict[@varname(m)]
(a = [1.0],)
Expand Down Expand Up @@ -274,7 +279,7 @@ end

Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn)

# `Dict`
# `AbstractDict`
function Base.getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName)
return nested_getindex(vi.values, vn)
end
Expand Down Expand Up @@ -364,7 +369,7 @@ function BangBang.push!!(
return Setfield.@set vi.values = set!!(vi.values, vn, value)
end

# `Dict`
# `AbstractDict`
function BangBang.push!!(
vi::SimpleVarInfo{<:AbstractDict},
vn::VarName,
Expand Down Expand Up @@ -473,17 +478,14 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)

"""
values_as(varinfo[, Type])
Return the values/realizations in `varinfo` as `Type`, if implemented.
If no `Type` is provided, return values as stored in `varinfo`.
"""
values_as(vi::SimpleVarInfo) = vi.values
values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values))
values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values))
values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict}
return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values)))
end
function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple})
return NamedTuple((Symbol(k), v) for (k, v) in vi.values)
end

"""
logjoint(model::Model, θ)
Expand Down
101 changes: 91 additions & 10 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1550,30 +1550,111 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys)
end

"""
values_as(vi::AbstractVarInfo)
"""
values_as(vi::VarInfo) = vi.metadata
values_as(varinfo[, Type])
"""
values_as(vi::AbstractVarInfo, ::Type{NamedTuple})
values_as(vi::AbstractVarInfo, ::Type{Dict})
Return the values/realizations in `varinfo` as `Type`, if implemented.
If no `Type` is provided, return values as stored in `varinfo`.
# Examples
`SimpleVarInfo` with `NamedTuple`:
```jldoctest
julia> data = (x = 1.0, m = [2.0]);
julia> values_as(SimpleVarInfo(data))
(x = 1.0, m = [2.0])
julia> values_as(SimpleVarInfo(data), NamedTuple)
(x = 1.0, m = [2.0])
julia> values_as(SimpleVarInfo(data), OrderedDict)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries:
x => 1.0
m => [2.0]
```
`SimpleVarInfo` with `OrderedDict`:
```jldoctest
julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]);
julia> values_as(SimpleVarInfo(data))
OrderedDict{Any, Any} with 2 entries:
x => 1.0
m => [2.0]
julia> values_as(SimpleVarInfo(data), NamedTuple)
(x = 1.0, m = [2.0])
julia> values_as(SimpleVarInfo(data), OrderedDict)
OrderedDict{Any, Any} with 2 entries:
x => 1.0
m => [2.0]
```
`TypedVarInfo`:
```jldoctest
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe());
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
julia> # For the sake of brevity, let's just check the type.
md = values_as(vi); md.s isa DynamicPPL.Metadata
true
julia> values_as(vi, NamedTuple)
(s = 1.0, m = 2.0)
julia> values_as(vi, OrderedDict)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries:
s => 1.0
m => 2.0
```
Return values in `vi` as the specified type.
`UntypedVarInfo`:
```jldoctest
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi);
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
julia> # For the sake of brevity, let's just check the type.
values_as(vi) isa DynamicPPL.Metadata
true
julia> values_as(vi, NamedTuple)
(s = 1.0, m = 2.0)
julia> values_as(vi, OrderedDict)
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries:
s => 1.0
m => 2.0
```
"""
values_as(vi::VarInfo) = vi.metadata
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})
iter = values_from_metadata(vi.metadata)
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
end
values_as(vi::UntypedVarInfo, ::Type{Dict}) = Dict(values_from_metadata(vi.metadata))
function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict}
return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata))
end

function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names}
iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names)
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
end

function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names}
function values_as(
vi::VarInfo{<:NamedTuple{names}}, ::Type{D}
) where {names,D<:AbstractDict}
iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names)
return Dict(iter)
return ConstructionBase.constructorof(D)(iter)
end

function values_from_metadata(md::Metadata)
Expand Down
3 changes: 3 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# FIXME: This fix should be in `AbstractPPL`.
AbstractPPL.subsumes(::Setfield.IdentityLens, ::Setfield.IdentityLens) = true

"""
subsumes_string(u::String, v::String[, u_indexing])
Expand Down
2 changes: 1 addition & 1 deletion test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ end
Random.seed!(1776)
s, m = model()
sample_namedtuple = (; s=s, m=m)
sample_dict = Dict(:s => s, :m => m)
sample_dict = Dict(@varname(s) => s, @varname(m) => m)

# With explicit RNG
@test rand(Random.seed!(1776), model) == sample_namedtuple
Expand Down
55 changes: 55 additions & 0 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,58 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1)
end
end
end

"""
short_varinfo_name(vi::AbstractVarInfo)
Return string representing a short description of `vi`.
"""
short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = short_varinfo_name(vi.varinfo)
short_varinfo_name(::TypedVarInfo) = "TypedVarInfo"
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"

"""
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
Return instance similar to `vi` but with `vns` set to values from `vals`.
"""
function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
for vn in vns
vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn)
end
return vi
end

"""
test_values(vi::AbstractVarInfo, vals::NamedTuple, vns)
Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`.
"""
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns)
for vn in vns
@test vi[vn] == get(vals, vn)
end
end

"""
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
Return a tuple of instances for different implementations of `AbstractVarInfo` with
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
"""
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
# <:VarInfo
vi_untyped = VarInfo()
model(vi_untyped)
vi_typed = TypedVarInfo(vi_untyped)
# <:SimpleVarInfo
svi_typed = SimpleVarInfo(example_values)
svi_untyped = SimpleVarInfo(OrderedDict())

return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi
# Set them all to the same values.
update_values!!(vi, example_values, varnames)
end
end
Loading

2 comments on commit e31a790

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67973

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.20.1 -m "<description of version>" e31a7905e33533c2d23f6a1000523009be1a68b2
git push origin v0.20.1

Please sign in to comment.