From 53a0e62e1725b406753eeb005e7a21e1611acd93 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Fri, 23 Aug 2024 13:18:42 +0400 Subject: [PATCH] Improved file classification for tool resources --- agency_swarm/agency/agency.py | 4 ++-- agency_swarm/util/files.py | 41 +++++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 58d701a7..6738e9b5 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -27,7 +27,7 @@ from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch from agency_swarm.user import User from agency_swarm.util.errors import RefusalError -from agency_swarm.util.files import determine_file_type +from agency_swarm.util.files import determine_file_type, get_tools from agency_swarm.util.shared_state import SharedState from agency_swarm.util.streaming import AgencyEventHandler @@ -349,7 +349,7 @@ def handle_file_upload(file_list): else: attachments.append({ "file_id": file.id, - "tools": tools + "tools": get_tools(file.filename) }) message_file_names.append(file.filename) diff --git a/agency_swarm/util/files.py b/agency_swarm/util/files.py index 07476a34..3da1b5ac 100644 --- a/agency_swarm/util/files.py +++ b/agency_swarm/util/files.py @@ -1,19 +1,38 @@ import mimetypes +code_interpreter_types = [ + "application/csv", "image/jpeg", "image/gif", "image/png", + "application/x-tar", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/xml", "text/xml", "application/zip" +] + +dual_types = [ + "text/x-c", "text/x-csharp", "text/x-c++", "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "text/html", "text/x-java", "application/json", "text/markdown", + "application/pdf", "text/x-php", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "text/x-python", "text/x-script.python", "text/x-ruby", "text/x-tex", + "text/plain", "text/css", "text/javascript", "application/x-sh", + "application/typescript" +] + def determine_file_type(file_path): mime_type, _ = mimetypes.guess_type(file_path) if mime_type: - if mime_type in [ - 'application/json', 'text/csv', 'application/xml', - 'application/vnd.ms-excel', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - 'application/zip' - ]: + if mime_type in code_interpreter_types: return "assistants.code_interpreter" - elif mime_type in [ - 'text/plain', 'text/markdown', 'application/pdf', - 'application/msword', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' - ]: - return "assistants.file_search" elif mime_type.startswith('image/'): return "vision" - return "assistants.file_search" \ No newline at end of file + elif mime_type in dual_types: + return "assistants.file_search" + raise ValueError(f"Unsupported file type: {mime_type}") + +def get_tools(file_path): + mime_type, _ = mimetypes.guess_type(file_path) + if mime_type in code_interpreter_types: + return [{"type": "code_interpreter"}] + elif mime_type in dual_types: + return [{"type": "code_interpreter"}, {"type": "retrieval"}] + else: + raise ValueError(f"Unsupported file type: {mime_type}") \ No newline at end of file