From 1f649274d1f1a7f7f0d94d52b76602b166cdd1bf Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 5 Dec 2024 22:44:40 +0000 Subject: [PATCH 01/15] feat: tooling init --- application/cache.py | 18 ++--- application/llm/base.py | 14 ++-- application/llm/openai.py | 22 +++++-- application/requirements.txt | 2 +- application/retriever/classic_rag.py | 11 ++-- application/tools/agent.py | 98 ++++++++++++++++++++++++++++ application/tools/base.py | 20 ++++++ application/tools/cryptoprice.py | 73 +++++++++++++++++++++ application/tools/telegram.py | 79 ++++++++++++++++++++++ application/tools/tool_manager.py | 43 ++++++++++++ application/usage.py | 19 ++++-- application/utils.py | 18 ++++- 12 files changed, 384 insertions(+), 33 deletions(-) create mode 100644 application/tools/agent.py create mode 100644 application/tools/base.py create mode 100644 application/tools/cryptoprice.py create mode 100644 application/tools/telegram.py create mode 100644 application/tools/tool_manager.py diff --git a/application/cache.py b/application/cache.py index 33022e45f..7239abacd 100644 --- a/application/cache.py +++ b/application/cache.py @@ -5,6 +5,7 @@ from threading import Lock from application.core.settings import settings from application.utils import get_hash +import sys logger = logging.getLogger(__name__) @@ -23,18 +24,19 @@ def get_redis_instance(): _redis_instance = None return _redis_instance -def gen_cache_key(*messages, model="docgpt"): +def gen_cache_key(messages, model="docgpt", tools=None): if not all(isinstance(msg, dict) for msg in messages): raise ValueError("All messages must be dictionaries.") - messages_str = json.dumps(list(messages), sort_keys=True) - combined = f"{model}_{messages_str}" + messages_str = json.dumps(messages) + tools_str = json.dumps(tools) if tools else "" + combined = f"{model}_{messages_str}_{tools_str}" cache_key = get_hash(combined) return cache_key def gen_cache(func): - def wrapper(self, model, messages, *args, **kwargs): + def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): try: - cache_key = gen_cache_key(*messages) + cache_key = gen_cache_key(messages, model, tools) redis_client = get_redis_instance() if redis_client: try: @@ -44,8 +46,8 @@ def wrapper(self, model, messages, *args, **kwargs): except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - result = func(self, model, messages, *args, **kwargs) - if redis_client: + result = func(self, model, messages, stream, tools, *args, **kwargs) + if redis_client and isinstance(result, str): try: redis_client.set(cache_key, result, ex=1800) except redis.ConnectionError as e: @@ -59,7 +61,7 @@ def wrapper(self, model, messages, *args, **kwargs): def stream_cache(func): def wrapper(self, model, messages, stream, *args, **kwargs): - cache_key = gen_cache_key(*messages) + cache_key = gen_cache_key(messages) logger.info(f"Stream cache key: {cache_key}") redis_client = get_redis_instance() diff --git a/application/llm/base.py b/application/llm/base.py index 1caab5d38..b9b0e5243 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -13,12 +13,12 @@ def _apply_decorator(self, method, decorators, *args, **kwargs): return method(self, *args, **kwargs) @abstractmethod - def _raw_gen(self, model, messages, stream, *args, **kwargs): + def _raw_gen(self, model, messages, stream, tools, *args, **kwargs): pass - def gen(self, model, messages, stream=False, *args, **kwargs): + def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): decorators = [gen_token_usage, gen_cache] - return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) + return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs) @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): @@ -26,4 +26,10 @@ def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): def gen_stream(self, model, messages, stream=True, *args, **kwargs): decorators = [stream_cache, stream_token_usage] - return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) \ No newline at end of file + return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) + + def supports_tools(self): + return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools')) + + def _supports_tools(self): + raise NotImplementedError("Subclass must implement _supports_tools method") \ No newline at end of file diff --git a/application/llm/openai.py b/application/llm/openai.py index f85de6eae..cc2285a12 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -25,14 +25,20 @@ def _raw_gen( model, messages, stream=False, + tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs - ): - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) - - return response.choices[0].message.content + ): + if tools: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, tools=tools, **kwargs + ) + return response.choices[0] + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + return response.choices[0].message.content def _raw_gen_stream( self, @@ -40,6 +46,7 @@ def _raw_gen_stream( model, messages, stream=True, + tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs ): @@ -52,6 +59,9 @@ def _raw_gen_stream( # print(line.choices[0].delta.content, file=sys.stderr) if line.choices[0].delta.content is not None: yield line.choices[0].delta.content + + def _supports_tools(self): + return True class AzureOpenAILLM(OpenAILLM): diff --git a/application/requirements.txt b/application/requirements.txt index 2f28c2ea6..c8f16d85e 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -43,7 +43,7 @@ multidict==6.1.0 mypy-extensions==1.0.0 networkx==3.3 numpy==1.26.4 -openai==1.46.1 +openai==1.57.0 openapi-schema-validator==0.6.2 openapi-spec-validator==0.6.0 openapi3-parser==1.1.18 diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 42e318d20..4ac52bc51 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -2,6 +2,7 @@ from application.core.settings import settings from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator +from application.tools.agent import Agent from application.utils import num_tokens_from_string @@ -90,10 +91,12 @@ def gen(self): ) messages_combine.append({"role": "user", "content": self.question}) - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key - ) - completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + # llm = LLMCreator.create_llm( + # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + # ) + # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key) + completion = agent.gen(messages_combine) for line in completion: yield {"answer": str(line)} diff --git a/application/tools/agent.py b/application/tools/agent.py new file mode 100644 index 000000000..2df14442b --- /dev/null +++ b/application/tools/agent.py @@ -0,0 +1,98 @@ +from application.llm.llm_creator import LLMCreator +from application.core.settings import settings +from application.tools.tool_manager import ToolManager +import json + +tool_tg = { + "name": "telegram_send_message", + "description": "Send a notification to telegram about current chat", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Text to send in the notification" + } + }, + "required": ["text"], + "additionalProperties": False + } +} + +tool_crypto = { + "name": "cryptoprice_get", + "description": "Retrieve the price of a specified cryptocurrency in a given currency", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The cryptocurrency symbol (e.g. BTC)" + }, + "currency": { + "type": "string", + "description": "The currency in which you want the price (e.g. USD)" + } + }, + "required": ["symbol", "currency"], + "additionalProperties": False + } +} + +class Agent: + def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): + # Initialize the LLM with the provided parameters + self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key) + self.gpt_model = gpt_model + # Static tool configuration (to be replaced later) + self.tools = [ + { + "type": "function", + "function": tool_crypto + } + ] + self.tool_config = { + } + + def gen(self, messages): + # Generate initial response from the LLM + resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + + if isinstance(resp, str): + # Yield the response if it's a string and exit + yield resp + return + + while resp.finish_reason == "tool_calls": + # Append the assistant's message to the conversation + messages.append(json.loads(resp.model_dump_json())['message']) + # Handle each tool call + tool_calls = resp.message.tool_calls + for call in tool_calls: + tm = ToolManager(config={}) + call_name = call.function.name + call_args = json.loads(call.function.arguments) + call_id = call.id + # Determine the tool name and load it + tool_name = call_name.split("_")[0] + tool = tm.load_tool(tool_name, tool_config=self.tool_config) + # Execute the tool's action + resp_tool = tool.execute_action(call_name, **call_args) + # Append the tool's response to the conversation + messages.append( + { + "role": "tool", + "content": str(resp_tool), + "tool_call_id": call_id + } + ) + # Generate a new response from the LLM after processing tools + resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + + # If no tool calls are needed, generate the final response + if isinstance(resp, str): + yield resp + else: + completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools) + for line in completion: + yield line diff --git a/application/tools/base.py b/application/tools/base.py new file mode 100644 index 000000000..00cfee3a0 --- /dev/null +++ b/application/tools/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + +class Tool(ABC): + @abstractmethod + def execute_action(self, action_name: str, **kwargs): + pass + + @abstractmethod + def get_actions_metadata(self): + """ + Returns a list of JSON objects describing the actions supported by the tool. + """ + pass + + @abstractmethod + def get_config_requirements(self): + """ + Returns a dictionary describing the configuration requirements for the tool. + """ + pass diff --git a/application/tools/cryptoprice.py b/application/tools/cryptoprice.py new file mode 100644 index 000000000..d7cf61e13 --- /dev/null +++ b/application/tools/cryptoprice.py @@ -0,0 +1,73 @@ +from application.tools.base import Tool +import requests + +class CryptoPriceTool(Tool): + def __init__(self, config): + self.config = config + + def execute_action(self, action_name, **kwargs): + actions = { + "cryptoprice_get": self.get_price + } + + if action_name in actions: + return actions[action_name](**kwargs) + else: + raise ValueError(f"Unknown action: {action_name}") + + def get_price(self, symbol, currency): + """ + Fetches the current price of a given cryptocurrency symbol in the specified currency. + Example: + symbol = "BTC" + currency = "USD" + returns price in USD. + """ + url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}" + response = requests.get(url) + if response.status_code == 200: + data = response.json() + # data will be like {"USD": } if the call is successful + if currency.upper() in data: + return { + "status_code": response.status_code, + "price": data[currency.upper()], + "message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully." + } + else: + return { + "status_code": response.status_code, + "message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}." + } + else: + return { + "status_code": response.status_code, + "message": "Failed to retrieve price." + } + + def get_actions_metadata(self): + return [ + { + "name": "cryptoprice_get", + "description": "Retrieve the price of a specified cryptocurrency in a given currency", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The cryptocurrency symbol (e.g. BTC)" + }, + "currency": { + "type": "string", + "description": "The currency in which you want the price (e.g. USD)" + } + }, + "required": ["symbol", "currency"], + "additionalProperties": False + } + } + ] + + def get_config_requirements(self): + # No specific configuration needed for this tool as it just queries a public endpoint + return {} diff --git a/application/tools/telegram.py b/application/tools/telegram.py new file mode 100644 index 000000000..8210d8e71 --- /dev/null +++ b/application/tools/telegram.py @@ -0,0 +1,79 @@ +from application.tools.base import Tool +import requests + +class TelegramTool(Tool): + def __init__(self, config): + self.config = config + self.chat_id = config.get("chat_id", "142189016") + self.token = config.get("token", "YOUR_TG_TOKEN") + + def execute_action(self, action_name, **kwargs): + actions = { + "telegram_send_message": self.send_message, + "telegram_send_image": self.send_image + } + + if action_name in actions: + return actions[action_name](**kwargs) + else: + raise ValueError(f"Unknown action: {action_name}") + + def send_message(self, text): + print(f"Sending message: {text}") + url = f"https://api.telegram.org/bot{self.token}/sendMessage" + payload = {"chat_id": self.chat_id, "text": text} + response = requests.post(url, data=payload) + return {"status_code": response.status_code, "message": "Message sent"} + + def send_image(self, image_url): + print(f"Sending image: {image_url}") + url = f"https://api.telegram.org/bot{self.token}/sendPhoto" + payload = {"chat_id": self.chat_id, "photo": image_url} + response = requests.post(url, data=payload) + return {"status_code": response.status_code, "message": "Image sent"} + + def get_actions_metadata(self): + return [ + { + "name": "telegram_send_message", + "description": "Send a notification to telegram chat", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Text to send in the notification" + } + }, + "required": ["text"], + "additionalProperties": False + } + }, + { + "name": "telegram_send_image", + "description": "Send an image to the Telegram chat", + "parameters": { + "type": "object", + "properties": { + "image_url": { + "type": "string", + "description": "URL of the image to send" + } + }, + "required": ["image_url"], + "additionalProperties": False + } + } + ] + + def get_config_requirements(self): + return { + "chat_id": { + "type": "string", + "description": "Telegram chat ID to send messages to" + }, + "token": { + "type": "string", + "description": "Bot token for authentication" + } + } diff --git a/application/tools/tool_manager.py b/application/tools/tool_manager.py new file mode 100644 index 000000000..10231cb28 --- /dev/null +++ b/application/tools/tool_manager.py @@ -0,0 +1,43 @@ +import importlib +import inspect +import pkgutil +import os + +from application.tools.base import Tool + +class ToolManager: + def __init__(self, config): + self.config = config + self.tools = {} + self.load_tools() + + def load_tools(self): + tools_dir = os.path.dirname(__file__) + for finder, name, ispkg in pkgutil.iter_modules([tools_dir]): + if name == 'base' or name.startswith('__'): + continue + module = importlib.import_module(f'application.tools.{name}') + for member_name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, Tool) and obj is not Tool: + tool_config = self.config.get(name, {}) + self.tools[name] = obj(tool_config) + + def load_tool(self, tool_name, tool_config): + self.config[tool_name] = tool_config + tools_dir = os.path.dirname(__file__) + module = importlib.import_module(f'application.tools.{tool_name}') + for member_name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, Tool) and obj is not Tool: + return obj(tool_config) + + + def execute_action(self, tool_name, action_name, **kwargs): + if tool_name not in self.tools: + raise ValueError(f"Tool '{tool_name}' not loaded") + return self.tools[tool_name].execute_action(action_name, **kwargs) + + def get_all_actions_metadata(self): + metadata = [] + for tool in self.tools.values(): + metadata.extend(tool.get_actions_metadata()) + return metadata diff --git a/application/usage.py b/application/usage.py index e87ebe385..fe4cd50e6 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,7 +1,7 @@ import sys from datetime import datetime from application.core.mongo_db import MongoDB -from application.utils import num_tokens_from_string +from application.utils import num_tokens_from_string, num_tokens_from_object_or_list mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -21,11 +21,16 @@ def update_token_usage(user_api_key, token_usage): def gen_token_usage(func): - def wrapper(self, model, messages, stream, **kwargs): + def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) - result = func(self, model, messages, stream, **kwargs) - self.token_usage["generated_tokens"] += num_tokens_from_string(result) + if message["content"]: + self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + result = func(self, model, messages, stream, tools, **kwargs) + # check if result is a string + if isinstance(result, str): + self.token_usage["generated_tokens"] += num_tokens_from_string(result) + else: + self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result) update_token_usage(self.user_api_key, self.token_usage) return result @@ -33,11 +38,11 @@ def wrapper(self, model, messages, stream, **kwargs): def stream_token_usage(func): - def wrapper(self, model, messages, stream, **kwargs): + def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) batch = [] - result = func(self, model, messages, stream, **kwargs) + result = func(self, model, messages, stream, tools, **kwargs) for r in result: batch.append(r) yield r diff --git a/application/utils.py b/application/utils.py index 1fc9e3291..3b2eb9f30 100644 --- a/application/utils.py +++ b/application/utils.py @@ -15,9 +15,21 @@ def get_encoding(): def num_tokens_from_string(string: str) -> int: encoding = get_encoding() - num_tokens = len(encoding.encode(string)) - return num_tokens - + if isinstance(string, str): + num_tokens = len(encoding.encode(string)) + return num_tokens + else: + return 0 + +def num_tokens_from_object_or_list(thing): + if isinstance(thing, list): + return sum([num_tokens_from_object_or_list(x) for x in thing]) + elif isinstance(thing, dict): + return sum([num_tokens_from_object_or_list(x) for x in thing.values()]) + elif isinstance(thing, str): + return num_tokens_from_string(thing) + else: + return 0 def count_tokens_docs(docs): docs_content = "" From 863950963f867e6fc5af2beb54ce53db7b7bc75e Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Dec 2024 22:19:01 +0000 Subject: [PATCH 02/15] simple user tool handling endpoint --- application/api/user/routes.py | 121 +++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 6a2f3bea3..814490dd6 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -30,6 +30,7 @@ token_usage_collection = db["token_usage"] shared_conversations_collections = db["shared_conversations"] user_logs_collection = db["user_logs"] +user_tools_collection = db["user_tools"] user = Blueprint("user", __name__) user_ns = Namespace("user", description="User related operations", path="/") @@ -1786,3 +1787,123 @@ def post(self): ) except Exception as err: return make_response(jsonify({"success": False, "error": str(err)}), 400) + + +@user_ns.route("/api/create_tool") +class CreateTool(Resource): + # write code such that it will accept tool_name, took_config and tool_actions + create_tool_model = api.model( + "CreateToolModel", + { + "tool_name": fields.String(required=True, description="Name of the tool"), + "tool_config": fields.Raw(required=True, description="Configuration of the tool"), + "tool_actions": fields.List(required=True, description="Actions the tool can perform"), + }, + ) + + @api.expect(create_tool_model) + @api.doc(description="Create a new tool") + def post(self): + data = request.get_json() + required_fields = ["tool_name", "tool_config", "tool_actions"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + user = "local" + try: + new_tool = { + "tool_name": data["tool_name"], + "tool_config": data["tool_config"], + "tool_actions": data["tool_actions"], + "user": user, + } + resp = user_tools_collection.insert_one(new_tool) + new_id = str(resp.inserted_id) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"id": new_id}), 200) + +@user_ns.route("/api/update_tool_config") +class UpdateToolConfig(Resource): + update_tool_config_model = api.model( + "UpdateToolConfigModel", + { + "tool_id": fields.String(required=True, description="Tool ID"), + "tool_config": fields.Raw(required=True, description="Configuration of the tool"), + }, + ) + + @api.expect(update_tool_config_model) + @api.doc(description="Update the configuration of a tool") + def post(self): + data = request.get_json() + required_fields = ["tool_id", "tool_config"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + user_tools_collection.update_one( + {"_id": ObjectId(data["tool_id"])}, + {"$set": {"tool_config": data["tool_config"]}}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + +@user_ns.route("/api/update_tool_actions") +class UpdateToolActions(Resource): + update_tool_actions_model = api.model( + "UpdateToolActionsModel", + { + "tool_id": fields.String(required=True, description="Tool ID"), + "tool_actions": fields.List(required=True, description="Actions the tool can perform"), + }, + ) + + @api.expect(update_tool_actions_model) + @api.doc(description="Update the actions of a tool") + def post(self): + data = request.get_json() + required_fields = ["tool_id", "tool_actions"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + user_tools_collection.update_one( + {"_id": ObjectId(data["tool_id"])}, + {"$set": {"tool_actions": data["tool_actions"]}}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + +@user_ns.route("/api/delete_tool") +class DeleteTool(Resource): + delete_tool_model = api.model( + "DeleteToolModel", + {"tool_id": fields.String(required=True, description="Tool ID")}, + ) + + @api.expect(delete_tool_model) + @api.doc(description="Delete a tool by ID") + def post(self): + data = request.get_json() + required_fields = ["tool_id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + result = user_tools_collection.delete_one({"_id": ObjectId(data["tool_id"])}) + if result.deleted_count == 0: + return {"success": False, "message": "Tool not found"}, 404 + except Exception as err: + return {"success": False, "error": str(err)}, 400 + + return {"success": True}, 200 \ No newline at end of file From 3e2e1ecddfc5431c63105ea43d99cee071ba48f9 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 6 Dec 2024 23:11:16 +0000 Subject: [PATCH 03/15] fix: add status to tools --- application/api/user/routes.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 814490dd6..627f5665c 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1798,6 +1798,8 @@ class CreateTool(Resource): "tool_name": fields.String(required=True, description="Name of the tool"), "tool_config": fields.Raw(required=True, description="Configuration of the tool"), "tool_actions": fields.List(required=True, description="Actions the tool can perform"), + "status": fields.Boolean(required=True, description="Status of the tool") + }, ) @@ -1805,7 +1807,7 @@ class CreateTool(Resource): @api.doc(description="Create a new tool") def post(self): data = request.get_json() - required_fields = ["tool_name", "tool_config", "tool_actions"] + required_fields = ["tool_name", "tool_config", "tool_actions", "status"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields @@ -1817,6 +1819,7 @@ def post(self): "tool_config": data["tool_config"], "tool_actions": data["tool_actions"], "user": user, + "status": data["status"], } resp = user_tools_collection.insert_one(new_tool) new_id = str(resp.inserted_id) @@ -1882,6 +1885,35 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"success": True}), 200) + +@user_ns.route("/api/update_tool_status") +class UpdateToolStatus(Resource): + update_tool_status_model = api.model( + "UpdateToolStatusModel", + { + "tool_id": fields.String(required=True, description="Tool ID"), + "status": fields.Boolean(required=True, description="Status of the tool"), + }, + ) + + @api.expect(update_tool_status_model) + @api.doc(description="Update the status of a tool") + def post(self): + data = request.get_json() + required_fields = ["tool_id", "status"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + user_tools_collection.update_one( + {"_id": ObjectId(data["tool_id"])}, + {"$set": {"status": data["status"]}}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) @user_ns.route("/api/delete_tool") class DeleteTool(Resource): From f87ae429f422aba53c793b107b955c7705b4cc07 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 9 Dec 2024 17:52:20 +0000 Subject: [PATCH 04/15] fix: edit names --- application/api/user/routes.py | 47 +++++++++++++++++----------------- application/tools/agent.py | 38 ++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 627f5665c..efb0242dc 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1791,13 +1791,12 @@ def post(self): @user_ns.route("/api/create_tool") class CreateTool(Resource): - # write code such that it will accept tool_name, took_config and tool_actions create_tool_model = api.model( "CreateToolModel", { - "tool_name": fields.String(required=True, description="Name of the tool"), - "tool_config": fields.Raw(required=True, description="Configuration of the tool"), - "tool_actions": fields.List(required=True, description="Actions the tool can perform"), + "name": fields.String(required=True, description="Name of the tool"), + "config": fields.Raw(required=True, description="Configuration of the tool"), + "actions": fields.List(required=True, description="Actions the tool can perform"), "status": fields.Boolean(required=True, description="Status of the tool") }, @@ -1807,7 +1806,7 @@ class CreateTool(Resource): @api.doc(description="Create a new tool") def post(self): data = request.get_json() - required_fields = ["tool_name", "tool_config", "tool_actions", "status"] + required_fields = ["name", "config", "actions", "status"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields @@ -1815,9 +1814,9 @@ def post(self): user = "local" try: new_tool = { - "tool_name": data["tool_name"], - "tool_config": data["tool_config"], - "tool_actions": data["tool_actions"], + "name": data["name"], + "config": data["config"], + "actions": data["actions"], "user": user, "status": data["status"], } @@ -1833,8 +1832,8 @@ class UpdateToolConfig(Resource): update_tool_config_model = api.model( "UpdateToolConfigModel", { - "tool_id": fields.String(required=True, description="Tool ID"), - "tool_config": fields.Raw(required=True, description="Configuration of the tool"), + "id": fields.String(required=True, description="Tool ID"), + "config": fields.Raw(required=True, description="Configuration of the tool"), }, ) @@ -1842,15 +1841,15 @@ class UpdateToolConfig(Resource): @api.doc(description="Update the configuration of a tool") def post(self): data = request.get_json() - required_fields = ["tool_id", "tool_config"] + required_fields = ["id", "config"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: user_tools_collection.update_one( - {"_id": ObjectId(data["tool_id"])}, - {"$set": {"tool_config": data["tool_config"]}}, + {"_id": ObjectId(data["id"])}, + {"$set": {"config": data["config"]}}, ) except Exception as err: return make_response(jsonify({"success": False, "error": str(err)}), 400) @@ -1862,8 +1861,8 @@ class UpdateToolActions(Resource): update_tool_actions_model = api.model( "UpdateToolActionsModel", { - "tool_id": fields.String(required=True, description="Tool ID"), - "tool_actions": fields.List(required=True, description="Actions the tool can perform"), + "id": fields.String(required=True, description="Tool ID"), + "actions": fields.List(required=True, description="Actions the tool can perform"), }, ) @@ -1871,15 +1870,15 @@ class UpdateToolActions(Resource): @api.doc(description="Update the actions of a tool") def post(self): data = request.get_json() - required_fields = ["tool_id", "tool_actions"] + required_fields = ["id", "actions"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: user_tools_collection.update_one( - {"_id": ObjectId(data["tool_id"])}, - {"$set": {"tool_actions": data["tool_actions"]}}, + {"_id": ObjectId(data["id"])}, + {"$set": {"actions": data["actions"]}}, ) except Exception as err: return make_response(jsonify({"success": False, "error": str(err)}), 400) @@ -1891,7 +1890,7 @@ class UpdateToolStatus(Resource): update_tool_status_model = api.model( "UpdateToolStatusModel", { - "tool_id": fields.String(required=True, description="Tool ID"), + "id": fields.String(required=True, description="Tool ID"), "status": fields.Boolean(required=True, description="Status of the tool"), }, ) @@ -1900,14 +1899,14 @@ class UpdateToolStatus(Resource): @api.doc(description="Update the status of a tool") def post(self): data = request.get_json() - required_fields = ["tool_id", "status"] + required_fields = ["id", "status"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: user_tools_collection.update_one( - {"_id": ObjectId(data["tool_id"])}, + {"_id": ObjectId(data["id"])}, {"$set": {"status": data["status"]}}, ) except Exception as err: @@ -1919,20 +1918,20 @@ def post(self): class DeleteTool(Resource): delete_tool_model = api.model( "DeleteToolModel", - {"tool_id": fields.String(required=True, description="Tool ID")}, + {"id": fields.String(required=True, description="Tool ID")}, ) @api.expect(delete_tool_model) @api.doc(description="Delete a tool by ID") def post(self): data = request.get_json() - required_fields = ["tool_id"] + required_fields = ["id"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: - result = user_tools_collection.delete_one({"_id": ObjectId(data["tool_id"])}) + result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])}) if result.deleted_count == 0: return {"success": False, "message": "Tool not found"}, 404 except Exception as err: diff --git a/application/tools/agent.py b/application/tools/agent.py index 2df14442b..af23b99e2 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,6 +1,7 @@ from application.llm.llm_creator import LLMCreator from application.core.settings import settings from application.tools.tool_manager import ToolManager +from application.core.mongo_db import MongoDB import json tool_tg = { @@ -53,9 +54,31 @@ def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): ] self.tool_config = { } + + def _get_user_tools(self, user="local"): + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + user_tools_collection = db["user_tools"] + user_tools = user_tools_collection.find({"user": user, "status": True}) + user_tools = list(user_tools) + for tool in user_tools: + tool.pop("_id") + user_tools = {tool["name"]: tool for tool in user_tools} + return user_tools + + def _simple_tool_agent(self, messages): + tools_dict = self._get_user_tools() + # combine all tool_actions into one list + self.tools.extend([ + { + "type": "function", + "function": tool_action + } + for tool in tools_dict.values() + for tool_action in tool["actions"] + ]) + - def gen(self, messages): - # Generate initial response from the LLM resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) if isinstance(resp, str): @@ -75,7 +98,7 @@ def gen(self, messages): call_id = call.id # Determine the tool name and load it tool_name = call_name.split("_")[0] - tool = tm.load_tool(tool_name, tool_config=self.tool_config) + tool = tm.load_tool(tool_name, tool_config=tools_dict[tool_name]['config']) # Execute the tool's action resp_tool = tool.execute_action(call_name, **call_args) # Append the tool's response to the conversation @@ -96,3 +119,12 @@ def gen(self, messages): completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools) for line in completion: yield line + + def gen(self, messages): + # Generate initial response from the LLM + if self.llm.supports_tools(): + self._simple_tool_agent(messages) + else: + resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) + for line in resp: + yield line \ No newline at end of file From 46b0de367a7a3ed12f14db99ba0b25c1117e54d2 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 12 Dec 2024 10:40:55 +0000 Subject: [PATCH 05/15] fix: strings --- application/api/user/routes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index efb0242dc..a5807b81e 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1796,7 +1796,7 @@ class CreateTool(Resource): { "name": fields.String(required=True, description="Name of the tool"), "config": fields.Raw(required=True, description="Configuration of the tool"), - "actions": fields.List(required=True, description="Actions the tool can perform"), + "actions": fields.List(fields.String, required=True, description="Actions the tool can perform"), "status": fields.Boolean(required=True, description="Status of the tool") }, @@ -1862,7 +1862,7 @@ class UpdateToolActions(Resource): "UpdateToolActionsModel", { "id": fields.String(required=True, description="Tool ID"), - "actions": fields.List(required=True, description="Actions the tool can perform"), + "actions": fields.List(fields.String, required=True, description="Actions the tool can perform"), }, ) From f9a7db11ebc89d63d7d7dad4cea062bfbe1567d0 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 18 Dec 2024 22:48:40 +0530 Subject: [PATCH 06/15] feat: tools frontend and endpoints refactor --- application/api/user/routes.py | 235 +++++++++++--- application/llm/groq.py | 39 +-- application/tools/agent.py | 82 +++-- application/tools/base.py | 1 + .../{ => implementations}/cryptoprice.py | 30 +- .../tools/{ => implementations}/telegram.py | 42 +-- application/tools/tool_manager.py | 18 +- .../public/toolIcons/tool_cryptoprice.svg | 1 + frontend/public/toolIcons/tool_telegram.svg | 10 + frontend/src/api/endpoints.ts | 5 + frontend/src/api/services/userService.ts | 10 + frontend/src/assets/cogwheel.svg | 3 + frontend/src/components/SettingsBar.tsx | 1 + frontend/src/index.css | 9 + frontend/src/locale/en.json | 3 + frontend/src/modals/AddToolModal.tsx | 136 ++++++++ frontend/src/modals/ConfigToolModal.tsx | 95 ++++++ frontend/src/modals/types/index.ts | 11 + frontend/src/settings/Documents.tsx | 18 +- frontend/src/settings/ToolConfig.tsx | 293 ++++++++++++++++++ frontend/src/settings/Tools.tsx | 157 ++++++++++ frontend/src/settings/index.tsx | 5 +- frontend/src/settings/types/index.ts | 29 ++ 23 files changed, 1067 insertions(+), 166 deletions(-) rename application/tools/{ => implementations}/cryptoprice.py (82%) rename application/tools/{ => implementations}/telegram.py (75%) create mode 100644 frontend/public/toolIcons/tool_cryptoprice.svg create mode 100644 frontend/public/toolIcons/tool_telegram.svg create mode 100644 frontend/src/assets/cogwheel.svg create mode 100644 frontend/src/modals/AddToolModal.tsx create mode 100644 frontend/src/modals/ConfigToolModal.tsx create mode 100644 frontend/src/modals/types/index.ts create mode 100644 frontend/src/settings/ToolConfig.tsx create mode 100644 frontend/src/settings/Tools.tsx diff --git a/application/api/user/routes.py b/application/api/user/routes.py index efb0242dc..05248f97a 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1,14 +1,14 @@ import datetime +import math import os import shutil import uuid -import math from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef from bson.objectid import ObjectId -from flask import Blueprint, jsonify, make_response, request, redirect -from flask_restx import inputs, fields, Namespace, Resource +from flask import Blueprint, jsonify, make_response, redirect, request +from flask_restx import fields, inputs, Namespace, Resource from werkzeug.utils import secure_filename from application.api.user.tasks import ingest, ingest_remote @@ -16,9 +16,10 @@ from application.core.mongo_db import MongoDB from application.core.settings import settings from application.extensions import api +from application.tools.tool_manager import ToolManager +from application.tts.google_tts import GoogleTTS from application.utils import check_required_fields from application.vectorstore.vector_creator import VectorCreator -from application.tts.google_tts import GoogleTTS mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -40,6 +41,9 @@ os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) +tool_config = {} +tool_manager = ToolManager(config=tool_config) + def generate_minute_range(start_date, end_date): return { @@ -1789,24 +1793,88 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) +@user_ns.route("/api/available_tools") +class AvailableTools(Resource): + @api.doc(description="Get available tools for a user") + def get(self): + try: + tools_metadata = [] + for tool_name, tool_instance in tool_manager.tools.items(): + doc = tool_instance.__doc__.strip() + lines = doc.split("\n", 1) + name = lines[0].strip() + description = lines[1].strip() if len(lines) > 1 else "" + tools_metadata.append( + { + "name": tool_name, + "displayName": name, + "description": description, + "configRequirements": tool_instance.get_config_requirements(), + "actions": tool_instance.get_actions_metadata(), + } + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True, "data": tools_metadata}), 200) + + +@user_ns.route("/api/get_tools") +class GetTools(Resource): + @api.doc(description="Get tools created by a user") + def get(self): + try: + user = "local" + tools = user_tools_collection.find({"user": user}) + user_tools = [] + for tool in tools: + tool["id"] = str(tool["_id"]) + tool.pop("_id") + user_tools.append(tool) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True, "tools": user_tools}), 200) + + @user_ns.route("/api/create_tool") class CreateTool(Resource): - create_tool_model = api.model( - "CreateToolModel", - { - "name": fields.String(required=True, description="Name of the tool"), - "config": fields.Raw(required=True, description="Configuration of the tool"), - "actions": fields.List(required=True, description="Actions the tool can perform"), - "status": fields.Boolean(required=True, description="Status of the tool") - - }, + @api.expect( + api.model( + "CreateToolModel", + { + "name": fields.String(required=True, description="Name of the tool"), + "displayName": fields.String( + required=True, description="Display name for the tool" + ), + "description": fields.String( + required=True, description="Tool description" + ), + "config": fields.Raw( + required=True, description="Configuration of the tool" + ), + "actions": fields.List( + fields.Raw, + required=True, + description="Actions the tool can perform", + ), + "status": fields.Boolean( + required=True, description="Status of the tool" + ), + }, + ) ) - - @api.expect(create_tool_model) @api.doc(description="Create a new tool") def post(self): data = request.get_json() - required_fields = ["name", "config", "actions", "status"] + required_fields = [ + "name", + "displayName", + "description", + "actions", + "config", + "status", + ] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields @@ -1814,10 +1882,12 @@ def post(self): user = "local" try: new_tool = { + "user": user, "name": data["name"], - "config": data["config"], + "displayName": data["displayName"], + "description": data["description"], "actions": data["actions"], - "user": user, + "config": data["config"], "status": data["status"], } resp = user_tools_collection.insert_one(new_tool) @@ -1826,18 +1896,72 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"id": new_id}), 200) - + + +@user_ns.route("/api/update_tool") +class UpdateTool(Resource): + @api.expect( + api.model( + "UpdateToolModel", + { + "id": fields.String(required=True, description="Tool ID"), + "name": fields.String(description="Name of the tool"), + "displayName": fields.String(description="Display name for the tool"), + "description": fields.String(description="Tool description"), + "config": fields.Raw(description="Configuration of the tool"), + "actions": fields.List( + fields.Raw, description="Actions the tool can perform" + ), + "status": fields.Boolean(description="Status of the tool"), + }, + ) + ) + @api.doc(description="Update a tool by ID") + def post(self): + data = request.get_json() + required_fields = ["id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + update_data = {} + if "name" in data: + update_data["name"] = data["name"] + if "displayName" in data: + update_data["displayName"] = data["displayName"] + if "description" in data: + update_data["description"] = data["description"] + if "actions" in data: + update_data["actions"] = data["actions"] + if "config" in data: + update_data["config"] = data["config"] + if "status" in data: + update_data["status"] = data["status"] + + user_tools_collection.update_one( + {"_id": ObjectId(data["id"]), "user": "local"}, + {"$set": update_data}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + @user_ns.route("/api/update_tool_config") class UpdateToolConfig(Resource): - update_tool_config_model = api.model( - "UpdateToolConfigModel", - { - "id": fields.String(required=True, description="Tool ID"), - "config": fields.Raw(required=True, description="Configuration of the tool"), - }, + @api.expect( + api.model( + "UpdateToolConfigModel", + { + "id": fields.String(required=True, description="Tool ID"), + "config": fields.Raw( + required=True, description="Configuration of the tool" + ), + }, + ) ) - - @api.expect(update_tool_config_model) @api.doc(description="Update the configuration of a tool") def post(self): data = request.get_json() @@ -1855,18 +1979,23 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"success": True}), 200) - + + @user_ns.route("/api/update_tool_actions") class UpdateToolActions(Resource): - update_tool_actions_model = api.model( - "UpdateToolActionsModel", - { - "id": fields.String(required=True, description="Tool ID"), - "actions": fields.List(required=True, description="Actions the tool can perform"), - }, + @api.expect( + api.model( + "UpdateToolActionsModel", + { + "id": fields.String(required=True, description="Tool ID"), + "actions": fields.List( + fields.Raw, + required=True, + description="Actions the tool can perform", + ), + }, + ) ) - - @api.expect(update_tool_actions_model) @api.doc(description="Update the actions of a tool") def post(self): data = request.get_json() @@ -1885,17 +2014,20 @@ def post(self): return make_response(jsonify({"success": True}), 200) + @user_ns.route("/api/update_tool_status") class UpdateToolStatus(Resource): - update_tool_status_model = api.model( - "UpdateToolStatusModel", - { - "id": fields.String(required=True, description="Tool ID"), - "status": fields.Boolean(required=True, description="Status of the tool"), - }, + @api.expect( + api.model( + "UpdateToolStatusModel", + { + "id": fields.String(required=True, description="Tool ID"), + "status": fields.Boolean( + required=True, description="Status of the tool" + ), + }, + ) ) - - @api.expect(update_tool_status_model) @api.doc(description="Update the status of a tool") def post(self): data = request.get_json() @@ -1913,15 +2045,16 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"success": True}), 200) - + + @user_ns.route("/api/delete_tool") class DeleteTool(Resource): - delete_tool_model = api.model( - "DeleteToolModel", - {"id": fields.String(required=True, description="Tool ID")}, + @api.expect( + api.model( + "DeleteToolModel", + {"id": fields.String(required=True, description="Tool ID")}, + ) ) - - @api.expect(delete_tool_model) @api.doc(description="Delete a tool by ID") def post(self): data = request.get_json() @@ -1937,4 +2070,4 @@ def post(self): except Exception as err: return {"success": False, "error": str(err)}, 400 - return {"success": True}, 200 \ No newline at end of file + return {"success": True}, 200 diff --git a/application/llm/groq.py b/application/llm/groq.py index b5731a905..f2fcfbeb4 100644 --- a/application/llm/groq.py +++ b/application/llm/groq.py @@ -1,45 +1,32 @@ from application.llm.base import BaseLLM - +from openai import OpenAI class GroqLLM(BaseLLM): - def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): - from openai import OpenAI - super().__init__(*args, **kwargs) self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1") self.api_key = api_key self.user_api_key = user_api_key - def _raw_gen( - self, - baseself, - model, - messages, - stream=False, - **kwargs - ): - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) - + def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): + if tools: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, tools=tools, **kwargs + ) + return response.choices[0] + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) return response.choices[0].message.content def _raw_gen_stream( - self, - baseself, - model, - messages, - stream=True, - **kwargs - ): + self, baseself, model, messages, stream=True, tools=None, **kwargs + ): response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, **kwargs ) - for line in response: - # import sys - # print(line.choices[0].delta.content, file=sys.stderr) if line.choices[0].delta.content is not None: yield line.choices[0].delta.content diff --git a/application/tools/agent.py b/application/tools/agent.py index af23b99e2..ffd14770e 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,8 +1,9 @@ -from application.llm.llm_creator import LLMCreator +import json + +from application.core.mongo_db import MongoDB from application.core.settings import settings +from application.llm.llm_creator import LLMCreator from application.tools.tool_manager import ToolManager -from application.core.mongo_db import MongoDB -import json tool_tg = { "name": "telegram_send_message", @@ -12,15 +13,15 @@ "properties": { "text": { "type": "string", - "description": "Text to send in the notification" + "description": "Text to send in the notification", } }, "required": ["text"], - "additionalProperties": False - } + "additionalProperties": False, + }, } -tool_crypto = { +tool_crypto = { "name": "cryptoprice_get", "description": "Retrieve the price of a specified cryptocurrency in a given currency", "parameters": { @@ -28,33 +29,30 @@ "properties": { "symbol": { "type": "string", - "description": "The cryptocurrency symbol (e.g. BTC)" + "description": "The cryptocurrency symbol (e.g. BTC)", }, "currency": { "type": "string", - "description": "The currency in which you want the price (e.g. USD)" - } + "description": "The currency in which you want the price (e.g. USD)", + }, }, "required": ["symbol", "currency"], - "additionalProperties": False - } + "additionalProperties": False, + }, } + class Agent: def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): # Initialize the LLM with the provided parameters - self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key) + self.llm = LLMCreator.create_llm( + llm_name, api_key=api_key, user_api_key=user_api_key + ) self.gpt_model = gpt_model # Static tool configuration (to be replaced later) - self.tools = [ - { - "type": "function", - "function": tool_crypto - } - ] - self.tool_config = { - } - + self.tools = [{"type": "function", "function": tool_crypto}] + self.tool_config = {} + def _get_user_tools(self, user="local"): mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -65,19 +63,17 @@ def _get_user_tools(self, user="local"): tool.pop("_id") user_tools = {tool["name"]: tool for tool in user_tools} return user_tools - + def _simple_tool_agent(self, messages): tools_dict = self._get_user_tools() # combine all tool_actions into one list - self.tools.extend([ - { - "type": "function", - "function": tool_action - } - for tool in tools_dict.values() - for tool_action in tool["actions"] - ]) - + self.tools.extend( + [ + {"type": "function", "function": tool_action} + for tool in tools_dict.values() + for tool_action in tool["actions"] + ] + ) resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) @@ -88,7 +84,7 @@ def _simple_tool_agent(self, messages): while resp.finish_reason == "tool_calls": # Append the assistant's message to the conversation - messages.append(json.loads(resp.model_dump_json())['message']) + messages.append(json.loads(resp.model_dump_json())["message"]) # Handle each tool call tool_calls = resp.message.tool_calls for call in tool_calls: @@ -98,25 +94,27 @@ def _simple_tool_agent(self, messages): call_id = call.id # Determine the tool name and load it tool_name = call_name.split("_")[0] - tool = tm.load_tool(tool_name, tool_config=tools_dict[tool_name]['config']) + tool = tm.load_tool( + tool_name, tool_config=tools_dict[tool_name]["config"] + ) # Execute the tool's action resp_tool = tool.execute_action(call_name, **call_args) # Append the tool's response to the conversation messages.append( - { - "role": "tool", - "content": str(resp_tool), - "tool_call_id": call_id - } + {"role": "tool", "content": str(resp_tool), "tool_call_id": call_id} ) # Generate a new response from the LLM after processing tools - resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + resp = self.llm.gen( + model=self.gpt_model, messages=messages, tools=self.tools + ) # If no tool calls are needed, generate the final response if isinstance(resp, str): yield resp else: - completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools) + completion = self.llm.gen_stream( + model=self.gpt_model, messages=messages, tools=self.tools + ) for line in completion: yield line @@ -127,4 +125,4 @@ def gen(self, messages): else: resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) for line in resp: - yield line \ No newline at end of file + yield line diff --git a/application/tools/base.py b/application/tools/base.py index 00cfee3a0..fd7b4a852 100644 --- a/application/tools/base.py +++ b/application/tools/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class Tool(ABC): @abstractmethod def execute_action(self, action_name: str, **kwargs): diff --git a/application/tools/cryptoprice.py b/application/tools/implementations/cryptoprice.py similarity index 82% rename from application/tools/cryptoprice.py rename to application/tools/implementations/cryptoprice.py index d7cf61e13..7b88c866d 100644 --- a/application/tools/cryptoprice.py +++ b/application/tools/implementations/cryptoprice.py @@ -1,21 +1,25 @@ -from application.tools.base import Tool import requests +from application.tools.base import Tool + class CryptoPriceTool(Tool): + """ + CryptoPrice + A tool for retrieving cryptocurrency prices using the CryptoCompare public API + """ + def __init__(self, config): self.config = config def execute_action(self, action_name, **kwargs): - actions = { - "cryptoprice_get": self.get_price - } + actions = {"cryptoprice_get": self._get_price} if action_name in actions: return actions[action_name](**kwargs) else: raise ValueError(f"Unknown action: {action_name}") - def get_price(self, symbol, currency): + def _get_price(self, symbol, currency): """ Fetches the current price of a given cryptocurrency symbol in the specified currency. Example: @@ -32,17 +36,17 @@ def get_price(self, symbol, currency): return { "status_code": response.status_code, "price": data[currency.upper()], - "message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully." + "message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.", } else: return { "status_code": response.status_code, - "message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}." + "message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.", } else: return { "status_code": response.status_code, - "message": "Failed to retrieve price." + "message": "Failed to retrieve price.", } def get_actions_metadata(self): @@ -55,16 +59,16 @@ def get_actions_metadata(self): "properties": { "symbol": { "type": "string", - "description": "The cryptocurrency symbol (e.g. BTC)" + "description": "The cryptocurrency symbol (e.g. BTC)", }, "currency": { "type": "string", - "description": "The currency in which you want the price (e.g. USD)" - } + "description": "The currency in which you want the price (e.g. USD)", + }, }, "required": ["symbol", "currency"], - "additionalProperties": False - } + "additionalProperties": False, + }, } ] diff --git a/application/tools/telegram.py b/application/tools/implementations/telegram.py similarity index 75% rename from application/tools/telegram.py rename to application/tools/implementations/telegram.py index 8210d8e71..a2b436b47 100644 --- a/application/tools/telegram.py +++ b/application/tools/implementations/telegram.py @@ -1,16 +1,23 @@ -from application.tools.base import Tool import requests +from application.tools.base import Tool + class TelegramTool(Tool): + """ + Telegram Bot + A flexible Telegram tool for performing various actions (e.g., sending messages, images). + Requires a bot token and chat ID for configuration + """ + def __init__(self, config): self.config = config - self.chat_id = config.get("chat_id", "142189016") - self.token = config.get("token", "YOUR_TG_TOKEN") + self.token = config.get("token", "") + self.chat_id = config.get("chat_id", "") def execute_action(self, action_name, **kwargs): actions = { - "telegram_send_message": self.send_message, - "telegram_send_image": self.send_image + "telegram_send_message": self._send_message, + "telegram_send_image": self._send_image, } if action_name in actions: @@ -18,14 +25,14 @@ def execute_action(self, action_name, **kwargs): else: raise ValueError(f"Unknown action: {action_name}") - def send_message(self, text): + def _send_message(self, text): print(f"Sending message: {text}") url = f"https://api.telegram.org/bot{self.token}/sendMessage" payload = {"chat_id": self.chat_id, "text": text} response = requests.post(url, data=payload) return {"status_code": response.status_code, "message": "Message sent"} - def send_image(self, image_url): + def _send_image(self, image_url): print(f"Sending image: {image_url}") url = f"https://api.telegram.org/bot{self.token}/sendPhoto" payload = {"chat_id": self.chat_id, "photo": image_url} @@ -42,12 +49,12 @@ def get_actions_metadata(self): "properties": { "text": { "type": "string", - "description": "Text to send in the notification" + "description": "Text to send in the notification", } }, "required": ["text"], - "additionalProperties": False - } + "additionalProperties": False, + }, }, { "name": "telegram_send_image", @@ -57,23 +64,20 @@ def get_actions_metadata(self): "properties": { "image_url": { "type": "string", - "description": "URL of the image to send" + "description": "URL of the image to send", } }, "required": ["image_url"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, ] def get_config_requirements(self): return { "chat_id": { "type": "string", - "description": "Telegram chat ID to send messages to" + "description": "Telegram chat ID to send messages to", }, - "token": { - "type": "string", - "description": "Bot token for authentication" - } + "token": {"type": "string", "description": "Bot token for authentication"}, } diff --git a/application/tools/tool_manager.py b/application/tools/tool_manager.py index 10231cb28..cc9a055a0 100644 --- a/application/tools/tool_manager.py +++ b/application/tools/tool_manager.py @@ -1,10 +1,11 @@ import importlib import inspect -import pkgutil import os +import pkgutil from application.tools.base import Tool + class ToolManager: def __init__(self, config): self.config = config @@ -12,11 +13,13 @@ def __init__(self, config): self.load_tools() def load_tools(self): - tools_dir = os.path.dirname(__file__) + tools_dir = os.path.join(os.path.dirname(__file__), "implementations") for finder, name, ispkg in pkgutil.iter_modules([tools_dir]): - if name == 'base' or name.startswith('__'): + if name == "base" or name.startswith("__"): continue - module = importlib.import_module(f'application.tools.{name}') + module = importlib.import_module( + f"application.tools.implementations.{name}" + ) for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: tool_config = self.config.get(name, {}) @@ -24,13 +27,14 @@ def load_tools(self): def load_tool(self, tool_name, tool_config): self.config[tool_name] = tool_config - tools_dir = os.path.dirname(__file__) - module = importlib.import_module(f'application.tools.{tool_name}') + tools_dir = os.path.join(os.path.dirname(__file__), "implementations") + module = importlib.import_module( + f"application.tools.implementations.{tool_name}" + ) for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: return obj(tool_config) - def execute_action(self, tool_name, action_name, **kwargs): if tool_name not in self.tools: raise ValueError(f"Tool '{tool_name}' not loaded") diff --git a/frontend/public/toolIcons/tool_cryptoprice.svg b/frontend/public/toolIcons/tool_cryptoprice.svg new file mode 100644 index 000000000..6a4226949 --- /dev/null +++ b/frontend/public/toolIcons/tool_cryptoprice.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/public/toolIcons/tool_telegram.svg b/frontend/public/toolIcons/tool_telegram.svg new file mode 100644 index 000000000..27536dedf --- /dev/null +++ b/frontend/public/toolIcons/tool_telegram.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 4e7112d04..8a7f9ae23 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -18,6 +18,11 @@ const endpoints = { FEEDBACK_ANALYTICS: '/api/get_feedback_analytics', LOGS: `/api/get_user_logs`, MANAGE_SYNC: '/api/manage_sync', + GET_AVAILABLE_TOOLS: '/api/available_tools', + GET_USER_TOOLS: '/api/get_tools', + CREATE_TOOL: '/api/create_tool', + UPDATE_TOOL_STATUS: '/api/update_tool_status', + UPDATE_TOOL: '/api/update_tool', }, CONVERSATION: { ANSWER: '/api/answer', diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 942318ae4..ab91a0a42 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -35,6 +35,16 @@ const userService = { apiClient.post(endpoints.USER.LOGS, data), manageSync: (data: any): Promise => apiClient.post(endpoints.USER.MANAGE_SYNC, data), + getAvailableTools: (): Promise => + apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS), + getUserTools: (): Promise => + apiClient.get(endpoints.USER.GET_USER_TOOLS), + createTool: (data: any): Promise => + apiClient.post(endpoints.USER.CREATE_TOOL, data), + updateToolStatus: (data: any): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data), + updateTool: (data: any): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL, data), }; export default userService; diff --git a/frontend/src/assets/cogwheel.svg b/frontend/src/assets/cogwheel.svg new file mode 100644 index 000000000..f5299b8ba --- /dev/null +++ b/frontend/src/assets/cogwheel.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/src/components/SettingsBar.tsx b/frontend/src/components/SettingsBar.tsx index f617c6e82..c6970600d 100644 --- a/frontend/src/components/SettingsBar.tsx +++ b/frontend/src/components/SettingsBar.tsx @@ -13,6 +13,7 @@ const useTabs = () => { t('settings.apiKeys.label'), t('settings.analytics.label'), t('settings.logs.label'), + t('settings.tools.label'), ]; return tabs; }; diff --git a/frontend/src/index.css b/frontend/src/index.css index 9bd2ec964..ecd4da1eb 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -68,6 +68,15 @@ body.dark { .table-default td:last-child { @apply border-r-0; /* Ensure no right border on the last column */ } + + .table-default th, + .table-default td { + min-width: 150px; + max-width: 320px; + overflow: auto; + scrollbar-width: thin; + scrollbar-color: grey transparent; + } } /*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */ diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 134604633..4048f6324 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -73,6 +73,9 @@ }, "logs": { "label": "Logs" + }, + "tools": { + "label": "Tools" } }, "modals": { diff --git a/frontend/src/modals/AddToolModal.tsx b/frontend/src/modals/AddToolModal.tsx new file mode 100644 index 000000000..330cf3bb2 --- /dev/null +++ b/frontend/src/modals/AddToolModal.tsx @@ -0,0 +1,136 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import Exit from '../assets/exit.svg'; +import { ActiveState } from '../models/misc'; +import { AvailableTool } from './types'; +import ConfigToolModal from './ConfigToolModal'; + +export default function AddToolModal({ + message, + modalState, + setModalState, + getUserTools, +}: { + message: string; + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + getUserTools: () => void; +}) { + const [availableTools, setAvailableTools] = React.useState( + [], + ); + const [selectedTool, setSelectedTool] = React.useState( + null, + ); + const [configModalState, setConfigModalState] = + React.useState('INACTIVE'); + + const getAvailableTools = () => { + userService + .getAvailableTools() + .then((res) => { + return res.json(); + }) + .then((data) => { + setAvailableTools(data.data); + }); + }; + + const handleAddTool = (tool: AvailableTool) => { + if (Object.keys(tool.configRequirements).length === 0) { + userService + .createTool({ + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: {}, + actions: tool.actions, + status: true, + }) + .then((res) => { + if (res.status === 200) { + getUserTools(); + setModalState('INACTIVE'); + } + }); + } else { + setModalState('INACTIVE'); + setConfigModalState('ACTIVE'); + } + }; + + React.useEffect(() => { + if (modalState === 'ACTIVE') getAvailableTools(); + }, [modalState]); + return ( + <> +
+
+
+ +
+

+ Select a tool to set up +

+
+ {availableTools.map((tool, index) => ( +
{ + setSelectedTool(tool); + handleAddTool(tool); + }} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + setSelectedTool(tool); + handleAddTool(tool); + } + }} + > +
+
+ +
+
+

+ {tool.displayName} +

+

+ {tool.description} +

+
+
+
+ ))} +
+
+
+
+
+ + + ); +} diff --git a/frontend/src/modals/ConfigToolModal.tsx b/frontend/src/modals/ConfigToolModal.tsx new file mode 100644 index 000000000..96bb15be8 --- /dev/null +++ b/frontend/src/modals/ConfigToolModal.tsx @@ -0,0 +1,95 @@ +import React from 'react'; + +import Exit from '../assets/exit.svg'; +import Input from '../components/Input'; +import { ActiveState } from '../models/misc'; +import { AvailableTool } from './types'; +import userService from '../api/services/userService'; + +export default function ConfigToolModal({ + modalState, + setModalState, + tool, + getUserTools, +}: { + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + tool: AvailableTool | null; + getUserTools: () => void; +}) { + const [authKey, setAuthKey] = React.useState(''); + + const handleAddTool = (tool: AvailableTool) => { + userService + .createTool({ + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: { token: authKey }, + actions: tool.actions, + status: true, + }) + .then(() => { + setModalState('INACTIVE'); + getUserTools(); + }); + }; + return ( +
+
+
+ +
+

+ Tool Config +

+

+ Type: {tool?.name} +

+
+ + API Key / Oauth + + setAuthKey(e.target.value)} + borderVariant="thin" + placeholder="Enter API Key / Oauth" + > +
+
+ + +
+
+
+
+
+ ); +} diff --git a/frontend/src/modals/types/index.ts b/frontend/src/modals/types/index.ts new file mode 100644 index 000000000..044ba4386 --- /dev/null +++ b/frontend/src/modals/types/index.ts @@ -0,0 +1,11 @@ +export type AvailableTool = { + name: string; + displayName: string; + description: string; + configRequirements: object; + actions: { + name: string; + description: string; + parameters: object; + }[]; +}; diff --git a/frontend/src/settings/Documents.tsx b/frontend/src/settings/Documents.tsx index f91a33559..6375b4e27 100644 --- a/frontend/src/settings/Documents.tsx +++ b/frontend/src/settings/Documents.tsx @@ -215,18 +215,22 @@ const Documents: React.FC = ({ {document.type === 'remote' ? 'Pre-loaded' : 'Private'} -
+
{document.type !== 'remote' && ( - Delete { event.stopPropagation(); handleDeleteDocument(index, document); }} - /> + > + Delete + )} {document.syncFrequency && (
diff --git a/frontend/src/settings/ToolConfig.tsx b/frontend/src/settings/ToolConfig.tsx new file mode 100644 index 000000000..0de4dab96 --- /dev/null +++ b/frontend/src/settings/ToolConfig.tsx @@ -0,0 +1,293 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import ArrowLeft from '../assets/arrow-left.svg'; +import Input from '../components/Input'; +import { UserTool } from './types'; + +export default function ToolConfig({ + tool, + setTool, + handleGoBack, +}: { + tool: UserTool; + setTool: (tool: UserTool) => void; + handleGoBack: () => void; +}) { + const [authKey, setAuthKey] = React.useState( + tool.config?.token || '', + ); + + const handleCheckboxChange = (actionIndex: number, property: string) => { + setTool({ + ...tool, + actions: tool.actions.map((action, index) => { + if (index === actionIndex) { + return { + ...action, + parameters: { + ...action.parameters, + properties: { + ...action.parameters.properties, + [property]: { + ...action.parameters.properties[property], + filled_by_llm: + !action.parameters.properties[property].filled_by_llm, + }, + }, + }, + }; + } + return action; + }), + }); + }; + + const handleSaveChanges = () => { + userService + .updateTool({ + id: tool.id, + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: { token: authKey }, + actions: tool.actions, + status: tool.status, + }) + .then(() => { + handleGoBack(); + }); + }; + return ( +
+
+ +

Back to all tools

+
+
+

+ Type +

+

+ {tool.name} +

+
+
+ {Object.keys(tool?.config).length !== 0 && ( +

+ Authentication +

+ )} +
+ {Object.keys(tool?.config).length !== 0 && ( +
+ + API Key / Oauth + + setAuthKey(e.target.value)} + borderVariant="thin" + placeholder="Enter API Key / Oauth" + > +
+ )} + +
+
+
+
+

+ Actions +

+
+ {tool.actions.map((action, actionIndex) => { + return ( +
+
+

+ {action.name} +

+ +
+
+ { + setTool({ + ...tool, + actions: tool.actions.map((act, index) => { + if (index === actionIndex) { + return { + ...act, + description: e.target.value, + }; + } + return act; + }), + }); + }} + borderVariant="thin" + > +
+
+ + + + + + + + + + + + {Object.entries(action.parameters?.properties).map( + (param, index) => { + const uniqueKey = `${actionIndex}-${param[0]}`; + return ( + + + + + + + + ); + }, + )} + +
Field NameField TypeFilled by LLMFIeld descriptionValue
{param[0]}{param[1].type} + + + { + setTool({ + ...tool, + actions: tool.actions.map( + (act, index) => { + if (index === actionIndex) { + return { + ...act, + parameters: { + ...act.parameters, + properties: { + ...act.parameters.properties, + [param[0]]: { + ...act.parameters + .properties[param[0]], + description: e.target.value, + }, + }, + }, + }; + } + return act; + }, + ), + }); + }} + > + + { + setTool({ + ...tool, + actions: tool.actions.map( + (act, index) => { + if (index === actionIndex) { + return { + ...act, + parameters: { + ...act.parameters, + properties: { + ...act.parameters.properties, + [param[0]]: { + ...act.parameters + .properties[param[0]], + value: e.target.value, + }, + }, + }, + }; + } + return act; + }, + ), + }); + }} + > +
+
+
+ ); + })} +
+
+
+ ); +} diff --git a/frontend/src/settings/Tools.tsx b/frontend/src/settings/Tools.tsx new file mode 100644 index 000000000..d29ba42c6 --- /dev/null +++ b/frontend/src/settings/Tools.tsx @@ -0,0 +1,157 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import CogwheelIcon from '../assets/cogwheel.svg'; +import Input from '../components/Input'; +import AddToolModal from '../modals/AddToolModal'; +import { ActiveState } from '../models/misc'; +import { UserTool } from './types'; +import ToolConfig from './ToolConfig'; + +export default function Tools() { + const [searchTerm, setSearchTerm] = React.useState(''); + const [addToolModalState, setAddToolModalState] = + React.useState('INACTIVE'); + const [userTools, setUserTools] = React.useState([]); + const [selectedTool, setSelectedTool] = React.useState(null); + + const getUserTools = () => { + userService + .getUserTools() + .then((res) => { + return res.json(); + }) + .then((data) => { + setUserTools(data.tools); + }); + }; + + const updateToolStatus = (toolId: string, newStatus: boolean) => { + userService + .updateToolStatus({ id: toolId, status: newStatus }) + .then(() => { + setUserTools((prevTools) => + prevTools.map((tool) => + tool.id === toolId ? { ...tool, status: newStatus } : tool, + ), + ); + }) + .catch((error) => { + console.error('Failed to update tool status:', error); + }); + }; + + const handleSettingsClick = (tool: UserTool) => { + setSelectedTool(tool); + }; + + const handleGoBack = () => { + setSelectedTool(null); + getUserTools(); + }; + + React.useEffect(() => { + getUserTools(); + }, []); + return ( +
+ {selectedTool ? ( + + ) : ( +
+
+
+
+ setSearchTerm(e.target.value)} + /> +
+ +
+
+ {userTools + .filter((tool) => + tool.displayName + .toLowerCase() + .includes(searchTerm.toLowerCase()), + ) + .map((tool, index) => ( +
+
+
+ + +
+
+

+ {tool.displayName} +

+

+ {tool.description} +

+
+
+
+ +
+
+ ))} +
+
+ +
+ )} +
+ ); +} diff --git a/frontend/src/settings/index.tsx b/frontend/src/settings/index.tsx index 15c7ce086..89f29a29c 100644 --- a/frontend/src/settings/index.tsx +++ b/frontend/src/settings/index.tsx @@ -7,8 +7,8 @@ import SettingsBar from '../components/SettingsBar'; import i18n from '../locale/i18n'; import { Doc } from '../models/misc'; import { - selectSourceDocs, selectPaginatedDocuments, + selectSourceDocs, setPaginatedDocuments, setSourceDocs, } from '../preferences/preferenceSlice'; @@ -17,6 +17,7 @@ import APIKeys from './APIKeys'; import Documents from './Documents'; import General from './General'; import Logs from './Logs'; +import Tools from './Tools'; import Widgets from './Widgets'; export default function Settings() { @@ -100,6 +101,8 @@ export default function Settings() { return ; case t('settings.logs.label'): return ; + case t('settings.tools.label'): + return ; default: return null; } diff --git a/frontend/src/settings/types/index.ts b/frontend/src/settings/types/index.ts index 52a58f236..322bdfeb3 100644 --- a/frontend/src/settings/types/index.ts +++ b/frontend/src/settings/types/index.ts @@ -18,3 +18,32 @@ export type LogData = { retriever_params: Record; timestamp: string; }; + +export type UserTool = { + id: string; + name: string; + displayName: string; + description: string; + status: boolean; + config: { + [key: string]: string; + }; + actions: { + name: string; + description: string; + parameters: { + properties: { + [key: string]: { + type: string; + description: string; + filled_by_llm: boolean; + value: string; + }; + }; + additionalProperties: boolean; + required: string[]; + type: string; + }; + active: boolean; + }[]; +}; From 343569ba195f380c8b9e889dbffa8f44d4bdd86c Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 19 Dec 2024 09:58:32 +0530 Subject: [PATCH 07/15] fix: create_tool endpoint for new fields --- application/api/user/routes.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 05248f97a..bcbc2d95a 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1880,13 +1880,24 @@ def post(self): return missing_fields user = "local" + transformed_actions = [] + for action in data["actions"]: + action["active"] = True + if "parameters" in action: + if "properties" in action["parameters"]: + for param_name, param_details in action["parameters"][ + "properties" + ].items(): + param_details["filled_by_llm"] = True + param_details["value"] = "" + transformed_actions.append(action) try: new_tool = { "user": user, "name": data["name"], "displayName": data["displayName"], "description": data["description"], - "actions": data["actions"], + "actions": transformed_actions, "config": data["config"], "status": data["status"], } From c3f538c2f6d866d7be34a75abc32f49d2661910a Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 19 Dec 2024 09:59:38 +0530 Subject: [PATCH 08/15] fix: merge errors --- application/api/user/routes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 05256f288..bcbc2d95a 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1863,6 +1863,7 @@ class CreateTool(Resource): ), }, ) + ) @api.doc(description="Create a new tool") def post(self): data = request.get_json() @@ -2005,6 +2006,7 @@ class UpdateToolActions(Resource): ), }, ) + ) @api.doc(description="Update the actions of a tool") def post(self): data = request.get_json() From daa332aa20c23e66d25d0fd6e1c36204c66c5005 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 19 Dec 2024 10:06:06 +0530 Subject: [PATCH 09/15] fix: python lint errors --- application/cache.py | 30 ++++++++++++++++++---------- application/retriever/classic_rag.py | 24 ++++++++++++---------- application/tools/agent.py | 1 - application/tools/tool_manager.py | 1 - 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/application/cache.py b/application/cache.py index 7239abacd..76b594c93 100644 --- a/application/cache.py +++ b/application/cache.py @@ -1,29 +1,34 @@ -import redis -import time import json import logging +import time from threading import Lock + +import redis + from application.core.settings import settings from application.utils import get_hash -import sys logger = logging.getLogger(__name__) _redis_instance = None _instance_lock = Lock() + def get_redis_instance(): global _redis_instance if _redis_instance is None: with _instance_lock: if _redis_instance is None: try: - _redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2) + _redis_instance = redis.Redis.from_url( + settings.CACHE_REDIS_URL, socket_connect_timeout=2 + ) except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") _redis_instance = None return _redis_instance + def gen_cache_key(messages, model="docgpt", tools=None): if not all(isinstance(msg, dict) for msg in messages): raise ValueError("All messages must be dictionaries.") @@ -33,6 +38,7 @@ def gen_cache_key(messages, model="docgpt", tools=None): cache_key = get_hash(combined) return cache_key + def gen_cache(func): def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): try: @@ -42,7 +48,7 @@ def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): try: cached_response = redis_client.get(cache_key) if cached_response: - return cached_response.decode('utf-8') + return cached_response.decode("utf-8") except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") @@ -57,20 +63,22 @@ def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): except ValueError as e: logger.error(e) return "Error: No user message found in the conversation to generate a cache key." + return wrapper + def stream_cache(func): def wrapper(self, model, messages, stream, *args, **kwargs): cache_key = gen_cache_key(messages) logger.info(f"Stream cache key: {cache_key}") - + redis_client = get_redis_instance() if redis_client: try: cached_response = redis_client.get(cache_key) if cached_response: logger.info(f"Cache hit for stream key: {cache_key}") - cached_response = json.loads(cached_response.decode('utf-8')) + cached_response = json.loads(cached_response.decode("utf-8")) for chunk in cached_response: yield chunk time.sleep(0.03) @@ -80,16 +88,16 @@ def wrapper(self, model, messages, stream, *args, **kwargs): result = func(self, model, messages, stream, *args, **kwargs) stream_cache_data = [] - + for chunk in result: stream_cache_data.append(chunk) yield chunk - + if redis_client: try: redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) logger.info(f"Stream cache saved for key: {cache_key}") except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - - return wrapper \ No newline at end of file + + return wrapper diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 4ac52bc51..81a5985bc 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,10 +1,9 @@ -from application.retriever.base import BaseRetriever from application.core.settings import settings -from application.vectorstore.vector_creator import VectorCreator -from application.llm.llm_creator import LLMCreator +from application.retriever.base import BaseRetriever from application.tools.agent import Agent from application.utils import num_tokens_from_string +from application.vectorstore.vector_creator import VectorCreator class ClassicRAG(BaseRetriever): @@ -21,7 +20,7 @@ def __init__( user_api_key=None, ): self.question = question - self.vectorstore = source['active_docs'] if 'active_docs' in source else None + self.vectorstore = source["active_docs"] if "active_docs" in source else None self.chat_history = chat_history self.prompt = prompt self.chunks = chunks @@ -78,9 +77,9 @@ def gen(self): # count tokens in history for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( - i["response"] - ) + tokens_batch = num_tokens_from_string( + i["prompt"] + ) + num_tokens_from_string(i["response"]) if tokens_current_history + tokens_batch < self.token_limit: tokens_current_history += tokens_batch messages_combine.append( @@ -95,14 +94,19 @@ def gen(self): # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key # ) # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) - agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key) + agent = Agent( + llm_name=settings.LLM_NAME, + gpt_model=self.gpt_model, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, + ) completion = agent.gen(messages_combine) for line in completion: yield {"answer": str(line)} def search(self): return self._get_data() - + def get_params(self): return { "question": self.question, @@ -112,5 +116,5 @@ def get_params(self): "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, - "user_api_key": self.user_api_key + "user_api_key": self.user_api_key, } diff --git a/application/tools/agent.py b/application/tools/agent.py index ffd14770e..e02c40f71 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,7 +1,6 @@ import json from application.core.mongo_db import MongoDB -from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.tools.tool_manager import ToolManager diff --git a/application/tools/tool_manager.py b/application/tools/tool_manager.py index cc9a055a0..3e0766cfe 100644 --- a/application/tools/tool_manager.py +++ b/application/tools/tool_manager.py @@ -27,7 +27,6 @@ def load_tools(self): def load_tool(self, tool_name, tool_config): self.config[tool_name] = tool_config - tools_dir = os.path.join(os.path.dirname(__file__), "implementations") module = importlib.import_module( f"application.tools.implementations.{tool_name}" ) From f67b79f00766df29dfda9faa8ae152c0b1f3e161 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 19 Dec 2024 17:55:58 +0530 Subject: [PATCH 10/15] fix: missing yield in tool agent --- application/tools/agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/application/tools/agent.py b/application/tools/agent.py index e02c40f71..ab797992d 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -120,7 +120,9 @@ def _simple_tool_agent(self, messages): def gen(self, messages): # Generate initial response from the LLM if self.llm.supports_tools(): - self._simple_tool_agent(messages) + resp = self._simple_tool_agent(messages) + for line in resp: + yield line else: resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) for line in resp: From 4c3f990d4b86a962b56e5bd7bf8c6d03ecfb1cf3 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 19 Dec 2024 20:34:20 +0530 Subject: [PATCH 11/15] feat: tools agent refactor for custom fields and unique actions --- application/llm/groq.py | 2 +- application/tools/agent.py | 156 ++++++++++-------- application/tools/implementations/telegram.py | 27 +-- 3 files changed, 104 insertions(+), 81 deletions(-) diff --git a/application/llm/groq.py b/application/llm/groq.py index f2fcfbeb4..282d7f477 100644 --- a/application/llm/groq.py +++ b/application/llm/groq.py @@ -19,7 +19,7 @@ def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, **kwargs ) - return response.choices[0].message.content + return response.choices[0].message.content def _raw_gen_stream( self, baseself, model, messages, stream=True, tools=None, **kwargs diff --git a/application/tools/agent.py b/application/tools/agent.py index ab797992d..d4077e45d 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -4,42 +4,6 @@ from application.llm.llm_creator import LLMCreator from application.tools.tool_manager import ToolManager -tool_tg = { - "name": "telegram_send_message", - "description": "Send a notification to telegram about current chat", - "parameters": { - "type": "object", - "properties": { - "text": { - "type": "string", - "description": "Text to send in the notification", - } - }, - "required": ["text"], - "additionalProperties": False, - }, -} - -tool_crypto = { - "name": "cryptoprice_get", - "description": "Retrieve the price of a specified cryptocurrency in a given currency", - "parameters": { - "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The cryptocurrency symbol (e.g. BTC)", - }, - "currency": { - "type": "string", - "description": "The currency in which you want the price (e.g. USD)", - }, - }, - "required": ["symbol", "currency"], - "additionalProperties": False, - }, -} - class Agent: def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): @@ -49,7 +13,7 @@ def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): ) self.gpt_model = gpt_model # Static tool configuration (to be replaced later) - self.tools = [{"type": "function", "function": tool_crypto}] + self.tools = [] self.tool_config = {} def _get_user_tools(self, user="local"): @@ -58,50 +22,102 @@ def _get_user_tools(self, user="local"): user_tools_collection = db["user_tools"] user_tools = user_tools_collection.find({"user": user, "status": True}) user_tools = list(user_tools) - for tool in user_tools: - tool.pop("_id") - user_tools = {tool["name"]: tool for tool in user_tools} - return user_tools + tools_by_id = {str(tool["_id"]): tool for tool in user_tools} + return tools_by_id + + def _prepare_tools(self, tools_dict): + self.tools = [ + { + "type": "function", + "function": { + "name": f"{action['name']}_{tool_id}", + "description": action["description"], + "parameters": { + **action["parameters"], + "properties": { + k: { + key: value + for key, value in v.items() + if key != "filled_by_llm" and key != "value" + } + for k, v in action["parameters"]["properties"].items() + if v.get("filled_by_llm", False) + }, + "required": [ + key + for key in action["parameters"]["required"] + if key in action["parameters"]["properties"] + and action["parameters"]["properties"][key].get( + "filled_by_llm", False + ) + ], + }, + }, + } + for tool_id, tool in tools_dict.items() + for action in tool["actions"] + if action["active"] + ] + + def _execute_tool_action(self, tools_dict, call): + call_id = call.id + call_args = json.loads(call.function.arguments) + tool_id = call.function.name.split("_")[-1] + action_name = call.function.name.rsplit("_", 1)[0] + + tool_data = tools_dict[tool_id] + action_data = next( + action for action in tool_data["actions"] if action["name"] == action_name + ) + + for param, details in action_data["parameters"]["properties"].items(): + if param not in call_args and "value" in details: + call_args[param] = details["value"] + + tm = ToolManager(config={}) + tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"]) + print(f"Executing tool: {action_name} with args: {call_args}") + return tool.execute_action(action_name, **call_args), call_id def _simple_tool_agent(self, messages): tools_dict = self._get_user_tools() - # combine all tool_actions into one list - self.tools.extend( - [ - {"type": "function", "function": tool_action} - for tool in tools_dict.values() - for tool_action in tool["actions"] - ] - ) + self._prepare_tools(tools_dict) resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) if isinstance(resp, str): - # Yield the response if it's a string and exit yield resp return + if resp.message.content: + yield resp.message.content + return while resp.finish_reason == "tool_calls": - # Append the assistant's message to the conversation - messages.append(json.loads(resp.model_dump_json())["message"]) - # Handle each tool call + message = json.loads(resp.model_dump_json())["message"] + keys_to_remove = {"audio", "function_call", "refusal"} + filtered_data = { + k: v for k, v in message.items() if k not in keys_to_remove + } + messages.append(filtered_data) tool_calls = resp.message.tool_calls for call in tool_calls: - tm = ToolManager(config={}) - call_name = call.function.name - call_args = json.loads(call.function.arguments) - call_id = call.id - # Determine the tool name and load it - tool_name = call_name.split("_")[0] - tool = tm.load_tool( - tool_name, tool_config=tools_dict[tool_name]["config"] - ) - # Execute the tool's action - resp_tool = tool.execute_action(call_name, **call_args) - # Append the tool's response to the conversation - messages.append( - {"role": "tool", "content": str(resp_tool), "tool_call_id": call_id} - ) + try: + tool_response, call_id = self._execute_tool_action(tools_dict, call) + messages.append( + { + "role": "tool", + "content": str(tool_response), + "tool_call_id": call_id, + } + ) + except Exception as e: + messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call.id, + } + ) # Generate a new response from the LLM after processing tools resp = self.llm.gen( model=self.gpt_model, messages=messages, tools=self.tools @@ -110,6 +126,8 @@ def _simple_tool_agent(self, messages): # If no tool calls are needed, generate the final response if isinstance(resp, str): yield resp + elif resp.message.content: + yield resp.message.content else: completion = self.llm.gen_stream( model=self.gpt_model, messages=messages, tools=self.tools @@ -117,6 +135,8 @@ def _simple_tool_agent(self, messages): for line in completion: yield line + return + def gen(self, messages): # Generate initial response from the LLM if self.llm.supports_tools(): diff --git a/application/tools/implementations/telegram.py b/application/tools/implementations/telegram.py index a2b436b47..a32bbe881 100644 --- a/application/tools/implementations/telegram.py +++ b/application/tools/implementations/telegram.py @@ -12,7 +12,6 @@ class TelegramTool(Tool): def __init__(self, config): self.config = config self.token = config.get("token", "") - self.chat_id = config.get("chat_id", "") def execute_action(self, action_name, **kwargs): actions = { @@ -25,17 +24,17 @@ def execute_action(self, action_name, **kwargs): else: raise ValueError(f"Unknown action: {action_name}") - def _send_message(self, text): + def _send_message(self, text, chat_id): print(f"Sending message: {text}") url = f"https://api.telegram.org/bot{self.token}/sendMessage" - payload = {"chat_id": self.chat_id, "text": text} + payload = {"chat_id": chat_id, "text": text} response = requests.post(url, data=payload) return {"status_code": response.status_code, "message": "Message sent"} - def _send_image(self, image_url): + def _send_image(self, image_url, chat_id): print(f"Sending image: {image_url}") url = f"https://api.telegram.org/bot{self.token}/sendPhoto" - payload = {"chat_id": self.chat_id, "photo": image_url} + payload = {"chat_id": chat_id, "photo": image_url} response = requests.post(url, data=payload) return {"status_code": response.status_code, "message": "Image sent"} @@ -43,14 +42,18 @@ def get_actions_metadata(self): return [ { "name": "telegram_send_message", - "description": "Send a notification to telegram chat", + "description": "Send a notification to Telegram chat", "parameters": { "type": "object", "properties": { "text": { "type": "string", "description": "Text to send in the notification", - } + }, + "chat_id": { + "type": "string", + "description": "Chat ID to send the notification to", + }, }, "required": ["text"], "additionalProperties": False, @@ -65,7 +68,11 @@ def get_actions_metadata(self): "image_url": { "type": "string", "description": "URL of the image to send", - } + }, + "chat_id": { + "type": "string", + "description": "Chat ID to send the image to", + }, }, "required": ["image_url"], "additionalProperties": False, @@ -75,9 +82,5 @@ def get_actions_metadata(self): def get_config_requirements(self): return { - "chat_id": { - "type": "string", - "description": "Telegram chat ID to send messages to", - }, "token": {"type": "string", "description": "Bot token for authentication"}, } From 856419832146ecf9247298f9c14c66427d81d964 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 19 Dec 2024 16:02:27 +0000 Subject: [PATCH 12/15] mini-model as default --- application/api/answer/routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index f109db268..94a176934 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -37,7 +37,7 @@ gpt_model = "" # to have some kind of default behaviour if settings.LLM_NAME == "openai": - gpt_model = "gpt-3.5-turbo" + gpt_model = "gpt-4o-mini" elif settings.LLM_NAME == "anthropic": gpt_model = "claude-2" elif settings.LLM_NAME == "groq": From 6fc4723d61b3f666649579b02b228bef13051bfb Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 19 Dec 2024 16:03:39 +0000 Subject: [PATCH 13/15] feat: flask debugger for vscode --- .vscode/launch.json | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index fc4b8128f..5be1f7115 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,6 +11,26 @@ "skipFiles": [ "/**" ] - } + }, + { + "name": "Python Debugger: Flask", + "type": "debugpy", + "request": "launch", + "module": "flask", + "env": { + "FLASK_APP": "application/app.py", + "PYTHONPATH": "${workspaceFolder}", + "FLASK_ENV": "development", + "FLASK_DEBUG": "1", + "FLASK_RUN_PORT": "7091", + "FLASK_RUN_HOST": "0.0.0.0" + + }, + "args": [ + "run", + "--no-debugger" + ], + "cwd": "${workspaceFolder}", + }, ] } \ No newline at end of file From c2a95b5bec24e07d8eed3a42553d1f21f45cfdbc Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 20 Dec 2024 17:32:58 +0000 Subject: [PATCH 14/15] lint: fixing index and classc rag --- application/retriever/classic_rag.py | 1 - frontend/src/modals/types/index.ts | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index a00e191fa..2e3555137 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -2,7 +2,6 @@ from application.retriever.base import BaseRetriever from application.tools.agent import Agent -from application.utils import num_tokens_from_string from application.vectorstore.vector_creator import VectorCreator diff --git a/frontend/src/modals/types/index.ts b/frontend/src/modals/types/index.ts index c018e44cb..458496d28 100644 --- a/frontend/src/modals/types/index.ts +++ b/frontend/src/modals/types/index.ts @@ -8,6 +8,7 @@ export type AvailableTool = { description: string; parameters: object; }[]; +}; export type WrapperModalProps = { children?: React.ReactNode; From 1f75f0c0821ca413a6617404a3c94f9c1d86c5ce Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 20 Dec 2024 18:13:37 +0000 Subject: [PATCH 15/15] fix: tests --- application/llm/anthropic.py | 4 ++-- application/llm/sagemaker.py | 4 ++-- tests/llm/test_anthropic.py | 3 ++- tests/llm/test_sagemaker.py | 2 +- tests/test_cache.py | 25 ++++++++++++++----------- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 4081bcd08..1fa3b5b20 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -17,7 +17,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): self.AI_PROMPT = AI_PROMPT def _raw_gen( - self, baseself, model, messages, stream=False, max_tokens=300, **kwargs + self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs ): context = messages[0]["content"] user_question = messages[-1]["content"] @@ -34,7 +34,7 @@ def _raw_gen( return completion.completion def _raw_gen_stream( - self, baseself, model, messages, stream=True, max_tokens=300, **kwargs + self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs ): context = messages[0]["content"] user_question = messages[-1]["content"] diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 63947430d..aaf99a126 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -76,7 +76,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime - def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -105,7 +105,7 @@ def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): print(result[0]["generated_text"], file=sys.stderr) return result[0]["generated_text"][len(prompt) :] - def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py index 689013c0d..50ddbe294 100644 --- a/tests/llm/test_anthropic.py +++ b/tests/llm/test_anthropic.py @@ -46,6 +46,7 @@ def test_gen_stream(self): {"content": "question"} ] mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")] + mock_tools = Mock() with patch("application.cache.get_redis_instance") as mock_make_redis: mock_redis_instance = mock_make_redis.return_value @@ -53,7 +54,7 @@ def test_gen_stream(self): mock_redis_instance.set = Mock() with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: - responses = list(self.llm.gen_stream("test_model", messages)) + responses = list(self.llm.gen_stream("test_model", messages, tools=mock_tools)) self.assertListEqual(responses, ["response_1", "response_2"]) prompt_expected = "### Context \n context \n ### Question \n question" diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py index d659d4983..2b893a9a4 100644 --- a/tests/llm/test_sagemaker.py +++ b/tests/llm/test_sagemaker.py @@ -76,7 +76,7 @@ def test_gen_stream(self): with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', return_value=self.response) as mock_invoke_endpoint: - output = list(self.sagemaker.gen_stream(None, self.messages)) + output = list(self.sagemaker.gen_stream(None, self.messages, tools=None)) mock_invoke_endpoint.assert_called_once_with( EndpointName=self.sagemaker.endpoint, ContentType='application/json', diff --git a/tests/test_cache.py b/tests/test_cache.py index 4270a1819..af2b5e008 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -12,18 +12,21 @@ def test_make_gen_cache_key(): {'role': 'system', 'content': 'test_system_message'}, ] model = "test_docgpt" + tools = None # Manually calculate the expected hash - expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}" + messages_str = json.dumps(messages) + tools_str = json.dumps(tools) if tools else "" + expected_combined = f"{model}_{messages_str}_{tools_str}" expected_hash = get_hash(expected_combined) - cache_key = gen_cache_key(*messages, model=model) + cache_key = gen_cache_key(messages, model=model, tools=None) assert cache_key == expected_hash def test_gen_cache_key_invalid_message_format(): # Test when messages is not a list with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context: - gen_cache_key("This is not a list", model="docgpt") + gen_cache_key("This is not a list", model="docgpt", tools=None) assert str(context.exception) == "All messages must be dictionaries." # Test for gen_cache decorator @@ -35,14 +38,14 @@ def test_gen_cache_hit(mock_make_redis): mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit @gen_cache - def mock_function(self, model, messages): + def mock_function(self, model, messages, stream, tools): return "new_result" messages = [{'role': 'user', 'content': 'test_user_message'}] model = "test_docgpt" # Act - result = mock_function(None, model, messages) + result = mock_function(None, model, messages, stream=False, tools=None) # Assert assert result == "cached_result" # Should return cached result @@ -58,7 +61,7 @@ def test_gen_cache_miss(mock_make_redis): mock_redis_instance.get.return_value = None # Simulate a cache miss @gen_cache - def mock_function(self, model, messages): + def mock_function(self, model, messages, steam, tools): return "new_result" messages = [ @@ -67,7 +70,7 @@ def mock_function(self, model, messages): ] model = "test_docgpt" # Act - result = mock_function(None, model, messages) + result = mock_function(None, model, messages, stream=False, tools=None) # Assert assert result == "new_result" @@ -83,14 +86,14 @@ def test_stream_cache_hit(mock_make_redis): mock_redis_instance.get.return_value = cached_chunk @stream_cache - def mock_function(self, model, messages, stream): + def mock_function(self, model, messages, stream, tools): yield "new_chunk" messages = [{'role': 'user', 'content': 'test_user_message'}] model = "test_docgpt" # Act - result = list(mock_function(None, model, messages, stream=True)) + result = list(mock_function(None, model, messages, stream=True, tools=None)) # Assert assert result == ["chunk1", "chunk2"] # Should return cached chunks @@ -106,7 +109,7 @@ def test_stream_cache_miss(mock_make_redis): mock_redis_instance.get.return_value = None # Simulate a cache miss @stream_cache - def mock_function(self, model, messages, stream): + def mock_function(self, model, messages, stream, tools): yield "new_chunk" messages = [ @@ -117,7 +120,7 @@ def mock_function(self, model, messages, stream): model = "test_docgpt" # Act - result = list(mock_function(None, model, messages, stream=True)) + result = list(mock_function(None, model, messages, stream=True, tools=None)) # Assert assert result == ["new_chunk"]