Skip to content

Commit

Permalink
norm bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rkansal47 committed Feb 9, 2024
1 parent 77effb6 commit b514d85
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
8 changes: 4 additions & 4 deletions jetnet/evaluation/gen_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,10 @@ def fpd(
stacklevel=2,
)

real_features, gen_features = _check_get_ndarray(real_features, gen_features)
X, Y = _check_get_ndarray(real_features, gen_features)

if normalise:
X, Y = _normalise_features(real_features, gen_features)
X, Y = _normalise_features(X, Y)

# regular intervals in 1/N
batches = (1 / np.linspace(1.0 / min_samples, 1.0 / max_samples, num_points)).astype("int32")
Expand Down Expand Up @@ -836,10 +836,10 @@ def kpd(
Returns:
Tuple[float, float]: median and error of KPD.
"""
real_features, gen_features = _check_get_ndarray(real_features, gen_features)
X, Y = _check_get_ndarray(real_features, gen_features)

if normalise:
X, Y = _normalise_features(real_features, gen_features)
X, Y = _normalise_features(X, Y)

if num_threads is None:
vals_point = _kpd_batches(X, Y, num_batches, batch_size, seed)
Expand Down
20 changes: 17 additions & 3 deletions tests/evaluation/test_gen_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,33 @@

test_zeros = np.zeros((50_000, 2))
test_ones = np.ones((50_000, 2))
test_twos = np.ones((50_000, 2)) * 2


def test_fpd():
val, err = evaluation.fpd(test_zeros, test_zeros)
assert val == approx(0, abs=0.01)
assert err < 1e-3

val, err = evaluation.fpd(test_zeros, test_ones)
assert val == approx(2, rel=0.01)
val, err = evaluation.fpd(test_twos, test_zeros)
assert val == approx(2, rel=0.01) # 1^2 + 1^2
assert err < 1e-3

# test normalization
val, err = evaluation.fpd(test_zeros, test_zeros, normalise=False) # should have no effect
assert val == approx(0, abs=0.01)
assert err < 1e-3

val, err = evaluation.fpd(test_twos, test_zeros, normalise=False)
assert val == approx(8, rel=0.01) # 2^2 + 2^2
assert err < 1e-3


@pytest.mark.parametrize("num_threads", [None, 2]) # test numba parallelization
def test_kpd(num_threads):
assert evaluation.kpd(test_zeros, test_zeros, num_threads=num_threads) == approx([0, 0])
assert evaluation.kpd(test_zeros, test_ones, num_threads=num_threads) == approx([15, 0])
assert evaluation.kpd(test_twos, test_zeros, num_threads=num_threads) == approx([15, 0])

# test normalization
assert evaluation.kpd(test_zeros, test_zeros, normalise=False, num_threads=num_threads) == approx([0, 0])
assert evaluation.kpd(test_twos, test_zeros, normalise=False, num_threads=num_threads) == approx([624, 0])

0 comments on commit b514d85

Please sign in to comment.