diff --git a/packages/openassistants/openassistants/eval/interaction.py b/packages/openassistants/openassistants/eval/interaction.py index 8826d41..8b64fa1 100644 --- a/packages/openassistants/openassistants/eval/interaction.py +++ b/packages/openassistants/openassistants/eval/interaction.py @@ -1,7 +1,7 @@ import abc import asyncio import textwrap -from typing import Any, Dict, List, Literal, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar from openassistants.core.assistant import Assistant from openassistants.data_models.chat_messages import ( @@ -14,7 +14,7 @@ from openassistants.data_models.function_output import DataFrameOutput, TextOutput from openassistants.functions.base import IFunction from openassistants.utils.async_utils import last_value -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, InstanceOf class InteractionCheckError(Exception): @@ -98,6 +98,9 @@ async def run( self, assistant: Assistant, ancestor_response: List["FunctionInteractionResponse"], + select_function: bool = True, + invoke_function: bool = True, + get_function_spec: bool = True, ) -> "FunctionInteractionResponse": history: List[OpasMessage] = [ m # type: ignore @@ -115,18 +118,32 @@ async def run( ), ) + co_selection = ( + self.run_function_selection(assistant, history + [user_message]) + if select_function + else as_coroutine(None) + ) + + co_invocation = ( + self.run_function_invocation( + assistant, history + [user_message, user_input_response] + ) + if invoke_function + else as_coroutine((None, None)) + ) + + co_function = ( + self.get_function(assistant) if get_function_spec else as_coroutine(None) + ) + ( assistant_selection, - # assistant_infilling, (assistant_function_invocation, function_response), function_spec, ) = await asyncio.gather( - self.run_function_selection(assistant, history + [user_message]), - # self.run_function_infilling(client, history + [user_message]), - self.run_function_invocation( - assistant, history + [user_message, user_input_response] - ), - self.get_function(assistant), + co_selection, + co_invocation, + co_function, ) interaction_response = FunctionInteractionResponse( @@ -157,9 +174,9 @@ class FunctionInteractionResponseNode(BaseModel): assistant_selection: OpasAssistantMessage assistant_infilling: OpasAssistantMessage user_input_response: OpasUserMessage - assistant_function_invocation: OpasAssistantMessage - function_response: OpasFunctionMessage - function_spec: IFunction + assistant_function_invocation: Optional[OpasAssistantMessage] = None + function_response: Optional[OpasFunctionMessage] = None + function_spec: Optional[InstanceOf[IFunction]] = None class FunctionInteractionResponse(FunctionInteractionResponseNode): @@ -242,3 +259,10 @@ def pretty_repr(self, include_summary=True, include_dataframe=True) -> str: child_status = textwrap.indent(child_status, "| ") return this_status + child_status + + +T = TypeVar("T") + + +async def as_coroutine(v: T) -> T: + return v diff --git a/packages/openassistants/openassistants/functions/base.py b/packages/openassistants/openassistants/functions/base.py index ff058a5..cdce1d9 100644 --- a/packages/openassistants/openassistants/functions/base.py +++ b/packages/openassistants/openassistants/functions/base.py @@ -76,7 +76,9 @@ def get_signature(self) -> str: """ if len(self.get_sample_questions()) > 0: - documentation += "\n".join(f"* {q}" for q in self.get_sample_questions()) + documentation += "Example Questions:\n" + "\n".join( + f"* {q}" for q in self.get_sample_questions() + ) # Construct the function signature signature = f"""\