Skip to content

Commit

Permalink
introduce sinkdiv simplified API, expand glossary with sinkhorn div…
Browse files Browse the repository at this point in the history
…ergence and LR sinkhorn (#596)

* better sinkdiv

* add terms in point cloud demo, fixes

* fix test

* adding LR term in solver

* renormalized
marcocuturi authored Nov 13, 2024
1 parent 828acca commit 86e8f2c
Showing 8 changed files with 247,748 additions and 374,073 deletions.
39 changes: 38 additions & 1 deletion docs/glossary.rst
Original file line number Diff line number Diff line change
@@ -201,6 +201,24 @@ Glossary
where :math:`\pi` is a :term:`coupling` density with first marginal
:math:`\mu` and second marginal :math:`\nu`.

low-rank optimal transport
Variant of the :term:`Kantorovich problem` whereby the search for an
optimal :term:`coupling` matrix :math:`P` is restricted to lie in a
subset of matrices of low-rank. Effectively, this is parameterized by
replacing :math:`P` by a low-rank factorization

.. math::
P = Q \text{diag}(g) R^T,
where :math:`Q,R` are :term:`coupling` matrices of size ``[n,r]`` and
``[m,r]`` and :math:`g` is a vector of size ``[r,]``. To be effective,
one assumes implicitly that rank :math:`r\ll n,m`. To solve this in
practice, the :term:`Kantorovich problem` is modified to only seek
solutions with this factorization, and updates on :math:`Q,R,g` are done
alternatively. These updates are themselves carried out by solving an
:term:`entropy-regularized optimal transport` problem.


matching
A bijective pairing between two families of points of the same size
:math:`N`, parameterized using a permutation of size :math:`N`.
@@ -252,6 +270,24 @@ Glossary
:math:`g` (resp. :math:`u` and :math:`v`) that cancels alternatively
their respective gradients, one at a time.

Sinkhorn divergence
Proxy for the :term:`Wasserstein distance` between two samples. Rather
than use the output of the :term:`Kantorovich problem` to compare two
families of samples, whose numerical resolution requires running a
linear program, use instead the objective of
:term:`entropy-regularized optimal transport` or that of
:term:`low-rank optimal transport` properly renormalized. This
normalization is done by considering:

.. math::
\text{SD}(\mu, \nu):= \Delta(\mu, \nu)
- \tfrac12 \left(\Delta(\mu, \mu) + \Delta(\nu, \nu)\right)
where :math:`Delta` is either the output of either
:term:`entropy-regularized optimal transport` or
:term:`low-rank optimal transport`

transport map
A function :math:`T` that associates to each point :math:`x` in the
support of a source distribution :math:`\mu` another point :math:`T(x)`
@@ -306,4 +342,5 @@ Glossary
distance is truly a distance (in the sense that it satisfies all 3
`metric axioms <https://en.wikipedia.org/wiki/Metric_space#Definition>`_
), as long as the :term:`ground cost` is itself a distance to a power
:math:`p\leq 1`, and the :math:`1/p`th power of the objective is taken.
:math:`p\leq 1`, and the :math:`p` root of the objective of the
:term:`Kantorovich problem` is used.
1 change: 1 addition & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
@@ -133,6 +133,7 @@ regularizer
regularizers
reimplementation
renormalize
renormalized
reparameterization
reproducibility
rescale
566,906 changes: 218,936 additions & 347,970 deletions docs/tutorials/geometry/000_point_cloud.ipynb

Large diffs are not rendered by default.

54,706 changes: 28,650 additions & 26,056 deletions docs/tutorials/linear/200_sinkhorn_divergence_gradient_flow.ipynb

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
@@ -261,13 +261,19 @@ def _inv_g(self) -> jnp.ndarray:
class LRSinkhorn(sinkhorn.Sinkhorn):
r"""Low-Rank Sinkhorn solver for linear reg-OT problems.
The algorithm tries to minimize the :term:`low-rank optimal transport`
problem, a constrained formulation of the :term:`Kantorovich problem` where
the :term:`coupling` variable is constrained to have a low-rank.
That problem is non-convex, and therefore any algorithm that tries to
solve it requires special attention to initialization and control of
convergence. Convergence is evaluated on successive evaluations of the
objective whereas initializers are instance of the
:class:`~ott.ott.initializers.linear.initializers_lr.LRInitializer` class.
The algorithm is described in :cite:`scetbon:21` and the implementation
contained here is adapted from `LOT <https://github.com/meyerscetbon/LOT>`_.
The algorithm minimizes a non-convex problem. It therefore requires special
care to initialization and convergence. Convergence is evaluated on successive
evaluations of the objective.
Args:
rank: Rank constraint on the coupling to minimize the linear OT problem
gamma: The (inverse of) gradient step size used by mirror descent.
12 changes: 8 additions & 4 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
@@ -175,6 +175,10 @@ def __call__(self, ot: Transport) -> List["plt.Artist"]:
*y.T, s=sy, edgecolors="k", marker="X", label="y"
)
self.ax.legend(fontsize=15)

if self._title is not None:
self.ax.set_title(self._title)

if not self._show_lines:
return []

@@ -191,8 +195,6 @@ def __call__(self, ot: Transport) -> List["plt.Artist"]:
alpha=alpha
)
self._lines.append(line)
if self._title is not None:
self.ax.set_title(self._title)
return [self._points_x, self._points_y] + self._lines

def update(self,
@@ -202,6 +204,10 @@ def update(self,
x, y, _, _ = self._scatter(ot)
self._points_x.set_offsets(x)
self._points_y.set_offsets(y)

if title is not None:
self.ax.set_title(title)

if not self._show_lines:
return []

@@ -232,8 +238,6 @@ def update(self,
self._lines.append(line)

self._lines = self._lines[:num_to_plot] # Maybe remove some
if title is not None:
self.ax.set_title(title)
return [self._points_x, self._points_y] + self._lines

def animate(
83 changes: 72 additions & 11 deletions src/ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
@@ -111,33 +111,89 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*children, **aux_data)


def sinkdiv(
x: jnp.ndarray,
y: jnp.ndarray,
*,
cost_fn: Optional[costs.CostFn] = None,
epsilon: Optional[float] = None,
**kwargs: Any,
) -> Tuple[jnp.ndarray, SinkhornDivergenceOutput]:
"""Wrapper to get the :term:`Sinkhorn divergence` between two point clouds.
Convenience wrapper around
:meth:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` provided to
compute the :term:`Sinkhorn divergence` between two point clouds compared with
any ground cost :class:`~ott.geometry.costs.CostFn`. See other relevant
arguments in :meth:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence`.
Args:
x: Array of input points, of shape `[num_x, feature]`.
y: Array of target points, of shape `[num_y, feature]`.
cost_fn: cost function of interest.
epsilon: entropic regularization.
kwargs: keywords arguments passed on to the generic
:meth:`~ott.tools.sinkhorn_divergence.sinkhorn_divergence` method. Of
notable interest are ``a`` and ``b`` weight vectors, ``static_b`` and
``offset_static_b`` which can be used to bypass the computations of the
transport problem between points stored in ``y`` (possibly with weights
``b``) and themselves, and ``solve_kwargs`` to parameterize the linear
OT solver.
Returns:
The Sinkhorn divergence value, and output object detailing computations.
"""
return sinkhorn_divergence(
pointcloud.PointCloud,
x=x,
y=y,
cost_fn=cost_fn,
epsilon=epsilon,
**kwargs
)


def sinkhorn_divergence(
geom: Type[geometry.Geometry],
*args: Any,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
solve_kwargs: Mapping[str, Any] = MappingProxyType({}),
static_b: bool = False,
offset_static_b: Optional[float] = None,
share_epsilon: bool = True,
symmetric_sinkhorn: bool = True,
**kwargs: Any,
) -> Tuple[jnp.ndarray, SinkhornDivergenceOutput]:
"""Compute Sinkhorn divergence defined by a geometry, weights, parameters.
r"""Compute :term:`Sinkhorn divergence` between two measures.
The :term:`Sinkhorn divergence` is computed between two measures :math:`\mu`
and :math:`\nu` by specifying three :class:`~ott.geometry.Geometry` objects,
each describing pairwise costs within points in :math:`\mu,\nu`,
:math:`\mu,\mu`, and :math:`\nu,\nu`.
This implementation proposes the most general interface, to generate those
three geometries by specifying first the type of
:class:`~ott.geometry.Geometry` that is used to compare
them, followed by the arguments used to generate these three
:class:`~ott.geometry.Geometry` instances through its corresponding
:meth:`~ott.geometry.geometry.Geometry.prepare_divergences` method.
Args:
geom: Type of the geometry.
args: Positional arguments to
:meth:`~ott.geometry.geometry.Geometry.prepare_divergences` that are
specific to each geometry.
a: the weight of each input point. The sum of all elements of `a` must
match that of `b` to converge.
b: the weight of each target point. The sum of all elements of `b` must
match that of `a` to converge.
a: the weight of each input point.
b: the weight of each target point.
solve_kwargs: keywords arguments for
:func:`~ott.solvers.linear.solve` that is called either twice
if ``static_b == True`` or three times when ``static_b == False``.
static_b: if ``True``, divergence of measure `b` against itself is **not**
computed.
static_b: if :obj:`True`, divergence of the second measure
(with weights ``b``) to itself is **not** recomputed.
offset_static_b: when ``static_b`` is :obj:`True`, use that value to offset
computation. Useful when the value of the divergence of the second measure
to itself is precomputed and not expected to change.
share_epsilon: if True, enforces that the same epsilon regularizer is shared
for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon
will be by default that used when comparing x to y (contained in the first
@@ -171,6 +227,7 @@ def sinkhorn_divergence(
a=a,
b=b,
symmetric_sinkhorn=symmetric_sinkhorn,
offset_yy=offset_static_b,
**solve_kwargs
)
return out.divergence, out
@@ -183,6 +240,7 @@ def _sinkhorn_divergence(
a: jnp.ndarray,
b: jnp.ndarray,
symmetric_sinkhorn: bool,
offset_yy: Optional[float],
**kwargs: Any,
) -> SinkhornDivergenceOutput:
"""Compute the (unbalanced) Sinkhorn divergence for the wrapper function.
@@ -205,6 +263,8 @@ def _sinkhorn_divergence(
all elements of ``b`` must match that of ``a`` to converge.
symmetric_sinkhorn: Use Sinkhorn updates in Eq. 25 of :cite:`feydy:19` for
symmetric terms comparing x/x and y/y.
offset_yy: when available, regularized OT cost precomputed on
``geometry_yy`` cost when transporting weight vector ``b`` onto itself.
kwargs: Keyword arguments to :func:`~ott.solvers.linear.solve`.
Returns:
@@ -239,7 +299,7 @@ def _sinkhorn_divergence(
out_xx = linear.solve(geometry_xx, a=a, b=a, **kwargs_symmetric)
if geometry_yy is None:
# Create dummy output, corresponds to scenario where static_b is True.
out_yy = _empty_output(is_low_rank)
out_yy = _empty_output(is_low_rank, offset_yy)
else:
out_yy = linear.solve(geometry_yy, a=b, b=b, **kwargs_symmetric)

@@ -406,7 +466,8 @@ def eval_fn(


def _empty_output(
is_low_rank: bool
is_low_rank: bool,
offset_yy: Optional[float] = None
) -> Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]:
if is_low_rank:
return sinkhorn_lr.LRSinkhornOutput(
@@ -419,13 +480,13 @@ def _empty_output(
converged=True,
costs=jnp.array([-jnp.inf]),
errors=jnp.array([-jnp.inf]),
reg_ot_cost=0.0,
reg_ot_cost=0.0 if offset_yy is None else offset_yy,
)

return sinkhorn.SinkhornOutput(
potentials=(None, None),
errors=jnp.array([-jnp.inf]),
reg_ot_cost=0.0,
reg_ot_cost=0.0 if offset_yy is None else offset_yy,
threshold=0.0,
inner_iterations=0,
)
60 changes: 33 additions & 27 deletions tests/tools/sinkhorn_divergence_test.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Dict, Optional, Tuple

import pytest
@@ -45,30 +46,18 @@ def setUp(self, rng: jax.Array):
)
def test_euclidean_point_cloud(self, cost_fn: costs.CostFn, rank: int):

def sinkdiv(
x: jnp.ndarray,
y: jnp.ndarray,
cost_fn: costs.CostFn,
epsilon: float,
) -> sinkhorn_divergence.SinkhornDivergenceOutput:
return sinkhorn_divergence.sinkhorn_divergence(
pointcloud.PointCloud,
x,
y,
cost_fn=cost_fn,
a=self._a,
b=self._b,
epsilon=epsilon,
solve_kwargs={"rank": rank},
)

is_low_rank = rank > 0
rngs = jax.random.split(self.rng, 2)
x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim))
y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim))

epsilon = 5e-2
div, out = jax.jit(sinkdiv)(x, y, cost_fn, epsilon)
sd = functools.partial(
sinkhorn_divergence.sinkdiv, solve_kwargs={"rank": rank}
)
div, out = jax.jit(sd)(
x, y, cost_fn=cost_fn, epsilon=epsilon, a=self._a, b=self._b
)

assert div >= 0.0
assert out.is_low_rank == is_low_rank
@@ -84,27 +73,44 @@ def sinkdiv(
assert iters_xx < iters_xy
assert iters_yy < iters_xy

grad = jax.jit(jax.grad(sinkdiv, has_aux=True, argnums=0))
np.testing.assert_array_equal(
jnp.isfinite(grad(x, y, cost_fn, epsilon)[0]), True
)

# Check computation of divergence matches that done separately.
geometry_xy = pointcloud.PointCloud(x, y, epsilon=epsilon, cost_fn=cost_fn)
geometry_xx = pointcloud.PointCloud(x, epsilon=epsilon, cost_fn=cost_fn)
geometry_yy = pointcloud.PointCloud(y, epsilon=epsilon, cost_fn=cost_fn)

div2 = linear.solve(
div2_xy = linear.solve(
geometry_xy, a=self._a, b=self._b, rank=rank
).reg_ot_cost
div2 -= 0.5 * linear.solve(
div2_xx = linear.solve(
geometry_xx, a=self._a, b=self._a, rank=rank
).reg_ot_cost
div2 -= 0.5 * linear.solve(
div2_yy = linear.solve(
geometry_yy, a=self._b, b=self._b, rank=rank
).reg_ot_cost

np.testing.assert_allclose(out.divergence, div2, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
div, div2_xy - .5 * (div2_xx + div2_yy), rtol=1e-5, atol=1e-5
)

# Check passing offset when using static_b works
div_offset, _ = sd(
x,
y,
cost_fn=cost_fn,
epsilon=epsilon,
a=self._a,
b=self._b,
static_b=True,
offset_static_b=div2_yy
)

np.testing.assert_allclose(div, div_offset, rtol=1e-5, atol=1e-5)

# Check gradient is finite
grad = jax.jit(jax.grad(sd, has_aux=True, argnums=0))
np.testing.assert_array_equal(
jnp.isfinite(grad(x, y, cost_fn=cost_fn, epsilon=epsilon)[0]), True
)

# Test divergence of x to itself close to 0.
div, out = sinkhorn_divergence.sinkhorn_divergence(

0 comments on commit 86e8f2c

Please sign in to comment.