Skip to content

Commit

Permalink
chore: ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
alxndrkalinin committed Feb 18, 2025
1 parent af495c3 commit c016abe
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/copairs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def pairwise_chebyshev(x_sample: np.ndarray, y_sample: np.ndarray) -> np.ndarray
return 1 / (1 + c_dist)


def _cdist_diag_sim(x_sample: np.ndarray, y_sample: np.ndarray, metric: str) -> np.ndarray:
def _cdist_diag_sim(
x_sample: np.ndarray, y_sample: np.ndarray, metric: str
) -> np.ndarray:
"""Compute similarity based on the diagonal of the ScipY's cdist result (row-wise distance).
Parameters
Expand Down Expand Up @@ -308,7 +310,9 @@ def get_similarity_fn(distance: Union[str, Callable]) -> Callable:
if distance in similarity_functions:
similarity_fn = similarity_functions[distance]
elif distance in SCIPY_METRICS_NAMES:
similarity_fn = lambda x_sample, y_sample: _cdist_diag_sim(x_sample, y_sample, distance)
similarity_fn = lambda x_sample, y_sample: _cdist_diag_sim(
x_sample, y_sample, distance
)
else:
raise ValueError(
f"Unsupported distance function: {distance}. Supported functions are: {set(similarity_functions.keys()) | set(SCIPY_METRICS_NAMES)}"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,4 @@ def test_hamming():
hamming_gt = hamming_naive(feats, pairs)
hamming_fn = compute.get_similarity_fn("hamming")
hamming = hamming_fn(feats, pairs, batch_size)
assert np.allclose(hamming_gt, hamming)
assert np.allclose(hamming_gt, hamming)

0 comments on commit c016abe

Please sign in to comment.