Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix complex support for L-BFGS #1142

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

gautierronan
Copy link

Closes #1141.

Not 100% sure that this doesn't break other things or fully works, but at least the MWE below seems to work fine.

import optax
import jax.numpy as jnp

def f(x):
    return jnp.sum(jnp.abs(x**2))

solver = optax.lbfgs()
params = jnp.array([1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j])
print("Objective function: ", f(params))

opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(f)

for _ in range(5):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        jnp.conj(grad), opt_state, params, value=value, grad=jnp.conj(grad), value_fn=f
    )
    params = optax.apply_updates(params, updates)
    print("Objective function: ", f(params))

Notice the solve.update call which requires a jnp.conj(grad) twice. I believe this is correct and aligned with other optax solvers, but not sure either.

@vroulet
Copy link
Collaborator

vroulet commented Nov 22, 2024

Hey @gautierronan,
Thanks for the PR! We'll need a test. Take look at this PR: google/jaxopt#468 that added support for complex parameters for the lbfgs of jaxopt. I think you'll find all that you'll need in that PR.
Thanks again!

@gautierronan
Copy link
Author

gautierronan commented Nov 23, 2024

@vroulet Should be good for review. Note that, in the test, I have commented one linesearch option because the test was not passing, but I don't think it's related to complex support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

L-BFGS optimizer with complex inputs
2 participants