Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L-BFGS optimizer with complex inputs #1141

Open
gautierronan opened this issue Nov 21, 2024 · 2 comments · May be fixed by #1142
Open

L-BFGS optimizer with complex inputs #1141

gautierronan opened this issue Nov 21, 2024 · 2 comments · May be fixed by #1142

Comments

@gautierronan
Copy link

gautierronan commented Nov 21, 2024

I am adapting the docstring example of optax.lbfgs() to use complex-valued parameters instead of real-valued ones. The output function to optimize is still real-valued. But this raises a TypeError: true_fun and false_fun output must have identical types when calling solver.update, along with a warning ComplexWarning: Casting complex values to real discards the imaginary part. I am guessing that this warning is the source of the TypeError.

MWE:

import optax
import jax.numpy as jnp

def f(x):
    return jnp.sum(jnp.abs(x**2))

solver = optax.lbfgs()
params = jnp.array([1.0 + 1j, 2.0 + 2j, 3.0 + 3j])
print("Objective function: ", f(params))

opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(f)

for _ in range(5):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, params, value=value, grad=jnp.conj(grad), value_fn=f
    )
    params = optax.apply_updates(params, updates)
    print("Objective function: ", f(params))

Stack trace:

Objective function:  28.0
[/python3.11/site-packages/jax/_src/ops/scatter.py:92](/python3.11/site-packages/jax/_src/ops/scatter.py:92): FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=complex64 to dtype=float32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
  warnings.warn(
[/python3.11/site-packages/jax/_src/ops/scatter.py:134](/python3.11/site-packages/jax/_src/ops/scatter.py:134): ComplexWarning: Casting complex values to real discards the imaginary part
  return lax_internal._convert_element_type(out, dtype, weak_type)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 16
     14 for _ in range(5):
     15     value, grad = value_and_grad(params, state=opt_state)
---> 16     updates, opt_state = solver.update(
     17         grad, opt_state, params, value=value, grad=jnp.conj(grad), value_fn=f
     18     )
     19     params = optax.apply_updates(params, updates)
     20     print("Objective function: ", f(params))

File ~/miniconda3/lib/python3.11/site-packages/optax/transforms/_combining.py:73, in chain.<locals>.update_fn(updates, state, params, **extra_args)
     71 new_state = []
     72 for s, fn in zip(state, update_fns):
---> 73   updates, new_s = fn(updates, s, params, **extra_args)
     74   new_state.append(new_s)
     75 return updates, tuple(new_state)

File ~/miniconda3/lib/python3.11/site-packages/optax/_src/linesearch.py:1493, in scale_by_zoom_linesearch.<locals>.update_fn(updates, state, params, value, grad, value_fn, **extra_args)
   1484 stepsize_guess = state.learning_rate
   1485 init_state = init_ls(
   1486     updates,
   1487     params,
   (...)
   1490     stepsize_guess=stepsize_guess,
   1491 )
-> 1493 final_state = jax.lax.while_loop(
   1494     cond_step_ls,
   1495     functools.partial(
   1496         step_ls, value_and_grad_fn=value_and_grad_fn, fn_kwargs=fn_kwargs
   1497     ),
   1498     init_state,
   1499 )
   1500 learning_rate = final_state.stepsize
   1501 scaled_updates = otu.tree_scalar_mul(learning_rate, updates)

    [... skipping hidden 9 frame]

File ~/miniconda3/lib/python3.11/site-packages/optax/_src/linesearch.py:1166, in zoom_linesearch.<locals>.step_fn(state, value_and_grad_fn, fn_kwargs)
   1159 def step_fn(
   1160     state: ZoomLinesearchState,
   1161     *,
   1162     value_and_grad_fn: Callable[..., tuple[chex.Numeric, base.Updates]],
   1163     fn_kwargs: dict[str, Any],
   1164 ) -> ZoomLinesearchState:
   1165   """Makes a step of the linesearch."""
-> 1166   new_state = jax.lax.cond(
   1167       state.interval_found,
   1168       functools.partial(
   1169           _zoom_into_interval,
   1170           value_and_grad_fn=value_and_grad_fn,
   1171           fn_kwargs=fn_kwargs,
   1172       ),
   1173       functools.partial(
   1174           _search_interval,
   1175           value_and_grad_fn=value_and_grad_fn,
   1176           fn_kwargs=fn_kwargs,
   1177       ),
   1178       state,
   1179   )
   1180   new_state = jax.lax.cond(
   1181       new_state.failed,
   1182       _try_safe_step,
   1183       lambda x: x,
   1184       new_state
   1185   )
   1186   return new_state

    [... skipping hidden 3 frame]

File ~/miniconda3/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py:214, in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
    211 if not all(map(core.typematch, avals1, avals2)):
    212   diff = tree_map(_show_diff, tree_unflatten(tree1, avals1),
    213                   tree_unflatten(tree2, avals2))
--> 214   raise TypeError(f"{what} must have identical types, got\n{diff}.")

TypeError: true_fun and false_fun output must have identical types, got
ZoomLinesearchState(count='ShapedArray(int32[])', params='ShapedArray(complex64[3])', updates='ShapedArray(complex64[3])', stepsize_guess='ShapedArray(float32[], weak_type=True)', stepsize='DIFFERENT ShapedArray(complex64[]) vs. ShapedArray(float32[], weak_type=True)', value='ShapedArray(float32[])', grad='ShapedArray(complex64[3])', slope='ShapedArray(complex64[])', value_init='ShapedArray(float32[])', slope_init='ShapedArray(complex64[])', decrease_error='ShapedArray(complex64[])', curvature_error='ShapedArray(float32[])', error='ShapedArray(complex64[])', interval_found='ShapedArray(bool[])', done='ShapedArray(bool[])', failed='ShapedArray(bool[])', low='DIFFERENT ShapedArray(complex64[]) vs. ShapedArray(float32[], weak_type=True)', value_low='ShapedArray(float32[])', slope_low='ShapedArray(complex64[])', high='DIFFERENT ShapedArray(complex64[]) vs. ShapedArray(float32[], weak_type=True)', value_high='ShapedArray(float32[])', slope_high='ShapedArray(complex64[])', cubic_ref='ShapedArray(float32[], weak_type=True)', value_cubic_ref='ShapedArray(float32[])', safe_stepsize='DIFFERENT ShapedArray(complex64[]) vs. ShapedArray(float32[], weak_type=True)', safe_value='ShapedArray(float32[])', safe_grad='ShapedArray(complex64[3])').
@vroulet
Copy link
Collaborator

vroulet commented Nov 21, 2024

Hello @gautierronan,

The LBFGS optimizer does not yet support complex parameters. It should not be too hard to add this feature (see https://jaxopt.github.io/stable/_modules/jaxopt/_src/lbfgs.html#LBFGS).
Would you be willing to do such a PR?
If not you may rephrase the title of the issue as "Add support to complex inputs for LBFGS", I can treat that later.

@gautierronan gautierronan linked a pull request Nov 22, 2024 that will close this issue
@gautierronan
Copy link
Author

Attempt at a PR: #1142.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants