Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix complex support for L-BFGS #1142

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
45 changes: 45 additions & 0 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
from optax._src import alias
from optax._src import base
from optax._src import linesearch as _linesearch
from optax._src import numerics
from optax._src import transform
from optax._src import update
Expand Down Expand Up @@ -333,6 +334,7 @@ def stopping_criterion(carry):
def step(carry):
params, state, count, _ = carry
value, grad = value_and_grad_fun(params)
grad = otu.tree_conj(grad)
updates, state = opt.update(
grad, state, params, value=value, grad=grad, value_fn=fun
)
Expand Down Expand Up @@ -865,6 +867,49 @@ def fun(x):
sol, _ = _run_opt(opt, fun, init_params=jnp.ones(n), tol=tol)
chex.assert_trees_all_close(sol, jnp.zeros(n), atol=tol, rtol=tol)

@parameterized.product(
linesearch=[
# _linesearch.zoom_linesearch(max_linesearch_steps=20),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove that commented line. There is no need to test for any possible variant of the linesearch. It's already great that you included the backtracking linesearch in addition to the default linesearch.

_linesearch.scale_by_backtracking_linesearch(max_backtracking_steps=20),
_linesearch.scale_by_zoom_linesearch(
max_linesearch_steps=20, initial_guess_strategy='one'
)
],
)
def test_complex(self, linesearch):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name it test_lbfgs_complex rather than test_complex.

# Test that optimization over complex variable matches equivalent real case

tol = 1e-5
mat = jnp.array(
[[1, - 2],
[3, 4],
[-4 + 2j, 5 - 3j],
[-2 - 2j, 6]]
)

def to_real(z):
return jnp.stack((z.real, z.imag))

def to_complex(x):
return x[..., 0, :] + 1j * x[..., 1, :]

def f_complex(z):
return jnp.sum(jnp.abs(mat @ z) ** 1.5)

def f_real(x):
return f_complex(to_complex(x))

z0 = jnp.array([1 - 1j, 0 + 1j])
x0 = to_real(z0)

opt_complex = alias.lbfgs(linesearch=linesearch)
opt_real = alias.lbfgs(linesearch=linesearch)
sol_complex, _ = _run_opt(opt_complex, f_complex, init_params=z0, tol=tol)
sol_real, _ = _run_opt(opt_real, f_real, init_params=x0, tol=tol)

chex.assert_trees_all_close(
sol_complex, to_complex(sol_real), atol=tol, rtol=tol
)

if __name__ == '__main__':
absltest.main()
6 changes: 3 additions & 3 deletions optax/_src/linesearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def update_fn(
# Slope of lr -> value_fn(params + lr * updates) at lr = 0
# Should be negative to ensure that there exists a lr (potentially
# infinitesimal) that satisfies the criterion.
slope = otu.tree_vdot(updates, grad)
slope = otu.tree_real(otu.tree_vdot(updates, otu.tree_conj(grad)))

def cond_fn(
search_state: BacktrackingLineSearchState,
Expand Down Expand Up @@ -698,7 +698,7 @@ def _value_and_slope_on_line(
"""
step = otu.tree_add_scalar_mul(params, stepsize, updates)
value_step, grad_step = value_and_grad_fn(step, **fn_kwargs)
slope_step = otu.tree_vdot(grad_step, updates)
slope_step = otu.tree_real(otu.tree_vdot(otu.tree_conj(grad_step), updates))
return step, value_step, grad_step, slope_step

def _compute_decrease_error(
Expand Down Expand Up @@ -1205,7 +1205,7 @@ def init_fn(
f"Unknown initial guess strategy: {initial_guess_strategy}"
)

slope = otu.tree_vdot(updates, grad)
slope = otu.tree_real(otu.tree_vdot(updates, grad))
return ZoomLinesearchState(
count=jnp.asarray(0, dtype=jnp.int32),
#
Expand Down
10 changes: 6 additions & 4 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,7 @@ def right_product(vec, idx):
dwi, dui = jax.tree.map(
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
)
alpha = rhos[idx] * otu.tree_vdot(dwi, vec)
alpha = rhos[idx] * otu.tree_real(otu.tree_vdot(dwi, vec))
vec = otu.tree_add_scalar_mul(vec, -alpha, dui)
return vec, alpha

Expand All @@ -1536,7 +1536,7 @@ def left_product(vec, idx_alpha):
dwi, dui = jax.tree.map(
lambda x: x[idx], (diff_params_memory, diff_updates_memory)
)
beta = rhos[idx] * otu.tree_vdot(dui, vec)
beta = rhos[idx] * otu.tree_real(otu.tree_vdot(dui, vec))
vec = otu.tree_add_scalar_mul(vec, alpha - beta, dwi)
return vec, beta

Expand Down Expand Up @@ -1666,7 +1666,9 @@ def update_fn(
# 1. Updates the memory buffers given fresh params and gradients/updates
diff_params = otu.tree_sub(params, state.params)
diff_updates = otu.tree_sub(updates, state.updates)
vdot_diff_params_updates = otu.tree_vdot(diff_updates, diff_params)
vdot_diff_params_updates = otu.tree_real(
otu.tree_vdot(diff_updates, diff_params)
)
weight = jnp.where(
vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates
)
Expand All @@ -1691,7 +1693,7 @@ def update_fn(
# used to initialize the approximation of the inverse through the memory
# buffer.
if scale_init_precond:
numerator = otu.tree_vdot(diff_updates, diff_params)
numerator = otu.tree_real(otu.tree_vdot(diff_updates, diff_params))
denominator = otu.tree_l2_norm(diff_updates, squared=True)
identity_scale = jnp.where(
denominator > 0.0, numerator / denominator, 1.0
Expand Down
2 changes: 2 additions & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from optax.tree_utils._tree_math import tree_l2_norm
from optax.tree_utils._tree_math import tree_linf_norm
from optax.tree_utils._tree_math import tree_max
from optax.tree_utils._tree_math import tree_conj
from optax.tree_utils._tree_math import tree_real
from optax.tree_utils._tree_math import tree_mul
from optax.tree_utils._tree_math import tree_ones_like
from optax.tree_utils._tree_math import tree_scalar_mul
Expand Down
24 changes: 24 additions & 0 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,30 @@ def tree_max(tree: Any) -> chex.Numeric:
return jax.tree.reduce(jnp.maximum, maxes, initializer=jnp.array(-jnp.inf))


def tree_conj(tree: Any) -> Any:
"""Compute the conjugate of a pytree.

Args:
tree: pytree.

Returns:
a pytree with the same structure as ``tree``.
"""
return jax.tree.map(jnp.conj, tree)


def tree_real(tree: Any) -> Any:
"""Compute the real part of a pytree.

Args:
tree: pytree.

Returns:
a pytree with the same structure as ``tree``.
"""
return jax.tree.map(jnp.real, tree)


def _square(leaf):
return jnp.square(leaf.real) + jnp.square(leaf.imag)

Expand Down
18 changes: 18 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,24 @@ def test_tree_max(self, key):
got = tu.tree_max(tree)
np.testing.assert_allclose(expected, got)

def test_tree_conj(self):
expected = jnp.conj(self.array_a)
got = tu.tree_conj(self.array_a)
np.testing.assert_array_almost_equal(expected, got)

expected = (jnp.conj(self.tree_a[0]), jnp.conj(self.tree_a[1]))
got = tu.tree_conj(self.tree_a)
chex.assert_trees_all_close(expected, got)

def test_tree_real(self):
expected = jnp.real(self.array_a)
got = tu.tree_real(self.array_a)
np.testing.assert_array_almost_equal(expected, got)

expected = (jnp.real(self.tree_a[0]), jnp.real(self.tree_a[1]))
got = tu.tree_real(self.tree_a)
chex.assert_trees_all_close(expected, got)

def test_tree_l2_norm(self):
expected = jnp.sqrt(jnp.vdot(self.array_a, self.array_a).real)
got = tu.tree_l2_norm(self.array_a)
Expand Down
Loading