Skip to content

Commit

Permalink
Merge pull request #40 from ReactiveBayes/update-deps
Browse files Browse the repository at this point in the history
feat(breaking): update to Manopt 0.5
  • Loading branch information
bvdmitri authored Nov 21, 2024
2 parents a368d5a + 10cfbd9 commit bc21474
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.10'
- '1.11'
os:
- ubuntu-latest
arch:
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ StaticTools = "86c06d3c-3f03-46de-9781-57580aa96d0a"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
BayesBase = "1.3"
BayesBase = "1.5.0"
Bumper = "0.6"
Distributions = "0.25"
ExponentialFamily = "1.5"
ExponentialFamily = "1.6"
ExponentialFamilyManifolds = "1.5"
FastCholesky = "1.3"
FillArrays = "1"
Expand All @@ -38,7 +38,7 @@ LinearAlgebra = "1.10"
LoopVectorization = "0.12"
Manifolds = "0.9"
ManifoldsBase = "0.15"
Manopt = "0.4"
Manopt = "0.5"
Random = "1.10"
RecursiveArrayTools = "3.2"
StableRNGs = "1"
Expand Down
4 changes: 2 additions & 2 deletions src/manopt/bounded_norm_update_rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Manopt
using LoopVectorization

"""
BoundedNormUpdateRule(limit; direction = IdentityUpdateRule())
BoundedNormUpdateRule(limit; direction = Manopt.IdentityUpdateRule())
A `DirectionUpdateRule` is a direction rule that constrains the norm of the direction to a specified limit.
Expand All @@ -23,7 +23,7 @@ struct BoundedNormUpdateRule{L,D} <: Manopt.DirectionUpdateRule
direction::D
end

function BoundedNormUpdateRule(limit; direction = IdentityUpdateRule())
function BoundedNormUpdateRule(limit; direction = Manopt.IdentityUpdateRule())
return BoundedNormUpdateRule(limit, direction)
end

Expand Down
4 changes: 2 additions & 2 deletions src/projected_to.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ The following parameters are available:
* `strategy = ExponentialFamilyProjection.DefaultStrategy()`: The strategy to use to compute the gradients.
* `niterations = 100`: The number of iterations for the optimization procedure.
* `tolerance = 1e-6`: The tolerance for the norm of the gradient.
* `stepsize = ConstantStepsize(0.1)`: The stepsize for the optimization procedure. Accepts stepsizes from `Manopt.jl`.
* `stepsize = ConstantLength(0.1)`: The stepsize for the optimization procedure. Accepts stepsizes from `Manopt.jl`.
* `seed`: Optional; Seed for the `rng`
* `rng`: Optional; Random number generator
* `direction = BoundedNormUpdateRule(static(1.0)`: Direction update rule. Accepts `Manopt.DirectionUpdateRule` from `Manopt.jl`.
Expand All @@ -117,7 +117,7 @@ Base.@kwdef struct ProjectionParameters{S,I,T,P,D,N,U}
strategy::S = DefaultStrategy()
niterations::I = 100
tolerance::T = 1e-6
stepsize::P = ConstantStepsize(0.1)
stepsize::P = ConstantLength(0.1)
seed::D = 42
rng::N = StableRNG(seed)
direction::U = BoundedNormUpdateRule(static(1.0))
Expand Down
6 changes: 3 additions & 3 deletions test/manopt/bounded_norm_update_rule_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# returning the unbounded gradient as the first argument and a collection of bounded gradients for testing.
function apply_update_rules_for_test(p, limit)
cpa = DefaultManoptProblem(M, ManifoldGradientObjective(f, grad_f))
gst = GradientDescentState(M, zero(p))
gst = GradientDescentState(M; p=zero(p))
Manopt.set_iterate!(gst, M, p)

_, X = IdentityUpdateRule()(cpa, gst, 1)
_, X = Manopt.IdentityUpdateRule()(cpa, gst, 1)
X_identity = copy(X)

_, X = BoundedNormUpdateRule(limit)(cpa, gst, 1)
Expand Down Expand Up @@ -79,7 +79,7 @@
@testset "JET tests" begin
for limit in (1, 1.0, 1.0f0), p in (zeros(Float64, 3), zeros(Float32, 3))
cpa = DefaultManoptProblem(M, ManifoldGradientObjective(f, grad_f))
gst = GradientDescentState(M, zero(p))
gst = GradientDescentState(M; p=zero(p))
@test_opt BoundedNormUpdateRule(limit)(cpa, gst, 1)
@test_opt BoundedNormUpdateRule(static(limit))(cpa, gst, 1)
@test_opt BoundedNormUpdateRule(
Expand Down
8 changes: 4 additions & 4 deletions test/projection/projected_to_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function test_convergence_nsamples(
nsamples_tolerance = _convergence_nsamples_default_tolerance(distribution),
nsamples_niterations = _convergence_nsamples_default_niterations(distribution),
nsamples_rng = StableRNG(42),
nsamples_stepsize = ConstantStepsize(0.1),
nsamples_stepsize = ConstantLength(0.1),
nsamples_required_accuracy = 1e-1,
kwargs...,
)
Expand Down Expand Up @@ -168,7 +168,7 @@ function test_convergence_niterations(
niterations_tolerance = _convergence_niterations_default_tolerance(distribution),
niterations_nsamples = _convergence_niterations_default_nsamples(distribution),
niterations_rng = StableRNG(42),
niterations_stepsize = ConstantStepsize(0.1),
niterations_stepsize = ConstantLength(0.1),
niterations_required_accuracy = 1e-1,
kwargs...,
)
Expand Down Expand Up @@ -219,7 +219,7 @@ function test_convergence_niterations_mle(
niterations_tolerance = _convergence_niterations_default_tolerance(distribution),
niterations_nsamples = _convergence_niterations_default_nsamples(distribution),
niterations_rng = StableRNG(42),
niterations_stepsize = ConstantStepsize(0.1),
niterations_stepsize = ConstantLength(0.1),
niterations_required_accuracy = 1e-1,
kwargs...,
)
Expand Down Expand Up @@ -268,7 +268,7 @@ function test_convergence_nsamples_mle(
nsamples_tolerance = _convergence_nsamples_default_tolerance(distribution),
nsamples_niterations = _convergence_nsamples_default_niterations(distribution),
nsamples_rng = StableRNG(42),
nsamples_stepsize = ConstantStepsize(0.1),
nsamples_stepsize = ConstantLength(0.1),
nsamples_required_accuracy = 1e-1,
kwargs...,
)
Expand Down
10 changes: 5 additions & 5 deletions test/projection/projected_to_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
@test typeof(get_stopping_criterion(defaultparams)) ==
typeof(get_stopping_criterion(parameters_from_creation))
# These should pass as soon as `Manopt` implements `==`
@test_broken getstepsize(defaultparams) == getstepsize(parameters_from_creation)
@test getstepsize(defaultparams) == getstepsize(parameters_from_creation)
@test_broken get_stopping_criterion(defaultparams) ==
get_stopping_criterion(parameters_from_creation)

Expand Down Expand Up @@ -518,7 +518,7 @@ end
nothing,
)
initialpoint = rand(manifold)
direction = MomentumGradient(manifold, initialpoint)
direction = MomentumGradient(p=initialpoint)

momentum_parameters =
ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8)
Expand All @@ -529,7 +529,7 @@ end

@test approximated isa MvNormalMeanCovariance
@test kldivergence(approximated, true_dist) < 0.01
@test projection.parameters.direction isa MomentumGradient
@test projection.parameters.direction isa Manopt.ManifoldDefaultsFactory
end

@testitem "MomentumGradient direction update rule on samples" begin
Expand All @@ -547,7 +547,7 @@ end
)

initialpoint = rand(rng, manifold)
direction = MomentumGradient(manifold, initialpoint)
direction = MomentumGradient(p=initialpoint)

momentum_parameters =
ProjectionParameters(direction = direction, niterations = 1000, tolerance = 1e-8)
Expand All @@ -557,5 +557,5 @@ end

@test approximated isa MvNormalMeanCovariance
@test kldivergence(approximated, true_dist) < 0.01 # Ensure good approximation
@test projection.parameters.direction isa MomentumGradient
@test projection.parameters.direction isa Manopt.ManifoldDefaultsFactory
end

0 comments on commit bc21474

Please sign in to comment.