Skip to content

Commit

Permalink
Merge branch 'main' into fix/correct-choice-position-handling
Browse files Browse the repository at this point in the history
  • Loading branch information
elronbandel authored Jan 16, 2025
2 parents b359ab4 + 9777799 commit 725f128
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 13 deletions.
10 changes: 9 additions & 1 deletion docs/docs/evaluating_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,12 @@ Will print:
| templates.key_val | 1 | 0.222222 | f1_micro | 0.0 | 0.7225818346056374 | 7 |
| templates.classification.multi_class.relation.default | 3 | 0.285714 | f1_micro | 0.0 | 0.779447856172277 | 6 |
| templates.classification.multi_class.relation.default | 0 | 0.181818 | f1_micro | 0.0 | 0.4105379478071894 | 19 |
| templates.key_val | 0 | 0 | f1_micro | | | 7 |
| templates.key_val | 0 | 0 | f1_micro | | | 7 |
Metadata
--------
The result object that returned by `evaluate` function contains `metadata` feature.
This feature contains the dataset and the inference engine metadata (if exists).:

This metadata can be accessed and used for further analysis or debugging.
14 changes: 14 additions & 0 deletions docs/docs/loading_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,17 @@ Here is an example of using random templates and a varying number of demonstrati
.. code-block:: python
dataset = load_dataset(card="cards.wnli",template=["templates.classification.multi_class.relation.default","templates.key_val"],num_demos=[0,1,3],demos_pool_size=100)
Metadata
--------
The `load_dataset` function result contains a metadata object. If the object is a Dataset or IterableDataset the metadata
saved under the path info.description. If the result is a dict of datasets, each dataset contains the metadata at the same path.
The metada is a dictionary which contains information about the execution, including:

* All parameters passed to the `load_dataset` function
* Execution time
* Other relevant metadata

This metadata can be accessed and used for further analysis or debugging.

39 changes: 34 additions & 5 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import inspect
import json
from datetime import datetime
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -190,13 +192,32 @@ def load_dataset(
disable_cache = settings.disable_hf_datasets_cache

if streaming:
return stream.to_iterable_dataset(
dataset = stream.to_iterable_dataset(
features=UNITXT_DATASET_SCHEMA,
).map(loads_instance, batched=True)
else:
dataset = stream.to_dataset(
features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache
).with_transform(loads_instance)

frame = inspect.currentframe()
args, _, _, values = inspect.getargvalues(frame)
all_kwargs = {key: values[key] for key in args if key != "kwargs"}
all_kwargs.update(kwargs)
metadata = fill_metadata(**all_kwargs)
if isinstance(dataset, dict):
for ds in dataset.values():
ds.info.description = metadata.copy()
else:
dataset.info.description = metadata
return dataset


return stream.to_dataset(
features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache
).with_transform(loads_instance)
def fill_metadata(**kwargs):
metadata = kwargs.copy()
metadata["unitxt_version"] = get_constants().version
metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
return metadata


def evaluate(
Expand All @@ -206,7 +227,15 @@ def evaluate(
raise UnitxtError(message="Specify 'dataset' in evaluate")
if data is not None:
dataset = data # for backward compatibility
return _compute(predictions=predictions, references=dataset)
evaluation_result = _compute(predictions=predictions, references=dataset)
if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
evaluation_result.metadata["dataset"] = dataset.info.description
if hasattr(predictions, "metadata"):
evaluation_result.metadata["predictions"] = predictions.metadata
evaluation_result.metadata["creation_time"] = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S.%f"
)[:-3]
return evaluation_result


def post_process(predictions, data) -> List[Dict[str, Any]]:
Expand Down
31 changes: 27 additions & 4 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import uuid
from collections import Counter
from datetime import datetime
from multiprocessing.pool import ThreadPool
from typing import (
Any,
Expand All @@ -21,6 +22,7 @@
Sequence,
Tuple,
TypedDict,
TypeVar,
Union,
)

Expand Down Expand Up @@ -131,6 +133,18 @@ class TextGenerationInferenceOutput:
inference_type: Optional[str] = None


T = TypeVar("T")


class ListWithMetadata(List[T]):
def __init__(self, *args, metadata: Optional[dict] = None, **kwargs):
super().__init__(*args, **kwargs)
self.metadata = metadata if metadata is not None else {}

def __repr__(self):
return f"ListWithMetadata(data={super().__repr__()}, metadata={self.metadata})"


class InferenceEngine(Artifact):
"""Abstract base class for inference."""

Expand Down Expand Up @@ -162,14 +176,14 @@ def __call__(
self,
dataset: Union[List[Dict[str, Any]], Dataset],
return_meta_data: bool = False,
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
return self.infer(dataset=dataset, return_meta_data=return_meta_data)

def infer(
self,
dataset: Union[List[Dict[str, Any]], Dataset],
return_meta_data: bool = False,
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
) -> Union[ListWithMetadata[str], ListWithMetadata[TextGenerationInferenceOutput]]:
"""Verifies instances of a dataset and perform inference on the input dataset.
If return_meta_data - returns a list of TextGenerationInferenceOutput, else returns a list of the string
Expand All @@ -187,8 +201,17 @@ def infer(

[self.verify_instance(instance) for instance in dataset]
if settings.mock_inference_mode:
return self._mock_infer(dataset)
return self._infer(dataset, return_meta_data)
result = self._mock_infer(dataset)
else:
result = self._infer(dataset, return_meta_data)
return ListWithMetadata(
result,
metadata={
"init_dict": self._init_dict,
"inference_engine_type": self.__class__.__name__,
"creation_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
},
)

def _mock_infer(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/unitxt/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,10 @@ def __repr__(self):


class EvaluationResults(list):
def __init__(self, *args, metadata=None, **kwargs):
super().__init__(*args, **kwargs)
self.metadata = metadata if metadata is not None else {}

@property
def global_scores(self):
return GlobalScores(self[0]["score"]["global"])
Expand Down
6 changes: 3 additions & 3 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@
"filename": "src/unitxt/inference.py",
"hashed_secret": "aa6cd2a77de22303be80e1f632195d62d211a729",
"is_verified": false,
"line_number": 1258,
"line_number": 1281,
"is_secret": false
},
{
"type": "Secret Keyword",
"filename": "src/unitxt/inference.py",
"hashed_secret": "c8f16a194efc59559549c7bd69f7bea038742e79",
"is_verified": false,
"line_number": 1743,
"line_number": 1766,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-01-13T09:29:10Z"
"generated_at": "2025-01-15T12:35:17Z"
}

0 comments on commit 725f128

Please sign in to comment.