From 2c59b8e6d079ce5782dca2315dd6b8e7bfd72728 Mon Sep 17 00:00:00 2001 From: Yoav Katz Date: Sun, 15 Dec 2024 10:10:04 +0200 Subject: [PATCH 1/2] Revert "Revert "Allow a serializer per field in template."" This reverts commit 88b4a4948727feeb09baf7e14ee20ddd89af9b39. --- examples/evaluate_rag_response_generation.py | 100 ++++++------------- src/unitxt/serializers.py | 20 ++-- src/unitxt/templates.py | 38 +++++-- tests/library/test_templates.py | 23 ++++- 4 files changed, 91 insertions(+), 90 deletions(-) diff --git a/examples/evaluate_rag_response_generation.py b/examples/evaluate_rag_response_generation.py index dd9e9cb496..3c4b3c083c 100644 --- a/examples/evaluate_rag_response_generation.py +++ b/examples/evaluate_rag_response_generation.py @@ -1,82 +1,48 @@ -from unitxt.api import evaluate, load_dataset -from unitxt.blocks import ( - TaskCard, -) -from unitxt.collections_operators import Wrap -from unitxt.inference import ( - HFPipelineBasedInferenceEngine, -) -from unitxt.loaders import LoadFromDictionary -from unitxt.operators import Rename, Set -from unitxt.templates import MultiReferenceTemplate, TemplatesDict +from unitxt.api import create_dataset, evaluate +from unitxt.inference import CrossProviderInferenceEngine +from unitxt.serializers import ListSerializer +from unitxt.templates import MultiReferenceTemplate from unitxt.text_utils import print_dict # Assume the RAG data is proved in this format -data = { - "test": [ - { - "query": "What city is the largest in Texas?", - "extracted_chunks": "Austin is the capital of Texas.\nHouston is the the largest city in Texas but not the capital of it. ", - "expected_answer": "Houston", - }, - { - "query": "What city is the capital of Texas?", - "extracted_chunks": "Houston is the the largest city in Texas but not the capital of it. ", - "expected_answer": "Austin", - }, - ] -} - - -card = TaskCard( - # Assumes this csv, contains 3 fields - # question (string), extracted_chunks (string), expected_answer (string) - loader=LoadFromDictionary(data=data), - # Map these fields to the fields of the task.rag.response_generation task. - # See https://www.unitxt.ai/en/latest/catalog/catalog.tasks.rag.response_generation.html - preprocess_steps=[ - Rename(field_to_field={"query": "question"}), - Wrap(field="extracted_chunks", inside="list", to_field="contexts"), - Wrap(field="expected_answer", inside="list", to_field="reference_answers"), - Set( - fields={ - "contexts_ids": [], - } - ), - ], - # Specify the task and the desired metrics (note that these are part of the default - # metrics for the task, so the metrics selection can be omitted). - task="tasks.rag.response_generation", - # Specify a default template - templates=TemplatesDict( - { - "simple": MultiReferenceTemplate( - instruction="Answer the question based on the information provided in the document given below.\n\n", - input_format="Document: {contexts}\nQuestion: {question}", - references_field="reference_answers", - ), - } - ), +data = [ + { + "question": "What city is the largest in Texas?", + "contexts": [ + "Austin is the capital of Texas.", + "Houston is the the largest city in Texas but not the capital of it. ", + ], + "reference_answers": ["Houston"], + }, + { + "question": "What city is the capital of Texas?", + "contexts": [ + "Houston is the the largest city in Texas but not the capital of it. " + ], + "reference_answers": ["Austin"], + }, +] + +template = MultiReferenceTemplate( + instruction="Answer the question based on the information provided in the document given below.\n\n", + input_format="Contexts:\n\n{contexts}\n\nQuestion: {question}", + references_field="reference_answers", + serializer={"contexts": ListSerializer(separator="\n\n")}, ) # Verbalize the dataset using the template -dataset = load_dataset( - card=card, - template_card_index="simple", +dataset = create_dataset( + test_set=data, + template=template, + task="tasks.rag.response_generation", format="formats.chat_api", split="test", - max_test_instances=10, ) - -# Infer using Llama-3.2-1B base using HF API -engine = HFPipelineBasedInferenceEngine( - model_name="meta-llama/Llama-3.2-1B", max_new_tokens=32 -) -# Change to this to infer with external APIs: -# CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx") +engine = CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx") # The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "bam"] + predictions = engine.infer(dataset) evaluated_dataset = evaluate(predictions=predictions, data=dataset) diff --git a/src/unitxt/serializers.py b/src/unitxt/serializers.py index 5fb24abe6a..9d913982ec 100644 --- a/src/unitxt/serializers.py +++ b/src/unitxt/serializers.py @@ -6,7 +6,7 @@ from .dataclass import AbstractField, Field from .operators import InstanceFieldOperator from .settings_utils import get_constants -from .type_utils import isoftype, to_type_string +from .type_utils import isoftype from .types import Dialog, Image, Number, Table, Video constants = get_constants() @@ -26,29 +26,27 @@ def serialize(self, value: Any, instance: Dict[str, Any]) -> str: return str(value) -class SingleTypeSerializer(InstanceFieldOperator): +class SingleTypeSerializer(Serializer): serialized_type: object = AbstractField() def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str: if not isoftype(value, self.serialized_type): raise ValueError( - f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {to_type_string(value)}" + f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {type(value)}. Value: {value}" ) return self.serialize(value, instance) -class DefaultListSerializer(Serializer): - def serialize(self, value: Any, instance: Dict[str, Any]) -> str: - if isinstance(value, list): - return ", ".join(str(item) for item in value) - return str(value) - - class ListSerializer(SingleTypeSerializer): serialized_type = list + separator: str = ", " + prefix: str = "" + suffix: str = "" def serialize(self, value: Any, instance: Dict[str, Any]) -> str: - return ", ".join(str(item) for item in value) + return ( + self.prefix + self.separator.join(str(item) for item in value) + self.suffix + ) class DialogSerializer(SingleTypeSerializer): diff --git a/src/unitxt/templates.py b/src/unitxt/templates.py index 03cc52745a..1686458423 100644 --- a/src/unitxt/templates.py +++ b/src/unitxt/templates.py @@ -72,7 +72,11 @@ def verify(self): super().verify() assert isoftype( self.postprocessors, List[Union[Operator, str]] - ), f"The template post processors field '{self.postprocessors}' is not a list of processors. Instead it is of type '{to_type_string(type(self.postprocessors))}'." + ), f"The template's 'post processors' field '{self.postprocessors}' is not a list of processors. Instead it is of type '{to_type_string(type(self.postprocessors))}'." + assert ( + isoftype(self.serializer, Serializer) + or isoftype(self.serializer, Dict[str, Serializer]) + ), f"The template's 'serializer' field '{self.serializer}' is not of type Serializer. Instead it is of type '{type(self.serializer)}'." def input_fields_to_instruction_and_target_prefix(self, input_fields): instruction = self.apply_formatting( @@ -86,6 +90,22 @@ def input_fields_to_instruction_and_target_prefix(self, input_fields): ) return instruction, target_prefix + def get_field_serializer(self, field: str): + if isoftype(self.serializer, Dict[str, Serializer]): + if field in self.serializer: + return self.serializer[field] + elif self.serializer is not None: + return self.serializer + return MultiTypeSerializer( + serializers=[ + ImageSerializer(), + VideoSerializer(), + TableSerializer(), + DialogSerializer(), + ListSerializer(), + ] + ) + def preprocess_input_and_reference_fields( self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -162,7 +182,10 @@ def post_process_instance(self, instance): def serialize( self, data: Dict[str, Any], instance: Dict[str, Any] ) -> Dict[str, str]: - return {k: self.serializer.serialize(v, instance) for k, v in data.items()} + return { + k: self.get_field_serializer(k).process_instance_value(v, instance) + for k, v in data.items() + } @abstractmethod def input_fields_to_source(self, input_fields: Dict[str, object]) -> str: @@ -729,9 +752,6 @@ def reference_fields_to_target_and_references( class OutputQuantizingTemplate(InputOutputTemplate): - serializer: MultiTypeSerializer = NonPositionalField( - default_factory=MultiTypeSerializer - ) quantum: Union[float, int] = 0.1 def prepare(self): @@ -766,7 +786,6 @@ def preprocess_reference_fields( class MultiReferenceTemplate(InputOutputTemplate): references_field: str = "references" random_reference: bool = False - serializer: Serializer = NonPositionalField(default_factory=MultiTypeSerializer) def serialize( self, data: Dict[str, Any], instance: Dict[str, Any] @@ -774,9 +793,12 @@ def serialize( result = {} for k, v in data.items(): if k == self.references_field: - v = [self.serializer.serialize(item, instance) for item in v] + v = [ + self.get_field_serializer(k).process_instance_value(item, instance) + for item in v + ] else: - v = self.serializer.serialize(v, instance) + v = self.get_field_serializer(k).process_instance_value(v, instance) result[k] = v return result diff --git a/tests/library/test_templates.py b/tests/library/test_templates.py index c1784ce505..81007ca74b 100644 --- a/tests/library/test_templates.py +++ b/tests/library/test_templates.py @@ -2,6 +2,7 @@ from unitxt.dataclass import RequiredFieldError from unitxt.error_utils import UnitxtError +from unitxt.serializers import ListSerializer from unitxt.templates import ( ApplyRandomTemplate, ApplySingleTemplate, @@ -404,7 +405,7 @@ def test_input_output_template_and_standard_template(self): template = InputOutputTemplate( input_format="This is my text:'{text}'", output_format="{label}", - instruction="Classify sentiment into: {labels}.\n", + instruction="Classify sentiment into: {labels}.", target_prefix="Sentiment is: ", ) @@ -442,7 +443,7 @@ def test_input_output_template_and_standard_template(self): "source": "This is my text:'hello world'", "target": "positive", "references": ["positive"], - "instruction": "Classify sentiment into: positive, negative.\n", + "instruction": "Classify sentiment into: positive, negative.", "target_prefix": "Sentiment is: ", "postprocessors": ["processors.to_string_stripped"], }, @@ -455,7 +456,7 @@ def test_input_output_template_and_standard_template(self): "source": "This is my text:'hello world\n, hell'", "target": "positive", "references": ["positive"], - "instruction": "Classify sentiment into: positive, negative.\n", + "instruction": "Classify sentiment into: positive, negative.", "target_prefix": "Sentiment is: ", "postprocessors": ["processors.to_string_stripped"], }, @@ -468,12 +469,26 @@ def test_input_output_template_and_standard_template(self): "source": "This is my text:'hello world\n, hell'", "target": "positive, 1", "references": ["positive, 1"], - "instruction": "Classify sentiment into: positive, negative.\n", + "instruction": "Classify sentiment into: positive, negative.", "target_prefix": "Sentiment is: ", "postprocessors": ["processors.to_string_stripped"], }, ] + check_operator(template, inputs, targets, tester=self) + # Now format with a different list serializer + for target in targets: + target["instruction"] = "Classify sentiment into: [positive/negative]." + + template = InputOutputTemplate( + input_format="This is my text:'{text}'", + output_format="{label}", + instruction="Classify sentiment into: {labels}.", + target_prefix="Sentiment is: ", + serializer={ + "labels": ListSerializer(separator="/", prefix="[", suffix="]") + }, + ) check_operator(template, inputs, targets, tester=self) # if "source" and "target" and "instruction_format" and "target_prefix" in instance - instance is not modified From bbc221c3181b11564fb64ea62068597f37b37094 Mon Sep 17 00:00:00 2001 From: Yoav Katz Date: Sun, 15 Dec 2024 10:31:16 +0200 Subject: [PATCH 2/2] Updated documentation Signed-off-by: Yoav Katz --- docs/docs/types_and_serializers.rst | 32 ++++++++++++++++------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/docs/docs/types_and_serializers.rst b/docs/docs/types_and_serializers.rst index 78fd964a08..e64c5ad31b 100644 --- a/docs/docs/types_and_serializers.rst +++ b/docs/docs/types_and_serializers.rst @@ -132,21 +132,11 @@ Now if you print the input of the first instance of the dataset by ``print(datas Adding a Serializer to a Template ------------------------------------ -Another option is to set a default serializer for a given template. When creating a template, we need to add all the serializers for all the types we want to support. For this purpose, we use a multi-type serializer that wraps all the serializers together. - +Another option is to set a default serializer for a given template per field. .. code-block:: python from unitxt.serializers import ( - MultiTypeSerializer, ImageSerializer, TableSerializer, DialogSerializer, ListSerializer, - ) - - serializer = MultiTypeSerializer( - serializers=[ - ImageSerializer(), - TableSerializer(), - DialogSerializer(), - ListSerializer(), - ] + DialogSerializer ) Now, we can add them to the template: @@ -157,7 +147,21 @@ Now, we can add them to the template: instruction="Summarize the following dialog.", input_format="{dialog}", output_format="{summary}", - serializer=serializer + serializer={"dialog": DialogSerializer} + ) + +As another example, we can use a customized ListSerializer, to format list of contexts with +two newlines separators, instead of the standard comma list separator. + +.. code-block:: python + + from unitxt.serializers import ( + ListSerializer + ) + template = MultiReferenceTemplate( + instruction="Answer the question based on the information provided in the document given below.\n\n", + input_format="Contexts:\n\n{contexts}\n\nQuestion: {question}", + references_field="reference_answers", + serializer={"contexts": ListSerializer(separator="\n\n")}, ) -Important: Serializers are activated in the order they are defined, in a "first in, first serve" manner. This means that if you place the ``ListSerializer`` before the ``DialogSerializer``, the `ListSerializer` will serialize the dialog, as the ``Dialog`` is also a ``List`` and matches the type requirement of the ``ListSerializer``. \ No newline at end of file