Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting distance functions working #3

Merged
merged 12 commits into from
Oct 25, 2023
Prev Previous commit
Next Next commit
Add correlation and cosine
chrisflesher committed Oct 23, 2023
commit 9637d2713c973a067532752ae6fb7d7b44b91634
21 changes: 21 additions & 0 deletions src/jax_scipy_spatial/distance.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,27 @@ def cityblock(u: jax.Array, v: jax.Array, w: typing.Optional[jax.Array] = None)
return jnp.sum(l1_diff)


@_wraps(scipy.spatial.distance.correlation)
def correlation(u: jax.Array, v: jax.Array, w: typing.Optional[jax.Array] = None, centered: bool = True) -> jax.Array:
"""Compute the correlation distance between two 1-D arrays."""
if centered:
umu = jnp.average(u, weights=w)
vmu = jnp.average(v, weights=w)
u = u - umu
v = v - vmu
uv = jnp.average(u * v, weights=w)
uu = jnp.average(jnp.square(u), weights=w)
vv = jnp.average(jnp.square(v), weights=w)
dist = 1.0 - uv / jnp.sqrt(uu * vv)
return jnp.abs(dist)


@_wraps(scipy.spatial.distance.cosine)
def cosine(u: jax.Array, v: jax.Array, w: typing.Optional[jax.Array] = None) -> jax.Array:
"""Compute the Cosine distance between 1-D arrays."""
return jnp.clip(correlation(u, v, w=w, centered=False), 0.0, 2.0)


@_wraps(scipy.spatial.distance.euclidean)
def euclidean(u: jax.Array, v: jax.Array, w: typing.Optional[jax.Array] = None) -> jax.Array:
"""Computes the Euclidean distance between two 1-D arrays."""
24 changes: 24 additions & 0 deletions tests/scipy_spatial_distance_test.py
Original file line number Diff line number Diff line change
@@ -61,6 +61,30 @@ def testCityblock(self, shape, dtype):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, tol=1e-4)

@jtu.sample_product(
dtype=float_dtypes,
shape=[(num_samples,)],
)
def testCorrelation(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(shape, dtype))
jnp_fn = lambda u, v: jsp_distance.correlation(u, v)
np_fn = lambda u, v: osp_distance.correlation(u, v)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, tol=1e-4)

@jtu.sample_product(
dtype=float_dtypes,
shape=[(num_samples,)],
)
def testCosine(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(shape, dtype))
jnp_fn = lambda u, v: jsp_distance.cosine(u, v)
np_fn = lambda u, v: osp_distance.cosine(u, v)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, tol=1e-4)

@jtu.sample_product(
dtype=float_dtypes,
shape=[(num_samples,)],