Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(breaking): update to Manopt 0.5 #40

Merged
merged 7 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.10'
- '1.11'
os:
- ubuntu-latest
arch:
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ StaticTools = "86c06d3c-3f03-46de-9781-57580aa96d0a"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
BayesBase = "1.3"
BayesBase = "1.5.0"
Bumper = "0.6"
Distributions = "0.25"
ExponentialFamily = "1.5"
ExponentialFamily = "1.6"
ExponentialFamilyManifolds = "1.5"
FastCholesky = "1.3"
FillArrays = "1"
Expand All @@ -38,7 +38,7 @@ LinearAlgebra = "1.10"
LoopVectorization = "0.12"
Manifolds = "0.9"
ManifoldsBase = "0.15"
Manopt = "0.4"
Manopt = "0.5"
Random = "1.10"
RecursiveArrayTools = "3.2"
StableRNGs = "1"
Expand Down
4 changes: 2 additions & 2 deletions src/manopt/bounded_norm_update_rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Manopt
using LoopVectorization

"""
BoundedNormUpdateRule(limit; direction = IdentityUpdateRule())
BoundedNormUpdateRule(limit; direction = Manopt.IdentityUpdateRule())

A `DirectionUpdateRule` is a direction rule that constrains the norm of the direction to a specified limit.

Expand All @@ -23,7 +23,7 @@ struct BoundedNormUpdateRule{L,D} <: Manopt.DirectionUpdateRule
direction::D
end

function BoundedNormUpdateRule(limit; direction = IdentityUpdateRule())
function BoundedNormUpdateRule(limit; direction = Manopt.IdentityUpdateRule())
return BoundedNormUpdateRule(limit, direction)
end

Expand Down
4 changes: 2 additions & 2 deletions src/projected_to.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ The following parameters are available:
* `strategy = ExponentialFamilyProjection.DefaultStrategy()`: The strategy to use to compute the gradients.
* `niterations = 100`: The number of iterations for the optimization procedure.
* `tolerance = 1e-6`: The tolerance for the norm of the gradient.
* `stepsize = ConstantStepsize(0.1)`: The stepsize for the optimization procedure. Accepts stepsizes from `Manopt.jl`.
* `stepsize = ConstantLength(0.1)`: The stepsize for the optimization procedure. Accepts stepsizes from `Manopt.jl`.
* `seed`: Optional; Seed for the `rng`
* `rng`: Optional; Random number generator
* `direction = BoundedNormUpdateRule(static(1.0)`: Direction update rule. Accepts `Manopt.DirectionUpdateRule` from `Manopt.jl`.
Expand All @@ -117,7 +117,7 @@ Base.@kwdef struct ProjectionParameters{S,I,T,P,D,N,U}
strategy::S = DefaultStrategy()
niterations::I = 100
tolerance::T = 1e-6
stepsize::P = ConstantStepsize(0.1)
stepsize::P = ConstantLength(0.1)
seed::D = 42
rng::N = StableRNG(seed)
direction::U = BoundedNormUpdateRule(static(1.0))
Expand Down
6 changes: 3 additions & 3 deletions test/manopt/bounded_norm_update_rule_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# returning the unbounded gradient as the first argument and a collection of bounded gradients for testing.
function apply_update_rules_for_test(p, limit)
cpa = DefaultManoptProblem(M, ManifoldGradientObjective(f, grad_f))
gst = GradientDescentState(M, zero(p))
gst = GradientDescentState(M; p=zero(p))
Manopt.set_iterate!(gst, M, p)

_, X = IdentityUpdateRule()(cpa, gst, 1)
_, X = Manopt.IdentityUpdateRule()(cpa, gst, 1)
X_identity = copy(X)

_, X = BoundedNormUpdateRule(limit)(cpa, gst, 1)
Expand Down Expand Up @@ -79,7 +79,7 @@
@testset "JET tests" begin
for limit in (1, 1.0, 1.0f0), p in (zeros(Float64, 3), zeros(Float32, 3))
cpa = DefaultManoptProblem(M, ManifoldGradientObjective(f, grad_f))
gst = GradientDescentState(M, zero(p))
gst = GradientDescentState(M; p=zero(p))
@test_opt BoundedNormUpdateRule(limit)(cpa, gst, 1)
@test_opt BoundedNormUpdateRule(static(limit))(cpa, gst, 1)
@test_opt BoundedNormUpdateRule(
Expand Down
8 changes: 4 additions & 4 deletions test/projection/projected_to_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function test_convergence_nsamples(
nsamples_tolerance = _convergence_nsamples_default_tolerance(distribution),
nsamples_niterations = _convergence_nsamples_default_niterations(distribution),
nsamples_rng = StableRNG(42),
nsamples_stepsize = ConstantStepsize(0.1),
nsamples_stepsize = ConstantLength(0.1),
nsamples_required_accuracy = 1e-1,
kwargs...,
)
Expand Down Expand Up @@ -168,7 +168,7 @@ function test_convergence_niterations(
niterations_tolerance = _convergence_niterations_default_tolerance(distribution),
niterations_nsamples = _convergence_niterations_default_nsamples(distribution),
niterations_rng = StableRNG(42),
niterations_stepsize = ConstantStepsize(0.1),
niterations_stepsize = ConstantLength(0.1),
niterations_required_accuracy = 1e-1,
kwargs...,
)
Expand Down Expand Up @@ -219,7 +219,7 @@ function test_convergence_niterations_mle(
niterations_tolerance = _convergence_niterations_default_tolerance(distribution),
niterations_nsamples = _convergence_niterations_default_nsamples(distribution),
niterations_rng = StableRNG(42),
niterations_stepsize = ConstantStepsize(0.1),
niterations_stepsize = ConstantLength(0.1),
niterations_required_accuracy = 1e-1,
kwargs...,
)
Expand Down Expand Up @@ -268,7 +268,7 @@ function test_convergence_nsamples_mle(
nsamples_tolerance = _convergence_nsamples_default_tolerance(distribution),
nsamples_niterations = _convergence_nsamples_default_niterations(distribution),
nsamples_rng = StableRNG(42),
nsamples_stepsize = ConstantStepsize(0.1),
nsamples_stepsize = ConstantLength(0.1),
nsamples_required_accuracy = 1e-1,
kwargs...,
)
Expand Down
10 changes: 5 additions & 5 deletions test/projection/projected_to_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
@test typeof(get_stopping_criterion(defaultparams)) ==
typeof(get_stopping_criterion(parameters_from_creation))
# These should pass as soon as `Manopt` implements `==`
@test_broken getstepsize(defaultparams) == getstepsize(parameters_from_creation)
@test getstepsize(defaultparams) == getstepsize(parameters_from_creation)
@test_broken get_stopping_criterion(defaultparams) ==
get_stopping_criterion(parameters_from_creation)

Expand Down Expand Up @@ -518,7 +518,7 @@ end
nothing,
)
initialpoint = rand(manifold)
direction = MomentumGradient(manifold, initialpoint)
direction = MomentumGradient(p=initialpoint)

momentum_parameters =
ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8)
Expand All @@ -529,7 +529,7 @@ end

@test approximated isa MvNormalMeanCovariance
@test kldivergence(approximated, true_dist) < 0.01
@test projection.parameters.direction isa MomentumGradient
@test projection.parameters.direction isa Manopt.ManifoldDefaultsFactory
end

@testitem "MomentumGradient direction update rule on samples" begin
Expand All @@ -547,7 +547,7 @@ end
)

initialpoint = rand(rng, manifold)
direction = MomentumGradient(manifold, initialpoint)
direction = MomentumGradient(p=initialpoint)

momentum_parameters =
ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8)
Expand All @@ -557,5 +557,5 @@ end

@test approximated isa MvNormalMeanCovariance
@test kldivergence(approximated, true_dist) < 0.01 # Ensure good approximation
@test projection.parameters.direction isa MomentumGradient
@test projection.parameters.direction isa Manopt.ManifoldDefaultsFactory
end
Loading