diff --git a/src/sayvai_tools/tools/__init__.py b/src/sayvai_tools/tools/__init__.py index c0426fc..cec5a8f 100644 --- a/src/sayvai_tools/tools/__init__.py +++ b/src/sayvai_tools/tools/__init__.py @@ -28,6 +28,7 @@ from sayvai_tools.tools.youtube.comments import (ListCommentRepliesTool, ReplyToCommentTool) from sayvai_tools.tools.youtube.utils import get_youtube_credentials +from sayvai_tools.tools.collect_data_from_user import CollectUserDataTool, create_data_model __all__: List[str] = [ "ConversationalHuman", @@ -54,6 +55,8 @@ "ListCommentRepliesTool", "ReplyToCommentTool", "load_tools", + "CollectUserDataTool", + "create_data_model" ] diff --git a/src/sayvai_tools/tools/collect_data_from_user/__init__.py b/src/sayvai_tools/tools/collect_data_from_user/__init__.py new file mode 100644 index 0000000..1b298d2 --- /dev/null +++ b/src/sayvai_tools/tools/collect_data_from_user/__init__.py @@ -0,0 +1,6 @@ +from sayvai_tools.tools.collect_data_from_user.collect_data_from_user import CollectUserDataTool, create_data_model + +__all__ = [ + "CollectUserDataTool", + "create_data_model" +] diff --git a/src/sayvai_tools/tools/collect_data_from_user/collect_data_from_user.py b/src/sayvai_tools/tools/collect_data_from_user/collect_data_from_user.py new file mode 100644 index 0000000..208d143 --- /dev/null +++ b/src/sayvai_tools/tools/collect_data_from_user/collect_data_from_user.py @@ -0,0 +1,32 @@ +from typing import Type, Optional, Dict + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.tools import BaseTool +from pydantic import BaseModel, create_model, Field + + +def create_data_model(fields: Dict[str, str]) -> Type[BaseModel]: + """ + Dynamically creates a Pydantic model from the fields dictionary. + + :param fields: A dictionary where keys are field names and values are descriptions. + :return: A dynamically created Pydantic model. + """ + field_definitions = {field: (str, Field(description=desc)) for field, desc in fields.items()} + return create_model('UserData', **field_definitions) + + +class CollectUserDataTool(BaseTool): + name: str = "user_data_collector" + description: str = "Collects data from the user based on specified fields." + + args_schema: Type[BaseModel] + + def _run(self, run_manager: Optional[CallbackManagerForToolRun] = None, **kwargs) -> bool: + try: + # Validate data using the dynamically created model + print("Collected and validated data:", kwargs) + return True + except Exception as e: + print("Error in data collection or validation:", str(e)) + return False