Tags: ott-jax/ott
Tags
Feature/batched vmap (#588) * Start batched vmap * Initial `batched_vmap` impl * Nicer formatting * Fix getting shape * Remove private API usage * Fix new args * Add a TODO * Canonicalize axes * Add `batched_vmap` to docs * Removed batched transport functions * Remove `_norm_{x,y}` from `CostFn` * Implement `apply_lse_kernel` * Implememt `apply_kernel` * Implement `apply_cost` * Remove old functions * Make function private * Refactor `apply_cost` to have consistent shapes * Use `_apply_cost_to_vec` in `PointCloud` * Remoeve TODO * Formatting * Simplify `_apply_sqeucl_cost` * Fix `RecusionError` * Remove docstring of a private method * Fix `apply_lse_kernel` * Squeeze only 1 axis of the cost * Add TODO * Rename function, make a property * Remove unused helper function * Compute mean summary online * Compute mean online * Compute max cost matrix * Update error message * Remove TODO * Flatten out axes * Fix missing cross terms in the costs * Fix geom tests * Fix dtype * Start implementing transport functions * Implement online transport functions * Fix solver tests * Fix Bures test * Don't use `pairwise` in tests * Update notebook that uses `norm` * Fix bug in `UnbalancedBures` * Rename `pairwise -> __call__` * Remove old shape code * Always instantiate the cost for online * Remove old TODO * Extract `_apply_cost_to_vec_fast` * Update max cost in LRCGeom * Fix test, use more `multi_dot` * Remove `batch_size` from `LRCGeometry` * Add better warning error * Reorder properties * Add docs to `batched_vmap` * Start adding tests * Reorder functions in test * Fix axes, add a test * Update test fn * Move out assert * Dont canon out_axes * Check max traces * Test memory of batched vmap * Install `typing_extensions` * Remove `.` from description * Add more `out_axes` tests * Add `in_axes` test * Fix negative axes * Increase memory limit in the test * Add in_axes pytree test * Remove old warnings filters * Update fixtures * Update SqEucl cost. * Update docstrings * Remove unused imports from the docs * Revert test pre-commits * Fix ICNN init notebook Was broken by #551 * Improve error message
Refactor regularized TI costs (#553) * Update docs Add new costs Update tests Final polishing * Decrease tolerance in `test_l2_moreau_envelope` * Fix not passing `kwargs` * Expose initial estimate in `h_transform` * Test for `x_init` * Change variable name in `h_transform` * Better `h_transform` parameterization * Add more references * Add a missing reference * Update docs of p-norm * Improve the docs of `RegTICost.h_transform` * Rename `is_squared` -> `is_factor` * Better `h_legendre` docs * Expose solver in `TICost.h_transform`
bug fix: avoid mixing up linear and quadratic in genot (#517) * bug fix: avoid mixing up linear and quadratic part by returning Dict in genot prepare_data() * fix data_match_fn() setup in genot tests * prepare_data() in GENOT now returns a tuple instead of a dict; change order of args in utils.match_quadratic() * Update docs * Fix typo --------- Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com>
Add low-rank kernel geometry (#440) * Add Gaussian kernel * Make runnable with Sinkhorn * Fix epsilon * Use the same random vectors, fix epsilon * Add arccos kernel * Add citation * Fix citation, start working on docs * Addd LambertW * Polish docstrings * Update Lambert W docs * Format references * Remove useless TYPE_CHECKING * Update docstring * Add test skeletons * Add rank test * Update tests * Start working on arccos cost * Add arccos to docs * Add `s=2` option for `Arccos` * Fix LRK tests * Fix tests and tree API * Rename argument * Add generic order using `jax.grad` * Test also the implementation of `J(theta)` * Polish flaky test * Address comments
Feature/ulrgw (#410) * Remove low-rank from GromovWasserstein solver * First skeleton loop * Add LRGW implementation * Add ULFGW * Revert change * Add a TODO * Fix `grad_g` in the fused case * Update docs * Remove duplicate citation * Fix cost for the fused case * Fix bugs in TI * Remove unused import * Change way array extraction in LR init works * Disallow LR in the old GW solver * Disallow LR in old GW class * Remove `is_entropic` property * Use `jnp.linalg.norm` * Simplify initializers in GW * Simplify initializer creation for low-rank * Remove temporary name * Fix norms * Fix linkcheck * Remove old initializers test * Fix more initializer tests * Remove `LRQuadraticInitializer`, `reg_ot_cost -> reg_gw_cost` * `host_callback` -> `io_callback` * Fix more initializers tests * Fix more tests * Remove initializer mention from the docs * Remove mention of LR initializer * Start incorporating GWLoss * Simplify reg GW cost computation * Finish `primal_cost` * Don't calculate unbal. grads in balanced case * Fix `primal_cost` in balanced case * Update GW LR notebook * Convert quad problem to LR if possible * Convert quad problem to LR if possible * Regenerate GWLR Sinkhorn * Regenerate `LRSinkhorn` * [ci skip] Fix linter * Fix convergence metric * Undo TODO * Fix factor * Regenerate notebooks * Add tests
PreviousNext