Skip to content

Commit

Permalink
Fix pydantic version comparison bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaranvpl committed Nov 21, 2024
1 parent 2fb622d commit d76786c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
21 changes: 13 additions & 8 deletions tests/api/openapi/test_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
from autogen.agentchat import ConversableAgent
from fastapi import Body, FastAPI, Query
from packaging.version import Version
from pydantic import BaseModel, Field

from fastagency.api.openapi import OpenAPI
Expand Down Expand Up @@ -87,7 +88,7 @@ def openapi_schema(self, fastapi_app: FastAPI) -> dict[str, Any]:
return fastapi_app.openapi()

def test_openapi_schema(
self, openapi_schema: dict[str, Any], pydantic_version: float
self, openapi_schema: dict[str, Any], pydantic_version: Version
) -> None:
expected = {
"openapi": "3.1.0",
Expand Down Expand Up @@ -352,7 +353,7 @@ def test_openapi_schema(
},
}
pydantic28_delta = '{"paths": {"/items/{item_id}": {"put": {"requestBody": {"content": {"application/json": {"schema": {"allOf": [{"$$ref": "#/components/schemas/Item"}], "title": "Item", "$delete": ["$$ref"]}}}}}}, "/items/": {"post": {"requestBody": {"content": {"application/json": {"schema": {"allOf": [{"$$ref": "#/components/schemas/Item"}], "title": "Item", "$delete": ["$$ref"]}}}}}}}}'
if pydantic_version < 2.9:
if pydantic_version < Version("2.9"):
# print(f"pydantic28_delta = '{jsondiff.diff(expected, openapi_schema, dump=True)}'")
expected = jsondiff.patch(json.dumps(expected), pydantic28_delta, load=True)
# print(openapi_schema)
Expand All @@ -368,7 +369,7 @@ def generated_code_path(self, openapi_schema: dict[str, Any]) -> Iterator[Path]:
yield td

def test_generated_code_main(
self, generated_code_path: Path, pydantic_version: float
self, generated_code_path: Path, pydantic_version: Version
) -> None:
expected_pydantic_v28 = '''# generated by fastapi-codegen:
# filename: openapi.json
Expand Down Expand Up @@ -553,7 +554,9 @@ def delete_item_items__item_id__delete(
pass
'''
expected = (
expected_pydantic_v28 if pydantic_version < 2.9 else expected_pydantic_v29
expected_pydantic_v28
if pydantic_version < Version("2.9")
else expected_pydantic_v29
)
suffix = generated_code_path.name
expected = expected.replace("tmp61z6vu75", suffix)
Expand All @@ -569,7 +572,7 @@ def delete_item_items__item_id__delete(
assert main == expected

def test_generated_code_models(
self, generated_code_path: Path, pydantic_version: float
self, generated_code_path: Path, pydantic_version: Version
) -> None:
expected_pydantic_v28 = """# generated by fastapi-codegen:
# filename: openapi.json
Expand Down Expand Up @@ -681,7 +684,9 @@ class HTTPValidationError(BaseModel):
"""

expected = (
expected_pydantic_v28 if pydantic_version < 2.9 else expected_pydantic_v29
expected_pydantic_v28
if pydantic_version < Version("2.9")
else expected_pydantic_v29
)
assert generated_code_path.exists()
assert generated_code_path.is_dir()
Expand Down Expand Up @@ -721,7 +726,7 @@ def test_register_for_llm(
self,
client: OpenAPI,
azure_gpt35_turbo_16k_llm_config: dict[str, Any],
pydantic_version: float,
pydantic_version: Version,
) -> None:
class JSONEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
Expand Down Expand Up @@ -865,7 +870,7 @@ def default(self, o: Any) -> Any:
tools = agent.llm_config["tools"]
# print(tools)
pydantic28_delta = '{"0": {"function": {"parameters": {"properties": {"body": {"title": "ItemsPostRequest"}}}}}, "2": {"function": {"parameters": {"properties": {"body": {"title": "ItemsItemIdPutRequest"}}}}}}'
if pydantic_version < 2.9:
if pydantic_version < Version("2.9"):
# print(f"pydantic28_delta = '{jsondiff.diff(expected_tools, tools, dump=True)}'")
expected_tools = jsondiff.patch(
json.dumps(expected_tools, cls=JSONEncoder), pydantic28_delta, load=True
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
import uvicorn
from fastapi import FastAPI, Path
from packaging.version import Version
from pydantic import BaseModel
from pydantic import __version__ as version_of_pydantic

Expand Down Expand Up @@ -263,8 +264,8 @@ def fastapi_openapi_url(request: pytest.FixtureRequest) -> Iterator[str]:


@pytest.fixture
def pydantic_version() -> float:
return float(".".join(version_of_pydantic.split(".")[:2]))
def pydantic_version() -> Version:
return Version(version_of_pydantic)


################################################################################
Expand Down

0 comments on commit d76786c

Please sign in to comment.