Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need alternative to NamedTuple for SimpleVarInfo #528

Closed
torfjelde opened this issue Sep 1, 2023 · 13 comments
Closed

Need alternative to NamedTuple for SimpleVarInfo #528

torfjelde opened this issue Sep 1, 2023 · 13 comments
Assignees

Comments

@torfjelde
Copy link
Member

Problem

Now that we properly support usage of different sizes in the underlying storage of the varinfo after linking, the current usage of NamedTuple for both the "ground truth" in TestUtils, e.g.

function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal, kwargs...)
for vn in vns
@test isequal(vi[vn], get(vals, vn); kwargs...)
end
end

and in SimpleVarInfo, makes less sense than it did before.

To see why, let's consider the following example:

julia> using DynamicPPL, Distributions

julia> @model function demo()
           x = Vector{Float64}(undef, 5)
           x[1] ~ Normal()
           x[2:3] ~ Dirichlet([1.0, 2.0])

           return (x=x,)
       end
demo (generic function with 2 methods)

julia> model = demo();

julia> nt = model()
(x = [-0.08084553378437927, 0.6662187241949805, 0.3337812758050194, 6.93842891015994e-310, 6.93842829167434e-310],)

julia> # Construct `SimpleVarInfo` from `nt`.
       vi = SimpleVarInfo(nt)
SimpleVarInfo((x = [-0.08084553378437927, 0.6662187241949805, 0.3337812758050194, 6.93842891015994e-310, 6.93842829167434e-310],), 0.0)

julia> vn = @varname(x[2:3])
x[2:3]

julia> # (✓) Everything works nicely
       vi[vn]
2-element Vector{Float64}:
 0.6662187241949805
 0.3337812758050194

julia> # Now we link it!
       vi_linked = DynamicPPL.link!!(vi, model);
ERROR: DimensionMismatch: tried to assign 1 elements to 2 destinations
Stacktrace:
  [1] throw_setindex_mismatch(X::Vector{Float64}, I::Tuple{Int64})
    @ Base ./indices.jl:191
  [2] setindex_shape_check
    @ ./indices.jl:245 [inlined]
  [3] setindex!
    @ ./array.jl:994 [inlined]
  [4] _setindex!
    @ ~/.julia/packages/BangBang/FUkah/src/base.jl:480 [inlined]
  [5] may
    @ ~/.julia/packages/BangBang/FUkah/src/core.jl:9 [inlined]
  [6] setindex!!
    @ ~/.julia/packages/BangBang/FUkah/src/base.jl:478 [inlined]
  [7] set(obj::Vector{Float64}, lens::BangBang.SetfieldImpl.Lens!!{Setfield.IndexLens{Tuple{UnitRange{Int64}}}}, value::Vector{Float64})
    @ BangBang.SetfieldImpl ~/.julia/packages/BangBang/FUkah/src/setfield.jl:34
  [8] set
    @ ~/.julia/packages/Setfield/PdKfV/src/lens.jl:188 [inlined]
  [9] set
    @ ~/.julia/packages/BangBang/FUkah/src/setfield.jl:17 [inlined]
 [10] set!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/utils.jl:354 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Setfield/PdKfV/src/sugar.jl:197 [inlined]
 [12] setindex!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/simple_varinfo.jl:339 [inlined]
 [13] tilde_assume(#unused#::DynamicPPL.DynamicTransformationContext{false}, right::Dirichlet{Float64, Vector{Float64}, Float64}, vn::VarName{:x, Setfield.IndexLens{Tuple{UnitRange{Int64}}}}, vi::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation})
    @ DynamicPPL /drive-2/Projects/public/DynamicPPL.jl/src/transforming.jl:19
 [14] tilde_assume!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/context_implementations.jl:117 [inlined]
 [15] demo(__model__::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext}, __varinfo__::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation}, __context__::DynamicPPL.DynamicTransformationContext{false})
    @ Main ./REPL[47]:4
 [16] _evaluate!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:963 [inlined]
 [17] evaluate_threadunsafe!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:936 [inlined]
 [18] evaluate!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/model.jl:889 [inlined]
 [19] link!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/transforming.jl:86 [inlined]
 [20] link!!
    @ /drive-2/Projects/public/DynamicPPL.jl/src/abstract_varinfo.jl:384 [inlined]
 [21] link!!(vi::SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64, DynamicPPL.NoTransformation}, model::Model{typeof(demo), (), (), (), Tuple{}, Tuple{}, DefaultContext})
    @ DynamicPPL /drive-2/Projects/public/DynamicPPL.jl/src/abstract_varinfo.jl:378
 [22] top-level scope
    @ REPL[53]:2

The issue here can really just be boiled down to the fact that we're trying to use the varname

julia> vn
x[2:3]

to index a NamedTuple which is after the transformation represented by a 1-length vector rather than 2-length vector.

In contrast, SimpeVarInfo{<:AbstractDict} will work just fine because here each varname gets its own entry:

julia> # Construct `SimpleVarInfo` using a dict now.
       vi = SimpleVarInfo(rand(OrderedDict, model))
SimpleVarInfo(OrderedDict{Any, Any}(x[1] => -0.12337922752695839, x[2:3] => [0.7836009759179734, 0.21639902408202646]), 0)
julia> # (✓) Everything works nicely
       vi[vn]
2-element Vector{Float64}:
 0.7836009759179734
 0.21639902408202646

julia> # Now we link it!
       vi_linked = DynamicPPL.link!!(vi, model);

julia> # (✓) Everything works nicely
         vi_linked[vn]
1-element Vector{Float64}:
 1.2867758943235161

"Luckily" it has always been the plan that SimpleVarInfo should be able to use different underlying representations fairly easily, e.g. I've successfully used it with ComponentVector from ComponentArrays.jl many times before. And so we should probably find a more flexible default representation for SimpleVarInfo that can be used in more cases.

Solution

Option 1: Use OrderedDict by default

This one is obviously not great becuase of performance reasons, but it will "just work" in all cases and it's very simple to reason about.

Option 2: Dict-like flattened representation

In an ideal world, the underlying representation of the values in a varinfo would have the following properties:

  1. It's type-stable, when possible.
  2. It's contiguous in memory, when possible.
  3. It's indexable by VarName.

Something like an OrderedDict fails in two regards:

  1. It's not contiguous in memory.
  2. Type-stability is not guaranteed, unless we create a dictionary for each eltype or something similar.

Current Metadata used by VarInfo

The Metadata type in VarInfo is a good example of something that satisfies all three properties (of course, the "when possible" in Property (1) is not concrete, but VarInfo uses a NamedTuple of Metadata to achieve this in most common use-cases).

As a reminder, here is what Metadata looks like:

struct Metadata{
TIdcs<:Dict{<:VarName,Int},
TDists<:AbstractVector{<:Distribution},
TVN<:AbstractVector{<:VarName},
TVal<:AbstractVector{<:Real},
TGIds<:AbstractVector{Set{Selector}},
}
# Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists`
idcs::TIdcs # Dict{<:VarName,Int}
# Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn`
vns::TVN # AbstractVector{<:VarName}
# Vector of index ranges in `vals` corresponding to `vns`
# Each `VarName` `vn` has a single index or a set of contiguous indices in `vals`
ranges::Vector{UnitRange{Int}}
# Vector of values of all the univariate, multivariate and matrix variables
# The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]`
vals::TVal # AbstractVector{<:Real}
# Vector of distributions correpsonding to `vns`
dists::TDists # AbstractVector{<:Distribution}
# Vector of sampler ids corresponding to `vns`
# Each random variable can be sampled using multiple samplers, e.g. in Gibbs, hence the `Set`
gids::TGIds # AbstractVector{Set{Selector}}
# Number of `observe` statements before each random variable is sampled
orders::Vector{Int}
# Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]`
flags::Dict{String,BitVector}
end

Most importantly for a dict-like storage of values, are the following lines:

# Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists`
idcs::TIdcs # Dict{<:VarName,Int}
# Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn`
vns::TVN # AbstractVector{<:VarName}
# Vector of index ranges in `vals` corresponding to `vns`
# Each `VarName` `vn` has a single index or a set of contiguous indices in `vals`
ranges::Vector{UnitRange{Int}}
# Vector of values of all the univariate, multivariate and matrix variables
# The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]`
vals::TVal # AbstractVector{<:Real}

With this, it's fairly easy to implement nice indexing behavior for VarInfo. Here's a simple sketch of what a getindex could look like for Metadata:

function Base.getindex(metadata::Metadata, varname::VarName)
    # Get the index for this `varname`
    idx = metadata.idcs[varname]
    # Get the range for this `varname`
    r = metadata.ranges[idx]
    # Extract the value.
    return metadata.values[r]
end

This is effectively the getval currently implemented:

https://github.com/TuringLang/DynamicPPL.jl/blob/e05bb0935a1e1a06027c603cc04c20f23195a6c4/src/varinfo.jl#L318C44-L318C44

This then results in a Vector of the flattened representation of vn.

Our current implementation of Base.getindex for VarInfo then contains more complexity to convert the Vector back into the original form expected by corresponding distribution, and it's usage looks like

varinfo[varname, dist]

Since the dist is also stored in the Metadata, the above in fact works the same if you do varinfo[varname] if varinfo isa VarInfo and not a SimpleVarInfo. But, as have been discussed many times before, this is not great because it doesn't properly handle dynamics constraints, etc.; we want to use the dist at the place of index, not from the construction of the varinfo.

Nonetheless, value-storage part of Metadata arguably proves quite a nice way to store values in a dict-like way while satisfying the three properties above.

So why not just use Metadata?

Well, we probably should be. But if we're doing so, we should probably simplify its structure quite a bit.

For example, should we drop the following fields?

  • dists: As mentioned, this is often not the correct thing to use.
  • gids: This is used by the Gibbs sampler, and will at some point not be of use anymore since we now have ways of programmatically conditioning and deconditioning models.
  • orders: Only used by particle methods to keep track of the number of observe statements hit. This should probably either be moved somewhere else or at least not be hardcoded into the "main" dict-like object.
  • flags: this might be generally useful, but the flags current used (istrans and delete) are no longer that useful (istrans should be replaced by explicit transformations, as is done in SimpleVarInfo, and delete should also no longer be needed as now have a clear way of indicating whether we're running a model in "sampling mode" or not using SamplingContext).

But the problem of doing this, is that we'll break a lot of code currently dependent on VarInfo functioning as is.
This is also the main reason why we introduced SimpleVarInfo: to allow us to create simpler and different representations of varinfos without breaking existing code.

So what should we do?

For now, it might be a good idea to just introduce a type very similar to Metadata but simpler in its form, i.e. mainly just a value container.

We could either implement out own, or we could see if there are existing implementations in the ecosystem that could benefit us, e.g. Dictionaries.jl seems like it might be suitable.

@torfjelde
Copy link
Member Author

Dictionaries.jl will unfortunately not give us contiguous memory (at least not by default):

julia> using Dictionaries

julia> vnv = Dictionary{VarName}(OrderedDict(@varname(a) => [1.0,2,3], @varname(b) => [4,5,6.]))
2-element Dictionary{VarName, Vector{Float64}}
 a │ [1.0, 2.0, 3.0]
 b │ [4.0, 5.0, 6.0]

julia> vnv.values
2-element Vector{Vector{Float64}}:
 [1.0, 2.0, 3.0]
 [4.0, 5.0, 6.0]

@sunxd3
Copy link
Member

sunxd3 commented Sep 1, 2023

How much does contiguous memory actually matter? I would imagine CPU cache fills in a lot of the performance gap.

@torfjelde
Copy link
Member Author

Well, it depends. Generally speaking, I'm expecting the access patterns here to be views mainly, and so if you then reshape these into larger matrices which is then used for downstream computations, then we'd expect it to make a noticeable difference, no?

But no matter what we do, we should benchmark these things at some point.

@torfjelde
Copy link
Member Author

torfjelde commented Sep 1, 2023

And rolling our own shouldn't be too difficult.

For example, we could use something like

"""
    VarNameDict

A `VarNameDict` is a vector-like collection of values that can be indexed by `VarName`.

This is basically like a `OrderedDict{<:VarName}` but ensures that the underlying values
are stored contiguously in memory.
"""
struct VarNameDict{K,T,V<:AbstractVector{T},D<:AbstractDict{K}} <: AbstractDict{K,T}
    values::V
    varname_to_ranges::D
end

function VarNameDict(dict::OrderedDict)
    offset = 0
    ranges = map(values(dict)) do x
        r = (offset + 1):(offset + length(x))
        offset = r[end]
        r
    end
    vals = mapreduce(DynamicPPL.vectorize, vcat, values(dict))
    return VarNameDict(vals, OrderedDict(zip(keys(dict), ranges)))
end

# Dict-like functionality.
Base.keys(vnd::VarNameDict) = keys(vnd.varname_to_ranges)
Base.values(vnd::VarNameDict) = vnd.values
Base.length(vnd::VarNameDict) = length(vnd.values)

Base.getindex(vnd::VarNameDict, i) = getindex(vnd.values, i)
Base.setindex!(vnd::VarNameDict, val, i) = setindex!(vnd.values, val, i)

function nextrange(vnd::VarNameDict, x)
    n = length(vnd)
    return n + 1:n + length(x)
end

function Base.getindex(vnd::VarNameDict, vn::VarName)
    return getindex(vnd.values, vnd.varname_to_ranges[vn])
end
function Base.setindex!(vnd::VarNameDict, val, vn::VarName)
    # If we don't have `vn` in the dictionary, then we need to add it.
    if !haskey(vnd.varname_to_ranges, vn)
        # Set the range for the new variable.
        r = nextrange(vnd, val)
        vnd.varname_to_ranges[vn] = r
        # Resize the underlying vector to accommodate the new values.
        resize!(vnd.values, r[end])
    else
        # Existing keys needs to be handled differently depending on
        # whether the size of the value is increasing or decreasing.
        r = vnd.varname_to_ranges[vn]
        n_val = length(val)
        n_r = length(r)
        if n_val > n_r
            # Remove the old range.
            delete!(vnd.varname_to_ranges, vn)
            # Add the new range.
            r_new = nextrange(vnd, val)
            vnd.varname_to_ranges[vn] = r_new
            # Resize the underlying vector to accommodate the new values.
            resize!(vnd.values, r_new[end])
        else n_val < n_r
            # Just decrease the current range.
            vnd.varname_to_ranges[vn] = r[1]:r[1] + n_val - 1
        end

        # TODO: Keep track of unused ranges so we can perform sweeps
        # every now and then to free up memory and re-contiguize the
        # underlying vector.
    end

    return setindex!(vnd.values, val, vnd.varname_to_ranges[vn])
end

function BangBang.setindex!!(vnd::VarNameDict, val, vn::VarName)
    setindex!(vnd, val, vn)
    return vnd
end

function Base.iterate(vnd::VarNameDict, state=nothing)
    res = state === nothing ? iterate(vnd.varname_to_ranges) : iterate(vnd.varname_to_ranges, state)
    res === nothing && return nothing
    (vn, range), state_new = res
    return vn => vnd.values[range], state_new
end

Adding this to DPPL and exporting:

julia> using DynamicPPL

julia> model = DynamicPPL.TestUtils.demo_one_variable_multiple_constraints()
Model{typeof(DynamicPPL.TestUtils.demo_one_variable_multiple_constraints), (Symbol("##arg#375"),), (), (), Tuple{DataType}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_one_variable_multiple_constraints, (var"##arg#375" = 
Vector{Float64},), NamedTuple(), DefaultContext())

julia> x = rand(OrderedDict, model)
OrderedDict{Any, Any} with 4 entries:
  x[1]   => 1.18456
  x[2]   => 1.83516
  x[3]   => 0.726668
  x[4:5] => [0.191595, 0.808405]

julia> vnd = VarNameDict(x)
VarNameDict{Any, Float64, Vector{Float64}, OrderedDict{Any, UnitRange{Int64}}} with 5 entries:
  x[1]   => [1.18456]
  x[2]   => [1.83516]
  x[3]   => [0.726668]
  x[4:5] => [0.191595, 0.808405]

julia> vnd[@varname(x[1])]
1-element Vector{Float64}:
 1.1845589710704487

julia> vnd[@varname(x[4:5])]
2-element Vector{Float64}:
 0.19159541239321576
 0.8084045876067844

julia> # Can create a `SimpleVarInfo` from a `VarNameDict`.
       vi = SimpleVarInfo(vnd)
SimpleVarInfo(VarNameDict{Any, Float64, Vector{Float64}, OrderedDict{Any, UnitRange{Int64}}}(x[1] => [1.1845589710704487], x[2] => [1.8351581568345814], x[3] => [0.726667648763101], x[4:5] => [0.19159541239321576, 0.8084045876067844]), 0.0)

julia> # Inherits from `AbstractDict`
       vi[@varname(x[1])]
1-element Vector{Float64}:
 1.1845589710704487

julia> vi[@varname(x[4:5])]
2-element Vector{Float64}:
 0.19159541239321576
 0.8084045876067844

julia> vi[@varname(x[4:5][1])]
0.19159541239321576

julia> vi_linked = link(vi, model)
Transformed SimpleVarInfo(VarNameDict{Any, Float64, Vector{Float64}, OrderedDict{Any, UnitRange{Int64}}}(x[1] => [1.1845589710704487], x[2] => [0.6071306668031453], x[3] => [-1.2135885975805698], x[4:5] => [-1.4396767388530345]), -3.354889980209639)

julia> vi_invlinked = invlink(vi_linked, model)
SimpleVarInfo(VarNameDict{Any, Float64, Vector{Float64}, OrderedDict{Any, UnitRange{Int64}}}(x[1] => [1.1845589710704487], x[2] => [1.8351581568345814], x[3] => [0.726667648763101], x[4:5] => [0.19159541239321573, 0.8084045876067842]), -3.5819390425447497)

Uncertain if it's really worth it to make it a AbstractDict 🤷

@yebai
Copy link
Member

yebai commented Sep 1, 2023

I like the VarNameDict mechanism. It combines the strength of NamedTuple and OrderedDict, by allowing more flexible Lens-like recursive indexing behaviour for Dict, but also keeping values in a continuous vector-like container.

Related: #358 #416

EDIT: reserving a field for some metadata, other than value would be helpful. For example

struct VarNameDict{K,T,M,V<:AbstractVector{T},D<:AbstractDict{K}} <: AbstractDict{K,T}
    metadata::M
    values::V
    varname_to_ranges::D
end

@yebai
Copy link
Member

yebai commented Sep 1, 2023

I suggest that we rename VarNameDict to VarDict, but keep it a subtype fo AbstractDict

@yebai
Copy link
Member

yebai commented Sep 1, 2023

I've successfully used it with ComponentVector from ComponentArrays.jl many times before.

Out of curiosity, what's the issue preventing us from using ComponentArrays as default for SimpleVarInfo, that would save us from rolling out our home-baked VarNameDict/VarDict.

@torfjelde
Copy link
Member Author

Out of curiosity, what's the issue preventing us from using ComponentArrays as default for SimpleVarInfo, that would save us from rolling out our home-baked VarNameDict/VarDict.

Indexing using VarName mainly

@torfjelde
Copy link
Member Author

EDIT: reserving a field for some metadata, other than value would be helpful. For example

I also thought about this but didn't want to add it in the initial version. We could do this, yeah, but it might also just be best to be kept separately in the varinfo? A bit uncertain.

@yebai
Copy link
Member

yebai commented Sep 1, 2023

I also thought about this but didn't want to add it in the initial version. We could do this, yeah, but it might also just be best to be kept separately in the varinfo? A bit uncertain.

Keeping it inside VarNameDict/VarDict feels more natural for me, where each key has its own value and metadata. But, yes, we can do this later.

@torfjelde
Copy link
Member Author

Ah, one quite important thing we need from something like this VarDict or anything similar: a way to convert back to the original form of the variable.

In Metadata this role is covered by the distributions; if we don't want to keep those around, we need to keep some other information around 😕

@yebai
Copy link
Member

yebai commented Sep 2, 2023

It's not too bad to save this "shape" information. I think it can be done transparently when adding new variables to VarDict. Shape information can be extracted by the vectorise function, and then saved to VarDict.

function BangBang.push!!(
vi::SimpleVarInfo{<:VarNameDict},
vn::VarName{sym},
value,
dist::Distribution,
gidset::Set{Selector},
) where {sym}
value_vectorized = vectorize(dist, value)
vi.values[vn] = value_vectorized
return vi
end

@yebai
Copy link
Member

yebai commented Dec 2, 2024

Fixed by #555

@yebai yebai closed this as completed Dec 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants