Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding NormalizeToRange op and support for unqiue_labels in the RocAuc computation. #914

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions kauldron/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,28 +142,27 @@ def compute(self) -> float:
# for which there are no GT examples and renormalize probabilities
# This will give wrong results, but allows getting a value during training
# where it cannot be guaranteed that each batch contains all classes.
curr_unique_label = np.unique(labels).tolist()
if self.parent.unique_labels is None:
unique_labels = np.unique(labels).tolist()
curr_unique_label = unique_labels
unique_labels = curr_unique_label
else:
# If we are testing on a small subset of data and by chance it does not
# contain all classes, we need to provide the groundtruth labels
# separately.
unique_labels = self.parent.unique_labels
curr_unique_label = np.unique(labels).tolist()

probs = out.probs[..., unique_labels]
probs /= probs.sum(axis=-1, keepdims=True) # renormalize
check_type(probs, Float["b n"])
if len(unique_labels) == 2:
# Binary mode: make it binary, otherwise sklearn complains.
# Binary mode: make it binary and assume positive class is 1, otherwise
# sklearn complains.
assert (
probs.shape[-1] == 2
), f"Unique labels are binary but probs.shape is {probs.shape}"
probs = probs[..., 1]
mask = out.mask[..., 0].astype(np.float32)
check_type(mask, Float["b"])
if len(curr_unique_label) > 1:
if len(curr_unique_label) == len(unique_labels):
# See comment above about small data subsets.
return sklearn_metrics.roc_auc_score(
y_true=labels,
Expand Down
34 changes: 34 additions & 0 deletions kauldron/metrics/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,37 @@ def test_roc():
s3.merge(s0)
with pytest.raises(ValueError, match='from different metrics'):
s0.merge(s3)


def test_roc_with_binary_labels():
metric = metrics.RocAuc()

logits = jnp.asarray([
[1.0, 0.0],
[0.3, 0.4],
[0.0, 2.0],
])
labels = jnp.asarray([[1], [0], [1]], dtype=jnp.int32)

s0 = metric.get_state(logits=logits, labels=labels)

x = s0.compute()
np.testing.assert_allclose(x, 0.5)



def test_roc_with_unique_labels():
metric = metrics.RocAuc(unique_labels=[0, 1, 2])

logits = jnp.asarray([
[1.0, 0.0, 0.0],
[0.3, 0.4, 0.3],
[0.0, 2.0, 8.0],
])
labels = jnp.asarray([[2], [0], [2]], dtype=jnp.int32)

s0 = metric.get_state(logits=logits, labels=labels)
x = s0.compute()
np.testing.assert_allclose(x, 0.0)


Loading