Skip to content

Commit

Permalink
Add MapReduceMetric a new base class to integrate all metrics into
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Dec 26, 2024
1 parent 0196391 commit c0217d9
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ extend-immutable-calls = ["fastapi.Depends", "fastapi.params.Depends", "fastapi.
"src".msg = "Use unitxt outside src/ and relative imports inside src/ and install unitxt from source with `pip install -e '.[dev]'`."

[tool.codespell]
ignore-words-list = 'rouge,ot,ans,nd,cann,som,tha,vie,ment,criterias'
ignore-words-list = 'rouge,ot,ans,nd,cann,som,tha,vie,ment,criterias,atleast'
check-filenames = true
check-hidden = false
regex = "(?<![a-z])[a-z'`]+|[A-Z][a-z'`]*|[a-z]+'[a-z]*|[a-z]+(?=[_-])|[a-z]+(?=[A-Z])|\\d+"
Expand Down
8 changes: 6 additions & 2 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,12 @@ def load_dataset(
).with_transform(loads_instance)


def evaluate(predictions, data) -> EvaluationResults:
return _compute(predictions=predictions, references=data)
def evaluate(
predictions, dataset: Union[Dataset, IterableDataset], data=None
) -> EvaluationResults:
if data is not None:
dataset = data # for backward compatibility
return _compute(predictions=predictions, references=dataset)


def post_process(predictions, data) -> List[Dict[str, Any]]:
Expand Down
255 changes: 254 additions & 1 deletion src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import uuid
import warnings
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from collections import Counter, defaultdict, namedtuple
from dataclasses import field
from functools import lru_cache
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
Expand Down Expand Up @@ -317,6 +317,259 @@ def update_and_adjust_global_score(
instance["score"]["global"].pop(score_ci)


def new_random_generator():
# The np.random.default_rng expects a 32-bit int, while hash(..) can return a 64-bit integer.
# So use '& MAX_32BIT' to get a 32-bit seed.
_max_32bit = 2**32 - 1
return np.random.default_rng(hash(get_seed()) & _max_32bit)


class ConfidenceIntervalMixin(Artifact):
n_resamples: int = 1000
confidence_level: float = 0.95
ci_score_names: List[str] = None

@abstractmethod
def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
pass

def get_statistic(self, data: List[Any], score_names: List[str]):
def statistic_function(indices, axis=0):
# indices might be a 1D or 2D array, depending on bootstrap internals
# For simplicity, ensure we handle them as 1D.
indices = np.atleast_1d(indices).astype(int)

# Gather the subset
sample = [data[i] for i in indices]

# Compute metrics on this sample
scores = self._sample_to_scores(sample)

# Return them in consistent order
return np.array([scores[m] for m in score_names])

return statistic_function

def bootstrap(self, data: List[Any], score_names: List[str]):
if self.ci_score_names is not None:
score_names = self.ci_score_names

intervals = bootstrap(
(np.arange(len(data)),),
statistic=self.get_statistic(data, score_names),
n_resamples=self.n_resamples,
confidence_level=self.confidence_level,
random_state=new_random_generator(),
paired=False,
vectorized=False, # set to True if your statistic function is vectorized
method="BCa",
).confidence_interval

result = {}
for i, metric in enumerate(score_names):
result[f"{metric}_ci_low"] = float(intervals.low[i])
result[f"{metric}_ci_high"] = float(intervals.high[i])

return result


from typing import Generic, TypeVar, NamedTuple
from dataclasses import dataclass

IntermediateType = TypeVar("IntermediateType")
PredictionType = TypeVar("PredictionType")


class EvaluationInput(tuple, Generic[PredictionType]):
def __new__(
cls,
prediction: PredictionType,
references: List[PredictionType],
task_data: Dict[str, Any],
) -> "EvaluationInput[PredictionType]":
return super().__new__(cls, (prediction, references, task_data))


class MapReduceMetric(
StreamOperator,
Metric,
ConfidenceIntervalMixin,
Generic[PredictionType, IntermediateType],
):
reference_field: str = NonPositionalField(default="references")
prediction_field: str = NonPositionalField(default="prediction")

def map(
self,
prediction: PredictionType,
references: List[PredictionType],
task_data: Dict[str, Any],
) -> IntermediateType:
raise NotImplementedError()

def reduce_one(self, intermidate: IntermediateType):
return self.reduce([intermidate])

@abstractmethod
def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]:
return {}

def disable_confidence_interval_calculation(self):
self.n_resamples = None

def annotate_main_score(self, scores):
scores["score_name"] = self.main_score
scores["score"] = scores[self.main_score]
for level in ["high", "low"]:
if f"{self.main_score}_ci_{level}" in scores:
scores[f"score_ci_{level}"] = scores[f"{self.main_score}_ci_{level}"]
return scores

def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
return self.reduce(sample)

def reduce_and_bootstrap(
self, intermediates: List[IntermediateType]
) -> Dict[str, Any]:
scores = self.reduce(intermediates)
score_names = [k for k, v in scores.items() if isinstance(v, float)]
intervals = self.bootstrap(intermediates, score_names)
return {**scores, **intervals}

def _instance_to_evaluation_input(
self, instance: Dict[str, Any]
) -> EvaluationInput[PredictionType]:
instance = self.verify_instance(instance)

task_data = instance.get("task_data", {})

if self.reference_field == "references":
references = instance["references"]
else:
references = task_data[self.reference_field]
if not isinstance(references, list):
references = [references]
if self.prediction_field == "prediction":
prediction = instance["prediction"]
else:
prediction = task_data[self.prediction_field]

self._validate_prediction(prediction)
self._validate_reference(references)

return EvaluationInput[PredictionType](
prediction=prediction, references=references, task_data=task_data
)

def _instances_stream_to_evaluation_inputs(
self, stream: Stream
) -> Generator[EvaluationInput[PredictionType], None, None]:
for instance in stream:
yield self._instance_to_evaluation_input(instance)

def map_stream(
self,
evaluation_inputs_stream: Generator[
EvaluationInput[PredictionType], None, None
],
):
intermediates = []
for prediction, references, task_data in evaluation_inputs_stream:
intermediate = self.map(
prediction=prediction, references=references, task_data=task_data
)

intermediates.append(intermediate)
return intermediates

def process(self, stream: Stream, stream_name: Optional[str] = None):
instances_scores, global_scores = self.compute(stream, stream_name)
for instance, instance_scores in zip(stream, instances_scores):
yield {
**instance,
"score": {"global": global_scores, "instance": instance_scores},
}

def compute(self, stream: Stream, stream_name: Optional[str] = None):
evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
intermediates_list = self.map_stream(evaluation_inputs_stream)

instances_scores = []
for intermediate in intermediates_list:
instance_score = self.reduce_one(intermediate)
instance_score = self.annotate_main_score(instance_score)
instances_scores.append(instance_score)

global_scores = self.reduce_and_bootstrap(intermediates_list)
global_scores = self.annotate_main_score(global_scores)

return instances_scores, global_scores


def get_index_or_default(lst, item, default=-1):
try:
return lst.index(item)
except ValueError:
return default


class F1Fast(MapReduceMetric[str, Tuple[int, int, List[str]]]):
main_score = "f1"

def prepare(self):
super().prepare()
from sklearn.metrics import f1_score

self._metric = f1_score

def map(
self, prediction: str, references: List[str], task_data: Dict[str, Any]
) -> Tuple[int, int, List[str]]:
assert "classes" in task_data, "F1Fast has to have classes"
classes: List[str] = task_data["classes"]
prediction_index = get_index_or_default(classes, prediction)
reference_index = get_index_or_default(classes, references[0])

return prediction_index, reference_index, classes

def reduce(self, intermediates: List[Tuple[int, int, List[str]]]) -> Dict[str, Any]:
# All classes are assumed to be the same set for each intermediate
all_classes = intermediates[0][2]
num_classes = len(all_classes)

# Build lists of true (y_true) and predicted (y_pred) indices
y_true = []
y_pred = []
for pred_idx, ref_idx, classes in intermediates:
y_pred.append(pred_idx if 0 <= pred_idx < num_classes else -1)
y_true.append(ref_idx if 0 <= ref_idx < num_classes else -1)

# Compute F1 scores
f1_macro = self._metric(
y_true, y_pred, average="macro", labels=range(num_classes), zero_division=0
)
f1_micro = self._metric(
y_true, y_pred, average="micro", labels=range(num_classes), zero_division=0
)
# For per-class F1, average=None returns an array of F1 for each label
f1_per_class = self._metric(
y_true, y_pred, average=None, labels=range(num_classes), zero_division=0
)

# Create a flat dict of all metrics
result = {
"f1": float(f1_macro), # Use macro-F1 as the "main_score"
"f1_macro": float(f1_macro),
"f1_micro": float(f1_micro),
}

# Add class-wise F1 using "f1_class_<class_name>" keys
for class_name, score in zip(all_classes, f1_per_class):
result[f"f1_{class_name}"] = float(score)

return result


class MetricWithConfidenceInterval(Metric):
# The number of resamples used to estimate the confidence intervals of this metric.
# Use None to disable confidence interval computation.
Expand Down
36 changes: 36 additions & 0 deletions tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Detector,
F1Binary,
F1BinaryPosOnly,
F1Fast,
F1Macro,
F1MacroMultiLabel,
F1Micro,
Expand Down Expand Up @@ -592,6 +593,41 @@ def test_f1_macro(self):
self.assertEqual("f1_macro", outputs[0]["score"]["global"]["score_name"])
self.assertEqual("f1_macro", outputs[0]["score"]["instance"]["score_name"])

def test_f1_macro_fast(self):
metric = F1Fast(main_score="f1_macro")
references = [["cat"], ["dog"], ["dog"], ["dog"], ["cat"], ["cat"]]
predictions = ["cat", "cat", "dog", "dog", "cat", "cat"]
task_data = [
{"classes": ["dog", "cat"]},
{"classes": ["dog", "cat"]},
{"classes": ["dog", "cat"]},
{"classes": ["dog", "cat"]},
{"classes": ["dog", "cat"]},
{"classes": ["dog", "cat"]},
]
# recall class 'dog' = 2/3 = 0.666 precision= 2/2 = 1 f1 = 0.8
# recall class 'cat' = 3/3 = 1 precision= 3/4 = 0.75 f1 = 0.857142857143
# macro f1 = (0.8+0.847)/2
global_target = 0.82857142
global_target_dog = 0.8
global_target_cat = 0.857142857143

outputs = apply_metric(
metric=metric,
predictions=predictions,
references=references,
task_data=task_data,
)
self.assertAlmostEqual(global_target, outputs[0]["score"]["global"]["score"])
self.assertAlmostEqual(
global_target_dog, outputs[0]["score"]["global"]["f1_dog"]
)
self.assertAlmostEqual(
global_target_cat, outputs[0]["score"]["global"]["f1_cat"]
)
self.assertEqual("f1_macro", outputs[0]["score"]["global"]["score_name"])
self.assertEqual("f1_macro", outputs[0]["score"]["instance"]["score_name"])

def test_f1_weighted(self):
metric = F1Weighted()
references = [
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
"filename": "src/unitxt/metrics.py",
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_verified": false,
"line_number": 2118,
"line_number": 2371,
"is_secret": false
}
],
Expand All @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2024-12-24T18:00:14Z"
"generated_at": "2024-12-26T12:47:09Z"
}

0 comments on commit c0217d9

Please sign in to comment.