Skip to content

Commit

Permalink
Added printing openapi schemas for agents
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Jan 23, 2024
1 parent ed4535b commit 71272df
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 131 deletions.
267 changes: 139 additions & 128 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,79 +73,6 @@ def __init__(self, agency_chart: List, shared_instructions: str = "", shared_fil
self.user = User()
self.main_thread = Thread(self.user, self.ceo)

def _init_agents(self):
"""
Initializes all agents in the agency with unique IDs, shared instructions, and OpenAI models.
This method iterates through each agent in the agency, assigns a unique ID, adds shared instructions, and initializes the OpenAI models for each agent.
There are no input parameters.
There are no output parameters as this method is used for internal initialization purposes within the Agency class.
"""
if self.settings_callbacks:
loaded_settings = self.settings_callbacks["load"]()
with open(self.agents[0].get_settings_path(), 'w') as f:
json.dump(loaded_settings, f, indent=4)

for agent in self.agents:
if "temp_id" in agent.id:
agent.id = None
agent.add_shared_instructions(self.shared_instructions)

if self.shared_files:
if isinstance(agent.files_folder, str):
agent.files_folder = [agent.files_folder]
agent.files_folder += self.shared_files
elif isinstance(agent.files_folder, list):
agent.files_folder += self.shared_files

agent.init_oai()

if self.settings_callbacks:
with open(self.agents[0].get_settings_path(), 'r') as f:
settings = f.read()
settings = json.loads(settings)
self.settings_callbacks["save"](settings)

def _init_threads(self):
"""
Initializes threads for communication between agents within the agency.
This method creates Thread objects for each pair of interacting agents as defined in the agents_and_threads attribute of the Agency. Each thread facilitates communication and task execution between an agent and its designated recipient agent.
No input parameters.
Output Parameters:
This method does not return any value but updates the agents_and_threads attribute with initialized Thread objects.
"""
# load thread ids
loaded_thread_ids = {}
if self.threads_callbacks:
loaded_thread_ids = self.threads_callbacks["load"]()

for agent_name, threads in self.agents_and_threads.items():
for other_agent, items in threads.items():
self.agents_and_threads[agent_name][other_agent] = self.ThreadType(
self.get_agent_by_name(items["agent"]),
self.get_agent_by_name(
items["recipient_agent"]))

if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]:
self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent]
elif self.threads_callbacks:
self.agents_and_threads[agent_name][other_agent].init_thread()

# save thread ids
if self.threads_callbacks:
loaded_thread_ids = {}
for agent_name, threads in self.agents_and_threads.items():
loaded_thread_ids[agent_name] = {}
for other_agent, thread in threads.items():
loaded_thread_ids[agent_name][other_agent] = thread.id

self.threads_callbacks["save"](loaded_thread_ids)

def get_completion(self, message: str, message_files=None, yield_messages=True):
"""
Retrieves the completion for a given message from the main thread.
Expand Down Expand Up @@ -244,6 +171,92 @@ def run_demo(self):
except StopIteration as e:
pass

def get_openapi_schema(self, url: str):
"""Returns the OpenAPI schema for the agency from the CEO agent, that you can use to integrate with custom gpts.
Parameters:
url (str): Your server url where the api will be hosted.
"""

return self.ceo.get_openapi_schema(url)


def plot_agency_chart(self):
pass

def _init_agents(self):
"""
Initializes all agents in the agency with unique IDs, shared instructions, and OpenAI models.
This method iterates through each agent in the agency, assigns a unique ID, adds shared instructions, and initializes the OpenAI models for each agent.
There are no input parameters.
There are no output parameters as this method is used for internal initialization purposes within the Agency class.
"""
if self.settings_callbacks:
loaded_settings = self.settings_callbacks["load"]()
with open(self.agents[0].get_settings_path(), 'w') as f:
json.dump(loaded_settings, f, indent=4)

for agent in self.agents:
if "temp_id" in agent.id:
agent.id = None
agent.add_shared_instructions(self.shared_instructions)

if self.shared_files:
if isinstance(agent.files_folder, str):
agent.files_folder = [agent.files_folder]
agent.files_folder += self.shared_files
elif isinstance(agent.files_folder, list):
agent.files_folder += self.shared_files

agent.init_oai()

if self.settings_callbacks:
with open(self.agents[0].get_settings_path(), 'r') as f:
settings = f.read()
settings = json.loads(settings)
self.settings_callbacks["save"](settings)

def _init_threads(self):
"""
Initializes threads for communication between agents within the agency.
This method creates Thread objects for each pair of interacting agents as defined in the agents_and_threads attribute of the Agency. Each thread facilitates communication and task execution between an agent and its designated recipient agent.
No input parameters.
Output Parameters:
This method does not return any value but updates the agents_and_threads attribute with initialized Thread objects.
"""
# load thread ids
loaded_thread_ids = {}
if self.threads_callbacks:
loaded_thread_ids = self.threads_callbacks["load"]()

for agent_name, threads in self.agents_and_threads.items():
for other_agent, items in threads.items():
self.agents_and_threads[agent_name][other_agent] = self.ThreadType(
self.get_agent_by_name(items["agent"]),
self.get_agent_by_name(
items["recipient_agent"]))

if agent_name in loaded_thread_ids and other_agent in loaded_thread_ids[agent_name]:
self.agents_and_threads[agent_name][other_agent].id = loaded_thread_ids[agent_name][other_agent]
elif self.threads_callbacks:
self.agents_and_threads[agent_name][other_agent].init_thread()

# save thread ids
if self.threads_callbacks:
loaded_thread_ids = {}
for agent_name, threads in self.agents_and_threads.items():
loaded_thread_ids[agent_name] = {}
for other_agent, thread in threads.items():
loaded_thread_ids[agent_name][other_agent] = thread.id

self.threads_callbacks["save"](loaded_thread_ids)

def _parse_agency_chart(self, agency_chart):
"""
Parses the provided agency chart to initialize and organize agents within the agency.
Expand Down Expand Up @@ -311,57 +324,6 @@ def _add_agent(self, agent):
else:
return self.get_agent_ids().index(agent.id)

def get_agent_by_name(self, agent_name):
"""
Retrieves an agent from the agency based on the agent's name.
Parameters:
agent_name (str): The name of the agent to be retrieved.
Returns:
Agent: The agent object with the specified name.
Raises:
Exception: If no agent with the given name is found in the agency.
"""
for agent in self.agents:
if agent.name == agent_name:
return agent
raise Exception(f"Agent {agent_name} not found.")

def get_agents_by_names(self, agent_names):
"""
Retrieves a list of agent objects based on their names.
Parameters:
agent_names: A list of strings representing the names of the agents to be retrieved.
Returns:
A list of Agent objects corresponding to the given names.
"""
return [self.get_agent_by_name(agent_name) for agent_name in agent_names]

def get_agent_ids(self):
"""
Retrieves the IDs of all agents currently in the agency.
Returns:
List[str]: A list containing the unique IDs of all agents.
"""
return [agent.id for agent in self.agents]

def get_agent_names(self):
"""
Retrieves the names of all agents in the agency.
Parameters:
None
Returns:
List[str]: A list of names of all agents currently part of the agency.
"""
return [agent.name for agent in self.agents]

def _read_instructions(self, path):
"""
Reads shared instructions from a specified file and stores them in the agency.
Expand All @@ -375,9 +337,6 @@ def _read_instructions(self, path):
with open(path, 'r') as f:
self.shared_instructions = f.read()

def plot_agency_chart(self):
pass

def _create_special_tools(self):
"""
Creates and assigns 'SendMessage' tools to each agent based on the agency's structure.
Expand Down Expand Up @@ -483,7 +442,8 @@ def _create_get_response_tool(self, agent: Agent, recipient_agents: List[Agent])

class GetResponse(BaseTool):
"""This tool allows you to check the status of a task or get a response from a specified recipient agent, if the task has been completed. You must always use 'SendMessage' tool first."""
recipient: recipients = Field(..., description=f"Recipient agent that you want to check the status of. Valid recipients are: {recipient_names}")
recipient: recipients = Field(...,
description=f"Recipient agent that you want to check the status of. Valid recipients are: {recipient_names}")
caller_agent_name: str = Field(default=agent.name,
description="The agent calling this tool. Defaults to your name. Do not change it.")

Expand All @@ -508,6 +468,57 @@ def run(self):

return GetResponse

def get_agent_by_name(self, agent_name):
"""
Retrieves an agent from the agency based on the agent's name.
Parameters:
agent_name (str): The name of the agent to be retrieved.
Returns:
Agent: The agent object with the specified name.
Raises:
Exception: If no agent with the given name is found in the agency.
"""
for agent in self.agents:
if agent.name == agent_name:
return agent
raise Exception(f"Agent {agent_name} not found.")

def get_agents_by_names(self, agent_names):
"""
Retrieves a list of agent objects based on their names.
Parameters:
agent_names: A list of strings representing the names of the agents to be retrieved.
Returns:
A list of Agent objects corresponding to the given names.
"""
return [self.get_agent_by_name(agent_name) for agent_name in agent_names]

def get_agent_ids(self):
"""
Retrieves the IDs of all agents currently in the agency.
Returns:
List[str]: A list containing the unique IDs of all agents.
"""
return [agent.id for agent in self.agents]

def get_agent_names(self):
"""
Retrieves the names of all agents in the agency.
Parameters:
None
Returns:
List[str]: A list of names of all agents currently part of the agency.
"""
return [agent.name for agent in self.agents]

def get_recipient_names(self):
"""
Retrieves the names of all agents in the agency.
Expand Down
66 changes: 66 additions & 0 deletions agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,72 @@ def _parse_schemas(self):
else:
raise Exception("Schemas folder path must be a string or list of strings.")

def get_openapi_schema(self, url):
"""Get openapi schema that contains all tools from the agent as different api paths. Make sure to call this after agency has been initialized."""
if self.assistant is None:
raise Exception("Assistant is not initialized. Please initialize the agency first, before using this method")

schema = {
"openapi": "3.1.0",
"info": {
"title": self.name,
"description": self.description if self.description else "",
"version": "v1.0.0"
},
"servers": [
{
"url": url,
}
],
"paths": {},
"components": {
"schemas": {},
"securitySchemes": {
"apiKey": {
"type": "apiKey"
}
}
},
}

for tool in self.tools:
if issubclass(tool, BaseTool):
openai_schema = tool.openai_schema
defs = {}
if '$defs' in openai_schema['parameters']:
defs = openai_schema['parameters']['$defs']
del openai_schema['parameters']['$defs']

schema['paths']["/" + tool.__name__] = {
"post": {
"description": openai_schema['description'],
"operationId": tool.__name__,
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/" + tool.__name__
}
}
},
"required": True,
},
"deprecated": False,
"security": [
{
"apiKey": []
}
]
}
}

if defs:
schema['components']['schemas'][tool.__name__] = openai_schema['parameters'],
schema['components']['schemas'].update(defs)

return json.dumps(schema, indent=2).replace("#/$defs/", "#/components/schemas/")

# --- Settings Methods ---

def _check_parameters(self, assistant_settings):
Expand Down
5 changes: 2 additions & 3 deletions agency_swarm/threads/thread_async.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from agency_swarm.threads import Thread
import threading
from typing import Literal

from agency_swarm.agents import Agent
from agency_swarm.messages import MessageOutput
from agency_swarm.threads import Thread
from agency_swarm.user import User
from agency_swarm.util.oai import get_openai_client


class ThreadAsync(Thread):
Expand Down

0 comments on commit 71272df

Please sign in to comment.