From 0418307da435aa1a5afa40fc0eef71f56fa9a97a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 15 Sep 2024 15:52:18 +0200 Subject: [PATCH 01/28] Require Julia v1.10 --- .github/workflows/CI.yml | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f0e0e023..6e12beea 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - '1.10' - '1' - 'pre' os: diff --git a/Project.toml b/Project.toml index f005cb3d..4856f5cb 100644 --- a/Project.toml +++ b/Project.toml @@ -54,4 +54,4 @@ Static = "0.8, 1" Statistics = "1" Test = "1" Tricks = "0.1" -julia = "1.6" +julia = "1.10" From 3c8c393a8e23c7570d64ae2a9afca9d33becd817 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 15 Sep 2024 16:04:29 +0200 Subject: [PATCH 02/28] Move ChainRulesCore support to extension --- Project.toml | 7 +++++- ext/MeasureBaseChainRulesCoreExt.jl | 37 +++++++++++++++++++++++++++++ src/MeasureBase.jl | 1 - src/density-core.jl | 8 ------- src/getdof.jl | 4 +--- src/insupport.jl | 5 ---- src/transport.jl | 3 --- 7 files changed, 44 insertions(+), 21 deletions(-) create mode 100644 ext/MeasureBaseChainRulesCoreExt.jl diff --git a/Project.toml b/Project.toml index 4856f5cb..9e3b74ef 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["Chad Scherrer ", "Oliver Schulz Date: Sun, 15 Sep 2024 16:04:37 +0200 Subject: [PATCH 03/28] Add ForwardDiff extension --- Project.toml | 3 +++ ext/MeasureBaseForwardDiffExt.jl | 14 ++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 ext/MeasureBaseForwardDiffExt.jl diff --git a/Project.toml b/Project.toml index 9e3b74ef..058dea9d 100644 --- a/Project.toml +++ b/Project.toml @@ -30,9 +30,11 @@ Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [extensions] MeasureBaseChainRulesCoreExt = "ChainRulesCore" +MeasureBaseForwardDiffExt = "ForwardDiff" [compat] ChainRulesCore = "1" @@ -41,6 +43,7 @@ Compat = "3.35, 4" ConstructionBase = "1.3" DensityInterface = "0.4" FillArrays = "0.12, 0.13, 1" +ForwardDiff = "0.8, 0.9, 0.10" FunctionChains = "0.1" IfElse = "0.1" IntervalSets = "0.7" diff --git a/ext/MeasureBaseForwardDiffExt.jl b/ext/MeasureBaseForwardDiffExt.jl new file mode 100644 index 00000000..8a1cab44 --- /dev/null +++ b/ext/MeasureBaseForwardDiffExt.jl @@ -0,0 +1,14 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +module MeasureBaseForwardDiffExt + +using MeasureBase +import ForwardDiff + +function MeasureBase.containsnan(x::ForwardDiff.Dual) + a = containsnan(x.value) + b = containsnan(x.partials) + return a || b +end + +end # module MeasureBaseForwardDiffExt From 3d45f64cae0954ae4bb85e80443001d13c2bb19c Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 15 Sep 2024 16:16:44 +0200 Subject: [PATCH 04/28] Add Distributions and DistributionsForwardDiff extensions --- Project.toml | 7 ++++++- ext/MeasureBaseDistributionsExt.jl | 8 ++++++++ ext/MeasureBaseDistributionsForwardDiffExt.jl | 9 +++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 ext/MeasureBaseDistributionsExt.jl create mode 100644 ext/MeasureBaseDistributionsForwardDiffExt.jl diff --git a/Project.toml b/Project.toml index 058dea9d..fef0dc97 100644 --- a/Project.toml +++ b/Project.toml @@ -30,10 +30,13 @@ Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [extensions] MeasureBaseChainRulesCoreExt = "ChainRulesCore" +MeasureBaseDistributionsExt = "Distributions" +MeasureBaseDistributionsForwardDiffExt = ["Distributions", "ForwardDiff"] MeasureBaseForwardDiffExt = "ForwardDiff" [compat] @@ -42,6 +45,8 @@ ChangesOfVariables = "0.1.3" Compat = "3.35, 4" ConstructionBase = "1.3" DensityInterface = "0.4" +Distributions = "0.25.1" +Distributions = "0.25.111" FillArrays = "0.12, 0.13, 1" ForwardDiff = "0.8, 0.9, 0.10" FunctionChains = "0.1" @@ -50,8 +55,8 @@ IntervalSets = "0.7" InverseFunctions = "0.1.8" IrrationalConstants = "0.1, 0.2" LinearAlgebra = "1" -LogExpFunctions = "0.3" LogarithmicNumbers = "1" +LogExpFunctions = "0.3" MappedArrays = "0.4" NaNMath = "0.3, 1" PrettyPrinting = "0.3, 0.4" diff --git a/ext/MeasureBaseDistributionsExt.jl b/ext/MeasureBaseDistributionsExt.jl new file mode 100644 index 00000000..beb47821 --- /dev/null +++ b/ext/MeasureBaseDistributionsExt.jl @@ -0,0 +1,8 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +module MeasureBaseDistributionsExt + +using MeasureBase +import Distributions + +end # module MeasureBaseDistributionsExt diff --git a/ext/MeasureBaseDistributionsForwardDiffExt.jl b/ext/MeasureBaseDistributionsForwardDiffExt.jl new file mode 100644 index 00000000..36218eec --- /dev/null +++ b/ext/MeasureBaseDistributionsForwardDiffExt.jl @@ -0,0 +1,9 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +module MeasureBaseDistributionsForwardDiffExt + +using MeasureBase +import Distributions +import ForwardDiff + +end # module MeasureBaseDistributionsForwardDiffExt From aad3dae9ef8b4da8588259d421e0bbb85a7caf75 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 15 Sep 2024 16:21:51 +0200 Subject: [PATCH 05/28] Add DistributionsChainRulesCore extension --- Project.toml | 1 + ext/MeasureBaseDistributionsChainRulesCoreExt.jl | 9 +++++++++ 2 files changed, 10 insertions(+) create mode 100644 ext/MeasureBaseDistributionsChainRulesCoreExt.jl diff --git a/Project.toml b/Project.toml index fef0dc97..a1c38ff6 100644 --- a/Project.toml +++ b/Project.toml @@ -36,6 +36,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" [extensions] MeasureBaseChainRulesCoreExt = "ChainRulesCore" MeasureBaseDistributionsExt = "Distributions" +MeasureBaseDistributionsChainRulesCoreExt = ["Distributions", "ChainRulesCore"] MeasureBaseDistributionsForwardDiffExt = ["Distributions", "ForwardDiff"] MeasureBaseForwardDiffExt = "ForwardDiff" diff --git a/ext/MeasureBaseDistributionsChainRulesCoreExt.jl b/ext/MeasureBaseDistributionsChainRulesCoreExt.jl new file mode 100644 index 00000000..4dd3f4ff --- /dev/null +++ b/ext/MeasureBaseDistributionsChainRulesCoreExt.jl @@ -0,0 +1,9 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +module MeasureBaseDistributionsChainRulesCoreExt + +using MeasureBase +import Distributions +import ChainRulesCore + +end # module MeasureBaseDistributionsChainRulesCoreExt From b70ce70afb7414af41727fca7eecea2b7987d41d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 15 Sep 2024 15:50:57 +0200 Subject: [PATCH 06/28] Add function asmeasure Will be used a lot when bridging from Distributions to MeasureBase. --- src/MeasureBase.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 8f7161c1..ed2d7827 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -58,6 +58,21 @@ abstract type AbstractMeasure end AbstractMeasure(m::AbstractMeasure) = m + +""" + asmeasure(m) + +Turns a measure-like object `m` into an `AbstractMeasure`. + +Calls `convert(AbstractMeasure, m)` by default +""" +function asmeasure end + +@inline asmeasure(m::AbstractMeasure) = m +asmeasure(m) = convert(AbstractMeasure, m) +export asmeasure + + function Pretty.quoteof(d::M) where {M<:AbstractMeasure} the_names = fieldnames(typeof(d)) :($M($([getfield(d, n) for n in the_names]...))) From 0f89a572aa4785f86ba0675ff0d7d40585082df5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 18:42:14 +0100 Subject: [PATCH 07/28] Add AsMeasure --- src/MeasureBase.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index ed2d7827..1bf198b9 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -73,6 +73,28 @@ asmeasure(m) = convert(AbstractMeasure, m) export asmeasure +""" + struct AsMeasure{T} + +Wrapes a measure-like object into an `AbstractMeasure`. + +Constructor: + +``` +AsMeasure{T}(obj::T) +``` + +User code should not create instances of `AsMeasure` directly, but should +call `asmeasure(obj)` instead. +""" +struct AsMeasure{T} + obj::T + + AsMeasure{T}(obj::T) = new(obj) +end + + + function Pretty.quoteof(d::M) where {M<:AbstractMeasure} the_names = fieldnames(typeof(d)) :($M($([getfield(d, n) for n in the_names]...))) From 0fb20093761b1f37aa2d3202a73defde5fcfc658 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:00:29 +0100 Subject: [PATCH 08/28] Add collection utils --- ext/MeasureBaseChainRulesCoreExt.jl | 47 +++++++++++++++++++++++++++++ src/MeasureBase.jl | 1 + src/collection_utils.jl | 24 +++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 src/collection_utils.jl diff --git a/ext/MeasureBaseChainRulesCoreExt.jl b/ext/MeasureBaseChainRulesCoreExt.jl index 56aabf72..0b4de769 100644 --- a/ext/MeasureBaseChainRulesCoreExt.jl +++ b/ext/MeasureBaseChainRulesCoreExt.jl @@ -7,6 +7,53 @@ using ChainRulesCore: NoTangent, ZeroTangent import ChainRulesCore +# = collection utils ========================================================= + +using MeasureBase: _dropfront, _dropback, _rev_cumsum, _exp_cumsum_log + +function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x) + result = _pushfront(v, x) + function _pushfront_pullback(thunked_ΔΩ) + ΔΩ = ChainRulesCore.unthunk(thunked_ΔΩ) + (NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)]) + end + return result, _pushfront_pullback +end + + +function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x) + result = _pushback(v, x) + function _pushback_pullback(thunked_ΔΩ) + ΔΩ = ChainRulesCore.unthunk(thunked_ΔΩ) + (NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)]) + end + return result, _pushback_pullback +end + + +function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector) + result = _rev_cumsum(xs) + function _rev_cumsum_pullback(ΔΩ) + ∂xs = ChainRulesCore.@thunk cumsum(ChainRulesCore.unthunk(ΔΩ)) + (NoTangent(), ∂xs) + end + return result, _rev_cumsum_pullback +end + + +function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector) + result = _exp_cumsum_log(xs) + function _exp_cumsum_log_pullback(ΔΩ) + ∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* ChainRulesCore.unthunk(ΔΩ)) + (NoTangent(), ∂xs) + end + return result, _exp_cumsum_log_pullback +end + + +# = measure functions ======================================================== + + @inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result) y = _checksupport(cond, result) function _checksupport_pullback(ȳ) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 1bf198b9..4d444cc6 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -144,6 +144,7 @@ using Compat using IrrationalConstants include("static.jl") +include("collection_utils.jl") include("smf.jl") include("getdof.jl") include("transport.jl") diff --git a/src/collection_utils.jl b/src/collection_utils.jl new file mode 100644 index 00000000..1de51f7e --- /dev/null +++ b/src/collection_utils.jl @@ -0,0 +1,24 @@ +function _pushfront(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[firstindex(r)] = x + r[firstindex(r)+1:lastindex(r)] = v + r +end + +function _pushback(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[lastindex(r)] = x + r[firstindex(r):lastindex(r)-1] = v + r +end + +_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)] + +_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1] + +_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs))) + +# Equivalent to `cumprod(xs)``: +_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs))) From 04e24b02f877458a7b3e28e0b44108d45e9ffb35 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:00:29 +0100 Subject: [PATCH 09/28] Add mreshape --- src/MeasureBase.jl | 1 + src/combinators/reshape.jl | 49 +++++++++++++++++++++++++++++++++++++ test/combinators/reshape.jl | 7 ++++++ test/runtests.jl | 1 + 4 files changed, 58 insertions(+) create mode 100644 src/combinators/reshape.jl create mode 100644 test/combinators/reshape.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 4d444cc6..3dbab005 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -166,6 +166,7 @@ include("primitives/trivial.jl") include("combinators/bind.jl") include("combinators/transformedmeasure.jl") +include("combinators/reshape.jl") include("combinators/weighted.jl") include("combinators/superpose.jl") include("combinators/product.jl") diff --git a/src/combinators/reshape.jl b/src/combinators/reshape.jl new file mode 100644 index 00000000..1f24ca85 --- /dev/null +++ b/src/combinators/reshape.jl @@ -0,0 +1,49 @@ +# ToDo: Support static resizes for static arrays + +""" + struct MeasureBase.Reshape <: Function + +Represents a function that reshapes an array. + +Supports `InverseFunctions.inverse` and +`ChangesOfVariables.with_logabsdet_jacobian`. + +Constructor: + +```julia +Reshape(output_size::Dims, input_size::Dims) +``` +""" +struct Reshape{M,N} <: Function + output_size::NTuple{M,Int} + input_size::NTuple{N,Int} +end + +_throw_reshape_mismatch(sz, sz_x) = throw(DimensionMismatch("Reshape input size is $sz but got input of size $sz_x")) + +function (f::Reshape)(x::AbstractArray) + sz_x = size(x) + f.input_size == sz_x || _throw_reshape_mismatch(f.input_size, sz_x) + return reshape(x, f.output_size) +end + +InverseFunctions.inverse(f::Reshape) = Reshape(f.input_size, f.output_size) + +ChangesOfVariables.with_logabsdet_jacobian(::Reshape, x::AbstractArray) = zero(real_numtype(typeof(x))) + + +""" + mreshape(m::AbstractMeasure, sz::Vararg{N,Integer}) where N + mreshape(m::AbstractMeasure, sz::NTuple{N,Integer}) where N + +Reshape a measure `m` over an array-valued space, returning a measure over +a space of arrays with shape `sz`. +""" +function mreshape end + +_elsize_for_reshape(m::AbstractMeasure) = _elsize_for_reshape(some_mspace_elsize(m), m) +_elsize_for_reshape(sz::NTuple{<:Any,Integer}, ::AbstractMeasure) = sz +_elsize_for_reshape(::NoMSpaceElementSize, m::AbstractMeasure) = size(testvalue(m)) + +mreshape(m::AbstractMeasure, sz::Vararg{<:Any,Integer}) = mreshape(m, sz) +mreshape(m::AbstractMeasure, sz::NTuple{<:Any,Integer}) = pushfwd(Reshape(sz, _elsize_for_reshape(m)), m) diff --git a/test/combinators/reshape.jl b/test/combinators/reshape.jl new file mode 100644 index 00000000..c6624582 --- /dev/null +++ b/test/combinators/reshape.jl @@ -0,0 +1,7 @@ +using Test + +using MeasureBase + +@testset "reshape" begin + +end diff --git a/test/runtests.jl b/test/runtests.jl index f9263b6d..b31a9da5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,5 +19,6 @@ include("smf.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") +include("combinators/reshape.jl") include("test_docs.jl") From 3135e1e0ea1c1a75e8f657af66eaafafb7fab746 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:01:07 +0100 Subject: [PATCH 10/28] STASH Distributions ext impl --- Project.toml | 1 - ext/MeasureBaseDistributionsExt.jl | 3 +- ext/distributions/_bat_dist_transforms.jl | 477 ++++++++++++++++++ ext/distributions/autodiff_utils.jl | 75 +++ ext/distributions/dirac.jl | 14 + ext/distributions/dirichlet.jl | 33 ++ ext/distributions/dist_vartransform.jl | 16 + ext/distributions/distribution_measure.jl | 71 +++ ext/distributions/distributions.jl | 66 +++ ext/distributions/mixture.jl | 4 + ext/distributions/product.jl | 17 + ext/distributions/reshaped.jl | 13 + ext/distributions/standardmv.jl | 33 ++ ext/distributions/univariate.jl | 176 +++++++ ext/distributions/utils.jl | 32 ++ test/distributions/getjacobian.jl | 34 ++ test/distributions/test_autodiff_utils.jl | 19 + .../test_distribution_measure.jl | 54 ++ test/distributions/test_distributions.jl | 12 + test/distributions/test_measure_interface.jl | 44 ++ test/distributions/test_standard_dist.jl | 129 +++++ test/distributions/test_standard_normal.jl | 130 +++++ test/distributions/test_standard_uniform.jl | 119 +++++ test/distributions/test_transport.jl | 149 ++++++ test/runtests.jl | 2 + 25 files changed, 1720 insertions(+), 3 deletions(-) create mode 100644 ext/distributions/_bat_dist_transforms.jl create mode 100644 ext/distributions/autodiff_utils.jl create mode 100644 ext/distributions/dirac.jl create mode 100644 ext/distributions/dirichlet.jl create mode 100644 ext/distributions/dist_vartransform.jl create mode 100644 ext/distributions/distribution_measure.jl create mode 100644 ext/distributions/distributions.jl create mode 100644 ext/distributions/mixture.jl create mode 100644 ext/distributions/product.jl create mode 100644 ext/distributions/reshaped.jl create mode 100644 ext/distributions/standardmv.jl create mode 100644 ext/distributions/univariate.jl create mode 100644 ext/distributions/utils.jl create mode 100644 test/distributions/getjacobian.jl create mode 100644 test/distributions/test_autodiff_utils.jl create mode 100644 test/distributions/test_distribution_measure.jl create mode 100644 test/distributions/test_distributions.jl create mode 100644 test/distributions/test_measure_interface.jl create mode 100644 test/distributions/test_standard_dist.jl create mode 100644 test/distributions/test_standard_normal.jl create mode 100644 test/distributions/test_standard_uniform.jl create mode 100644 test/distributions/test_transport.jl diff --git a/Project.toml b/Project.toml index a1c38ff6..881a80e0 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,6 @@ Compat = "3.35, 4" ConstructionBase = "1.3" DensityInterface = "0.4" Distributions = "0.25.1" -Distributions = "0.25.111" FillArrays = "0.12, 0.13, 1" ForwardDiff = "0.8, 0.9, 0.10" FunctionChains = "0.1" diff --git a/ext/MeasureBaseDistributionsExt.jl b/ext/MeasureBaseDistributionsExt.jl index beb47821..a4c74601 100644 --- a/ext/MeasureBaseDistributionsExt.jl +++ b/ext/MeasureBaseDistributionsExt.jl @@ -2,7 +2,6 @@ module MeasureBaseDistributionsExt -using MeasureBase -import Distributions +include "distributions/distributions.jl" end # module MeasureBaseDistributionsExt diff --git a/ext/distributions/_bat_dist_transforms.jl b/ext/distributions/_bat_dist_transforms.jl new file mode 100644 index 00000000..3b8043a2 --- /dev/null +++ b/ext/distributions/_bat_dist_transforms.jl @@ -0,0 +1,477 @@ + + +# Use ForwardDiff for univariate distribution transformations: +@inline function ChainRulesCore.rrule(::typeof(apply_dist_trafo), trg_d::Distribution{Univariate}, src_d::Distribution{Univariate}, src_v::Any) + ChainRulesCore.rrule(fwddiff(apply_dist_trafo), trg_d, src_d, src_v) +end + + + +const _StdDistType = Union{Uniform, Normal} + +_trg_disttype(::Type{Uniform}, ::Type{Univariate}) = StandardUvUniform +_trg_disttype(::Type{Uniform}, ::Type{Multivariate}) = StandardMvUniform +_trg_disttype(::Type{Normal}, ::Type{Univariate}) = StandardUvNormal +_trg_disttype(::Type{Normal}, ::Type{Multivariate}) = StandardMvNormal + +function _trg_dist(disttype::Type{<:_StdDistType}, source_dist::Distribution{Univariate,Continuous}) + trg_dt = _trg_disttype(disttype, Univariate) + trg_dt() +end + +function _trg_dist(disttype::Type{<:_StdDistType}, source_dist::Distribution{Multivariate,Continuous}) + trg_dt = _trg_disttype(disttype, Multivariate) + trg_dt(eff_totalndof(source_dist)) +end + +function _trg_dist(disttype::Type{<:_StdDistType}, source_dist::ContinuousDistribution) + trg_dt = _trg_disttype(disttype, Multivariate) + trg_dt(eff_totalndof(source_dist)) +end + + +function DistributionTransform(disttype::Type{<:_StdDistType}, source_dist::ContinuousDistribution) + trg_d = _trg_dist(disttype, source_dist) + DistributionTransform(trg_d, source_dist) +end + + +function std_dist_from(src_d::Distribution) + throw(ArgumentError("No standard intermediate distribution defined to transform from $(typeof(src_d).name)")) +end + +function std_dist_to(trg_d::Distribution) + throw(ArgumentError("No standard intermediate distribution defined to transform into $(typeof(trg_d).name)")) +end + + +@inline function _intermediate_std_dist(trg_d::Distribution, src_d::Distribution) + _select_intermediate_dist(std_dist_to(trg_d), std_dist_from(src_d)) +end + +@inline _intermediate_std_dist(::Union{StdUvDist,StdMvDist}, src_d::Distribution) = std_dist_from(src_d) + +@inline _intermediate_std_dist(trg_d::Distribution, ::Union{StdUvDist,StdMvDist}) = std_dist_to(trg_d) + +function _intermediate_std_dist(::Union{StdUvDist,StdMvDist}, ::Union{StdUvDist,StdMvDist}) + throw(ArgumentError("Direct conversions must be used between standard intermediate distributions")) +end + +@inline _select_intermediate_dist(a::D, ::D) where D<:Union{StdUvDist,StdMvDist} = a +@inline _select_intermediate_dist(a::D, ::D) where D<:Union{StandardUvUniform,StandardMvUniform} = a +@inline _select_intermediate_dist(a::Union{StandardUvUniform,StandardMvUniform}, ::Union{StdUvDist,StdMvDist}) = a +@inline _select_intermediate_dist(::Union{StdUvDist,StdMvDist}, b::Union{StandardUvUniform,StandardMvUniform}) = b + +_check_conv_eff_totalndof(trg_d::Uniform, src_d::Uniform) = nothing + +function _check_conv_eff_totalndof(trg_d::Distribution, src_d::Distribution) + trg_d_n = eff_totalndof(trg_d) + src_d_n = eff_totalndof(src_d) + if trg_d_n != src_d_n + throw(ArgumentError("Can't convert to $(typeof(trg_d).name) with $(trg_d_n) eff. DOF from $(typeof(src_d).name) with $(src_d_n) eff. DOF")) + end + nothing +end + +function apply_dist_trafo(trg_d::Distribution, src_d::Distribution, src_v::Any) + _check_conv_eff_totalndof(trg_d, src_d) + intermediate_d = _intermediate_std_dist(trg_d, src_d) + intermediate_d === trg_d && throw(ArgumentError("No transformation path between distributions")) + intermediate_v = apply_dist_trafo(intermediate_d, src_d, src_v) + apply_dist_trafo(trg_d, intermediate_d, intermediate_v) +end + + +function apply_dist_trafo(trg_d::DT, src_d::DT, src_v) where {DT <: StdMvDist} + @argcheck src_v isa AbstractVector{<:Real} + @argcheck length(trg_d) == length(src_d) == length(eachindex(src_v)) + return src_v +end + + +_dist_params_numtype(d::Distribution) = realnumtype(typeof(params(d))) + +function ChainRulesCore.rrule(::typeof(_dist_params_numtype), d::Distribution) + _dist_params_numtype_pullback(ΔΩ) = (NoTangent(), NoTangent()) + _dist_params_numtype(d), _dist_params_numtype_pullback +end + + +@inline _trafo_cdf(d::Distribution{Univariate,Continuous}, x::Real) = _trafo_cdf_impl(_dist_params_numtype(d), d, x) + +@inline _trafo_cdf_impl(::Type{<:Real}, d::Distribution{Univariate,Continuous}, x::Real) = cdf(d, x) + +@inline function _trafo_cdf_impl(::Type{<:Union{Integer,AbstractFloat}}, d::Distribution{Univariate,Continuous}, x::ForwardDiff.Dual{TAG}) where TAG + x_v = ForwardDiff.value(x) + u = cdf(d, x_v) + dudx = pdf(d, x_v) + ForwardDiff.Dual{TAG}(u, dudx * ForwardDiff.partials(x)) +end + + +@inline _trafo_quantile(d::Distribution{Univariate,Continuous}, u::Real) = _trafo_quantile_impl(_dist_params_numtype(d), d, u) + +@inline _trafo_quantile_impl(::Type{<:Real}, d::Distribution{Univariate,Continuous}, u::Real) = _trafo_quantile_impl_generic(d, u) + +@inline function _trafo_quantile_impl(::Type{<:Union{Integer,AbstractFloat}}, d::Distribution{Univariate,Continuous}, u::ForwardDiff.Dual{TAG}) where {TAG} + x = _trafo_quantile_impl_generic(d, ForwardDiff.value(u)) + dxdu = inv(pdf(d, x)) + ForwardDiff.Dual{TAG}(x, dxdu * ForwardDiff.partials(u)) +end + +# Workaround for Beta dist, ForwardDiff doesn't work for parameters: +@inline _trafo_quantile_impl_generic(d::Beta{T}, u::Real) where {T<:ForwardDiff.Dual} = convert(float(typeof(u)), NaN) +# Workaround for Beta dist, current quantile implementation only supports Float64: +@inline _trafo_quantile_impl_generic(d::Beta{T}, u::Union{Integer,AbstractFloat}) where {T<:Union{Integer,AbstractFloat}} = _trafo_quantile_impl(T, d, convert(promote_type(Float64, typeof(u)), u)) +# Workaround for StatsFuns issues #133, caused by SpecialFunctions, fixed in SpecialFunctions v2.1.4: +@inline _trafo_quantile_impl_generic(d::Beta{T}, u::Float64) where {T<:Union{Integer,AbstractFloat}} = (d.α ≈ 1 && d.β ≈ 1 && u < 1e-19) ? u : convert(Float64, quantile(d, u)) + +@inline _trafo_quantile_impl_generic(d::Distribution{Univariate,Continuous}, u::Real) = quantile(d, u) + +# Workaround for rounding errors that can result in quantile values outside of support of Truncated: +@inline function _trafo_quantile_impl_generic(d::Truncated{<:Distribution{Univariate,Continuous}}, u::Real) + x = quantile(d, u) + T = typeof(x) + min_x = T(minimum(d)) + max_x = T(maximum(d)) + if x < min_x && isapprox(x, min_x, atol = 4 * eps(T)) + min_x + elseif x > max_x && isapprox(x, max_x, atol = 4 * eps(T)) + max_x + else + x + end +end + + +@inline function _eval_dist_trafo_func(f::typeof(_trafo_cdf), d::Distribution{Univariate,Continuous}, src_v::Real) + R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d))) + if insupport(d, src_v) + trg_v = f(d, src_v) + convert(R_V, trg_v) + else + convert(R_V, NaN) + end +end + +@inline function _eval_dist_trafo_func(f::typeof(_trafo_quantile), d::Distribution{Univariate,Continuous}, src_v::Real) + R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d))) + if 0 <= src_v <= 1 + trg_v = f(d, src_v) + convert(R_V, trg_v) + else + convert(R_V, NaN) + end +end + + +std_dist_from(src_d::Distribution{Univariate,Continuous}) = StandardUvUniform() + +function apply_dist_trafo(::StandardUvUniform, src_d::Distribution{Univariate,Continuous}, src_v::Real) + _eval_dist_trafo_func(_trafo_cdf, src_d, src_v) +end + +std_dist_to(trg_d::Distribution{Univariate,Continuous}) = StandardUvUniform() + +function apply_dist_trafo(trg_d::Distribution{Univariate,Continuous}, ::StandardUvUniform, src_v::Real) + TV = float(typeof(src_v)) + # Avoid src_v ≈ 0 and src_v ≈ 1 to avoid infinite variate values for target distributions with infinite support: + mod_src_v = ifelse(src_v ≈ 0, zero(TV) + eps(TV), ifelse(src_v ≈ 1, one(TV) - eps(TV), convert(TV, src_v))) + _eval_dist_trafo_func(_trafo_quantile, trg_d, mod_src_v) +end + + + +function _dist_trafo_rescale_impl(trg_d, src_d, src_v::Real) + R = float(typeof(src_v)) + trg_offs, trg_scale = location(trg_d), scale(trg_d) + src_offs, src_scale = location(src_d), scale(src_d) + rescale_factor = trg_scale / src_scale + (src_v - src_offs) * rescale_factor + trg_offs +end + +@inline apply_dist_trafo(trg_d::Uniform, src_d::Uniform, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) +@inline apply_dist_trafo(trg_d::StandardUvUniform, src_d::Uniform, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) +@inline apply_dist_trafo(trg_d::Uniform, src_d::StandardUvUniform, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) + +# ToDo: Use StandardUvNormal as standard intermediate dist for Normal? Would +# be useful if StandardUvNormal would be a better standard intermediate than +# StandardUvUniform for some other uniform distributions as well. +# +# std_dist_from(src_d::Normal) = StandardUvNormal() +# std_dist_to(trg_d::Normal) = StandardUvNormal() + +@inline apply_dist_trafo(trg_d::Normal, src_d::Normal, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) +@inline apply_dist_trafo(trg_d::StandardUvNormal, src_d::Normal, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) +@inline apply_dist_trafo(trg_d::Normal, src_d::StandardUvNormal, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) + + +# ToDo: Optimized implementation for Distributions.Truncated <-> StandardUvUniform + + +@inline apply_dist_trafo(trg_d::StandardUvUniform, src_d::StandardUvUniform, src_v::Real) = src_v + +@inline apply_dist_trafo(trg_d::StandardUvNormal, src_d::StandardUvNormal, src_v::Real) = src_v + +@inline function apply_dist_trafo(trg_d::StandardUvUniform, src_d::StandardUvNormal, src_v::Real) + apply_dist_trafo(StandardUvUniform(), Normal(), src_v) +end + +@inline function apply_dist_trafo(trg_d::StandardUvNormal, src_d::StandardUvUniform, src_v::Real) + apply_dist_trafo(Normal(), StandardUvUniform(), src_v) +end + + +@inline function apply_dist_trafo(trg_d::StandardMvUniform, src_d::StandardMvNormal, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(StandardUvUniform(), StandardUvNormal(), src_v) +end + +@inline function apply_dist_trafo(trg_d::StandardMvNormal, src_d::StandardMvUniform, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(StandardUvNormal(), StandardUvUniform(), src_v) +end + + +std_dist_from(src_d::MvNormal) = StandardMvNormal(length(src_d)) + +_cholesky_L(A) = cholesky(A).L +_cholesky_L(A::Diagonal{<:Real}) = Diagonal(sqrt.(diag(A))) +_cholesky_L(A::PDiagMat{<:Real}) = Diagonal(sqrt.(A.diag)) +_cholesky_L(A::ScalMat{<:Real}) = Diagonal(Fill(sqrt(A.value), A.dim)) + +function apply_dist_trafo(trg_d::StandardMvNormal, src_d::MvNormal, src_v::AbstractVector{<:Real}) + @argcheck length(trg_d) == length(src_d) + _cholesky_L(src_d.Σ) \ (src_v - src_d.μ) +end + +std_dist_to(trg_d::MvNormal) = StandardMvNormal(length(trg_d)) + +function apply_dist_trafo(trg_d::MvNormal, src_d::StandardMvNormal, src_v::AbstractVector{<:Real}) + @argcheck length(trg_d) == length(src_d) + _cholesky_L(trg_d.Σ) * src_v + trg_d.μ +end + + +eff_totalndof(d::Dirichlet) = length(d) - 1 +eff_totalndof(d::DistributionsAD.TuringDirichlet) = length(d) - 1 + +std_dist_to(trg_d::Dirichlet) = StandardMvUniform(eff_totalndof(trg_d)) +std_dist_to(trg_d::DistributionsAD.TuringDirichlet) = StandardMvUniform(eff_totalndof(trg_d)) + +std_dist_from(trg_d::Dirichlet) = StandardMvUniform(eff_totalndof(trg_d)) +std_dist_from(trg_d::DistributionsAD.TuringDirichlet) = StandardMvUniform(eff_totalndof(trg_d)) + + +function apply_dist_trafo(trg_d::Dirichlet, src_d::StandardMvUniform, src_v::AbstractVector{<:Real}) + apply_dist_trafo(DistributionsAD.TuringDirichlet(trg_d.alpha), src_d, src_v) +end + +function apply_dist_trafo(trg_d::StandardMvUniform, src_d::Dirichlet, src_v::AbstractVector{<:Real}) + apply_dist_trafo(trg_d, DistributionsAD.TuringDirichlet(src_d.alpha), src_v) +end + +function _dirichlet_beta_trafo(α::Real, β::Real, src_v::Real) + R = float(promote_type(typeof(α), typeof(β), typeof(src_v))) + convert(R, apply_dist_trafo(Beta(α, β), StandardUvUniform(), src_v))::R +end + +_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b) + +function apply_dist_trafo(trg_d::DistributionsAD.TuringDirichlet, src_d::StandardMvUniform, src_v::AbstractVector{<:Real}) + # See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution", + # https://arxiv.org/abs/1010.3436 + + @_adignore @argcheck length(trg_d) == length(src_d) + 1 + αs = _dropfront(_rev_cumsum(trg_d.alpha)) + βs = _dropback(trg_d.alpha) + beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, src_v) + beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1)) + beta_v_ext = _pushback(beta_v, 0) + fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext) +end + +function _inv_dirichlet_beta_trafo(α::Real, β::Real, beta_v::Real) + R = float(promote_type(typeof(α), typeof(β), typeof(beta_v))) + convert(R, apply_dist_trafo(StandardUvUniform(), Beta(α, β), beta_v))::R +end + +# ToDo: Find efficient pullback for this: +function _dirichlet_variate_to_beta_v(src_v::AbstractVector{<:Real}) + idxs = eachindex(src_v) + beta_v = similar(src_v, length(idxs) - 1) + @assert firstindex(beta_v) == firstindex(src_v) + @assert lastindex(beta_v) == lastindex(src_v) - 1 + T = eltype(src_v) + sum_log_beta_v::T = 0 + @inbounds for i in eachindex(beta_v) + beta_v[i] = 1 - src_v[i] / exp(sum_log_beta_v) + sum_log_beta_v += log(beta_v[i]) + end + return beta_v +end + +# ToDo: Make Zygote-compatible: +function apply_dist_trafo(trg_d::StandardMvUniform, src_d::DistributionsAD.TuringDirichlet, src_v::AbstractVector{<:Real}) + @_adignore @argcheck length(trg_d) == length(src_d) - 1 + αs = _dropfront(_rev_cumsum(src_d.alpha)) + βs = _dropback(src_d.alpha) + beta_v = _dirichlet_variate_to_beta_v(src_v) + fwddiff(_inv_dirichlet_beta_trafo).(αs, βs, beta_v) +end + + +function _product_dist_trafo_impl(trg_ds, src_ds, src_v::AbstractVector{<:Real}) + fwddiff(apply_dist_trafo).(trg_ds, src_ds, src_v) +end + +function apply_dist_trafo(trg_d::Distributions.Product, src_d::Distributions.Product, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(trg_d.v, src_d.v, src_v) +end + +function apply_dist_trafo(trg_d::StandardMvUniform, src_d::Distributions.Product, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(StandardUvUniform(), src_d.v, src_v) +end + +function apply_dist_trafo(trg_d::StandardMvNormal, src_d::Distributions.Product, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(StandardUvNormal(), src_d.v, src_v) +end + +function apply_dist_trafo(trg_d::Distributions.Product, src_d::StandardMvUniform, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(trg_d.v, StandardUvUniform(), src_v) +end + +function apply_dist_trafo(trg_d::Distributions.Product, src_d::StandardMvNormal, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(trg_d.v, StandardUvNormal(), src_v) +end + + +_flat_ntd_orig_elshape(d::Distribution) = ArrayShape{Real}(totalndof(varshape(d))) + +function _flat_ntd_orig_accessors(d::NamedTupleDist{names,DT,AT,VT}) where {names,DT,AT,VT} + shapes = map(_flat_ntd_orig_elshape, values(d)) + vs = NamedTupleShape(VT, NamedTuple{names}(shapes)) + values(vs) +end + +_flat_ntd_eff_elshape(d::Distribution) = ArrayShape{Real}(eff_totalndof(d)) + +function _flat_ntd_eff_accessors(d::NamedTupleDist{names,DT,AT,VT}) where {names,DT,AT,VT} + shapes = map(_flat_ntd_eff_elshape, values(d)) + vs = NamedTupleShape(VT, NamedTuple{names}(shapes)) + values(vs) +end + +function _flat_ntdistelem_to_stdmv(trg_d::StdMvDist, sd::Distribution, src_v_unshaped::AbstractVector{<:Real}, src_acc::ValueAccessor) + td = view(trg_d, Base.OneTo(eff_totalndof(sd))) + sv = src_acc(src_v_unshaped) + apply_dist_trafo(td, unshaped(sd), sv) +end + +function _flat_ntdistelem_to_stdmv(trg_d::StdMvDist, sd::ConstValueDist, src_v_unshaped::AbstractVector{<:Real}, src_acc::ValueAccessor) + Bool[] +end + +function apply_dist_trafo(trg_d::StdMvDist, src_d::ValueShapes.UnshapedNTD, src_v::AbstractVector{<:Real}) + @argcheck length(src_d) == length(eachindex(src_v)) + src_accessors = _flat_ntd_orig_accessors(src_d.shaped) + rs = map((src_acc, sd) -> _flat_ntdistelem_to_stdmv(trg_d, sd, src_v, src_acc), src_accessors, values(src_d.shaped)) + vcat(rs...) +end + +apply_dist_trafo(trg_d::StdMvDist, src_d::ValueShapes.UnshapedNTD, src_v) = throw(ArgumentError("Invalid variate type $(nameof(typeof(src_v)))) for NamedTupleDist")) + +function apply_dist_trafo(trg_d::StdMvDist, src_d::NamedTupleDist, src_v::Union{NamedTuple,ShapedAsNT}) + src_v_unshaped = unshaped(src_v, varshape(src_d)) + apply_dist_trafo(trg_d, unshaped(src_d), src_v_unshaped) +end + +apply_dist_trafo(trg_d::StdMvDist, src_d::NamedTupleDist, src_v) = throw(ArgumentError("Invalid variate type $(nameof(typeof(src_v))) for NamedTupleDist")) + + +function _stdmv_to_flat_ntdistelem(td::Distribution, src_d::StdMvDist, src_v::AbstractVector{<:Real}, src_acc::ValueAccessor) + sd = view(src_d, ValueShapes.view_idxs(Base.OneTo(length(src_d)), src_acc)) + sv = src_acc(src_v) + apply_dist_trafo(unshaped(td), sd, sv) +end + +function _stdmv_to_flat_ntdistelem(td::ConstValueDist, src_d::StdMvDist, src_v::AbstractVector{<:Real}, src_acc::ValueAccessor) + Bool[] +end + +function apply_dist_trafo(trg_d::ValueShapes.UnshapedNTD, src_d::StdMvDist, src_v::AbstractVector{<:Real}) + @argcheck length(src_d) == length(eachindex(src_v)) + src_accessors = _flat_ntd_eff_accessors(trg_d.shaped) + rs = map((acc, td) -> _stdmv_to_flat_ntdistelem(td, src_d, src_v, acc), src_accessors, values(trg_d.shaped)) + vcat(rs...) +end + +function apply_dist_trafo(trg_d::NamedTupleDist, src_d::StdMvDist, src_v::AbstractVector{<:Real}) + unshaped_result = apply_dist_trafo(unshaped(trg_d), src_d, src_v) + varshape(trg_d)(unshaped_result) +end + +@static if isdefined(Distributions, :ReshapedDistribution) + const AnyReshapedDist = Union{Distributions.ReshapedDistribution,ValueShapes.ReshapedDist} +else + const AnyReshapedDist = Union{Distributions.MatrixReshaped,ValueShapes.ReshapedDist} +end + +eff_totalndof(d::AnyReshapedDist) = eff_totalndof(unshaped(d)) +std_dist_from(src_d::AnyReshapedDist) = std_dist_from(unshaped(src_d)) +std_dist_to(trg_d::AnyReshapedDist) = std_dist_to(unshaped(trg_d)) + +function apply_dist_trafo(trg_d::Distribution{Multivariate}, src_d::AnyReshapedDist, src_v::Any) + src_vs = varshape(src_d) + @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + apply_dist_trafo(trg_d, unshaped(src_d), unshaped(src_v, src_vs)) +end + +function apply_dist_trafo(trg_d::AnyReshapedDist, src_d::Distribution{Multivariate}, src_v::AbstractVector{<:Real}) + trg_vs = varshape(trg_d) + @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + r = apply_dist_trafo(unshaped(trg_d), src_d, src_v) + trg_vs(r) +end + +function apply_dist_trafo(trg_d::AnyReshapedDist, src_d::AnyReshapedDist, src_v::AbstractVector{<:Real}) + trg_vs = varshape(trg_d) + src_vs = varshape(src_d) + @argcheck totalndof(trg_vs) == totalndof(src_vs) + r = apply_dist_trafo(unshaped(trg_d), unshaped(src_d), unshaped(src_v, src_vs)) + v = trg_vs(r) +end + + +function apply_dist_trafo(trg_d::StdMvDist, src_d::UnshapedHDist, src_v::AbstractVector{<:Real}) + src_v_primary, src_v_secondary = _hd_split(src_d, src_v) + trg_d_primary = typeof(trg_d)(length(eachindex(src_v_primary))) + trg_d_secondary = typeof(trg_d)(length(eachindex(src_v_secondary))) + trg_v_primary = apply_dist_trafo(trg_d_primary, _hd_pridist(src_d), src_v_primary) + trg_v_secondary = apply_dist_trafo(trg_d_secondary, _hd_secdist(src_d, src_v_primary), src_v_secondary) + vcat(trg_v_primary, trg_v_secondary) +end + +function apply_dist_trafo(trg_d::StdMvDist, src_d::HierarchicalDistribution, src_v::Any) + src_v_unshaped = unshaped(src_v, varshape(src_d)) + apply_dist_trafo(trg_d, unshaped(src_d), src_v_unshaped) +end + +function apply_dist_trafo(trg_d::UnshapedHDist, src_d::StdMvDist, src_v::AbstractVector{<:Real}) + src_v_primary, src_v_secondary = _hd_split_efftotalndof(trg_d, src_v) + src_d_primary = typeof(src_d)(length(eachindex(src_v_primary))) + src_d_secondary = typeof(src_d)(length(eachindex(src_v_secondary))) + trg_v_primary = apply_dist_trafo(_hd_pridist(trg_d), src_d_primary, src_v_primary) + trg_v_secondary = apply_dist_trafo(_hd_secdist(trg_d, trg_v_primary), src_d_secondary, src_v_secondary) + vcat(trg_v_primary, trg_v_secondary) +end + +function apply_dist_trafo(trg_d::HierarchicalDistribution, src_d::StdMvDist, src_v::AbstractVector{<:Real}) + unshaped_result = apply_dist_trafo(unshaped(trg_d), src_d, src_v) + varshape(trg_d)(unshaped_result) +end diff --git a/ext/distributions/autodiff_utils.jl b/ext/distributions/autodiff_utils.jl new file mode 100644 index 00000000..6f2e1516 --- /dev/null +++ b/ext/distributions/autodiff_utils.jl @@ -0,0 +1,75 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +@inline _adignore_call(f) = f() +@inline _adignore_call_pullback(@nospecialize ΔΩ) = (NoTangent(), NoTangent()) +ChainRulesCore.rrule(::typeof(_adignore_call), f) = _adignore_call(f), _adignore_call_pullback + +macro _adignore(expr) + :(_adignore_call(() -> $(esc(expr)))) +end + + +function _pushfront(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[firstindex(r)] = x + r[firstindex(r)+1:lastindex(r)] = v + r +end + +function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x) + result = _pushfront(v, x) + function _pushfront_pullback(thunked_ΔΩ) + ΔΩ = unthunk(thunked_ΔΩ) + (NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)]) + end + return result, _pushfront_pullback +end + + +function _pushback(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[lastindex(r)] = x + r[firstindex(r):lastindex(r)-1] = v + r +end + +function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x) + result = _pushback(v, x) + function _pushback_pullback(thunked_ΔΩ) + ΔΩ = unthunk(thunked_ΔΩ) + (NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)]) + end + return result, _pushback_pullback +end + + +_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)] + +_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1] + + +_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs))) + +function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector) + result = _rev_cumsum(xs) + function _rev_cumsum_pullback(ΔΩ) + ∂xs = @thunk cumsum(unthunk(ΔΩ)) + (NoTangent(), ∂xs) + end + return result, _rev_cumsum_pullback +end + + +# Equivalent to `cumprod(xs)``: +_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs))) + +function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector) + result = _exp_cumsum_log(xs) + function _exp_cumsum_log_pullback(ΔΩ) + ∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* unthunk(ΔΩ)) + (NoTangent(), ∂xs) + end + return result, _exp_cumsum_log_pullback +end diff --git a/ext/distributions/dirac.jl b/ext/distributions/dirac.jl new file mode 100644 index 00000000..9967f46f --- /dev/null +++ b/ext/distributions/dirac.jl @@ -0,0 +1,14 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +MeasureBase.AbstractMeasure(obj::Distributions.Dirac) = MeasureBase.Dirac(obj.value) + +function AsMeasure{D}(::D) where {D<:Distributions.Dirac} + throw(ArgumentError("Don't wrap Distributions.Dirac into MeasureBase.AsMeasure, use asmeasure to convert instead.")) +end + + +Distributions.Distribution(m::MeasureBase.Dirac{<:Real}) = Distribtions.Dirac(m.x) + +function Distributions.Distribution(@nospecialize(m::MeasureBase.Dirac{T})) where T + throw(ArgumentError("Can only convert MeasureBase.Dirac{<:Real} to Distributions.Dirac, but not MeasureBase.Dirac{<:$(nameof(T))}")) +end diff --git a/ext/distributions/dirichlet.jl b/ext/distributions/dirichlet.jl new file mode 100644 index 00000000..226532c0 --- /dev/null +++ b/ext/distributions/dirichlet.jl @@ -0,0 +1,33 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +const DirichletMeasure = AsMeasure{<:Dirichlet} + +MeasureBase.getdof(m::DirichletMeasure) = length(m.obj) - 1 + +MeasureBase.transport_origin(m::DirichletMeasure) = StdUniform()^getdof(m) + + + +function _dirichlet_beta_trafo(α::Real, β::Real, x::Real) + R = float(promote_type(typeof(α), typeof(β), typeof(x))) + convert(R, transport_def(Beta(α, β), StdUniform(), x))::R +end + +_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b) + +function MeasureBase.from_origin(ν::Dirichlet, x) + # See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution", + # https://arxiv.org/abs/1010.3436 + + # Sanity check (TODO - remove?): + @_adignore @argcheck length(ν) == length(x) + 1 + + αs = _dropfront(_rev_cumsum(ν.alpha)) + βs = _dropback(ν.alpha) + beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, x) + beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1)) + beta_v_ext = _pushback(beta_v, 0) + fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext) +end + +# ToDo: MeasureBase.to_origin(ν::Dirichlet, y) diff --git a/ext/distributions/dist_vartransform.jl b/ext/distributions/dist_vartransform.jl new file mode 100644 index 00000000..569fefa8 --- /dev/null +++ b/ext/distributions/dist_vartransform.jl @@ -0,0 +1,16 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +const _AnyStdUniform = Union{StandardUniform, Uniform} +const _AnyStdNormal = Union{StandardNormal, Normal} + +const _AnyStdDistribution = Union{_AnyStdUniform, _AnyStdNormal} + +_std_measure(::Type{<:_AnyStdUniform}) = StandardUniform +_std_measure(::Type{<:_AnyStdNormal}) = StandardNormal + +_std_measure(::Type{M}, ::StaticInt{1}) where {M<:_AnyStdDistribution} = M() +_std_measure(::Type{M}, dof::Integer) where {M<:_AnyStdDistribution} = M(dof) +_std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution} = _std_measure(_std_measure(M), getdof(μ)) + +MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution} = transport_to(_std_measure_for(NU, μ), μ) +MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution} = transport_to(ν, _std_measure_for(MU, ν)) diff --git a/ext/distributions/distribution_measure.jl b/ext/distributions/distribution_measure.jl new file mode 100644 index 00000000..e15f66c8 --- /dev/null +++ b/ext/distributions/distribution_measure.jl @@ -0,0 +1,71 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + + +const DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} = AsMeasure{D} + +@inline MeasureBase.AbstractMeasure(obj::Distribution) = AsMeasure{typeof(obj)}(obj) +@inline Base.convert(::Type{AbstractMeasure}, obj::Distribution) = AbstractMeasure(obj) + +@inline Distributions.Distribution(m::DistributionMeasure) = m.obj +@inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m) +@inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m) + +@inline Base.convert(::Type{Distribution}, m::DistributionMeasure) = Distribution(m) +@inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m) +@inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m) + + +Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = convert_realtype(T, rand(m.obj)) + +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real} + convert_realtype(T, reshape(rand(d, prod(sz)), sz...)) +end + +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real} + convert_realtype(T, reshape(rand(rng, d, prod(sz)), size(d)..., sz...)) +end + +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N} + convert_realtype(T, reshape(rand(rng, d.dist, prod(sz)), d.dims..., sz...)) +end + +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real} + flatview(ArrayOfSimilarArrays(convert_realtype(T, rand(rng, d, sz)))) +end + +function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N} + _flat_powrand(rng, T, m.parent.obj, map(length, m.axes)) +end + +function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{M}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,M,N} + flat_data = _flat_powrand(rng, T, m.parent.obj, map(length, m.axes)) + ArrayOfSimilarArrays{T,M,N}(flat_data) +end + + +@inline DensityInterface.densityof(m::DistributionMeasure) = densityof(m.obj) +@inline DensityInterface.logdensityof(m::DistributionMeasure) = logdensityof(m.obj) + +@inline MeasureBase.logdensity_def(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.obj, x) +@inline MeasureBase.unsafe_logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.obj, x) +@inline MeasureBase.insupport(m::DistributionMeasure, x) = Distributions.insupport(m.obj, x) + +@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue() +@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(m.obj) +@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Discrete}) = Counting() +@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(m.obj) + +@inline MeasureBase.basemeasure(m::DistributionMeasure) = rootmeasure(m) + +@inline MeasureBase.mspace_elsize(m::DistributionMeasure{<:ArrayLikeVariate}) = size(m.obj) + +@inline MeasureBase.getdof(m::DistributionMeasure{<:ArrayLikeVariate{0}}) = 1 + +@inline MeasureBase.paramnames(m::DistributionMeasure) = propertynames(m.obj) +@inline MeasureBase.params(m::DistributionMeasure) = NamedTuple{propertynames(m.obj)}(Distributions.params(m.obj)) + +# @inline MeasureBase.testvalue(m::DistributionMeasure) = testvalue(basemeasure(d)) + + +@inline MeasureBase.basemeasure(d::Distributions.Poisson) = Counting(MeasureBase.BoundedInts(static(0), static(Inf))) +@inline MeasureBase.basemeasure(d::Distributions.Product{<:Any,<:Distributions.Poisson}) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))^size(d) diff --git a/ext/distributions/distributions.jl b/ext/distributions/distributions.jl new file mode 100644 index 00000000..39a9620a --- /dev/null +++ b/ext/distributions/distributions.jl @@ -0,0 +1,66 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using LinearAlgebra: Diagonal, dot, cholesky + +import Random +using Random: AbstractRNG, rand! + +import DensityInterface +using DensityInterface: logdensityof + +import MeasureBase +using MeasureBase: AbstractMeasure, AsMeasure +using MeasureBase: Lebesgue, Counting, ℝ +using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic +using MeasureBase: PowerMeasure, WeightedMeasure +using MeasureBase: basemeasure, testvalue +using MeasureBase: getdof, checked_arg +using MeasureBase: transport_to, transport_def, transport_origin, from_origin, to_origin +using MeasureBase: NoTransformOrigin, NoTransport + +import Distributions +using Distributions: Distribution, VariateForm, ValueSupport, ContinuousDistribution +using Distributions: Univariate, Multivariate, ArrayLikeVariate, Continuous, Discrete +using Distributions: Uniform, Exponential, Logistic, Normal +using Distributions: MvNormal, Beta, Dirichlet +using Distributions: ReshapedDistribution + +import Statistics +import StatsBase +import StatsFuns +import PDMats + +using IrrationalConstants: log2π, invsqrt2π + +using Static: True, False, StaticInt, static +using FillArrays: Fill, Ones, Zeros + +import ChainRulesCore +using ChainRulesCore: ZeroTangent, NoTangent, unthunk, @thunk + +import ForwardDiff +using ForwardDiffPullbacks: fwddiff + +import Functors +using Functors: fmap + +using ArgCheck: @argcheck + +using ArraysOfArrays: ArrayOfSimilarArrays, flatview + +include("utils.jl") +include("autodiff_utils.jl") +include("standard_dist.jl") +include("standard_uniform.jl") +include("standard_normal.jl") +include("distribution_measure.jl") +include("dist_vartransform.jl") +include("univariate.jl") +include("standardmv.jl") +include("product.jl") +include("reshaped.jl") +include("dirichlet.jl") + +export StdNormal +export DistributionMeasure +export StandardDist diff --git a/ext/distributions/mixture.jl b/ext/distributions/mixture.jl new file mode 100644 index 00000000..587d3e73 --- /dev/null +++ b/ext/distributions/mixture.jl @@ -0,0 +1,4 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +# ToDo: +# AbstractMixtureModel: MixtureModel, UnivariateGMM diff --git a/ext/distributions/product.jl b/ext/distributions/product.jl new file mode 100644 index 00000000..07b38299 --- /dev/null +++ b/ext/distributions/product.jl @@ -0,0 +1,17 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +@static if isdefined(Distributions, :Product) + MeasureBase.AbstractMeasure(obj::Distributions.Product) = productmeasure(map(asmeasure, obj.v)) + + function AsMeasure{D}(::D) where {D<:Distributions.Product} + throw(ArgumentError("Don't wrap Distributions.Product into MeasureBase.AsMeasure, use asmeasure to convert instead.")) + end +end + +@static if isdefined(Distributions, :ProductDistribution) + MeasureBase.AbstractMeasure(obj::Distributions.ProductDistribution) = productmeasure(map(asmeasure, obj.dists)) + + function AsMeasure{D}(::D) where {D<:Distributions.ProductDistribution} + throw(ArgumentError("Don't wrap Distributions.ProductDistribution into MeasureBase.AsMeasure, use asmeasure to convert instead.")) + end +end diff --git a/ext/distributions/reshaped.jl b/ext/distributions/reshaped.jl new file mode 100644 index 00000000..6cbd5ede --- /dev/null +++ b/ext/distributions/reshaped.jl @@ -0,0 +1,13 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +function MeasureBase.AbstractMeasure(d::Distributions.ReshapedDistribution) + orig_dist = d.dist + pushfwd(Reshape(size(d), size(orig_dist)), AbstractMeasure(orig_dist)) +end + +function AsMeasure{D}(::D) where {D<:Distributions.ReshapedDistribution} + throw(ArgumentError("Don't wrap Distributions.ReshapedDistribution into MeasureBase.AsMeasure, use asmeasure to convert instead.")) +end + + +# ToDo: Conversion back to Distribution diff --git a/ext/distributions/standardmv.jl b/ext/distributions/standardmv.jl new file mode 100644 index 00000000..2e1e36ee --- /dev/null +++ b/ext/distributions/standardmv.jl @@ -0,0 +1,33 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + + +MeasureBase.getdof(m::AsMeasure{<:AbstractMvNormal}) = length(m.obj) + +MeasureBase.transport_origin(ν::MvNormal) = StandardDist{Normal}(length(ν)) + +function MeasureBase.from_origin(ν::MvNormal, x) + A = cholesky(ν.Σ).L + b = ν.μ + muladd(A, x, b) +end + +function MeasureBase.to_origin(ν::MvNormal, y) + A = cholesky(ν.Σ).L + b = ν.μ + A \ (y - b) +end + + +AbstractMvNormal +AbstractMvLogNormal + +#DirichletMultinomial +#Distributions.AbstractMvLogNormal +#Distributions.AbstractMvTDist +#Distributions.ProductDistribution{1} +#Distributions.ReshapedDistribution{1, S, D} where {S<:ValueSupport, D<:(Distribution{<:ArrayLikeVariate, S})} +#JointOrderStatistics +#Multinomial +#MultivariateMixture (alias for AbstractMixtureModel{ArrayLikeVariate{1}}) +#MvLogitNormal +#VonMisesFisher diff --git a/ext/distributions/univariate.jl b/ext/distributions/univariate.jl new file mode 100644 index 00000000..899f615a --- /dev/null +++ b/ext/distributions/univariate.jl @@ -0,0 +1,176 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + + +@inline MeasureBase.getdof(::Distribution{Univariate}) = static(1) + +@inline MeasureBase.check_dof(a::Distribution{Univariate}, b::Distribution{Univariate}) = nothing + + +# Use ForwardDiff for univariate transformations: +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::Distribution{Univariate}, μ::Distribution{Univariate}, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::MeasureBase.StdMeasure, μ::Distribution{Univariate}, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::Distribution{Univariate}, μ::MeasureBase.StdMeasure, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end + + +# Generic transformations to/from StdUniform via cdf/quantile: + + +_dist_params_numtype(d::Distribution) = promote_type(map(typeof, Distributions.params(d))...) + + +@inline _trafo_cdf(d::Distribution{Univariate,Continuous}, x::Real) = _trafo_cdf_impl(_dist_params_numtype(d), d, x) + +@inline _trafo_cdf_impl(::Type{<:Real}, d::Distribution{Univariate,Continuous}, x::Real) = Distributions.cdf(d, x) + +@inline function _trafo_cdf_impl(::Type{<:Union{Integer,AbstractFloat}}, d::Distribution{Univariate,Continuous}, x::ForwardDiff.Dual{TAG}) where TAG + x_v = ForwardDiff.value(x) + u = Distributions.cdf(d, x_v) + dudx = Distributions.pdf(d, x_v) + ForwardDiff.Dual{TAG}(u, dudx * ForwardDiff.partials(x)) +end + + +@inline _trafo_quantile(d::Distribution{Univariate,Continuous}, u::Real) = _trafo_quantile_impl(_dist_params_numtype(d), d, u) + +@inline _trafo_quantile_impl(::Type{<:Real}, d::Distribution{Univariate,Continuous}, u::Real) = _trafo_quantile_impl_generic(d, u) + +@inline function _trafo_quantile_impl(::Type{<:Union{Integer,AbstractFloat}}, d::Distribution{Univariate,Continuous}, u::ForwardDiff.Dual{TAG}) where {TAG} + x = _trafo_quantile_impl_generic(d, ForwardDiff.value(u)) + dxdu = inv(Distributions.pdf(d, x)) + ForwardDiff.Dual{TAG}(x, dxdu * ForwardDiff.partials(u)) +end + + +@inline _trafo_quantile_impl_generic(d::Distribution{Univariate,Continuous}, u::Real) = Distributions.quantile(d, u) + +# Workaround for Beta dist, ForwardDiff doesn't work for parameters: +@inline _trafo_quantile_impl_generic(d::Beta{T}, u::Real) where {T<:ForwardDiff.Dual} = convert(float(typeof(u)), NaN) +# Workaround for Beta dist, current quantile implementation only supports Float64: +@inline function _trafo_quantile_impl_generic(d::Beta{T}, u::Union{Integer,AbstractFloat}) where {T<:Union{Integer,AbstractFloat}} + Distributions.quantile(d, convert(promote_type(Float64, typeof(u)), u)) +end + +#= +# ToDo: + +# Workaround for rounding errors that can result in quantile values outside of support of Truncated: +@inline function _trafo_quantile_impl_generic(d::Truncated{<:Distribution{Univariate,Continuous}}, u::Real) + x = Distributions.quantile(d, u) + T = typeof(x) + min_x = T(minimum(d)) + max_x = T(maximum(d)) + if x < min_x && isapprox(x, min_x, atol = 4 * eps(T)) + min_x + elseif x > max_x && isapprox(x, max_x, atol = 4 * eps(T)) + max_x + else + x + end +end + +# Workaround for rounding errors that can result in quantile values outside of support of Truncated: +@inline function _trafo_quantile_impl_generic(d::Truncated{<:Distribution{Univariate,Continuous}}, u::Real) + x = Distributions.quantile(d, u) + T = typeof(x) + min_x = T(minimum(d)) + max_x = T(maximum(d)) + if x < min_x && isapprox(x, min_x, atol = 4 * eps(T)) + min_x + elseif x > max_x && isapprox(x, max_x, atol = 4 * eps(T)) + max_x + else + x + end +end +=# + + +@inline function _result_numtype(d::Distribution{Univariate}, x::T) where {T<:Real} + float(promote_type(T, eltype(Distributions.params(d)))) + # firsttype(first(typeof(x), promote_type(map(eltype, Distributions.params(d))...))) +end + + +@inline function MeasureBase.transport_def(::StdUniform, μ::Distribution{Univariate,Continuous}, x) + R = _result_numtype(μ, x) + if Distributions.insupport(μ, x) + y = _trafo_cdf(μ, x) + convert(R, y) + else + convert(R, NaN) + end +end + + +@inline function MeasureBase.transport_def(ν::Distribution{Univariate,Continuous}, ::StdUniform, x::T) where T + R = _result_numtype(ν, x) + TF = float(T) + if 0 <= x <= 1 + # Avoid x ≈ 0 and x ≈ 1 to avoid infinite variate values for target distributions with infinite support: + mod_x = ifelse(x == 0, zero(TF) + eps(TF), ifelse(x == 1, one(TF) - eps(TF), convert(TF, x))) + y = _trafo_quantile(ν, mod_x) + convert(R, y) + else + convert(R, NaN) + end +end + + +# Use standard measures as transformation origin for scaled/translated equivalents: + +function _origin_to_affine(ν::Distribution{Univariate}, y::T) where {T<:Real} + trg_offs, trg_scale = Distributions.location(ν), Distributions.scale(ν) + x = muladd(y, trg_scale, trg_offs) + convert(_result_numtype(ν, y), x) +end + +function _affine_to_origin(μ::Distribution{Univariate}, x::T) where {T<:Real} + src_offs, src_scale = Distributions.location(μ), Distributions.scale(μ) + y = (x - src_offs) / src_scale + convert(_result_numtype(μ, x), y) +end + +for (A, B) in [ + (Uniform, StdUniform), + (Logistic, StdLogistic), + (Normal, StdNormal) +] + @eval begin + @inline MeasureBase.transport_origin(::$A) = $B() + @inline MeasureBase.to_origin(ν::$A, y) = _affine_to_origin(ν, y) + @inline MeasureBase.from_origin(ν::$A, x) = _origin_to_affine(ν, x) + end +end + +@inline MeasureBase.transport_origin(::Exponential) = StdExponential() +@inline MeasureBase.to_origin(ν::Exponential, y) = Distributions.scale(ν) \ y +@inline MeasureBase.from_origin(ν::Exponential, x) = Distributions.scale(ν) * x + + + +# Transform between univariate and single-element power measure + +function MeasureBase.transport_def(ν::Distribution{Univariate}, μ::PowerMeasure{<:StdMeasure}, x) + return transport_def(ν, μ.parent, only(x)) +end + +function MeasureBase.transport_def(ν::PowerMeasure{<:StdMeasure}, μ::Distribution{Univariate}, x) + return Fill(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)...) +end + + +# Transform between univariate and single-element standard multivariate + +function MeasureBase.transport_def(ν::Distribution{Univariate}, μ::StandardDist{D,1}, x) where D + return transport_def(ν, StandardDist{D}(), only(x)) +end + +function MeasureBase.transport_def(ν::StandardDist{D,1}, μ::Distribution{Univariate}, x) where D + return Fill(transport_def(StandardDist{D}(), μ, only(x)), size(ν)...) +end diff --git a/ext/distributions/utils.jl b/ext/distributions/utils.jl new file mode 100644 index 00000000..be146786 --- /dev/null +++ b/ext/distributions/utils.jl @@ -0,0 +1,32 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + + +""" + convert_realtype(::Type{T}, x) where {T<:Real} + +Convert x to use `T` as it's underlying type for real numbers. +""" +function convert_realtype end + +_convert_realtype_pullback(ΔΩ) = NoTangent(), NoTangent, ΔΩ +ChainRulesCore.rrule(::typeof(convert_realtype), ::Type{T}, x) where T = convert_realtype(T, x), _convert_realtype_pullback + +@inline convert_realtype(::Type{T}, x::T) where {T<:Real} = x +@inline convert_realtype(::Type{T}, x::AbstractArray{T}) where {T<:Real} = x +@inline convert_realtype(::Type{T}, x::U) where {T<:Real,U<:Real} = T(x) +convert_realtype(::Type{T}, x::AbstractArray{U}) where {T<:Real,U<:Real} = T.(x) +convert_realtype(::Type{T}, x) where {T<:Real} = fmap(elem -> convert_realtype(T, elem), x) + + +""" + firsttype(::Type{T}, ::Type{U}) where {T<:Real,U<:Real} + +Return the first type, but as a dual number type if the second one is dual. + +If `U <: ForwardDiff.Dual{tag,<:Real,N}`, returns `ForwardDiff.Dual{tag,T,N}`, +otherwise returns `T` +""" +function firsttype end + +firsttype(::Type{T}, ::Type{U}) where {T<:Real,U<:Real} = T +firsttype(::Type{T}, ::Type{<:ForwardDiff.Dual{tag,<:Real,N}}) where {T<:Real,tag,N} = ForwardDiff.Dual{tag,T,N} diff --git a/test/distributions/getjacobian.jl b/test/distributions/getjacobian.jl new file mode 100644 index 00000000..87de7b86 --- /dev/null +++ b/test/distributions/getjacobian.jl @@ -0,0 +1,34 @@ +# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT). + +import ForwardDiff + +torv_and_back(V::AbstractVector{<:Real}) = V, identity +torv_and_back(x::Real) = [x], V -> V[1] +torv_and_back(x::Complex) = [real(x), imag(x)], V -> Complex(V[1], V[2]) +torv_and_back(x::NTuple{N}) where N = [x...], V -> ntuple(i -> V[i], Val(N)) + +function torv_and_back(x::Ref) + xval = x[] + V, to_xval = torv_and_back(xval) + back_to_ref(V) = Ref(to_xval(V)) + return (V, back_to_ref) +end + +torv_and_back(A::AbstractArray{<:Real}) = vec(A), V -> reshape(V, size(A)) + +function torv_and_back(A::AbstractArray{Complex{T}, N}) where {T<:Real, N} + RA = cat(real.(A), imag.(A), dims = N+1) + V, to_array = torv_and_back(RA) + function back_to_complex(V) + RA = to_array(V) + Complex.(view(RA, map(_ -> :, size(A))..., 1), view(RA, map(_ -> :, size(A))..., 2)) + end + return (V, back_to_complex) +end + + +function getjacobian(f, x) + V, to_x = torv_and_back(x) + vf(V) = torv_and_back(f(to_x(V)))[1] + ForwardDiff.jacobian(vf, V) +end diff --git a/test/distributions/test_autodiff_utils.jl b/test/distributions/test_autodiff_utils.jl new file mode 100644 index 00000000..6399bb80 --- /dev/null +++ b/test/distributions/test_autodiff_utils.jl @@ -0,0 +1,19 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using LinearAlgebra +using Distributions, ArraysOfArrays +import ForwardDiff, Zygote + + +@testset "trafo_utils" begin + xs = rand(5) + @test Zygote.jacobian(DistributionMeasures._pushfront, xs, 42)[1] ≈ ForwardDiff.jacobian(xs -> DistributionMeasures._pushfront(xs, 1), xs) + @test Zygote.jacobian(DistributionMeasures._pushfront, xs, 42)[2] ≈ vec(ForwardDiff.jacobian(x -> DistributionMeasures._pushfront(xs, x[1]), [42])) + @test Zygote.jacobian(DistributionMeasures._pushback, xs, 42)[1] ≈ ForwardDiff.jacobian(xs -> DistributionMeasures._pushback(xs, 1), xs) + @test Zygote.jacobian(DistributionMeasures._pushback, xs, 42)[2] ≈ vec(ForwardDiff.jacobian(x -> DistributionMeasures._pushback(xs, x[1]), [42])) + @test Zygote.jacobian(DistributionMeasures._rev_cumsum, xs)[1] ≈ ForwardDiff.jacobian(DistributionMeasures._rev_cumsum, xs) + @test Zygote.jacobian(DistributionMeasures._exp_cumsum_log, xs)[1] ≈ ForwardDiff.jacobian(DistributionMeasures._exp_cumsum_log, xs) ≈ ForwardDiff.jacobian(cumprod, xs) +end diff --git a/test/distributions/test_distribution_measure.jl b/test/distributions/test_distribution_measure.jl new file mode 100644 index 00000000..33b1ecbe --- /dev/null +++ b/test/distributions/test_distribution_measure.jl @@ -0,0 +1,54 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +import Distributions +using Distributions: Distribution +import MeasureBase +using MeasureBase: AbstractMeasure + +@testset "Measure interface" begin + d = Distributions.Weibull() + @test @inferred(AbstractMeasure(d)) isa AbstractMeasure + @test @inferred(AbstractMeasure(d)) isa DistributionMeasure + @test @inferred(convert(AbstractMeasure, d)) isa AbstractMeasure + @test @inferred(convert(AbstractMeasure, d)) isa DistributionMeasure + @test @inferred(Distribution(AbstractMeasure(d))) === d + @test @inferred(convert(Distribution, convert(AbstractMeasure, d))) === d + + + c0 = AbstractMeasure(Distributions.Weibull(0.7, 1.3)) + c1 = AbstractMeasure(Distributions.MvNormal([0.7, 0.9], [1.4 0.5; 0.5 1.1])) + + d0 = AbstractMeasure(Distributions.Poisson(0.7)) + d1 = AbstractMeasure(Distributions.product_distribution(Distributions.Poisson.([0.7, 1.4]))) + + for μ in [c0, c1, d0, d1] + d = Distribution(μ) + x = rand(μ) + @test @inferred(MeasureBase.logdensity_def(μ, x)) == Distributions.logpdf(d, x) + @test @inferred(MeasureBase.unsafe_logdensityof(μ, x)) == Distributions.logpdf(d, x) + + MeasureBase.Interface.test_interface(d) + end + + @test @inferred(MeasureBase.basemeasure(c0)) == MeasureBase.Lebesgue(MeasureBase.ℝ) + @test @inferred(MeasureBase.basemeasure(c1)) == MeasureBase.Lebesgue(MeasureBase.ℝ) ^ 2 + + @test @inferred(MeasureBase.insupport(c0, 3)) == true + @test @inferred(MeasureBase.insupport(c0, -3)) == false + @test @inferred(MeasureBase.insupport(c1, [0.1, 0.2])) == true + @test @inferred(MeasureBase.insupport(d0, 3)) == true + @test @inferred(MeasureBase.insupport(d0, 3.2)) == false + @test @inferred(MeasureBase.insupport(d1, [1, 2])) == true + @test @inferred(MeasureBase.insupport(d1, [1.1, 2.2])) == false + + @test MeasureBase.paramnames(c0) == (:α, :θ) + if VERSION >= v"1.8" + @test @inferred(MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + else + # v1.6 can't type-infer this: + @test (MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + end +end diff --git a/test/distributions/test_distributions.jl b/test/distributions/test_distributions.jl new file mode 100644 index 00000000..6ad52a73 --- /dev/null +++ b/test/distributions/test_distributions.jl @@ -0,0 +1,12 @@ +using DistributionMeasures +using Test + +@testset "Distributions extension" begin + include("test_autodiff_utils.jl") + include("test_measure_interface.jl") + include("test_distribution_measure.jl") + include("test_standard_dist.jl") + include("test_standard_uniform.jl") + include("test_standard_normal.jl") + include("test_transport.jl") +end diff --git a/test/distributions/test_measure_interface.jl b/test/distributions/test_measure_interface.jl new file mode 100644 index 00000000..50873533 --- /dev/null +++ b/test/distributions/test_measure_interface.jl @@ -0,0 +1,44 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +import Distributions +import MeasureBase + +@testset "Measure interface" begin + c0 = Distributions.Weibull(0.7, 1.3) + c1 = Distributions.MvNormal([0.7, 0.9], [1.4 0.5; 0.5 1.1]) + + d0 = Distributions.Poisson(0.7) + d1 = Distributions.product_distribution(Distributions.Poisson.([0.7, 1.4])) + + for d in [c0, c1, d0, d1] + x = rand(d) + @test @inferred(MeasureBase.logdensity_def(d, x)) == Distributions.logpdf(d, x) + @test @inferred(MeasureBase.unsafe_logdensityof(d, x)) == Distributions.logpdf(d, x) + + MeasureBase.Interface.test_interface(d) + end + + @test @inferred(MeasureBase.basemeasure(c0)) == MeasureBase.Lebesgue(MeasureBase.ℝ) + @test @inferred(MeasureBase.basemeasure(c1)) == MeasureBase.Lebesgue(MeasureBase.ℝ) ^ 2 + + @test @inferred(MeasureBase.insupport(c0, 3)) == true + @test @inferred(MeasureBase.insupport(c0, -3)) == false + @test @inferred(MeasureBase.insupport(c1, [0.1, 0.2])) == true + @test @inferred(MeasureBase.insupport(d0, 3)) == true + @test @inferred(MeasureBase.insupport(d0, 3.2)) == false + @test @inferred(MeasureBase.insupport(d1, [1, 2])) == true + @test @inferred(MeasureBase.insupport(d1, [1.1, 2.2])) == false + + @test MeasureBase.paramnames(c0) == (:α, :θ) + if VERSION >= v"1.8" + @test @inferred(MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + else + # v1.6 can't type-infer this: + @test (MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + end + + @test MeasureBase.∫(x -> Distributions.Normal(x, 0), Distributions.Normal()) isa MeasureBase.DensityMeasure +end diff --git a/test/distributions/test_standard_dist.jl b/test/distributions/test_standard_dist.jl new file mode 100644 index 00000000..4d211d87 --- /dev/null +++ b/test/distributions/test_standard_dist.jl @@ -0,0 +1,129 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs +import ForwardDiff, ChainRulesTestUtils + + +@testset "standard_dist" begin + stblrng() = StableRNG(789990641) + + for (D, sz, dref) in [ + (Uniform, (), Uniform()), + (Uniform, (5,), product_distribution(fill(Uniform(0.0, 1.0), 5))), + (Uniform, (2, 3), reshape(product_distribution(fill(Uniform(0.0, 1.0), 6)), 2, 3)), + (Normal, (), Normal()), + (Normal, (), Normal(0., 1.0)), + (Normal, (5,), MvNormal(Diagonal(fill(1.0, 5)))), + (Normal, (2, 3), reshape(MvNormal(Diagonal(fill(1.0, 6))), 2, 3)), + (Exponential, (), Exponential()), + (Exponential, (5,), product_distribution(fill(Exponential(1.0), 5))), + (Exponential, (2, 3), reshape(product_distribution(fill(Exponential(1.0), 6)), 2, 3)), + ] + @testset "StandardDist{$D}($(join(sz,",")))" begin + N = length(sz) + + @test @inferred(StandardDist{D}(sz...)) isa StandardDist{D} + @test @inferred(StandardDist{D}(sz...)) isa StandardDist{D} + @test @inferred(size(StandardDist{D}(sz...))) == size(dref) + @test @inferred(size(StandardDist{D}(sz...))) == size(dref) + + d = StandardDist{D}(sz...) + + if size(d) == () + @test @inferred(DistributionMeasures.nonstddist(d)) == dref + end + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + for f in [minimum, maximum, mean, median, mode, modes, var, std, skewness, kurtosis, location, scale, entropy] + supported_by_dref = try f(dref); true catch MethodError; false; end + if supported_by_dref + @test @inferred(f(d)) ≈ f(dref) + end + end + + for x in [rand(dref) for i in 1:10] + ref_gradlogpdf = try + gradlogpdf(dref, x) + catch MethodError + ForwardDiff.gradient(x -> logpdf(dref, x), x) + end + @test @inferred(gradlogpdf(d, x)) ≈ ref_gradlogpdf + @test @inferred(logpdf(d, x)) ≈ logpdf(dref, x) + @test @inferred(pdf(d, x)) ≈ pdf(dref, x) + end + + if size(d) == () + for x in [minimum(dref), quantile(dref, 1//3), quantile(dref, 1//2), quantile(dref, 2//3), maximum(dref)] + for f in [logpdf, pdf, gradlogpdf, logcdf, cdf, logccdf, ccdf] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for x in [0, 1//3, 1//2, 2//3, 1] + for f in [quantile, cquantile] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for x in log.([0, 1//3, 1//2, 2//3, 1]) + for f in [invlogcdf, invlogccdf] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test isapprox(@inferred(mgf(d, t)), mgf(dref, t), rtol = 1e-5) + @test isapprox(@inferred(cf(d, t)), cf(dref, t), rtol = 1e-5) + end + + @test @inferred(truncated(d, quantile(dref, 1//3), quantile(dref, 2//3))) == truncated(dref, quantile(dref, 1//3), quantile(dref, 2//3)) + + @test @inferred(product_distribution(fill(d, 3))) == StandardDist{typeof(d)}(3) + @test @inferred(product_distribution(fill(d, 3, 4))) == StandardDist{typeof(d)}(3, 4) + end + + if length(size(d)) == 1 + @test @inferred(convert(Distributions.Product, d)) isa Distributions.Product + d_as_prod = convert(Distributions.Product, d) + @test d_as_prod.v == fill(StandardDist{D}(), size(d)...) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), d, 5) + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + @test @inferred(rand!(stblrng(), d, zeros(size(d)...))) == rand!(stblrng(), dref, zeros(size(dref)...)) + if length(size(d)) == 1 + @test @inferred(rand!(stblrng(), d, zeros(size(d)..., 5))) == rand!(stblrng(), dref, zeros(size(dref)..., 5)) + end + end + end + + @testset "StandardDist{Normal}()" begin + # TODO: Add @inferred + d = StandardDist{Normal}(4) + d_uv = StandardDist{Normal}() + dref = MvNormal(Diagonal(fill(1.0, 4))) + @test (MvNormal(d)) == dref + @test (Base.convert(MvNormal, d)) == dref + end +end diff --git a/test/distributions/test_standard_normal.jl b/test/distributions/test_standard_normal.jl new file mode 100644 index 00000000..4d5cbad4 --- /dev/null +++ b/test/distributions/test_standard_normal.jl @@ -0,0 +1,130 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs + + +@testset "StandardDist{Normal}" begin + stblrng() = StableRNG(789990641) + + @testset "StandardDist{Normal,0}" begin + @test @inferred(Normal(StandardDist{Normal}())) isa Normal{Float64} + @test @inferred(Normal(StandardDist{Normal}())) == Normal() + @test @inferred(convert(Normal, StandardDist{Normal}())) == Normal() + + d = StandardDist{Normal}() + dref = Normal() + + @test @inferred(minimum(d)) == minimum(dref) + @test @inferred(maximum(d)) == maximum(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(location(d)) == location(dref) + @test @inferred(scale(d)) == scale(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(median(d)) == median(dref) + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) ≈ modes(dref) + + @test @inferred(var(d)) == var(dref) + @test @inferred(std(d)) == std(dref) + @test @inferred(skewness(d)) == skewness(dref) + @test @inferred(kurtosis(d)) == kurtosis(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in [-Inf, -1.3, 0.0, 1.3, +Inf] + @test @inferred(gradlogpdf(d, x)) == gradlogpdf(dref, x) + + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(logcdf(d, x)) == logcdf(dref, x) + @test @inferred(cdf(d, x)) == cdf(dref, x) + @test @inferred(logccdf(d, x)) == logccdf(dref, x) + @test @inferred(ccdf(d, x)) == ccdf(dref, x) + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test @inferred(mgf(d, t)) == mgf(dref, t) + @test @inferred(cf(d, t)) == cf(dref, t) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand!(stblrng(), d, fill(0.0))) == rand!(stblrng(), dref, fill(0.0)) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + + @test @inferred(truncated(StandardDist{Normal}(), -2.2f0, 3.1f0)) isa Truncated{Normal{Float64}} + @test truncated(StandardDist{Normal}(), -2.2f0, 3.1f0) == truncated(Normal(0.0, 1.0), -2.2f0, 3.1f0) + + @test @inferred(product_distribution(fill(StandardDist{Normal}(), 3))) isa StandardDist{Normal,1} + @test product_distribution(fill(StandardDist{Normal}(), 3)) == StandardDist{Normal}(3) + end + + + @testset "StandardDist{Normal,1}" begin + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + + @test @inferred(MvNormal(StandardDist{Normal}(3))) isa MvNormal{Int} + @test @inferred(MvNormal(StandardDist{Normal}(3))) == MvNormal(ScalMat(3, 1.0)) + @test @inferred(convert(MvNormal, StandardDist{Normal}(3))) == MvNormal(ScalMat(3, 1.0)) + + d = StandardDist{Normal}(3) + dref = MvNormal(ScalMat(3, 1.0)) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(var(d)) == var(dref) + @test @inferred(cov(d)) == cov(dref) + + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) == modes(dref) + + @test @inferred(invcov(d)) == invcov(dref) + @test @inferred(logdetcov(d)) == logdetcov(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in fill.([-Inf, -1.3, 0.0, 1.3, +Inf], 3) + # Distributions.insupport is inconsistent at +- Inf between Normal and MvNormal + if !any(isinf, x) + @test @inferred(Distributions.insupport(d, x)) == Distributions.insupport(dref, x) + end + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(sqmahal(d, x)) == sqmahal(dref, x) + @test @inferred(gradlogpdf(d, x)) == gradlogpdf(dref, x) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand!(stblrng(), d, zeros(3))) == rand!(stblrng(), d, zeros(3)) + @test @inferred(rand!(stblrng(), d, zeros(3, 10))) == rand!(stblrng(), d, zeros(3, 10)) + end +end diff --git a/test/distributions/test_standard_uniform.jl b/test/distributions/test_standard_uniform.jl new file mode 100644 index 00000000..66c8a99c --- /dev/null +++ b/test/distributions/test_standard_uniform.jl @@ -0,0 +1,119 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs +using FillArrays +using ForwardDiff + + +@testset "StandardDist{Uniform}" begin + stblrng() = StableRNG(789990641) + + @testset "StandardDist{Uniform,0}" begin + @test @inferred(Uniform(StandardDist{Uniform}())) isa Uniform{Float64} + @test @inferred(Uniform(StandardDist{Uniform}())) == Uniform() + @test @inferred(convert(Uniform, StandardDist{Uniform}())) == Uniform() + + d = StandardDist{Uniform}() + dref = Uniform() + + @test @inferred(minimum(d)) == minimum(dref) + @test @inferred(maximum(d)) == maximum(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(location(d)) == location(dref) + @test @inferred(scale(d)) == scale(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(median(d)) == median(dref) + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) ≈ modes(dref) + + @test @inferred(var(d)) ≈ var(dref) + @test @inferred(std(d)) ≈ std(dref) + @test @inferred(skewness(d)) == skewness(dref) + @test @inferred(kurtosis(d)) ≈ kurtosis(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in [-0.5, 0.0, 0.25, 0.75, 1.0, 1.5] + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(logcdf(d, x)) == logcdf(dref, x) + @test @inferred(cdf(d, x)) == cdf(dref, x) + @test @inferred(logccdf(d, x)) == logccdf(dref, x) + @test @inferred(ccdf(d, x)) == ccdf(dref, x) + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test @inferred(mgf(d, t)) == mgf(dref, t) + @test @inferred(cf(d, t)) == cf(dref, t) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand!(stblrng(), d, fill(0.0))) == rand!(stblrng(), dref, fill(0.0)) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + + @test @inferred(truncated(StandardDist{Uniform}(), -0.5f0, 0.7f0)) isa Uniform{Float64} + @test truncated(StandardDist{Uniform}(), -0.5f0, 0.7f0) == Uniform(0.0f0, 0.7f0) + @test truncated(StandardDist{Uniform}(), 0.2f0, 0.7f0) == Uniform(0.2f0, 0.7f0) + + @test @inferred(product_distribution(fill(StandardDist{Uniform}(), 3))) isa DistributionMeasures.StandardDist{Uniform,1} + @test product_distribution(fill(StandardDist{Uniform}(), 3)) == DistributionMeasures.StandardDist{Uniform}(3) + end + + + @testset "StandardDist{Uniform,1}" begin + d = DistributionMeasures.StandardDist{Uniform}(3) + dref = product_distribution(fill(Uniform(), 3)) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(var(d)) ≈ var(dref) + @test @inferred(cov(d)) ≈ cov(dref) + + @test @inferred(mode(d)) == [0.5, 0.5, 0.5] + @test @inferred(modes(d)) == fill([0, 0,0 ]) + + @test @inferred(invcov(d)) == inv(cov(dref)) + @test @inferred(logdetcov(d)) == logdet(cov(dref)) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in fill.([-Inf, -1.3, 0.0, 1.3, +Inf], 3) + @test @inferred(Distributions.insupport(d, x)) == Distributions.insupport(dref, x) + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(gradlogpdf(d, x)) == ForwardDiff.gradient(x -> logpdf(d, x), x) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand!(stblrng(), d, zeros(3))) == rand!(stblrng(), d, zeros(3)) + @test @inferred(rand!(stblrng(), d, zeros(3, 10))) == rand!(stblrng(), d, zeros(3, 10)) + end +end diff --git a/test/distributions/test_transport.jl b/test/distributions/test_transport.jl new file mode 100644 index 00000000..1542dcc2 --- /dev/null +++ b/test/distributions/test_transport.jl @@ -0,0 +1,149 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using LinearAlgebra +using InverseFunctions, ChangesOfVariables +using Distributions, ArraysOfArrays +import ForwardDiff, Zygote + +using MeasureBase: transport_to, transport_def, transport_origin +using MeasureBase: StdExponential +using DistributionMeasures: _trafo_cdf, _trafo_quantile + +include("getjacobian.jl") + + +@testset "test_distribution_transform" begin + function test_back_and_forth(trg, src) + @testset "transform $(typeof(trg).name) <-> $(typeof(src).name)" begin + x = rand(src) + y = transport_def(trg, src, x) + src_v_reco = transport_def(src, trg, y) + + @test x ≈ src_v_reco + + f = x -> transport_def(trg, src, x) + ref_ladj = logpdf(src, x) - logpdf(trg, y) + @test ref_ladj ≈ logabsdet(getjacobian(f, x))[1] + end + end + + reshaped_rand(d::Distribution{Univariate}, n) = rand(d, n) + reshaped_rand(d::Distribution{Multivariate}, n) = nestedview(rand(d, n)) + + function test_dist_trafo_moments(trg, src) + unshaped(x) = first(torv_and_back(x)) + @testset "check moments of trafo $(typeof(trg).name) <- $(typeof(src).name)" begin + X = reshaped_rand(src, 10^5) + Y = transport_to(trg, src).(X) + Y_ref = reshaped_rand(trg, 10^6) + @test isapprox(mean(unshaped.(Y)), mean(unshaped.(Y_ref)), rtol = 0.5) + @test isapprox(cov(unshaped.(Y)), cov(unshaped.(Y_ref)), rtol = 0.5) + end + end + + @testset "transforms-tests" begin + stduvuni = StandardDist{Uniform}() + stduvnorm = StandardDist{Uniform}() + + uniform1 = Uniform(-5.0, -0.01) + uniform2 = Uniform(0.01, 5.0) + + normal1 = Normal(-10, 1) + normal2 = Normal(10, 5) + + stdmvnorm1 = StandardDist{Normal}(1) + stdmvnorm2 = StandardDist{Normal}(2) + + stdmvuni2 = StandardDist{Uniform}(2) + + standnorm2_reshaped = reshape(stdmvnorm2, 1, 2) + + mvnorm = MvNormal([0.3, -2.9], [1.7 0.5; 0.5 2.3]) + beta = Beta(3,1) + gamma = Gamma(0.1,0.7) + dirich = Dirichlet([0.1,4]) + + test_back_and_forth(stduvuni, stduvuni) + test_back_and_forth(stduvnorm, stduvnorm) + test_back_and_forth(stduvuni, stduvnorm) + test_back_and_forth(stduvnorm, stduvuni) + + test_back_and_forth(stdmvuni2, stdmvuni2) + test_back_and_forth(stdmvnorm2, stdmvnorm2) + test_back_and_forth(stdmvuni2, stdmvnorm2) + test_back_and_forth(stdmvnorm2, stdmvuni2) + + test_back_and_forth(beta, stduvnorm) + test_back_and_forth(gamma, stduvnorm) + test_back_and_forth(gamma, beta) + + test_back_and_forth(mvnorm, stdmvuni2) + test_back_and_forth(stdmvuni2, mvnorm) + + test_back_and_forth(mvnorm, standnorm2_reshaped) + test_back_and_forth(standnorm2_reshaped, mvnorm) + test_back_and_forth(stdmvnorm2, standnorm2_reshaped) + test_back_and_forth(standnorm2_reshaped, standnorm2_reshaped) + + test_dist_trafo_moments(normal2, normal1) + test_dist_trafo_moments(uniform2, uniform1) + + test_dist_trafo_moments(beta, stduvnorm) + test_dist_trafo_moments(gamma, stduvnorm) + + test_dist_trafo_moments(mvnorm, stdmvnorm2) + test_dist_trafo_moments(dirich, stdmvnorm1) + + let + mvuni = product_distribution([Uniform(), Uniform()]) + + x = rand() + @test_throws ArgumentError transport_to(stduvnorm, mvnorm)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm1)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm2)(x) + + x = rand(2) + @test_throws ArgumentError transport_to(stduvnorm, mvnorm)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm1)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm2)(x) + end + end + + @testset "Custom cdf and quantile for dual numbers" begin + Dual = ForwardDiff.Dual + + @test isapprox(_trafo_cdf(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)), cdf(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)), rtol = 10^-6) + @test isapprox(_trafo_cdf(Normal(0, 1), Dual(0.5, 1)), cdf(Normal(0, 1), Dual(0.5, 1)), rtol = 10^-6) + + @test isapprox(_trafo_quantile(Normal(0, 1), Dual(0.5, 1)), quantile(Normal(0, 1), Dual(0.5, 1)), rtol = 10^-6) + @test isapprox(_trafo_quantile(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)), quantile(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)), rtol = 10^-6) + end + + @testset "trafo autodiff pullbacks" begin + x = [0.6, 0.7, 0.8, 0.9] + f = transport_to(Dirichlet([3.0, 4.0, 5.0, 6.0, 7.0]), Uniform) + @test isapprox(ForwardDiff.jacobian(f, x), Zygote.jacobian(f, x)[1], rtol = 10^-4) + f = inverse(transport_to(Normal, Dirichlet([3.0, 4.0, 5.0, 6.0, 7.0]))) + @test isapprox(ForwardDiff.jacobian(f, x), Zygote.jacobian(f, x)[1], rtol = 10^-4) + end + + + @testset "transport_to autosel" begin + for (M,R) in [ + (StandardNormal, StandardNormal) + (Normal, StandardNormal) + (StandardUniform, StandardUniform) + (Uniform, StandardUniform) + ] + @test @inferred(transport_to(M, Weibull())) == transport_to(R(), Weibull()) + @test @inferred(transport_to(Weibull(), M)) == transport_to(Weibull(), R()) + @test @inferred(transport_to(M, MvNormal(float(I(5))))) == transport_to(R(5), MvNormal(float(I(5)))) + @test @inferred(transport_to(MvNormal(float(I(5))), M)) == transport_to(MvNormal(float(I(5))), R(5)) + @test @inferred(transport_to(M, StdExponential()^(2,3))) == transport_to(R(6), StdExponential()^(2,3)) + @test @inferred(transport_to(StdExponential()^(2,3), M)) == transport_to(StdExponential()^(2,3), R(6)) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b31a9da5..374d56be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,4 +21,6 @@ include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") include("combinators/reshape.jl") +include("test_distributions.jl") + include("test_docs.jl") From 1d271885dfeeac1149817af1b748588b67a4d696 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:07 +0100 Subject: [PATCH 11/28] Remove PowerWeightedMeasure Unused and untested. --- src/MeasureBase.jl | 1 - src/combinators/powerweighted.jl | 37 -------------------------------- 2 files changed, 38 deletions(-) delete mode 100644 src/combinators/powerweighted.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 3dbab005..ceaadae1 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -176,7 +176,6 @@ include("combinators/likelihood.jl") include("combinators/pointwise.jl") include("combinators/restricted.jl") include("combinators/smart-constructors.jl") -include("combinators/powerweighted.jl") include("combinators/conditional.jl") include("standard/stdmeasure.jl") diff --git a/src/combinators/powerweighted.jl b/src/combinators/powerweighted.jl deleted file mode 100644 index 47f50da4..00000000 --- a/src/combinators/powerweighted.jl +++ /dev/null @@ -1,37 +0,0 @@ -export ↑ - -struct PowerWeightedMeasure{M,A} <: AbstractMeasure - parent::M - exponent::A -end - -logdensity_def(d::PowerWeightedMeasure, x) = d.exponent * logdensity_def(d.parent, x) - -basemeasure(d::PowerWeightedMeasure, x) = basemeasure(d.parent, x)↑d.exponent - -basemeasure(d::PowerWeightedMeasure) = basemeasure(d.parent)↑d.exponent - -function powerweightedmeasure(d, α) - isone(α) && return d - PowerWeightedMeasure(d, α) -end - -(d::AbstractMeasure)↑α = powerweightedmeasure(d, α) - -insupport(d::PowerWeightedMeasure, x) = insupport(d.parent, x) - -function Base.show(io::IO, d::PowerWeightedMeasure) - print(io, d.parent, " ↑ ", d.exponent) -end - -function powerweightedmeasure(d::PowerWeightedMeasure, α) - powerweightedmeasure(d.parent, α * d.exponent) -end - -function powerweightedmeasure(d::WeightedMeasure, α) - weightedmeasure(α * d.logweight, powerweightedmeasure(d.base, α)) -end - -function Pretty.tile(d::PowerWeightedMeasure) - Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.exponent), sep = " ↑ ") -end From 98445838a025ddfc310e59f0e4470baf6263f6b1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:07 +0100 Subject: [PATCH 12/28] Remove kernelfactor Not used currently. --- src/parameterized.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/parameterized.jl b/src/parameterized.jl index 78e43995..8b1c8c88 100644 --- a/src/parameterized.jl +++ b/src/parameterized.jl @@ -127,14 +127,3 @@ params(::Type{PM}) where {N,PM<:ParameterizedMeasure{N}} = N function paramnames(μ, constraints::NamedTuple{N}) where {N} tuple((k for k in paramnames(μ) if k ∉ N)...) end - -############################################################################### -# kernelfactor - -function kernelfactor(::Type{P}) where {N,P<:ParameterizedMeasure{N}} - (constructorof(P), N) -end - -function kernelfactor(::P) where {N,P<:ParameterizedMeasure{N}} - (constructorof(P), N) -end From cd92b8d828d10a3d4702eb5de4ff34b24705f868 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:07 +0100 Subject: [PATCH 13/28] Rename pullback to pullbck and export it pullback has a huge potential for naming conflickts, and pullbck is more in line with pushfwd. Also simplify implementation of pullbck. --- src/combinators/transformedmeasure.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 803b404b..dab76d5f 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -140,7 +140,7 @@ end # pullback """ - pullback(f, μ, volcorr = WithVolCorr()) + pullbck(f, μ, volcorr = WithVolCorr()) A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a map _from_ the support of a measure, a pullback requires a map _into_ the @@ -152,8 +152,11 @@ in terms of the inverse function; the "forward" function is not used at all. In some cases, we may be focusing on log-density (and not, for example, sampling). To manually specify an inverse, call -`pullback(InverseFunctions.setinverse(f, finv), μ, volcorr)`. +`pullbck(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -function pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) - pushfwd(setinverse(inverse(f), f), μ, volcorr) +function pullbck(f, μ, volcorr::TransformVolCorr = WithVolCorr()) + PushforwardMeasure(inverse(f), f, μ, volcorr) end +export pullbck + +@deprecate pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) pullbck(f, μ, volcorr) From f500aea9f7b6d508e8146d25998aa38eb2ba0391 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 14/28] Rename bind to mbind and deprecate rightarrowtail Bind has too much naming conflict potential with Base.bind. The rightarrowtail operator looks very similar to the `>=>` "fish" operator (e.g. in Haskell), which is not a monadic bind. --- src/combinators/bind.jl | 45 ++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index cc2022f2..491bd5a2 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -3,34 +3,47 @@ struct Bind{M,K} <: AbstractMeasure k::K end -export ↣ """ -If -- μ is an `AbstractMeasure` or satisfies the Measure interface, and -- k is a function taking values from the support of μ and returning a measure + mbind(k, μ)::AbstractMeasure + +Given + +- a measure μ +- a kernel function k that takes values from the support of μ and returns a + measure -Then `μ ↣ k` is a measure, called a *monadic bind*. In a -probabilistic programming language like Soss.jl, this could be expressed as +The *monadic bind* operation `mbind(k, μ)` returns is a new measure. -Note that bind is usually written `>>=`, but this symbol is unavailable in Julia. +A monadic bind ofen written as `>>=` (e.g. in Haskell), but this symbol is +unavailable in Julia. ``` -bind = @model μ,k begin - x ~ μ - y ~ k(x) - return y +μ = StdExponential() +ν = mbind(μ) do scale + pushfwd(Base.Fix1(*, scale), StdNormal()) end ``` - -See also `bind` and `Bind` """ -↣(μ, k) = bind(μ, k) - -bind(μ, k) = Bind(μ, k) +mbind(k, μ) = Bind(μ, k) +export mbind function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} x = rand(rng, T, d.μ) y = rand(rng, T, d.k(x)) return y end + + +# ToDo: Remove `bind` (breaking). +@noinline function bind(μ, k) + Base.depwarn("`foo(μ, k)` is deprecated, use `mbind(k, μ)` instead.", :bind) + mbind(k, μ) +end + + +# ToDo: Remove `↣` (breaking): It looks too similar to the `>=>` "fish" +# operator (e.g. in Haskell) that is typically understood to take two monadic +# functions as an argument, while a bind take a monad and a monadic functions. +@deprecate ↣(μ, k) mbind(μ, k) +export ↣ From fefe36df7f1d5053292dc9984b1e07d8164535a1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 15/28] Introduce mintegrate and mintegrate_exp Removes the integral operators from MeasureBase, to be re-introduced in the submodule MeasureOperators. Also improves the likelihood documentation. --- src/combinators/likelihood.jl | 134 +++++++++++++++++++++------------- src/density.jl | 92 +++++++++++++++-------- 2 files changed, 143 insertions(+), 83 deletions(-) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 6dfd164f..93dc1186 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -11,9 +11,9 @@ abstract type AbstractLikelihood end # insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x) @doc raw""" - Likelihood(k::AbstractTransitionKernel, x) + Likelihood(k, x) -"Observe" a value `x`, yielding a function from the parameters to ℝ. +Default result of [`likelihoodof(k, x)`](@ref). Likelihoods are most commonly used in conjunction with an existing _prior_ measure to yield a new measure, the _posterior_. In Bayes's Law, we have @@ -91,12 +91,10 @@ Similarly to the above, we have Finally, let's return to the expression for Bayes's Law, -``P(θ|x) ∝ P(θ) P(x|θ)`` +``P(θ|x) ∝ P(x|θ) P(θ)`` -The product on the right side is computed pointwise. To work with this in -MeasureBase, we have a "pointwise product" `⊙`, which takes a measure and a -likelihood, and returns a new measure, that is, the unnormalized posterior that -has density ``P(θ) P(x|θ)`` with respect to the base measure of the prior. +In measure theory, the product on the right side is actually the Lebesgue integral, +of the likelihood with respect to the prior. For example, say we have @@ -104,21 +102,24 @@ For example, say we have x ~ Normal(μ,σ) σ = 1 -and we observe `x=3`. We can compute the posterior measure on `μ` as - - julia> post = Normal() ⊙ Likelihood(Normal{(:μ, :σ)}, (σ=1,), 3) - Normal() ⊙ Likelihood(Normal{(:μ, :σ), T} where T, (σ = 1,), 3) +and we observe `x=3`. We can compute the (non-normalized) posterior measure on +`μ` as - julia> logdensity_def(post, 2) - -2.5 + julia> prior = Normal() + julia> likelihood = Likelihood(μ -> Normal(μ, 1), 3) + julia> post = mintegrate(likelihood, prior) + julia> post isa MeasureBase.DensityMeasure + true + julia> logdensity_rel(post, Lebesgue(), 2) + -4.337877066409345 """ struct Likelihood{K,X} <: AbstractLikelihood k::K x::X - Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k, x) - Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k, x) - Likelihood(μ, x) = Likelihood(kernel(μ), x) + Likelihood(k::K, x::X) where {K,X} = new{K,X}(k, x) +#!!!!!!!!!!! # For type stability if `K isa UnionAll (e.g. a parameterized MeasureType)` + Likelihood(::Type{K}, x::X) where {K<:AbstractMeasure,X} = new{K,X}(K, x) end (lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) @@ -150,58 +151,87 @@ end export likelihoodof -""" - likelihoodof(k::AbstractTransitionKernel, x; constraints...) - likelihoodof(k::AbstractTransitionKernel, x, constraints::NamedTuple) +@doc raw""" + likelihoodof(k, x) -A likelihood is *not* a measure. Rather, a likelihood acts on a measure, through -the "pointwise product" `⊙`, yielding another measure. -""" -function likelihoodof end +Returns the likelihood of observing `x` under a family of probability +measures that is generated by a transition kernel `k(θ)`. + +`k(θ)` maps points in the parameter space to measures (resp. objects that can +be converted to measures) on a implicit set `Χ` that contains values like `x`. + +`likelihoodof(k, x)` returns a likelihood object. A likelihhood is **not** a +measure, it is a function from the parameter space to `ℝ₊`. Likelihood +objects can also be interpreted as "generic densities" (but **not** as +probability densities). -likelihoodof(k, x, ::NamedTuple{()}) = Likelihood(k, x) +`likelihoodof(k, x)` implicitly chooses `ξ = rootmeasure(k(θ))` as the +reference measure on the observation set `Χ`. Note that this implicit +`ξ` **must** be independent of `θ`. -likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) +`ℒₓ = likelihoodof(k, x)` has the mathematical interpretation -likelihoodof(k, x, pars::NamedTuple) = likelihoodof(kernel(k, pars), x) +```math +\mathcal{L}_x(\theta) = \frac{\rm{d}\, k(\theta)}{\rm{d}\, \chi}(x) +``` -likelihoodof(k::AbstractTransitionKernel, x) = Likelihood(k, x) +`likelihoodof` must return an object that implements the +[`DensityInterface`](https://github.com/JuliaMath/DensityInterface.jl)` API +and `ℒₓ = likelihoodof(k, x)` must satisfy -export log_likelihood_ratio +```julia +log(ℒₓ(θ)) == logdensityof(ℒₓ, θ) ≈ logdensityof(k(θ), x) +DensityKind(ℒₓ) isa IsDensity +``` + +By default, an instance of [`MeasureBase.Likelihood`](@ref) is returned. """ - log_likelihood_ratio(ℓ::Likelihood, p, q) +function likelihoodof end -Compute the log of the likelihood ratio, in order to compare two choices for -parameters. This is computed as +likelihoodof(k, x) = Likelihood(k, x) - logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) -Since `logdensity_rel` can leave common base measure unevaluated, this can be -more efficient than +############################################################################### +# At the least, we need to think through in some more detail whether +# (log-)likelihood ratios expressed in this way are correct and useful. For now +# this code is commented out; we may remove it entirely in the future. - logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) -""" -log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# export log_likelihood_ratio -# likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) +# """ +# log_likelihood_ratio(ℓ::Likelihood, p, q) -export likelihood_ratio +# Compute the log of the likelihood ratio, in order to compare two choices for +# parameters. This is computed as -""" - likelihood_ratio(ℓ::Likelihood, p, q) +# logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) -Compute the log of the likelihood ratio, in order to compare two choices for -parameters. This is equal to +# Since `logdensity_rel` can leave common base measure unevaluated, this can be +# more efficient than - density_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# """ +# log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) -but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. -Since `density_rel` can leave common base measure unevaluated, this can be -more efficient than +# # likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) - logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) -""" -function likelihood_ratio(ℓ::Likelihood, p, q) - exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) -end +# export likelihood_ratio + +# """ +# likelihood_ratio(ℓ::Likelihood, p, q) + +# Compute the log of the likelihood ratio, in order to compare two choices for +# parameters. This is equal to + +# density_rel(ℓ.k(p), ℓ.k(q), ℓ.x) + +# but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. +# Since `density_rel` can leave common base measure unevaluated, this can be +# more efficient than + +# logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# """ +# function likelihood_ratio(ℓ::Likelihood, p, q) +# exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) +# end diff --git a/src/density.jl b/src/density.jl index 4862dcb1..a3b1b95e 100644 --- a/src/density.jl +++ b/src/density.jl @@ -98,12 +98,13 @@ DensityInterface.funcdensity(d::LogDensity) = throw(MethodError(funcdensity, (d, base :: B end -A `DensityMeasure` is a measure defined by a density or log-density with respect -to some other "base" measure. +A `DensityMeasure` is a measure defined by a density or log-density with +respect to some other "base" measure. -Users should not call `DensityMeasure` directly, but should instead call `∫(f, -base)` (if `f` is a density function or `DensityInterface.IsDensity` object) or -`∫exp(f, base)` (if `f` is a log-density function). +Users should not instantiate `DensityMeasure` directly, but should instead +call `mintegral_exp(f, base)` (if `f` is a density function or +`DensityInterface.IsDensity` object) or `mintegral_exp(f, base)` (if `f` +is a log-density function). """ struct DensityMeasure{F,B} <: AbstractMeasure f::F @@ -120,48 +121,77 @@ end end function Pretty.tile(μ::DensityMeasure{F,B}) where {F,B} - result = Pretty.literal("DensityMeasure ∫(") + result = Pretty.literal("mintegrate(") result *= Pretty.pair_layout(Pretty.tile(μ.f), Pretty.tile(μ.base); sep = ", ") result *= Pretty.literal(")") end -export ∫ +basemeasure(μ::DensityMeasure) = μ.base + +logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) + +density_def(μ::DensityMeasure, x) = densityof(μ.f, x) -""" - ∫(f, base::AbstractMeasure) -Define a new measure in terms of a density `f` over some measure `base`. + +@doc raw""" + mintegrate(f, μ::AbstractMeasure)::AbstractMeasure + +Returns a new measure that represents the indefinite +[integral](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `f` with respect to `μ`. + +`ν = mintegrate(f, μ)` generates a measure `ν` that has the mathematical +interpretation + +math``` +\nu(A) = \int_A f(a) \, \rm{d}\mu(a) +``` """ -∫(f, base) = _densitymeasure(f, base, DensityKind(f)) +function mintegrate end +export mintegrate + +mintegrate(f, μ::AbstractMeasure) = _mintegrate_impl(f, μ, DensityKind(f)) -_densitymeasure(f, base, ::IsDensity) = DensityMeasure(f, base) -function _densitymeasure(f, base, ::HasDensity) - @error "`∫(f, base)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`." +_mintegrate_impl(f, μ, ::IsDensity) = DensityMeasure(f, μ) +function _mintegrate_impl(f, μ, ::HasDensity) + throw(ArgumentError( "`mintegrate(f, mu)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`.")) end -_densitymeasure(f, base, ::NoDensity) = DensityMeasure(funcdensity(f), base) +_mintegrate_impl(f, μ, ::NoDensity) = DensityMeasure(funcdensity(f), μ) -export ∫exp -""" - ∫exp(f, base::AbstractMeasure) +@doc raw""" + mintegrate_exp(log_f, μ::AbstractMeasure) + +Given a function `log_f` that semantically represents the log of a function +`f`, `mintegrate` returns a new measure that represents the indefinite +[integral](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `f` with respect to `μ`. + +`ν = mintegrate_exp(log_f, μ)` generates a measure `ν` that has the +mathematical interpretation -Define a new measure in terms of a log-density `f` over some measure `base`. +math``` +\nu(A) = \int_A e^{log(f(a))} \, \rm{d}\mu(a) = \int_A f(a) \, \rm{d}\mu(a) +``` + +Note that `exp(log_f(...))` is usually not run explicitly, calculations that +involve the resulting measure are typically performed in log-space, +internally. """ -∫exp(f, base) = _logdensitymeasure(f, base, DensityKind(f)) +function mintegrate_exp end +export mintegrate_exp + +mintegrate_exp(log_f, μ::AbstractMeasure) = _mintegrate_exp_impl(log_f, μ, DensityKind(log_f)) -function _logdensitymeasure(f, base, ::IsDensity) - @error "`∫exp(f, base)` is not valid when `DensityKind(f) == IsDensity()`. Use `∫(f, base)` instead." +function _mintegrate_exp_impl(log_f, μ, ::IsDensity) + throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegral(log_f, μ)` instead.")) end -function _logdensitymeasure(f, base, ::HasDensity) - @error "`∫exp(f, base)` is not valid when `DensityKind(f) == HasDensity()`." +function _mintegrate_exp_impl(log_f, μ, ::HasDensity) + throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.")) end -_logdensitymeasure(f, base, ::NoDensity) = DensityMeasure(logfuncdensity(f), base) +_mintegrate_exp_impl(log_f, μ, ::NoDensity) = DensityMeasure(logfuncdensity(log_f), μ) -basemeasure(μ::DensityMeasure) = μ.base - -logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) - -density_def(μ::DensityMeasure, x) = densityof(μ.f, x) """ rebase(μ, ν) @@ -172,4 +202,4 @@ basemeasure(rebase(μ, ν)) == ν density(rebase(μ, ν)) == 𝒹(μ,ν) ``` """ -rebase(μ, ν) = ∫(𝒹(μ, ν), ν) +rebase(μ, ν) = mintegrate(density_rel(μ, ν), ν) From fb3c98cecb995e48b64367ccc7122081c9113b2b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 16/28] Remove the rebase function A rebase can easily be written explicitly. --- src/MeasureBase.jl | 1 - src/density.jl | 12 ------------ 2 files changed, 13 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index ceaadae1..c38685d2 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -38,7 +38,6 @@ using FunctionChains export ≪ export gentype -export rebase export AbstractMeasure diff --git a/src/density.jl b/src/density.jl index a3b1b95e..ea976462 100644 --- a/src/density.jl +++ b/src/density.jl @@ -191,15 +191,3 @@ function _mintegrate_exp_impl(log_f, μ, ::HasDensity) throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.")) end _mintegrate_exp_impl(log_f, μ, ::NoDensity) = DensityMeasure(logfuncdensity(log_f), μ) - - -""" - rebase(μ, ν) - -Express `μ` in terms of a density over `ν`. Satisfies -``` -basemeasure(rebase(μ, ν)) == ν -density(rebase(μ, ν)) == 𝒹(μ,ν) -``` -""" -rebase(μ, ν) = mintegrate(density_rel(μ, ν), ν) From 86db05e2c734b6073117d6f6176576addf8a9ae7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 17/28] Rename bind to mbind and remove fish operator --- src/combinators/bind.jl | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 491bd5a2..e987ea8e 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -33,17 +33,3 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} y = rand(rng, T, d.k(x)) return y end - - -# ToDo: Remove `bind` (breaking). -@noinline function bind(μ, k) - Base.depwarn("`foo(μ, k)` is deprecated, use `mbind(k, μ)` instead.", :bind) - mbind(k, μ) -end - - -# ToDo: Remove `↣` (breaking): It looks too similar to the `>=>` "fish" -# operator (e.g. in Haskell) that is typically understood to take two monadic -# functions as an argument, while a bind take a monad and a monadic functions. -@deprecate ↣(μ, k) mbind(μ, k) -export ↣ From 58f07468158ab0d20de1afbbe1e7f84f21674c8d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 18/28] Change field order of Bind and improve docs. --- src/combinators/bind.jl | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index e987ea8e..a27be6e0 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -1,6 +1,20 @@ +""" + struct MeasureBase.Bind{M,K} <: AbstractMeasure + +Represents a monatic bind. User code should create instances of `Bind` +directly, but should call `mbind(k, μ)` instead. +""" struct Bind{M,K} <: AbstractMeasure - μ::M k::K + μ::M +end + +getdof(d::Bind) = NoDOF{typeof(d)}() + +function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} + x = rand(rng, T, d.μ) + y = rand(rng, T, d.k(x)) + return y end @@ -25,11 +39,5 @@ unavailable in Julia. end ``` """ -mbind(k, μ) = Bind(μ, k) +mbind(k, μ) = Bind(k, μ) export mbind - -function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} - x = rand(rng, T, d.μ) - y = rand(rng, T, d.k(x)) - return y -end From 0cdca3d443b14b28313264f17bfdc085a190c394 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 19/28] Remove operator otimes To be re-introduced in sub-module MeasureOperators. --- src/combinators/product.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/combinators/product.jl b/src/combinators/product.jl index cb7a0aaf..3a5c6494 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -167,18 +167,6 @@ function testvalue(::Type{T}, d::AbstractProductMeasure) where {T} _map(m -> testvalue(T, m), marginals(d)) end -export ⊗ - -""" - ⊗(μs::AbstractMeasure...) - -`⊗` is a binary operator for building product measures. This satisfies the law - -``` - basemeasure(μ ⊗ ν) == basemeasure(μ) ⊗ basemeasure(ν) -``` -""" -⊗(μs::AbstractMeasure...) = productmeasure(μs) ############################################################################### # I <: Base.Generator From 3c611806cd54740d2458a5192c97f77085c186cb Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 20/28] Removes PointwiseProductMeasure `mintegral` should be used instead to express posteriors. --- src/MeasureBase.jl | 1 - src/combinators/pointwise.jl | 30 ------------------------------ 2 files changed, 31 deletions(-) delete mode 100644 src/combinators/pointwise.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index c38685d2..8798e902 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -172,7 +172,6 @@ include("combinators/product.jl") include("combinators/power.jl") include("combinators/spikemixture.jl") include("combinators/likelihood.jl") -include("combinators/pointwise.jl") include("combinators/restricted.jl") include("combinators/smart-constructors.jl") include("combinators/conditional.jl") diff --git a/src/combinators/pointwise.jl b/src/combinators/pointwise.jl deleted file mode 100644 index 778e7f4e..00000000 --- a/src/combinators/pointwise.jl +++ /dev/null @@ -1,30 +0,0 @@ -export ⊙ - -struct PointwiseProductMeasure{P,L} <: AbstractMeasure - prior::P - likelihood::L -end - -iterate(p::PointwiseProductMeasure, i = 1) = iterate((p.prior, p.likelihood), i) - -function Pretty.tile(d::PointwiseProductMeasure) - Pretty.pair_layout(Pretty.tile(d.prior), Pretty.tile(d.likelihood), sep = " ⊙ ") -end - -⊙(prior, ℓ) = pointwiseproduct(prior, ℓ) - -@inbounds function insupport(d::PointwiseProductMeasure, p) - prior, ℓ = d - istrue(insupport(prior, p)) && istrue(insupport(ℓ, p)) -end - -@inline function logdensity_def(d::PointwiseProductMeasure, p) - prior, ℓ = d - unsafe_logdensityof(ℓ, p) -end - -basemeasure(d::PointwiseProductMeasure) = d.prior - -function gentype(d::PointwiseProductMeasure) - gentype(d.prior) -end From 867adbeeeb0fe2bad0a2d691ddb54b050d0aa8e5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 21/28] Remove scrd operator To be reintroduced in submodule MeasureOperators --- src/density.jl | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/src/density.jl b/src/density.jl index ea976462..bba69985 100644 --- a/src/density.jl +++ b/src/density.jl @@ -20,8 +20,7 @@ For measures `μ` and `ν`, `Density(μ,ν)` represents the _density function_ `dμ/dν`, also called the _Radom-Nikodym derivative_: https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem#Radon%E2%80%93Nikodym_derivative -Instead of calling this directly, users should call `density_rel(μ, ν)` or -its abbreviated form, `𝒹(μ,ν)`. +Instead of calling this directly, users should call `density_rel(μ, ν)`. """ struct Density{M,B} <: AbstractDensity μ::M @@ -32,15 +31,6 @@ Base.:∘(::typeof(log), d::Density) = logdensity_rel(d.μ, d.base) Base.log(d::Density) = log ∘ d -export 𝒹 - -""" - 𝒹(μ, base) - -Compute the density (Radom-Nikodym derivative) of μ with respect to `base`. This -is a shorthand form for `density_rel(μ, base)`. -""" -𝒹(μ, base) = density_rel(μ, base) density_rel(μ, base) = Density(μ, base) @@ -73,16 +63,6 @@ Base.:∘(::typeof(exp), d::LogDensity) = density_rel(d.μ, d.base) Base.exp(d::LogDensity) = exp ∘ d -export log𝒹 - -""" - log𝒹(μ, base) - -Compute the log-density (Radom-Nikodym derivative) of μ with respect to `base`. -This is a shorthand form for `logdensity_rel(μ, base)` -""" -log𝒹(μ, base) = logdensity_rel(μ, base) - logdensity_rel(μ, base) = LogDensity(μ, base) (f::LogDensity)(x) = logdensity_rel(f.μ, f.base, x) From e2b88a22680a20f7467aa5b03268201249131681 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 22/28] Remove ll-operator Absolute continuity is not really implemented yet. --- src/MeasureBase.jl | 1 - src/absolutecontinuity.jl | 3 +++ src/combinators/weighted.jl | 3 --- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 8798e902..678a0dc1 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -36,7 +36,6 @@ using Static using Static: StaticInteger using FunctionChains -export ≪ export gentype export AbstractMeasure diff --git a/src/absolutecontinuity.jl b/src/absolutecontinuity.jl index 8062198c..c65aeaf3 100644 --- a/src/absolutecontinuity.jl +++ b/src/absolutecontinuity.jl @@ -54,3 +54,6 @@ # representative(μ) ≪ representative(ν) && return true # return false # end + +# ≪(::M, ::WeightedMeasure{R,M}) where {R,M} = true +# ≪(::WeightedMeasure{R,M}, ::M) where {R,M} = true diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index db239b50..124662b6 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -46,9 +46,6 @@ end Base.:*(m::AbstractMeasure, k::Real) = k * m -≪(::M, ::WeightedMeasure{R,M}) where {R,M} = true -≪(::WeightedMeasure{R,M}, ::M) where {R,M} = true - gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) From 5404ff1b492112913d4d0060e2050ba78875bdb0 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 23/28] Add measure operators in submodule MeasureOperators Having the operators in a sub-module makes it easier for users to control whether of they want them in their namespace. Operators have a larger naming conflict potential. --- src/MeasureBase.jl | 2 + src/measure_operators.jl | 141 ++++++++++++++++++++++++++++++++++++++ test/measure_operators.jl | 24 +++++++ test/runtests.jl | 2 + 4 files changed, 169 insertions(+) create mode 100644 src/measure_operators.jl create mode 100644 test/measure_operators.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 678a0dc1..9f23fc8a 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -190,6 +190,8 @@ include("density-core.jl") include("interface.jl") +include("measure_operators.jl") + using .Interface end # module MeasureBase diff --git a/src/measure_operators.jl b/src/measure_operators.jl new file mode 100644 index 00000000..ee8a7d8d --- /dev/null +++ b/src/measure_operators.jl @@ -0,0 +1,141 @@ +""" + module MeasureOperators + +Defines the following operators for measures: + +* `f ⋄ μ == pushfwd(f, μ)` + +* `μ ⊙ f == inverse(f) ⋄ μ` +""" +module MeasureOperators + +using MeasureBase: AbstractMeasure +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel +using InverseFunctions: inverse +using Reexport: @reexport + + +@doc raw""" + ⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) + +The `\\diamond` operator denotes a pushforward operation: `ν = f ⋄ μ` +generates a +[pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure). + +A common mathematical notation for a pushforward is ``f_*μ``, but as +there is no "subscript-star" operator in Julia, we use `⋄`. + +See [`pushfwd(f, μ)`](@ref) for details. + +Also see [`ν ⊙ f`](@ref), the pullback operator. +""" +⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) +export ⋄ + + +@doc raw""" + ⊙(ν::AbstractMeasure, f) = pullbck(f, ν) + +The `\\odot` operator denotes a pullback operation. + +See also [`pullbck(ν, f)`](@ref) for details. Note that `pullbck` takes it's +arguments in different order, in keeping with the Julia convention of +passing functions as the first argument. A pullback is mathematically the +precomposition of a measure `μ`` with the function `f` applied to sets. so +`⊙` takes the measure as the first and the function as the second argument, +as common in mathematical notation for precomposition. + +A common mathematical notation for pullback in measure theory is +``f \circ μ``, but as `∘` is used for function composition in Julia and as +`f` semantically acts point-wise on sets, we use `⊙`. + +Also see [f ⋄ μ](@ref), the pushforward operator. +""" +⊙(ν::AbstractMeasure, f) = pullbck(f, ν) +export ⊙ + + +""" + μ ▷ k = mbind(k, μ) + +The `\\triangleright` operator denotes a measure monadic bind operation. + +A common operator choice for a monadics bind operator is `>>=` (e.g. in +the Haskell programming language), but this has a different meaning in +Julia and there is no close equivalent, so we use `▷`. + +See [`mbind(k, μ)`](@ref) for details. Note that `mbind` takes its +arguments in different order, in keeping with the Julia convention of +passing functions as the first argument. `▷`, on the other hand, takes +its arguments in the order common for monadic binds in functional +programming (like the Haskell `>>=` operator) and mathematics. +""" +▷(μ::AbstractMeasure,k) = mbind(k, μ) +export ▷ + + +# ToDo: Use `⨂` instead of `⊗` for better readability? +""" + ⊗(μs::AbstractMeasure...) = productmeasure(μs) + +`⊗` is an operator for building product measures. + +See [`productmeasure(μs)`](@ref) for details. +""" +⊗(μs::AbstractMeasure...) = productmeasure(μs) +export ⊗ + + +""" + ∫(f, μ::AbstractMeasure) = mintegrate(f, μ) + +Denotes an indefinite integral of the function `f` with respect to the +measure `μ`. + +See [`mintegrate(f, μ)`](@ref) for details. +""" +∫(f, μ::AbstractMeasure) = mintegrate(f, μ) +export ∫ + + +""" + ∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) + +Generates a new measure that is the indefinite integral of `exp` of `f` +with respect to the measure `μ`. + +See [`mintegrate_exp(f, μ)`](@ref) for details. +""" +∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) +export ∫exp + + +""" + 𝒹(ν, μ) = density_rel(ν, μ) + +Compute the density, i.e. the +[Radom-Nikodym derivative](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `ν`` with respect to `μ`. + +For details, see [`density_rel(ν, μ)`}(@ref). +""" +𝒹(ν, μ::AbstractMeasure) = density_rel(ν, μ) +export 𝒹 + + + +""" + log𝒹(ν, μ) = logdensity_rel(ν, μ) + +Compute the log-density, i.e. the logarithm of the +[Radom-Nikodym derivative](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `ν`` with respect to `μ`. + +For details, see [`logdensity_rel(ν, μ)`}(@ref). +""" +log𝒹(ν, μ::AbstractMeasure) = logdensity_rel(ν, μ) +export log𝒹 + + +end # module MeasureOperators diff --git a/test/measure_operators.jl b/test/measure_operators.jl new file mode 100644 index 00000000..a3adaa8f --- /dev/null +++ b/test/measure_operators.jl @@ -0,0 +1,24 @@ +using Test + +using MeasureBase: AbstractMeasure +using MeasureBase: StdExponential, StdLogistic, StdUniform +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel +using MeasureBase.MeasureOperators: ⋄, ⊙, ▷, ⊗, ∫, ∫exp, 𝒹, log𝒹 + +@testset "MeasureOperators" begin + μ = StdExponential() + ν = StdUniform() + k(σ) = pushfwd(x -> σ * x, StdNormal()) + μs = (StdExponential(), StdLogistic(), StdUniform()) + f = sqrt + + @test @inferred(f ⋄ μ) == pushfwd(f, μ) + @test @inferred(ν ⊙ f) == pullbck(f, ν) + @test @inferred(μ ▷ k) == mbind(k, μ) + @test @inferred(⊗(μs...)) == productmeasure(μs) + @test @inferred(∫(f, μ)) == mintegrate(f, μ) + @test @inferred(∫exp(f, μ)) == mintegrate_exp(f, μ) + @test @inferred(𝒹(ν, μ)) == density_rel(ν, μ) + @test @inferred(log𝒹(ν, μ)) == logdensity_rel(ν, μ) +end diff --git a/test/runtests.jl b/test/runtests.jl index 374d56be..97314dd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,4 +23,6 @@ include("combinators/reshape.jl") include("test_distributions.jl") +include("measure_operators.jl") + include("test_docs.jl") From aa7416bba982ea596840410fcfe3ab4d79e678e5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 24/28] Improve docstring for mbind Co-authored-by: Chad Scherrer --- src/combinators/bind.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index a27be6e0..4b34e7a2 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -28,6 +28,9 @@ Given measure The *monadic bind* operation `mbind(k, μ)` returns is a new measure. +If `ν == mbind(k, μ)` and all measures involved are sampleable, then +samples from `rand(ν)` follow the same distribution as those from `rand(k(rand(μ)))`. + A monadic bind ofen written as `>>=` (e.g. in Haskell), but this symbol is unavailable in Julia. From 470db239a83c0aebae1dcce876f8cd59185d0a97 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 25/28] Improve likelihood docs Co-authored-by: Chad Scherrer --- src/combinators/likelihood.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 93dc1186..c5229b38 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -93,7 +93,7 @@ Finally, let's return to the expression for Bayes's Law, ``P(θ|x) ∝ P(x|θ) P(θ)`` -In measure theory, the product on the right side is actually the Lebesgue integral, +In measure theory, the product on the right side is the Lebesgue integral of the likelihood with respect to the prior. For example, say we have From 28f25ef677dd8fc97f489e14c07d9edc2eacab1e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 26/28] Apply JuliaFormatter --- src/combinators/bind.jl | 1 - src/combinators/likelihood.jl | 28 +--------------------------- src/combinators/product.jl | 1 - src/density.jl | 26 ++++++++++++++++++-------- src/measure_operators.jl | 12 +----------- src/static.jl | 4 +++- test/static.jl | 11 +++++++---- 7 files changed, 30 insertions(+), 53 deletions(-) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index 4b34e7a2..465f2bf7 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -17,7 +17,6 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} return y end - """ mbind(k, μ)::AbstractMeasure diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index c5229b38..9b6ac567 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -64,31 +64,6 @@ With several parameters, things work as expected: --------- - Likelihood(M<:ParameterizedMeasure, constraint::NamedTuple, x) - -In some cases the measure might have several parameters, and we may want the -(log-)likelihood with respect to some subset of them. In this case, we can use -the three-argument form, where the second argument is a constraint. For example, - - julia> ℓ = Likelihood(Normal{(:μ,:σ)}, (σ=3.0,), 2.0) - Likelihood(Normal{(:μ, :σ), T} where T, (σ = 3.0,), 2.0) - -Similarly to the above, we have - - julia> density_def(ℓ, (μ=2.0,)) - 0.3333333333333333 - - julia> logdensity_def(ℓ, (μ=2.0,)) - -1.0986122886681098 - - julia> density_def(ℓ, 2.0) - 0.3333333333333333 - - julia> logdensity_def(ℓ, 2.0) - -1.0986122886681098 - ------------------------ - Finally, let's return to the expression for Bayes's Law, ``P(θ|x) ∝ P(x|θ) P(θ)`` @@ -118,7 +93,7 @@ struct Likelihood{K,X} <: AbstractLikelihood x::X Likelihood(k::K, x::X) where {K,X} = new{K,X}(k, x) -#!!!!!!!!!!! # For type stability if `K isa UnionAll (e.g. a parameterized MeasureType)` + #!!!!!!!!!!! # For type stability if `K isa UnionAll (e.g. a parameterized MeasureType)` Likelihood(::Type{K}, x::X) where {K<:AbstractMeasure,X} = new{K,X}(K, x) end @@ -191,7 +166,6 @@ function likelihoodof end likelihoodof(k, x) = Likelihood(k, x) - ############################################################################### # At the least, we need to think through in some more detail whether # (log-)likelihood ratios expressed in this way are correct and useful. For now diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 3a5c6494..516678f5 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -167,7 +167,6 @@ function testvalue(::Type{T}, d::AbstractProductMeasure) where {T} _map(m -> testvalue(T, m), marginals(d)) end - ############################################################################### # I <: Base.Generator diff --git a/src/density.jl b/src/density.jl index bba69985..0f4c8e03 100644 --- a/src/density.jl +++ b/src/density.jl @@ -31,7 +31,6 @@ Base.:∘(::typeof(log), d::Density) = logdensity_rel(d.μ, d.base) Base.log(d::Density) = log ∘ d - density_rel(μ, base) = Density(μ, base) (f::Density)(x) = density_rel(f.μ, f.base, x) @@ -112,8 +111,6 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) density_def(μ::DensityMeasure, x) = densityof(μ.f, x) - - @doc raw""" mintegrate(f, μ::AbstractMeasure)::AbstractMeasure @@ -135,11 +132,14 @@ mintegrate(f, μ::AbstractMeasure) = _mintegrate_impl(f, μ, DensityKind(f)) _mintegrate_impl(f, μ, ::IsDensity) = DensityMeasure(f, μ) function _mintegrate_impl(f, μ, ::HasDensity) - throw(ArgumentError( "`mintegrate(f, mu)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`.")) + throw( + ArgumentError( + "`mintegrate(f, mu)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`.", + ), + ) end _mintegrate_impl(f, μ, ::NoDensity) = DensityMeasure(funcdensity(f), μ) - @doc raw""" mintegrate_exp(log_f, μ::AbstractMeasure) @@ -162,12 +162,22 @@ internally. function mintegrate_exp end export mintegrate_exp -mintegrate_exp(log_f, μ::AbstractMeasure) = _mintegrate_exp_impl(log_f, μ, DensityKind(log_f)) +function mintegrate_exp(log_f, μ::AbstractMeasure) + _mintegrate_exp_impl(log_f, μ, DensityKind(log_f)) +end function _mintegrate_exp_impl(log_f, μ, ::IsDensity) - throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegral(log_f, μ)` instead.")) + throw( + ArgumentError( + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegral(log_f, μ)` instead.", + ), + ) end function _mintegrate_exp_impl(log_f, μ, ::HasDensity) - throw(ArgumentError("`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.")) + throw( + ArgumentError( + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.", + ), + ) end _mintegrate_exp_impl(log_f, μ, ::NoDensity) = DensityMeasure(logfuncdensity(log_f), μ) diff --git a/src/measure_operators.jl b/src/measure_operators.jl index ee8a7d8d..5822d4de 100644 --- a/src/measure_operators.jl +++ b/src/measure_operators.jl @@ -15,7 +15,6 @@ using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel using InverseFunctions: inverse using Reexport: @reexport - @doc raw""" ⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) @@ -33,7 +32,6 @@ Also see [`ν ⊙ f`](@ref), the pullback operator. ⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) export ⋄ - @doc raw""" ⊙(ν::AbstractMeasure, f) = pullbck(f, ν) @@ -55,7 +53,6 @@ Also see [f ⋄ μ](@ref), the pushforward operator. ⊙(ν::AbstractMeasure, f) = pullbck(f, ν) export ⊙ - """ μ ▷ k = mbind(k, μ) @@ -71,10 +68,9 @@ passing functions as the first argument. `▷`, on the other hand, takes its arguments in the order common for monadic binds in functional programming (like the Haskell `>>=` operator) and mathematics. """ -▷(μ::AbstractMeasure,k) = mbind(k, μ) +▷(μ::AbstractMeasure, k) = mbind(k, μ) export ▷ - # ToDo: Use `⨂` instead of `⊗` for better readability? """ ⊗(μs::AbstractMeasure...) = productmeasure(μs) @@ -86,7 +82,6 @@ See [`productmeasure(μs)`](@ref) for details. ⊗(μs::AbstractMeasure...) = productmeasure(μs) export ⊗ - """ ∫(f, μ::AbstractMeasure) = mintegrate(f, μ) @@ -98,7 +93,6 @@ See [`mintegrate(f, μ)`](@ref) for details. ∫(f, μ::AbstractMeasure) = mintegrate(f, μ) export ∫ - """ ∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) @@ -110,7 +104,6 @@ See [`mintegrate_exp(f, μ)`](@ref) for details. ∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) export ∫exp - """ 𝒹(ν, μ) = density_rel(ν, μ) @@ -123,8 +116,6 @@ For details, see [`density_rel(ν, μ)`}(@ref). 𝒹(ν, μ::AbstractMeasure) = density_rel(ν, μ) export 𝒹 - - """ log𝒹(ν, μ) = logdensity_rel(ν, μ) @@ -137,5 +128,4 @@ For details, see [`logdensity_rel(ν, μ)`}(@ref). log𝒹(ν, μ::AbstractMeasure) = logdensity_rel(ν, μ) export log𝒹 - end # module MeasureOperators diff --git a/src/static.jl b/src/static.jl index b723d043..da471b62 100644 --- a/src/static.jl +++ b/src/static.jl @@ -49,7 +49,9 @@ Returns the length of `x` as a dynamic or static integer. """ maybestatic_length(x) = length(x) maybestatic_length(x::AbstractUnitRange) = length(x) -function maybestatic_length(::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}) where {A,B} +function maybestatic_length( + ::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}, +) where {A,B} StaticInt{B - A + 1}() end diff --git a/test/static.jl b/test/static.jl index a6c50db2..f618124b 100644 --- a/test/static.jl +++ b/test/static.jl @@ -11,7 +11,7 @@ import FillArrays @test static(2) isa MeasureBase.IntegerLike @test true isa MeasureBase.IntegerLike @test static(true) isa MeasureBase.IntegerLike - + @test @inferred(MeasureBase.one_to(7)) isa Base.OneTo @test @inferred(MeasureBase.one_to(7)) == 1:7 @test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo @@ -19,10 +19,13 @@ import FillArrays @test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7) @test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == FillArrays.Fill(4.2, 3, 7) + @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == + FillArrays.Fill(4.2, 3, 7) @test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == FillArrays.Fill(4.2, (3:7, 2:5)) + @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == + FillArrays.Fill(4.2, (3:7,)) + @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == + FillArrays.Fill(4.2, (3:7, 2:5)) @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) == 7 From bf35d4020fa9bb9438785b1c8e60b096aaac6e71 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 27/28] Improve Likelihood ctor --- src/combinators/likelihood.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 9b6ac567..b244fd0f 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -92,11 +92,12 @@ struct Likelihood{K,X} <: AbstractLikelihood k::K x::X - Likelihood(k::K, x::X) where {K,X} = new{K,X}(k, x) - #!!!!!!!!!!! # For type stability if `K isa UnionAll (e.g. a parameterized MeasureType)` - Likelihood(::Type{K}, x::X) where {K<:AbstractMeasure,X} = new{K,X}(K, x) + Likelihood{K,X}(k, x) where {K,X} = new{K,X}(k, x) end +# For type stability, in case k is a type (resp. a constructor): +Likelihood(k, x::X) where {X} = Likelihood{Core.Typeof(k),X}(k, x) + (lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity() From 7aae58e9f20c580e3ae81f91fafff905870131bc Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 1 Nov 2024 20:28:08 +0100 Subject: [PATCH 28/28] Fix typo in _mintegrate_exp_impl exception --- src/density.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/density.jl b/src/density.jl index 0f4c8e03..57367ec5 100644 --- a/src/density.jl +++ b/src/density.jl @@ -169,7 +169,7 @@ end function _mintegrate_exp_impl(log_f, μ, ::IsDensity) throw( ArgumentError( - "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegral(log_f, μ)` instead.", + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegrate(log_f, μ)` instead.", ), ) end