Skip to content

Commit

Permalink
Untuned working version on Ackley
Browse files Browse the repository at this point in the history
Signed-off-by: anindex <[email protected]>
  • Loading branch information
anindex committed Jul 22, 2023
1 parent 757969f commit c3e1611
Show file tree
Hide file tree
Showing 13 changed files with 417 additions and 151 deletions.
8 changes: 4 additions & 4 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ def mixture_log_prob_fn(sample):
def run_mcmc(
rng_key,
model: Callable,
reference_data: Union[np.ndarray, jnp.ndarray],
reference_data: Union[np.array, jnp.array],
n_train_data: int,
n_eval_data: int,
n_warmup: int = 1000,
n_chains: int = 1,
) -> jnp.ndarray:
) -> jnp.array:

n_proposal_samples_per_chain = math.ceil((n_train_data + n_eval_data) / n_chains)
kernel = NUTS(model, dense_mass=False, init_strategy=init_to_median(num_samples=10_000), target_accept_prob=0.7)
Expand Down Expand Up @@ -316,7 +316,7 @@ def run_density_estimation(
train_dataloader = build_dataloader(train_dataset, batch_size=batch_size)
eval_dataloader = build_dataloader(eval_dataset, batch_size=batch_size)

def training_step(state: TrainState, data: jnp.ndarray):
def training_step(state: TrainState, data: jnp.array):
def loss_fn(model_params: hk.Params, batch):
loss = -jnp.mean(state.apply_fn(model_params, batch))
return loss
Expand All @@ -326,7 +326,7 @@ def loss_fn(model_params: hk.Params, batch):

return state, None

def eval_fn(state: TrainState, data: jnp.ndarray):
def eval_fn(state: TrainState, data: jnp.array):
loss = -jnp.mean(state.apply_fn(state.params, data))
return state, loss

Expand Down
100 changes: 100 additions & 0 deletions scripts/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,105 @@
import jax
import jax.numpy as jnp
from jax import jit

import numpy as np
import matplotlib.pyplot as plt
import time

from ott.solvers.linear.sinkhorn import Sinkhorn
from ott.geometry.epsilon_scheduler import Epsilon

from ssax.ss.initializer import SSGaussianInitializer, SSUniformInitializer
from ssax.ss.costs import GenericCost
from ssax.ss.solver import SinkhornStep
from ssax.ss.epsilon_scheduler import LinearEpsilon

from ssax.objectives.visualization import plot_objective
from ssax.objectives.synthetic import Ackley, Beale


if __name__ == '__main__':
# plt.figure()
# plot_objective(Ackley())
# rng = jax.random.PRNGKey(0)
# num_points = 100
# # initializer = SSGaussianInitializer(
# # jnp.array([0.0, 0.0]),
# # jnp.eye(2),
# # rng=rng
# # )
# initializer = SSUniformInitializer(
# jnp.array([[-5.0, 5.0], [-5.0, 5.0]]),
# rng=rng
# )
# X = initializer(num_points)
# plt.scatter(X[:, 0], X[:, 1], c='r')
# plt.show()

rng = jax.random.PRNGKey(0)
num_points = 100000
# sinkhorn solver
solver = Sinkhorn(
threshold=1e-3,
inner_iterations=1,
min_iterations=1,
max_iterations=100,
initializer='default'
)

# cost function
objective_fn = Ackley()

# initialize points
initializer = SSUniformInitializer(
jnp.array([[-5.0, 5.0], [-5.0, 5.0]]),
rng=rng
)
X = initializer(num_points)

# Sinkhorn Step solver
epsilon = LinearEpsilon(
target=0.3,
init=1.,
decay=0.01,
)
sinkhorn_step = SinkhornStep(
objective_fn=objective_fn,
linear_ot_solver=solver,
epsilon=epsilon,
polytope_type='orthoplex',
step_radius=0.15,
probe_radius=0.2,
num_probe=5,
min_iterations=5,
max_iterations=100,
threshold=1e-3,
rng=rng
)

# run Sinkhorn Step
state = sinkhorn_step.init_state(X)
plt.figure()
ax = plt.gca()
for i in range(100):
plt.clf()
tic = time.time()
state = sinkhorn_step.step(state, i)
toc = time.time()
print(f'Iteration {i}, time: {toc - tic}')
plot_objective(objective_fn, ax=ax)
X = state.X
plt.scatter(X[:, 0], X[:, 1], c='r', s=3)
ax.set_aspect('equal')
plt.draw()
plt.pause(1e-4)

# tic = time.time()
# state = sinkhorn_step.iterations(X)
# toc = time.time()
# print(f'Time: {toc - tic}')
# plt.figure()
# plot_objective(objective_fn)
# X = state.X
# plt.scatter(X[:, 0], X[:, 1], c='r', s=3)
# plt.show()
23 changes: 13 additions & 10 deletions ssax/objectives/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class ObjectiveFn(abc.ABC):

dim: int
_optimal_value: float
_optimizers: Optional[jnp.ndarray] = None
_bounds: jnp.ndarray
_optimizers: Optional[jnp.array] = None
_bounds: jnp.array

def __init__(self,
noise_std: Optional[float] = None,
Expand Down Expand Up @@ -41,24 +41,27 @@ def optimal_value(self) -> float:
return -self._optimal_value if self.negate else self._optimal_value

@abc.abstractmethod
def evaluate(self, X: jnp.ndarray) -> jnp.ndarray:
def evaluate(self, X: jnp.array) -> jnp.array:
"""Compute cost
Args:
X: Array.
X: array.
Returns:
The cost array.
"""

def __call__(self, X: jnp.ndarray) -> jnp.ndarray:
def __call__(self, X: jnp.array) -> jnp.array:
cost = self.evaluate(X)
return cost

def tree_flatten(self): # noqa: D102
return (), None
def tree_flatten(self):
return (), {
"noise_std": self.noise_std,
"negate": self.negate,
"bounds": self._bounds,
}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
del aux_data
return cls(*children)
def tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
19 changes: 15 additions & 4 deletions ssax/objectives/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,41 @@ def __init__(
self.c = 2 * jnp.pi

@jit
def evaluate(self, X: jnp.ndarray) -> jnp.ndarray:
def evaluate(self, X: jnp.array) -> jnp.array:
a, b, c = self.a, self.b, self.c
part1 = -a * jnp.exp(-b / jnp.sqrt(self.dim) * jnp.linalg.norm(X, axis=-1))
part2 = -(jnp.exp(jnp.mean(jnp.cos(c * X), axis=-1)))
return part1 + part2 + a + jnp.e

def tree_flatten(self):
return (), {
"dim": self.dim,
"noise_std": self.noise_std,
"negate": self.negate,
"bounds": self._bounds,
}

@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)


@jax.tree_util.register_pytree_node_class
class Beale(ObjectiveFn):

dim = 2
_optimal_value = 0.0
_bounds = jnp.array([(-4.5, 4.5), (-4.5, 4.5)])
_optimizers = jnp.array([(3.0, 0.5)])

def __init__(self,
noise_std: float | None = None,
negate: bool = False,
bounds: List[Tuple[float, float]] | None = None,
**kwargs: Any):
self.dim = 2
super().__init__(noise_std, negate, bounds, **kwargs)

@jit
def evaluate(self, X: jnp.ndarray) -> jnp.ndarray:
def evaluate(self, X: jnp.array) -> jnp.array:
x1, x2 = X[..., 0], X[..., 1]
part1 = (1.5 - x1 + x1 * x2) ** 2
part2 = (2.25 - x1 + x1 * x2**2) ** 2
Expand Down
2 changes: 1 addition & 1 deletion ssax/objectives/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def plot_objective(objective_fn, ax=None):
X, Y = np.meshgrid(x, y)
Z = objective_fn(np.stack([X, Y], axis=-1))
plt.contourf(X, Y, Z, 100)
plt.colorbar()
# plt.colorbar()
56 changes: 27 additions & 29 deletions ssax/ss/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax
import jax.numpy as jnp

from ssax.ss.polytopes import POLYTOPE_NUM_VERTICES_MAP

from ott.geometry import geometry

__all__ = ["GenericCost"]
Expand All @@ -16,69 +18,65 @@ class GenericCost(geometry.Geometry):
def __init__(
self,
objective_fn: Any,
X: jnp.array,
**kwargs: Any
):
super().__init__(**kwargs)
self.objective_fn = objective_fn
self._X = None

@property.setter
def X(self, new_x: jnp.ndarray): # noqa: D102
assert new_x.ndim == 4 # polytope vertices [batch, num_vertices, num_probe, d]
self._X = new_x
self._compute_cost_matrix()

@property
def X(self) -> jnp.ndarray: # noqa: D102
return self._X
self.objective_fn = objective_fn

assert X.ndim == 4 # polytope vertices [batch, num_vertices, num_probe, d]
self.X = X

@property
def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
def cost_matrix(self) -> Optional[jnp.array]:
if self._cost_matrix is None:
self._compute_cost_matrix()
return self._cost_matrix * self.inv_scale_cost

@property
def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
def kernel_matrix(self) -> Optional[jnp.array]:
return jnp.exp(-self.cost_matrix / self.epsilon)

@property
def shape(self) -> Tuple[int, int, int]:
def shape(self) -> Tuple[int, int]:
# in the process of flattening/unflattening in vmap, `__init__`
# can be called with dummy objects
# we optionally access `shape` in order to get the batch size
if self._X is None:
return 0
return self._X.shape
return self.X.shape[:2]

@property
def is_symmetric(self) -> bool: # noqa: D102
return self._X.shape[0] == self._X.shape[1]
def is_symmetric(self) -> bool:
return self.X.shape[0] == self.X.shape[1]

def _compute_cost_matrix(self) -> jnp.ndarray:
costs = self.objective_fn(self._X)
def _compute_cost_matrix(self) -> jnp.array:
costs = self.objective_fn(self.X)
self._cost_matrix = costs.mean(axis=-1) # [batch, num_vertices]

def tree_flatten(self): # noqa: D102
def evaluate(self, X: jnp.array) -> jnp.array:
"""Evaluate cost function at given points."""
return self.objective_fn(X)

def tree_flatten(self):
return (
self._X,
self.objective_fn,
self.X,
self._src_mask,
self._tgt_mask,
self._epsilon_init,
self.objective_fn,
), {
"scale_cost": self._scale_cost
"scale_cost": self._scale_cost,
"relative_epsilon": self._relative_epsilon
}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
X, src_mask, tgt_mask, epsilon, objective_fn = children
def tree_unflatten(cls, aux_data, children):
objective_fn, X, src_mask, tgt_mask, epsilon = children
return cls(
X,
objective_fn=objective_fn,
X=X,
src_mask=src_mask,
tgt_mask=tgt_mask,
epsilon=epsilon,
**aux_data
)

4 changes: 2 additions & 2 deletions ssax/ss/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def __init__(self, target: float | None = None, scale_epsilon: float | None = No
def at(self, iteration: Optional[int] = 1) -> float:
if iteration is None:
return self.target
eps = jnp.minimum(self._init - (self._decay * iteration), self.target)

eps = jnp.maximum(self._init - (self._decay * iteration), self.target)
return eps
Loading

0 comments on commit c3e1611

Please sign in to comment.