diff --git a/evadb/expression/abstract_expression.py b/evadb/expression/abstract_expression.py index 9b72f32e6..53a128723 100644 --- a/evadb/expression/abstract_expression.py +++ b/evadb/expression/abstract_expression.py @@ -48,6 +48,7 @@ class ExpressionType(IntEnum): FUNCTION_EXPRESSION = auto() + # Aggregation functions AGGREGATION_COUNT = auto() AGGREGATION_SUM = auto() AGGREGATION_MIN = auto() @@ -56,6 +57,7 @@ class ExpressionType(IntEnum): AGGREGATION_FIRST = auto() AGGREGATION_LAST = auto() AGGREGATION_SEGMENT = auto() + AGGREGATION_STRING_AGG = auto() CASE = auto() # add other types diff --git a/evadb/expression/aggregation_expression.py b/evadb/expression/aggregation_expression.py index f1ba6b16c..cb8932a37 100644 --- a/evadb/expression/aggregation_expression.py +++ b/evadb/expression/aggregation_expression.py @@ -54,6 +54,10 @@ def evaluate(self, *args, **kwargs): batch.aggregate("min") elif self.etype == ExpressionType.AGGREGATION_MAX: batch.aggregate("max") + elif self.etype == ExpressionType.AGGREGATION_STRING_AGG: + delimiter = kwargs.get("delimiter") + batch.aggregate_string_agg(delimiter) + batch.reset_index() column_name = self.etype.name @@ -82,6 +86,8 @@ def get_symbol(self) -> str: return "MIN" elif self.etype == ExpressionType.AGGREGATION_MAX: return "MAX" + elif self.etype == ExpressionType.AGGREGATION_STRING_AGG: + return "STRING_AGG" else: raise NotImplementedError diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 43e69cc4f..975efa31e 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -376,6 +376,21 @@ def aggregate(self, method: str) -> None: """ self._frames = self._frames.agg([method]) + def aggregate_string_agg(self, delimiter: str) -> None: + """ + Aggregate strings using a specified delimiter. + + Arguments: + delimiter (str): The delimiter to use for concatenation. + """ + if not isinstance(delimiter, str): + raise ValueError("Delimiter must be a string") + + aggregated_data = { + col: [delimiter.join(self._frames[col].astype(str))] for col in self._frames + } + self._frames = pd.DataFrame(aggregated_data) + def empty(self): """Checks if the batch is empty Returns: diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index 4b96bf647..d95c5776c 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -300,11 +300,12 @@ function_call: function ->function_call function: simple_id "(" (STAR | function_args)? ")" dotted_id? -aggregate_windowed_function: aggregate_function_name "(" function_arg ")" +aggregate_windowed_function: STRING_AGG "(" function_arg "," function_arg ")" + | aggregate_function_name "(" function_arg ")" | COUNT "(" (STAR | function_arg) ")" -aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT +aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT | STRING_AGG function_args: (function_arg) ("," function_arg)* @@ -488,6 +489,7 @@ FCOUNT: "FCOUNT"i FIRST: "FIRST"i LAST: "LAST"i SEGMENT: "SEGMENT"i +STRING_AGG: "STRING_AGG"i // Keywords, but can be ID // Common Keywords, but can be ID diff --git a/evadb/parser/lark_visitor/_functions.py b/evadb/parser/lark_visitor/_functions.py index 2b2c18095..4db63256a 100644 --- a/evadb/parser/lark_visitor/_functions.py +++ b/evadb/parser/lark_visitor/_functions.py @@ -134,6 +134,8 @@ def get_aggregate_function_type(self, agg_func_name): agg_func_type = ExpressionType.AGGREGATION_LAST elif agg_func_name == "SEGMENT": agg_func_type = ExpressionType.AGGREGATION_SEGMENT + elif agg_func_name == "STRING_AGG": + agg_func_type = ExpressionType.AGGREGATION_STRING_AGG return agg_func_type def aggregate_windowed_function(self, tree): diff --git a/script/test/test.sh b/script/test/test.sh old mode 100644 new mode 100755 diff --git a/test/unit_tests/expression/test_aggregation.py b/test/unit_tests/expression/test_aggregation.py index e8bf2a187..801afd841 100644 --- a/test/unit_tests/expression/test_aggregation.py +++ b/test/unit_tests/expression/test_aggregation.py @@ -52,7 +52,7 @@ def test_aggregation_last(self): def test_aggregation_segment(self): columnName = TupleValueExpression(name=0) - columnName.col_alias = 0 + columnName.col_alias = 0 # sets the col to use aggr_expr = AggregationExpression( ExpressionType.AGGREGATION_SEGMENT, None, columnName ) @@ -116,6 +116,58 @@ def test_aggregation_max(self): self.assertEqual(3, batch.frames.iloc[0][0]) self.assertNotEqual(str(aggr_expr), None) + def test_aggregation_string_agg(self): + columnName = TupleValueExpression(name=0) + columnName.col_alias = 0 + aggr_expr = AggregationExpression( + ExpressionType.AGGREGATION_STRING_AGG, None, columnName + ) + tuples = Batch( + pd.DataFrame( + { + 0: ["Hello", "World", "EvaDB", "Here"], + 1: ["Goodbye", "Everyone", "EvaDB", "Out"], + } + ) + ) + batch = aggr_expr.evaluate(tuples, delimiter=" ") + self.assertEqual("Hello World EvaDB Here", batch.frames.iloc[0][0]) + self.assertNotEqual(str(aggr_expr), None) + + def test_aggregation_string_agg_incorrect_column(self): + columnName = TupleValueExpression(name=0) + columnName.col_alias = 2 + aggr_expr = AggregationExpression( + ExpressionType.AGGREGATION_STRING_AGG, None, columnName + ) + tuples = Batch( + pd.DataFrame( + { + 0: ["Hello", "World", "EvaDB" "Here"], + 1: ["Goodbye", "Everyone", "EvaDB" "Out"], + } + ) + ) + with pytest.raises(AssertionError): + aggr_expr.evaluate(tuples, delimiter=" ") + + def test_aggregation_string_agg_incorrect_delimiter(self): + columnName = TupleValueExpression(name=0) + columnName.col_alias = 0 + aggr_expr = AggregationExpression( + ExpressionType.AGGREGATION_STRING_AGG, None, columnName + ) + tuples = Batch( + pd.DataFrame( + { + 0: ["Hello", "World", "EvaDB" "Here"], + 1: ["Goodbye", "Everyone", "EvaDB" "Out"], + } + ) + ) + with pytest.raises(ValueError): + aggr_expr.evaluate(tuples, delimiter=0) + def test_aggregation_incorrect_etype(self): incorrect_etype = 100 columnName = TupleValueExpression(name=0)