Skip to content

Commit

Permalink
Added custom gpt mentions to agency
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Feb 2, 2024
1 parent 5ba60c1 commit 219d465
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 46 deletions.
116 changes: 87 additions & 29 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
self.ceo = None
self.agents = []
self.agents_and_threads = {}
self.main_recipients = []
self.shared_files = shared_files if shared_files else []
self.settings_path = settings_path
self.settings_callbacks = settings_callbacks
Expand All @@ -82,20 +83,21 @@ def __init__(self,
self.user = User()
self.main_thread = Thread(self.user, self.ceo)

def get_completion(self, message: str, message_files=None, yield_messages=True):
def get_completion(self, message: str, message_files=None, yield_messages=True, recipient_agent=None):
"""
Retrieves the completion for a given message from the main thread.
Parameters:
message (str): The message for which completion is to be retrieved.
message_files (list, optional): A list of file ids to be sent as attachments with the message. Defaults to None.
yield_messages (bool, optional): Flag to determine if intermediate messages should be yielded. Defaults to True.
recipient_agent (Agent, optional): The agent to which the message should be sent. Defaults to the first agent in the agency chart.
Returns:
Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread.
"""
gen = self.main_thread.get_completion(message=message, message_files=message_files,
yield_messages=yield_messages)
yield_messages=yield_messages, recipient_agent=recipient_agent)

if not yield_messages:
while True:
Expand Down Expand Up @@ -133,49 +135,77 @@ def demo_gradio(self, height=600, dark_mode=True):
else:
js = js.replace("{theme}", "light")

message_files = []
message_file_ids = []
message_file_names = None
recipient_agents = [agent.name for agent in self.main_recipients]
recipient_agent = self.main_recipients[0]

with gr.Blocks(js=js) as demo:
chatbot = gr.Chatbot(height=height)
msg = gr.Textbox()
file_upload = gr.Files(label="Upload File", type="filepath")
with gr.Row():
with gr.Column(scale=9):
dropdown = gr.Dropdown(label="Recipient Agent", choices=recipient_agents,
value=recipient_agent.name)
msg = gr.Textbox(label="Your Message", lines=4)
with gr.Column(scale=1):
file_upload = gr.Files(label="Files", type="filepath")
button = gr.Button(value="Send", variant="primary")

def handle_dropdown_change(selected_option):
nonlocal recipient_agent
recipient_agent = self.get_agent_by_name(selected_option)

def handle_file_upload(file_list):
nonlocal message_files
message_files = []
nonlocal message_file_ids
nonlocal message_file_names
message_file_ids = []
message_file_names = []
if file_list:
try:
for file_obj in file_list:
# copy file to current directory
# path = "./" + os.path.basename(file_obj)
# shutil.copyfile(file_obj.name, path)
# print(f"Uploading file: {path}")
with open(file_obj.name, 'rb') as f:
# Upload the file to OpenAI
file = self.main_thread.client.files.create(
file=f,
purpose="assistants"
)
message_files.append(file.id)
message_file_ids.append(file.id)
message_file_names.append(file.filename)
print(f"Uploaded file ID: {file.id}")
return message_files
return message_file_ids
except Exception as e:
print(f"Error: {e}")
return str(e)

return "No files uploaded"

def user(user_message, history):
if history is None:
history = []

original_user_message = user_message

# Append the user message with a placeholder for bot response
user_message = "👤 User: " + user_message.strip()
return "", history + [[user_message, None]]
if recipient_agent:
user_message = f"👤 User 🗣️ @{recipient_agent.name}:\n" + user_message.strip()
else:
user_message = f"👤 User:" + user_message.strip()

def bot(history):
nonlocal message_files
print("Message files: ", message_files)
# Replace this with your actual chatbot logic
gen = self.get_completion(message=history[-1][0], message_files=message_files)
nonlocal message_file_names
if message_file_names:
user_message += "\n\n📎 Files:\n" + "\n".join(message_file_names)

return original_user_message, history + [[user_message, None]]

def bot(original_message, history):
nonlocal message_file_ids
nonlocal message_file_names
nonlocal recipient_agent
print("Message files: ", message_file_ids)
# Replace this with your actual chatbot logic
gen = self.get_completion(message=original_message, message_files=message_file_ids, recipient_agent=recipient_agent)
message_file_ids = []
message_file_names = []
try:
# Yield each message from the generator
for bot_message in gen:
Expand All @@ -185,16 +215,23 @@ def bot(history):
message = bot_message.get_sender_emoji() + " " + bot_message.get_formatted_content()

history.append((None, message))
yield history
yield "", history
except StopIteration:
# Handle the end of the conversation if necessary
message_files = []

pass

# Chain the events
button.click(
user,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
).then(
bot, [msg, chatbot], [msg, chatbot]
)
dropdown.change(handle_dropdown_change, dropdown)
file_upload.change(handle_file_upload, file_upload)
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
bot, [msg, chatbot], [msg, chatbot]
)

# Enable queuing for streaming intermediate outputs
Expand Down Expand Up @@ -324,12 +361,20 @@ def _parse_agency_chart(self, agency_chart):
If a node is a list, it iterates through the agents in the list, adding them to the agency and establishing communication
threads between them. It raises an exception if the agency chart is invalid or if multiple CEOs are defined.
"""
if not isinstance(agency_chart, list):
raise Exception("Invalid agency chart.")

if len(agency_chart) == 0:
raise Exception("Agency chart cannot be empty.")

for node in agency_chart:
if isinstance(node, Agent):
if self.ceo:
raise Exception("Only 1 ceo is supported for now.")
self.ceo = node
self._add_agent(self.ceo)
if not self.ceo:
self.ceo = node
self._add_agent(self.ceo)
else:
self._add_agent(node)
self._add_main_recipient(node)

elif isinstance(node, list):
for i, agent in enumerate(node):
Expand All @@ -352,7 +397,6 @@ def _parse_agency_chart(self, agency_chart):
"agent": agent.name,
"recipient_agent": other_agent.name,
}

else:
raise Exception("Invalid agency chart.")

Expand All @@ -379,6 +423,20 @@ def _add_agent(self, agent):
else:
return self.get_agent_ids().index(agent.id)

def _add_main_recipient(self, agent):
"""
Adds an agent to the agency's list of main recipients.
Parameters:
agent (Agent): The agent to be added to the agency's list of main recipients.
This method adds an agent to the agency's list of main recipients. These are agents that can be directly contacted by the user.
"""
main_recipient_ids = [agent.id for agent in self.main_recipients]

if agent.id not in main_recipient_ids:
self.main_recipients.append(agent)

def _read_instructions(self, path):
"""
Reads shared instructions from a specified file and stores them in the agency.
Expand Down
3 changes: 3 additions & 0 deletions agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ def get_class_folder_path(self):
return os.path.abspath(os.path.realpath(os.path.dirname(class_file)))

def add_shared_instructions(self, instructions: str):
if not instructions:
return

if self._shared_instructions is None:
self._shared_instructions = instructions
else:
Expand Down
30 changes: 18 additions & 12 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ def init_thread(self):
self.thread = self.client.beta.threads.create()
self.id = self.thread.id

def get_completion(self, message: str, message_files=None, yield_messages=True):
def get_completion(self, message: str, message_files=None, yield_messages=True, recipient_agent=None):
if not self.thread:
self.init_thread()

if not recipient_agent:
recipient_agent = self.recipient_agent

# Determine the sender's name based on the agent type
sender_name = "user" if isinstance(self.agent, User) else self.agent.name
playground_url = f'https://platform.openai.com/playground?assistant={self.recipient_agent._assistant.id}&mode=assistant&thread={self.thread.id}'
print(f'THREAD:[ {sender_name} -> {self.recipient_agent.name} ]: URL {playground_url}')
playground_url = f'https://platform.openai.com/playground?assistant={recipient_agent.assistant.id}&mode=assistant&thread={self.thread.id}'
print(f'THREAD:[ {sender_name} -> {recipient_agent.name} ]: URL {playground_url}')

# send message
self.client.beta.threads.messages.create(
Expand All @@ -44,12 +47,12 @@ def get_completion(self, message: str, message_files=None, yield_messages=True):
)

if yield_messages:
yield MessageOutput("text", self.agent.name, self.recipient_agent.name, message)
yield MessageOutput("text", self.agent.name, recipient_agent.name, message)

# create run
self.run = self.client.beta.threads.runs.create(
thread_id=self.thread.id,
assistant_id=self.recipient_agent.id,
assistant_id=recipient_agent.id,
)

while True:
Expand All @@ -67,10 +70,10 @@ def get_completion(self, message: str, message_files=None, yield_messages=True):
tool_outputs = []
for tool_call in tool_calls:
if yield_messages:
yield MessageOutput("function", self.recipient_agent.name, self.agent.name,
yield MessageOutput("function", recipient_agent.name, self.agent.name,
str(tool_call.function))

output = self.execute_tool(tool_call)
output = self.execute_tool(tool_call, recipient_agent)
if inspect.isgenerator(output):
try:
while True:
Expand All @@ -81,7 +84,7 @@ def get_completion(self, message: str, message_files=None, yield_messages=True):
output = e.value
else:
if yield_messages:
yield MessageOutput("function_output", tool_call.function.name, self.recipient_agent.name,
yield MessageOutput("function_output", tool_call.function.name, recipient_agent.name,
output)

tool_outputs.append({"tool_call_id": tool_call.id, "output": str(output)})
Expand All @@ -103,12 +106,15 @@ def get_completion(self, message: str, message_files=None, yield_messages=True):
message = messages.data[0].content[0].text.value

if yield_messages:
yield MessageOutput("text", self.recipient_agent.name, self.agent.name, message)
yield MessageOutput("text", recipient_agent.name, self.agent.name, message)

return message

def execute_tool(self, tool_call):
funcs = self.recipient_agent.functions
def execute_tool(self, tool_call, recipient_agent=None):
if not recipient_agent:
recipient_agent = self.recipient_agent

funcs = recipient_agent.functions
func = next((func for func in funcs if func.__name__ == tool_call.function.name), None)

if not func:
Expand All @@ -117,7 +123,7 @@ def execute_tool(self, tool_call):
try:
# init tool
func = func(**eval(tool_call.function.arguments))
func.caller_agent = self.recipient_agent
func.caller_agent = recipient_agent
# get outputs from the tool
output = func.run()

Expand Down
22 changes: 17 additions & 5 deletions tests/demos/demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,36 @@

import gradio as gr

from agency_swarm import set_openai_key, Agent
sys.path.insert(0, './agency-swarm')

sys.path.insert(0, '../agency-swarm')
from agency_swarm import set_openai_key, Agent

from agency_swarm.agency.agency import Agency
from agency_swarm.tools.oai import Retrieval

ceo = Agent(name="CEO",
description="Responsible for client communication, task planning and management.",
instructions="Read files with myfiles_browser tool.", # can be a file like ./instructions.md
instructions="Analyze uploaded files with myfiles_browser tool.", # can be a file like ./instructions.md
tools=[Retrieval])


test_agent = Agent(name="Test Agent1",
description="Responsible for testing.",
instructions="Read files with myfiles_browser tool.", # can be a file like ./instructions.md
tools=[Retrieval])

test_agent2 = Agent(name="Test Agent2",
description="Responsible for testing.",
instructions="Read files with myfiles_browser tool.", # can be a file like ./instructions.md
tools=[Retrieval])



agency = Agency([
ceo,
ceo, test_agent, test_agent2
], shared_instructions="")

# agency.demo_gradio()

agency.demo_gradio(height=900)
print(agency.get_completion("Hello", recipient_agent=test_agent, yield_messages=False))

0 comments on commit 219d465

Please sign in to comment.