Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Commit

Permalink
fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jordan-wu-97 committed Jan 5, 2024
1 parent 7875f4b commit 6d47661
Show file tree
Hide file tree
Showing 12 changed files with 718 additions and 605 deletions.
1,121 changes: 590 additions & 531 deletions examples/fast-api-server/poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions examples/fast-api-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ mypy = "^1.7.1"
pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
ruff = "^0.1.7"
pandas-stubs = "^2.1.4.231227"
types-jsonschema = "^4.20.0.0"

[build-system]
requires = ["poetry-core"]
Expand Down
3 changes: 0 additions & 3 deletions packages/openassistants/.mypy.ini

This file was deleted.

113 changes: 63 additions & 50 deletions packages/openassistants/openassistants/contrib/advisor_function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from typing import Literal, Sequence

import httpx
import pandas as pd
import requests
from openassistants.data_models.function_output import (
DataFrameOutput,
FunctionOutput,
Expand All @@ -28,6 +28,12 @@ async def advisor_query(
BEARER_TOKEN = os.getenv("ADVISOR_BEARER_TOKEN")
ADVISOR_API_BASE = os.getenv("ADVISOR_API_BASE")

if ADVISOR_API_BASE is None:
raise ValueError("ADVISOR_API_BASE environment variable not set")

if BEARER_TOKEN is None:
raise ValueError("ADVISOR_BEARER_TOKEN environment variable not set")

# Headers for the requests
HEADERS = {
"accept": "application/json",
Expand All @@ -41,63 +47,70 @@ async def advisor_query(
outputs += "Searching for the relevant dataset... \n"
yield [TextOutput(text=outputs)]

response = requests.post(
ADVISOR_API_BASE + "/v0/datasets/recommend",
headers=HEADERS,
json=recommend_payload,
)
if response.status_code != 200:
print("Error in recommend step")
return
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(
ADVISOR_API_BASE + "/v0/datasets/recommend",
headers=HEADERS,
json=recommend_payload,
)

datasets = response.json().get("datasets")
if not datasets:
# Override yield
yield [TextOutput(text="No relevant datasets found. \n")]
return
if response.status_code != 200:
print("Error in recommend step")
return

outputs += f"Found relevant dataset: `{datasets[0]['id']}`. \n"
yield [TextOutput(text=outputs)]
datasets = response.json().get("datasets")
if not datasets:
# Override yield
yield [TextOutput(text="No relevant datasets found. \n")]
return

# Step 2: Generate SQL based on the dataset and user query
prompt_payload = {"prompt": user_query, "dataset": datasets[0], "timeout": 60}
outputs += f"Found relevant dataset: `{datasets[0]['id']}`. \n"
yield [TextOutput(text=outputs)]

outputs += "Generating SQL query... \n"
yield [TextOutput(text=outputs)]
# Step 2: Generate SQL based on the dataset and user query
prompt_payload = {"prompt": user_query, "dataset": datasets[0], "timeout": 60}

response = requests.post(
ADVISOR_API_BASE + "/v0/datasets/prompt", headers=HEADERS, json=prompt_payload
)
if response.status_code != 200:
print("Error in prompt step")
return

sql_query = response.json().get("response", {}).get("sql")
if not sql_query:
print("No SQL query found in prompt response")
return

# Step 3: Execute SQL
execute_payload = {
"query": {"type": "sql", "sql": sql_query},
"dataset": datasets[0],
"timeout": 60,
"mode": "nonblocking",
}
outputs += "Generating SQL query... \n"
yield [TextOutput(text=outputs)]

outputs += "Running SQL query... \n"
yield [TextOutput(text=outputs)]
response = await client.post(
ADVISOR_API_BASE + "/v0/datasets/prompt",
headers=HEADERS,
json=prompt_payload,
)
if response.status_code != 200:
print("Error in prompt step")
return

response = requests.post(
ADVISOR_API_BASE + "/v0/datasets/execute", headers=HEADERS, json=execute_payload
)
if response.status_code != 200 or "error" in response.json():
print("Error in execute step")
return
sql_query = response.json().get("response", {}).get("sql")
if not sql_query:
print("No SQL query found in prompt response")
return

# Step 3: Execute SQL
execute_payload = {
"query": {"type": "sql", "sql": sql_query},
"dataset": datasets[0],
"timeout": 60,
"mode": "nonblocking",
}

outputs += "Running SQL query... \n"
yield [TextOutput(text=outputs)]

response = await client.post(
ADVISOR_API_BASE + "/v0/datasets/execute",
headers=HEADERS,
json=execute_payload,
)

if response.status_code != 200 or "error" in response.json():
print("Error in execute step")
return

# Parsing the result into a DataFrame
result = response.json().get("result", {}).get("dataframe", {})
error = response.json().get("error")
# Parsing the result into a DataFrame
result = response.json().get("result", {}).get("dataframe", {})
error = response.json().get("error")

if error:
yield [TextOutput(text=f"Failed to execute SQL: {error['message']}")]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from typing import Dict, List, Literal, Optional, Sequence

from duckduckgo_search.exceptions import RateLimitException
from openassistants.data_models.function_output import FunctionOutput, TextOutput
from openassistants.functions.base import BaseFunction, FunctionExecutionDependency
from openassistants.functions.utils import AsyncStreamVersion
from duckduckgo_search.exceptions import RateLimitException


def ddgs_text(query: str, max_results: Optional[int] = None) -> List[Dict[str, str]]:
"""Run query through DuckDuckGo text search and return results."""
Expand Down Expand Up @@ -48,7 +49,11 @@ async def execute(
try:
results = ddgs_text(query, max_results=4)
except RateLimitException:
yield [TextOutput(text="DuckDuckGo is currently unavailable, please try again later.")]
yield [
TextOutput(
text="DuckDuckGo is currently unavailable, please try again later." # noqa: E501
)
]
return

formatted_results = "\n\n".join(
Expand Down
19 changes: 12 additions & 7 deletions packages/openassistants/openassistants/contrib/text_response.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
from typing import Annotated, List, Literal, Sequence
import pandas as pd

from openassistants.data_models.function_output import FunctionOutput,TextOutput,SuggestedPrompt,FollowUpsOutput
import pandas as pd
from openassistants.data_models.function_output import (
FollowUpsOutput,
FunctionOutput,
SuggestedPrompt,
TextOutput,
)
from openassistants.functions.base import BaseFunction, FunctionExecutionDependency
from openassistants.functions.utils import AsyncStreamVersion
from openassistants.utils.strings import resolve_str_template
from pydantic import Field


class TextResponseFunction(BaseFunction):
type: Literal["TextResponseFunction"] = "TextRespnseFunction"
type: Literal["TextResponseFunction"] = "TextResponseFunction"
text_response: str
suggested_follow_ups: Annotated[List[SuggestedPrompt], Field(default_factory=list)]

async def execute(
self, deps: FunctionExecutionDependency
) -> AsyncStreamVersion[Sequence[FunctionOutput]]:

results: List[FunctionOutput] = []


results.extend([TextOutput(text=self.text_response)])
yield results
Expand All @@ -29,7 +32,9 @@ async def execute(
FollowUpsOutput(
follow_ups=[
SuggestedPrompt(
title=resolve_str_template(template.title, dfs=pd.DataFrame()),
title=resolve_str_template(
template.title, dfs=pd.DataFrame()
),
prompt=resolve_str_template(
template.prompt, dfs=pd.DataFrame()
),
Expand All @@ -39,4 +44,4 @@ async def execute(
)
]
)
yield results
yield results
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Annotated, List, Literal, Optional

from langchain.schema.messages import BaseMessage, merge_content
from pydantic import BaseModel, Field

from openassistants.data_models.function_input import FunctionCall, FunctionInputRequest
from openassistants.data_models.function_output import FunctionOutput
from pydantic import BaseModel, Field


class OpasUserMessage(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Annotated, Any, Literal

from openassistants.data_models.serialized_dataframe import SerializedDataFrame
from pydantic import BaseModel, Field

from openassistants.data_models.serialized_dataframe import SerializedDataFrame


class SuggestedPrompt(BaseModel):
title: str
Expand Down
16 changes: 13 additions & 3 deletions packages/openassistants/openassistants/functions/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
import json
from json.decoder import JSONDecodeError
from pathlib import Path
from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

from langchain.chains.openai_functions.openapi import openapi_spec_to_openai_fn
from langchain_community.utilities.openapi import OpenAPISpec
Expand Down Expand Up @@ -74,7 +84,7 @@ def read(self, function_id: str) -> Optional[BaseFunction]:
except Exception as e:
raise RuntimeError(f"Failed to load: {function_id}") from e

async def aread_all(self) -> List[BaseFunction]:
async def aread_all(self) -> List[IBaseFunction]:
ids = self.list_ids()
return [self.read(f_id) for f_id in ids] # type: ignore

Expand All @@ -85,7 +95,7 @@ def list_ids(self) -> List[str]:


class PythonCRUD(FunctionCRUD):
def __init__(self, functions: List[IBaseFunction]):
def __init__(self, functions: Sequence[IBaseFunction]):
self.functions = functions

def read(self, slug: str) -> Optional[IBaseFunction]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)


async def last_value(src: AsyncStreamVersion) -> T:
async def last_value(src: AsyncStreamVersion[T]) -> T:
last = None
try:
async for last in src:
Expand Down
10 changes: 5 additions & 5 deletions packages/openassistants/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 21 additions & 1 deletion packages/openassistants/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ tabulate = "^0.9.0"
pydantic = "^2.5.2"
langchain-community = "^0.0.6"
openapi-pydantic = "^0.3.2"
requests = "^2.31.0"
httpx = "^0.26.0"

[tool.poetry.extras]
sql = ["sqlalchemy"]
Expand All @@ -42,6 +42,26 @@ coverage = "^7.3.2"
types-pyyaml = "^6.0.12.12"
ruff = "^0.1.7"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.mypy]
python_version = "3.10"
plugins = [
"pydantic.v1.mypy",
"pydantic.mypy"
]
files = [
"openassistants",
"tests"
]

[[tool.mypy.overrides]]
module = [
"ruamel",
"duckduckgo_search",
"duckduckgo_search.exceptions"
]
ignore_missing_imports = true

0 comments on commit 6d47661

Please sign in to comment.