diff --git a/src/funcchain/components.py b/src/funcchain/components.py index 3e73866..4654dcf 100644 --- a/src/funcchain/components.py +++ b/src/funcchain/components.py @@ -15,6 +15,9 @@ class Route(TypedDict): class ChatRouter(BaseModel): routes: Routes + class Config: + arbitrary_types_allowed = True + @field_validator("routes") def validate_routes(cls, v: Routes) -> Routes: if "default" not in v.keys(): diff --git a/tests/router_test.py b/tests/router_test.py new file mode 100644 index 0000000..67b38a6 --- /dev/null +++ b/tests/router_test.py @@ -0,0 +1,40 @@ +from funcchain.components import ChatRouter + + +def handle_pdf_requests(user_query: str) -> str: + return f"Handling PDF requests with user query: {user_query}" + + +def handle_csv_requests(user_query: str) -> str: + return f"Handling CSV requests with user query: {user_query}" + + +def handle_default_requests(user_query: str) -> str: + return f"Handling DEFAULT requests with user query: {user_query}" + + +router = ChatRouter( + routes={ + "pdf": { + "handler": handle_pdf_requests, + "description": "Call this for requests including PDF Files.", + }, + "csv": { + "handler": handle_csv_requests, + "description": "Call this for requests including CSV Files.", + }, + "default": handle_default_requests, + }, +) + + +def test_router() -> None: + assert "Handling CSV" in router.invoke_route("Can you summarize this csv?") + + assert "Handling PDF" in router.invoke_route("Can you summarize this pdf?") + + assert "Handling DEFAULT" in router.invoke_route("Hey, whatsup?") + + +if __name__ == "__main__": + test_router()