Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native support for STRING_AGG #1393

Open
wants to merge 18 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions evadb/expression/abstract_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ExpressionType(IntEnum):

FUNCTION_EXPRESSION = auto()

# Aggregation functions
AGGREGATION_COUNT = auto()
AGGREGATION_SUM = auto()
AGGREGATION_MIN = auto()
Expand All @@ -56,6 +57,7 @@ class ExpressionType(IntEnum):
AGGREGATION_FIRST = auto()
AGGREGATION_LAST = auto()
AGGREGATION_SEGMENT = auto()
AGGREGATION_STRING_AGG = auto()

CASE = auto()
# add other types
Expand Down
6 changes: 6 additions & 0 deletions evadb/expression/aggregation_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def evaluate(self, *args, **kwargs):
batch.aggregate("min")
elif self.etype == ExpressionType.AGGREGATION_MAX:
batch.aggregate("max")
elif self.etype == ExpressionType.AGGREGATION_STRING_AGG:
delimiter = kwargs.get("delimiter")
batch.aggregate_string_agg(delimiter)

batch.reset_index()

column_name = self.etype.name
Expand Down Expand Up @@ -82,6 +86,8 @@ def get_symbol(self) -> str:
return "MIN"
elif self.etype == ExpressionType.AGGREGATION_MAX:
return "MAX"
elif self.etype == ExpressionType.AGGREGATION_STRING_AGG:
return "STRING_AGG"
else:
raise NotImplementedError

Expand Down
15 changes: 15 additions & 0 deletions evadb/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,21 @@ def aggregate(self, method: str) -> None:
"""
self._frames = self._frames.agg([method])

def aggregate_string_agg(self, delimiter: str) -> None:
"""
Aggregate strings using a specified delimiter.

Arguments:
delimiter (str): The delimiter to use for concatenation.
"""
if not isinstance(delimiter, str):
raise ValueError("Delimiter must be a string")

aggregated_data = {
col: [delimiter.join(self._frames[col].astype(str))] for col in self._frames
}
self._frames = pd.DataFrame(aggregated_data)

def empty(self):
"""Checks if the batch is empty
Returns:
Expand Down
6 changes: 4 additions & 2 deletions evadb/parser/evadb.lark
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,12 @@ function_call: function ->function_call

function: simple_id "(" (STAR | function_args)? ")" dotted_id?

aggregate_windowed_function: aggregate_function_name "(" function_arg ")"
aggregate_windowed_function: STRING_AGG "(" function_arg "," function_arg ")"
| aggregate_function_name "(" function_arg ")"
| COUNT "(" (STAR | function_arg) ")"


aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT
aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT | STRING_AGG

function_args: (function_arg) ("," function_arg)*

Expand Down Expand Up @@ -488,6 +489,7 @@ FCOUNT: "FCOUNT"i
FIRST: "FIRST"i
LAST: "LAST"i
SEGMENT: "SEGMENT"i
STRING_AGG: "STRING_AGG"i

// Keywords, but can be ID
// Common Keywords, but can be ID
Expand Down
2 changes: 2 additions & 0 deletions evadb/parser/lark_visitor/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def get_aggregate_function_type(self, agg_func_name):
agg_func_type = ExpressionType.AGGREGATION_LAST
elif agg_func_name == "SEGMENT":
agg_func_type = ExpressionType.AGGREGATION_SEGMENT
elif agg_func_name == "STRING_AGG":
agg_func_type = ExpressionType.AGGREGATION_STRING_AGG
return agg_func_type

def aggregate_windowed_function(self, tree):
Expand Down
Empty file modified script/test/test.sh
100644 → 100755
Empty file.
54 changes: 53 additions & 1 deletion test/unit_tests/expression/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_aggregation_last(self):

def test_aggregation_segment(self):
columnName = TupleValueExpression(name=0)
columnName.col_alias = 0
columnName.col_alias = 0 # sets the col to use
aggr_expr = AggregationExpression(
ExpressionType.AGGREGATION_SEGMENT, None, columnName
)
Expand Down Expand Up @@ -116,6 +116,58 @@ def test_aggregation_max(self):
self.assertEqual(3, batch.frames.iloc[0][0])
self.assertNotEqual(str(aggr_expr), None)

def test_aggregation_string_agg(self):
columnName = TupleValueExpression(name=0)
columnName.col_alias = 0
aggr_expr = AggregationExpression(
ExpressionType.AGGREGATION_STRING_AGG, None, columnName
)
tuples = Batch(
pd.DataFrame(
{
0: ["Hello", "World", "EvaDB", "Here"],
1: ["Goodbye", "Everyone", "EvaDB", "Out"],
}
)
)
batch = aggr_expr.evaluate(tuples, delimiter=" ")
self.assertEqual("Hello World EvaDB Here", batch.frames.iloc[0][0])
self.assertNotEqual(str(aggr_expr), None)

def test_aggregation_string_agg_incorrect_column(self):
columnName = TupleValueExpression(name=0)
xzdandy marked this conversation as resolved.
Show resolved Hide resolved
columnName.col_alias = 2
aggr_expr = AggregationExpression(
ExpressionType.AGGREGATION_STRING_AGG, None, columnName
)
tuples = Batch(
pd.DataFrame(
{
0: ["Hello", "World", "EvaDB" "Here"],
1: ["Goodbye", "Everyone", "EvaDB" "Out"],
}
)
)
with pytest.raises(AssertionError):
aggr_expr.evaluate(tuples, delimiter=" ")

def test_aggregation_string_agg_incorrect_delimiter(self):
columnName = TupleValueExpression(name=0)
xzdandy marked this conversation as resolved.
Show resolved Hide resolved
columnName.col_alias = 0
aggr_expr = AggregationExpression(
ExpressionType.AGGREGATION_STRING_AGG, None, columnName
)
tuples = Batch(
pd.DataFrame(
{
0: ["Hello", "World", "EvaDB" "Here"],
1: ["Goodbye", "Everyone", "EvaDB" "Out"],
}
)
)
with pytest.raises(ValueError):
aggr_expr.evaluate(tuples, delimiter=0)

def test_aggregation_incorrect_etype(self):
incorrect_etype = 100
columnName = TupleValueExpression(name=0)
Expand Down