Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postgres integration #1427

Open
wants to merge 2 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/postgres_setup.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## INSTRUCTIONS TO SET UP POSTGRES DATABASE

These were the steps followed on a Mac system to install postgres and configure it:

1. Install postgres
``` brew install postgresql```
2. Start PostgresSQL:
```brew services start postgressql```
3. Create a user and database:

Username: evadb

Password: password

- ```psql postgres```
- ```CREATE ROLE evadb WITH LOGIN PASSWORD 'password';```
- ```ALTER ROLE evadb CREATEDB;```
- ```\q```
4. Login as your new user

``` psql -d postgres -U evadb```
5. create the database evadb

```CREATE DATABASE evadb;```
6. ```pip install psycopg2```
13 changes: 8 additions & 5 deletions evadb/catalog/sql_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,14 @@ def __init__(self, uri):
# set echo=True to log SQL

# Default to SQLite.
connect_args = {"timeout": 1000}
self.engine = create_engine(
self.worker_uri,
connect_args=connect_args,
)
# connect_args = {"timeout": 1000}
# self.engine = create_engine(
# self.worker_uri,
# connect_args=connect_args,
# )

# Postgres version
self.engine = create_engine(self.worker_uri)

if self.engine.url.get_backend_name() == "sqlite":
# enforce foreign key constraint and wal logging for sqlite
Expand Down
15 changes: 14 additions & 1 deletion evadb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,23 @@ def catalog(self) -> "CatalogManager":
return self.catalog_func(self.catalog_uri)


def get_default_db_uri(evadb_dir: Path):
def get_default_db_uri_sqlite(evadb_dir: Path):
# Default to sqlite.
return f"sqlite:///{evadb_dir.resolve()}/{DB_DEFAULT_NAME}"

def get_default_db_uri(evadb_dir: Path):
"""
Generates a PostgreSQL connection URI for the local database.

Returns:
str: A PostgreSQL connection URI.
"""
user = "evadb"
password = "password"
host = "localhost"
port = 5432 # Default PostgreSQL port
db_name = "evadb"
return f"postgresql://{user}:{password}@{host}:{port}/{db_name}"

def init_evadb_instance(
db_dir: str, host: str = None, port: int = None, custom_db_uri: str = None
Expand Down
1 change: 1 addition & 0 deletions evadb/expression/abstract_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ExpressionType(IntEnum):
AGGREGATION_FIRST = auto()
AGGREGATION_LAST = auto()
AGGREGATION_SEGMENT = auto()
AGGREGATION_STRING_AGG = auto()

CASE = auto()
# add other types
Expand Down
6 changes: 6 additions & 0 deletions evadb/expression/aggregation_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def evaluate(self, *args, **kwargs):
batch.aggregate("min")
elif self.etype == ExpressionType.AGGREGATION_MAX:
batch.aggregate("max")
elif self.etype == ExpressionType.AGGREGATION_STRING_AGG:
# Assuming two children: the column and the delimiter
column_to_aggregate = self.get_child(0).evaluate(*args, **kwargs)
delimiter = kwargs.get('delimiter')
batch.aggregate_string_aggregation(column_to_aggregate, delimiter)

batch.reset_index()

column_name = self.etype.name
Expand Down
16 changes: 16 additions & 0 deletions evadb/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,19 @@ def to_numpy(self):
def rename(self, columns) -> None:
"Rename column names"
self._frames.rename(columns=columns, inplace=True)

def aggregate_string_aggregation(self, column_name:str, delimiter:str):
# First, ensure the column data is in string format
string_column = self._frames[column_name].astype(str)

def aggregate_column(data, sep):
# Join the data using the provided separator
aggregated_string = sep.join(data)
return aggregated_string

aggregated_result = aggregate_column(string_column, delimiter)

aggregated_dataframe = pd.DataFrame({column_name: [aggregated_result]})

# Update the original DataFrame with the new aggregated DataFrame
self._frames = aggregated_dataframe
4 changes: 4 additions & 0 deletions evadb/parser/evadb.lark
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ or_replace: OR REPLACE

function_call: function ->function_call
| aggregate_windowed_function ->aggregate_function_call
| string_agg_function

function: simple_id "(" (STAR | function_args)? ")" dotted_id?

Expand All @@ -306,6 +307,9 @@ aggregate_windowed_function: aggregate_function_name "(" function_arg ")"

aggregate_function_name: AVG | MAX | MIN | SUM | FIRST | LAST | SEGMENT

string_agg_function: "STRING_AGG" LR_BRACKET expression COMMA string_literal RR_BRACKET


function_args: (function_arg) ("," function_arg)*

function_arg: constant | expression
Expand Down
10 changes: 10 additions & 0 deletions evadb/parser/lark_visitor/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,14 @@ def get_aggregate_function_type(self, agg_func_name):
agg_func_type = ExpressionType.AGGREGATION_LAST
elif agg_func_name == "SEGMENT":
agg_func_type = ExpressionType.AGGREGATION_SEGMENT
elif agg_func_name == "STRING_AGG":
agg_func_type = ExpressionType.AGGREGATION_STRING_AGG
return agg_func_type

def aggregate_windowed_function(self, tree):
agg_func_arg = None
agg_func_name = None
agg_func_args = []

for child in tree.children:
if isinstance(child, Tree):
Expand All @@ -156,6 +159,13 @@ def aggregate_windowed_function(self, tree):
else:
agg_func_arg = TupleValueExpression(name="id")

if agg_func_name == "STRING_AGG":
if len(agg_func_args) != 2:
raise ValueError("String Agg requires exactly two arguments")
agg_func_type = self.get_aggregate_function_type(agg_func_name)
agg_expr = AggregationExpression(agg_func_type, None, *agg_func_args)
return agg_expr

agg_func_type = self.get_aggregate_function_type(agg_func_name)
agg_expr = AggregationExpression(agg_func_type, None, agg_func_arg)
return agg_expr