diff --git a/docs/src/api.md b/docs/src/api.md index 6b9493985..ddd119816 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -243,6 +243,8 @@ DynamicPPL.StaticTransformation DynamicPPL.istrans DynamicPPL.settrans!! DynamicPPL.transformation +DynamicPPL.link +DynamicPPL.invlink DynamicPPL.link!! DynamicPPL.invlink!! DynamicPPL.default_transformation diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d1a21530b..1fd008ffe 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -60,8 +60,10 @@ export AbstractVarInfo, updategid!, setorder!, istrans, + link, link!, link!!, + invlink, invlink!, invlink!!, tonamedtuple, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c4cdda5a2..de1efe4c1 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -368,7 +368,8 @@ function settrans!! end link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) -Transforms the variables in `vi` to their linked space, using the transformation `t`. +Transform the variables in `vi` to their linked space, using the transformation `t`, +mutating `vi` if possible. If `t` is not provided, `default_transformation(model, vi)` will be used. @@ -383,12 +384,31 @@ function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) return link!!(default_transformation(model, vi), vi, spl, model) end +""" + link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + link([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + +Transform the variables in `vi` to their linked space without mutating `vi`, using the transformation `t`. + +If `t` is not provided, `default_transformation(model, vi)` will be used. + +See also: [`default_transformation`](@ref), [`invlink`](@ref). +""" +link(vi::AbstractVarInfo, model::Model) = link(deepcopy(vi), SampleFromPrior(), model) +function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + return link(t, deepcopy(vi), SampleFromPrior(), model) +end +function link(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + # Use `default_transformation` to decide which transformation to use if none is specified. + return link(default_transformation(model, vi), deepcopy(vi), spl, model) +end + """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) Transform the variables in `vi` to their constrained space, using the (inverse of) -transformation `t`. +transformation `t`, mutating `vi` if possible. If `t` is not provided, `default_transformation(model, vi)` will be used. @@ -434,6 +454,25 @@ function invlink!!( return settrans!!(vi_new, NoTransformation()) end +""" + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) + invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + +Transform the variables in `vi` to their constrained space without mutating `vi`, using the (inverse of) +transformation `t`. + +If `t` is not provided, `default_transformation(model, vi)` will be used. + +See also: [`default_transformation`](@ref), [`link`](@ref). +""" +invlink(vi::AbstractVarInfo, model::Model) = invlink(vi, SampleFromPrior(), model) +function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) + return invlink(t, vi, SampleFromPrior(), model) +end +function invlink(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) + return invlink(transformation(vi), vi, spl, model) +end + """ maybe_invlink_before_eval!!([t::Transformation,] vi, context, model) diff --git a/src/test_utils.jl b/src/test_utils.jl index 5028699f2..14da79afa 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -197,6 +197,115 @@ function logprior_true_with_logabsdet_jacobian( return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp end +""" + demo_one_variable_multiple_constraints() + +A model with a single multivariate `x` whose components have multiple different constraints. + +# Model +```julia +x[1] ~ Normal() +x[2] ~ InverseGamma(2, 3) +x[3] ~ truncated(Normal(), -5, 20) +x[4:5] ~ Dirichlet([1.0, 2.0]) +``` + +""" +@model function demo_one_variable_multiple_constraints( + ::Type{TV}=Vector{Float64} +) where {TV} + x = TV(undef, 5) + x[1] ~ Normal() + x[2] ~ InverseGamma(2, 3) + x[3] ~ truncated(Normal(), -5, 20) + x[4:5] ~ Dirichlet([1.0, 2.0]) + + return (x=x,) +end + +function logprior_true(model::Model{typeof(demo_one_variable_multiple_constraints)}, x) + return ( + logpdf(Normal(), x[1]) + + logpdf(InverseGamma(2, 3), x[2]) + + logpdf(truncated(Normal(), -5, 20), x[3]) + + logpdf(Dirichlet([1.0, 2.0]), x[4:5]) + ) +end +function loglikelihood_true(model::Model{typeof(demo_one_variable_multiple_constraints)}, x) + return zero(float(eltype(x))) +end +function varnames(model::Model{typeof(demo_one_variable_multiple_constraints)}) + return [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(x[4:5])] +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_one_variable_multiple_constraints)}, x +) + b_x2 = Bijectors.bijector(InverseGamma(2, 3)) + b_x3 = Bijectors.bijector(truncated(Normal(), -5, 20)) + b_x4 = Bijectors.bijector(Dirichlet([1.0, 2.0])) + x_unconstrained = vcat(x[1], b_x2(x[2]), b_x3(x[3]), b_x4(x[4:5])) + Δlogp = ( + Bijectors.logabsdetjac(b_x2, x[2]) + + Bijectors.logabsdetjac(b_x3, x[3]) + + Bijectors.logabsdetjac(b_x4, x[4:5]) + ) + return (x=x_unconstrained,), logprior_true(model, x) - Δlogp +end + +function Random.rand( + rng::Random.AbstractRNG, + ::Type{NamedTuple}, + model::Model{typeof(demo_one_variable_multiple_constraints)}, +) + x = Vector{Float64}(undef, 5) + x[1] = rand(rng, Normal()) + x[2] = rand(rng, InverseGamma(2, 3)) + x[3] = rand(rng, truncated(Normal(), -5, 20)) + x[4:5] = rand(rng, Dirichlet([1.0, 2.0])) + return (x=x,) +end + +""" + demo_lkjchol(d=2) + +A model with a single variable `x` with support on the Cholesky factor of a +LKJ distribution. + +# Model +```julia +x ~ LKJCholesky(d, 1.0) +``` +""" +@model function demo_lkjchol(d::Int=2) + x ~ LKJCholesky(d, 1.0) + return (x=x,) +end + +function logprior_true(model::Model{typeof(demo_lkjchol)}, x) + return logpdf(LKJCholesky(model.args.d, 1.0), x) +end + +function loglikelihood_true(model::Model{typeof(demo_lkjchol)}, x) + return zero(float(eltype(x))) +end + +function varnames(model::Model{typeof(demo_lkjchol)}) + return [@varname(x)] +end + +function logprior_true_with_logabsdet_jacobian(model::Model{typeof(demo_lkjchol)}, x) + b_x = Bijectors.bijector(LKJCholesky(model.args.d, 1.0)) + x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x) + return (x=x_unconstrained,), logprior_true(model, x) - Δlogp +end + +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::Model{typeof(demo_lkjchol)} +) + x = rand(rng, LKJCholesky(model.args.d, 1.0)) + return (x=x,) +end + # A collection of models for which the posterior should be "similar". # Some utility methods for these. function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index dc8720e0a..f7ab3fa85 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -93,6 +93,18 @@ function invlink!!( return invlink!!(t, vi.varinfo, spl, model) end +function link( + t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return link(t, vi.varinfo, spl, model) +end + +function invlink( + t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model +) + return invlink(t, vi.varinfo, spl, model) +end + function maybe_invlink_before_eval!!( vi::ThreadSafeVarInfo, context::AbstractContext, model::Model ) diff --git a/src/transforming.jl b/src/transforming.jl index a544a814b..41c877c91 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -94,3 +94,15 @@ function invlink!!( NoTransformation(), ) end + +function link( + t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model +) + return link!!(t, deepcopy(vi), spl, model) +end + +function invlink( + t::DynamicTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model +) + return invlink!!(t, deepcopy(vi), spl, model) +end diff --git a/src/utils.jl b/src/utils.jl index a1fb12788..ec11675ac 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -501,6 +501,8 @@ function splitlens(condition, lens) return current_parent, current_child, condition(current_parent) end +# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233 +# and https://github.com/JuliaFolds/BangBang.jl/pull/238. # HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. function BangBang.possible( ::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer @@ -514,6 +516,23 @@ function BangBang.possible( return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) end +# HACK: Makes it possible to use ranges, etc. for setting a vector. +# For example, without this hack, BangBang.jl will consider +# +# x[1:2] = [1, 2] +# +# as NOT supported. This results is calling the immutable +# `BangBang.setindex` instead, which also ends up expanding the +# type of the containing array (`x` in the above scenario) to +# have element type `Any`. +# The below code just, correctly, marks this as possible and +# thus we hit the mutable `setindex!` instead. +function BangBang.possible( + ::typeof(BangBang._setindex!), ::C, ::T, ::AbstractVector{<:Integer} +) where {C<:AbstractVector,T<:AbstractVector} + return BangBang.implements(setindex!, C) && + promote_type(eltype(C), eltype(T)) <: eltype(C) +end # HACK(torfjelde): This makes it so it works on iterators, etc. by default. # TODO(torfjelde): Do better. diff --git a/src/varinfo.jl b/src/varinfo.jl index 60b6e93c1..fbe3f6088 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -363,7 +363,11 @@ Set the values of all the variables in `vi` to `val`. The values may or may not be transformed to Euclidean space. """ -setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val +function setall!(vi::UntypedVarInfo, val) + for r in vi.metadata.ranges + vi.metadata.vals[r] .= val[r] + end +end setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) @generated function _setall!(metadata::NamedTuple{names}, val) where {names} expr = Expr(:block) @@ -885,7 +889,6 @@ end end function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) - @debug "X -> ℝ for $(vn)..." # TODO: Use inplace versions to avoid allocations y, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, getval(vi, vn)) yvec = vectorize(dist, y) @@ -899,6 +902,142 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) return vi end +function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) + return _link(varinfo) +end + +function _link(varinfo::UntypedVarInfo) + varinfo = deepcopy(varinfo) + return VarInfo( + _link_metadata!(varinfo, varinfo.metadata), + Base.Ref(getlogp(varinfo)), + Ref(get_num_produce(varinfo)), + ) +end + +function _link(varinfo::TypedVarInfo) + varinfo = deepcopy(varinfo) + md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata) + # TODO: Update logp, etc. + return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) +end + +function _link_metadata!(varinfo::VarInfo, metadata::Metadata) + vns = metadata.vns + + # Construct the new transformed values, and keep track of their lengths. + vals_new = map(vns) do vn + # Return early if we're already in unconstrained space. + if istrans(varinfo, vn) + return metadata.vals[getrange(metadata, vn)] + end + + # Transform to constrained space. + x = getval(varinfo, vn) + dist = getdist(varinfo, vn) + f = link_transform(dist) + y, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, x) + # Vectorize value. + yvec = vectorize(dist, y) + # Accumulate the log-abs-det jacobian correction. + acclogp!!(varinfo, -logjac) + # Mark as no longer transformed. + settrans!!(varinfo, true, vn) + # Return the vectorized transformed value. + return yvec + end + + # Determine new ranges. + ranges_new = similar(metadata.ranges) + offset = 0 + for (i, v) in enumerate(vals_new) + r_start, r_end = offset + 1, length(v) + offset + offset = r_end + ranges_new[i] = r_start:r_end + end + + # Now we just create a new metadata with the new `vals` and `ranges`. + return Metadata( + metadata.idcs, + metadata.vns, + ranges_new, + reduce(vcat, vals_new), + metadata.dists, + metadata.gids, + metadata.orders, + metadata.flags, + ) +end + +function invlink( + ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model +) + return _invlink(varinfo) +end + +function _invlink(varinfo::UntypedVarInfo) + varinfo = deepcopy(varinfo) + return VarInfo( + _invlink_metadata!(varinfo, varinfo.metadata), + Base.Ref(getlogp(varinfo)), + Ref(get_num_produce(varinfo)), + ) +end + +function _invlink(varinfo::TypedVarInfo) + varinfo = deepcopy(varinfo) + md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata) + # TODO: Update logp, etc. + return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) +end + +function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata) + vns = metadata.vns + + # Construct the new transformed values, and keep track of their lengths. + vals_new = map(vns) do vn + # Return early if we're already in constrained space. + if !istrans(varinfo, vn) + return metadata.vals[getrange(metadata, vn)] + end + + # Transform to constrained space. + y = getval(varinfo, vn) + dist = getdist(varinfo, vn) + f = invlink_transform(dist) + x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y) + # Vectorize value. + xvec = vectorize(dist, x) + # Accumulate the log-abs-det jacobian correction. + acclogp!!(varinfo, -logjac) + # Mark as no longer transformed. + settrans!!(varinfo, false, vn) + # Return the vectorized transformed value. + return xvec + end + + # Determine new ranges. + ranges_new = similar(metadata.ranges) + offset = 0 + for (i, v) in enumerate(vals_new) + r_start, r_end = offset + 1, length(v) + offset + offset = r_end + ranges_new[i] = r_start:r_end + end + + # Now we just create a new metadata with the new `vals` and `ranges`. + return Metadata( + metadata.idcs, + metadata.vns, + ranges_new, + reduce(vcat, vals_new), + metadata.dists, + metadata.gids, + metadata.orders, + metadata.flags, + ) +end + """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) diff --git a/test/linking.jl b/test/linking.jl index c9c0c318f..493a0d2b0 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -57,7 +57,7 @@ function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, istrans::Boo return lp end -@testset "Linking" begin +@testset "Linking (mutable=$mutable)" for mutable in [false, true] @testset "simple matrix distribution" begin # Just making sure the transformations are okay. x = randn(3, 3) @@ -76,7 +76,11 @@ end @testset "$(short_varinfo_name(vi))" for vi in vis # Evaluate once to ensure we have `logp` value. vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + vi_linked = if mutable + DynamicPPL.link!!(deepcopy(vi), model) + else + DynamicPPL.link(vi, model) + end # Difference should just be the log-absdet-jacobian "correction". @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) @@ -84,7 +88,11 @@ end @test length(vi_linked[:]) < length(vi[:]) @test length(vi_linked[:]) == length(y) # Invlinked. - vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) + vi_invlinked = if mutable + DynamicPPL.invlink!!(deepcopy(vi_linked), model) + else + DynamicPPL.invlink(vi_linked, model) + end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi) @@ -112,12 +120,20 @@ end lp_model = logjoint(model, vi) @test lp_model ≈ lp # Linked. - vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + vi_linked = if mutable + DynamicPPL.link!!(deepcopy(vi), model) + else + DynamicPPL.link(vi, model) + end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. @test !(getlogp(vi_linked) ≈ lp) # Invlinked. - vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) + vi_invlinked = if mutable + DynamicPPL.invlink!!(deepcopy(vi_linked), model) + else + DynamicPPL.invlink(vi_linked, model) + end @test length(vi_invlinked[:]) == d^2 @test getlogp(vi_invlinked) ≈ lp end @@ -137,12 +153,20 @@ end lp_model = logjoint(model, vi) @test lp_model ≈ lp # Linked. - vi_linked = DynamicPPL.link!!(deepcopy(vi), model) + vi_linked = if mutable + DynamicPPL.link!!(deepcopy(vi), model) + else + DynamicPPL.link(vi, model) + end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. @test !(getlogp(vi_linked) ≈ lp) # Invlinked. - vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) + vi_invlinked = if mutable + DynamicPPL.invlink!!(deepcopy(vi_linked), model) + else + DynamicPPL.invlink(vi_linked, model) + end @test length(vi_invlinked[:]) == d @test getlogp(vi_invlinked) ≈ lp end diff --git a/test/varinfo.jl b/test/varinfo.jl index 35ab30dcd..598ea7814 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -356,4 +356,69 @@ end end end + + @testset "unflatten + linking" begin + @testset "Model: $(model.f)" for model in [ + DynamicPPL.TestUtils.demo_one_variable_multiple_constraints(), + DynamicPPL.TestUtils.demo_lkjchol(), + ] + @testset "mutating=$mutating" for mutating in [false, true] + value_true = rand(model) + varnames = DynamicPPL.TestUtils.varnames(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + if varinfo isa SimpleVarInfo{<:NamedTuple} + # NOTE: this is broken since we'll end up trying to set + # + # varinfo[@varname(x[4:5])] = [x[4],] + # + # upon linking (since `x[4:5]` will be projected onto a 1-dimensional + # space). In the case of `SimpleVarInfo{<:NamedTuple}`, this results in + # calling `setindex!!(varinfo.values, [x[4],], @varname(x[4:5]))`, which + # in turn attempts to call `setindex!(varinfo.values.x, [x[4],], 4:5)`, + # i.e. a vector of length 1 (`[x[4],]`) being assigned to 2 indices (`4:5`). + @test_broken false + continue + end + + # Evaluate the model once to update the logp of the varinfo. + varinfo = last(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())) + + varinfo_linked = if mutating + DynamicPPL.link!!(deepcopy(varinfo), model) + else + DynamicPPL.link(varinfo, model) + end + @test length(varinfo[:]) > length(varinfo_linked[:]) + varinfo_linked_unflattened = DynamicPPL.unflatten( + varinfo_linked, varinfo_linked[:] + ) + @test length(varinfo_linked_unflattened[:]) == length(varinfo_linked[:]) + + lp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) + value_linked_true, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, value_true... + ) + + lp = logjoint(model, varinfo) + @test lp ≈ lp_true + @test getlogp(varinfo) ≈ lp_true + lp_linked = getlogp(varinfo_linked) + @test lp_linked ≈ lp_linked_true + + # TODO: Compare values once we are no longer working with `NamedTuple` for + # the true values, e.g. `value_true`. + + if !mutating + # This is also compatible with invlinking of unflattened varinfo. + varinfo_invlinked = DynamicPPL.invlink( + varinfo_linked_unflattened, model + ) + @test length(varinfo_invlinked[:]) == length(varinfo[:]) + @test getlogp(varinfo_invlinked) ≈ lp_true + end + end + end + end + end end