From 413f2c66487f6b37b03061906bef05f6fe900808 Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Fri, 22 Nov 2024 18:22:54 +0100 Subject: [PATCH 01/10] Fix complex support for L-BFGS --- optax/_src/linesearch.py | 4 ++-- optax/tree_utils/__init__.py | 1 + optax/tree_utils/_tree_math.py | 14 +++++++++++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py index 83f988e9..669b52eb 100644 --- a/optax/_src/linesearch.py +++ b/optax/_src/linesearch.py @@ -698,7 +698,7 @@ def _value_and_slope_on_line( """ step = otu.tree_add_scalar_mul(params, stepsize, updates) value_step, grad_step = value_and_grad_fn(step, **fn_kwargs) - slope_step = otu.tree_vdot(grad_step, updates) + slope_step = otu.tree_vdot(otu.tree_conj(grad_step), updates) return step, value_step, grad_step, slope_step def _compute_decrease_error( @@ -1205,7 +1205,7 @@ def init_fn( f"Unknown initial guess strategy: {initial_guess_strategy}" ) - slope = otu.tree_vdot(updates, grad) + slope = otu.tree_vdot(updates, otu.tree_conj(grad)) return ZoomLinesearchState( count=jnp.asarray(0, dtype=jnp.int32), # diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index 130b9d90..8d722082 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -35,6 +35,7 @@ from optax.tree_utils._tree_math import tree_l2_norm from optax.tree_utils._tree_math import tree_linf_norm from optax.tree_utils._tree_math import tree_max +from optax.tree_utils._tree_math import tree_conj from optax.tree_utils._tree_math import tree_mul from optax.tree_utils._tree_math import tree_ones_like from optax.tree_utils._tree_math import tree_scalar_mul diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 37e0ea27..adefec11 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -127,7 +127,7 @@ def tree_add_scalar_mul( def _vdot_safe(a, b): - return _vdot(jnp.asarray(a), jnp.asarray(b)) + return _vdot(jnp.asarray(a), jnp.asarray(b)).real def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric: @@ -183,6 +183,18 @@ def tree_max(tree: Any) -> chex.Numeric: return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf)) +def tree_conj(tree: Any) -> Any: + """Compute the conjugate of a pytree. + + Args: + tree: pytree. + + Returns: + a pytree with the same structure as ``tree``. + """ + return jax.tree.map(jnp.conj, tree) + + def _square(leaf): return jnp.square(leaf.real) + jnp.square(leaf.imag) From d4e1d6fdedb68e937f9d85bd70d94d668657fb79 Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 12:24:23 +0100 Subject: [PATCH 02/10] Introduce `otu.tree_real` instead of fixing `otu.tree_vdot` --- optax/_src/linesearch.py | 6 +++--- optax/tree_utils/__init__.py | 1 + optax/tree_utils/_tree_math.py | 14 +++++++++++++- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py index 669b52eb..43b98f43 100644 --- a/optax/_src/linesearch.py +++ b/optax/_src/linesearch.py @@ -319,7 +319,7 @@ def update_fn( # Slope of lr -> value_fn(params + lr * updates) at lr = 0 # Should be negative to ensure that there exists a lr (potentially # infinitesimal) that satisfies the criterion. - slope = otu.tree_vdot(updates, grad) + slope = otu.tree_real(otu.tree_vdot(updates, otu.tree_conj(grad))) def cond_fn( search_state: BacktrackingLineSearchState, @@ -698,7 +698,7 @@ def _value_and_slope_on_line( """ step = otu.tree_add_scalar_mul(params, stepsize, updates) value_step, grad_step = value_and_grad_fn(step, **fn_kwargs) - slope_step = otu.tree_vdot(otu.tree_conj(grad_step), updates) + slope_step = otu.tree_real(otu.tree_vdot(otu.tree_conj(grad_step), updates)) return step, value_step, grad_step, slope_step def _compute_decrease_error( @@ -1205,7 +1205,7 @@ def init_fn( f"Unknown initial guess strategy: {initial_guess_strategy}" ) - slope = otu.tree_vdot(updates, otu.tree_conj(grad)) + slope = otu.tree_real(otu.tree_vdot(updates, otu.tree_conj(grad))) return ZoomLinesearchState( count=jnp.asarray(0, dtype=jnp.int32), # diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index 8d722082..ff28716f 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -36,6 +36,7 @@ from optax.tree_utils._tree_math import tree_linf_norm from optax.tree_utils._tree_math import tree_max from optax.tree_utils._tree_math import tree_conj +from optax.tree_utils._tree_math import tree_real from optax.tree_utils._tree_math import tree_mul from optax.tree_utils._tree_math import tree_ones_like from optax.tree_utils._tree_math import tree_scalar_mul diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index adefec11..307acca9 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -127,7 +127,7 @@ def tree_add_scalar_mul( def _vdot_safe(a, b): - return _vdot(jnp.asarray(a), jnp.asarray(b)).real + return _vdot(jnp.asarray(a), jnp.asarray(b)) def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric: @@ -195,6 +195,18 @@ def tree_conj(tree: Any) -> Any: return jax.tree.map(jnp.conj, tree) +def tree_real(tree: Any) -> Any: + """Compute the real part of a pytree. + + Args: + tree: pytree. + + Returns: + a pytree with the same structure as ``tree``. + """ + return jax.tree.map(jnp.real, tree) + + def _square(leaf): return jnp.square(leaf.real) + jnp.square(leaf.imag) From 02b5968a1b68798a2faf08bd5a9e7f325ef9eca7 Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 12:38:24 +0100 Subject: [PATCH 03/10] Additional fixes for lbfgs vdots --- optax/_src/transform.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index dcfefdea..6ab7109a 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1521,7 +1521,7 @@ def right_product(vec, idx): dwi, dui = jax.tree.map( lambda x: x[idx], (diff_params_memory, diff_updates_memory) ) - alpha = rhos[idx] * otu.tree_vdot(dwi, vec) + alpha = rhos[idx] * otu.tree_real(otu.tree_vdot(dwi, vec)) vec = otu.tree_add_scalar_mul(vec, -alpha, dui) return vec, alpha @@ -1536,7 +1536,7 @@ def left_product(vec, idx_alpha): dwi, dui = jax.tree.map( lambda x: x[idx], (diff_params_memory, diff_updates_memory) ) - beta = rhos[idx] * otu.tree_vdot(dui, vec) + beta = rhos[idx] * otu.tree_real(otu.tree_vdot(dui, vec)) vec = otu.tree_add_scalar_mul(vec, alpha - beta, dwi) return vec, beta @@ -1666,7 +1666,7 @@ def update_fn( # 1. Updates the memory buffers given fresh params and gradients/updates diff_params = otu.tree_sub(params, state.params) diff_updates = otu.tree_sub(updates, state.updates) - vdot_diff_params_updates = otu.tree_vdot(diff_updates, diff_params) + vdot_diff_params_updates = otu.tree_real(otu.tree_vdot(diff_updates, diff_params)) weight = jnp.where( vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates ) From b2ccb3556b5aed715af20ca5d5f9c6ab3e501156 Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 12:40:36 +0100 Subject: [PATCH 04/10] Add tests for tree_math --- optax/tree_utils/_tree_math_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index c1908e62..8eeff0fb 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -152,6 +152,24 @@ def test_tree_max(self, key): got = tu.tree_max(tree) np.testing.assert_allclose(expected, got) + def test_tree_conj(self): + expected = jnp.conj(self.array_a) + got = tu.tree_conj(self.array_a) + np.testing.assert_array_almost_equal(expected, got) + + expected = (jnp.conj(self.tree_a[0]), jnp.conj(self.tree_a[1])) + got = tu.tree_conj(self.tree_a) + chex.assert_trees_all_close(expected, got) + + def test_tree_real(self): + expected = jnp.real(self.array_a) + got = tu.tree_real(self.array_a) + np.testing.assert_array_almost_equal(expected, got) + + expected = (jnp.real(self.tree_a[0]), jnp.real(self.tree_a[1])) + got = tu.tree_real(self.tree_a) + chex.assert_trees_all_close(expected, got) + def test_tree_l2_norm(self): expected = jnp.sqrt(jnp.vdot(self.array_a, self.array_a).real) got = tu.tree_l2_norm(self.array_a) From cf9de95d30aca20440db5e6f5fa959d387d908ab Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 12:52:30 +0100 Subject: [PATCH 05/10] Add test for complex-valued LBFGS --- optax/_src/alias_test.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index b03471b9..f03165f7 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -865,6 +865,41 @@ def fun(x): sol, _ = _run_opt(opt, fun, init_params=jnp.ones(n), tol=tol) chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol) + def test_complex(self): + """Test that optimization over complex variable z = x + jy matches equivalent + real case""" + + tol=1e-5 + W = jnp.array( + [[1, - 2], + [3, 4], + [-4 + 2j, 5 - 3j], + [-2 - 2j, 6]] + ) + + def to_real(z): + return jnp.stack((z.real, z.imag)) + + def to_complex(x): + return x[..., 0, :] + 1j * x[..., 1, :] # if len(x)>0 else jnp.zeros_like(x) + + def f(z): + return W @ z + + def fun_complex(z): + return jnp.sum(jnp.abs(f(z)) ** 1.5) + + def fun_real(z): + return fun_complex(to_complex(z)) + + z0 = jnp.array([1 - 1j, 0 + 1j]) + + opt_complex = alias.lbfgs() + opt_real = alias.lbfgs() + sol_complex, _ = _run_opt(opt_complex, fun_complex, init_params=z0, tol=tol) + sol_real, _ = _run_opt(opt_real, fun_real, init_params=to_real(z0), tol=tol) + + chex.assert_trees_all_close(sol_complex, to_complex(sol_real), atol=tol, rtol=tol) if __name__ == '__main__': absltest.main() From 3917b63e763243ce8e45ff31eea6c199611fe61a Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 12:55:35 +0100 Subject: [PATCH 06/10] Fix CI formatting --- optax/_src/alias_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index f03165f7..59fd5d90 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -869,7 +869,7 @@ def test_complex(self): """Test that optimization over complex variable z = x + jy matches equivalent real case""" - tol=1e-5 + tol = 1e-5 W = jnp.array( [[1, - 2], [3, 4], From c9e85b2bb57aea6505ee3d986ce959afe0cb92b6 Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 14:45:39 +0100 Subject: [PATCH 07/10] Fix tests --- optax/_src/alias_test.py | 18 ++++++++---------- optax/_src/linesearch.py | 2 +- optax/_src/transform.py | 2 +- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 59fd5d90..7ce1a4d9 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -334,7 +334,7 @@ def step(carry): params, state, count, _ = carry value, grad = value_and_grad_fun(params) updates, state = opt.update( - grad, state, params, value=value, grad=grad, value_fn=fun + jnp.conj(grad), state, params, value=value, grad=jnp.conj(grad), value_fn=fun ) params = update.apply_updates(params, updates) return params, state, count + 1, grad @@ -883,21 +883,19 @@ def to_real(z): def to_complex(x): return x[..., 0, :] + 1j * x[..., 1, :] # if len(x)>0 else jnp.zeros_like(x) - def f(z): - return W @ z + def f_complex(z): + return jnp.sum(jnp.abs(W @ z) ** 1.5) - def fun_complex(z): - return jnp.sum(jnp.abs(f(z)) ** 1.5) - - def fun_real(z): - return fun_complex(to_complex(z)) + def f_real(x): + return f_complex(to_complex(x)) z0 = jnp.array([1 - 1j, 0 + 1j]) + x0 = to_real(z0) opt_complex = alias.lbfgs() opt_real = alias.lbfgs() - sol_complex, _ = _run_opt(opt_complex, fun_complex, init_params=z0, tol=tol) - sol_real, _ = _run_opt(opt_real, fun_real, init_params=to_real(z0), tol=tol) + sol_complex, _ = _run_opt(opt_complex, f_complex, init_params=z0, tol=tol) + sol_real, _ = _run_opt(opt_real, f_real, init_params=x0, tol=tol) chex.assert_trees_all_close(sol_complex, to_complex(sol_real), atol=tol, rtol=tol) diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py index 43b98f43..9174257e 100644 --- a/optax/_src/linesearch.py +++ b/optax/_src/linesearch.py @@ -1205,7 +1205,7 @@ def init_fn( f"Unknown initial guess strategy: {initial_guess_strategy}" ) - slope = otu.tree_real(otu.tree_vdot(updates, otu.tree_conj(grad))) + slope = otu.tree_real(otu.tree_vdot(updates, grad)) return ZoomLinesearchState( count=jnp.asarray(0, dtype=jnp.int32), # diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 6ab7109a..9f9f1547 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1691,7 +1691,7 @@ def update_fn( # used to initialize the approximation of the inverse through the memory # buffer. if scale_init_precond: - numerator = otu.tree_vdot(diff_updates, diff_params) + numerator = otu.tree_real(otu.tree_vdot(diff_updates, diff_params)) denominator = otu.tree_l2_norm(diff_updates, squared=True) identity_scale = jnp.where( denominator > 0.0, numerator / denominator, 1.0 From 0d22fc0f07eed79ea7c68570b3489e81a1ae4105 Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 14:59:52 +0100 Subject: [PATCH 08/10] Add test parametrization over linesearch --- optax/_src/alias_test.py | 16 +++++++++++++--- optax/_src/transform.py | 4 +++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 7ce1a4d9..3919b294 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -37,6 +37,7 @@ import scipy.optimize as scipy_optimize from sklearn import datasets from sklearn import linear_model +from optax._src import linesearch as _linesearch ############## @@ -865,7 +866,16 @@ def fun(x): sol, _ = _run_opt(opt, fun, init_params=jnp.ones(n), tol=tol) chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol) - def test_complex(self): + @parameterized.product( + linesearch=[ + # _linesearch.zoom_linesearch(max_linesearch_steps=20), + _linesearch.scale_by_backtracking_linesearch(max_backtracking_steps=20), + _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=20, initial_guess_strategy='one' + ) + ], + ) + def test_complex(self, linesearch): """Test that optimization over complex variable z = x + jy matches equivalent real case""" @@ -892,8 +902,8 @@ def f_real(x): z0 = jnp.array([1 - 1j, 0 + 1j]) x0 = to_real(z0) - opt_complex = alias.lbfgs() - opt_real = alias.lbfgs() + opt_complex = alias.lbfgs(linesearch=linesearch) + opt_real = alias.lbfgs(linesearch=linesearch) sol_complex, _ = _run_opt(opt_complex, f_complex, init_params=z0, tol=tol) sol_real, _ = _run_opt(opt_real, f_real, init_params=x0, tol=tol) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 9f9f1547..a0675a60 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1666,7 +1666,9 @@ def update_fn( # 1. Updates the memory buffers given fresh params and gradients/updates diff_params = otu.tree_sub(params, state.params) diff_updates = otu.tree_sub(updates, state.updates) - vdot_diff_params_updates = otu.tree_real(otu.tree_vdot(diff_updates, diff_params)) + vdot_diff_params_updates = otu.tree_real( + otu.tree_vdot(diff_updates, diff_params) + ) weight = jnp.where( vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates ) From efcad3630cc2860422e32374836cd78c0114d32b Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 15:07:02 +0100 Subject: [PATCH 09/10] Fix linting for CI --- optax/_src/alias_test.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 3919b294..6ccf3106 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -27,6 +27,7 @@ import numpy as np from optax._src import alias from optax._src import base +from optax._src import linesearch as _linesearch from optax._src import numerics from optax._src import transform from optax._src import update @@ -37,7 +38,6 @@ import scipy.optimize as scipy_optimize from sklearn import datasets from sklearn import linear_model -from optax._src import linesearch as _linesearch ############## @@ -334,8 +334,9 @@ def stopping_criterion(carry): def step(carry): params, state, count, _ = carry value, grad = value_and_grad_fun(params) + grad = jnp.conj(grad) updates, state = opt.update( - jnp.conj(grad), state, params, value=value, grad=jnp.conj(grad), value_fn=fun + grad, state, params, value=value, grad=grad, value_fn=fun ) params = update.apply_updates(params, updates) return params, state, count + 1, grad @@ -876,11 +877,10 @@ def fun(x): ], ) def test_complex(self, linesearch): - """Test that optimization over complex variable z = x + jy matches equivalent - real case""" + # Test that optimization over complex variable matches equivalent real case tol = 1e-5 - W = jnp.array( + mat = jnp.array( [[1, - 2], [3, 4], [-4 + 2j, 5 - 3j], @@ -891,10 +891,10 @@ def to_real(z): return jnp.stack((z.real, z.imag)) def to_complex(x): - return x[..., 0, :] + 1j * x[..., 1, :] # if len(x)>0 else jnp.zeros_like(x) + return x[..., 0, :] + 1j * x[..., 1, :] def f_complex(z): - return jnp.sum(jnp.abs(W @ z) ** 1.5) + return jnp.sum(jnp.abs(mat @ z) ** 1.5) def f_real(x): return f_complex(to_complex(x)) @@ -907,7 +907,9 @@ def f_real(x): sol_complex, _ = _run_opt(opt_complex, f_complex, init_params=z0, tol=tol) sol_real, _ = _run_opt(opt_real, f_real, init_params=x0, tol=tol) - chex.assert_trees_all_close(sol_complex, to_complex(sol_real), atol=tol, rtol=tol) + chex.assert_trees_all_close( + sol_complex, to_complex(sol_real), atol=tol, rtol=tol + ) if __name__ == '__main__': absltest.main() From 5d1e6bc4d4167b8dbfd6022fb7f04c2486892e22 Mon Sep 17 00:00:00 2001 From: Ronan Gautier Date: Sat, 23 Nov 2024 15:38:49 +0100 Subject: [PATCH 10/10] Fix remaining tests in CI --- optax/_src/alias_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 6ccf3106..73f03c4a 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -334,7 +334,7 @@ def stopping_criterion(carry): def step(carry): params, state, count, _ = carry value, grad = value_and_grad_fun(params) - grad = jnp.conj(grad) + grad = otu.tree_conj(grad) updates, state = opt.update( grad, state, params, value=value, grad=grad, value_fn=fun )