Skip to content

Commit

Permalink
doc fixes + ot bar coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
eloitanguy committed Jan 21, 2025
1 parent 3e8421e commit ccf608a
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 41 deletions.
46 changes: 27 additions & 19 deletions examples/barycenters/plot_barycenter_generic_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
projection onto a circle k. This is an example of the fixed-point barycenter
solver introduced in [74] which generalises [20] and [43].
The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over
:math:`x` with Pytorch.
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
Barycentres of Measures for Generic Transport
Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
Barycentres of Measures for Generic Transport Costs.
arXiv preprint 2501.04016 (2024)
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein
Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
Conference in Machine Learning
[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein
Barycenters. InternationalConference in Machine Learning
[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in
Wasserstein space. Journal of Mathematical Analysis and Applications 441.2
(2016): 744-762.
"""

Expand All @@ -32,7 +33,8 @@

# sphinx_gallery_thumbnail_number = 1

# %% Generate data
# %%
# Generate data
import torch
from torch.optim import Adam
from ot.utils import dist
Expand All @@ -43,7 +45,7 @@

torch.manual_seed(42)

n = 100 # number of points of the of the barycentre
n = 200 # number of points of the of the barycentre
d = 2 # dimensions of the original measure
K = 4 # number of measures to barycentre
m = 50 # number of points of the measures
Expand Down Expand Up @@ -82,7 +84,8 @@ def proj_circle(X, origin, radius):
Y_list.append(P_list[k](X_temp))


# %% Define costs and ground barycenter function
# %%
# Define costs and ground barycenter function
# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
# (n, n_k) matrix of costs
def c1(x, y):
Expand Down Expand Up @@ -140,25 +143,30 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
return x


# %% Compute the barycenter measure
fixed_point_its = 10
# %%
# Compute the barycenter measure
fixed_point_its = 3
X_init = torch.rand(n, d)
X_bar = free_support_barycenter_generic_costs(
X_init,
Y_list,
b_list,
X_init,
cost_list,
B,
numItermax=fixed_point_its,
stopThr=stop_threshold,
)

# %% Plot Barycenter (Iteration 10)
alpha = 0.5
# %%
# Plot Barycenter (Iteration 3)
alpha = 0.4
s = 80
labels = ["circle 1", "circle 2", "circle 3", "circle 4"]
for Y, label in zip(Y_list, labels):
plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label)
plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha)
plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s)
plt.scatter(
*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s
)
plt.axis("equal")
plt.xlim(-0.3, 1.3)
plt.ylim(-0.3, 1.3)
Expand Down
61 changes: 39 additions & 22 deletions ot/lp/_barycenter_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,19 +429,20 @@ class StoppingCriterionReached(Exception):


def free_support_barycenter_generic_costs(
X_init,
measure_locations,
measure_weights,
X_init,
cost_list,
B,
a=None,
numItermax=5,
stopThr=1e-5,
log=False,
):
r"""
Solves the OT barycenter problem for generic costs using the fixed point
algorithm, iterating the ground barycenter function B on transport plans
between the current barycentre and the measures.
between the current barycenter and the measures.
The problem finds an optimal barycenter support `X` of given size (n, d)
(enforced by the initialisation), minimising a sum of pairwise transport
Expand All @@ -452,12 +453,13 @@ def free_support_barycenter_generic_costs(
where:
- :math:`X` (n, d) is the barycentre support,
- :math:`a` (n) is the (fixed) barycentre weights,
- :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`),
- :math:`X` (n, d) is the barycenter support,
- :math:`a` (n) is the (fixed) barycenter weights,
- :math:`Y_k` (m_k, d_k) is the k-th measure support
(`measure_locations[k]`),
- :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`),
- :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix)
- :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`:
- :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`:
.. math::
\mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F
Expand All @@ -471,9 +473,10 @@ def free_support_barycenter_generic_costs(
in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k,
c_k(X, Y_k))`.
The algorithm requires a given ground barycentre function `B` which computes
a solution of the following minimisation problem given :math:`(y_1, \cdots,
y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`:
The algorithm requires a given ground barycenter function `B` which computes
(broadcasted of `n`) solutions of the following minimisation problem given
:math:`(Y_1, \cdots, Y_K) \in
\mathbb{R}^{n\times d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`:
.. math::
B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k),
Expand All @@ -482,23 +485,32 @@ def free_support_barycenter_generic_costs(
:math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times
\cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to
this function, and for certain costs it can be computed explicitly of
through a numerical solver.
through a numerical solver. The input function B takes a list of K arrays of
shape (n, d_k) and returns an array of shape (n, d).
This function implements [74] Algorithm 2, which generalises [20] and [43]
to general costs and includes convergence guarantees, including for discrete measures.
to general costs and includes convergence guarantees, including for discrete
measures.
Parameters
----------
X_init : array-like
Array of shape (n, d) representing initial barycentre points.
measure_locations : list of array-like
List of K arrays of measure positions, each of shape (m_k, d_k).
measure_weights : list of array-like
List of K arrays of measure weights, each of shape (m_k).
X_init : array-like
Array of shape (n, d) representing initial barycenter points.
cost_list : list of callable
List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`.
List of K cost functions :math:`c_k: \mathbb{R}^{n\times
d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times
m_k}`.
B : callable
Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre.
Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays
of shape (n\times d_K), computing the ground barycenters (broadcasted
over n).
a : array-like, optional
Array of shape (n,) representing weights of the barycenter
measure.Defaults to uniform.
numItermax : int, optional
Maximum number of iterations (default is 5).
stopThr : float, optional
Expand All @@ -509,7 +521,7 @@ def free_support_barycenter_generic_costs(
Returns
-------
X : array-like
Array of shape (n, d) representing barycentre points.
Array of shape (n, d) representing barycenter points.
log_dict : list of array-like, optional
log containing the exit status, list of iterations and list of
displacements if log is True.
Expand All @@ -518,22 +530,27 @@ def free_support_barycenter_generic_costs(
References
----------
.. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
.. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
barycenters of Measures for Generic Transport Costs. arXiv preprint
2501.04016 (2024)
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein
barycenters." International Conference on Machine Learning. 2014.
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to
barycenters in Wasserstein space." Journal of Mathematical Analysis and
Applications 441.2 (2016): 744-762.
See Also
--------
ot.lp.free_support_barycenter : Free support solver for the case where
:math:`c_k(x,y) = \|x-y\|_2^2`.
ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|x-y\|_2^2`.
ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
"""
nx = get_backend(X_init, measure_locations[0])
K = len(measure_locations)
n = X_init.shape[0]
a = nx.ones(n) / n
if a is None:
a = nx.ones(n, type_as=X_init) / n
X_list = [X_init] if log else [] # store the iterations
X = X_init
dX_list = [] # store the displacement squared norms
Expand Down
95 changes: 95 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from ot.datasets import make_1D_gauss as gauss
from ot.backend import torch, tf

# import ot.lp._barycenter_solvers # TODO: remove this import


def test_emd_dimension_and_mass_mismatch():
# test emd and emd2 for dimension mismatch
Expand Down Expand Up @@ -395,6 +397,99 @@ def test_generalised_free_support_barycenter_backends(nx):
np.testing.assert_allclose(Y, nx.to_numpy(Y2))


def test_free_support_barycenter_generic_costs():
measures_locations = [
np.array([-1.0]).reshape((1, 1)),
np.array([1.0]).reshape((1, 1)),
]
measures_weights = [np.array([1.0]), np.array([1.0])]

X_init = np.array([-12.0]).reshape((1, 1))

# obvious barycenter location between two Diracs
bar_locations = np.array([0.0]).reshape((1, 1))

def cost(x, y):
return ot.dist(x, y)

cost_list = [cost, cost]

def B(y):
out = 0
for yk in y:
out += yk / len(y)
return out

X = ot.lp.free_support_barycenter_generic_costs(
measures_locations, measures_weights, X_init, cost_list, B
)

np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)

# test with log and specific weights
X2, log = ot.lp.free_support_barycenter_generic_costs(
measures_locations,
measures_weights,
X_init,
cost_list,
B,
a=ot.unif(1),
log=True,
)

assert "X_list" in log
assert "exit_status" in log
assert "dX_list" in log

np.testing.assert_allclose(X, X2, rtol=1e-5, atol=1e-7)

# test with one iteration for Max Iterations Reached
X3, log2 = ot.lp.free_support_barycenter_generic_costs(
measures_locations,
measures_weights,
X_init,
cost_list,
B,
numItermax=1,
log=True,
)
assert log2["exit_status"] == "Max iterations reached"


def test_free_support_barycenter_generic_costs_backends(nx):
measures_locations = [
np.array([-1.0]).reshape((1, 1)),
np.array([1.0]).reshape((1, 1)),
]
measures_weights = [np.array([1.0]), np.array([1.0])]
X_init = np.array([-12.0]).reshape((1, 1))

def cost(x, y):
return ot.dist(x, y)

cost_list = [cost, cost]

def B(y):
out = 0
for yk in y:
out += yk / len(y)
return out

X = ot.lp.free_support_barycenter_generic_costs(
measures_locations, measures_weights, X_init, cost_list, B
)

measures_locations2 = nx.from_numpy(*measures_locations)
measures_weights2 = nx.from_numpy(*measures_weights)
X_init2 = nx.from_numpy(X_init)

X2 = ot.lp.free_support_barycenter_generic_costs(
measures_locations2, measures_weights2, X_init2, cost_list, B
)

np.testing.assert_allclose(X, nx.to_numpy(X2))


@pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():
a1 = np.array([1.0, 0, 0])[:, None]
Expand Down

0 comments on commit ccf608a

Please sign in to comment.