Skip to content

Commit

Permalink
Fix assistants v1 initialization #128
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed May 29, 2024
1 parent 75af243 commit b1b8917
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 14 deletions.
53 changes: 50 additions & 3 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from agency_swarm.user import User
from agency_swarm.util.files import determine_file_type
from agency_swarm.util.shared_state import SharedState
from openai.types.beta.threads.runs.tool_call import ToolCall
from openai.types.beta.threads.runs.tool_call import ToolCall, FunctionToolCall, CodeInterpreterToolCall, FileSearchToolCall


from agency_swarm.util.streaming import AgencyEventHandler

Expand Down Expand Up @@ -358,7 +359,17 @@ def on_text_delta(self, delta, snapshot):
@override
def on_tool_call_created(self, tool_call: ToolCall):
if isinstance(tool_call, dict):
tool_call = ToolCall(**tool_call)
if "type" not in tool_call:
tool_call["type"] = "function"

if tool_call["type"] == "function":
tool_call = FunctionToolCall(**tool_call)
elif tool_call["type"] == "code_interpreter":
tool_call = CodeInterpreterToolCall(**tool_call)
elif tool_call["type"] == "file_search" or tool_call["type"] == "retrieval":
tool_call = FileSearchToolCall(**tool_call)
else:
raise ValueError("Invalid tool call type: " + tool_call["type"])

# TODO: add support for code interpreter and retirieval tools
if tool_call.type == "function":
Expand All @@ -370,7 +381,17 @@ def on_tool_call_created(self, tool_call: ToolCall):
@override
def on_tool_call_done(self, snapshot: ToolCall):
if isinstance(snapshot, dict):
snapshot = ToolCall(**snapshot)
if "type" not in snapshot:
snapshot["type"] = "function"

if snapshot["type"] == "function":
snapshot = FunctionToolCall(**snapshot)
elif snapshot["type"] == "code_interpreter":
snapshot = CodeInterpreterToolCall(**snapshot)
elif snapshot["type"] == "file_search":
snapshot = FileSearchToolCall(**snapshot)
else:
raise ValueError("Invalid tool call type: " + snapshot["type"])

self.message_output = None

Expand Down Expand Up @@ -561,6 +582,19 @@ def on_text_delta(self, delta, snapshot):

@override
def on_tool_call_created(self, tool_call):
if isinstance(tool_call, dict):
if "type" not in tool_call:
tool_call["type"] = "function"

if tool_call["type"] == "function":
tool_call = FunctionToolCall(**tool_call)
elif tool_call["type"] == "code_interpreter":
tool_call = CodeInterpreterToolCall(**tool_call)
elif tool_call["type"] == "file_search" or tool_call["type"] == "retrieval":
tool_call = FileSearchToolCall(**tool_call)
else:
raise ValueError("Invalid tool call type: " + tool_call["type"])

# TODO: add support for code interpreter and retirieval tools

if tool_call.type == "function":
Expand All @@ -569,6 +603,19 @@ def on_tool_call_created(self, tool_call):

@override
def on_tool_call_delta(self, delta, snapshot):
if isinstance(snapshot, dict):
if "type" not in snapshot:
snapshot["type"] = "function"

if snapshot["type"] == "function":
snapshot = FunctionToolCall(**snapshot)
elif snapshot["type"] == "code_interpreter":
snapshot = CodeInterpreterToolCall(**snapshot)
elif snapshot["type"] == "file_search":
snapshot = FileSearchToolCall(**snapshot)
else:
raise ValueError("Invalid tool call type: " + snapshot["type"])

self.message_output.cprint_update(str(snapshot.function))

@override
Expand Down
14 changes: 3 additions & 11 deletions agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,9 @@ def init_oai(self):
self.tool_resources = self.tool_resources or self.assistant.tool_resources.model_dump()

for tool in self.assistant.tools:
if tool.type == "function":
# function tools must be added manually
continue
elif tool.type == "file_search":
self.add_tool(FileSearch)
elif tool.type == "code_interpreter":
self.add_tool(CodeInterpreter)
elif tool.type == "retrieval":
self.add_tool(Retrieval)
else:
raise Exception("Invalid tool type.")
# update assistants created with v1
if tool.type == "retrieval":
self.client.beta.assistants.update(self.id, tools=self.get_oai_tools())

# # update assistant if parameters are different
# if not self._check_parameters(self.assistant.model_dump()):
Expand Down

0 comments on commit b1b8917

Please sign in to comment.