Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-negative least squares solver. #1155

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

carlosgmartin
Copy link
Contributor

Fixes #1152.

@carlosgmartin
Copy link
Contributor Author

@rdyro How does this look?

@carlosgmartin carlosgmartin force-pushed the nnls branch 2 times, most recently from a503ed5 to fe9fc6b Compare February 4, 2025 19:02
@rdyro
Copy link
Collaborator

rdyro commented Feb 5, 2025

@carlosgmartin This looks great, I left one comment!

@carlosgmartin
Copy link
Contributor Author

@rdyro Where is this comment?

optax/_src/linear_algebra.py Outdated Show resolved Hide resolved
@rdyro
Copy link
Collaborator

rdyro commented Feb 5, 2025

@rdyro Where is this comment?

Oops, should be up now!

@rdyro
Copy link
Collaborator

rdyro commented Feb 5, 2025

# We use lstsq with a pre-computed AtA to reduce computation time.
s = jnp.linalg.lstsq(AtA * p[:, None] * p[None, :], Atb * p)[0]

If speed is a consideration, perhaps we should use the jnp.linlag.lsqt(A * p, b) directly, letting XLA optimize as necessary?

I'd be curious if we could introduce a once-factorized cholesky version of this algorithm where we repeatedly apply cho_solve to a masked L or U factor of A^T A?

@carlosgmartin @fabianp ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add nnls (non-negative least squares)
2 participants