diff --git a/Project.toml b/Project.toml index 162448316..e9babbaa6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.20.0" +version = "0.20.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/contexts.jl b/src/contexts.jl index 23a6128e1..bd8acf278 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -272,28 +272,18 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end -struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext +struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx - - function ConditionContext{Values}( - values::Values, context::AbstractContext - ) where {names,Values<:NamedTuple{names}} - return new{names,typeof(values),typeof(context)}(values, context) - end end -function ConditionContext(values::NamedTuple) - return ConditionContext(values, DefaultContext()) -end -function ConditionContext(values::NamedTuple, context::AbstractContext) - return ConditionContext{typeof(values)}(values, context) -end +const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}} +const DictConditionContext = ConditionContext{<:AbstractDict} + +ConditionContext(values) = ConditionContext(values, DefaultContext()) # Try to avoid nested `ConditionContext`. -function ConditionContext( - values::NamedTuple{Names}, context::ConditionContext -) where {Names} +function ConditionContext(values::NamedTuple, context::NamedConditionContext) # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), childcontext(context)) @@ -303,7 +293,7 @@ function Base.show(io::IO, context::ConditionContext) return print(io, "ConditionContext($(context.values), $(childcontext(context)))") end -NodeTrait(context::ConditionContext) = IsParent() +NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) @@ -313,14 +303,9 @@ setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.value Return `true` if `vn` is found in `context`. """ hasvalue(context, vn) = false - -function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} - return sym in vars -end -function hasvalue( - context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} -) where {vars,sym} - return sym in vars +hasvalue(context::ConditionContext, vn::VarName) = nested_haskey(context.values, vn) +function hasvalue(context::ConditionContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(nested_haskey, context.values), vns) end """ @@ -331,7 +316,8 @@ Return value of `vn` in `context`. function getvalue(context::AbstractContext, vn) return error("context $(context) does not contain value for $vn") end -getvalue(context::ConditionContext, vn) = get(context.values, vn) +getvalue(context::NamedConditionContext, vn) = get(context.values, vn) +getvalue(context::ConditionContext, vn) = nested_getindex(context.values, vn) """ hasvalue_nested(context, vn) @@ -386,15 +372,33 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ -AbstractPPL.condition(; values...) = condition(DefaultContext(), NamedTuple(values)) +AbstractPPL.condition(; values...) = condition(NamedTuple(values)) AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values) +function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...) + return condition((value, values...)) +end +function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}}) + return condition(DefaultContext(), values) +end AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context -function AbstractPPL.condition(context::AbstractContext, values::NamedTuple) +function AbstractPPL.condition( + context::AbstractContext, values::Union{AbstractDict,NamedTuple} +) return ConditionContext(values, context) end function AbstractPPL.condition(context::AbstractContext; values...) return condition(context, NamedTuple(values)) end +function AbstractPPL.condition( + context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}... +) + return condition(context, (value, values...)) +end +function AbstractPPL.condition( + context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}} +) + return condition(context, Dict(values)) +end """ decondition(context::AbstractContext, syms...) @@ -430,6 +434,19 @@ function AbstractPPL.decondition(context::ConditionContext, sym, syms...) ) end +function AbstractPPL.decondition( + context::NamedConditionContext, vn::VarName{sym} +) where {sym} + return condition( + decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym) + ) +end +function AbstractPPL.decondition(context::ConditionContext, vn::VarName) + return condition( + decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn) + ) +end + """ conditioned(context::AbstractContext) diff --git a/src/model.jl b/src/model.jl index f7dc4b113..b9974254d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -109,7 +109,7 @@ This is done for the sake of backwards compatibility. # Examples ## Simple univariate model ```jldoctest condition -julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reproducibility. +julia> using Distributions julia> @model function demo() m ~ Normal() @@ -120,27 +120,48 @@ demo (generic function with 2 methods) julia> model = demo(); -julia> model(rng) -(m = -0.6702516921145671, x = -0.22312984965118443) +julia> m, x = model(); (m ≠ 1.0 && x ≠ 100.0) +true julia> # Create a new instance which treats `x` as observed # with value `100.0`, and similarly for `m=1.0`. conditioned_model = condition(model, x=100.0, m=1.0); -julia> conditioned_model(rng) -(m = 1.0, x = 100.0) +julia> m, x = conditioned_model(); (m == 1.0 && x == 100.0) +true julia> # Let's only condition on `x = 100.0`. conditioned_model = condition(model, x = 100.0); -julia> conditioned_model(rng) -(m = 1.3736306979834252, x = 100.0) +julia> m, x =conditioned_model(); (m ≠ 1.0 && x == 100.0) +true julia> # We can also use the nicer `|` syntax. conditioned_model = model | (x = 100.0, ); -julia> conditioned_model(rng) -(m = 1.3095394956381083, x = 100.0) +julia> m, x = conditioned_model(); (m ≠ 1.0 && x == 100.0) +true +``` + +The above uses a `NamedTuple` to hold the conditioning variables, which allows us to perform some +additional optimizations; in many cases, the above has zero runtime-overhead. + +But we can also use a `Dict`, which offers more flexibility in the conditioning +(see examples further below) but generally has worse performance than the `NamedTuple` +approach: + +```jldoctest condition +julia> conditioned_model_dict = condition(model, Dict(@varname(x) => 100.0)); + +julia> m, x = conditioned_model_dict(); (m ≠ 1.0 && x == 100.0) +true + +julia> # There's also an option using `|` by letting the right-hand side be a tuple + # with elements of type `Pair{<:VarName}`, i.e. `vn => value` with `vn isa VarName`. + conditioned_model_dict = model | (@varname(x) => 100.0, ); + +julia> m, x = conditioned_model_dict(); (m ≠ 1.0 && x == 100.0) +true ``` ## Condition only a part of a multivariate variable @@ -162,23 +183,31 @@ julia> model = demo_mv(); julia> conditioned_model = condition(model, m = [missing, 1.0]); -julia> conditioned_model(rng) # (✓) `m[1]` sampled, `m[2]` is fixed -2-element Vector{Float64}: - 0.12607002180931043 - 1.0 +julia> # (✓) `m[1]` sampled while `m[2]` is fixed + m = conditioned_model(); (m[1] ≠ 1.0 && m[2] == 1.0) +true ``` -Intuitively one might also expect to be able to write `model | (x[1] = 1.0, )`. -Unfortunately this is not supported due to performance. +Intuitively one might also expect to be able to write `model | (m[1] = 1.0, )`. +Unfortunately this is not supported as it has the potential of increasing compilation +times but without offering any benefit with respect to runtime: ```jldoctest condition -julia> condition(model, var"x[2]" = 1.0)(rng) # (×) `x[2]` is not set to 1.0. -2-element Vector{Float64}: - 0.683947930996541 - -1.019202452456547 +julia> # (×) `m[2]` is not set to 1.0. + m = condition(model, var"m[2]" = 1.0)(); m[2] == 1.0 +false ``` -We will likely provide some syntactic sugar for this in the future. +But you _can_ do this if you use a `Dict` as the underlying storage instead: + +```jldoctest condition +julia> # Alternatives: + # - `model | (@varname(m[2]) => 1.0,)` + # - `condition(model, Dict(@varname(m[2] => 1.0)))` + # (✓) `m[2]` is set to 1.0. + m = condition(model, @varname(m[2]) => 1.0)(); (m[1] ≠ 1.0 && m[2] == 1.0) +true +``` ## Nested models @@ -197,12 +226,12 @@ demo_outer (generic function with 2 methods) julia> model = demo_outer(); -julia> model(rng) --0.7935128416361353 +julia> model() ≠ 1.0 +true julia> conditioned_model = model | (m = 1.0, ); -julia> conditioned_model(rng) +julia> conditioned_model() 1.0 ``` @@ -215,16 +244,16 @@ julia> @model function demo_outer_prefix() end demo_outer_prefix (generic function with 2 methods) -julia> # This doesn't work now! +julia> # (×) This doesn't work now! conditioned_model = demo_outer_prefix() | (m = 1.0, ); -julia> conditioned_model(rng) -1.7747246334368165 +julia> conditioned_model() == 1.0 +false -julia> # `m` in `demo_inner` is referred to as `inner.m` internally, so we do: +julia> # (✓) `m` in `demo_inner` is referred to as `inner.m` internally, so we do: conditioned_model = demo_outer_prefix() | (var"inner.m" = 1.0, ); -julia> conditioned_model(rng) +julia> conditioned_model() 1.0 julia> # Note that the above `var"..."` is just standard Julia syntax: @@ -232,6 +261,15 @@ julia> # Note that the above `var"..."` is just standard Julia syntax: (Symbol("inner.m"),) ``` +And similarly when using `Dict`: + +```jldoctest condition +julia> conditioned_model_dict = demo_outer_prefix() | (@varname(var"inner.m") => 1.0); + +julia> conditioned_model_dict() +1.0 +``` + The difference is maybe more obvious once we look at how these different in their trace/`VarInfo`: @@ -250,24 +288,27 @@ is in the two different models. """ AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values)) -function AbstractPPL.condition(model::Model, values) - return contextualize(model, condition(model.context, values)) +function AbstractPPL.condition(model::Model, value, values...) + return contextualize(model, condition(model.context, value, values...)) end """ decondition(model::Model) - decondition(model::Model, syms...) + decondition(model::Model, variables...) -Return a `Model` for which `syms...` are _not_ considered observations. -If no `syms` are provided, then all variables currently considered observations +Return a `Model` for which `variables...` are _not_ considered observations. +If no `variables` are provided, then all variables currently considered observations will no longer be. This is essentially the inverse of [`condition`](@ref). This also means that it suffers from the same limitiations. +Note that currently we only support `variables` to take on explicit values +provided to `condition. + # Examples -```jldoctest -julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reproducibility. +```jldoctest decondition +julia> using Distributions julia> @model function demo() m ~ Normal() @@ -278,29 +319,88 @@ demo (generic function with 2 methods) julia> conditioned_model = condition(demo(), m = 1.0, x = 10.0); -julia> conditioned_model(rng) +julia> conditioned_model() (m = 1.0, x = 10.0) -julia> model = decondition(conditioned_model, :m); +julia> # By specifying the `VarName` to `decondition`. + model = decondition(conditioned_model, @varname(m)); + +julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0) +true -julia> model(rng) -(m = -0.6702516921145671, x = 10.0) +julia> # When `NamedTuple` is used as the underlying, you can also provide + # the symbol directly (though the `@varname` approach is preferable if + # if the variable is known at compile-time). + model = decondition(conditioned_model, :m); + +julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0) +true julia> # `decondition` multiple at once: - decondition(model, :m, :x)(rng) -(m = 0.4471218424633827, x = 1.820752540446808) + (m, x) = decondition(model, :m, :x)(); (m ≠ 1.0 && x ≠ 10.0) +true julia> # `decondition` without any symbols will `decondition` all variables. - decondition(model)(rng) -(m = 1.3095394956381083, x = 1.4356095174474188) + (m, x) = decondition(model)(); (m ≠ 1.0 && x ≠ 10.0) +true julia> # Usage of `Val` to perform `decondition` at compile-time if possible # is also supported. model = decondition(conditioned_model, Val{:m}()); -julia> model(rng) -(m = 0.683947930996541, x = 10.0) +julia> (m, x) = model(); (m ≠ 1.0 && x == 10.0) +true +``` + +Similarly when using a `Dict`: + +```jldoctest decondition +julia> conditioned_model_dict = condition(demo(), @varname(m) => 1.0, @varname(x) => 10.0); + +julia> conditioned_model_dict() +(m = 1.0, x = 10.0) + +julia> deconditioned_model_dict = decondition(conditioned_model_dict, @varname(m)); + +julia> (m, x) = deconditioned_model_dict(); m ≠ 1.0 && x == 10.0 +true +``` + +But, as mentioned, `decondition` is only supported for variables explicitly +provided to `condition` earlier; + +```jldoctest decondition +julia> @model function demo_mv(::Type{TV}=Float64) where {TV} + m = Vector{TV}(undef, 2) + m[1] ~ Normal() + m[2] ~ Normal() + return m + end +demo_mv (generic function with 3 methods) + +julia> model = demo_mv(); + +julia> conditioned_model = condition(model, @varname(m) => [1.0, 2.0]); + +julia> conditioned_model() +2-element Vector{Float64}: + 1.0 + 2.0 + +julia> deconditioned_model = decondition(conditioned_model, @varname(m[1])); + +julia> deconditioned_model() # (×) `m[1]` is still conditioned +2-element Vector{Float64}: + 1.0 + 2.0 + +julia> # (✓) this works though + deconditioned_model_2 = deconditioned_model | (@varname(m[1]) => missing); + +julia> m = deconditioned_model_2(); (m[1] ≠ 1.0 && m[2] == 2.0) +true ``` + """ function AbstractPPL.decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 5b9edefdf..b243234bb 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -276,28 +276,7 @@ Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) # `Dict` function Base.getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) - if haskey(vi.values, vn) - return vi.values[vn] - end - - # Split the lens into the key / `parent` and the extraction lens / `child`. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(vi.values, VarName(vn, l)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent - - # If we found a valid split, then we can extract the value. - if !issuccess - # At this point we just throw an error since the key could not be found. - throw(KeyError(vn)) - end - - # TODO: Should we also check that we `canview` the extracted `value` - # rather than just let it fail upon `get` call? - value = vi.values[VarName(vn, keylens)] - return get(value, child) + return nested_getindex(vi.values, vn) end # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than @@ -327,38 +306,7 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut return reconstruct(dist, vals, length(vns)) end -Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) -function _haskey(nt::NamedTuple, vn::VarName) - # LHS: Ensure that `nt` indeed has the property we want. - # RHS: Ensure that the lens can view into `nt`. - sym = getsym(vn) - return haskey(nt, sym) && canview(getlens(vn), getproperty(nt, sym)) -end - -# For `dictlike` we need to check wether `vn` is "immediately" present, or -# if some ancestor of `vn` is present in `dictlike`. -function _haskey(dict::AbstractDict, vn::VarName) - # First we check if `vn` is present as is. - haskey(dict, vn) && return true - - # If `vn` is not present, we check any parent-varnames by attempting - # to split the lens into the key / `parent` and the extraction lens / `child`. - # If `issuccess` is `true`, we found such a split, and hence `vn` is present. - parent, child, issuccess = splitlens(getlens(vn)) do lens - l = lens === nothing ? Setfield.IdentityLens() : lens - haskey(dict, VarName(vn, l)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. - keylens = parent === nothing ? Setfield.IdentityLens() : parent - - # Return early if no such split could be found. - issuccess || return false - - # At this point we just need to check that we `canview` the value. - value = dict[VarName(vn, keylens)] - - return canview(child, value) -end +Base.haskey(vi::SimpleVarInfo, vn::VarName) = nested_haskey(vi.values, vn) function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. diff --git a/src/utils.jl b/src/utils.jl index ac9222818..917e1d71b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -426,3 +426,123 @@ function BangBang.possible( return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) end + +""" + nested_getindex(values::AbstractDict, vn::VarName) + +Return value corresponding to `vn` in `values` by also looking +in the the actual values of the dict. + +# Examples + +```jldoctest +julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x)) # same as `getindex` +1-element Vector{Float64}: + 1.0 + +julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x[1])) # different from `getindex` +1.0 + +julia> DynamicPPL.nested_getindex(Dict(@varname(x) => [1.0]), @varname(x[2])) +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] +``` +""" +function nested_getindex(values::AbstractDict, vn::VarName) + maybeval = get(values, vn, nothing) + if maybeval !== nothing + return maybeval + end + + # Split the lens into the key / `parent` and the extraction lens / `child`. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(values, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + # If we found a valid split, then we can extract the value. + if !issuccess + # At this point we just throw an error since the key could not be found. + throw(KeyError(vn)) + end + + # TODO: Should we also check that we `canview` the extracted `value` + # rather than just let it fail upon `get` call? + value = values[VarName(vn, keylens)] + return get(value, child) +end + +""" + nested_haskey(x, vn::VarName) + +Determine whether `x` has a mapping for a given `vn`. + +# Examples +With `x` as a `NamedTuple`: +```jldoctest +julia> DynamicPPL.nested_haskey((x = 1.0, ), @varname(x)) +true + +julia> DynamicPPL.nested_haskey((x = 1.0, ), @varname(x[1])) +false + +julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x)) +true + +julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x[1])) +true + +julia> DynamicPPL.nested_haskey((x = [1.0],), @varname(x[2])) +false +``` + +With `x` as a `AbstractDict`: +```jldoctest +julia> DynamicPPL.nested_haskey(Dict(@varname(x) => 1.0, ), @varname(x)) +true + +julia> DynamicPPL.nested_haskey(Dict(@varname(x) => 1.0, ), @varname(x[1])) +false + +julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x)) +true + +julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x[1])) +true + +julia> DynamicPPL.nested_haskey(Dict(@varname(x) => [1.0]), @varname(x[2])) +false +``` +""" +function nested_haskey(nt::NamedTuple, vn::VarName{sym}) where {sym} + # LHS: Ensure that `nt` indeed has the property we want. + # RHS: Ensure that the lens can view into `nt`. + return haskey(nt, sym) && canview(getlens(vn), getproperty(nt, sym)) +end + +# For `dictlike` we need to check wether `vn` is "immediately" present, or +# if some ancestor of `vn` is present in `dictlike`. +function nested_haskey(dict::AbstractDict, vn::VarName) + # First we check if `vn` is present as is. + haskey(dict, vn) && return true + + # If `vn` is not present, we check any parent-varnames by attempting + # to split the lens into the key / `parent` and the extraction lens / `child`. + # If `issuccess` is `true`, we found such a split, and hence `vn` is present. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(dict, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + # Return early if no such split could be found. + issuccess || return false + + # At this point we just need to check that we `canview` the value. + value = dict[VarName(vn, keylens)] + + return canview(child, value) +end