diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 08346fa13..b2c2d1865 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -2514,6 +2514,8 @@ def lbfgs( constrain the trust-region of the first step to an Euclidean ball of radius 1 at the first iteration. The choice of :math:`\gamma_0` is not detailed in the references above, so this is a heuristic choice. + + .. note:: The algorithm can support complex inputs. """ if learning_rate is None: base_scaling = transform.scale(-1.0) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index b03471b98..b48ea2d8f 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -27,6 +27,7 @@ import numpy as np from optax._src import alias from optax._src import base +from optax._src import linesearch as _linesearch from optax._src import numerics from optax._src import transform from optax._src import update @@ -333,6 +334,7 @@ def stopping_criterion(carry): def step(carry): params, state, count, _ = carry value, grad = value_and_grad_fun(params) + grad = otu.tree_conj(grad) updates, state = opt.update( grad, state, params, value=value, grad=grad, value_fn=fun ) @@ -865,6 +867,76 @@ def fun(x): sol, _ = _run_opt(opt, fun, init_params=jnp.ones(n), tol=tol) chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol) + @parameterized.product( + linesearch=[ + _linesearch.scale_by_backtracking_linesearch(max_backtracking_steps=20), + _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=20, initial_guess_strategy='one' + ) + ], + ) + def test_lbfgs_complex(self, linesearch): + # Test that optimization over complex variable matches equivalent real case + + tol = 1e-5 + mat = jnp.array( + [[1, - 2], + [3, 4], + [-4 + 2j, 5 - 3j], + [-2 - 2j, 6]] + ) + + def to_real(z): + return jnp.stack((z.real, z.imag)) + + def to_complex(x): + return x[..., 0, :] + 1j * x[..., 1, :] + + def f_complex(z): + return jnp.sum(jnp.abs(mat @ z) ** 1.5) + + def f_real(x): + return f_complex(to_complex(x)) + + z0 = jnp.array([1 - 1j, 0 + 1j]) + x0 = to_real(z0) + + opt_complex = alias.lbfgs(linesearch=linesearch) + opt_real = alias.lbfgs(linesearch=linesearch) + sol_complex, _ = _run_opt(opt_complex, f_complex, init_params=z0, tol=tol) + sol_real, _ = _run_opt(opt_real, f_real, init_params=x0, tol=tol) + + chex.assert_trees_all_close( + sol_complex, to_complex(sol_real), atol=tol, rtol=tol + ) + + @parameterized.product( + linesearch=[ + _linesearch.scale_by_backtracking_linesearch(max_backtracking_steps=20), + _linesearch.scale_by_zoom_linesearch( + max_linesearch_steps=20, initial_guess_strategy='one' + ) + ], + ) + def test_lbfgs_complex_rosenbrock(self, linesearch): + # Taken from previous jax tests + tol = 1e-5 + complex_dim = 5 + + fun_real = _get_problem('rosenbrock')['fun'] + init_real = jnp.zeros((2 * complex_dim,), dtype=complex) + expected_real = jnp.ones((2 * complex_dim,), dtype=complex) + + def fun(z): + x_real = jnp.concatenate([jnp.real(z), jnp.imag(z)]) + return fun_real(x_real) + + init = init_real[:complex_dim] + 1.j * init_real[complex_dim:] + expected = expected_real[:complex_dim] + 1.j * expected_real[complex_dim:] + + opt = alias.lbfgs(linesearch=linesearch) + got, _ = _run_opt(opt, fun, init, maxiter=500, tol=tol) + chex.assert_trees_all_close(got, expected, atol=tol, rtol=tol) if __name__ == '__main__': absltest.main() diff --git a/optax/_src/linesearch.py b/optax/_src/linesearch.py index 83f988e9a..a28a65af9 100644 --- a/optax/_src/linesearch.py +++ b/optax/_src/linesearch.py @@ -240,6 +240,8 @@ def scale_by_backtracking_linesearch( after the backtracking line-search doesn't necessarily need to satisfy the descent direction property (one could for example use momentum). + .. note:: The algorithm can support complex inputs. + .. seealso:: :func:`optax.value_and_grad_from_state` to make this method more efficient for non-stochastic objectives. @@ -319,7 +321,7 @@ def update_fn( # Slope of lr -> value_fn(params + lr * updates) at lr = 0 # Should be negative to ensure that there exists a lr (potentially # infinitesimal) that satisfies the criterion. - slope = otu.tree_vdot(updates, grad) + slope = otu.tree_real(otu.tree_vdot(updates, otu.tree_conj(grad))) def cond_fn( search_state: BacktrackingLineSearchState, @@ -698,7 +700,7 @@ def _value_and_slope_on_line( """ step = otu.tree_add_scalar_mul(params, stepsize, updates) value_step, grad_step = value_and_grad_fn(step, **fn_kwargs) - slope_step = otu.tree_vdot(grad_step, updates) + slope_step = otu.tree_real(otu.tree_vdot(otu.tree_conj(grad_step), updates)) return step, value_step, grad_step, slope_step def _compute_decrease_error( @@ -1205,7 +1207,7 @@ def init_fn( f"Unknown initial guess strategy: {initial_guess_strategy}" ) - slope = otu.tree_vdot(updates, grad) + slope = otu.tree_real(otu.tree_vdot(updates, grad)) return ZoomLinesearchState( count=jnp.asarray(0, dtype=jnp.int32), # @@ -1511,6 +1513,8 @@ def scale_by_zoom_linesearch( This can be sufficient in practice and avoids having the linesearch spend many iterations trying to satisfy the small curvature criterion. + .. note:: The algorithm can support complex inputs. + .. seealso:: :func:`optax.value_and_grad_from_state` to make this method more efficient for non-stochastic objectives. """ diff --git a/optax/_src/transform.py b/optax/_src/transform.py index dcfefdea8..a0675a608 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -1521,7 +1521,7 @@ def right_product(vec, idx): dwi, dui = jax.tree.map( lambda x: x[idx], (diff_params_memory, diff_updates_memory) ) - alpha = rhos[idx] * otu.tree_vdot(dwi, vec) + alpha = rhos[idx] * otu.tree_real(otu.tree_vdot(dwi, vec)) vec = otu.tree_add_scalar_mul(vec, -alpha, dui) return vec, alpha @@ -1536,7 +1536,7 @@ def left_product(vec, idx_alpha): dwi, dui = jax.tree.map( lambda x: x[idx], (diff_params_memory, diff_updates_memory) ) - beta = rhos[idx] * otu.tree_vdot(dui, vec) + beta = rhos[idx] * otu.tree_real(otu.tree_vdot(dui, vec)) vec = otu.tree_add_scalar_mul(vec, alpha - beta, dwi) return vec, beta @@ -1666,7 +1666,9 @@ def update_fn( # 1. Updates the memory buffers given fresh params and gradients/updates diff_params = otu.tree_sub(params, state.params) diff_updates = otu.tree_sub(updates, state.updates) - vdot_diff_params_updates = otu.tree_vdot(diff_updates, diff_params) + vdot_diff_params_updates = otu.tree_real( + otu.tree_vdot(diff_updates, diff_params) + ) weight = jnp.where( vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates ) @@ -1691,7 +1693,7 @@ def update_fn( # used to initialize the approximation of the inverse through the memory # buffer. if scale_init_precond: - numerator = otu.tree_vdot(diff_updates, diff_params) + numerator = otu.tree_real(otu.tree_vdot(diff_updates, diff_params)) denominator = otu.tree_l2_norm(diff_updates, squared=True) identity_scale = jnp.where( denominator > 0.0, numerator / denominator, 1.0 diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index 130b9d909..ff28716f6 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -35,6 +35,8 @@ from optax.tree_utils._tree_math import tree_l2_norm from optax.tree_utils._tree_math import tree_linf_norm from optax.tree_utils._tree_math import tree_max +from optax.tree_utils._tree_math import tree_conj +from optax.tree_utils._tree_math import tree_real from optax.tree_utils._tree_math import tree_mul from optax.tree_utils._tree_math import tree_ones_like from optax.tree_utils._tree_math import tree_scalar_mul diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 37e0ea27c..307acca9a 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -183,6 +183,30 @@ def tree_max(tree: Any) -> chex.Numeric: return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf)) +def tree_conj(tree: Any) -> Any: + """Compute the conjugate of a pytree. + + Args: + tree: pytree. + + Returns: + a pytree with the same structure as ``tree``. + """ + return jax.tree.map(jnp.conj, tree) + + +def tree_real(tree: Any) -> Any: + """Compute the real part of a pytree. + + Args: + tree: pytree. + + Returns: + a pytree with the same structure as ``tree``. + """ + return jax.tree.map(jnp.real, tree) + + def _square(leaf): return jnp.square(leaf.real) + jnp.square(leaf.imag) diff --git a/optax/tree_utils/_tree_math_test.py b/optax/tree_utils/_tree_math_test.py index c1908e620..8eeff0fb4 100644 --- a/optax/tree_utils/_tree_math_test.py +++ b/optax/tree_utils/_tree_math_test.py @@ -152,6 +152,24 @@ def test_tree_max(self, key): got = tu.tree_max(tree) np.testing.assert_allclose(expected, got) + def test_tree_conj(self): + expected = jnp.conj(self.array_a) + got = tu.tree_conj(self.array_a) + np.testing.assert_array_almost_equal(expected, got) + + expected = (jnp.conj(self.tree_a[0]), jnp.conj(self.tree_a[1])) + got = tu.tree_conj(self.tree_a) + chex.assert_trees_all_close(expected, got) + + def test_tree_real(self): + expected = jnp.real(self.array_a) + got = tu.tree_real(self.array_a) + np.testing.assert_array_almost_equal(expected, got) + + expected = (jnp.real(self.tree_a[0]), jnp.real(self.tree_a[1])) + got = tu.tree_real(self.tree_a) + chex.assert_trees_all_close(expected, got) + def test_tree_l2_norm(self): expected = jnp.sqrt(jnp.vdot(self.array_a, self.array_a).real) got = tu.tree_l2_norm(self.array_a)