From 4f196bb6a396cb30c457ae3f1e63ad3bed30ba85 Mon Sep 17 00:00:00 2001 From: Sarath S <47180054+sarathsgvr@users.noreply.github.com> Date: Wed, 29 Jan 2025 19:15:58 +0530 Subject: [PATCH] Add CollateInstanceByField operator to group data by specific field (#1546) * Add GroupByProcessor operator to group data and apply custom functions Signed-off-by: Sarath-S * Add CollateInstanceByField operator to group and aggregate data by field name Signed-off-by: Sarath-S * Add CollateInstanceByField operator to group and aggregate data by field name Signed-off-by: Sarath-S * Add consistency validation and test cases Signed-off-by: Sarath-S * fix for test case failures Signed-off-by: Sarath-S * fix for test case failures Signed-off-by: Sarath-S * Added more tests and clarified error messages Signed-off-by: Yoav Katz * More tests Signed-off-by: Yoav Katz * Added checks and fixed bug in data classification policty handling Signed-off-by: Yoav Katz * Added more tests and error message and remove doc_id default. Signed-off-by: Yoav Katz * Improved documentation Signed-off-by: Yoav Katz * Fix for handling data_classification_policy None cases Signed-off-by: Sarath-S --------- Signed-off-by: Sarath-S Signed-off-by: Yoav Katz Co-authored-by: Sarath-S Co-authored-by: Yoav Katz Co-authored-by: Yoav Katz <68273864+yoavkatz@users.noreply.github.com> --- src/unitxt/operators.py | 97 ++++++++++++++++++++++++++++ tests/library/test_operators.py | 110 ++++++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+) diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index 6b9e444a83..999a188efd 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -67,6 +67,7 @@ from .dataclass import NonPositionalField, OptionalField from .deprecation_utils import deprecation from .dict_utils import dict_delete, dict_get, dict_set, is_subpath +from .error_utils import UnitxtError from .generator_utils import ReusableGenerator from .operator import ( InstanceOperator, @@ -2243,6 +2244,102 @@ def verify(self): ) +class CollateInstancesByField(StreamOperator): + """Groups a list of instances by a specified field, aggregates specified fields into lists, and ensures consistency for all other non-aggregated fields. + + Args: + by_field str: the name of the field to group data by. + aggregate_fields list(str): the field names to aggregate into lists. + + Returns: + A stream of instances grouped and aggregated by the specified field. + + Raises: + UnitxtError: If non-aggregate fields have inconsistent values. + + Example: + Collate the instances based on field "category" and aggregate fields "value" and "id". + + CollateInstancesByField(by_field="category", aggregate_fields=["value", "id"]) + + given input: + [ + {"id": 1, "category": "A", "value": 10", "flag" : True}, + {"id": 2, "category": "B", "value": 20", "flag" : False}, + {"id": 3, "category": "A", "value": 30", "flag" : True}, + {"id": 4, "category": "B", "value": 40", "flag" : False} + ] + + the output is: + [ + {"category": "A", "id": [1, 3], "value": [10, 30], "info": True}, + {"category": "B", "id": [2, 4], "value": [20, 40], "info": False} + ] + + Note that the "flag" field is not aggregated, and must be the same + in all instances in the same category, or an error is raised. + """ + + by_field: str = NonPositionalField(required=True) + aggregate_fields: List[str] = NonPositionalField(required=True) + + def prepare(self): + super().prepare() + + def verify(self): + super().verify() + if not isinstance(self.by_field, str): + raise UnitxtError( + f"The 'by_field' value is not a string but '{type(self.by_field)}'" + ) + + if not isinstance(self.aggregate_fields, list): + raise UnitxtError( + f"The 'allowed_field_values' is not a list but '{type(self.aggregate_fields)}'" + ) + + def process(self, stream: Stream, stream_name: Optional[str] = None): + grouped_data = {} + + for instance in stream: + if self.by_field not in instance: + raise UnitxtError( + f"The field '{self.by_field}' specified by CollateInstancesByField's 'by_field' argument is not found in instance." + ) + for k in self.aggregate_fields: + if k not in instance: + raise UnitxtError( + f"The field '{k}' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance." + ) + key = instance[self.by_field] + + if key not in grouped_data: + grouped_data[key] = { + k: v for k, v in instance.items() if k not in self.aggregate_fields + } + # Add empty lists for fields to aggregate + for agg_field in self.aggregate_fields: + if agg_field in instance: + grouped_data[key][agg_field] = [] + + for k, v in instance.items(): + # Merge classification policy list across instance with same key + if k == "data_classification_policy" and instance[k]: + grouped_data[key][k] = sorted(set(grouped_data[key][k] + v)) + # Check consistency for all non-aggregate fields + elif k != self.by_field and k not in self.aggregate_fields: + if k in grouped_data[key] and grouped_data[key][k] != v: + raise ValueError( + f"Inconsistent value for field '{k}' in group '{key}': " + f"'{grouped_data[key][k]}' vs '{v}'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances." + ) + # Aggregate fields + elif k in self.aggregate_fields: + grouped_data[key][k].append(instance[k]) + + yield from grouped_data.values() + + class WikipediaFetcher(FieldOperator): mode: Literal["summary", "text"] = "text" _requirements_list = ["Wikipedia-API"] diff --git a/tests/library/test_operators.py b/tests/library/test_operators.py index af038dd654..9fb9b688b0 100644 --- a/tests/library/test_operators.py +++ b/tests/library/test_operators.py @@ -12,6 +12,7 @@ ApplyStreamOperatorsField, CastFields, CollateInstances, + CollateInstancesByField, Copy, Deduplicate, DeterministicBalancer, @@ -2653,6 +2654,115 @@ def test_collate_instance(self): tester=self, ) + def test_collate_instances_by_field(self): + inputs = [ + {"id": 1, "category": "A", "value": 10}, + {"id": 1, "category": "A", "value": 20}, + {"id": 2, "category": "B", "value": 30}, + {"id": 2, "category": "B", "value": 40}, + ] + + targets = [ + {"category": "A", "id": 1, "value": [10, 20]}, + {"category": "B", "id": 2, "value": [30, 40]}, + ] + + check_operator( + operator=CollateInstancesByField( + by_field="category", aggregate_fields=["value"] + ), + inputs=inputs, + targets=targets, + tester=self, + ) + + inputs = [ + { + "id": 1, + "category": "A", + "value": 10, + "data_classification_policy": ["public"], + }, + { + "id": 2, + "category": "A", + "value": 20, + "data_classification_policy": ["public"], + }, + { + "id": 3, + "category": "B", + "value": 30, + "data_classification_policy": ["public"], + }, + { + "id": 4, + "category": "B", + "value": 40, + "data_classification_policy": ["private"], + }, + ] + + targets = [ + { + "category": "A", + "id": [1, 2], + "value": [10, 20], + "data_classification_policy": ["public"], + }, + { + "category": "B", + "id": [3, 4], + "value": [30, 40], + "data_classification_policy": ["private", "public"], + }, + ] + + check_operator( + operator=CollateInstancesByField( + by_field="category", aggregate_fields=["value", "id"] + ), + inputs=inputs, + targets=targets, + tester=self, + ) + + exception_texts = [ + "Inconsistent value for field 'id' in group 'A': '1' vs '2'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances.", + ] + check_operator_exception( + operator=CollateInstancesByField( + by_field="category", aggregate_fields=["value"] + ), + inputs=inputs, + exception_texts=exception_texts, + tester=self, + ) + + exception_texts = [ + "The field 'not_exist' specified by CollateInstancesByField's 'by_field' argument is not found in instance." + ] + check_operator_exception( + operator=CollateInstancesByField( + by_field="not_exist", aggregate_fields=["value"] + ), + inputs=inputs, + exception_texts=exception_texts, + tester=self, + ) + + exception_texts = [ + "The field 'not_exist' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance." + ] + check_operator_exception( + operator=CollateInstancesByField( + by_field="category", aggregate_fields=["id", "value", "not_exist"] + ), + inputs=inputs, + exception_texts=exception_texts, + tester=self, + ) + class TestApplyMetric(UnitxtTestCase): def _test_apply_metric(