From 219d46540337834dba162434f0b3be7fc22ab804 Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Fri, 2 Feb 2024 09:24:55 +0400 Subject: [PATCH] Added custom gpt mentions to agency --- agency_swarm/agency/agency.py | 116 ++++++++++++++++++++++++--------- agency_swarm/agents/agent.py | 3 + agency_swarm/threads/thread.py | 30 +++++---- tests/demos/demo_gradio.py | 22 +++++-- 4 files changed, 125 insertions(+), 46 deletions(-) diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index e4a8540f..b5126d87 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -62,6 +62,7 @@ def __init__(self, self.ceo = None self.agents = [] self.agents_and_threads = {} + self.main_recipients = [] self.shared_files = shared_files if shared_files else [] self.settings_path = settings_path self.settings_callbacks = settings_callbacks @@ -82,7 +83,7 @@ def __init__(self, self.user = User() self.main_thread = Thread(self.user, self.ceo) - def get_completion(self, message: str, message_files=None, yield_messages=True): + def get_completion(self, message: str, message_files=None, yield_messages=True, recipient_agent=None): """ Retrieves the completion for a given message from the main thread. @@ -90,12 +91,13 @@ def get_completion(self, message: str, message_files=None, yield_messages=True): message (str): The message for which completion is to be retrieved. message_files (list, optional): A list of file ids to be sent as attachments with the message. Defaults to None. yield_messages (bool, optional): Flag to determine if intermediate messages should be yielded. Defaults to True. + recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart. Returns: Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread. """ gen = self.main_thread.get_completion(message=message, message_files=message_files, - yield_messages=yield_messages) + yield_messages=yield_messages, recipient_agent=recipient_agent) if not yield_messages: while True: @@ -133,32 +135,44 @@ def demo_gradio(self, height=600, dark_mode=True): else: js = js.replace("{theme}", "light") - message_files = [] + message_file_ids = [] + message_file_names = None + recipient_agents = [agent.name for agent in self.main_recipients] + recipient_agent = self.main_recipients[0] with gr.Blocks(js=js) as demo: chatbot = gr.Chatbot(height=height) - msg = gr.Textbox() - file_upload = gr.Files(label="Upload File", type="filepath") + with gr.Row(): + with gr.Column(scale=9): + dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agents, + value=recipient_agent.name) + msg = gr.Textbox(label="Your Message", lines=4) + with gr.Column(scale=1): + file_upload = gr.Files(label="Files", type="filepath") + button = gr.Button(value="Send", variant="primary") + + def handle_dropdown_change(selected_option): + nonlocal recipient_agent + recipient_agent = self.get_agent_by_name(selected_option) def handle_file_upload(file_list): - nonlocal message_files - message_files = [] + nonlocal message_file_ids + nonlocal message_file_names + message_file_ids = [] + message_file_names = [] if file_list: try: for file_obj in file_list: - # copy file to current directory - # path = "./" + os.path.basename(file_obj) - # shutil.copyfile(file_obj.name, path) - # print(f"Uploading file: {path}") with open(file_obj.name, 'rb') as f: # Upload the file to OpenAI file = self.main_thread.client.files.create( file=f, purpose="assistants" ) - message_files.append(file.id) + message_file_ids.append(file.id) + message_file_names.append(file.filename) print(f"Uploaded file ID: {file.id}") - return message_files + return message_file_ids except Exception as e: print(f"Error: {e}") return str(e) @@ -166,16 +180,32 @@ def handle_file_upload(file_list): return "No files uploaded" def user(user_message, history): + if history is None: + history = [] + + original_user_message = user_message + # Append the user message with a placeholder for bot response - user_message = "šŸ‘¤ User: " + user_message.strip() - return "", history + [[user_message, None]] + if recipient_agent: + user_message = f"šŸ‘¤ User šŸ—£ļø @{recipient_agent.name}:\n" + user_message.strip() + else: + user_message = f"šŸ‘¤ User:" + user_message.strip() - def bot(history): - nonlocal message_files - print("Message files: ", message_files) - # Replace this with your actual chatbot logic - gen = self.get_completion(message=history[-1][0], message_files=message_files) + nonlocal message_file_names + if message_file_names: + user_message += "\n\nšŸ“Ž Files:\n" + "\n".join(message_file_names) + + return original_user_message, history + [[user_message, None]] + def bot(original_message, history): + nonlocal message_file_ids + nonlocal message_file_names + nonlocal recipient_agent + print("Message files: ", message_file_ids) + # Replace this with your actual chatbot logic + gen = self.get_completion(message=original_message, message_files=message_file_ids, recipient_agent=recipient_agent) + message_file_ids = [] + message_file_names = [] try: # Yield each message from the generator for bot_message in gen: @@ -185,16 +215,23 @@ def bot(history): message = bot_message.get_sender_emoji() + " " + bot_message.get_formatted_content() history.append((None, message)) - yield history + yield "", history except StopIteration: # Handle the end of the conversation if necessary - message_files = [] + pass - # Chain the events + button.click( + user, + inputs=[msg, chatbot], + outputs=[msg, chatbot] + ).then( + bot, [msg, chatbot], [msg, chatbot] + ) + dropdown.change(handle_dropdown_change, dropdown) file_upload.change(handle_file_upload, file_upload) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( - bot, chatbot, chatbot + bot, [msg, chatbot], [msg, chatbot] ) # Enable queuing for streaming intermediate outputs @@ -324,12 +361,20 @@ def _parse_agency_chart(self, agency_chart): If a node is a list, it iterates through the agents in the list, adding them to the agency and establishing communication threads between them. It raises an exception if the agency chart is invalid or if multiple CEOs are defined. """ + if not isinstance(agency_chart, list): + raise Exception("Invalid agency chart.") + + if len(agency_chart) == 0: + raise Exception("Agency chart cannot be empty.") + for node in agency_chart: if isinstance(node, Agent): - if self.ceo: - raise Exception("Only 1 ceo is supported for now.") - self.ceo = node - self._add_agent(self.ceo) + if not self.ceo: + self.ceo = node + self._add_agent(self.ceo) + else: + self._add_agent(node) + self._add_main_recipient(node) elif isinstance(node, list): for i, agent in enumerate(node): @@ -352,7 +397,6 @@ def _parse_agency_chart(self, agency_chart): "agent": agent.name, "recipient_agent": other_agent.name, } - else: raise Exception("Invalid agency chart.") @@ -379,6 +423,20 @@ def _add_agent(self, agent): else: return self.get_agent_ids().index(agent.id) + def _add_main_recipient(self, agent): + """ + Adds an agent to the agency's list of main recipients. + + Parameters: + agent (Agent): The agent to be added to the agency's list of main recipients. + + This method adds an agent to the agency's list of main recipients. These are agents that can be directly contacted by the user. + """ + main_recipient_ids = [agent.id for agent in self.main_recipients] + + if agent.id not in main_recipient_ids: + self.main_recipients.append(agent) + def _read_instructions(self, path): """ Reads shared instructions from a specified file and stores them in the agency. diff --git a/agency_swarm/agents/agent.py b/agency_swarm/agents/agent.py index 8bdf50f5..722be130 100644 --- a/agency_swarm/agents/agent.py +++ b/agency_swarm/agents/agent.py @@ -486,6 +486,9 @@ def get_class_folder_path(self): return os.path.abspath(os.path.realpath(os.path.dirname(class_file))) def add_shared_instructions(self, instructions: str): + if not instructions: + return + if self._shared_instructions is None: self._shared_instructions = instructions else: diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 924896c1..e7c10060 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -26,14 +26,17 @@ def init_thread(self): self.thread = self.client.beta.threads.create() self.id = self.thread.id - def get_completion(self, message: str, message_files=None, yield_messages=True): + def get_completion(self, message: str, message_files=None, yield_messages=True, recipient_agent=None): if not self.thread: self.init_thread() + if not recipient_agent: + recipient_agent = self.recipient_agent + # Determine the sender's name based on the agent type sender_name = "user" if isinstance(self.agent, User) else self.agent.name - playground_url = f'https://platform.openai.com/playground?assistant={self.recipient_agent._assistant.id}&mode=assistant&thread={self.thread.id}' - print(f'THREAD:[ {sender_name} -> {self.recipient_agent.name} ]: URL {playground_url}') + playground_url = f'https://platform.openai.com/playground?assistant={recipient_agent.assistant.id}&mode=assistant&thread={self.thread.id}' + print(f'THREAD:[ {sender_name} -> {recipient_agent.name} ]: URL {playground_url}') # send message self.client.beta.threads.messages.create( @@ -44,12 +47,12 @@ def get_completion(self, message: str, message_files=None, yield_messages=True): ) if yield_messages: - yield MessageOutput("text", self.agent.name, self.recipient_agent.name, message) + yield MessageOutput("text", self.agent.name, recipient_agent.name, message) # create run self.run = self.client.beta.threads.runs.create( thread_id=self.thread.id, - assistant_id=self.recipient_agent.id, + assistant_id=recipient_agent.id, ) while True: @@ -67,10 +70,10 @@ def get_completion(self, message: str, message_files=None, yield_messages=True): tool_outputs = [] for tool_call in tool_calls: if yield_messages: - yield MessageOutput("function", self.recipient_agent.name, self.agent.name, + yield MessageOutput("function", recipient_agent.name, self.agent.name, str(tool_call.function)) - output = self.execute_tool(tool_call) + output = self.execute_tool(tool_call, recipient_agent) if inspect.isgenerator(output): try: while True: @@ -81,7 +84,7 @@ def get_completion(self, message: str, message_files=None, yield_messages=True): output = e.value else: if yield_messages: - yield MessageOutput("function_output", tool_call.function.name, self.recipient_agent.name, + yield MessageOutput("function_output", tool_call.function.name, recipient_agent.name, output) tool_outputs.append({"tool_call_id": tool_call.id, "output": str(output)}) @@ -103,12 +106,15 @@ def get_completion(self, message: str, message_files=None, yield_messages=True): message = messages.data[0].content[0].text.value if yield_messages: - yield MessageOutput("text", self.recipient_agent.name, self.agent.name, message) + yield MessageOutput("text", recipient_agent.name, self.agent.name, message) return message - def execute_tool(self, tool_call): - funcs = self.recipient_agent.functions + def execute_tool(self, tool_call, recipient_agent=None): + if not recipient_agent: + recipient_agent = self.recipient_agent + + funcs = recipient_agent.functions func = next((func for func in funcs if func.__name__ == tool_call.function.name), None) if not func: @@ -117,7 +123,7 @@ def execute_tool(self, tool_call): try: # init tool func = func(**eval(tool_call.function.arguments)) - func.caller_agent = self.recipient_agent + func.caller_agent = recipient_agent # get outputs from the tool output = func.run() diff --git a/tests/demos/demo_gradio.py b/tests/demos/demo_gradio.py index b506483c..3600f377 100644 --- a/tests/demos/demo_gradio.py +++ b/tests/demos/demo_gradio.py @@ -2,24 +2,36 @@ import gradio as gr -from agency_swarm import set_openai_key, Agent +sys.path.insert(0, './agency-swarm') -sys.path.insert(0, '../agency-swarm') +from agency_swarm import set_openai_key, Agent from agency_swarm.agency.agency import Agency from agency_swarm.tools.oai import Retrieval ceo = Agent(name="CEO", description="Responsible for client communication, task planning and management.", - instructions="Read files with myfiles_browser tool.", # can be a file like ./instructions.md + instructions="Analyze uploaded files with myfiles_browser tool.", # can be a file like ./instructions.md tools=[Retrieval]) +test_agent = Agent(name="Test Agent1", + description="Responsible for testing.", + instructions="Read files with myfiles_browser tool.", # can be a file like ./instructions.md + tools=[Retrieval]) + +test_agent2 = Agent(name="Test Agent2", + description="Responsible for testing.", + instructions="Read files with myfiles_browser tool.", # can be a file like ./instructions.md + tools=[Retrieval]) + + agency = Agency([ - ceo, + ceo, test_agent, test_agent2 ], shared_instructions="") +# agency.demo_gradio() -agency.demo_gradio(height=900) +print(agency.get_completion("Hello", recipient_agent=test_agent, yield_messages=False))