diff --git a/.vscode/launch.json b/.vscode/launch.json index fc4b8128f..5be1f7115 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,6 +11,26 @@ "skipFiles": [ "/**" ] - } + }, + { + "name": "Python Debugger: Flask", + "type": "debugpy", + "request": "launch", + "module": "flask", + "env": { + "FLASK_APP": "application/app.py", + "PYTHONPATH": "${workspaceFolder}", + "FLASK_ENV": "development", + "FLASK_DEBUG": "1", + "FLASK_RUN_PORT": "7091", + "FLASK_RUN_HOST": "0.0.0.0" + + }, + "args": [ + "run", + "--no-debugger" + ], + "cwd": "${workspaceFolder}", + }, ] } \ No newline at end of file diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index c55ffe725..74f8a4a95 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -37,7 +37,7 @@ gpt_model = "" # to have some kind of default behaviour if settings.LLM_NAME == "openai": - gpt_model = "gpt-3.5-turbo" + gpt_model = "gpt-4o-mini" elif settings.LLM_NAME == "anthropic": gpt_model = "claude-2" elif settings.LLM_NAME == "groq": diff --git a/application/api/user/routes.py b/application/api/user/routes.py index af26f7ba3..f2d1be06d 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1,14 +1,14 @@ import datetime +import math import os import shutil import uuid -import math from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef from bson.objectid import ObjectId -from flask import Blueprint, jsonify, make_response, request, redirect -from flask_restx import inputs, fields, Namespace, Resource +from flask import Blueprint, jsonify, make_response, redirect, request +from flask_restx import fields, inputs, Namespace, Resource from werkzeug.utils import secure_filename from application.api.user.tasks import ingest, ingest_remote @@ -16,9 +16,10 @@ from application.core.mongo_db import MongoDB from application.core.settings import settings from application.extensions import api +from application.tools.tool_manager import ToolManager +from application.tts.google_tts import GoogleTTS from application.utils import check_required_fields from application.vectorstore.vector_creator import VectorCreator -from application.tts.google_tts import GoogleTTS mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -30,6 +31,7 @@ token_usage_collection = db["token_usage"] shared_conversations_collections = db["shared_conversations"] user_logs_collection = db["user_logs"] +user_tools_collection = db["user_tools"] user = Blueprint("user", __name__) user_ns = Namespace("user", description="User related operations", path="/") @@ -39,6 +41,9 @@ os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) +tool_config = {} +tool_manager = ToolManager(config=tool_config) + def generate_minute_range(start_date, end_date): return { @@ -1801,4 +1806,296 @@ def post(self): 200, ) except Exception as err: - return make_response(jsonify({"success": False, "error": str(err)}), 400) \ No newline at end of file + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + +@user_ns.route("/api/available_tools") +class AvailableTools(Resource): + @api.doc(description="Get available tools for a user") + def get(self): + try: + tools_metadata = [] + for tool_name, tool_instance in tool_manager.tools.items(): + doc = tool_instance.__doc__.strip() + lines = doc.split("\n", 1) + name = lines[0].strip() + description = lines[1].strip() if len(lines) > 1 else "" + tools_metadata.append( + { + "name": tool_name, + "displayName": name, + "description": description, + "configRequirements": tool_instance.get_config_requirements(), + "actions": tool_instance.get_actions_metadata(), + } + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True, "data": tools_metadata}), 200) + + +@user_ns.route("/api/get_tools") +class GetTools(Resource): + @api.doc(description="Get tools created by a user") + def get(self): + try: + user = "local" + tools = user_tools_collection.find({"user": user}) + user_tools = [] + for tool in tools: + tool["id"] = str(tool["_id"]) + tool.pop("_id") + user_tools.append(tool) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True, "tools": user_tools}), 200) + + +@user_ns.route("/api/create_tool") +class CreateTool(Resource): + @api.expect( + api.model( + "CreateToolModel", + { + "name": fields.String(required=True, description="Name of the tool"), + "displayName": fields.String( + required=True, description="Display name for the tool" + ), + "description": fields.String( + required=True, description="Tool description" + ), + "config": fields.Raw( + required=True, description="Configuration of the tool" + ), + "actions": fields.List( + fields.Raw, + required=True, + description="Actions the tool can perform", + ), + "status": fields.Boolean( + required=True, description="Status of the tool" + ), + }, + ) + ) + @api.doc(description="Create a new tool") + def post(self): + data = request.get_json() + required_fields = [ + "name", + "displayName", + "description", + "actions", + "config", + "status", + ] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + user = "local" + transformed_actions = [] + for action in data["actions"]: + action["active"] = True + if "parameters" in action: + if "properties" in action["parameters"]: + for param_name, param_details in action["parameters"][ + "properties" + ].items(): + param_details["filled_by_llm"] = True + param_details["value"] = "" + transformed_actions.append(action) + try: + new_tool = { + "user": user, + "name": data["name"], + "displayName": data["displayName"], + "description": data["description"], + "actions": transformed_actions, + "config": data["config"], + "status": data["status"], + } + resp = user_tools_collection.insert_one(new_tool) + new_id = str(resp.inserted_id) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"id": new_id}), 200) + + +@user_ns.route("/api/update_tool") +class UpdateTool(Resource): + @api.expect( + api.model( + "UpdateToolModel", + { + "id": fields.String(required=True, description="Tool ID"), + "name": fields.String(description="Name of the tool"), + "displayName": fields.String(description="Display name for the tool"), + "description": fields.String(description="Tool description"), + "config": fields.Raw(description="Configuration of the tool"), + "actions": fields.List( + fields.Raw, description="Actions the tool can perform" + ), + "status": fields.Boolean(description="Status of the tool"), + }, + ) + ) + @api.doc(description="Update a tool by ID") + def post(self): + data = request.get_json() + required_fields = ["id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + update_data = {} + if "name" in data: + update_data["name"] = data["name"] + if "displayName" in data: + update_data["displayName"] = data["displayName"] + if "description" in data: + update_data["description"] = data["description"] + if "actions" in data: + update_data["actions"] = data["actions"] + if "config" in data: + update_data["config"] = data["config"] + if "status" in data: + update_data["status"] = data["status"] + + user_tools_collection.update_one( + {"_id": ObjectId(data["id"]), "user": "local"}, + {"$set": update_data}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + +@user_ns.route("/api/update_tool_config") +class UpdateToolConfig(Resource): + @api.expect( + api.model( + "UpdateToolConfigModel", + { + "id": fields.String(required=True, description="Tool ID"), + "config": fields.Raw( + required=True, description="Configuration of the tool" + ), + }, + ) + ) + @api.doc(description="Update the configuration of a tool") + def post(self): + data = request.get_json() + required_fields = ["id", "config"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + user_tools_collection.update_one( + {"_id": ObjectId(data["id"])}, + {"$set": {"config": data["config"]}}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + +@user_ns.route("/api/update_tool_actions") +class UpdateToolActions(Resource): + @api.expect( + api.model( + "UpdateToolActionsModel", + { + "id": fields.String(required=True, description="Tool ID"), + "actions": fields.List( + fields.Raw, + required=True, + description="Actions the tool can perform", + ), + }, + ) + ) + @api.doc(description="Update the actions of a tool") + def post(self): + data = request.get_json() + required_fields = ["id", "actions"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + user_tools_collection.update_one( + {"_id": ObjectId(data["id"])}, + {"$set": {"actions": data["actions"]}}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + +@user_ns.route("/api/update_tool_status") +class UpdateToolStatus(Resource): + @api.expect( + api.model( + "UpdateToolStatusModel", + { + "id": fields.String(required=True, description="Tool ID"), + "status": fields.Boolean( + required=True, description="Status of the tool" + ), + }, + ) + ) + @api.doc(description="Update the status of a tool") + def post(self): + data = request.get_json() + required_fields = ["id", "status"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + user_tools_collection.update_one( + {"_id": ObjectId(data["id"])}, + {"$set": {"status": data["status"]}}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + +@user_ns.route("/api/delete_tool") +class DeleteTool(Resource): + @api.expect( + api.model( + "DeleteToolModel", + {"id": fields.String(required=True, description="Tool ID")}, + ) + ) + @api.doc(description="Delete a tool by ID") + def post(self): + data = request.get_json() + required_fields = ["id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])}) + if result.deleted_count == 0: + return {"success": False, "message": "Tool not found"}, 404 + except Exception as err: + return {"success": False, "error": str(err)}, 400 + + return {"success": True}, 200 + \ No newline at end of file diff --git a/application/cache.py b/application/cache.py index 33022e45f..76b594c93 100644 --- a/application/cache.py +++ b/application/cache.py @@ -1,8 +1,10 @@ -import redis -import time import json import logging +import time from threading import Lock + +import redis + from application.core.settings import settings from application.utils import get_hash @@ -11,41 +13,47 @@ _redis_instance = None _instance_lock = Lock() + def get_redis_instance(): global _redis_instance if _redis_instance is None: with _instance_lock: if _redis_instance is None: try: - _redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2) + _redis_instance = redis.Redis.from_url( + settings.CACHE_REDIS_URL, socket_connect_timeout=2 + ) except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") _redis_instance = None return _redis_instance -def gen_cache_key(*messages, model="docgpt"): + +def gen_cache_key(messages, model="docgpt", tools=None): if not all(isinstance(msg, dict) for msg in messages): raise ValueError("All messages must be dictionaries.") - messages_str = json.dumps(list(messages), sort_keys=True) - combined = f"{model}_{messages_str}" + messages_str = json.dumps(messages) + tools_str = json.dumps(tools) if tools else "" + combined = f"{model}_{messages_str}_{tools_str}" cache_key = get_hash(combined) return cache_key + def gen_cache(func): - def wrapper(self, model, messages, *args, **kwargs): + def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): try: - cache_key = gen_cache_key(*messages) + cache_key = gen_cache_key(messages, model, tools) redis_client = get_redis_instance() if redis_client: try: cached_response = redis_client.get(cache_key) if cached_response: - return cached_response.decode('utf-8') + return cached_response.decode("utf-8") except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - result = func(self, model, messages, *args, **kwargs) - if redis_client: + result = func(self, model, messages, stream, tools, *args, **kwargs) + if redis_client and isinstance(result, str): try: redis_client.set(cache_key, result, ex=1800) except redis.ConnectionError as e: @@ -55,20 +63,22 @@ def wrapper(self, model, messages, *args, **kwargs): except ValueError as e: logger.error(e) return "Error: No user message found in the conversation to generate a cache key." + return wrapper + def stream_cache(func): def wrapper(self, model, messages, stream, *args, **kwargs): - cache_key = gen_cache_key(*messages) + cache_key = gen_cache_key(messages) logger.info(f"Stream cache key: {cache_key}") - + redis_client = get_redis_instance() if redis_client: try: cached_response = redis_client.get(cache_key) if cached_response: logger.info(f"Cache hit for stream key: {cache_key}") - cached_response = json.loads(cached_response.decode('utf-8')) + cached_response = json.loads(cached_response.decode("utf-8")) for chunk in cached_response: yield chunk time.sleep(0.03) @@ -78,16 +88,16 @@ def wrapper(self, model, messages, stream, *args, **kwargs): result = func(self, model, messages, stream, *args, **kwargs) stream_cache_data = [] - + for chunk in result: stream_cache_data.append(chunk) yield chunk - + if redis_client: try: redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) logger.info(f"Stream cache saved for key: {cache_key}") except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - - return wrapper \ No newline at end of file + + return wrapper diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 4081bcd08..1fa3b5b20 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -17,7 +17,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): self.AI_PROMPT = AI_PROMPT def _raw_gen( - self, baseself, model, messages, stream=False, max_tokens=300, **kwargs + self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs ): context = messages[0]["content"] user_question = messages[-1]["content"] @@ -34,7 +34,7 @@ def _raw_gen( return completion.completion def _raw_gen_stream( - self, baseself, model, messages, stream=True, max_tokens=300, **kwargs + self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs ): context = messages[0]["content"] user_question = messages[-1]["content"] diff --git a/application/llm/base.py b/application/llm/base.py index 1caab5d38..b9b0e5243 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -13,12 +13,12 @@ def _apply_decorator(self, method, decorators, *args, **kwargs): return method(self, *args, **kwargs) @abstractmethod - def _raw_gen(self, model, messages, stream, *args, **kwargs): + def _raw_gen(self, model, messages, stream, tools, *args, **kwargs): pass - def gen(self, model, messages, stream=False, *args, **kwargs): + def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): decorators = [gen_token_usage, gen_cache] - return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) + return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs) @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): @@ -26,4 +26,10 @@ def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): def gen_stream(self, model, messages, stream=True, *args, **kwargs): decorators = [stream_cache, stream_token_usage] - return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) \ No newline at end of file + return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) + + def supports_tools(self): + return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools')) + + def _supports_tools(self): + raise NotImplementedError("Subclass must implement _supports_tools method") \ No newline at end of file diff --git a/application/llm/groq.py b/application/llm/groq.py index b5731a905..282d7f477 100644 --- a/application/llm/groq.py +++ b/application/llm/groq.py @@ -1,45 +1,32 @@ from application.llm.base import BaseLLM - +from openai import OpenAI class GroqLLM(BaseLLM): - def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): - from openai import OpenAI - super().__init__(*args, **kwargs) self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1") self.api_key = api_key self.user_api_key = user_api_key - def _raw_gen( - self, - baseself, - model, - messages, - stream=False, - **kwargs - ): - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) - - return response.choices[0].message.content + def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): + if tools: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, tools=tools, **kwargs + ) + return response.choices[0] + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + return response.choices[0].message.content def _raw_gen_stream( - self, - baseself, - model, - messages, - stream=True, - **kwargs - ): + self, baseself, model, messages, stream=True, tools=None, **kwargs + ): response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, **kwargs ) - for line in response: - # import sys - # print(line.choices[0].delta.content, file=sys.stderr) if line.choices[0].delta.content is not None: yield line.choices[0].delta.content diff --git a/application/llm/openai.py b/application/llm/openai.py index f85de6eae..cc2285a12 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -25,14 +25,20 @@ def _raw_gen( model, messages, stream=False, + tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs - ): - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) - - return response.choices[0].message.content + ): + if tools: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, tools=tools, **kwargs + ) + return response.choices[0] + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + return response.choices[0].message.content def _raw_gen_stream( self, @@ -40,6 +46,7 @@ def _raw_gen_stream( model, messages, stream=True, + tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, **kwargs ): @@ -52,6 +59,9 @@ def _raw_gen_stream( # print(line.choices[0].delta.content, file=sys.stderr) if line.choices[0].delta.content is not None: yield line.choices[0].delta.content + + def _supports_tools(self): + return True class AzureOpenAILLM(OpenAILLM): diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 63947430d..aaf99a126 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -76,7 +76,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime - def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -105,7 +105,7 @@ def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): print(result[0]["generated_text"], file=sys.stderr) return result[0]["generated_text"][len(prompt) :] - def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/application/requirements.txt b/application/requirements.txt index de043d2d1..b9d2c33c8 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -43,7 +43,7 @@ multidict==6.1.0 mypy-extensions==1.0.0 networkx==3.3 numpy==1.26.4 -openai==1.55.3 +openai==1.57.0 openapi-schema-validator==0.6.2 openapi-spec-validator==0.6.0 openapi3-parser==1.1.18 diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 8de625dd8..2e3555137 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,7 +1,8 @@ -from application.retriever.base import BaseRetriever from application.core.settings import settings +from application.retriever.base import BaseRetriever +from application.tools.agent import Agent + from application.vectorstore.vector_creator import VectorCreator -from application.llm.llm_creator import LLMCreator @@ -19,7 +20,7 @@ def __init__( user_api_key=None, ): self.question = question - self.vectorstore = source['active_docs'] if 'active_docs' in source else None + self.vectorstore = source["active_docs"] if "active_docs" in source else None self.chat_history = chat_history self.prompt = prompt self.chunks = chunks @@ -81,17 +82,23 @@ def gen(self): {"role": "system", "content": i["response"]} ) messages_combine.append({"role": "user", "content": self.question}) - - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + # llm = LLMCreator.create_llm( + # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + # ) + # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + agent = Agent( + llm_name=settings.LLM_NAME, + gpt_model=self.gpt_model, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, ) - completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + completion = agent.gen(messages_combine) for line in completion: yield {"answer": str(line)} def search(self): return self._get_data() - + def get_params(self): return { "question": self.question, @@ -101,5 +108,5 @@ def get_params(self): "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, - "user_api_key": self.user_api_key + "user_api_key": self.user_api_key, } diff --git a/application/tools/agent.py b/application/tools/agent.py new file mode 100644 index 000000000..d4077e45d --- /dev/null +++ b/application/tools/agent.py @@ -0,0 +1,149 @@ +import json + +from application.core.mongo_db import MongoDB +from application.llm.llm_creator import LLMCreator +from application.tools.tool_manager import ToolManager + + +class Agent: + def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): + # Initialize the LLM with the provided parameters + self.llm = LLMCreator.create_llm( + llm_name, api_key=api_key, user_api_key=user_api_key + ) + self.gpt_model = gpt_model + # Static tool configuration (to be replaced later) + self.tools = [] + self.tool_config = {} + + def _get_user_tools(self, user="local"): + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + user_tools_collection = db["user_tools"] + user_tools = user_tools_collection.find({"user": user, "status": True}) + user_tools = list(user_tools) + tools_by_id = {str(tool["_id"]): tool for tool in user_tools} + return tools_by_id + + def _prepare_tools(self, tools_dict): + self.tools = [ + { + "type": "function", + "function": { + "name": f"{action['name']}_{tool_id}", + "description": action["description"], + "parameters": { + **action["parameters"], + "properties": { + k: { + key: value + for key, value in v.items() + if key != "filled_by_llm" and key != "value" + } + for k, v in action["parameters"]["properties"].items() + if v.get("filled_by_llm", False) + }, + "required": [ + key + for key in action["parameters"]["required"] + if key in action["parameters"]["properties"] + and action["parameters"]["properties"][key].get( + "filled_by_llm", False + ) + ], + }, + }, + } + for tool_id, tool in tools_dict.items() + for action in tool["actions"] + if action["active"] + ] + + def _execute_tool_action(self, tools_dict, call): + call_id = call.id + call_args = json.loads(call.function.arguments) + tool_id = call.function.name.split("_")[-1] + action_name = call.function.name.rsplit("_", 1)[0] + + tool_data = tools_dict[tool_id] + action_data = next( + action for action in tool_data["actions"] if action["name"] == action_name + ) + + for param, details in action_data["parameters"]["properties"].items(): + if param not in call_args and "value" in details: + call_args[param] = details["value"] + + tm = ToolManager(config={}) + tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"]) + print(f"Executing tool: {action_name} with args: {call_args}") + return tool.execute_action(action_name, **call_args), call_id + + def _simple_tool_agent(self, messages): + tools_dict = self._get_user_tools() + self._prepare_tools(tools_dict) + + resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + + if isinstance(resp, str): + yield resp + return + if resp.message.content: + yield resp.message.content + return + + while resp.finish_reason == "tool_calls": + message = json.loads(resp.model_dump_json())["message"] + keys_to_remove = {"audio", "function_call", "refusal"} + filtered_data = { + k: v for k, v in message.items() if k not in keys_to_remove + } + messages.append(filtered_data) + tool_calls = resp.message.tool_calls + for call in tool_calls: + try: + tool_response, call_id = self._execute_tool_action(tools_dict, call) + messages.append( + { + "role": "tool", + "content": str(tool_response), + "tool_call_id": call_id, + } + ) + except Exception as e: + messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call.id, + } + ) + # Generate a new response from the LLM after processing tools + resp = self.llm.gen( + model=self.gpt_model, messages=messages, tools=self.tools + ) + + # If no tool calls are needed, generate the final response + if isinstance(resp, str): + yield resp + elif resp.message.content: + yield resp.message.content + else: + completion = self.llm.gen_stream( + model=self.gpt_model, messages=messages, tools=self.tools + ) + for line in completion: + yield line + + return + + def gen(self, messages): + # Generate initial response from the LLM + if self.llm.supports_tools(): + resp = self._simple_tool_agent(messages) + for line in resp: + yield line + else: + resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) + for line in resp: + yield line diff --git a/application/tools/base.py b/application/tools/base.py new file mode 100644 index 000000000..fd7b4a852 --- /dev/null +++ b/application/tools/base.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + + +class Tool(ABC): + @abstractmethod + def execute_action(self, action_name: str, **kwargs): + pass + + @abstractmethod + def get_actions_metadata(self): + """ + Returns a list of JSON objects describing the actions supported by the tool. + """ + pass + + @abstractmethod + def get_config_requirements(self): + """ + Returns a dictionary describing the configuration requirements for the tool. + """ + pass diff --git a/application/tools/implementations/cryptoprice.py b/application/tools/implementations/cryptoprice.py new file mode 100644 index 000000000..7b88c866d --- /dev/null +++ b/application/tools/implementations/cryptoprice.py @@ -0,0 +1,77 @@ +import requests +from application.tools.base import Tool + + +class CryptoPriceTool(Tool): + """ + CryptoPrice + A tool for retrieving cryptocurrency prices using the CryptoCompare public API + """ + + def __init__(self, config): + self.config = config + + def execute_action(self, action_name, **kwargs): + actions = {"cryptoprice_get": self._get_price} + + if action_name in actions: + return actions[action_name](**kwargs) + else: + raise ValueError(f"Unknown action: {action_name}") + + def _get_price(self, symbol, currency): + """ + Fetches the current price of a given cryptocurrency symbol in the specified currency. + Example: + symbol = "BTC" + currency = "USD" + returns price in USD. + """ + url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}" + response = requests.get(url) + if response.status_code == 200: + data = response.json() + # data will be like {"USD": } if the call is successful + if currency.upper() in data: + return { + "status_code": response.status_code, + "price": data[currency.upper()], + "message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.", + } + else: + return { + "status_code": response.status_code, + "message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.", + } + else: + return { + "status_code": response.status_code, + "message": "Failed to retrieve price.", + } + + def get_actions_metadata(self): + return [ + { + "name": "cryptoprice_get", + "description": "Retrieve the price of a specified cryptocurrency in a given currency", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The cryptocurrency symbol (e.g. BTC)", + }, + "currency": { + "type": "string", + "description": "The currency in which you want the price (e.g. USD)", + }, + }, + "required": ["symbol", "currency"], + "additionalProperties": False, + }, + } + ] + + def get_config_requirements(self): + # No specific configuration needed for this tool as it just queries a public endpoint + return {} diff --git a/application/tools/implementations/telegram.py b/application/tools/implementations/telegram.py new file mode 100644 index 000000000..a32bbe881 --- /dev/null +++ b/application/tools/implementations/telegram.py @@ -0,0 +1,86 @@ +import requests +from application.tools.base import Tool + + +class TelegramTool(Tool): + """ + Telegram Bot + A flexible Telegram tool for performing various actions (e.g., sending messages, images). + Requires a bot token and chat ID for configuration + """ + + def __init__(self, config): + self.config = config + self.token = config.get("token", "") + + def execute_action(self, action_name, **kwargs): + actions = { + "telegram_send_message": self._send_message, + "telegram_send_image": self._send_image, + } + + if action_name in actions: + return actions[action_name](**kwargs) + else: + raise ValueError(f"Unknown action: {action_name}") + + def _send_message(self, text, chat_id): + print(f"Sending message: {text}") + url = f"https://api.telegram.org/bot{self.token}/sendMessage" + payload = {"chat_id": chat_id, "text": text} + response = requests.post(url, data=payload) + return {"status_code": response.status_code, "message": "Message sent"} + + def _send_image(self, image_url, chat_id): + print(f"Sending image: {image_url}") + url = f"https://api.telegram.org/bot{self.token}/sendPhoto" + payload = {"chat_id": chat_id, "photo": image_url} + response = requests.post(url, data=payload) + return {"status_code": response.status_code, "message": "Image sent"} + + def get_actions_metadata(self): + return [ + { + "name": "telegram_send_message", + "description": "Send a notification to Telegram chat", + "parameters": { + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Text to send in the notification", + }, + "chat_id": { + "type": "string", + "description": "Chat ID to send the notification to", + }, + }, + "required": ["text"], + "additionalProperties": False, + }, + }, + { + "name": "telegram_send_image", + "description": "Send an image to the Telegram chat", + "parameters": { + "type": "object", + "properties": { + "image_url": { + "type": "string", + "description": "URL of the image to send", + }, + "chat_id": { + "type": "string", + "description": "Chat ID to send the image to", + }, + }, + "required": ["image_url"], + "additionalProperties": False, + }, + }, + ] + + def get_config_requirements(self): + return { + "token": {"type": "string", "description": "Bot token for authentication"}, + } diff --git a/application/tools/tool_manager.py b/application/tools/tool_manager.py new file mode 100644 index 000000000..3e0766cfe --- /dev/null +++ b/application/tools/tool_manager.py @@ -0,0 +1,46 @@ +import importlib +import inspect +import os +import pkgutil + +from application.tools.base import Tool + + +class ToolManager: + def __init__(self, config): + self.config = config + self.tools = {} + self.load_tools() + + def load_tools(self): + tools_dir = os.path.join(os.path.dirname(__file__), "implementations") + for finder, name, ispkg in pkgutil.iter_modules([tools_dir]): + if name == "base" or name.startswith("__"): + continue + module = importlib.import_module( + f"application.tools.implementations.{name}" + ) + for member_name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, Tool) and obj is not Tool: + tool_config = self.config.get(name, {}) + self.tools[name] = obj(tool_config) + + def load_tool(self, tool_name, tool_config): + self.config[tool_name] = tool_config + module = importlib.import_module( + f"application.tools.implementations.{tool_name}" + ) + for member_name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, Tool) and obj is not Tool: + return obj(tool_config) + + def execute_action(self, tool_name, action_name, **kwargs): + if tool_name not in self.tools: + raise ValueError(f"Tool '{tool_name}' not loaded") + return self.tools[tool_name].execute_action(action_name, **kwargs) + + def get_all_actions_metadata(self): + metadata = [] + for tool in self.tools.values(): + metadata.extend(tool.get_actions_metadata()) + return metadata diff --git a/application/usage.py b/application/usage.py index e87ebe385..fe4cd50e6 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,7 +1,7 @@ import sys from datetime import datetime from application.core.mongo_db import MongoDB -from application.utils import num_tokens_from_string +from application.utils import num_tokens_from_string, num_tokens_from_object_or_list mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -21,11 +21,16 @@ def update_token_usage(user_api_key, token_usage): def gen_token_usage(func): - def wrapper(self, model, messages, stream, **kwargs): + def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) - result = func(self, model, messages, stream, **kwargs) - self.token_usage["generated_tokens"] += num_tokens_from_string(result) + if message["content"]: + self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + result = func(self, model, messages, stream, tools, **kwargs) + # check if result is a string + if isinstance(result, str): + self.token_usage["generated_tokens"] += num_tokens_from_string(result) + else: + self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result) update_token_usage(self.user_api_key, self.token_usage) return result @@ -33,11 +38,11 @@ def wrapper(self, model, messages, stream, **kwargs): def stream_token_usage(func): - def wrapper(self, model, messages, stream, **kwargs): + def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) batch = [] - result = func(self, model, messages, stream, **kwargs) + result = func(self, model, messages, stream, tools, **kwargs) for r in result: batch.append(r) yield r diff --git a/application/utils.py b/application/utils.py index 7099a20a9..690eac5e2 100644 --- a/application/utils.py +++ b/application/utils.py @@ -15,9 +15,21 @@ def get_encoding(): def num_tokens_from_string(string: str) -> int: encoding = get_encoding() - num_tokens = len(encoding.encode(string)) - return num_tokens - + if isinstance(string, str): + num_tokens = len(encoding.encode(string)) + return num_tokens + else: + return 0 + +def num_tokens_from_object_or_list(thing): + if isinstance(thing, list): + return sum([num_tokens_from_object_or_list(x) for x in thing]) + elif isinstance(thing, dict): + return sum([num_tokens_from_object_or_list(x) for x in thing.values()]) + elif isinstance(thing, str): + return num_tokens_from_string(thing) + else: + return 0 def count_tokens_docs(docs): docs_content = "" diff --git a/frontend/public/toolIcons/tool_cryptoprice.svg b/frontend/public/toolIcons/tool_cryptoprice.svg new file mode 100644 index 000000000..6a4226949 --- /dev/null +++ b/frontend/public/toolIcons/tool_cryptoprice.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/public/toolIcons/tool_telegram.svg b/frontend/public/toolIcons/tool_telegram.svg new file mode 100644 index 000000000..27536dedf --- /dev/null +++ b/frontend/public/toolIcons/tool_telegram.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/frontend/src/api/endpoints.ts b/frontend/src/api/endpoints.ts index 4e7112d04..8a7f9ae23 100644 --- a/frontend/src/api/endpoints.ts +++ b/frontend/src/api/endpoints.ts @@ -18,6 +18,11 @@ const endpoints = { FEEDBACK_ANALYTICS: '/api/get_feedback_analytics', LOGS: `/api/get_user_logs`, MANAGE_SYNC: '/api/manage_sync', + GET_AVAILABLE_TOOLS: '/api/available_tools', + GET_USER_TOOLS: '/api/get_tools', + CREATE_TOOL: '/api/create_tool', + UPDATE_TOOL_STATUS: '/api/update_tool_status', + UPDATE_TOOL: '/api/update_tool', }, CONVERSATION: { ANSWER: '/api/answer', diff --git a/frontend/src/api/services/userService.ts b/frontend/src/api/services/userService.ts index 942318ae4..ab91a0a42 100644 --- a/frontend/src/api/services/userService.ts +++ b/frontend/src/api/services/userService.ts @@ -35,6 +35,16 @@ const userService = { apiClient.post(endpoints.USER.LOGS, data), manageSync: (data: any): Promise => apiClient.post(endpoints.USER.MANAGE_SYNC, data), + getAvailableTools: (): Promise => + apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS), + getUserTools: (): Promise => + apiClient.get(endpoints.USER.GET_USER_TOOLS), + createTool: (data: any): Promise => + apiClient.post(endpoints.USER.CREATE_TOOL, data), + updateToolStatus: (data: any): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data), + updateTool: (data: any): Promise => + apiClient.post(endpoints.USER.UPDATE_TOOL, data), }; export default userService; diff --git a/frontend/src/assets/cogwheel.svg b/frontend/src/assets/cogwheel.svg new file mode 100644 index 000000000..f5299b8ba --- /dev/null +++ b/frontend/src/assets/cogwheel.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/src/components/SettingsBar.tsx b/frontend/src/components/SettingsBar.tsx index f617c6e82..c6970600d 100644 --- a/frontend/src/components/SettingsBar.tsx +++ b/frontend/src/components/SettingsBar.tsx @@ -13,6 +13,7 @@ const useTabs = () => { t('settings.apiKeys.label'), t('settings.analytics.label'), t('settings.logs.label'), + t('settings.tools.label'), ]; return tabs; }; diff --git a/frontend/src/index.css b/frontend/src/index.css index 91f1403dd..b8cf596e5 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -72,6 +72,15 @@ body.dark { .table-default td:last-child { @apply border-r-0; /* Ensure no right border on the last column */ } + + .table-default th, + .table-default td { + min-width: 150px; + max-width: 320px; + overflow: auto; + scrollbar-width: thin; + scrollbar-color: grey transparent; + } } /*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */ diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 8adda1aa7..c6ae80239 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -73,6 +73,9 @@ }, "logs": { "label": "Logs" + }, + "tools": { + "label": "Tools" } }, "modals": { diff --git a/frontend/src/modals/AddToolModal.tsx b/frontend/src/modals/AddToolModal.tsx new file mode 100644 index 000000000..330cf3bb2 --- /dev/null +++ b/frontend/src/modals/AddToolModal.tsx @@ -0,0 +1,136 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import Exit from '../assets/exit.svg'; +import { ActiveState } from '../models/misc'; +import { AvailableTool } from './types'; +import ConfigToolModal from './ConfigToolModal'; + +export default function AddToolModal({ + message, + modalState, + setModalState, + getUserTools, +}: { + message: string; + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + getUserTools: () => void; +}) { + const [availableTools, setAvailableTools] = React.useState( + [], + ); + const [selectedTool, setSelectedTool] = React.useState( + null, + ); + const [configModalState, setConfigModalState] = + React.useState('INACTIVE'); + + const getAvailableTools = () => { + userService + .getAvailableTools() + .then((res) => { + return res.json(); + }) + .then((data) => { + setAvailableTools(data.data); + }); + }; + + const handleAddTool = (tool: AvailableTool) => { + if (Object.keys(tool.configRequirements).length === 0) { + userService + .createTool({ + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: {}, + actions: tool.actions, + status: true, + }) + .then((res) => { + if (res.status === 200) { + getUserTools(); + setModalState('INACTIVE'); + } + }); + } else { + setModalState('INACTIVE'); + setConfigModalState('ACTIVE'); + } + }; + + React.useEffect(() => { + if (modalState === 'ACTIVE') getAvailableTools(); + }, [modalState]); + return ( + <> +
+
+
+ +
+

+ Select a tool to set up +

+
+ {availableTools.map((tool, index) => ( +
{ + setSelectedTool(tool); + handleAddTool(tool); + }} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + setSelectedTool(tool); + handleAddTool(tool); + } + }} + > +
+
+ +
+
+

+ {tool.displayName} +

+

+ {tool.description} +

+
+
+
+ ))} +
+
+
+
+
+ + + ); +} diff --git a/frontend/src/modals/ConfigToolModal.tsx b/frontend/src/modals/ConfigToolModal.tsx new file mode 100644 index 000000000..96bb15be8 --- /dev/null +++ b/frontend/src/modals/ConfigToolModal.tsx @@ -0,0 +1,95 @@ +import React from 'react'; + +import Exit from '../assets/exit.svg'; +import Input from '../components/Input'; +import { ActiveState } from '../models/misc'; +import { AvailableTool } from './types'; +import userService from '../api/services/userService'; + +export default function ConfigToolModal({ + modalState, + setModalState, + tool, + getUserTools, +}: { + modalState: ActiveState; + setModalState: (state: ActiveState) => void; + tool: AvailableTool | null; + getUserTools: () => void; +}) { + const [authKey, setAuthKey] = React.useState(''); + + const handleAddTool = (tool: AvailableTool) => { + userService + .createTool({ + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: { token: authKey }, + actions: tool.actions, + status: true, + }) + .then(() => { + setModalState('INACTIVE'); + getUserTools(); + }); + }; + return ( +
+
+
+ +
+

+ Tool Config +

+

+ Type: {tool?.name} +

+
+ + API Key / Oauth + + setAuthKey(e.target.value)} + borderVariant="thin" + placeholder="Enter API Key / Oauth" + > +
+
+ + +
+
+
+
+
+ ); +} diff --git a/frontend/src/modals/types/index.ts b/frontend/src/modals/types/index.ts index 976bf0e9f..458496d28 100644 --- a/frontend/src/modals/types/index.ts +++ b/frontend/src/modals/types/index.ts @@ -1,3 +1,15 @@ +export type AvailableTool = { + name: string; + displayName: string; + description: string; + configRequirements: object; + actions: { + name: string; + description: string; + parameters: object; + }[]; +}; + export type WrapperModalProps = { children?: React.ReactNode; isPerformingTask?: boolean; diff --git a/frontend/src/settings/Documents.tsx b/frontend/src/settings/Documents.tsx index 52a633519..8e68f226e 100644 --- a/frontend/src/settings/Documents.tsx +++ b/frontend/src/settings/Documents.tsx @@ -181,7 +181,7 @@ const Documents: React.FC = ({ {loading ? ( ) : ( -
+
diff --git a/frontend/src/settings/ToolConfig.tsx b/frontend/src/settings/ToolConfig.tsx new file mode 100644 index 000000000..0de4dab96 --- /dev/null +++ b/frontend/src/settings/ToolConfig.tsx @@ -0,0 +1,293 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import ArrowLeft from '../assets/arrow-left.svg'; +import Input from '../components/Input'; +import { UserTool } from './types'; + +export default function ToolConfig({ + tool, + setTool, + handleGoBack, +}: { + tool: UserTool; + setTool: (tool: UserTool) => void; + handleGoBack: () => void; +}) { + const [authKey, setAuthKey] = React.useState( + tool.config?.token || '', + ); + + const handleCheckboxChange = (actionIndex: number, property: string) => { + setTool({ + ...tool, + actions: tool.actions.map((action, index) => { + if (index === actionIndex) { + return { + ...action, + parameters: { + ...action.parameters, + properties: { + ...action.parameters.properties, + [property]: { + ...action.parameters.properties[property], + filled_by_llm: + !action.parameters.properties[property].filled_by_llm, + }, + }, + }, + }; + } + return action; + }), + }); + }; + + const handleSaveChanges = () => { + userService + .updateTool({ + id: tool.id, + name: tool.name, + displayName: tool.displayName, + description: tool.description, + config: { token: authKey }, + actions: tool.actions, + status: tool.status, + }) + .then(() => { + handleGoBack(); + }); + }; + return ( +
+
+ +

Back to all tools

+
+
+

+ Type +

+

+ {tool.name} +

+
+
+ {Object.keys(tool?.config).length !== 0 && ( +

+ Authentication +

+ )} +
+ {Object.keys(tool?.config).length !== 0 && ( +
+ + API Key / Oauth + + setAuthKey(e.target.value)} + borderVariant="thin" + placeholder="Enter API Key / Oauth" + > +
+ )} + +
+
+
+
+

+ Actions +

+
+ {tool.actions.map((action, actionIndex) => { + return ( +
+
+

+ {action.name} +

+ +
+
+ { + setTool({ + ...tool, + actions: tool.actions.map((act, index) => { + if (index === actionIndex) { + return { + ...act, + description: e.target.value, + }; + } + return act; + }), + }); + }} + borderVariant="thin" + > +
+
+
+ + + + + + + + + + + {Object.entries(action.parameters?.properties).map( + (param, index) => { + const uniqueKey = `${actionIndex}-${param[0]}`; + return ( + + + + + + + + ); + }, + )} + +
Field NameField TypeFilled by LLMFIeld descriptionValue
{param[0]}{param[1].type} + + + { + setTool({ + ...tool, + actions: tool.actions.map( + (act, index) => { + if (index === actionIndex) { + return { + ...act, + parameters: { + ...act.parameters, + properties: { + ...act.parameters.properties, + [param[0]]: { + ...act.parameters + .properties[param[0]], + description: e.target.value, + }, + }, + }, + }; + } + return act; + }, + ), + }); + }} + > + + { + setTool({ + ...tool, + actions: tool.actions.map( + (act, index) => { + if (index === actionIndex) { + return { + ...act, + parameters: { + ...act.parameters, + properties: { + ...act.parameters.properties, + [param[0]]: { + ...act.parameters + .properties[param[0]], + value: e.target.value, + }, + }, + }, + }; + } + return act; + }, + ), + }); + }} + > +
+
+
+ ); + })} +
+
+ + ); +} diff --git a/frontend/src/settings/Tools.tsx b/frontend/src/settings/Tools.tsx new file mode 100644 index 000000000..d29ba42c6 --- /dev/null +++ b/frontend/src/settings/Tools.tsx @@ -0,0 +1,157 @@ +import React from 'react'; + +import userService from '../api/services/userService'; +import CogwheelIcon from '../assets/cogwheel.svg'; +import Input from '../components/Input'; +import AddToolModal from '../modals/AddToolModal'; +import { ActiveState } from '../models/misc'; +import { UserTool } from './types'; +import ToolConfig from './ToolConfig'; + +export default function Tools() { + const [searchTerm, setSearchTerm] = React.useState(''); + const [addToolModalState, setAddToolModalState] = + React.useState('INACTIVE'); + const [userTools, setUserTools] = React.useState([]); + const [selectedTool, setSelectedTool] = React.useState(null); + + const getUserTools = () => { + userService + .getUserTools() + .then((res) => { + return res.json(); + }) + .then((data) => { + setUserTools(data.tools); + }); + }; + + const updateToolStatus = (toolId: string, newStatus: boolean) => { + userService + .updateToolStatus({ id: toolId, status: newStatus }) + .then(() => { + setUserTools((prevTools) => + prevTools.map((tool) => + tool.id === toolId ? { ...tool, status: newStatus } : tool, + ), + ); + }) + .catch((error) => { + console.error('Failed to update tool status:', error); + }); + }; + + const handleSettingsClick = (tool: UserTool) => { + setSelectedTool(tool); + }; + + const handleGoBack = () => { + setSelectedTool(null); + getUserTools(); + }; + + React.useEffect(() => { + getUserTools(); + }, []); + return ( +
+ {selectedTool ? ( + + ) : ( +
+
+
+
+ setSearchTerm(e.target.value)} + /> +
+ +
+
+ {userTools + .filter((tool) => + tool.displayName + .toLowerCase() + .includes(searchTerm.toLowerCase()), + ) + .map((tool, index) => ( +
+
+
+ + +
+
+

+ {tool.displayName} +

+

+ {tool.description} +

+
+
+
+ +
+
+ ))} +
+
+ +
+ )} +
+ ); +} diff --git a/frontend/src/settings/index.tsx b/frontend/src/settings/index.tsx index 15c7ce086..89f29a29c 100644 --- a/frontend/src/settings/index.tsx +++ b/frontend/src/settings/index.tsx @@ -7,8 +7,8 @@ import SettingsBar from '../components/SettingsBar'; import i18n from '../locale/i18n'; import { Doc } from '../models/misc'; import { - selectSourceDocs, selectPaginatedDocuments, + selectSourceDocs, setPaginatedDocuments, setSourceDocs, } from '../preferences/preferenceSlice'; @@ -17,6 +17,7 @@ import APIKeys from './APIKeys'; import Documents from './Documents'; import General from './General'; import Logs from './Logs'; +import Tools from './Tools'; import Widgets from './Widgets'; export default function Settings() { @@ -100,6 +101,8 @@ export default function Settings() { return ; case t('settings.logs.label'): return ; + case t('settings.tools.label'): + return ; default: return null; } diff --git a/frontend/src/settings/types/index.ts b/frontend/src/settings/types/index.ts index 52a58f236..322bdfeb3 100644 --- a/frontend/src/settings/types/index.ts +++ b/frontend/src/settings/types/index.ts @@ -18,3 +18,32 @@ export type LogData = { retriever_params: Record; timestamp: string; }; + +export type UserTool = { + id: string; + name: string; + displayName: string; + description: string; + status: boolean; + config: { + [key: string]: string; + }; + actions: { + name: string; + description: string; + parameters: { + properties: { + [key: string]: { + type: string; + description: string; + filled_by_llm: boolean; + value: string; + }; + }; + additionalProperties: boolean; + required: string[]; + type: string; + }; + active: boolean; + }[]; +}; diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py index 689013c0d..50ddbe294 100644 --- a/tests/llm/test_anthropic.py +++ b/tests/llm/test_anthropic.py @@ -46,6 +46,7 @@ def test_gen_stream(self): {"content": "question"} ] mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")] + mock_tools = Mock() with patch("application.cache.get_redis_instance") as mock_make_redis: mock_redis_instance = mock_make_redis.return_value @@ -53,7 +54,7 @@ def test_gen_stream(self): mock_redis_instance.set = Mock() with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: - responses = list(self.llm.gen_stream("test_model", messages)) + responses = list(self.llm.gen_stream("test_model", messages, tools=mock_tools)) self.assertListEqual(responses, ["response_1", "response_2"]) prompt_expected = "### Context \n context \n ### Question \n question" diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py index d659d4983..2b893a9a4 100644 --- a/tests/llm/test_sagemaker.py +++ b/tests/llm/test_sagemaker.py @@ -76,7 +76,7 @@ def test_gen_stream(self): with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', return_value=self.response) as mock_invoke_endpoint: - output = list(self.sagemaker.gen_stream(None, self.messages)) + output = list(self.sagemaker.gen_stream(None, self.messages, tools=None)) mock_invoke_endpoint.assert_called_once_with( EndpointName=self.sagemaker.endpoint, ContentType='application/json', diff --git a/tests/test_cache.py b/tests/test_cache.py index 4270a1819..af2b5e008 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -12,18 +12,21 @@ def test_make_gen_cache_key(): {'role': 'system', 'content': 'test_system_message'}, ] model = "test_docgpt" + tools = None # Manually calculate the expected hash - expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}" + messages_str = json.dumps(messages) + tools_str = json.dumps(tools) if tools else "" + expected_combined = f"{model}_{messages_str}_{tools_str}" expected_hash = get_hash(expected_combined) - cache_key = gen_cache_key(*messages, model=model) + cache_key = gen_cache_key(messages, model=model, tools=None) assert cache_key == expected_hash def test_gen_cache_key_invalid_message_format(): # Test when messages is not a list with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context: - gen_cache_key("This is not a list", model="docgpt") + gen_cache_key("This is not a list", model="docgpt", tools=None) assert str(context.exception) == "All messages must be dictionaries." # Test for gen_cache decorator @@ -35,14 +38,14 @@ def test_gen_cache_hit(mock_make_redis): mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit @gen_cache - def mock_function(self, model, messages): + def mock_function(self, model, messages, stream, tools): return "new_result" messages = [{'role': 'user', 'content': 'test_user_message'}] model = "test_docgpt" # Act - result = mock_function(None, model, messages) + result = mock_function(None, model, messages, stream=False, tools=None) # Assert assert result == "cached_result" # Should return cached result @@ -58,7 +61,7 @@ def test_gen_cache_miss(mock_make_redis): mock_redis_instance.get.return_value = None # Simulate a cache miss @gen_cache - def mock_function(self, model, messages): + def mock_function(self, model, messages, steam, tools): return "new_result" messages = [ @@ -67,7 +70,7 @@ def mock_function(self, model, messages): ] model = "test_docgpt" # Act - result = mock_function(None, model, messages) + result = mock_function(None, model, messages, stream=False, tools=None) # Assert assert result == "new_result" @@ -83,14 +86,14 @@ def test_stream_cache_hit(mock_make_redis): mock_redis_instance.get.return_value = cached_chunk @stream_cache - def mock_function(self, model, messages, stream): + def mock_function(self, model, messages, stream, tools): yield "new_chunk" messages = [{'role': 'user', 'content': 'test_user_message'}] model = "test_docgpt" # Act - result = list(mock_function(None, model, messages, stream=True)) + result = list(mock_function(None, model, messages, stream=True, tools=None)) # Assert assert result == ["chunk1", "chunk2"] # Should return cached chunks @@ -106,7 +109,7 @@ def test_stream_cache_miss(mock_make_redis): mock_redis_instance.get.return_value = None # Simulate a cache miss @stream_cache - def mock_function(self, model, messages, stream): + def mock_function(self, model, messages, stream, tools): yield "new_chunk" messages = [ @@ -117,7 +120,7 @@ def mock_function(self, model, messages, stream): model = "test_docgpt" # Act - result = list(mock_function(None, model, messages, stream=True)) + result = list(mock_function(None, model, messages, stream=True, tools=None)) # Assert assert result == ["new_chunk"]