Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Commit

Permalink
clean up query function to be more extendable (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
jordan-wu-97 authored Dec 20, 2023
1 parent 183934f commit 027b3e0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
15 changes: 7 additions & 8 deletions packages/openassistants/openassistants/contrib/duckdb_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,27 @@
from typing import List, Literal

import pandas as pd
from openassistants.contrib.sqlalchemy_query import QueryFunction
from openassistants.contrib.sqlalchemy_query import SQLAlchemyFunction
from openassistants.functions.base import FunctionExecutionDependency
from sqlalchemy import Engine, create_engine
from sqlalchemy import create_engine


class DuckDBQueryFunction(QueryFunction):
class DuckDBQueryFunction(SQLAlchemyFunction):
type: Literal["DuckDBQueryFunction"] = "DuckDBQueryFunction" # type: ignore
dataset: str

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._engine = create_engine("duckdb:///:memory:")
def __init__(self, **kwargs):
super().__init__(engine=create_engine("duckdb:///:memory:"), **kwargs)

async def _execute_sqls(
self, engine: Engine, deps: FunctionExecutionDependency
self, deps: FunctionExecutionDependency
) -> List[pd.DataFrame]:
# Set workdir to the dataset path
# Capture the current working directory
cwd = os.getcwd()
try:
os.chdir(self.dataset)
res = await super()._execute_sqls(engine, deps)
res = await super()._execute_sqls(deps)
finally:
# Restore the working directory even if an error occurs
os.chdir(cwd)
Expand Down
40 changes: 27 additions & 13 deletions packages/openassistants/openassistants/contrib/sqlalchemy_query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import asyncio
from typing import Annotated, Any, List, Literal, Sequence

Expand Down Expand Up @@ -67,24 +68,18 @@ def _opas_to_summarization_lc(
return lc_messages


class QueryFunction(BaseFunction):
class QueryFunction(BaseFunction, abc.ABC):
type: Literal["QueryFunction"] = "QueryFunction"
sqls: List[str]
visualizations: List[str]
summarization: str
suggested_follow_ups: Annotated[List[SuggestedPrompt], Field(default_factory=list)]
_engine: Engine = PrivateAttr()

@abc.abstractmethod
async def _execute_sqls(
self, engine: Engine, deps: FunctionExecutionDependency
self, deps: FunctionExecutionDependency
) -> List[pd.DataFrame]:
res: List[pd.DataFrame] = await asyncio.gather( # type: ignore
*[
run_in_threadpool(run_sql, engine, sql, deps.arguments)
for sql in self.sqls
]
)
return res
pass

async def _execute_visualizations(
self, dfs: List[pd.DataFrame], deps: FunctionExecutionDependency
Expand All @@ -93,7 +88,7 @@ async def _execute_visualizations(
*[execute_visualization(viz, dfs) for viz in self.visualizations]
)

async def execute_summarization(
async def _execute_summarization(
self, dfs: List[pd.DataFrame], deps: FunctionExecutionDependency
) -> AsyncStreamVersion[str]:
chat_continued = [
Expand Down Expand Up @@ -159,7 +154,7 @@ async def execute(

results: List[FunctionOutput] = []

dataframes = await self._execute_sqls(self._engine, deps)
dataframes = await self._execute_sqls(deps)
results.extend(
[
DataFrameOutput(dataframe=SerializedDataFrame.from_pd(df))
Expand All @@ -180,7 +175,7 @@ async def execute(
# Add summarization
summarization_text = ""

async for summarization_text in self.execute_summarization(dataframes, deps):
async for summarization_text in self._execute_summarization(dataframes, deps):
yield results + [TextOutput(text=summarization_text)]

results.extend([TextOutput(text=summarization_text)])
Expand All @@ -205,3 +200,22 @@ async def execute(
)

yield results


class SQLAlchemyFunction(QueryFunction, abc.ABC):
_engine: Engine = PrivateAttr()

def __init__(self, engine: Engine, **kwargs):
super().__init__(**kwargs)
self._engine = engine

async def _execute_sqls(
self, deps: FunctionExecutionDependency
) -> List[pd.DataFrame]:
res: List[pd.DataFrame] = await asyncio.gather( # type: ignore
*[
run_in_threadpool(run_sql, self._engine, sql, deps.arguments)
for sql in self.sqls
]
)
return res

0 comments on commit 027b3e0

Please sign in to comment.