Skip to content

Commit

Permalink
Update metrics handling and improve score calculations for consistency
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Dec 29, 2024
1 parent 5f2efae commit 146cb07
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 74 deletions.
6 changes: 3 additions & 3 deletions src/unitxt/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def recursive_mean(dic):
all_num_of_instances.append(score["num_of_instances"])
result[k] = score

result["score"] = nan_mean(all_scores)
result["score"] = float(nan_mean(all_scores))
result["score_name"] = "subsets_mean"
if all_num_of_instances:
result["num_of_instances"] = sum(all_num_of_instances)
Expand All @@ -267,9 +267,9 @@ def recursive_mean(dic):
if "subsets" in score:
score["subsets"] = recursive_mean(score["subsets"])
score["global"] = {
"score": score["subsets"]["score"],
"score": float(score["subsets"]["score"]),
"score_name": score["subsets"]["score_name"],
"subsets_mean": score["subsets"]["score"],
"subsets_mean": float(score["subsets"]["score"]),
}
if "num_of_instances" in score["subsets"]:
score["global"]["num_of_instances"] = score["subsets"][
Expand Down
39 changes: 31 additions & 8 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,12 @@ def __new__(
return super().__new__(cls, (prediction, references, task_data))


def is_original_key(key):
if key.endswith("_ci_low") or key.endswith("_ci_high") or key == "score":
return False
return True


class MapReduceMetric(
StreamOperator,
Metric,
Expand Down Expand Up @@ -437,6 +443,8 @@ def reduce_and_bootstrap(
) -> Dict[str, Any]:
scores = self.reduce(intermediates)
score_names = [k for k, v in scores.items() if isinstance(v, float)]
if self.n_resamples is None:
return scores
intervals = self.bootstrap(intermediates, score_names)
return {**scores, **intervals}

Expand Down Expand Up @@ -489,6 +497,21 @@ def map_stream(
def process(self, stream: Stream, stream_name: Optional[str] = None):
instances_scores, global_scores = self.compute(stream, stream_name)
for instance, instance_scores in zip(stream, instances_scores):
previous_score = instance.get("score", {"global": {}, "instance": {}})

for key in global_scores:
if is_original_key(key) and key in previous_score["global"]:
UnitxtWarning(
message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
f"which will yield, in this case, a score named: 'my_second_{key}')",
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
)

global_scores = {**previous_score["global"], **global_scores}
instance_scores = {**previous_score["instance"], **instance_scores}

yield {
**instance,
"score": {"global": global_scores, "instance": instance_scores},
Expand Down Expand Up @@ -783,11 +806,11 @@ def statistic(arr, axis, score_name=score_name):
random_state=self.new_random_generator(),
).confidence_interval
full_score_name = ci_score_prefix + score_name
result[f"{full_score_name}_ci_low"] = ci.low
result[f"{full_score_name}_ci_high"] = ci.high
result[f"{full_score_name}_ci_low"] = float(ci.low)
result[f"{full_score_name}_ci_high"] = float(ci.high)
if score_name == self.score_prefix + self.main_score:
result["score_ci_low"] = ci.low
result["score_ci_high"] = ci.high
result["score_ci_low"] = float(ci.low)
result["score_ci_high"] = float(ci.high)
return result

def resample_from_non_nan(self, values):
Expand Down Expand Up @@ -882,10 +905,10 @@ def metric(sample_refs, sample_preds, sample_task_data):
confidence_level=self.confidence_level,
random_state=random_gen,
).confidence_interval
result["score_ci_low"] = ci.low
result["score_ci_high"] = ci.high
result[f"{score_name}_ci_low"] = ci.low
result[f"{score_name}_ci_high"] = ci.high
result["score_ci_low"] = float(ci.low)
result["score_ci_high"] = float(ci.high)
result[f"{score_name}_ci_low"] = float(ci.low)
result[f"{score_name}_ci_high"] = float(ci.high)
return result


Expand Down
34 changes: 16 additions & 18 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,12 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
yield from stream


def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
for instance, score in zip(stream, scores):
instance["score"] = recursive_copy(score)
yield instance


class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
"""Applies metric operators to a stream based on a metric field specified in each instance.
Expand All @@ -1647,13 +1653,6 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
from .metrics import Metric, MetricsList

def update_scores_of_stream_instances(
stream: Stream, scores: List[dict]
) -> Generator:
for instance, score in zip(stream, scores):
instance["score"] = recursive_copy(score)
yield instance

# to be populated only when two or more metrics
accumulated_scores = []

Expand All @@ -1680,29 +1679,28 @@ def update_scores_of_stream_instances(
f"Operator {metric_name} must be a Metric or MetricsList"
)

for metric in metrics_list:
if not self.calc_confidence_intervals:
metric.disable_confidence_interval_calculation()
# Each metric operator computes its score and then sets the main score, overwriting
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
# This will cause the first listed metric to run last, and the main score will be set
# by the first listed metric (as desired).
metrics_list = list(reversed(metrics_list))

for metric_no, metric in enumerate(metrics_list):
if not self.calc_confidence_intervals:
metric.disable_confidence_interval_calculation()

if metric_no > 0:
# update input stream with accumulated scores
for i, metric in enumerate(metrics_list):
if i == 0: # first metric
multi_stream = MultiStream({"tmp": stream})
else: # metrics with previous scores
reusable_generator = ReusableGenerator(
generator=update_scores_of_stream_instances,
gen_kwargs={"stream": stream, "scores": accumulated_scores},
)
multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
else:
multi_stream = MultiStream.from_iterables({"tmp": stream})

multi_stream = metric(multi_stream)
if metric_no < len(metrics_list) - 1:
# not the last metric, so prepare for the next metric by
# updating accumulated_scores

if i < len(metrics_list) - 1: # last metric
accumulated_scores = []
for inst in multi_stream["tmp"]:
accumulated_scores.append(recursive_copy(inst["score"]))
Expand Down
72 changes: 33 additions & 39 deletions tests/library/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,62 +396,56 @@ def test_end_to_end(self):
"num_of_instances": 13,
},
"wnli": {
"num_of_instances": 12,
"f1_macro": 0.357,
"f1_micro": 0.5,
"f1_entailment": 0.0,
"f1_not entailment": 0.714,
"score": 0.5,
"f1_micro_ci_low": 0.2,
"f1_micro_ci_high": 0.762,
"f1_entailment_ci_low": None,
"f1_entailment_ci_high": None,
"f1_not entailment_ci_low": 0.364,
"f1_not entailment_ci_high": 0.933,
"score_name": "f1_micro",
"score_ci_low": 0.235,
"score_ci_high": 0.736,
"f1_macro_ci_low": 0.205,
"f1_macro_ci_high": 0.429,
"accuracy": 0.417,
"accuracy_ci_low": 0.167,
"accuracy_ci_high": 0.667,
"f1_micro": 0.5,
"f1_micro_ci_low": 0.235,
"f1_micro_ci_high": 0.736,
"score": 0.5,
"score_ci_high": 0.762,
"score_ci_low": 0.2,
"num_of_instances": 12,
"groups": {
"template": {
"templates.classification.multi_class.relation.default": {
"num_of_instances": 12,
"f1_macro": 0.357,
"f1_micro": 0.5,
"f1_entailment": 0.0,
"f1_not entailment": 0.714,
"score": 0.5,
"f1_micro_ci_low": 0.2,
"f1_micro_ci_high": 0.762,
"f1_entailment_ci_low": None,
"f1_entailment_ci_high": None,
"f1_not entailment_ci_low": 0.364,
"f1_not entailment_ci_high": 0.933,
"score_name": "f1_micro",
"score_ci_low": 0.235,
"score_ci_high": 0.736,
"f1_macro_ci_low": 0.205,
"f1_macro_ci_high": 0.429,
"accuracy": 0.417,
"accuracy_ci_low": 0.167,
"accuracy_ci_high": 0.667,
"f1_micro": 0.5,
"f1_micro_ci_low": 0.235,
"f1_micro_ci_high": 0.736,
"score": 0.5,
"score_ci_high": 0.762,
"score_ci_low": 0.2,
"num_of_instances": 12,
}
}
},
},
"rte": {
"num_of_instances": 5,
"f1_macro": 0.333,
"f1_not entailment": 0.667,
"f1_micro": 0.5,
"f1_entailment": 0.0,
"score": 0.5,
"f1_not entailment": 0.667,
"f1_micro_ci_low": 0.0,
"f1_micro_ci_high": 0.889,
"f1_entailment_ci_low": None,
"f1_entailment_ci_high": None,
"f1_not entailment_ci_low": 0.0,
"f1_not entailment_ci_high": 1.0,
"score_name": "f1_micro",
"score": 0.5,
"score_ci_high": 0.889,
"score_ci_low": 0.0,
"score_ci_high": 0.795,
"f1_macro_ci_low": 0.0,
"f1_macro_ci_high": 0.75,
"accuracy": 0.4,
"accuracy_ci_low": 0.0,
"accuracy_ci_high": 0.8,
"f1_micro": 0.5,
"f1_micro_ci_low": 0.0,
"f1_micro_ci_high": 0.795,
"num_of_instances": 5,
},
"score": 0.161,
"score_name": "subsets_mean",
Expand Down
28 changes: 24 additions & 4 deletions tests/library/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2641,10 +2641,30 @@ def _test_apply_metric(
calc_confidence_intervals=False,
):
inputs = [
{"prediction": "0", "references": ["1"], "metrics": metrics},
{"prediction": "1", "references": ["1"], "metrics": metrics},
{"prediction": "0", "references": ["2"], "metrics": metrics},
{"prediction": "0", "references": ["0"], "metrics": metrics},
{
"prediction": "0",
"references": ["1"],
"task_data": {"classes": ["0", "1", "2"]},
"metrics": metrics,
},
{
"prediction": "1",
"references": ["1"],
"task_data": {"classes": ["0", "1", "2"]},
"metrics": metrics,
},
{
"prediction": "0",
"references": ["2"],
"task_data": {"classes": ["0", "1", "2"]},
"metrics": metrics,
},
{
"prediction": "0",
"references": ["0"],
"task_data": {"classes": ["0", "1", "2"]},
"metrics": metrics,
},
]
output = apply_operator(
operator=ApplyMetric(
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
"filename": "src/unitxt/metrics.py",
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_verified": false,
"line_number": 2483,
"line_number": 2506,
"is_secret": false
}
],
Expand All @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2024-12-26T18:41:32Z"
"generated_at": "2024-12-29T09:44:05Z"
}

0 comments on commit 146cb07

Please sign in to comment.