Skip to content

Commit

Permalink
Merge pull request #7 from pfnet-research/harness-v4
Browse files Browse the repository at this point in the history
feat: lm_eval update
  • Loading branch information
masanorihirano authored Apr 13, 2024
2 parents 0e57c13 + 26f6fff commit 934ddfc
Show file tree
Hide file tree
Showing 58 changed files with 1,245 additions and 1,834 deletions.
1 change: 1 addition & 0 deletions jlm_fin_eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .api import metrics
69 changes: 69 additions & 0 deletions jlm_fin_eval/api/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Tuple

import numpy as np
from lm_eval.api.registry import register_aggregation
from lm_eval.api.registry import register_metric
from sklearn.metrics import f1_score


@register_aggregation("macro_f1_score")
def macro_f1_score(items: Tuple) -> float | np.ndarray:
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="macro")
return fscore


@register_aggregation("2class_adjusted_macro_f1_score_for_chabsa")
def two_class_adjusted_macro_f1_score_for_chabsa(items: Tuple) -> float | np.ndarray:
unzipped_list = list(zip(*items))
golds = unzipped_list[0]
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="macro") * 1.5
return fscore


@register_metric(
metric="f1_norm",
higher_is_better=True,
output_type="multiple_choice",
)
def f1_norm_fn(items): # This is a passthrough function
return items


@register_metric(
metric="map",
higher_is_better=True,
output_type="multiple_choice",
)
def map_fn(items): # This is a passthrough function
return items


@register_metric(
metric="map_2",
higher_is_better=True,
output_type="multiple_choice",
)
def map_2_fn(items): # This is a passthrough function
return items


@register_metric(
metric="map_3",
higher_is_better=True,
output_type="multiple_choice",
)
def map_3_fn(items): # This is a passthrough function
return items


@register_metric(
metric="map_4",
higher_is_better=True,
output_type="multiple_choice",
)
def map_4_fn(items): # This is a passthrough function
return items
Loading

0 comments on commit 934ddfc

Please sign in to comment.