Skip to content

Commit

Permalink
Merge pull request #1206 from julep-ai/x/none-model-validation
Browse files Browse the repository at this point in the history
fix(agents-api): fix ``model=None`` case for chat endpoint
  • Loading branch information
Ahmad-mtos authored Mar 3, 2025
2 parents e849e5c + b2599e5 commit 3aaf949
Show file tree
Hide file tree
Showing 5 changed files with 540 additions and 526 deletions.
5 changes: 2 additions & 3 deletions agents-api/agents_api/routers/sessions/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ async def render_chat_input(
session_id: UUID,
chat_input: ChatInput,
) -> tuple[list[dict], list[DocReference], list[dict] | None, dict, list[dict], ChatContext]:
if chat_input.model:
await validate_model(chat_input.model)

# check if the developer is paid
if "paid" not in developer.tags:
# get the session length
Expand All @@ -83,6 +80,8 @@ async def render_chat_input(
chat_context.merge_settings(chat_input)
settings: dict = chat_context.settings or {}

await validate_model(settings.get("model"))

# Get the past messages and doc references
past_messages, doc_references = await gather_messages(
developer=developer,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/utils/model_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ...clients.litellm import get_model_list


async def validate_model(model_name: str) -> None:
async def validate_model(model_name: str | None) -> None:
"""
Validates if a given model name is available in LiteLLM.
Raises HTTPException if model is not available.
Expand Down
15 changes: 13 additions & 2 deletions agents-api/tests/test_model_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tests.fixtures import SAMPLE_MODELS


@test("validate_model: succeeds when model is available")
@test("validate_model: succeeds when model is available in model list")
async def _():
# Use async context manager for patching
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
Expand All @@ -16,7 +16,7 @@ async def _():
mock_get_models.assert_called_once()


@test("validate_model: fails when model is unavailable")
@test("validate_model: fails when model is unavailable in model list")
async def _():
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
mock_get_models.return_value = SAMPLE_MODELS
Expand All @@ -26,3 +26,14 @@ async def _():
assert exc.raised.status_code == 400
assert "Model non-existent-model not available" in exc.raised.detail
mock_get_models.assert_called_once()


@test("validate_model: fails when model is None")
async def _():
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
mock_get_models.return_value = SAMPLE_MODELS
with raises(HTTPException) as exc:
await validate_model(None)

assert exc.raised.status_code == 400
assert "Model None not available" in exc.raised.detail
Loading

0 comments on commit 3aaf949

Please sign in to comment.