Skip to content

Commit

Permalink
Merge branch 'main' into add-text2sql
Browse files Browse the repository at this point in the history
  • Loading branch information
perlitz authored Jan 8, 2025
2 parents 342b7c5 + 21b732c commit b6da498
Show file tree
Hide file tree
Showing 36 changed files with 1,226 additions and 120 deletions.
58 changes: 58 additions & 0 deletions prepare/cards/frames_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unitxt
from unitxt.blocks import LoadHF
from unitxt.card import TaskCard
from unitxt.catalog import add_to_catalog
from unitxt.collections_operators import Wrap
from unitxt.operators import (
Rename,
Set,
WikipediaFetcher,
)
from unitxt.processors import LiteralEval
from unitxt.test_utils.card import test_card

with unitxt.settings.context(allow_unverified_code=True):
card = TaskCard(
loader=LoadHF(
path="google/frames-benchmark", data_classification_policy=["public"]
),
preprocess_steps=[
Rename(field="Prompt", to_field="question"),
Rename(field="Answer", to_field="answer"),
Wrap(field="answer", inside="list", to_field="answers"),
LiteralEval(field="wiki_links", to_field="context"),
WikipediaFetcher(field="context", process_every_value=True),
Set(fields={"context_type": "wikipedia articles"}),
],
task="tasks.qa.with_context",
templates="templates.qa.with_context.all",
__description__=(
"""FRAMES is a comprehensive evaluation dataset designed to test the capabilities of Retrieval-Augmented Generation (RAG) systems across factuality, retrieval accuracy, and reasoning."""
),
__tags__={
"annotations_creators": "expert-generated",
"arxiv": "1904.09728",
"flags": ["NLU", "natural language understanding"],
"language": "en",
"language_creators": "other",
"license": "other",
"multilinguality": "monolingual",
"region": "us",
"size_categories": "10K<n<100K",
"source_datasets": "extended|other",
"task_categories": [
"text-classification",
"token-classification",
"question-answering",
],
"task_ids": [
"natural-language-inference",
"word-sense-disambiguation",
"coreference-resolution",
"extractive-qa",
],
},
)

test_card(card, strict=False)
add_to_catalog(card, "cards.frames", overwrite=True)
33 changes: 28 additions & 5 deletions prepare/cards/open_australian_legal_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
SplitRandomMix,
TaskCard,
)
from unitxt.operators import (
Copy,
ListFieldValues,
Shuffle,
)
from unitxt.collections_operators import Wrap
from unitxt.operators import Copy, ListFieldValues, Set, Shuffle
from unitxt.test_utils.card import test_card

card = TaskCard(
Expand Down Expand Up @@ -49,3 +46,29 @@
add_to_catalog(
card, "cards.rag.response_generation.train.open_australian_legal_qa", overwrite=True
)


card = TaskCard(
loader=LoadHF(
path="umarbutler/open-australian-legal-qa",
),
preprocess_steps=[
SplitRandomMix(
{"train": "train[0.5]", "validation": "train[0.2]", "test": "train[0.3]"}
),
Shuffle(),
Set({"context_type": "legal document"}),
Copy(field="source/text", to_field="context/body"),
Copy(field="source/citation", to_field="context/title"),
Wrap(field="answer", inside="list", to_field="answers"),
],
task="tasks.qa.with_context",
templates="templates.qa.with_context.all",
)

test_card(
card,
strict=True,
demos_taken_from="test",
)
add_to_catalog(card, "cards.open_australian_legal_qa", overwrite=True)
2 changes: 1 addition & 1 deletion prepare/cards/rag/end_to_end/bioasq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
RenameSplits({"test": "train"}),
Cast(field="id", to="str"),
Copy(field="id", to_field="document_id"),
Wrap(field="passage", inside="list"),
Wrap(field="passage", inside="list", to_field="passages"),
Set(
fields={
"metadata_field": "",
Expand Down
13 changes: 3 additions & 10 deletions prepare/cards/rag/end_to_end/clapnq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from unitxt import add_to_catalog
from unitxt.blocks import TaskCard
from unitxt.loaders import LoadCSV
from unitxt.loaders import LoadCSV, LoadHF
from unitxt.operators import Copy, ListFieldValues, Set
from unitxt.templates import InputOutputTemplate
from unitxt.test_utils.card import test_card
Expand All @@ -16,12 +16,6 @@ class ClapNqBenchmark:
TEST_RAW_FILE_URL: str = "https://raw.githubusercontent.com/primeqa/clapnq/main/retrieval/dev/question_dev_answerable.tsv"


@dataclass(frozen=True)
class ClapNqDocuments:
# Raw_data
RAW_FILE_URL: str = "https://media.githubusercontent.com/media/primeqa/clapnq/main/retrieval/passages.tsv"


card = TaskCard(
loader=LoadCSV(
sep="\t",
Expand Down Expand Up @@ -78,9 +72,8 @@ class ClapNqDocuments:

# Documents
card = TaskCard(
loader=LoadCSV(
sep="\t",
files={"train": ClapNqDocuments.RAW_FILE_URL},
loader=LoadHF(
path="PrimeQA/clapnq_passages",
data_classification_policy=["public"],
),
preprocess_steps=[
Expand Down
4 changes: 2 additions & 2 deletions prepare/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from unitxt import add_to_catalog
from unitxt.metrics import Accuracy, BinaryAccuracy, BinaryMaxAccuracy
from unitxt.metrics import AccuracyFast, BinaryAccuracy, BinaryMaxAccuracy
from unitxt.test_utils.metrics import test_metric

metric = Accuracy()
metric = AccuracyFast()

predictions = ["A", "B", "C"]
references = [["B"], ["A"], ["C"]]
Expand Down
9 changes: 5 additions & 4 deletions prepare/metrics/f1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@
from unitxt.metrics import (
BinaryMaxF1,
F1Binary,
F1Macro,
F1Fast,
F1MacroMultiLabel,
F1Micro,
F1MicroMultiLabel,
F1Strings,
F1Weighted,
PrecisionBinary,
RecallBinary,
)

metric = F1Macro()
metric = F1Fast(
main_score="f1_macro", averages=["macro", "per_class"], ci_score_names=["f1_macro"]
)
add_to_catalog(metric, "metrics.f1_macro", overwrite=True)

metric = F1Micro()
metric = F1Fast(main_score="f1_micro", averages=["micro"])
add_to_catalog(metric, "metrics.f1_micro", overwrite=True)

metric = F1MacroMultiLabel()
Expand Down
43 changes: 38 additions & 5 deletions prepare/metrics/meteor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from unitxt import add_to_catalog
from unitxt.metrics import HuggingfaceMetric, Meteor
from unitxt.metrics import HuggingfaceMetric, MeteorFast
from unitxt.test_utils.metrics import test_metric

metric = Meteor(n_resamples=3)
metric = MeteorFast(
__description__="""METEOR (Metric for Evaluation of Translation with Explicit ORdering) is a machine translation evaluation metric, which is calculated based on the harmonic mean of precision and recall, with recall weighted more than precision.
METEOR is based on a generalized concept of unigram matching between the machine-produced translation and human-produced reference translations. Unigrams can be matched based on their surface forms, stemmed forms, and meanings. Once all generalized unigram matches between the two strings have been found, METEOR computes a score for this matching using a combination of unigram-precision, unigram-recall, and a measure of fragmentation that is designed to directly capture how well-ordered the matched words in the machine translation are in relation to the reference.
""",
)

predictions = [
"It is a guide to action which ensures that the military always obeys the commands of the party",
Expand All @@ -28,6 +33,35 @@
{"meteor": 0.47, "score": 0.47, "score_name": "meteor"},
]

global_target = {
"meteor": 0.58,
"meteor_ci_high": 0.67,
"meteor_ci_low": 0.48,
"num_of_instances": 4,
"score": 0.58,
"score_ci_high": 0.67,
"score_ci_low": 0.48,
"score_name": "meteor",
}

outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)

# to match the setting to occur by testing on the global version, metric2, below, setting n_resamples=3

metric_hf = MeteorFast(
n_resamples=3,
__description__="""Huggingface version with bad confidence interval calculation of METEOR (Metric for Evaluation of Translation with Explicit ORdering) is a machine translation evaluation metric, which is calculated based on the harmonic mean of precision and recall, with recall weighted more than precision.
METEOR is based on a generalized concept of unigram matching between the machine-produced translation and human-produced reference translations. Unigrams can be matched based on their surface forms, stemmed forms, and meanings. Once all generalized unigram matches between the two strings have been found, METEOR computes a score for this matching using a combination of unigram-precision, unigram-recall, and a measure of fragmentation that is designed to directly capture how well-ordered the matched words in the machine translation are in relation to the reference.
""",
)

global_target = {
"meteor": 0.58,
"meteor_ci_high": 0.59,
Expand All @@ -39,10 +73,8 @@
"num_of_instances": 4,
}

# to match the setting to occur by testing on the global version, metric2, below

outputs = test_metric(
metric=metric,
metric=metric_hf,
predictions=predictions,
references=references,
instance_targets=instance_targets,
Expand All @@ -63,3 +95,4 @@
)

add_to_catalog(metric, "metrics.meteor", overwrite=True)
add_to_catalog(metric_hf, "metrics.meteor_hf", overwrite=True)
6 changes: 4 additions & 2 deletions prepare/tasks/qa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from unitxt.blocks import Task
from unitxt.catalog import add_link_to_catalog, add_to_catalog
from unitxt.types import Audio, Dialog, Image, Table, Text
from unitxt.types import Audio, Dialog, Document, Image, MultiDocument, Table, Text

add_link_to_catalog(
artifact_linked_to="tasks.qa.extractive",
Expand Down Expand Up @@ -37,7 +37,9 @@
By default, classical Rouge metric is used , but list of additional applicable metrics can be found under 'metrics.qa' in the Unitxt catalog.
""",
input_fields={
"context": Union[Text, Image, Audio, Table, Dialog],
"context": Union[
Text, Image, Audio, Table, Dialog, Document, MultiDocument
],
"context_type": str,
"question": str,
},
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ tests = [
"accelerate",
"spacy",
"func_timeout==4.3.5",
"Wikipedia-API"
]
ui = [
"gradio",
Expand Down Expand Up @@ -237,7 +238,7 @@ extend-immutable-calls = ["fastapi.Depends", "fastapi.params.Depends", "fastapi.
"src".msg = "Use unitxt outside src/ and relative imports inside src/ and install unitxt from source with `pip install -e '.[dev]'`."

[tool.codespell]
ignore-words-list = 'rouge,ot,ans,nd,cann,som,tha,vie,ment,criterias'
ignore-words-list = 'rouge,ot,ans,nd,cann,som,tha,vie,ment,criterias,atleast'
check-filenames = true
check-hidden = false
regex = "(?<![a-z])[a-z'`]+|[A-Z][a-z'`]*|[a-z]+'[a-z]*|[a-z]+(?=[_-])|[a-z]+(?=[A-Z])|\\d+"
Expand Down
11 changes: 9 additions & 2 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .artifact import fetch_artifact
from .card import TaskCard
from .dataset_utils import get_dataset_artifact
from .error_utils import UnitxtError
from .inference import (
InferenceEngine,
LogProbInferenceEngine,
Expand Down Expand Up @@ -198,8 +199,14 @@ def load_dataset(
).with_transform(loads_instance)


def evaluate(predictions, data) -> EvaluationResults:
return _compute(predictions=predictions, references=data)
def evaluate(
predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
) -> EvaluationResults:
if dataset is None and data is None:
raise UnitxtError(message="Specify 'dataset' in evaluate")
if data is not None:
dataset = data # for backward compatibility
return _compute(predictions=predictions, references=dataset)


def post_process(predictions, data) -> List[Dict[str, Any]]:
Expand Down
73 changes: 73 additions & 0 deletions src/unitxt/catalog/cards/frames.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
{
"__type__": "task_card",
"loader": {
"__type__": "load_hf",
"path": "google/frames-benchmark",
"data_classification_policy": [
"public"
]
},
"preprocess_steps": [
{
"__type__": "rename",
"field": "Prompt",
"to_field": "question"
},
{
"__type__": "rename",
"field": "Answer",
"to_field": "answer"
},
{
"__type__": "wrap",
"field": "answer",
"inside": "list",
"to_field": "answers"
},
{
"__type__": "literal_eval",
"field": "wiki_links",
"to_field": "context"
},
{
"__type__": "wikipedia_fetcher",
"field": "context",
"process_every_value": true
},
{
"__type__": "set",
"fields": {
"context_type": "wikipedia articles"
}
}
],
"task": "tasks.qa.with_context",
"templates": "templates.qa.with_context.all",
"__description__": "FRAMES is a comprehensive evaluation dataset designed to test the capabilities of Retrieval-Augmented Generation (RAG) systems across factuality, retrieval accuracy, and reasoning.",
"__tags__": {
"annotations_creators": "expert-generated",
"arxiv": "1904.09728",
"flags": [
"NLU",
"natural language understanding"
],
"language": "en",
"language_creators": "other",
"license": "other",
"multilinguality": "monolingual",
"region": "us",
"size_categories": "10K<n<100K",
"source_datasets": "extended|other",
"task_categories": [
"text-classification",
"token-classification",
"question-answering"
],
"task_ids": [
"natural-language-inference",
"word-sense-disambiguation",
"coreference-resolution",
"extractive-qa"
]
}
}
Loading

0 comments on commit b6da498

Please sign in to comment.