Skip to content

Commit

Permalink
Iteration
Browse files Browse the repository at this point in the history
Signed-off-by: An Thai Le <[email protected]>
  • Loading branch information
anindex committed Jul 19, 2023
1 parent c386342 commit dee7039
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 120 deletions.
52 changes: 52 additions & 0 deletions ssax/objectives/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import abc
import functools
import math
from typing import Any, Callable, Dict, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jaxopt
import numpy as np

from ott.math import fixed_point_loop, matrix_square_root
from ott.math import utils as mu



@jax.tree_util.register_pytree_node_class
class ObjectiveFn(abc.ABC):
"""Base class for all costs.
"""

@abc.abstractmethod
def evaluate(self, x: jnp.ndarray) -> jnp.ndarray:
"""Compute cost between :math:`x` and :math:`y`.
Args:
x: Array.
Returns:
The cost array.
"""

def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute cost between :math:`x` and :math:`y`.
Args:
x: Array.
y: Array.
Returns:
The cost, optionally including the :attr:`norms <norm>` of
:math:`x`/:math:`y`.
"""
cost = self.evaluate(x)
return cost

def tree_flatten(self): # noqa: D102
return (), None

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del aux_data
return cls(*children)
144 changes: 24 additions & 120 deletions ssax/ss/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import costs, geometry, low_rank
from ott.math import utils as mu
from ott.geometry import geometry

__all__ = ["GenericCost"]

Expand All @@ -18,164 +16,70 @@ class GenericCost(geometry.Geometry):
def __init__(
self,
x: jnp.ndarray,
cost_fn: costs.CostFn,
is_online: Optional[bool] = True,
scale_cost: Union[bool, int, float,
Literal["mean", "max_norm", "max_bound", "max_cost",
"median"]] = 1.0,
objective_fn: Any,
**kwargs: Any
):
super().__init__(**kwargs)
self.x = x # polytope vertices [b, n, d]
self._x = x # polytope vertices [batch, num_vertices, d]
self.objective_fn = objective_fn

self.cost_fn = cost_fn
self._axis_norm = 0 if callable(self.cost_fn.norm) else None

self.is_online = is_online
self._scale_cost = "mean" if scale_cost is True else scale_cost
@property.setter
def x(self, new_x: jnp.ndarray): # noqa: D102
assert new_x.ndim == 3
self._x = new_x
self._compute_cost_matrix()

@property
def _norm_x(self) -> Union[float, jnp.ndarray]:
if self._axis_norm == 0:
return self.cost_fn.norm(self.x)
return 0.
def x(self) -> jnp.ndarray: # noqa: D102
return self._x

@property
def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
if self.is_online:
return None
cost_matrix = self._compute_cost_matrix()
return cost_matrix * self.inv_scale_cost
if self._cost_matrix is None:
self._compute_cost_matrix()
return self._cost_matrix * self.inv_scale_cost

@property
def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
if self.is_online:
return None
return jnp.exp(-self.cost_matrix / self.epsilon)

@property
def shape(self) -> Tuple[int, int, int]:
# in the process of flattening/unflattening in vmap, `__init__`
# can be called with dummy objects
# we optionally access `shape` in order to get the batch size
if self.x is None:
if self._x is None:
return 0
return self.x.shape
return self._x.shape

@property
def is_symmetric(self) -> bool: # noqa: D102
return self.x.shape[0] == self.x.shape[1]

@property
def inv_scale_cost(self) -> float: # noqa: D102
if isinstance(self._scale_cost,
(int, float)) or utils.is_jax_array(self._scale_cost):
return 1.0 / self._scale_cost
self = self._masked_geom()
if self._scale_cost == "max_cost":
if self.is_online:
return 1.0 / self._compute_summary_online(self._scale_cost)
return 1.0 / jnp.max(self._compute_cost_matrix())
if self._scale_cost == "mean":
if self.is_online:
return 1.0 / self._compute_summary_online(self._scale_cost)
if self.shape[0] > 0:
geom = self._masked_geom(mask_value=jnp.nan)._compute_cost_matrix()
return 1.0 / jnp.nanmean(geom)
return 1.0
if self._scale_cost == "median":
if not self.is_online:
geom = self._masked_geom(mask_value=jnp.nan)
return 1.0 / jnp.nanmedian(geom._compute_cost_matrix())
raise NotImplementedError(
"Using the median as scaling factor for "
"the cost matrix with the online mode is not implemented."
)
if self._scale_cost == "max_norm":
if self.cost_fn.norm is not None:
return 1.0 / jnp.maximum(self._norm_x.max(), self._norm_y.max())
return 1.0
if self._scale_cost == "max_bound":
if self.is_squared_euclidean:
x_argmax = jnp.argmax(self._norm_x)
y_argmax = jnp.argmax(self._norm_y)
max_bound = (
self._norm_x[x_argmax] + self._norm_y[y_argmax] +
2 * jnp.sqrt(self._norm_x[x_argmax] * self._norm_y[y_argmax])
)
return 1.0 / max_bound
raise NotImplementedError(
"Using max_bound as scaling factor for "
"the cost matrix when the cost is not squared euclidean "
"is not implemented."
)
raise ValueError(f"Scaling {self._scale_cost} not implemented.")
return self._x.shape[0] == self._x.shape[1]

def _compute_cost_matrix(self) -> jnp.ndarray:
cost_matrix = ...
return cost_matrix

def apply_cost(
self,
arr: jnp.ndarray,
axis: int = 0,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
is_linear: bool = False,
) -> jnp.ndarray:

if self.is_squared_euclidean and (fn is None or is_linear):
return self.vec_apply_cost(arr, axis, fn=fn)

return self._apply_cost(arr, axis, fn=fn)

def _apply_cost(
self, arr: jnp.ndarray, axis: int = 0, fn=None
) -> jnp.ndarray:
"""See :meth:`apply_cost`."""
if not self.is_online:
return super().apply_cost(arr, axis, fn)

app = jax.vmap(
_apply_cost_xy,
in_axes=[None, 0, None, self._axis_norm, None, None, None, None]
)
if arr.ndim == 1:
arr = arr.reshape(-1, 1)

if axis == 0:
return app(
self.x, self.y, self._norm_x, self._norm_y, arr, self.cost_fn,
self.inv_scale_cost, fn
)
return app(
self.y, self.x, self._norm_y, self._norm_x, arr, self.cost_fn,
self.inv_scale_cost, fn
)
self._cost_matrix = self.objective_fn(self._x)

def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray:
"""Compute barycenter of points in self.x using weights."""
return self.cost_fn.barycenter(self.x, weights)[0]
"""Compute barycenter of points in self._x using weights."""
return jnp.average(self._x, weights=weights, axis=1) # [batch, d]

def tree_flatten(self): # noqa: D102
return (
self.x,
self.y,
self._x,
self._src_mask,
self._tgt_mask,
self._epsilon_init,
self.cost_fn,
self.objective_fn,
), {
"batch_size": self._batch_size,
"scale_cost": self._scale_cost
}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
x, y, src_mask, tgt_mask, epsilon, cost_fn = children
x, src_mask, tgt_mask, epsilon, objective_fn = children
return cls(
x,
y,
cost_fn=cost_fn,
objective_fn=objective_fn,
src_mask=src_mask,
tgt_mask=tgt_mask,
epsilon=epsilon,
Expand Down

0 comments on commit dee7039

Please sign in to comment.