Skip to content

Commit

Permalink
Add CollateInstanceByField operator to group data by specific field (#…
Browse files Browse the repository at this point in the history
…1546)

* Add GroupByProcessor operator to group data and apply custom functions

Signed-off-by: Sarath-S <[email protected]>

* Add CollateInstanceByField operator to group and aggregate data by field name

Signed-off-by: Sarath-S <[email protected]>

* Add CollateInstanceByField operator to group and aggregate data by field name

Signed-off-by: Sarath-S <[email protected]>

* Add consistency validation and test cases

Signed-off-by: Sarath-S <[email protected]>

* fix for test case failures

Signed-off-by: Sarath-S <[email protected]>

* fix for test case failures

Signed-off-by: Sarath-S <[email protected]>

* Added more tests and clarified error messages

Signed-off-by: Yoav Katz <[email protected]>

* More tests

Signed-off-by: Yoav Katz <[email protected]>

* Added checks and fixed bug in data classification policty handling

Signed-off-by: Yoav Katz <[email protected]>

* Added more tests and error message and remove doc_id default.

Signed-off-by: Yoav Katz <[email protected]>

* Improved documentation

Signed-off-by: Yoav Katz <[email protected]>

* Fix for handling data_classification_policy None cases

Signed-off-by: Sarath-S <[email protected]>

---------

Signed-off-by: Sarath-S <[email protected]>
Signed-off-by: Yoav Katz <[email protected]>
Co-authored-by: Sarath-S <[email protected]>
Co-authored-by: Yoav Katz <[email protected]>
Co-authored-by: Yoav Katz <[email protected]>
  • Loading branch information
4 people authored Jan 29, 2025
1 parent 6ce9541 commit 4f196bb
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 0 deletions.
97 changes: 97 additions & 0 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .dataclass import NonPositionalField, OptionalField
from .deprecation_utils import deprecation
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
from .error_utils import UnitxtError
from .generator_utils import ReusableGenerator
from .operator import (
InstanceOperator,
Expand Down Expand Up @@ -2243,6 +2244,102 @@ def verify(self):
)


class CollateInstancesByField(StreamOperator):
"""Groups a list of instances by a specified field, aggregates specified fields into lists, and ensures consistency for all other non-aggregated fields.
Args:
by_field str: the name of the field to group data by.
aggregate_fields list(str): the field names to aggregate into lists.
Returns:
A stream of instances grouped and aggregated by the specified field.
Raises:
UnitxtError: If non-aggregate fields have inconsistent values.
Example:
Collate the instances based on field "category" and aggregate fields "value" and "id".
CollateInstancesByField(by_field="category", aggregate_fields=["value", "id"])
given input:
[
{"id": 1, "category": "A", "value": 10", "flag" : True},
{"id": 2, "category": "B", "value": 20", "flag" : False},
{"id": 3, "category": "A", "value": 30", "flag" : True},
{"id": 4, "category": "B", "value": 40", "flag" : False}
]
the output is:
[
{"category": "A", "id": [1, 3], "value": [10, 30], "info": True},
{"category": "B", "id": [2, 4], "value": [20, 40], "info": False}
]
Note that the "flag" field is not aggregated, and must be the same
in all instances in the same category, or an error is raised.
"""

by_field: str = NonPositionalField(required=True)
aggregate_fields: List[str] = NonPositionalField(required=True)

def prepare(self):
super().prepare()

def verify(self):
super().verify()
if not isinstance(self.by_field, str):
raise UnitxtError(
f"The 'by_field' value is not a string but '{type(self.by_field)}'"
)

if not isinstance(self.aggregate_fields, list):
raise UnitxtError(
f"The 'allowed_field_values' is not a list but '{type(self.aggregate_fields)}'"
)

def process(self, stream: Stream, stream_name: Optional[str] = None):
grouped_data = {}

for instance in stream:
if self.by_field not in instance:
raise UnitxtError(
f"The field '{self.by_field}' specified by CollateInstancesByField's 'by_field' argument is not found in instance."
)
for k in self.aggregate_fields:
if k not in instance:
raise UnitxtError(
f"The field '{k}' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance."
)
key = instance[self.by_field]

if key not in grouped_data:
grouped_data[key] = {
k: v for k, v in instance.items() if k not in self.aggregate_fields
}
# Add empty lists for fields to aggregate
for agg_field in self.aggregate_fields:
if agg_field in instance:
grouped_data[key][agg_field] = []

for k, v in instance.items():
# Merge classification policy list across instance with same key
if k == "data_classification_policy" and instance[k]:
grouped_data[key][k] = sorted(set(grouped_data[key][k] + v))
# Check consistency for all non-aggregate fields
elif k != self.by_field and k not in self.aggregate_fields:
if k in grouped_data[key] and grouped_data[key][k] != v:
raise ValueError(
f"Inconsistent value for field '{k}' in group '{key}': "
f"'{grouped_data[key][k]}' vs '{v}'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances."
)
# Aggregate fields
elif k in self.aggregate_fields:
grouped_data[key][k].append(instance[k])

yield from grouped_data.values()


class WikipediaFetcher(FieldOperator):
mode: Literal["summary", "text"] = "text"
_requirements_list = ["Wikipedia-API"]
Expand Down
110 changes: 110 additions & 0 deletions tests/library/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ApplyStreamOperatorsField,
CastFields,
CollateInstances,
CollateInstancesByField,
Copy,
Deduplicate,
DeterministicBalancer,
Expand Down Expand Up @@ -2653,6 +2654,115 @@ def test_collate_instance(self):
tester=self,
)

def test_collate_instances_by_field(self):
inputs = [
{"id": 1, "category": "A", "value": 10},
{"id": 1, "category": "A", "value": 20},
{"id": 2, "category": "B", "value": 30},
{"id": 2, "category": "B", "value": 40},
]

targets = [
{"category": "A", "id": 1, "value": [10, 20]},
{"category": "B", "id": 2, "value": [30, 40]},
]

check_operator(
operator=CollateInstancesByField(
by_field="category", aggregate_fields=["value"]
),
inputs=inputs,
targets=targets,
tester=self,
)

inputs = [
{
"id": 1,
"category": "A",
"value": 10,
"data_classification_policy": ["public"],
},
{
"id": 2,
"category": "A",
"value": 20,
"data_classification_policy": ["public"],
},
{
"id": 3,
"category": "B",
"value": 30,
"data_classification_policy": ["public"],
},
{
"id": 4,
"category": "B",
"value": 40,
"data_classification_policy": ["private"],
},
]

targets = [
{
"category": "A",
"id": [1, 2],
"value": [10, 20],
"data_classification_policy": ["public"],
},
{
"category": "B",
"id": [3, 4],
"value": [30, 40],
"data_classification_policy": ["private", "public"],
},
]

check_operator(
operator=CollateInstancesByField(
by_field="category", aggregate_fields=["value", "id"]
),
inputs=inputs,
targets=targets,
tester=self,
)

exception_texts = [
"Inconsistent value for field 'id' in group 'A': '1' vs '2'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances.",
]
check_operator_exception(
operator=CollateInstancesByField(
by_field="category", aggregate_fields=["value"]
),
inputs=inputs,
exception_texts=exception_texts,
tester=self,
)

exception_texts = [
"The field 'not_exist' specified by CollateInstancesByField's 'by_field' argument is not found in instance."
]
check_operator_exception(
operator=CollateInstancesByField(
by_field="not_exist", aggregate_fields=["value"]
),
inputs=inputs,
exception_texts=exception_texts,
tester=self,
)

exception_texts = [
"The field 'not_exist' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance."
]
check_operator_exception(
operator=CollateInstancesByField(
by_field="category", aggregate_fields=["id", "value", "not_exist"]
),
inputs=inputs,
exception_texts=exception_texts,
tester=self,
)


class TestApplyMetric(UnitxtTestCase):
def _test_apply_metric(
Expand Down

0 comments on commit 4f196bb

Please sign in to comment.