diff --git a/.vscode/launch.json b/.vscode/launch.json index fc4b8128f..5be1f7115 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,6 +11,26 @@ "skipFiles": [ "/**" ] - } + }, + { + "name": "Python Debugger: Flask", + "type": "debugpy", + "request": "launch", + "module": "flask", + "env": { + "FLASK_APP": "application/app.py", + "PYTHONPATH": "${workspaceFolder}", + "FLASK_ENV": "development", + "FLASK_DEBUG": "1", + "FLASK_RUN_PORT": "7091", + "FLASK_RUN_HOST": "0.0.0.0" + + }, + "args": [ + "run", + "--no-debugger" + ], + "cwd": "${workspaceFolder}", + }, ] } \ No newline at end of file diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index f109db268..94a176934 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -37,7 +37,7 @@ gpt_model = "" # to have some kind of default behaviour if settings.LLM_NAME == "openai": - gpt_model = "gpt-3.5-turbo" + gpt_model = "gpt-4o-mini" elif settings.LLM_NAME == "anthropic": gpt_model = "claude-2" elif settings.LLM_NAME == "groq": diff --git a/application/api/user/routes.py b/application/api/user/routes.py index a5807b81e..bcbc2d95a 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1,14 +1,14 @@ import datetime +import math import os import shutil import uuid -import math from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef from bson.objectid import ObjectId -from flask import Blueprint, jsonify, make_response, request, redirect -from flask_restx import inputs, fields, Namespace, Resource +from flask import Blueprint, jsonify, make_response, redirect, request +from flask_restx import fields, inputs, Namespace, Resource from werkzeug.utils import secure_filename from application.api.user.tasks import ingest, ingest_remote @@ -16,9 +16,10 @@ from application.core.mongo_db import MongoDB from application.core.settings import settings from application.extensions import api +from application.tools.tool_manager import ToolManager +from application.tts.google_tts import GoogleTTS from application.utils import check_required_fields from application.vectorstore.vector_creator import VectorCreator -from application.tts.google_tts import GoogleTTS mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -40,6 +41,9 @@ os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) +tool_config = {} +tool_manager = ToolManager(config=tool_config) + def generate_minute_range(start_date, end_date): return { @@ -1789,35 +1793,112 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) +@user_ns.route("/api/available_tools") +class AvailableTools(Resource): + @api.doc(description="Get available tools for a user") + def get(self): + try: + tools_metadata = [] + for tool_name, tool_instance in tool_manager.tools.items(): + doc = tool_instance.__doc__.strip() + lines = doc.split("\n", 1) + name = lines[0].strip() + description = lines[1].strip() if len(lines) > 1 else "" + tools_metadata.append( + { + "name": tool_name, + "displayName": name, + "description": description, + "configRequirements": tool_instance.get_config_requirements(), + "actions": tool_instance.get_actions_metadata(), + } + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True, "data": tools_metadata}), 200) + + +@user_ns.route("/api/get_tools") +class GetTools(Resource): + @api.doc(description="Get tools created by a user") + def get(self): + try: + user = "local" + tools = user_tools_collection.find({"user": user}) + user_tools = [] + for tool in tools: + tool["id"] = str(tool["_id"]) + tool.pop("_id") + user_tools.append(tool) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True, "tools": user_tools}), 200) + + @user_ns.route("/api/create_tool") class CreateTool(Resource): - create_tool_model = api.model( - "CreateToolModel", - { - "name": fields.String(required=True, description="Name of the tool"), - "config": fields.Raw(required=True, description="Configuration of the tool"), - "actions": fields.List(fields.String, required=True, description="Actions the tool can perform"), - "status": fields.Boolean(required=True, description="Status of the tool") - - }, + @api.expect( + api.model( + "CreateToolModel", + { + "name": fields.String(required=True, description="Name of the tool"), + "displayName": fields.String( + required=True, description="Display name for the tool" + ), + "description": fields.String( + required=True, description="Tool description" + ), + "config": fields.Raw( + required=True, description="Configuration of the tool" + ), + "actions": fields.List( + fields.Raw, + required=True, + description="Actions the tool can perform", + ), + "status": fields.Boolean( + required=True, description="Status of the tool" + ), + }, + ) ) - - @api.expect(create_tool_model) @api.doc(description="Create a new tool") def post(self): data = request.get_json() - required_fields = ["name", "config", "actions", "status"] + required_fields = [ + "name", + "displayName", + "description", + "actions", + "config", + "status", + ] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields user = "local" + transformed_actions = [] + for action in data["actions"]: + action["active"] = True + if "parameters" in action: + if "properties" in action["parameters"]: + for param_name, param_details in action["parameters"][ + "properties" + ].items(): + param_details["filled_by_llm"] = True + param_details["value"] = "" + transformed_actions.append(action) try: new_tool = { + "user": user, "name": data["name"], + "displayName": data["displayName"], + "description": data["description"], + "actions": transformed_actions, "config": data["config"], - "actions": data["actions"], - "user": user, "status": data["status"], } resp = user_tools_collection.insert_one(new_tool) @@ -1826,18 +1907,72 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"id": new_id}), 200) - + + +@user_ns.route("/api/update_tool") +class UpdateTool(Resource): + @api.expect( + api.model( + "UpdateToolModel", + { + "id": fields.String(required=True, description="Tool ID"), + "name": fields.String(description="Name of the tool"), + "displayName": fields.String(description="Display name for the tool"), + "description": fields.String(description="Tool description"), + "config": fields.Raw(description="Configuration of the tool"), + "actions": fields.List( + fields.Raw, description="Actions the tool can perform" + ), + "status": fields.Boolean(description="Status of the tool"), + }, + ) + ) + @api.doc(description="Update a tool by ID") + def post(self): + data = request.get_json() + required_fields = ["id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + update_data = {} + if "name" in data: + update_data["name"] = data["name"] + if "displayName" in data: + update_data["displayName"] = data["displayName"] + if "description" in data: + update_data["description"] = data["description"] + if "actions" in data: + update_data["actions"] = data["actions"] + if "config" in data: + update_data["config"] = data["config"] + if "status" in data: + update_data["status"] = data["status"] + + user_tools_collection.update_one( + {"_id": ObjectId(data["id"]), "user": "local"}, + {"$set": update_data}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + @user_ns.route("/api/update_tool_config") class UpdateToolConfig(Resource): - update_tool_config_model = api.model( - "UpdateToolConfigModel", - { - "id": fields.String(required=True, description="Tool ID"), - "config": fields.Raw(required=True, description="Configuration of the tool"), - }, + @api.expect( + api.model( + "UpdateToolConfigModel", + { + "id": fields.String(required=True, description="Tool ID"), + "config": fields.Raw( + required=True, description="Configuration of the tool" + ), + }, + ) ) - - @api.expect(update_tool_config_model) @api.doc(description="Update the configuration of a tool") def post(self): data = request.get_json() @@ -1855,18 +1990,23 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"success": True}), 200) - + + @user_ns.route("/api/update_tool_actions") class UpdateToolActions(Resource): - update_tool_actions_model = api.model( - "UpdateToolActionsModel", - { - "id": fields.String(required=True, description="Tool ID"), - "actions": fields.List(fields.String, required=True, description="Actions the tool can perform"), - }, + @api.expect( + api.model( + "UpdateToolActionsModel", + { + "id": fields.String(required=True, description="Tool ID"), + "actions": fields.List( + fields.Raw, + required=True, + description="Actions the tool can perform", + ), + }, + ) ) - - @api.expect(update_tool_actions_model) @api.doc(description="Update the actions of a tool") def post(self): data = request.get_json() @@ -1885,17 +2025,20 @@ def post(self): return make_response(jsonify({"success": True}), 200) + @user_ns.route("/api/update_tool_status") class UpdateToolStatus(Resource): - update_tool_status_model = api.model( - "UpdateToolStatusModel", - { - "id": fields.String(required=True, description="Tool ID"), - "status": fields.Boolean(required=True, description="Status of the tool"), - }, + @api.expect( + api.model( + "UpdateToolStatusModel", + { + "id": fields.String(required=True, description="Tool ID"), + "status": fields.Boolean( + required=True, description="Status of the tool" + ), + }, + ) ) - - @api.expect(update_tool_status_model) @api.doc(description="Update the status of a tool") def post(self): data = request.get_json() @@ -1913,15 +2056,16 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) return make_response(jsonify({"success": True}), 200) - + + @user_ns.route("/api/delete_tool") class DeleteTool(Resource): - delete_tool_model = api.model( - "DeleteToolModel", - {"id": fields.String(required=True, description="Tool ID")}, + @api.expect( + api.model( + "DeleteToolModel", + {"id": fields.String(required=True, description="Tool ID")}, + ) ) - - @api.expect(delete_tool_model) @api.doc(description="Delete a tool by ID") def post(self): data = request.get_json() @@ -1937,4 +2081,4 @@ def post(self): except Exception as err: return {"success": False, "error": str(err)}, 400 - return {"success": True}, 200 \ No newline at end of file + return {"success": True}, 200 diff --git a/application/cache.py b/application/cache.py index 7239abacd..76b594c93 100644 --- a/application/cache.py +++ b/application/cache.py @@ -1,29 +1,34 @@ -import redis -import time import json import logging +import time from threading import Lock + +import redis + from application.core.settings import settings from application.utils import get_hash -import sys logger = logging.getLogger(__name__) _redis_instance = None _instance_lock = Lock() + def get_redis_instance(): global _redis_instance if _redis_instance is None: with _instance_lock: if _redis_instance is None: try: - _redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2) + _redis_instance = redis.Redis.from_url( + settings.CACHE_REDIS_URL, socket_connect_timeout=2 + ) except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") _redis_instance = None return _redis_instance + def gen_cache_key(messages, model="docgpt", tools=None): if not all(isinstance(msg, dict) for msg in messages): raise ValueError("All messages must be dictionaries.") @@ -33,6 +38,7 @@ def gen_cache_key(messages, model="docgpt", tools=None): cache_key = get_hash(combined) return cache_key + def gen_cache(func): def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): try: @@ -42,7 +48,7 @@ def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): try: cached_response = redis_client.get(cache_key) if cached_response: - return cached_response.decode('utf-8') + return cached_response.decode("utf-8") except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") @@ -57,20 +63,22 @@ def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): except ValueError as e: logger.error(e) return "Error: No user message found in the conversation to generate a cache key." + return wrapper + def stream_cache(func): def wrapper(self, model, messages, stream, *args, **kwargs): cache_key = gen_cache_key(messages) logger.info(f"Stream cache key: {cache_key}") - + redis_client = get_redis_instance() if redis_client: try: cached_response = redis_client.get(cache_key) if cached_response: logger.info(f"Cache hit for stream key: {cache_key}") - cached_response = json.loads(cached_response.decode('utf-8')) + cached_response = json.loads(cached_response.decode("utf-8")) for chunk in cached_response: yield chunk time.sleep(0.03) @@ -80,16 +88,16 @@ def wrapper(self, model, messages, stream, *args, **kwargs): result = func(self, model, messages, stream, *args, **kwargs) stream_cache_data = [] - + for chunk in result: stream_cache_data.append(chunk) yield chunk - + if redis_client: try: redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) logger.info(f"Stream cache saved for key: {cache_key}") except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - - return wrapper \ No newline at end of file + + return wrapper diff --git a/application/llm/groq.py b/application/llm/groq.py index b5731a905..282d7f477 100644 --- a/application/llm/groq.py +++ b/application/llm/groq.py @@ -1,45 +1,32 @@ from application.llm.base import BaseLLM - +from openai import OpenAI class GroqLLM(BaseLLM): - def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): - from openai import OpenAI - super().__init__(*args, **kwargs) self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1") self.api_key = api_key self.user_api_key = user_api_key - def _raw_gen( - self, - baseself, - model, - messages, - stream=False, - **kwargs - ): - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) - - return response.choices[0].message.content + def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): + if tools: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, tools=tools, **kwargs + ) + return response.choices[0] + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + return response.choices[0].message.content def _raw_gen_stream( - self, - baseself, - model, - messages, - stream=True, - **kwargs - ): + self, baseself, model, messages, stream=True, tools=None, **kwargs + ): response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, **kwargs ) - for line in response: - # import sys - # print(line.choices[0].delta.content, file=sys.stderr) if line.choices[0].delta.content is not None: yield line.choices[0].delta.content diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 4ac52bc51..81a5985bc 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,10 +1,9 @@ -from application.retriever.base import BaseRetriever from application.core.settings import settings -from application.vectorstore.vector_creator import VectorCreator -from application.llm.llm_creator import LLMCreator +from application.retriever.base import BaseRetriever from application.tools.agent import Agent from application.utils import num_tokens_from_string +from application.vectorstore.vector_creator import VectorCreator class ClassicRAG(BaseRetriever): @@ -21,7 +20,7 @@ def __init__( user_api_key=None, ): self.question = question - self.vectorstore = source['active_docs'] if 'active_docs' in source else None + self.vectorstore = source["active_docs"] if "active_docs" in source else None self.chat_history = chat_history self.prompt = prompt self.chunks = chunks @@ -78,9 +77,9 @@ def gen(self): # count tokens in history for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( - i["response"] - ) + tokens_batch = num_tokens_from_string( + i["prompt"] + ) + num_tokens_from_string(i["response"]) if tokens_current_history + tokens_batch < self.token_limit: tokens_current_history += tokens_batch messages_combine.append( @@ -95,14 +94,19 @@ def gen(self): # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key # ) # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) - agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key) + agent = Agent( + llm_name=settings.LLM_NAME, + gpt_model=self.gpt_model, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, + ) completion = agent.gen(messages_combine) for line in completion: yield {"answer": str(line)} def search(self): return self._get_data() - + def get_params(self): return { "question": self.question, @@ -112,5 +116,5 @@ def get_params(self): "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, - "user_api_key": self.user_api_key + "user_api_key": self.user_api_key, } diff --git a/application/tools/agent.py b/application/tools/agent.py index af23b99e2..d4077e45d 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,130 +1,149 @@ +import json + +from application.core.mongo_db import MongoDB from application.llm.llm_creator import LLMCreator -from application.core.settings import settings from application.tools.tool_manager import ToolManager -from application.core.mongo_db import MongoDB -import json -tool_tg = { - "name": "telegram_send_message", - "description": "Send a notification to telegram about current chat", - "parameters": { - "type": "object", - "properties": { - "text": { - "type": "string", - "description": "Text to send in the notification" - } - }, - "required": ["text"], - "additionalProperties": False - } -} - -tool_crypto = { - "name": "cryptoprice_get", - "description": "Retrieve the price of a specified cryptocurrency in a given currency", - "parameters": { - "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The cryptocurrency symbol (e.g. BTC)" - }, - "currency": { - "type": "string", - "description": "The currency in which you want the price (e.g. USD)" - } - }, - "required": ["symbol", "currency"], - "additionalProperties": False - } -} class Agent: def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): # Initialize the LLM with the provided parameters - self.llm = LLMCreator.create_llm(llm_name, api_key=api_key, user_api_key=user_api_key) + self.llm = LLMCreator.create_llm( + llm_name, api_key=api_key, user_api_key=user_api_key + ) self.gpt_model = gpt_model # Static tool configuration (to be replaced later) - self.tools = [ - { - "type": "function", - "function": tool_crypto - } - ] - self.tool_config = { - } - + self.tools = [] + self.tool_config = {} + def _get_user_tools(self, user="local"): mongo = MongoDB.get_client() db = mongo["docsgpt"] user_tools_collection = db["user_tools"] user_tools = user_tools_collection.find({"user": user, "status": True}) user_tools = list(user_tools) - for tool in user_tools: - tool.pop("_id") - user_tools = {tool["name"]: tool for tool in user_tools} - return user_tools - - def _simple_tool_agent(self, messages): - tools_dict = self._get_user_tools() - # combine all tool_actions into one list - self.tools.extend([ + tools_by_id = {str(tool["_id"]): tool for tool in user_tools} + return tools_by_id + + def _prepare_tools(self, tools_dict): + self.tools = [ { "type": "function", - "function": tool_action + "function": { + "name": f"{action['name']}_{tool_id}", + "description": action["description"], + "parameters": { + **action["parameters"], + "properties": { + k: { + key: value + for key, value in v.items() + if key != "filled_by_llm" and key != "value" + } + for k, v in action["parameters"]["properties"].items() + if v.get("filled_by_llm", False) + }, + "required": [ + key + for key in action["parameters"]["required"] + if key in action["parameters"]["properties"] + and action["parameters"]["properties"][key].get( + "filled_by_llm", False + ) + ], + }, + }, } - for tool in tools_dict.values() - for tool_action in tool["actions"] - ]) + for tool_id, tool in tools_dict.items() + for action in tool["actions"] + if action["active"] + ] + + def _execute_tool_action(self, tools_dict, call): + call_id = call.id + call_args = json.loads(call.function.arguments) + tool_id = call.function.name.split("_")[-1] + action_name = call.function.name.rsplit("_", 1)[0] + tool_data = tools_dict[tool_id] + action_data = next( + action for action in tool_data["actions"] if action["name"] == action_name + ) + + for param, details in action_data["parameters"]["properties"].items(): + if param not in call_args and "value" in details: + call_args[param] = details["value"] + + tm = ToolManager(config={}) + tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"]) + print(f"Executing tool: {action_name} with args: {call_args}") + return tool.execute_action(action_name, **call_args), call_id + + def _simple_tool_agent(self, messages): + tools_dict = self._get_user_tools() + self._prepare_tools(tools_dict) resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) if isinstance(resp, str): - # Yield the response if it's a string and exit yield resp return + if resp.message.content: + yield resp.message.content + return while resp.finish_reason == "tool_calls": - # Append the assistant's message to the conversation - messages.append(json.loads(resp.model_dump_json())['message']) - # Handle each tool call + message = json.loads(resp.model_dump_json())["message"] + keys_to_remove = {"audio", "function_call", "refusal"} + filtered_data = { + k: v for k, v in message.items() if k not in keys_to_remove + } + messages.append(filtered_data) tool_calls = resp.message.tool_calls for call in tool_calls: - tm = ToolManager(config={}) - call_name = call.function.name - call_args = json.loads(call.function.arguments) - call_id = call.id - # Determine the tool name and load it - tool_name = call_name.split("_")[0] - tool = tm.load_tool(tool_name, tool_config=tools_dict[tool_name]['config']) - # Execute the tool's action - resp_tool = tool.execute_action(call_name, **call_args) - # Append the tool's response to the conversation - messages.append( - { - "role": "tool", - "content": str(resp_tool), - "tool_call_id": call_id - } - ) + try: + tool_response, call_id = self._execute_tool_action(tools_dict, call) + messages.append( + { + "role": "tool", + "content": str(tool_response), + "tool_call_id": call_id, + } + ) + except Exception as e: + messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call.id, + } + ) # Generate a new response from the LLM after processing tools - resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + resp = self.llm.gen( + model=self.gpt_model, messages=messages, tools=self.tools + ) # If no tool calls are needed, generate the final response if isinstance(resp, str): yield resp + elif resp.message.content: + yield resp.message.content else: - completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools) + completion = self.llm.gen_stream( + model=self.gpt_model, messages=messages, tools=self.tools + ) for line in completion: yield line + return + def gen(self, messages): # Generate initial response from the LLM if self.llm.supports_tools(): - self._simple_tool_agent(messages) + resp = self._simple_tool_agent(messages) + for line in resp: + yield line else: resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) for line in resp: - yield line \ No newline at end of file + yield line diff --git a/application/tools/base.py b/application/tools/base.py index 00cfee3a0..fd7b4a852 100644 --- a/application/tools/base.py +++ b/application/tools/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class Tool(ABC): @abstractmethod def execute_action(self, action_name: str, **kwargs): diff --git a/application/tools/cryptoprice.py b/application/tools/implementations/cryptoprice.py similarity index 82% rename from application/tools/cryptoprice.py rename to application/tools/implementations/cryptoprice.py index d7cf61e13..7b88c866d 100644 --- a/application/tools/cryptoprice.py +++ b/application/tools/implementations/cryptoprice.py @@ -1,21 +1,25 @@ -from application.tools.base import Tool import requests +from application.tools.base import Tool + class CryptoPriceTool(Tool): + """ + CryptoPrice + A tool for retrieving cryptocurrency prices using the CryptoCompare public API + """ + def __init__(self, config): self.config = config def execute_action(self, action_name, **kwargs): - actions = { - "cryptoprice_get": self.get_price - } + actions = {"cryptoprice_get": self._get_price} if action_name in actions: return actions[action_name](**kwargs) else: raise ValueError(f"Unknown action: {action_name}") - def get_price(self, symbol, currency): + def _get_price(self, symbol, currency): """ Fetches the current price of a given cryptocurrency symbol in the specified currency. Example: @@ -32,17 +36,17 @@ def get_price(self, symbol, currency): return { "status_code": response.status_code, "price": data[currency.upper()], - "message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully." + "message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.", } else: return { "status_code": response.status_code, - "message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}." + "message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.", } else: return { "status_code": response.status_code, - "message": "Failed to retrieve price." + "message": "Failed to retrieve price.", } def get_actions_metadata(self): @@ -55,16 +59,16 @@ def get_actions_metadata(self): "properties": { "symbol": { "type": "string", - "description": "The cryptocurrency symbol (e.g. BTC)" + "description": "The cryptocurrency symbol (e.g. BTC)", }, "currency": { "type": "string", - "description": "The currency in which you want the price (e.g. USD)" - } + "description": "The currency in which you want the price (e.g. USD)", + }, }, "required": ["symbol", "currency"], - "additionalProperties": False - } + "additionalProperties": False, + }, } ] diff --git a/application/tools/telegram.py b/application/tools/implementations/telegram.py similarity index 59% rename from application/tools/telegram.py rename to application/tools/implementations/telegram.py index 8210d8e71..a32bbe881 100644 --- a/application/tools/telegram.py +++ b/application/tools/implementations/telegram.py @@ -1,16 +1,22 @@ -from application.tools.base import Tool import requests +from application.tools.base import Tool + class TelegramTool(Tool): + """ + Telegram Bot + A flexible Telegram tool for performing various actions (e.g., sending messages, images). + Requires a bot token and chat ID for configuration + """ + def __init__(self, config): self.config = config - self.chat_id = config.get("chat_id", "142189016") - self.token = config.get("token", "YOUR_TG_TOKEN") + self.token = config.get("token", "") def execute_action(self, action_name, **kwargs): actions = { - "telegram_send_message": self.send_message, - "telegram_send_image": self.send_image + "telegram_send_message": self._send_message, + "telegram_send_image": self._send_image, } if action_name in actions: @@ -18,17 +24,17 @@ def execute_action(self, action_name, **kwargs): else: raise ValueError(f"Unknown action: {action_name}") - def send_message(self, text): + def _send_message(self, text, chat_id): print(f"Sending message: {text}") url = f"https://api.telegram.org/bot{self.token}/sendMessage" - payload = {"chat_id": self.chat_id, "text": text} + payload = {"chat_id": chat_id, "text": text} response = requests.post(url, data=payload) return {"status_code": response.status_code, "message": "Message sent"} - def send_image(self, image_url): + def _send_image(self, image_url, chat_id): print(f"Sending image: {image_url}") url = f"https://api.telegram.org/bot{self.token}/sendPhoto" - payload = {"chat_id": self.chat_id, "photo": image_url} + payload = {"chat_id": chat_id, "photo": image_url} response = requests.post(url, data=payload) return {"status_code": response.status_code, "message": "Image sent"} @@ -36,18 +42,22 @@ def get_actions_metadata(self): return [ { "name": "telegram_send_message", - "description": "Send a notification to telegram chat", + "description": "Send a notification to Telegram chat", "parameters": { "type": "object", "properties": { "text": { "type": "string", - "description": "Text to send in the notification" - } + "description": "Text to send in the notification", + }, + "chat_id": { + "type": "string", + "description": "Chat ID to send the notification to", + }, }, "required": ["text"], - "additionalProperties": False - } + "additionalProperties": False, + }, }, { "name": "telegram_send_image", @@ -57,23 +67,20 @@ def get_actions_metadata(self): "properties": { "image_url": { "type": "string", - "description": "URL of the image to send" - } + "description": "URL of the image to send", + }, + "chat_id": { + "type": "string", + "description": "Chat ID to send the image to", + }, }, "required": ["image_url"], - "additionalProperties": False - } - } + "additionalProperties": False, + }, + }, ] def get_config_requirements(self): return { - "chat_id": { - "type": "string", - "description": "Telegram chat ID to send messages to" - }, - "token": { - "type": "string", - "description": "Bot token for authentication" - } + "token": {"type": "string", "description": "Bot token for authentication"}, } diff --git a/application/tools/tool_manager.py b/application/tools/tool_manager.py index 10231cb28..3e0766cfe 100644 --- a/application/tools/tool_manager.py +++ b/application/tools/tool_manager.py @@ -1,10 +1,11 @@ import importlib import inspect -import pkgutil import os +import pkgutil from application.tools.base import Tool + class ToolManager: def __init__(self, config): self.config = config @@ -12,11 +13,13 @@ def __init__(self, config): self.load_tools() def load_tools(self): - tools_dir = os.path.dirname(__file__) + tools_dir = os.path.join(os.path.dirname(__file__), "implementations") for finder, name, ispkg in pkgutil.iter_modules([tools_dir]): - if name == 'base' or name.startswith('__'): + if name == "base" or name.startswith("__"): continue - module = importlib.import_module(f'application.tools.{name}') + module = importlib.import_module( + f"application.tools.implementations.{name}" + ) for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: tool_config = self.config.get(name, {}) @@ -24,13 +27,13 @@ def load_tools(self): def load_tool(self, tool_name, tool_config): self.config[tool_name] = tool_config - tools_dir = os.path.dirname(__file__) - module = importlib.import_module(f'application.tools.{tool_name}') + module = importlib.import_module( + f"application.tools.implementations.{tool_name}" + ) for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: return obj(tool_config) - def execute_action(self, tool_name, action_name, **kwargs): if tool_name not in self.tools: raise ValueError(f"Tool '{tool_name}' not loaded") diff --git a/frontend/public/toolIcons/tool_cryptoprice.svg b/frontend/public/toolIcons/tool_cryptoprice.svg new file mode 100644 index 000000000..6a4226949 --- /dev/null +++ b/frontend/public/toolIcons/tool_cryptoprice.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/public/toolIcons/tool_telegram.svg b/frontend/public/toolIcons/tool_telegram.svg new file mode 100644 index 000000000..27536dedf --- /dev/null +++ b/frontend/public/toolIcons/tool_telegram.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 4e7112d04..8a7f9ae23 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -18,6 +18,11 @@ const endpoints = { FEEDBACK_ANALYTICS: '/api/get_feedback_analytics', LOGS: `/api/get_user_logs`, MANAGE_SYNC: '/api/manage_sync', + GET_AVAILABLE_TOOLS: '/api/available_tools', + GET_USER_TOOLS: '/api/get_tools', + CREATE_TOOL: '/api/create_tool', + UPDATE_TOOL_STATUS: '/api/update_tool_status', + UPDATE_TOOL: '/api/update_tool', }, CONVERSATION: { ANSWER: '/api/answer', diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 942318ae4..ab91a0a42 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -35,6 +35,16 @@ const userService = { apiClient.post(endpoints.USER.LOGS, data), manageSync: (data: any): Promise => apiClient.post(endpoints.USER.MANAGE_SYNC, data), + getAvailableTools: (): Promise => + apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS), + getUserTools: (): Promise => + apiClient.get(endpoints.USER.GET_USER_TOOLS), + createTool: (data: any): Promise => + apiClient.post(endpoints.USER.CREATE_TOOL, data), + updateToolStatus: (data: any): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data), + updateTool: (data: any): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL, data), }; export default userService; diff --git a/frontend/src/assets/cogwheel.svg b/frontend/src/assets/cogwheel.svg new file mode 100644 index 000000000..f5299b8ba --- /dev/null +++ b/frontend/src/assets/cogwheel.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/src/components/SettingsBar.tsx b/frontend/src/components/SettingsBar.tsx index f617c6e82..c6970600d 100644 --- a/frontend/src/components/SettingsBar.tsx +++ b/frontend/src/components/SettingsBar.tsx @@ -13,6 +13,7 @@ const useTabs = () => { t('settings.apiKeys.label'), t('settings.analytics.label'), t('settings.logs.label'), + t('settings.tools.label'), ]; return tabs; }; diff --git a/frontend/src/index.css b/frontend/src/index.css index 9bd2ec964..ecd4da1eb 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -68,6 +68,15 @@ body.dark { .table-default td:last-child { @apply border-r-0; /* Ensure no right border on the last column */ } + + .table-default th, + .table-default td { + min-width: 150px; + max-width: 320px; + overflow: auto; + scrollbar-width: thin; + scrollbar-color: grey transparent; + } } /*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */ diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 134604633..4048f6324 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -73,6 +73,9 @@ }, "logs": { "label": "Logs" + }, + "tools": { + "label": "Tools" } }, "modals": { diff --git a/frontend/src/modals/AddToolModal.tsx b/frontend/src/modals/AddToolModal.tsx new file mode 100644 index 000000000..330cf3bb2 --- /dev/null +++ b/frontend/src/modals/AddToolModal.tsx @@ -0,0 +1,136 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import Exit from '../assets/exit.svg'; +import { ActiveState } from '../models/misc'; +import { AvailableTool } from './types'; +import ConfigToolModal from './ConfigToolModal'; + +export default function AddToolModal({ + message, + modalState, + setModalState, + getUserTools, +}: { + message: string; + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + getUserTools: () => void; +}) { + const [availableTools, setAvailableTools] = React.useState( + [], + ); + const [selectedTool, setSelectedTool] = React.useState( + null, + ); + const [configModalState, setConfigModalState] = + React.useState('INACTIVE'); + + const getAvailableTools = () => { + userService + .getAvailableTools() + .then((res) => { + return res.json(); + }) + .then((data) => { + setAvailableTools(data.data); + }); + }; + + const handleAddTool = (tool: AvailableTool) => { + if (Object.keys(tool.configRequirements).length === 0) { + userService + .createTool({ + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: {}, + actions: tool.actions, + status: true, + }) + .then((res) => { + if (res.status === 200) { + getUserTools(); + setModalState('INACTIVE'); + } + }); + } else { + setModalState('INACTIVE'); + setConfigModalState('ACTIVE'); + } + }; + + React.useEffect(() => { + if (modalState === 'ACTIVE') getAvailableTools(); + }, [modalState]); + return ( + <> +
+
+
+ +
+

+ Select a tool to set up +

+
+ {availableTools.map((tool, index) => ( +
{ + setSelectedTool(tool); + handleAddTool(tool); + }} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + setSelectedTool(tool); + handleAddTool(tool); + } + }} + > +
+
+ +
+
+

+ {tool.displayName} +

+

+ {tool.description} +

+
+
+
+ ))} +
+
+
+
+
+ + + ); +} diff --git a/frontend/src/modals/ConfigToolModal.tsx b/frontend/src/modals/ConfigToolModal.tsx new file mode 100644 index 000000000..96bb15be8 --- /dev/null +++ b/frontend/src/modals/ConfigToolModal.tsx @@ -0,0 +1,95 @@ +import React from 'react'; + +import Exit from '../assets/exit.svg'; +import Input from '../components/Input'; +import { ActiveState } from '../models/misc'; +import { AvailableTool } from './types'; +import userService from '../api/services/userService'; + +export default function ConfigToolModal({ + modalState, + setModalState, + tool, + getUserTools, +}: { + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + tool: AvailableTool | null; + getUserTools: () => void; +}) { + const [authKey, setAuthKey] = React.useState(''); + + const handleAddTool = (tool: AvailableTool) => { + userService + .createTool({ + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: { token: authKey }, + actions: tool.actions, + status: true, + }) + .then(() => { + setModalState('INACTIVE'); + getUserTools(); + }); + }; + return ( +
+
+
+ +
+

+ Tool Config +

+

+ Type: {tool?.name} +

+
+ + API Key / Oauth + + setAuthKey(e.target.value)} + borderVariant="thin" + placeholder="Enter API Key / Oauth" + > +
+
+ + +
+
+
+
+
+ ); +} diff --git a/frontend/src/modals/types/index.ts b/frontend/src/modals/types/index.ts new file mode 100644 index 000000000..044ba4386 --- /dev/null +++ b/frontend/src/modals/types/index.ts @@ -0,0 +1,11 @@ +export type AvailableTool = { + name: string; + displayName: string; + description: string; + configRequirements: object; + actions: { + name: string; + description: string; + parameters: object; + }[]; +}; diff --git a/frontend/src/settings/Documents.tsx b/frontend/src/settings/Documents.tsx index f91a33559..6375b4e27 100644 --- a/frontend/src/settings/Documents.tsx +++ b/frontend/src/settings/Documents.tsx @@ -215,18 +215,22 @@ const Documents: React.FC = ({ {document.type === 'remote' ? 'Pre-loaded' : 'Private'} -
+
{document.type !== 'remote' && ( - Delete { event.stopPropagation(); handleDeleteDocument(index, document); }} - /> + > + Delete + )} {document.syncFrequency && (
diff --git a/frontend/src/settings/ToolConfig.tsx b/frontend/src/settings/ToolConfig.tsx new file mode 100644 index 000000000..0de4dab96 --- /dev/null +++ b/frontend/src/settings/ToolConfig.tsx @@ -0,0 +1,293 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import ArrowLeft from '../assets/arrow-left.svg'; +import Input from '../components/Input'; +import { UserTool } from './types'; + +export default function ToolConfig({ + tool, + setTool, + handleGoBack, +}: { + tool: UserTool; + setTool: (tool: UserTool) => void; + handleGoBack: () => void; +}) { + const [authKey, setAuthKey] = React.useState( + tool.config?.token || '', + ); + + const handleCheckboxChange = (actionIndex: number, property: string) => { + setTool({ + ...tool, + actions: tool.actions.map((action, index) => { + if (index === actionIndex) { + return { + ...action, + parameters: { + ...action.parameters, + properties: { + ...action.parameters.properties, + [property]: { + ...action.parameters.properties[property], + filled_by_llm: + !action.parameters.properties[property].filled_by_llm, + }, + }, + }, + }; + } + return action; + }), + }); + }; + + const handleSaveChanges = () => { + userService + .updateTool({ + id: tool.id, + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: { token: authKey }, + actions: tool.actions, + status: tool.status, + }) + .then(() => { + handleGoBack(); + }); + }; + return ( +
+
+ +

Back to all tools

+
+
+

+ Type +

+

+ {tool.name} +

+
+
+ {Object.keys(tool?.config).length !== 0 && ( +

+ Authentication +

+ )} +
+ {Object.keys(tool?.config).length !== 0 && ( +
+ + API Key / Oauth + + setAuthKey(e.target.value)} + borderVariant="thin" + placeholder="Enter API Key / Oauth" + > +
+ )} + +
+
+
+
+

+ Actions +

+
+ {tool.actions.map((action, actionIndex) => { + return ( +
+
+

+ {action.name} +

+ +
+
+ { + setTool({ + ...tool, + actions: tool.actions.map((act, index) => { + if (index === actionIndex) { + return { + ...act, + description: e.target.value, + }; + } + return act; + }), + }); + }} + borderVariant="thin" + > +
+
+ + + + + + + + + + + + {Object.entries(action.parameters?.properties).map( + (param, index) => { + const uniqueKey = `${actionIndex}-${param[0]}`; + return ( + + + + + + + + ); + }, + )} + +
Field NameField TypeFilled by LLMFIeld descriptionValue
{param[0]}{param[1].type} + + + { + setTool({ + ...tool, + actions: tool.actions.map( + (act, index) => { + if (index === actionIndex) { + return { + ...act, + parameters: { + ...act.parameters, + properties: { + ...act.parameters.properties, + [param[0]]: { + ...act.parameters + .properties[param[0]], + description: e.target.value, + }, + }, + }, + }; + } + return act; + }, + ), + }); + }} + > + + { + setTool({ + ...tool, + actions: tool.actions.map( + (act, index) => { + if (index === actionIndex) { + return { + ...act, + parameters: { + ...act.parameters, + properties: { + ...act.parameters.properties, + [param[0]]: { + ...act.parameters + .properties[param[0]], + value: e.target.value, + }, + }, + }, + }; + } + return act; + }, + ), + }); + }} + > +
+
+
+ ); + })} +
+
+
+ ); +} diff --git a/frontend/src/settings/Tools.tsx b/frontend/src/settings/Tools.tsx new file mode 100644 index 000000000..d29ba42c6 --- /dev/null +++ b/frontend/src/settings/Tools.tsx @@ -0,0 +1,157 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import CogwheelIcon from '../assets/cogwheel.svg'; +import Input from '../components/Input'; +import AddToolModal from '../modals/AddToolModal'; +import { ActiveState } from '../models/misc'; +import { UserTool } from './types'; +import ToolConfig from './ToolConfig'; + +export default function Tools() { + const [searchTerm, setSearchTerm] = React.useState(''); + const [addToolModalState, setAddToolModalState] = + React.useState('INACTIVE'); + const [userTools, setUserTools] = React.useState([]); + const [selectedTool, setSelectedTool] = React.useState(null); + + const getUserTools = () => { + userService + .getUserTools() + .then((res) => { + return res.json(); + }) + .then((data) => { + setUserTools(data.tools); + }); + }; + + const updateToolStatus = (toolId: string, newStatus: boolean) => { + userService + .updateToolStatus({ id: toolId, status: newStatus }) + .then(() => { + setUserTools((prevTools) => + prevTools.map((tool) => + tool.id === toolId ? { ...tool, status: newStatus } : tool, + ), + ); + }) + .catch((error) => { + console.error('Failed to update tool status:', error); + }); + }; + + const handleSettingsClick = (tool: UserTool) => { + setSelectedTool(tool); + }; + + const handleGoBack = () => { + setSelectedTool(null); + getUserTools(); + }; + + React.useEffect(() => { + getUserTools(); + }, []); + return ( +
+ {selectedTool ? ( + + ) : ( +
+
+
+
+ setSearchTerm(e.target.value)} + /> +
+ +
+
+ {userTools + .filter((tool) => + tool.displayName + .toLowerCase() + .includes(searchTerm.toLowerCase()), + ) + .map((tool, index) => ( +
+
+
+ + +
+
+

+ {tool.displayName} +

+

+ {tool.description} +

+
+
+
+ +
+
+ ))} +
+
+ +
+ )} +
+ ); +} diff --git a/frontend/src/settings/index.tsx b/frontend/src/settings/index.tsx index 15c7ce086..89f29a29c 100644 --- a/frontend/src/settings/index.tsx +++ b/frontend/src/settings/index.tsx @@ -7,8 +7,8 @@ import SettingsBar from '../components/SettingsBar'; import i18n from '../locale/i18n'; import { Doc } from '../models/misc'; import { - selectSourceDocs, selectPaginatedDocuments, + selectSourceDocs, setPaginatedDocuments, setSourceDocs, } from '../preferences/preferenceSlice'; @@ -17,6 +17,7 @@ import APIKeys from './APIKeys'; import Documents from './Documents'; import General from './General'; import Logs from './Logs'; +import Tools from './Tools'; import Widgets from './Widgets'; export default function Settings() { @@ -100,6 +101,8 @@ export default function Settings() { return ; case t('settings.logs.label'): return ; + case t('settings.tools.label'): + return ; default: return null; } diff --git a/frontend/src/settings/types/index.ts b/frontend/src/settings/types/index.ts index 52a58f236..322bdfeb3 100644 --- a/frontend/src/settings/types/index.ts +++ b/frontend/src/settings/types/index.ts @@ -18,3 +18,32 @@ export type LogData = { retriever_params: Record; timestamp: string; }; + +export type UserTool = { + id: string; + name: string; + displayName: string; + description: string; + status: boolean; + config: { + [key: string]: string; + }; + actions: { + name: string; + description: string; + parameters: { + properties: { + [key: string]: { + type: string; + description: string; + filled_by_llm: boolean; + value: string; + }; + }; + additionalProperties: boolean; + required: string[]; + type: string; + }; + active: boolean; + }[]; +};