diff --git a/packages/openassistants/openassistants/contrib/duckdb_query.py b/packages/openassistants/openassistants/contrib/duckdb_query.py index a3c4b68..d2d04c5 100644 --- a/packages/openassistants/openassistants/contrib/duckdb_query.py +++ b/packages/openassistants/openassistants/contrib/duckdb_query.py @@ -2,28 +2,27 @@ from typing import List, Literal import pandas as pd -from openassistants.contrib.sqlalchemy_query import QueryFunction +from openassistants.contrib.sqlalchemy_query import SQLAlchemyFunction from openassistants.functions.base import FunctionExecutionDependency -from sqlalchemy import Engine, create_engine +from sqlalchemy import create_engine -class DuckDBQueryFunction(QueryFunction): +class DuckDBQueryFunction(SQLAlchemyFunction): type: Literal["DuckDBQueryFunction"] = "DuckDBQueryFunction" # type: ignore dataset: str - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._engine = create_engine("duckdb:///:memory:") + def __init__(self, **kwargs): + super().__init__(engine=create_engine("duckdb:///:memory:"), **kwargs) async def _execute_sqls( - self, engine: Engine, deps: FunctionExecutionDependency + self, deps: FunctionExecutionDependency ) -> List[pd.DataFrame]: # Set workdir to the dataset path # Capture the current working directory cwd = os.getcwd() try: os.chdir(self.dataset) - res = await super()._execute_sqls(engine, deps) + res = await super()._execute_sqls(deps) finally: # Restore the working directory even if an error occurs os.chdir(cwd) diff --git a/packages/openassistants/openassistants/contrib/sqlalchemy_query.py b/packages/openassistants/openassistants/contrib/sqlalchemy_query.py index d5d11fb..7411ffb 100644 --- a/packages/openassistants/openassistants/contrib/sqlalchemy_query.py +++ b/packages/openassistants/openassistants/contrib/sqlalchemy_query.py @@ -1,3 +1,4 @@ +import abc import asyncio from typing import Annotated, Any, List, Literal, Sequence @@ -67,24 +68,18 @@ def _opas_to_summarization_lc( return lc_messages -class QueryFunction(BaseFunction): +class QueryFunction(BaseFunction, abc.ABC): type: Literal["QueryFunction"] = "QueryFunction" sqls: List[str] visualizations: List[str] summarization: str suggested_follow_ups: Annotated[List[SuggestedPrompt], Field(default_factory=list)] - _engine: Engine = PrivateAttr() + @abc.abstractmethod async def _execute_sqls( - self, engine: Engine, deps: FunctionExecutionDependency + self, deps: FunctionExecutionDependency ) -> List[pd.DataFrame]: - res: List[pd.DataFrame] = await asyncio.gather( # type: ignore - *[ - run_in_threadpool(run_sql, engine, sql, deps.arguments) - for sql in self.sqls - ] - ) - return res + pass async def _execute_visualizations( self, dfs: List[pd.DataFrame], deps: FunctionExecutionDependency @@ -93,7 +88,7 @@ async def _execute_visualizations( *[execute_visualization(viz, dfs) for viz in self.visualizations] ) - async def execute_summarization( + async def _execute_summarization( self, dfs: List[pd.DataFrame], deps: FunctionExecutionDependency ) -> AsyncStreamVersion[str]: chat_continued = [ @@ -159,7 +154,7 @@ async def execute( results: List[FunctionOutput] = [] - dataframes = await self._execute_sqls(self._engine, deps) + dataframes = await self._execute_sqls(deps) results.extend( [ DataFrameOutput(dataframe=SerializedDataFrame.from_pd(df)) @@ -180,7 +175,7 @@ async def execute( # Add summarization summarization_text = "" - async for summarization_text in self.execute_summarization(dataframes, deps): + async for summarization_text in self._execute_summarization(dataframes, deps): yield results + [TextOutput(text=summarization_text)] results.extend([TextOutput(text=summarization_text)]) @@ -205,3 +200,22 @@ async def execute( ) yield results + + +class SQLAlchemyFunction(QueryFunction, abc.ABC): + _engine: Engine = PrivateAttr() + + def __init__(self, engine: Engine, **kwargs): + super().__init__(**kwargs) + self._engine = engine + + async def _execute_sqls( + self, deps: FunctionExecutionDependency + ) -> List[pd.DataFrame]: + res: List[pd.DataFrame] = await asyncio.gather( # type: ignore + *[ + run_in_threadpool(run_sql, self._engine, sql, deps.arguments) + for sql in self.sqls + ] + ) + return res