Skip to content

Commit

Permalink
Default to ordered tool calls (#697)
Browse files Browse the repository at this point in the history
  • Loading branch information
mskarlin committed Nov 18, 2024
1 parent 0af021a commit f59b3ab
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 37 deletions.
9 changes: 7 additions & 2 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,13 @@ async def step(
# If the action has empty tool_calls, the agent can later take that into account
msgs = cast(
list[Message],
await self.exec_tool_calls(action, state=self.state, handle_tool_exc=True),
)
await self.exec_tool_calls(
action,
ordered=True, # PQA Environment currently not safe for parallel tool calls
state=self.state,
handle_tool_exc=True,
),
) or [Message(content=f"No tool calls input in tool request {action}.")]
return (
msgs,
0, # Reward is computed in post-processing, use 0 as a placeholder
Expand Down
112 changes: 77 additions & 35 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import importlib
import itertools
import json
Expand All @@ -26,7 +27,7 @@
from tantivy import Index

from paperqa.agents import SearchIndex, agent_query
from paperqa.agents.env import settings_to_tools
from paperqa.agents.env import PaperQAEnvironment, settings_to_tools
from paperqa.agents.main import FAKE_AGENT_TYPE
from paperqa.agents.models import AgentStatus, AnswerResponse, QueryRequest
from paperqa.agents.search import (
Expand Down Expand Up @@ -744,46 +745,87 @@ def test_agent_prompt_collection_validations(
AgentSettings(**kwargs)


@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"])
@pytest.mark.asyncio
async def test_deepcopy_env(agent_test_settings: Settings) -> None:
await get_directory_index(settings=agent_test_settings) # Trigger build
async def test_sequential_tool_calls(agent_test_settings: Settings):

SLEEP_TIME = 2.0

async def fake_gather_evidence(*args, **kwargs) -> str: # noqa: ARG001
await asyncio.sleep(SLEEP_TIME)
return "fake evidence"

question = "How can you use XAI for chemical property prediction?"
env = GradablePaperQAEnvironment(
env = PaperQAEnvironment(
query=QueryRequest(query=question, settings=agent_test_settings),
docs=Docs(),
)

# 1. Rollout until after gather evidence
await env.reset()
for tool_call in (
ToolCall.from_name(
"paper_search",
query="XAI for chemical property prediction",
min_year=2018,
max_year=2024,
),
ToolCall.from_name("gather_evidence", question=question),
):
await env.step(ToolRequestMessage(tool_calls=[tool_call]))

# 2. Now we deepcopy the environment
env_copy = deepcopy(env)
assert env.state == env_copy.state

# 3. Generate an answer for both, and confirm they are identical
gen_answer_action = ToolRequestMessage(
tool_calls=[ToolCall.from_name("gen_answer", question=question)]
)
_, _, done, _ = await env.step(gen_answer_action)
assert done
assert not env.state.answer.could_not_answer
assert env.state.answer.used_contexts
_, _, done, _ = await env_copy.step(gen_answer_action)
assert done
assert not env_copy.state.answer.could_not_answer
assert env_copy.state.answer.used_contexts
assert sorted(env.state.answer.used_contexts) == sorted(
env_copy.state.answer.used_contexts
gather_tool = next(
tool for tool in env.tools if tool.info.name == GatherEvidence.TOOL_FN_NAME
)

with patch.object(gather_tool, "_tool_fn", fake_gather_evidence):
tic = time.time()
await env.step(
ToolRequestMessage(
tool_calls=[
ToolCall.from_name(
"gather_evidence",
question="XAI for chemical property prediction",
),
ToolCall.from_name(
"gather_evidence",
question="XAI for chemical property prediction",
),
]
)
)

assert time.time() - tic > 2 * SLEEP_TIME # since they are sequential


class TestGradablePaperQAEnvironment:
@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"])
@pytest.mark.asyncio
async def test_deepcopy_env(self, agent_test_settings: Settings) -> None:
await get_directory_index(settings=agent_test_settings) # Trigger build

question = "How can you use XAI for chemical property prediction?"
env = GradablePaperQAEnvironment(
query=QueryRequest(query=question, settings=agent_test_settings),
docs=Docs(),
)

# 1. Rollout until after gather evidence
await env.reset()
for tool_call in (
ToolCall.from_name(
"paper_search",
query="XAI for chemical property prediction",
min_year=2018,
max_year=2024,
),
ToolCall.from_name("gather_evidence", question=question),
):
await env.step(ToolRequestMessage(tool_calls=[tool_call]))

# 2. Now we deepcopy the environment
env_copy = deepcopy(env)
assert env.state == env_copy.state

# 3. Generate an answer for both, and confirm they are identical
gen_answer_action = ToolRequestMessage(
tool_calls=[ToolCall.from_name("gen_answer", question=question)]
)
_, _, done, _ = await env.step(gen_answer_action)
assert done
assert not env.state.answer.could_not_answer
assert env.state.answer.used_contexts
_, _, done, _ = await env_copy.step(gen_answer_action)
assert done
assert not env_copy.state.answer.could_not_answer
assert env_copy.state.answer.used_contexts
assert sorted(env.state.answer.used_contexts) == sorted(
env_copy.state.answer.used_contexts
)

0 comments on commit f59b3ab

Please sign in to comment.