diff --git a/Project.toml b/Project.toml index 7df6400..2328021 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ExponentialFamilyProjection" uuid = "17f509fa-9a96-44ba-99b2-1c5f01f0931b" authors = ["Mykola Lukashchuk ", "Dmitry Bagaev ", "Albert Podusenko "] -version = "1.1.2" +version = "1.2.0" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" @@ -30,7 +30,7 @@ BayesBase = "1.3" Bumper = "0.6" Distributions = "0.25" ExponentialFamily = "1.5" -ExponentialFamilyManifolds = "1.2" +ExponentialFamilyManifolds = "1.5" FastCholesky = "1.3" FillArrays = "1" ForwardDiff = "0.10.36" diff --git a/src/projected_to.jl b/src/projected_to.jl index 36548e7..875493b 100644 --- a/src/projected_to.jl +++ b/src/projected_to.jl @@ -16,6 +16,7 @@ The following arguments are optional: * `conditioner = nothing`: a conditioner to use for the projection, not all exponential family members require a conditioner, but some do, e.g. `Laplace` * `parameters = DefaultProjectionParameters`: parameters for the projection procedure +* `kwargs = nothing`: Additional arguments passed to `Manopt.gradient_descent!` (optional). For details on `gradient_descent!` parameters, see the [Manopt.jl documentation](https://manoptjl.org/stable/solvers/gradient_descent/#Manopt.gradient_descent). Note, that `kwargs` passed to `project_to` take precedence over `kwargs` specified in the parameters. ```jldoctest julia> using ExponentialFamily @@ -33,28 +34,32 @@ julia> projected_to = ProjectedTo(Laplace, conditioner = 2.0) ProjectedTo(Laplace, conditioner = 2.0) ``` """ -struct ProjectedTo{T,D,C,P} +struct ProjectedTo{T,D,C,P,E} dims::D conditioner::C parameters::P + kwargs::E end ProjectedTo( dims::Vararg{Int}; conditioner = nothing, parameters = DefaultProjectionParameters(), + kwargs = nothing, ) = ProjectedTo( ExponentialFamilyDistribution, dims..., conditioner = conditioner, parameters = parameters, + kwargs = kwargs, ) function ProjectedTo( ::Type{T}, dims...; conditioner::C = nothing, parameters::P = DefaultProjectionParameters(), -) where {T,C,P} + kwargs::E = nothing, +) where {T,C,P,E} # Check that `dims` are all integers if !all(d -> typeof(d) <: Int, dims) # If not, throw an error, also suggesting to use keyword arguments @@ -65,13 +70,14 @@ function ProjectedTo( end error(msg) end - return ProjectedTo{T,typeof(dims),C,P}(dims, conditioner, parameters) + return ProjectedTo{T,typeof(dims),C,P,E}(dims, conditioner, parameters, kwargs) end get_projected_to_type(::ProjectedTo{T}) where {T} = T get_projected_to_dims(prj::ProjectedTo) = prj.dims get_projected_to_conditioner(prj::ProjectedTo) = prj.conditioner get_projected_to_parameters(prj::ProjectedTo) = prj.parameters +get_projected_to_kwargs(prj::ProjectedTo) = prj.kwargs get_projected_to_manifold(prj::ProjectedTo) = ExponentialFamilyManifolds.get_natural_manifold( get_projected_to_type(prj), @@ -268,9 +274,19 @@ function project_to( supplementary_η, ) + # First we query the `kwargs` defined in the `ProjectionParameters` + prj_kwargs = get_projected_to_kwargs(prj) + prj_kwargs = isnothing(prj_kwargs) ? (;) : prj_kwargs + # And attach the `kwargs` passed to `project_to`, those may override + # some settings in the `ProjectionParameters` + if !isnothing(kwargs) + prj_kwargs = (; prj_kwargs..., kwargs...) + end # We disable the default `debug` statements, which are set in `Manopt` # in order to improve the performance a little bit - kwargs = !haskey(kwargs, :debug) ? (; kwargs..., debug = missing) : kwargs + if !haskey(prj_kwargs, :debug) + prj_kwargs = (; prj_kwargs..., debug = missing) + end return _kernel_project_to( get_projected_to_type(prj), @@ -281,7 +297,7 @@ function project_to( strategy, state, current_η, - kwargs, + prj_kwargs, ) end diff --git a/src/strategies/control_variate.jl b/src/strategies/control_variate.jl index 39c2c68..6e2d8ac 100644 --- a/src/strategies/control_variate.jl +++ b/src/strategies/control_variate.jl @@ -295,18 +295,6 @@ function control_variate_compute_gradient_buffered!( state.sufficientstatistics', state.gradsamples', ) - # -- - - # Next we compute the `corr_matrix` using the sample principle, preallocate the storage - # The naive code would be `corr_matrix = cov_matrix * inv_fisher` - # -- - corr_matrix = @alloc( - promote_type(eltype(cov_matrix), eltype(inv_fisher)), - size(cov_matrix, 1), - size(inv_fisher, 2) - ) - mul!(corr_matrix, cov_matrix, inv_fisher) - # -- # Compute means of sufficientstatistics and gradsamples inplace # The naive code would be @@ -323,19 +311,22 @@ function control_variate_compute_gradient_buffered!( # The next four lines finish the computation, and essentially equivalent to the following code # `estimated_grad_vector = mean_gradsamples - corr_matrix * (mean_sufficientstats - gradlogpartition)` + # where `corr_matrix = cov_matrix * inv_fisher` # `ef_gradient = η - inv_fisher * estimated_grad_vector` # or (η - (η_ef + inv_fisher * estimated_grad_vector)) # -- tmp1 = @alloc( promote_type(eltype(mean_sufficientstats), eltype(gradlogpartition)), length(mean_sufficientstats) ) - tmp2 = @alloc(promote_type(eltype(corr_matrix), eltype(tmp1)), length(tmp1)) + tmp2 = @alloc(promote_type(eltype(inv_fisher), eltype(tmp1)), length(tmp1)) + tmp3 = @alloc(promote_type(eltype(cov_matrix), eltype(tmp2)), length(tmp2)) map!(-, tmp1, mean_sufficientstats, gradlogpartition) # tmp1 = (mean_sufficientstats - gradlogpartition) - mul!(tmp2, corr_matrix, tmp1) # tmp2 = corr_matrix * tmp1 - map!(-, tmp1, mean_gradsamples, tmp2) # tmp1 = estimated_grad_vector = mean_gradsamples - tmp2 - mul!(tmp2, inv_fisher, tmp1) # tmp2 = inv_fisher * estimated_grad_vector - map!(-, X, η, tmp2) # X .= η .- tmp2 + mul!(tmp2, inv_fisher, tmp1) # tmp2 = inv_fisher * tmp1 + mul!(tmp3, cov_matrix, tmp2) # tmp3 = cov_matrix * tmp2, such that tmp3 = cov_matrix * inv_fisher * tmp1 + map!(-, tmp1, mean_gradsamples, tmp3) # tmp1 = estimated_grad_vector = mean_gradsamples - tmp3 + mul!(tmp3, inv_fisher, tmp1) # tmp3 = inv_fisher * estimated_grad_vector + map!(-, X, η, tmp3) # X .= η .- tmp3 # -- nothing diff --git a/test/projection/helpers/debug.jl b/test/projection/helpers/debug.jl index 13545c2..9108c8c 100644 --- a/test/projection/helpers/debug.jl +++ b/test/projection/helpers/debug.jl @@ -4,17 +4,27 @@ include("../projected_to_setuptests.jl") - function test_projection_with_debug(n_iterations, do_debug) + function test_projection_with_debug(n_iterations, do_debug, pass_to_prj = false) distribution = Bernoulli(0.5) buf = IOBuffer() if do_debug - debug = [DebugCost(io=buf), DebugDivider("\n";io=buf)] + debug = [DebugCost(io = buf), DebugDivider("\n"; io = buf)] else debug = missing end - projection_parameters = ProjectionParameters(niterations=n_iterations) - project_to(ProjectedTo(Bernoulli, parameters = projection_parameters), (x) -> logpdf(distribution, x); debug=debug) + projection_parameters = ProjectionParameters(niterations = n_iterations) + if pass_to_prj + prj = ProjectedTo( + Bernoulli, + parameters = projection_parameters, + kwargs = (debug = debug,), + ) + project_to(prj, (x) -> logpdf(distribution, x)) + else + prj = ProjectedTo(Bernoulli, parameters = projection_parameters) + project_to(prj, (x) -> logpdf(distribution, x); debug = debug) + end debug_string = String(take!(buf)) if do_debug lines = split(debug_string, '\n') @@ -26,14 +36,20 @@ end @testset "projections with debug" begin - for n in 1:10 + for n = 1:10 test_projection_with_debug(n, true) end end @testset "projections without debug" begin - for n in 1:10 + for n = 1:10 test_projection_with_debug(n, false) end end + + @testset "projections with debug passed through ProjectedTo" begin + for n = 1:10 + test_projection_with_debug(n, true, true) + end + end end \ No newline at end of file diff --git a/test/projection/projected_to_normal_tests.jl b/test/projection/projected_to_normal_tests.jl index 0469301..fc2bac6 100644 --- a/test/projection/projected_to_normal_tests.jl +++ b/test/projection/projected_to_normal_tests.jl @@ -69,10 +69,7 @@ end @testset let distribution = MvNormalMeanCovariance(10randn(StableRNG(42), 4), 10rand(StableRNG(43), 4)) - @test test_projection_convergence( - distribution, - niterations_range = 500:100:2000 - ) + @test test_projection_convergence(distribution, niterations_range = 500:100:2000) end end @@ -88,14 +85,60 @@ end @test test_projection_convergence( distribution, to = MvNormalMeanCovariance, - dims = (2, ), + dims = (2,), conditioner = nothing, ) end end +@testitem "Project a product of `MvNormalMeanScalePrecision` and `MvNormalMeanScalePrecision` to `MvNormalMeanScalePrecision`" begin + using BayesBase, ExponentialFamily, Distributions, LinearAlgebra + + include("./projected_to_setuptests.jl") + + @testset let distribution = ProductOf( + MvNormalMeanScalePrecision(ones(2), 2), + MvNormalMeanScalePrecision(ones(2), 3), + ) + @test test_projection_convergence( + distribution, + to = MvNormalMeanScalePrecision, + dims = (2,), + conditioner = nothing, + ) + end + + @testset let distribution = ProductOf( + MvNormalMeanScalePrecision(ones(8), 2), + MvNormalMeanScalePrecision(ones(8), 3), + ) + @test test_projection_convergence( + distribution, + to = MvNormalMeanScalePrecision, + dims = (8,), + conditioner = nothing, + ) + end + + @testset let distribution = ProductOf( + MvNormalMeanScalePrecision(ones(20), 2), + MvNormalMeanScalePrecision(ones(20), 3), + ) + @test test_projection_convergence( + distribution, + to = MvNormalMeanScalePrecision, + dims = (20,), + conditioner = nothing, + nsamples_niterations = 6000, + nsamples_range = 1000:1000:6000, + niterations_range = 400:100:1000, + nsamples_required_accuracy=0.3, + niterations_required_accuracy=0.3 + ) + end +end @testitem "MLE" begin using BayesBase, ExponentialFamily, Distributions, JET @@ -123,6 +166,9 @@ end @test test_projection_mle(distribution) end + @testset let distribution = MvNormalMeanScalePrecision(ones(2), 2) + @test test_projection_mle(distribution) + end @testset let distribution = MvNormalMeanCovariance(ones(2), Matrix(Diagonal(ones(2)))) @test test_projection_mle(distribution) @@ -137,10 +183,7 @@ end @testset let distribution = MvNormalMeanCovariance(10randn(StableRNG(42), 4), 10rand(StableRNG(43), 4)) - @test test_projection_mle( - distribution, - niterations_range = 500:100:2000 - ) + @test test_projection_mle(distribution, niterations_range = 500:100:2000) end end \ No newline at end of file diff --git a/test/projection/projected_to_tests.jl b/test/projection/projected_to_tests.jl index 96a85ef..cfd01b3 100644 --- a/test/projection/projected_to_tests.jl +++ b/test/projection/projected_to_tests.jl @@ -431,35 +431,76 @@ end # Do not produce debug output by default @test_logs match_mode = :all project_to(prj, targetfn) @test_logs match_mode = :all project_to(prj, targetfn, debug = []) - + +end + +@testitem "kwargs in the `project_to` should take precedence over kwargs in `ProjectionParameters`" begin + using ExponentialFamilyProjection, StableRNGs, ExponentialFamily, Manopt, JET + + @testset begin + rng = StableRNG(42) + prj = ProjectedTo(Beta; kwargs = (debug = missing,)) + targetfn = (x) -> rand(rng) > 0.5 ? 1 : -1 + + @test_logs match_mode = :all project_to(prj, targetfn) + @test_logs (:warn, r"The cost increased.*") match_mode = :any project_to( + prj, + targetfn, + debug = [Manopt.DebugWarnIfCostIncreases()], + ) + end + + @testset begin + rng = StableRNG(42) + prj = ProjectedTo(Beta; kwargs = (debug = [Manopt.DebugWarnIfCostIncreases()],)) + targetfn = (x) -> rand(rng) > 0.5 ? 1 : -1 + + @test_logs (:warn, r"The cost increased.*") match_mode = :any project_to( + prj, + targetfn, + ) + @test_logs match_mode = :all project_to(prj, targetfn, debug = missing) + @test_logs match_mode = :all project_to(prj, targetfn, debug = []) + end + end -@testitem "Direction rule can improve for MLE" begin +@testitem "Direction rule can improve for MLE" begin using BayesBase, ExponentialFamily, Distributions using ExponentialFamilyProjection, StableRNGs - dists = (Beta(1, 1), Gamma(10, 20), Bernoulli(0.8), NormalMeanVariance(-10, 0.1), Poisson(4.8)) - + dists = ( + Beta(1, 1), + Gamma(10, 20), + Bernoulli(0.8), + NormalMeanVariance(-10, 0.1), + Poisson(4.8), + ) + for dist in dists rng = StableRNG(42) data = rand(rng, dist, 4000) - + norm_bounds = [0.01, 0.1, 10.0] - + divergences = map(norm_bounds) do norm parameters = ProjectionParameters( - direction = ExponentialFamilyProjection.BoundedNormUpdateRule(norm) + direction = ExponentialFamilyProjection.BoundedNormUpdateRule(norm), + ) + projection = ProjectedTo( + ExponentialFamily.exponential_family_typetag(dist), + ()..., + parameters = parameters, ) - projection = ProjectedTo(ExponentialFamily.exponential_family_typetag(dist), ()..., parameters = parameters) approximated = project_to(projection, data) kldivergence(approximated, dist) end @testset "true dist $(dist)" begin - @test issorted(divergences, rev=true) + @test issorted(divergences, rev = true) @test (divergences[1] - divergences[end]) / divergences[1] > 0.05 end - + end end @@ -471,20 +512,21 @@ end true_dist = MvNormal([1.0, 2.0], [1.0 0.7; 0.7 2.0]) logp = (x) -> logpdf(true_dist, x) - manifold = ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (2,), nothing) + manifold = ExponentialFamilyManifolds.get_natural_manifold( + MvNormalMeanCovariance, + (2,), + nothing, + ) initialpoint = rand(manifold) direction = MomentumGradient(manifold, initialpoint) - momentum_parameters = ProjectionParameters( - direction = direction, - niterations = 1000, - tolerance = 1e-8 - ) + momentum_parameters = + ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8) + + projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters = momentum_parameters) - projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters=momentum_parameters) - approximated = project_to(projection, logp, initialpoint = initialpoint) - + @test approximated isa MvNormalMeanCovariance @test kldivergence(approximated, true_dist) < 0.01 @test projection.parameters.direction isa MomentumGradient @@ -497,21 +539,22 @@ end true_dist = MvNormal([1.0, 2.0], [1.0 0.7; 0.7 2.0]) rng = StableRNG(42) samples = rand(rng, true_dist, 1000) - - manifold = ExponentialFamilyManifolds.get_natural_manifold(MvNormalMeanCovariance, (2,), nothing) - + + manifold = ExponentialFamilyManifolds.get_natural_manifold( + MvNormalMeanCovariance, + (2,), + nothing, + ) + initialpoint = rand(rng, manifold) direction = MomentumGradient(manifold, initialpoint) - - momentum_parameters = ProjectionParameters( - direction = direction, - niterations = 1000, - tolerance = 1e-8 - ) - - projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters=momentum_parameters) + + momentum_parameters = + ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8) + + projection = ProjectedTo(MvNormalMeanCovariance, 2, parameters = momentum_parameters) approximated = project_to(projection, samples, initialpoint = initialpoint) - + @test approximated isa MvNormalMeanCovariance @test kldivergence(approximated, true_dist) < 0.01 # Ensure good approximation @test projection.parameters.direction isa MomentumGradient