Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

fix mypy errors #89

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 136 additions & 97 deletions examples/fast-api-server/poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions examples/fast-api-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ mypy = "^1.7.1"
pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
ruff = "^0.1.7"
pandas-stubs = "^2.1.4.231227"
types-jsonschema = "^4.20.0.0"

[build-system]
requires = ["poetry-core"]
Expand Down
3 changes: 0 additions & 3 deletions packages/openassistants/.mypy.ini

This file was deleted.

113 changes: 63 additions & 50 deletions packages/openassistants/openassistants/contrib/advisor_function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from typing import Literal, Sequence

import httpx
import pandas as pd
import requests
jordan-wu-97 marked this conversation as resolved.
Show resolved Hide resolved
from openassistants.data_models.function_output import (
DataFrameOutput,
FunctionOutput,
Expand All @@ -28,6 +28,12 @@ async def advisor_query(
BEARER_TOKEN = os.getenv("ADVISOR_BEARER_TOKEN")
ADVISOR_API_BASE = os.getenv("ADVISOR_API_BASE")

if ADVISOR_API_BASE is None:
raise ValueError("ADVISOR_API_BASE environment variable not set")

if BEARER_TOKEN is None:
raise ValueError("ADVISOR_BEARER_TOKEN environment variable not set")

jordan-wu-97 marked this conversation as resolved.
Show resolved Hide resolved
# Headers for the requests
HEADERS = {
"accept": "application/json",
Expand All @@ -41,63 +47,70 @@ async def advisor_query(
outputs += "Searching for the relevant dataset... \n"
yield [TextOutput(text=outputs)]

response = requests.post(
ADVISOR_API_BASE + "/v0/datasets/recommend",
headers=HEADERS,
json=recommend_payload,
)
if response.status_code != 200:
print("Error in recommend step")
return
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(
ADVISOR_API_BASE + "/v0/datasets/recommend",
headers=HEADERS,
json=recommend_payload,
)

datasets = response.json().get("datasets")
if not datasets:
# Override yield
yield [TextOutput(text="No relevant datasets found. \n")]
return
if response.status_code != 200:
print("Error in recommend step")
return

outputs += f"Found relevant dataset: `{datasets[0]['id']}`. \n"
yield [TextOutput(text=outputs)]
datasets = response.json().get("datasets")
if not datasets:
# Override yield
yield [TextOutput(text="No relevant datasets found. \n")]
return

# Step 2: Generate SQL based on the dataset and user query
prompt_payload = {"prompt": user_query, "dataset": datasets[0], "timeout": 60}
outputs += f"Found relevant dataset: `{datasets[0]['id']}`. \n"
yield [TextOutput(text=outputs)]

outputs += "Generating SQL query... \n"
yield [TextOutput(text=outputs)]
# Step 2: Generate SQL based on the dataset and user query
prompt_payload = {"prompt": user_query, "dataset": datasets[0], "timeout": 60}

response = requests.post(
ADVISOR_API_BASE + "/v0/datasets/prompt", headers=HEADERS, json=prompt_payload
)
if response.status_code != 200:
print("Error in prompt step")
return

sql_query = response.json().get("response", {}).get("sql")
if not sql_query:
print("No SQL query found in prompt response")
return

# Step 3: Execute SQL
execute_payload = {
"query": {"type": "sql", "sql": sql_query},
"dataset": datasets[0],
"timeout": 60,
"mode": "nonblocking",
}
outputs += "Generating SQL query... \n"
yield [TextOutput(text=outputs)]

outputs += "Running SQL query... \n"
yield [TextOutput(text=outputs)]
response = await client.post(
ADVISOR_API_BASE + "/v0/datasets/prompt",
headers=HEADERS,
json=prompt_payload,
)
if response.status_code != 200:
print("Error in prompt step")
return

response = requests.post(
ADVISOR_API_BASE + "/v0/datasets/execute", headers=HEADERS, json=execute_payload
)
if response.status_code != 200 or "error" in response.json():
print("Error in execute step")
return
sql_query = response.json().get("response", {}).get("sql")
if not sql_query:
print("No SQL query found in prompt response")
return

# Step 3: Execute SQL
execute_payload = {
"query": {"type": "sql", "sql": sql_query},
"dataset": datasets[0],
"timeout": 60,
"mode": "nonblocking",
}

outputs += "Running SQL query... \n"
yield [TextOutput(text=outputs)]

response = await client.post(
ADVISOR_API_BASE + "/v0/datasets/execute",
headers=HEADERS,
json=execute_payload,
)

if response.status_code != 200 or "error" in response.json():
print("Error in execute step")
return

# Parsing the result into a DataFrame
result = response.json().get("result", {}).get("dataframe", {})
error = response.json().get("error")
# Parsing the result into a DataFrame
result = response.json().get("result", {}).get("dataframe", {})
error = response.json().get("error")

if error:
yield [TextOutput(text=f"Failed to execute SQL: {error['message']}")]
Expand Down
19 changes: 12 additions & 7 deletions packages/openassistants/openassistants/contrib/text_response.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from typing import Annotated, List, Literal, Sequence
import pandas as pd

from openassistants.data_models.function_output import FunctionOutput,TextOutput,SuggestedPrompt,FollowUpsOutput
import pandas as pd
from openassistants.data_models.function_output import (
FollowUpsOutput,
FunctionOutput,
SuggestedPrompt,
TextOutput,
)
from openassistants.functions.base import BaseFunction, FunctionExecutionDependency
from openassistants.functions.utils import AsyncStreamVersion
from openassistants.utils.strings import resolve_str_template
from pydantic import Field


class TextResponseFunction(BaseFunction):
type: Literal["TextResponseFunction"] = "TextRespnseFunction"
type: Literal["TextResponseFunction"] = "TextResponseFunction"
text_response: str
suggested_follow_ups: Annotated[List[SuggestedPrompt], Field(default_factory=list)]

async def execute(
self, deps: FunctionExecutionDependency
) -> AsyncStreamVersion[Sequence[FunctionOutput]]:

results: List[FunctionOutput] = []


results.extend([TextOutput(text=self.text_response)])
yield results
Expand All @@ -29,7 +32,9 @@ async def execute(
FollowUpsOutput(
follow_ups=[
SuggestedPrompt(
title=resolve_str_template(template.title, dfs=pd.DataFrame()),
title=resolve_str_template(
template.title, dfs=pd.DataFrame()
),
prompt=resolve_str_template(
template.prompt, dfs=pd.DataFrame()
),
Expand All @@ -39,4 +44,4 @@ async def execute(
)
]
)
yield results
yield results
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Annotated, Any, Literal

from openassistants.data_models.serialized_dataframe import SerializedDataFrame
from pydantic import BaseModel, Field

from openassistants.data_models.serialized_dataframe import SerializedDataFrame


class SuggestedPrompt(BaseModel):
title: str
Expand Down
16 changes: 13 additions & 3 deletions packages/openassistants/openassistants/functions/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
import json
from json.decoder import JSONDecodeError
from pathlib import Path
from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

from langchain.chains.openai_functions.openapi import openapi_spec_to_openai_fn
from langchain_community.utilities.openapi import OpenAPISpec
Expand Down Expand Up @@ -74,7 +84,7 @@ def read(self, function_id: str) -> Optional[BaseFunction]:
except Exception as e:
raise RuntimeError(f"Failed to load: {function_id}") from e

async def aread_all(self) -> List[BaseFunction]:
async def aread_all(self) -> List[IBaseFunction]:
ids = self.list_ids()
return [self.read(f_id) for f_id in ids] # type: ignore

Expand All @@ -85,7 +95,7 @@ def list_ids(self) -> List[str]:


class PythonCRUD(FunctionCRUD):
def __init__(self, functions: List[IBaseFunction]):
def __init__(self, functions: Sequence[IBaseFunction]):
self.functions = functions

def read(self, slug: str) -> Optional[IBaseFunction]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)


async def last_value(src: AsyncStreamVersion) -> T:
async def last_value(src: AsyncStreamVersion[T]) -> T:
last = None
try:
async for last in src:
Expand Down
10 changes: 5 additions & 5 deletions packages/openassistants/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 21 additions & 1 deletion packages/openassistants/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ tabulate = "^0.9.0"
pydantic = "^2.5.2"
langchain-community = "^0.0.6"
openapi-pydantic = "^0.3.2"
requests = "^2.31.0"
httpx = "^0.26.0"

[tool.poetry.extras]
sql = ["sqlalchemy"]
Expand All @@ -42,6 +42,26 @@ coverage = "^7.3.2"
types-pyyaml = "^6.0.12.12"
ruff = "^0.1.7"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.mypy]
python_version = "3.10"
plugins = [
"pydantic.v1.mypy",
"pydantic.mypy"
]
files = [
"openassistants",
"tests"
]

[[tool.mypy.overrides]]
module = [
"ruamel",
"duckduckgo_search",
"duckduckgo_search.exceptions"
]
ignore_missing_imports = true