From f59b3ab6f6f3ce371b30d6d3d18f15a1d53c131c Mon Sep 17 00:00:00 2001 From: mskarlin <12701035+mskarlin@users.noreply.github.com> Date: Sun, 17 Nov 2024 15:14:57 -0800 Subject: [PATCH] Default to ordered tool calls (#697) --- paperqa/agents/env.py | 9 +++- tests/test_agents.py | 112 +++++++++++++++++++++++++++++------------- 2 files changed, 84 insertions(+), 37 deletions(-) diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index 7791bd291..dc5236a7d 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -165,8 +165,13 @@ async def step( # If the action has empty tool_calls, the agent can later take that into account msgs = cast( list[Message], - await self.exec_tool_calls(action, state=self.state, handle_tool_exc=True), - ) + await self.exec_tool_calls( + action, + ordered=True, # PQA Environment currently not safe for parallel tool calls + state=self.state, + handle_tool_exc=True, + ), + ) or [Message(content=f"No tool calls input in tool request {action}.")] return ( msgs, 0, # Reward is computed in post-processing, use 0 as a placeholder diff --git a/tests/test_agents.py b/tests/test_agents.py index ed693f882..4b9f3f0bc 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import importlib import itertools import json @@ -26,7 +27,7 @@ from tantivy import Index from paperqa.agents import SearchIndex, agent_query -from paperqa.agents.env import settings_to_tools +from paperqa.agents.env import PaperQAEnvironment, settings_to_tools from paperqa.agents.main import FAKE_AGENT_TYPE from paperqa.agents.models import AgentStatus, AnswerResponse, QueryRequest from paperqa.agents.search import ( @@ -744,46 +745,87 @@ def test_agent_prompt_collection_validations( AgentSettings(**kwargs) -@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"]) @pytest.mark.asyncio -async def test_deepcopy_env(agent_test_settings: Settings) -> None: - await get_directory_index(settings=agent_test_settings) # Trigger build +async def test_sequential_tool_calls(agent_test_settings: Settings): + + SLEEP_TIME = 2.0 + + async def fake_gather_evidence(*args, **kwargs) -> str: # noqa: ARG001 + await asyncio.sleep(SLEEP_TIME) + return "fake evidence" question = "How can you use XAI for chemical property prediction?" - env = GradablePaperQAEnvironment( + env = PaperQAEnvironment( query=QueryRequest(query=question, settings=agent_test_settings), docs=Docs(), ) - - # 1. Rollout until after gather evidence await env.reset() - for tool_call in ( - ToolCall.from_name( - "paper_search", - query="XAI for chemical property prediction", - min_year=2018, - max_year=2024, - ), - ToolCall.from_name("gather_evidence", question=question), - ): - await env.step(ToolRequestMessage(tool_calls=[tool_call])) - - # 2. Now we deepcopy the environment - env_copy = deepcopy(env) - assert env.state == env_copy.state - # 3. Generate an answer for both, and confirm they are identical - gen_answer_action = ToolRequestMessage( - tool_calls=[ToolCall.from_name("gen_answer", question=question)] - ) - _, _, done, _ = await env.step(gen_answer_action) - assert done - assert not env.state.answer.could_not_answer - assert env.state.answer.used_contexts - _, _, done, _ = await env_copy.step(gen_answer_action) - assert done - assert not env_copy.state.answer.could_not_answer - assert env_copy.state.answer.used_contexts - assert sorted(env.state.answer.used_contexts) == sorted( - env_copy.state.answer.used_contexts + gather_tool = next( + tool for tool in env.tools if tool.info.name == GatherEvidence.TOOL_FN_NAME ) + + with patch.object(gather_tool, "_tool_fn", fake_gather_evidence): + tic = time.time() + await env.step( + ToolRequestMessage( + tool_calls=[ + ToolCall.from_name( + "gather_evidence", + question="XAI for chemical property prediction", + ), + ToolCall.from_name( + "gather_evidence", + question="XAI for chemical property prediction", + ), + ] + ) + ) + + assert time.time() - tic > 2 * SLEEP_TIME # since they are sequential + + +class TestGradablePaperQAEnvironment: + @pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"]) + @pytest.mark.asyncio + async def test_deepcopy_env(self, agent_test_settings: Settings) -> None: + await get_directory_index(settings=agent_test_settings) # Trigger build + + question = "How can you use XAI for chemical property prediction?" + env = GradablePaperQAEnvironment( + query=QueryRequest(query=question, settings=agent_test_settings), + docs=Docs(), + ) + + # 1. Rollout until after gather evidence + await env.reset() + for tool_call in ( + ToolCall.from_name( + "paper_search", + query="XAI for chemical property prediction", + min_year=2018, + max_year=2024, + ), + ToolCall.from_name("gather_evidence", question=question), + ): + await env.step(ToolRequestMessage(tool_calls=[tool_call])) + + # 2. Now we deepcopy the environment + env_copy = deepcopy(env) + assert env.state == env_copy.state + + # 3. Generate an answer for both, and confirm they are identical + gen_answer_action = ToolRequestMessage( + tool_calls=[ToolCall.from_name("gen_answer", question=question)] + ) + _, _, done, _ = await env.step(gen_answer_action) + assert done + assert not env.state.answer.could_not_answer + assert env.state.answer.used_contexts + _, _, done, _ = await env_copy.step(gen_answer_action) + assert done + assert not env_copy.state.answer.could_not_answer + assert env_copy.state.answer.used_contexts + assert sorted(env.state.answer.used_contexts) == sorted( + env_copy.state.answer.used_contexts + )