Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow a serializer per field in template. #1438

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions docs/docs/types_and_serializers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,11 @@ Now if you print the input of the first instance of the dataset by ``print(datas
Adding a Serializer to a Template
------------------------------------

Another option is to set a default serializer for a given template. When creating a template, we need to add all the serializers for all the types we want to support. For this purpose, we use a multi-type serializer that wraps all the serializers together.

Another option is to set a default serializer for a given template per field.
.. code-block:: python

from unitxt.serializers import (
MultiTypeSerializer, ImageSerializer, TableSerializer, DialogSerializer, ListSerializer,
)

serializer = MultiTypeSerializer(
serializers=[
ImageSerializer(),
TableSerializer(),
DialogSerializer(),
ListSerializer(),
]
DialogSerializer
)

Now, we can add them to the template:
Expand All @@ -157,7 +147,21 @@ Now, we can add them to the template:
instruction="Summarize the following dialog.",
input_format="{dialog}",
output_format="{summary}",
serializer=serializer
serializer={"dialog": DialogSerializer}
)

As another example, we can use a customized ListSerializer, to format list of contexts with
two newlines separators, instead of the standard comma list separator.

.. code-block:: python

from unitxt.serializers import (
ListSerializer
)
template = MultiReferenceTemplate(
instruction="Answer the question based on the information provided in the document given below.\n\n",
input_format="Contexts:\n\n{contexts}\n\nQuestion: {question}",
references_field="reference_answers",
serializer={"contexts": ListSerializer(separator="\n\n")},
)

Important: Serializers are activated in the order they are defined, in a "first in, first serve" manner. This means that if you place the ``ListSerializer`` before the ``DialogSerializer``, the `ListSerializer` will serialize the dialog, as the ``Dialog`` is also a ``List`` and matches the type requirement of the ``ListSerializer``.
100 changes: 33 additions & 67 deletions examples/evaluate_rag_response_generation.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,48 @@
from unitxt.api import evaluate, load_dataset
from unitxt.blocks import (
TaskCard,
)
from unitxt.collections_operators import Wrap
from unitxt.inference import (
HFPipelineBasedInferenceEngine,
)
from unitxt.loaders import LoadFromDictionary
from unitxt.operators import Rename, Set
from unitxt.templates import MultiReferenceTemplate, TemplatesDict
from unitxt.api import create_dataset, evaluate
from unitxt.inference import CrossProviderInferenceEngine
from unitxt.serializers import ListSerializer
from unitxt.templates import MultiReferenceTemplate
from unitxt.text_utils import print_dict

# Assume the RAG data is proved in this format
data = {
"test": [
{
"query": "What city is the largest in Texas?",
"extracted_chunks": "Austin is the capital of Texas.\nHouston is the the largest city in Texas but not the capital of it. ",
"expected_answer": "Houston",
},
{
"query": "What city is the capital of Texas?",
"extracted_chunks": "Houston is the the largest city in Texas but not the capital of it. ",
"expected_answer": "Austin",
},
]
}


card = TaskCard(
# Assumes this csv, contains 3 fields
# question (string), extracted_chunks (string), expected_answer (string)
loader=LoadFromDictionary(data=data),
# Map these fields to the fields of the task.rag.response_generation task.
# See https://www.unitxt.ai/en/latest/catalog/catalog.tasks.rag.response_generation.html
preprocess_steps=[
Rename(field_to_field={"query": "question"}),
Wrap(field="extracted_chunks", inside="list", to_field="contexts"),
Wrap(field="expected_answer", inside="list", to_field="reference_answers"),
Set(
fields={
"contexts_ids": [],
}
),
],
# Specify the task and the desired metrics (note that these are part of the default
# metrics for the task, so the metrics selection can be omitted).
task="tasks.rag.response_generation",
# Specify a default template
templates=TemplatesDict(
{
"simple": MultiReferenceTemplate(
instruction="Answer the question based on the information provided in the document given below.\n\n",
input_format="Document: {contexts}\nQuestion: {question}",
references_field="reference_answers",
),
}
),
data = [
{
"question": "What city is the largest in Texas?",
"contexts": [
"Austin is the capital of Texas.",
"Houston is the the largest city in Texas but not the capital of it. ",
],
"reference_answers": ["Houston"],
},
{
"question": "What city is the capital of Texas?",
"contexts": [
"Houston is the the largest city in Texas but not the capital of it. "
],
"reference_answers": ["Austin"],
},
]

template = MultiReferenceTemplate(
instruction="Answer the question based on the information provided in the document given below.\n\n",
input_format="Contexts:\n\n{contexts}\n\nQuestion: {question}",
references_field="reference_answers",
serializer={"contexts": ListSerializer(separator="\n\n")},
)

# Verbalize the dataset using the template
dataset = load_dataset(
card=card,
template_card_index="simple",
dataset = create_dataset(
test_set=data,
template=template,
task="tasks.rag.response_generation",
format="formats.chat_api",
split="test",
max_test_instances=10,
)


# Infer using Llama-3.2-1B base using HF API
engine = HFPipelineBasedInferenceEngine(
model_name="meta-llama/Llama-3.2-1B", max_new_tokens=32
)
# Change to this to infer with external APIs:
# CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
engine = CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "bam"]


predictions = engine.infer(dataset)
evaluated_dataset = evaluate(predictions=predictions, data=dataset)

Expand Down
20 changes: 9 additions & 11 deletions src/unitxt/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .dataclass import AbstractField, Field
from .operators import InstanceFieldOperator
from .settings_utils import get_constants
from .type_utils import isoftype, to_type_string
from .type_utils import isoftype
from .types import Dialog, Image, Number, Table, Video

constants = get_constants()
Expand All @@ -26,29 +26,27 @@ def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
return str(value)


class SingleTypeSerializer(InstanceFieldOperator):
class SingleTypeSerializer(Serializer):
serialized_type: object = AbstractField()

def process_instance_value(self, value: Any, instance: Dict[str, Any]) -> str:
if not isoftype(value, self.serialized_type):
raise ValueError(
f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {to_type_string(value)}"
f"SingleTypeSerializer for type {self.serialized_type} should get this type. got {type(value)}. Value: {value}"
)
return self.serialize(value, instance)


class DefaultListSerializer(Serializer):
def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
if isinstance(value, list):
return ", ".join(str(item) for item in value)
return str(value)


class ListSerializer(SingleTypeSerializer):
serialized_type = list
separator: str = ", "
prefix: str = ""
suffix: str = ""

def serialize(self, value: Any, instance: Dict[str, Any]) -> str:
return ", ".join(str(item) for item in value)
return (
self.prefix + self.separator.join(str(item) for item in value) + self.suffix
)


class DialogSerializer(SingleTypeSerializer):
Expand Down
38 changes: 30 additions & 8 deletions src/unitxt/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def verify(self):
super().verify()
assert isoftype(
self.postprocessors, List[Union[Operator, str]]
), f"The template post processors field '{self.postprocessors}' is not a list of processors. Instead it is of type '{to_type_string(type(self.postprocessors))}'."
), f"The template's 'post processors' field '{self.postprocessors}' is not a list of processors. Instead it is of type '{to_type_string(type(self.postprocessors))}'."
assert (
isoftype(self.serializer, Serializer)
or isoftype(self.serializer, Dict[str, Serializer])
), f"The template's 'serializer' field '{self.serializer}' is not of type Serializer. Instead it is of type '{type(self.serializer)}'."

def input_fields_to_instruction_and_target_prefix(self, input_fields):
instruction = self.apply_formatting(
Expand All @@ -86,6 +90,22 @@ def input_fields_to_instruction_and_target_prefix(self, input_fields):
)
return instruction, target_prefix

def get_field_serializer(self, field: str):
if isoftype(self.serializer, Dict[str, Serializer]):
if field in self.serializer:
return self.serializer[field]
elif self.serializer is not None:
return self.serializer
return MultiTypeSerializer(
serializers=[
ImageSerializer(),
VideoSerializer(),
TableSerializer(),
DialogSerializer(),
ListSerializer(),
]
)

def preprocess_input_and_reference_fields(
self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand Down Expand Up @@ -162,7 +182,10 @@ def post_process_instance(self, instance):
def serialize(
self, data: Dict[str, Any], instance: Dict[str, Any]
) -> Dict[str, str]:
return {k: self.serializer.serialize(v, instance) for k, v in data.items()}
return {
k: self.get_field_serializer(k).process_instance_value(v, instance)
for k, v in data.items()
}

@abstractmethod
def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
Expand Down Expand Up @@ -729,9 +752,6 @@ def reference_fields_to_target_and_references(


class OutputQuantizingTemplate(InputOutputTemplate):
serializer: MultiTypeSerializer = NonPositionalField(
default_factory=MultiTypeSerializer
)
quantum: Union[float, int] = 0.1

def prepare(self):
Expand Down Expand Up @@ -766,17 +786,19 @@ def preprocess_reference_fields(
class MultiReferenceTemplate(InputOutputTemplate):
references_field: str = "references"
random_reference: bool = False
serializer: Serializer = NonPositionalField(default_factory=MultiTypeSerializer)

def serialize(
self, data: Dict[str, Any], instance: Dict[str, Any]
) -> Dict[str, str]:
result = {}
for k, v in data.items():
if k == self.references_field:
v = [self.serializer.serialize(item, instance) for item in v]
v = [
self.get_field_serializer(k).process_instance_value(item, instance)
for item in v
]
else:
v = self.serializer.serialize(v, instance)
v = self.get_field_serializer(k).process_instance_value(v, instance)
result[k] = v
return result

Expand Down
23 changes: 19 additions & 4 deletions tests/library/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from unitxt.dataclass import RequiredFieldError
from unitxt.error_utils import UnitxtError
from unitxt.serializers import ListSerializer
from unitxt.templates import (
ApplyRandomTemplate,
ApplySingleTemplate,
Expand Down Expand Up @@ -404,7 +405,7 @@ def test_input_output_template_and_standard_template(self):
template = InputOutputTemplate(
input_format="This is my text:'{text}'",
output_format="{label}",
instruction="Classify sentiment into: {labels}.\n",
instruction="Classify sentiment into: {labels}.",
target_prefix="Sentiment is: ",
)

Expand Down Expand Up @@ -442,7 +443,7 @@ def test_input_output_template_and_standard_template(self):
"source": "This is my text:'hello world'",
"target": "positive",
"references": ["positive"],
"instruction": "Classify sentiment into: positive, negative.\n",
"instruction": "Classify sentiment into: positive, negative.",
"target_prefix": "Sentiment is: ",
"postprocessors": ["processors.to_string_stripped"],
},
Expand All @@ -455,7 +456,7 @@ def test_input_output_template_and_standard_template(self):
"source": "This is my text:'hello world\n, hell'",
"target": "positive",
"references": ["positive"],
"instruction": "Classify sentiment into: positive, negative.\n",
"instruction": "Classify sentiment into: positive, negative.",
"target_prefix": "Sentiment is: ",
"postprocessors": ["processors.to_string_stripped"],
},
Expand All @@ -468,12 +469,26 @@ def test_input_output_template_and_standard_template(self):
"source": "This is my text:'hello world\n, hell'",
"target": "positive, 1",
"references": ["positive, 1"],
"instruction": "Classify sentiment into: positive, negative.\n",
"instruction": "Classify sentiment into: positive, negative.",
"target_prefix": "Sentiment is: ",
"postprocessors": ["processors.to_string_stripped"],
},
]

check_operator(template, inputs, targets, tester=self)
# Now format with a different list serializer
for target in targets:
target["instruction"] = "Classify sentiment into: [positive/negative]."

template = InputOutputTemplate(
input_format="This is my text:'{text}'",
output_format="{label}",
instruction="Classify sentiment into: {labels}.",
target_prefix="Sentiment is: ",
serializer={
"labels": ListSerializer(separator="/", prefix="[", suffix="]")
},
)
check_operator(template, inputs, targets, tester=self)

# if "source" and "target" and "instruction_format" and "target_prefix" in instance - instance is not modified
Expand Down
Loading