Skip to content

Commit

Permalink
[Fix] l2_exp random fail in half-float32 mixed precision on self-neig…
Browse files Browse the repository at this point in the history
…hboring (#596)

Authors:
  - rhdong (https://github.com/rhdong)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: #596
  • Loading branch information
rhdong authored Jan 22, 2025
1 parent 43969ca commit 1c91e1f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
34 changes: 21 additions & 13 deletions cpp/src/distance/detail/distance_ops/l2_exp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ namespace cuvs::distance::detail::ops {
* for round-off error tolerance.
* @tparam DataT
*/
template <typename DataT>
__device__ constexpr DataT get_clamp_precision()
template <typename DataT, typename AccT>
__device__ constexpr AccT get_clamp_precision()
{
switch (sizeof(DataT)) {
case 2: return 1e-3;
case 4: return 1e-6;
case 8: return 1e-15;
default: return 0;
case 2: return AccT{1e-3};
case 4: return AccT{1e-6};
case 8: return AccT{1e-15};
default: return AccT{0};
}
}

Expand All @@ -46,19 +46,27 @@ struct l2_exp_cutlass_op {

__device__ l2_exp_cutlass_op() noexcept : sqrt(false) {}
__device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {}
inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept
inline __device__ AccT operator()(AccT aNorm, AccT bNorm, AccT accVal) const noexcept
{
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;
AccT outVal = aNorm + bNorm - AccT(2.0) * accVal;

/**
* Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal)
* can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead.
*/
outVal = outVal * AccT(!((outVal * outVal < get_clamp_precision<AccT>()) * (aNorm == bNorm)));
outVal =
outVal * AccT(!((outVal * outVal < get_clamp_precision<DataT, AccT>()) * (aNorm == bNorm)));
return sqrt ? raft::sqrt(outVal * static_cast<AccT>(outVal > AccT(0))) : outVal;
}

__device__ AccT operator()(DataT aData) const noexcept { return aData; }
__device__ AccT operator()(DataT aData) const noexcept
{
if constexpr (std::is_same_v<DataT, half> && std::is_same_v<AccT, float>) {
return __half2float(aData);
} else {
return aData;
}
}
};

/**
Expand Down Expand Up @@ -121,9 +129,9 @@ struct l2_exp_distance_op {
* (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal
* instead.
*/
acc[i][j] =
val * static_cast<AccT>((val > AccT(0))) *
static_cast<AccT>(!((val * val < get_clamp_precision<AccT>()) * (regxn[i] == regyn[j])));
acc[i][j] = val * static_cast<AccT>((val > AccT(0))) *
static_cast<AccT>(
!((val * val < get_clamp_precision<DataT, AccT>()) * (regxn[i] == regyn[j])));
}
}
if (sqrt) {
Expand Down
5 changes: 2 additions & 3 deletions python/cuvs/cuvs/test/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from cuvs.distance import pairwise_distance


@pytest.mark.parametrize("times", range(20))
@pytest.mark.parametrize("n_rows", [50, 100])
@pytest.mark.parametrize("n_cols", [10, 50])
@pytest.mark.parametrize(
Expand All @@ -43,7 +44,7 @@
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("order", ["F", "C"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16])
def test_distance(n_rows, n_cols, inplace, order, metric, dtype):
def test_distance(n_rows, n_cols, inplace, order, metric, dtype, times):
input1 = np.random.random_sample((n_rows, n_cols))
input1 = np.asarray(input1, order=order).astype(dtype)

Expand Down Expand Up @@ -79,7 +80,5 @@ def test_distance(n_rows, n_cols, inplace, order, metric, dtype):
actual = output_device.copy_to_host()

tol = 1e-3
if np.issubdtype(dtype, np.float16):
tol = 1e-1

assert np.allclose(expected, actual, atol=tol, rtol=tol)

0 comments on commit 1c91e1f

Please sign in to comment.