Skip to content

Commit

Permalink
Merge branch 'main' into inverse-gamma-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Nov 6, 2024
2 parents 3103a23 + 4d8623a commit 65caac7
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 70 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ExponentialFamilyProjection"
uuid = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
authors = ["Mykola Lukashchuk <[email protected]>", "Dmitry Bagaev <[email protected]>", "Albert Podusenko <[email protected]>"]
version = "1.1.2"
version = "1.2.0"

[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
Expand Down Expand Up @@ -30,7 +30,7 @@ BayesBase = "1.3"
Bumper = "0.6"
Distributions = "0.25"
ExponentialFamily = "1.5"
ExponentialFamilyManifolds = "1.2"
ExponentialFamilyManifolds = "1.5"
FastCholesky = "1.3"
FillArrays = "1"
ForwardDiff = "0.10.36"
Expand Down
26 changes: 21 additions & 5 deletions src/projected_to.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The following arguments are optional:
* `conditioner = nothing`: a conditioner to use for the projection, not all exponential family members require a conditioner, but some do, e.g. `Laplace`
* `parameters = DefaultProjectionParameters`: parameters for the projection procedure
* `kwargs = nothing`: Additional arguments passed to `Manopt.gradient_descent!` (optional). For details on `gradient_descent!` parameters, see the [Manopt.jl documentation](https://manoptjl.org/stable/solvers/gradient_descent/#Manopt.gradient_descent). Note, that `kwargs` passed to `project_to` take precedence over `kwargs` specified in the parameters.
```jldoctest
julia> using ExponentialFamily
Expand All @@ -33,28 +34,32 @@ julia> projected_to = ProjectedTo(Laplace, conditioner = 2.0)
ProjectedTo(Laplace, conditioner = 2.0)
```
"""
struct ProjectedTo{T,D,C,P}
struct ProjectedTo{T,D,C,P,E}
dims::D
conditioner::C
parameters::P
kwargs::E
end

ProjectedTo(
dims::Vararg{Int};
conditioner = nothing,
parameters = DefaultProjectionParameters(),
kwargs = nothing,
) = ProjectedTo(
ExponentialFamilyDistribution,
dims...,
conditioner = conditioner,
parameters = parameters,
kwargs = kwargs,
)
function ProjectedTo(
::Type{T},
dims...;
conditioner::C = nothing,
parameters::P = DefaultProjectionParameters(),
) where {T,C,P}
kwargs::E = nothing,
) where {T,C,P,E}
# Check that `dims` are all integers
if !all(d -> typeof(d) <: Int, dims)
# If not, throw an error, also suggesting to use keyword arguments
Expand All @@ -65,13 +70,14 @@ function ProjectedTo(
end
error(msg)
end
return ProjectedTo{T,typeof(dims),C,P}(dims, conditioner, parameters)
return ProjectedTo{T,typeof(dims),C,P,E}(dims, conditioner, parameters, kwargs)
end

get_projected_to_type(::ProjectedTo{T}) where {T} = T
get_projected_to_dims(prj::ProjectedTo) = prj.dims
get_projected_to_conditioner(prj::ProjectedTo) = prj.conditioner
get_projected_to_parameters(prj::ProjectedTo) = prj.parameters
get_projected_to_kwargs(prj::ProjectedTo) = prj.kwargs
get_projected_to_manifold(prj::ProjectedTo) =
ExponentialFamilyManifolds.get_natural_manifold(
get_projected_to_type(prj),
Expand Down Expand Up @@ -268,9 +274,19 @@ function project_to(
supplementary_η,
)

# First we query the `kwargs` defined in the `ProjectionParameters`
prj_kwargs = get_projected_to_kwargs(prj)
prj_kwargs = isnothing(prj_kwargs) ? (;) : prj_kwargs
# And attach the `kwargs` passed to `project_to`, those may override
# some settings in the `ProjectionParameters`
if !isnothing(kwargs)
prj_kwargs = (; prj_kwargs..., kwargs...)
end
# We disable the default `debug` statements, which are set in `Manopt`
# in order to improve the performance a little bit
kwargs = !haskey(kwargs, :debug) ? (; kwargs..., debug = missing) : kwargs
if !haskey(prj_kwargs, :debug)
prj_kwargs = (; prj_kwargs..., debug = missing)
end

return _kernel_project_to(
get_projected_to_type(prj),
Expand All @@ -281,7 +297,7 @@ function project_to(
strategy,
state,
current_η,
kwargs,
prj_kwargs,
)
end

Expand Down
25 changes: 8 additions & 17 deletions src/strategies/control_variate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,6 @@ function control_variate_compute_gradient_buffered!(
state.sufficientstatistics',
state.gradsamples',
)
# --

# Next we compute the `corr_matrix` using the sample principle, preallocate the storage
# The naive code would be `corr_matrix = cov_matrix * inv_fisher`
# --
corr_matrix = @alloc(
promote_type(eltype(cov_matrix), eltype(inv_fisher)),
size(cov_matrix, 1),
size(inv_fisher, 2)
)
mul!(corr_matrix, cov_matrix, inv_fisher)
# --

# Compute means of sufficientstatistics and gradsamples inplace
# The naive code would be
Expand All @@ -323,19 +311,22 @@ function control_variate_compute_gradient_buffered!(

# The next four lines finish the computation, and essentially equivalent to the following code
# `estimated_grad_vector = mean_gradsamples - corr_matrix * (mean_sufficientstats - gradlogpartition)`
# where `corr_matrix = cov_matrix * inv_fisher`
# `ef_gradient = η - inv_fisher * estimated_grad_vector` # or (η - (η_ef + inv_fisher * estimated_grad_vector))
# --
tmp1 = @alloc(
promote_type(eltype(mean_sufficientstats), eltype(gradlogpartition)),
length(mean_sufficientstats)
)
tmp2 = @alloc(promote_type(eltype(corr_matrix), eltype(tmp1)), length(tmp1))
tmp2 = @alloc(promote_type(eltype(inv_fisher), eltype(tmp1)), length(tmp1))
tmp3 = @alloc(promote_type(eltype(cov_matrix), eltype(tmp2)), length(tmp2))

map!(-, tmp1, mean_sufficientstats, gradlogpartition) # tmp1 = (mean_sufficientstats - gradlogpartition)
mul!(tmp2, corr_matrix, tmp1) # tmp2 = corr_matrix * tmp1
map!(-, tmp1, mean_gradsamples, tmp2) # tmp1 = estimated_grad_vector = mean_gradsamples - tmp2
mul!(tmp2, inv_fisher, tmp1) # tmp2 = inv_fisher * estimated_grad_vector
map!(-, X, η, tmp2) # X .= η .- tmp2
mul!(tmp2, inv_fisher, tmp1) # tmp2 = inv_fisher * tmp1
mul!(tmp3, cov_matrix, tmp2) # tmp3 = cov_matrix * tmp2, such that tmp3 = cov_matrix * inv_fisher * tmp1
map!(-, tmp1, mean_gradsamples, tmp3) # tmp1 = estimated_grad_vector = mean_gradsamples - tmp3
mul!(tmp3, inv_fisher, tmp1) # tmp3 = inv_fisher * estimated_grad_vector
map!(-, X, η, tmp3) # X .= η .- tmp3
# --

nothing
Expand Down
28 changes: 22 additions & 6 deletions test/projection/helpers/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,27 @@

include("../projected_to_setuptests.jl")

function test_projection_with_debug(n_iterations, do_debug)
function test_projection_with_debug(n_iterations, do_debug, pass_to_prj = false)
distribution = Bernoulli(0.5)
buf = IOBuffer()
if do_debug
debug = [DebugCost(io=buf), DebugDivider("\n";io=buf)]
debug = [DebugCost(io = buf), DebugDivider("\n"; io = buf)]
else
debug = missing
end

projection_parameters = ProjectionParameters(niterations=n_iterations)
project_to(ProjectedTo(Bernoulli, parameters = projection_parameters), (x) -> logpdf(distribution, x); debug=debug)
projection_parameters = ProjectionParameters(niterations = n_iterations)
if pass_to_prj
prj = ProjectedTo(
Bernoulli,
parameters = projection_parameters,
kwargs = (debug = debug,),
)
project_to(prj, (x) -> logpdf(distribution, x))
else
prj = ProjectedTo(Bernoulli, parameters = projection_parameters)
project_to(prj, (x) -> logpdf(distribution, x); debug = debug)
end
debug_string = String(take!(buf))
if do_debug
lines = split(debug_string, '\n')
Expand All @@ -26,14 +36,20 @@
end

@testset "projections with debug" begin
for n in 1:10
for n = 1:10
test_projection_with_debug(n, true)
end
end

@testset "projections without debug" begin
for n in 1:10
for n = 1:10
test_projection_with_debug(n, false)
end
end

@testset "projections with debug passed through ProjectedTo" begin
for n = 1:10
test_projection_with_debug(n, true, true)
end
end
end
61 changes: 52 additions & 9 deletions test/projection/projected_to_normal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ end

@testset let distribution =
MvNormalMeanCovariance(10randn(StableRNG(42), 4), 10rand(StableRNG(43), 4))
@test test_projection_convergence(
distribution,
niterations_range = 500:100:2000
)
@test test_projection_convergence(distribution, niterations_range = 500:100:2000)
end
end

Expand All @@ -88,14 +85,60 @@ end
@test test_projection_convergence(
distribution,
to = MvNormalMeanCovariance,
dims = (2, ),
dims = (2,),
conditioner = nothing,
)
end

end

@testitem "Project a product of `MvNormalMeanScalePrecision` and `MvNormalMeanScalePrecision` to `MvNormalMeanScalePrecision`" begin
using BayesBase, ExponentialFamily, Distributions, LinearAlgebra

include("./projected_to_setuptests.jl")

@testset let distribution = ProductOf(
MvNormalMeanScalePrecision(ones(2), 2),
MvNormalMeanScalePrecision(ones(2), 3),
)
@test test_projection_convergence(
distribution,
to = MvNormalMeanScalePrecision,
dims = (2,),
conditioner = nothing,
)
end

@testset let distribution = ProductOf(
MvNormalMeanScalePrecision(ones(8), 2),
MvNormalMeanScalePrecision(ones(8), 3),
)
@test test_projection_convergence(
distribution,
to = MvNormalMeanScalePrecision,
dims = (8,),
conditioner = nothing,
)
end

@testset let distribution = ProductOf(
MvNormalMeanScalePrecision(ones(20), 2),
MvNormalMeanScalePrecision(ones(20), 3),
)
@test test_projection_convergence(
distribution,
to = MvNormalMeanScalePrecision,
dims = (20,),
conditioner = nothing,
nsamples_niterations = 6000,
nsamples_range = 1000:1000:6000,
niterations_range = 400:100:1000,
nsamples_required_accuracy=0.3,
niterations_required_accuracy=0.3
)
end

end

@testitem "MLE" begin
using BayesBase, ExponentialFamily, Distributions, JET
Expand Down Expand Up @@ -123,6 +166,9 @@ end
@test test_projection_mle(distribution)
end

@testset let distribution = MvNormalMeanScalePrecision(ones(2), 2)
@test test_projection_mle(distribution)
end

@testset let distribution = MvNormalMeanCovariance(ones(2), Matrix(Diagonal(ones(2))))
@test test_projection_mle(distribution)
Expand All @@ -137,10 +183,7 @@ end

@testset let distribution =
MvNormalMeanCovariance(10randn(StableRNG(42), 4), 10rand(StableRNG(43), 4))
@test test_projection_mle(
distribution,
niterations_range = 500:100:2000
)
@test test_projection_mle(distribution, niterations_range = 500:100:2000)
end

end
Loading

0 comments on commit 65caac7

Please sign in to comment.