From 99b3d722a5fbd0fff63054cf10a1249c904b9218 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Fri, 24 Nov 2023 00:08:44 -0500 Subject: [PATCH 01/17] Added STRING_AGG to evadb.lark --- evadb/parser/evadb.lark | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index 4b96bf647b..4d92ab9447 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -300,11 +300,12 @@ function_call: function ->function_call function: simple_id "(" (STAR | function_args)? ")" dotted_id? -aggregate_windowed_function: aggregate_function_name "(" function_arg ")" +aggregate_windowed_function: STRING_AGG "(" function_arg "," function_arg ")" + | aggregate_function_name "(" function_arg ")" | COUNT "(" (STAR | function_arg) ")" -aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT +aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT | STRING_AGG function_args: (function_arg) ("," function_arg)* From 3d3bbdaeaad55128806c01e64be3434788ca14f3 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Fri, 24 Nov 2023 00:16:50 -0500 Subject: [PATCH 02/17] Add STRING_AGG to aggregate functions --- evadb/expression/abstract_expression.py | 2 ++ evadb/parser/lark_visitor/_functions.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/evadb/expression/abstract_expression.py b/evadb/expression/abstract_expression.py index 9b72f32e68..1d4a42354c 100644 --- a/evadb/expression/abstract_expression.py +++ b/evadb/expression/abstract_expression.py @@ -48,6 +48,7 @@ class ExpressionType(IntEnum): FUNCTION_EXPRESSION = auto() + # Aggregation functions AGGREGATION_COUNT = auto() AGGREGATION_SUM = auto() AGGREGATION_MIN = auto() @@ -56,6 +57,7 @@ class ExpressionType(IntEnum): AGGREGATION_FIRST = auto() AGGREGATION_LAST = auto() AGGREGATION_SEGMENT = auto() + AGGREGATION_STRING_AGG = auto() CASE = auto() # add other types diff --git a/evadb/parser/lark_visitor/_functions.py b/evadb/parser/lark_visitor/_functions.py index 2b2c180953..4db63256af 100644 --- a/evadb/parser/lark_visitor/_functions.py +++ b/evadb/parser/lark_visitor/_functions.py @@ -134,6 +134,8 @@ def get_aggregate_function_type(self, agg_func_name): agg_func_type = ExpressionType.AGGREGATION_LAST elif agg_func_name == "SEGMENT": agg_func_type = ExpressionType.AGGREGATION_SEGMENT + elif agg_func_name == "STRING_AGG": + agg_func_type = ExpressionType.AGGREGATION_STRING_AGG return agg_func_type def aggregate_windowed_function(self, tree): From 17c94ac82b718e0baf35fe43c17afdaad2483585 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Fri, 24 Nov 2023 01:35:25 -0500 Subject: [PATCH 03/17] Native support for STRING_AGG --- .../unit_tests/expression/test_aggregation.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/test/unit_tests/expression/test_aggregation.py b/test/unit_tests/expression/test_aggregation.py index e8bf2a1875..53c0ef77ab 100644 --- a/test/unit_tests/expression/test_aggregation.py +++ b/test/unit_tests/expression/test_aggregation.py @@ -52,7 +52,7 @@ def test_aggregation_last(self): def test_aggregation_segment(self): columnName = TupleValueExpression(name=0) - columnName.col_alias = 0 + columnName.col_alias = 0 # sets the col to use aggr_expr = AggregationExpression( ExpressionType.AGGREGATION_SEGMENT, None, columnName ) @@ -116,6 +116,37 @@ def test_aggregation_max(self): self.assertEqual(3, batch.frames.iloc[0][0]) self.assertNotEqual(str(aggr_expr), None) + def test_aggregation_string_agg(self): + columnName = TupleValueExpression(name=0) + columnName.col_alias = 0 + aggr_expr = AggregationExpression( + ExpressionType.AGGREGATION_STRING_AGG, None, columnName + ) + tuples = Batch(pd.DataFrame({0: ["Hello", "World", "EvaDB" "Here"], 1: ["Goodbye", "Everyone", "EvaDB" "Out"]})) + batch = aggr_expr.evaluate(tuples, delimiter=" ") + self.assertEqual("Hello World EvaDB Here", batch.frames.iloc[0][0]) + self.assertNotEqual(str(aggr_expr), None) + + def test_aggregation_string_agg_incorrect_column(self): + columnName = TupleValueExpression(name=0) + columnName.col_alias = 2 + aggr_expr = AggregationExpression( + ExpressionType.AGGREGATION_STRING_AGG, None, columnName + ) + tuples = Batch(pd.DataFrame({0: ["Hello", "World", "EvaDB" "Here"], 1: ["Goodbye", "Everyone", "EvaDB" "Out"]})) + with pytest.raises(KeyError): + batch = aggr_expr.evaluate(tuples, delimiter=" ") + + def test_aggregation_string_agg_incorrect_delimiter(self): + columnName = TupleValueExpression(name=0) + columnName.col_alias = 0 + aggr_expr = AggregationExpression( + ExpressionType.AGGREGATION_STRING_AGG, None, columnName + ) + tuples = Batch(pd.DataFrame({0: ["Hello", "World", "EvaDB" "Here"], 1: ["Goodbye", "Everyone", "EvaDB" "Out"]})) + with pytest.raises(ValueError): + batch = aggr_expr.evaluate(tuples, delimiter=0) + def test_aggregation_incorrect_etype(self): incorrect_etype = 100 columnName = TupleValueExpression(name=0) From 0e10f32aa8a2de503e2f765fb23ba0824c2e8040 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Fri, 24 Nov 2023 01:43:54 -0500 Subject: [PATCH 04/17] Native support for STRING_AGG --- evadb/expression/aggregation_expression.py | 9 +++++++-- evadb/models/storage/batch.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/evadb/expression/aggregation_expression.py b/evadb/expression/aggregation_expression.py index f1ba6b16c4..a203d9c563 100644 --- a/evadb/expression/aggregation_expression.py +++ b/evadb/expression/aggregation_expression.py @@ -37,7 +37,7 @@ def __init__( ) # can also be a float def evaluate(self, *args, **kwargs): - batch: Batch = self.get_child(0).evaluate(*args, **kwargs) + batch: Batch = self.get_child(0).evaluate(*args, **kwargs) #.get_child returns a TupleValueExpression -> .evaluate takes in the pandas if self.etype == ExpressionType.AGGREGATION_FIRST: batch = batch[0] elif self.etype == ExpressionType.AGGREGATION_LAST: @@ -54,10 +54,15 @@ def evaluate(self, *args, **kwargs): batch.aggregate("min") elif self.etype == ExpressionType.AGGREGATION_MAX: batch.aggregate("max") + elif self.etype == ExpressionType.AGGREGATION_STRING_AGG: + column_name = self.get_child(0).col_alias + delimiter = kwargs.get('delimiter') + batch.aggregate_string_agg(column_name, delimiter) + batch.reset_index() column_name = self.etype.name - if column_name.find("AGGREGATION_") != -1: + if column_name.find("AGGREGATION_") != -1: # Not an aggregation function # AGGREGATION_MAX -> MAX updated_column_name = column_name.replace("AGGREGATION_", "") batch.modify_column_alias(updated_column_name) diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 43e69cc4fc..1cd53f1640 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -376,6 +376,24 @@ def aggregate(self, method: str) -> None: """ self._frames = self._frames.agg([method]) + def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: + """ + Aggregate strings in a column using a specified delimiter. + + Arguments: + column_name (str): The name of the column to aggregate. + delimiter (str): The delimiter to use for concatenation. + """ + verified_col = column_name if column_name in self._frames else None + + if not verified_col: + raise KeyError(f"ERROR: column '{column_name}' does not exist") + + if not delimiter or not isinstance(delimiter, str): + raise ValueError("Delimiter must be a string") + + self._frames = self._frames.agg(lambda x: delimiter.join(x.astype(str)), axis=0)[verified_col] + def empty(self): """Checks if the batch is empty Returns: From 1657a59efced66a10bce67d98c0411657ca4dfa3 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Fri, 24 Nov 2023 18:59:09 -0500 Subject: [PATCH 05/17] Removed comments --- evadb/expression/aggregation_expression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/evadb/expression/aggregation_expression.py b/evadb/expression/aggregation_expression.py index a203d9c563..f1a6d2f3e5 100644 --- a/evadb/expression/aggregation_expression.py +++ b/evadb/expression/aggregation_expression.py @@ -37,7 +37,7 @@ def __init__( ) # can also be a float def evaluate(self, *args, **kwargs): - batch: Batch = self.get_child(0).evaluate(*args, **kwargs) #.get_child returns a TupleValueExpression -> .evaluate takes in the pandas + batch: Batch = self.get_child(0).evaluate(*args, **kwargs) if self.etype == ExpressionType.AGGREGATION_FIRST: batch = batch[0] elif self.etype == ExpressionType.AGGREGATION_LAST: @@ -62,7 +62,7 @@ def evaluate(self, *args, **kwargs): batch.reset_index() column_name = self.etype.name - if column_name.find("AGGREGATION_") != -1: # Not an aggregation function + if column_name.find("AGGREGATION_") != -1: # AGGREGATION_MAX -> MAX updated_column_name = column_name.replace("AGGREGATION_", "") batch.modify_column_alias(updated_column_name) From fcc91fae68bd032a679fdb6b7b357daf32def62c Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 03:52:50 -0500 Subject: [PATCH 06/17] KeyError Bug Fix --- evadb/expression/aggregation_expression.py | 2 ++ evadb/models/storage/batch.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/evadb/expression/aggregation_expression.py b/evadb/expression/aggregation_expression.py index f1a6d2f3e5..5dbe8822b2 100644 --- a/evadb/expression/aggregation_expression.py +++ b/evadb/expression/aggregation_expression.py @@ -38,6 +38,8 @@ def __init__( def evaluate(self, *args, **kwargs): batch: Batch = self.get_child(0).evaluate(*args, **kwargs) + column_name = self.get_child(0).col_alias + print(column_name, batch.columns) if self.etype == ExpressionType.AGGREGATION_FIRST: batch = batch[0] elif self.etype == ExpressionType.AGGREGATION_LAST: diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 1cd53f1640..5bc6c91863 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -384,7 +384,8 @@ def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: column_name (str): The name of the column to aggregate. delimiter (str): The delimiter to use for concatenation. """ - verified_col = column_name if column_name in self._frames else None + updated_column_name = f"STRING_AGG.{column_name}" + verified_col = updated_column_name if updated_column_name in self._frames else None if not verified_col: raise KeyError(f"ERROR: column '{column_name}' does not exist") From 20a7d780fb19f92188b22367f34d6b93cab0cd74 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 03:59:40 -0500 Subject: [PATCH 07/17] Defined STRING_AGG in LARK grammar --- evadb/parser/evadb.lark | 1 + 1 file changed, 1 insertion(+) diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index 4d92ab9447..264ec1e6aa 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -489,6 +489,7 @@ FCOUNT: "FCOUNT"i FIRST: "FIRST"i LAST: "LAST"i SEGMENT: "SEGMENT"i +STRING_AGG "STRING_AGG"i // Keywords, but can be ID // Common Keywords, but can be ID From 92edfe2be744f8e9043af317664fcc5926465fe9 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 04:07:36 -0500 Subject: [PATCH 08/17] Defined STRING_AGG in LARK grammar --- evadb/parser/evadb.lark | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index 264ec1e6aa..d95c5776cc 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -489,7 +489,7 @@ FCOUNT: "FCOUNT"i FIRST: "FIRST"i LAST: "LAST"i SEGMENT: "SEGMENT"i -STRING_AGG "STRING_AGG"i +STRING_AGG: "STRING_AGG"i // Keywords, but can be ID // Common Keywords, but can be ID From 9becc4b7088e38b06e3305170a8a179a77220599 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 04:09:18 -0500 Subject: [PATCH 09/17] Removed unused batch variable --- test/unit_tests/expression/test_aggregation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unit_tests/expression/test_aggregation.py b/test/unit_tests/expression/test_aggregation.py index 53c0ef77ab..125ef969a3 100644 --- a/test/unit_tests/expression/test_aggregation.py +++ b/test/unit_tests/expression/test_aggregation.py @@ -135,7 +135,7 @@ def test_aggregation_string_agg_incorrect_column(self): ) tuples = Batch(pd.DataFrame({0: ["Hello", "World", "EvaDB" "Here"], 1: ["Goodbye", "Everyone", "EvaDB" "Out"]})) with pytest.raises(KeyError): - batch = aggr_expr.evaluate(tuples, delimiter=" ") + aggr_expr.evaluate(tuples, delimiter=" ") def test_aggregation_string_agg_incorrect_delimiter(self): columnName = TupleValueExpression(name=0) @@ -145,7 +145,7 @@ def test_aggregation_string_agg_incorrect_delimiter(self): ) tuples = Batch(pd.DataFrame({0: ["Hello", "World", "EvaDB" "Here"], 1: ["Goodbye", "Everyone", "EvaDB" "Out"]})) with pytest.raises(ValueError): - batch = aggr_expr.evaluate(tuples, delimiter=0) + aggr_expr.evaluate(tuples, delimiter=0) def test_aggregation_incorrect_etype(self): incorrect_etype = 100 From 4b8c94e7352d8a57aa10c53ea8036dafb63f8e6b Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 04:15:17 -0500 Subject: [PATCH 10/17] Removed print() statement --- evadb/expression/aggregation_expression.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/evadb/expression/aggregation_expression.py b/evadb/expression/aggregation_expression.py index 5dbe8822b2..f1a6d2f3e5 100644 --- a/evadb/expression/aggregation_expression.py +++ b/evadb/expression/aggregation_expression.py @@ -38,8 +38,6 @@ def __init__( def evaluate(self, *args, **kwargs): batch: Batch = self.get_child(0).evaluate(*args, **kwargs) - column_name = self.get_child(0).col_alias - print(column_name, batch.columns) if self.etype == ExpressionType.AGGREGATION_FIRST: batch = batch[0] elif self.etype == ExpressionType.AGGREGATION_LAST: From 55e0b3ad6c3c30ff3062a69fc18b3a951c4a7c41 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 04:32:52 -0500 Subject: [PATCH 11/17] KeyError message update --- evadb/models/storage/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 5bc6c91863..6ebdbe028f 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -388,7 +388,7 @@ def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: verified_col = updated_column_name if updated_column_name in self._frames else None if not verified_col: - raise KeyError(f"ERROR: column '{column_name}' does not exist") + raise KeyError(f"ERROR: column '{updated_column_name}' does not exist in columns: {self._frames.columns}") if not delimiter or not isinstance(delimiter, str): raise ValueError("Delimiter must be a string") From 79204473d7b16909f41be0de4f747c92d440c11e Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 04:40:44 -0500 Subject: [PATCH 12/17] column_name check in self._frames --- evadb/models/storage/batch.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 6ebdbe028f..a1c885e833 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -384,11 +384,10 @@ def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: column_name (str): The name of the column to aggregate. delimiter (str): The delimiter to use for concatenation. """ - updated_column_name = f"STRING_AGG.{column_name}" - verified_col = updated_column_name if updated_column_name in self._frames else None + verified_col = column_name if column_name in self._frames else None if not verified_col: - raise KeyError(f"ERROR: column '{updated_column_name}' does not exist in columns: {self._frames.columns}") + raise KeyError(f"ERROR: column '{column_name}' does not exist in columns: {self._frames.columns}") if not delimiter or not isinstance(delimiter, str): raise ValueError("Delimiter must be a string") From 2cc16b03105b3763ada034a8ba7ace112a54d0f5 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 04:50:20 -0500 Subject: [PATCH 13/17] Check column_name against self._frames.columns directly --- evadb/models/storage/batch.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index a1c885e833..8d94868a5b 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -384,15 +384,13 @@ def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: column_name (str): The name of the column to aggregate. delimiter (str): The delimiter to use for concatenation. """ - verified_col = column_name if column_name in self._frames else None - - if not verified_col: + if column_name not in self._frames.columns: raise KeyError(f"ERROR: column '{column_name}' does not exist in columns: {self._frames.columns}") if not delimiter or not isinstance(delimiter, str): raise ValueError("Delimiter must be a string") - self._frames = self._frames.agg(lambda x: delimiter.join(x.astype(str)), axis=0)[verified_col] + self._frames = self._frames.agg(lambda x: delimiter.join(x.astype(str)), axis=0)[column_name] def empty(self): """Checks if the batch is empty From 360e8f3b96cedc8602e66a1118e41d27f42de602 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 05:06:22 -0500 Subject: [PATCH 14/17] Debug self._frames --- evadb/models/storage/batch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 8d94868a5b..353c642cfb 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -391,7 +391,9 @@ def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: raise ValueError("Delimiter must be a string") self._frames = self._frames.agg(lambda x: delimiter.join(x.astype(str)), axis=0)[column_name] - + if isinstance(self._frames, str): + raise TypeError(f"ERROR: self._frames converted to a string: {self._frames}") + def empty(self): """Checks if the batch is empty Returns: From e91fef9100ad19abc118d58be9d6adef5c5261e1 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 05:12:06 -0500 Subject: [PATCH 15/17] self._frames to remain a Pandas DataFrame --- evadb/models/storage/batch.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 353c642cfb..8e873a10e2 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -390,9 +390,7 @@ def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: if not delimiter or not isinstance(delimiter, str): raise ValueError("Delimiter must be a string") - self._frames = self._frames.agg(lambda x: delimiter.join(x.astype(str)), axis=0)[column_name] - if isinstance(self._frames, str): - raise TypeError(f"ERROR: self._frames converted to a string: {self._frames}") + self._frames = self._frames[[column_name]].agg(lambda x: delimiter.join(x.astype(str)), axis=0) def empty(self): """Checks if the batch is empty From 1d940f798e850886d35981fa55edaa052f1f27ce Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 05:12:06 -0500 Subject: [PATCH 16/17] self._frames to remain a Pandas DataFrame --- evadb/models/storage/batch.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 353c642cfb..25dc659252 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -385,14 +385,12 @@ def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: delimiter (str): The delimiter to use for concatenation. """ if column_name not in self._frames.columns: - raise KeyError(f"ERROR: column '{column_name}' does not exist in columns: {self._frames.columns}") + raise KeyError(f"ERROR: column '{column_name}' does not exist in columns: {self.columns}") if not delimiter or not isinstance(delimiter, str): raise ValueError("Delimiter must be a string") - self._frames = self._frames.agg(lambda x: delimiter.join(x.astype(str)), axis=0)[column_name] - if isinstance(self._frames, str): - raise TypeError(f"ERROR: self._frames converted to a string: {self._frames}") + self._frames = pd.DataFrame({column_name: [self._frames[column_name].astype(str).agg(delimiter.join)]}) def empty(self): """Checks if the batch is empty From 41ead18dcfe93b2241ca7a4ccf5bf2a80da26fb2 Mon Sep 17 00:00:00 2001 From: ryanmle2001 Date: Sat, 25 Nov 2023 17:57:33 -0500 Subject: [PATCH 17/17] Native support for STRING_AGG --- evadb/expression/abstract_expression.py | 2 +- evadb/expression/aggregation_expression.py | 11 ++++--- evadb/models/storage/batch.py | 17 +++++----- script/test/test.sh | 0 .../unit_tests/expression/test_aggregation.py | 31 ++++++++++++++++--- 5 files changed, 41 insertions(+), 20 deletions(-) mode change 100644 => 100755 script/test/test.sh diff --git a/evadb/expression/abstract_expression.py b/evadb/expression/abstract_expression.py index 1d4a42354c..53a1287230 100644 --- a/evadb/expression/abstract_expression.py +++ b/evadb/expression/abstract_expression.py @@ -57,7 +57,7 @@ class ExpressionType(IntEnum): AGGREGATION_FIRST = auto() AGGREGATION_LAST = auto() AGGREGATION_SEGMENT = auto() - AGGREGATION_STRING_AGG = auto() + AGGREGATION_STRING_AGG = auto() CASE = auto() # add other types diff --git a/evadb/expression/aggregation_expression.py b/evadb/expression/aggregation_expression.py index f1a6d2f3e5..cb8932a378 100644 --- a/evadb/expression/aggregation_expression.py +++ b/evadb/expression/aggregation_expression.py @@ -37,7 +37,7 @@ def __init__( ) # can also be a float def evaluate(self, *args, **kwargs): - batch: Batch = self.get_child(0).evaluate(*args, **kwargs) + batch: Batch = self.get_child(0).evaluate(*args, **kwargs) if self.etype == ExpressionType.AGGREGATION_FIRST: batch = batch[0] elif self.etype == ExpressionType.AGGREGATION_LAST: @@ -55,14 +55,13 @@ def evaluate(self, *args, **kwargs): elif self.etype == ExpressionType.AGGREGATION_MAX: batch.aggregate("max") elif self.etype == ExpressionType.AGGREGATION_STRING_AGG: - column_name = self.get_child(0).col_alias - delimiter = kwargs.get('delimiter') - batch.aggregate_string_agg(column_name, delimiter) + delimiter = kwargs.get("delimiter") + batch.aggregate_string_agg(delimiter) batch.reset_index() column_name = self.etype.name - if column_name.find("AGGREGATION_") != -1: + if column_name.find("AGGREGATION_") != -1: # AGGREGATION_MAX -> MAX updated_column_name = column_name.replace("AGGREGATION_", "") batch.modify_column_alias(updated_column_name) @@ -87,6 +86,8 @@ def get_symbol(self) -> str: return "MIN" elif self.etype == ExpressionType.AGGREGATION_MAX: return "MAX" + elif self.etype == ExpressionType.AGGREGATION_STRING_AGG: + return "STRING_AGG" else: raise NotImplementedError diff --git a/evadb/models/storage/batch.py b/evadb/models/storage/batch.py index 25dc659252..975efa31ef 100644 --- a/evadb/models/storage/batch.py +++ b/evadb/models/storage/batch.py @@ -376,22 +376,21 @@ def aggregate(self, method: str) -> None: """ self._frames = self._frames.agg([method]) - def aggregate_string_agg(self, column_name: str, delimiter: str) -> None: + def aggregate_string_agg(self, delimiter: str) -> None: """ - Aggregate strings in a column using a specified delimiter. + Aggregate strings using a specified delimiter. Arguments: - column_name (str): The name of the column to aggregate. delimiter (str): The delimiter to use for concatenation. """ - if column_name not in self._frames.columns: - raise KeyError(f"ERROR: column '{column_name}' does not exist in columns: {self.columns}") - - if not delimiter or not isinstance(delimiter, str): + if not isinstance(delimiter, str): raise ValueError("Delimiter must be a string") - self._frames = pd.DataFrame({column_name: [self._frames[column_name].astype(str).agg(delimiter.join)]}) - + aggregated_data = { + col: [delimiter.join(self._frames[col].astype(str))] for col in self._frames + } + self._frames = pd.DataFrame(aggregated_data) + def empty(self): """Checks if the batch is empty Returns: diff --git a/script/test/test.sh b/script/test/test.sh old mode 100644 new mode 100755 diff --git a/test/unit_tests/expression/test_aggregation.py b/test/unit_tests/expression/test_aggregation.py index 125ef969a3..801afd841f 100644 --- a/test/unit_tests/expression/test_aggregation.py +++ b/test/unit_tests/expression/test_aggregation.py @@ -52,7 +52,7 @@ def test_aggregation_last(self): def test_aggregation_segment(self): columnName = TupleValueExpression(name=0) - columnName.col_alias = 0 # sets the col to use + columnName.col_alias = 0 # sets the col to use aggr_expr = AggregationExpression( ExpressionType.AGGREGATION_SEGMENT, None, columnName ) @@ -122,7 +122,14 @@ def test_aggregation_string_agg(self): aggr_expr = AggregationExpression( ExpressionType.AGGREGATION_STRING_AGG, None, columnName ) - tuples = Batch(pd.DataFrame({0: ["Hello", "World", "EvaDB" "Here"], 1: ["Goodbye", "Everyone", "EvaDB" "Out"]})) + tuples = Batch( + pd.DataFrame( + { + 0: ["Hello", "World", "EvaDB", "Here"], + 1: ["Goodbye", "Everyone", "EvaDB", "Out"], + } + ) + ) batch = aggr_expr.evaluate(tuples, delimiter=" ") self.assertEqual("Hello World EvaDB Here", batch.frames.iloc[0][0]) self.assertNotEqual(str(aggr_expr), None) @@ -133,8 +140,15 @@ def test_aggregation_string_agg_incorrect_column(self): aggr_expr = AggregationExpression( ExpressionType.AGGREGATION_STRING_AGG, None, columnName ) - tuples = Batch(pd.DataFrame({0: ["Hello", "World", "EvaDB" "Here"], 1: ["Goodbye", "Everyone", "EvaDB" "Out"]})) - with pytest.raises(KeyError): + tuples = Batch( + pd.DataFrame( + { + 0: ["Hello", "World", "EvaDB" "Here"], + 1: ["Goodbye", "Everyone", "EvaDB" "Out"], + } + ) + ) + with pytest.raises(AssertionError): aggr_expr.evaluate(tuples, delimiter=" ") def test_aggregation_string_agg_incorrect_delimiter(self): @@ -143,7 +157,14 @@ def test_aggregation_string_agg_incorrect_delimiter(self): aggr_expr = AggregationExpression( ExpressionType.AGGREGATION_STRING_AGG, None, columnName ) - tuples = Batch(pd.DataFrame({0: ["Hello", "World", "EvaDB" "Here"], 1: ["Goodbye", "Everyone", "EvaDB" "Out"]})) + tuples = Batch( + pd.DataFrame( + { + 0: ["Hello", "World", "EvaDB" "Here"], + 1: ["Goodbye", "Everyone", "EvaDB" "Out"], + } + ) + ) with pytest.raises(ValueError): aggr_expr.evaluate(tuples, delimiter=0)