diff --git a/prepare/cards/wikitq.py b/prepare/cards/wikitq.py index c0b127d1a..ea1674d0e 100644 --- a/prepare/cards/wikitq.py +++ b/prepare/cards/wikitq.py @@ -3,7 +3,8 @@ TaskCard, ) from unitxt.catalog import add_to_catalog -from unitxt.operators import Copy, Set +from unitxt.operators import Copy, FilterByCondition, Set +from unitxt.struct_data_operators import GetNumOfTableCells from unitxt.templates import MultiReferenceTemplate from unitxt.test_utils.card import test_card @@ -14,6 +15,10 @@ ), preprocess_steps=[ Set({"context_type": "table"}), + GetNumOfTableCells(field="table", to_field="table_cell_size"), + FilterByCondition( + values={"table_cell_size": 200}, condition="le" + ), # filter out tables with more than 200 cells Copy(field="table", to_field="context"), # TruncateTableRows(field="table", to_field="context"), ], diff --git a/src/unitxt/catalog/cards/wikitq.json b/src/unitxt/catalog/cards/wikitq.json index d6e27ae45..652c77e0a 100644 --- a/src/unitxt/catalog/cards/wikitq.json +++ b/src/unitxt/catalog/cards/wikitq.json @@ -15,6 +15,18 @@ "context_type": "table" } }, + { + "__type__": "get_num_of_table_cells", + "field": "table", + "to_field": "table_cell_size" + }, + { + "__type__": "filter_by_condition", + "values": { + "table_cell_size": 200 + }, + "condition": "le" + }, { "__type__": "copy", "field": "table", diff --git a/src/unitxt/struct_data_operators.py b/src/unitxt/struct_data_operators.py index 50c7ee809..ac5bb9d6d 100644 --- a/src/unitxt/struct_data_operators.py +++ b/src/unitxt/struct_data_operators.py @@ -517,6 +517,15 @@ def truncate_table_rows(self, table_content: Dict): return table_content +class GetNumOfTableCells(FieldOperator): + """Get the number of cells in the given table.""" + + def process_value(self, table: Any) -> Any: + num_of_rows = len(table.get("rows")) + num_of_cols = len(table.get("header")) + return num_of_rows * num_of_cols + + class SerializeTableRowAsText(InstanceOperator): """Serializes a table row as text.