Skip to content

Commit

Permalink
Merge branch 'dev/async-agency'
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Jan 27, 2024
2 parents 9468a4a + 68fc4b7 commit b2f5d4e
Show file tree
Hide file tree
Showing 8 changed files with 707 additions and 124 deletions.
2 changes: 1 addition & 1 deletion agency_swarm/agency/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .agency import Agency
from .agency import Agency
321 changes: 218 additions & 103 deletions agency_swarm/agency/agency.py

Large diffs are not rendered by default.

69 changes: 69 additions & 0 deletions agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def init_oai(self):
self._update_assistant()
self._update_settings()
return self

# create assistant if settings.json does not exist or assistant with the same name does not exist
self.assistant = self.client.beta.assistants.create(
name=self.name,
Expand Down Expand Up @@ -315,6 +316,74 @@ def _parse_schemas(self):
else:
raise Exception("Schemas folder path must be a string or list of strings.")

def get_openapi_schema(self, url):
"""Get openapi schema that contains all tools from the agent as different api paths. Make sure to call this after agency has been initialized."""
if self.assistant is None:
raise Exception("Assistant is not initialized. Please initialize the agency first, before using this method")

schema = {
"openapi": "3.1.0",
"info": {
"title": self.name,
"description": self.description if self.description else "",
"version": "v1.0.0"
},
"servers": [
{
"url": url,
}
],
"paths": {},
"components": {
"schemas": {},
"securitySchemes": {
"apiKey": {
"type": "apiKey"
}
}
},
}

for tool in self.tools:
if issubclass(tool, BaseTool):
openai_schema = tool.openai_schema
defs = {}
if '$defs' in openai_schema['parameters']:
defs = openai_schema['parameters']['$defs']
del openai_schema['parameters']['$defs']

schema['paths']["/" + openai_schema['name']] = {
"post": {
"description": openai_schema['description'],
"operationId": openai_schema['name'],
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": openai_schema['parameters']
}
},
"required": True,
},
"deprecated": False,
"security": [
{
"apiKey": []
}
],
"x-openai-isConsequential": False,
}
}

if defs:
schema['components']['schemas'].update(**defs)

print(openai_schema)

schema = json.dumps(schema, indent=2).replace("#/$defs/", "#/components/schemas/")

return schema

# --- Settings Methods ---

def _check_parameters(self, assistant_settings):
Expand Down
27 changes: 16 additions & 11 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,24 @@ class Thread:
def __init__(self, agent: Literal[Agent, User], recipient_agent: Agent):
self.agent = agent
self.recipient_agent = recipient_agent

self.client = get_openai_client()

def init_thread(self):
if self.id:
self.thread = self.client.beta.threads.retrieve(self.id)
else:
self.thread = self.client.beta.threads.create()
self.id = self.thread.id

def get_completion(self, message: str, message_files=None, yield_messages=True):
if not self.thread:
if self.id:
self.thread = self.client.beta.threads.retrieve(self.id)
else:
self.thread = self.client.beta.threads.create()
self.id = self.thread.id
# Determine the sender's name based on the agent type
sender_name = "user" if isinstance(self.agent, User) else self.agent.name
playground_url = f'https://platform.openai.com/playground?assistant={self.recipient_agent._assistant.id}&mode=assistant&thread={self.thread.id}'
print(f'THREAD:[ {sender_name} -> {self.recipient_agent.name} ]: URL {playground_url}')
self.init_thread()

# Determine the sender's name based on the agent type
sender_name = "user" if isinstance(self.agent, User) else self.agent.name
playground_url = f'https://platform.openai.com/playground?assistant={self.recipient_agent._assistant.id}&mode=assistant&thread={self.thread.id}'
print(f'THREAD:[ {sender_name} -> {self.recipient_agent.name} ]: URL {playground_url}')

# send message
self.client.beta.threads.messages.create(
Expand Down Expand Up @@ -65,7 +70,7 @@ def get_completion(self, message: str, message_files=None, yield_messages=True):
yield MessageOutput("function", self.recipient_agent.name, self.agent.name,
str(tool_call.function))

output = self._execute_tool(tool_call)
output = self.execute_tool(tool_call)
if inspect.isgenerator(output):
try:
while True:
Expand Down Expand Up @@ -102,7 +107,7 @@ def get_completion(self, message: str, message_files=None, yield_messages=True):

return message

def _execute_tool(self, tool_call):
def execute_tool(self, tool_call):
funcs = self.recipient_agent.functions
func = next((func for func in funcs if func.__name__ == tool_call.function.name), None)

Expand Down
82 changes: 82 additions & 0 deletions agency_swarm/threads/thread_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import threading
from typing import Literal

from agency_swarm.agents import Agent
from agency_swarm.threads import Thread
from agency_swarm.user import User


class ThreadAsync(Thread):
def __init__(self, agent: Literal[Agent, User], recipient_agent: Agent):
super().__init__(agent, recipient_agent)
self.pythread = None
self.response = None

def worker(self, message: str, message_files=None):
gen = super().get_completion(message=message, message_files=message_files,
yield_messages=False) # yielding is not supported in async mode
while True:
try:
next(gen)
except StopIteration as e:
self.response = f"""{self.recipient_agent.name}'s Response: '{e.value}'"""
break

return

def get_completion_async(self, message: str, message_files=None):
if self.pythread and self.pythread.is_alive():
return "System Notification: 'Agent is busy, so your message was not received. Please always use 'GetResponse' tool to check for status first, before using 'SendMessage' tool again for the same agent.'"
elif self.pythread and not self.pythread.is_alive():
self.pythread.join()
self.pythread = None
self.response = None

run = self.get_last_run()

if run and run.status in ['queued', 'in_progress', 'requires_action']:
return "System Notification: 'Agent is busy, so your message was not received. Please always use 'GetResponse' tool to check for status first, before using 'SendMessage' tool again for the same agent.'"

self.pythread = threading.Thread(target=self.worker,
args=(message, message_files))

self.pythread.start()

return "System Notification: 'Task has started. Please notify the user that they can tell you to check the status later. You can do this with the 'GetResponse' tool, after you have been instructed to do so. Don't mention the tool itself to the user. "

def check_status(self, run=None):
if not run:
run = self.get_last_run()

if not run:
return "System Notification: 'Agent is ready to receive a message. Please send a message with the 'SendMessage' tool.'"

# check run status
if run.status in ['queued', 'in_progress', 'requires_action']:
return "System Notification: 'Task is not completed yet. Please tell the user to wait and try again later.'"

if run.status == "failed":
return f"System Notification: 'Agent run failed with error: {run.last_error.message}. You may send another message with the 'SendMessage' tool.'"

messages = self.client.beta.threads.messages.list(
thread_id=self.id,
order="desc",
)

return f"""{self.recipient_agent.name}'s Response: '{messages.data[0].content[0].text.value}'"""

def get_last_run(self):
if not self.thread:
self.init_thread()

runs = self.client.beta.threads.runs.list(
thread_id=self.thread.id,
order="desc",
)

if len(runs.data) == 0:
return None

run = runs.data[0]

return run
2 changes: 1 addition & 1 deletion agency_swarm/tools/BaseTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class BaseTool(OpenAISchema, ABC):
caller_agent: Optional[Any] = Field(
None, description="The agent that called this tool. Please ignore this field."
None, description="The agent that called this tool. This field will be removed from schema."
)

@classmethod
Expand Down
Loading

0 comments on commit b2f5d4e

Please sign in to comment.