Skip to content

Commit

Permalink
Revamp ReAct, adjust Bootstrap, adjust ChatAdapter (#1713)
Browse files Browse the repository at this point in the history
* JsonAdapter: Handle JSON formatting in demo's outputs

* Adjustmetns for JsonAdapter

* Revamp ReAct, adjust Bootstrap (handle repeat calls to a module; transpose order for max_rounds), adjust ChatAdapter (handle incomplete demos better)

* Remove ReAct tests (outdated format)

* Remove react tests (outdated)
  • Loading branch information
okhat authored Oct 29, 2024
1 parent 6fe6935 commit 803dff0
Show file tree
Hide file tree
Showing 9 changed files with 444 additions and 429 deletions.
6 changes: 3 additions & 3 deletions dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base import Adapter
from .chat_adapter import ChatAdapter
from .json_adapter import JsonAdapter
from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JsonAdapter
43 changes: 28 additions & 15 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def parse(self, signature, completion, _parse_values=True):

def format_turn(self, signature, values, role, incomplete=False):
return format_turn(signature, values, role, incomplete)

def format_fields(self, signature, values):
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in signature.fields.items()
if field_name in values
}

return format_fields(fields_with_values)



def format_blob(blob):
Expand Down Expand Up @@ -228,21 +240,22 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
content.append(formatted_fields)

if role == "user":
# def type_info(v):
# return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
# if v.annotation is not str else ""
#
# content.append(
# "Respond with the corresponding output fields, starting with the field "
# + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
# + ", and then ending with the marker for `[[ ## completed ## ]]`."
# )

content.append(
"Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`{f}`" for f in signature.output_fields)
+ ", and then ending with the marker for `completed`."
)
def type_info(v):
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
if v.annotation is not str else ""

if not incomplete:
content.append(
"Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for `[[ ## completed ## ]]`."
)

# content.append(
# "Respond with the corresponding output fields, starting with the field "
# + ", then ".join(f"`{f}`" for f in signature.output_fields)
# + ", and then ending with the marker for `completed`."
# )

return {"role": role, "content": "\n\n".join(content).strip()}

Expand Down
13 changes: 13 additions & 0 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ def parse(self, signature, completion, _parse_values=True):

def format_turn(self, signature, values, role, incomplete=False):
return format_turn(signature, values, role, incomplete)

def format_fields(self, signature, values):
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in signature.fields.items()
if field_name in values
}

return format_fields(role='user', fields_with_values=fields_with_values)



def parse_value(value, annotation):
Expand Down Expand Up @@ -241,6 +253,7 @@ def type_info(v):
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
if v.annotation is not str else ""

# TODO: Consider if not incomplete:
content.append(
"Respond with a JSON object in the following order of fields: "
+ ", then ".join(f"`{f}`{type_info(v)}" for f, v in signature.output_fields.items())
Expand Down
193 changes: 82 additions & 111 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
@@ -1,130 +1,101 @@
import dsp
import dspy
from dspy.signatures.signature import ensure_signature

from ..primitives.program import Module
from .predict import Predict
import inspect

# TODO: Simplify a lot.
# TODO: Divide Action and Action Input like langchain does for ReAct.
from pydantic import BaseModel
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature
from dspy.adapters.json_adapter import get_annotation_name
from typing import Callable, Any, get_type_hints, get_origin, Literal

# TODO: There's a lot of value in having a stopping condition in the LM calls at `\n\nObservation:`
class Tool:
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None):
annotations_func = func if inspect.isfunction(func) else func.__call__
self.func = func
self.name = name or getattr(func, '__name__', type(func).__name__)
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "No description")
self.args = {
k: v.schema() if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel)
else get_annotation_name(v)
for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return'
}

# TODO [NEW]: When max_iters is about to be reached, reduce the set of available actions to only the Finish action.
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)


class ReAct(Module):
def __init__(self, signature, max_iters=5, num_results=3, tools=None):
super().__init__()
def __init__(self, signature, tools: list[Callable], max_iters=5):
"""
Tools is either a list of functions, callable classes, or dspy.Tool instances.
"""

self.signature = signature = ensure_signature(signature)
self.max_iters = max_iters

self.tools = tools or [dspy.Retrieve(k=num_results)]
self.tools = {tool.name: tool for tool in self.tools}

self.input_fields = self.signature.input_fields
self.output_fields = self.signature.output_fields

assert len(self.output_fields) == 1, "ReAct only supports one output field."
tools = [t if isinstance(t, Tool) or hasattr(t, 'input_variable') else Tool(t) for t in tools]
tools = {tool.name: tool for tool in tools}

inputs_ = ", ".join([f"`{k}`" for k in self.input_fields.keys()])
outputs_ = ", ".join([f"`{k}`" for k in self.output_fields.keys()])
inputs_ = ", ".join([f"`{k}`" for k in signature.input_fields.keys()])
outputs_ = ", ".join([f"`{k}`" for k in signature.output_fields.keys()])
instr = [f"{signature.instructions}\n"] if signature.instructions else []

instr = []

if self.signature.instructions is not None:
instr.append(f"{self.signature.instructions}\n")

instr.extend([
f"You will be given {inputs_} and you will respond with {outputs_}.\n",
"To do this, you will interleave Thought, Action, and Observation steps.\n",
"Thought can reason about the current situation, and Action can be the following types:\n",
f"You will be given {inputs_} and your goal is to finish with {outputs_}.\n",
"To do this, you will interleave Thought, Tool Name, and Tool Args, and receive a resulting Observation.\n",
"Thought can reason about the current situation, and Tool Name can be the following types:\n",
])

self.tools["Finish"] = dspy.Example(
name="Finish",
input_variable=outputs_.strip("`"),
desc=f"returns the final {outputs_} and finishes the task",
finish_desc = f"Signals that the final outputs, i.e. {outputs_}, are now available and marks the task as complete."
finish_args = {} #k: v.annotation for k, v in signature.output_fields.items()}
tools["finish"] = Tool(func=lambda **kwargs: kwargs, name="finish", desc=finish_desc, args=finish_args)

for idx, tool in enumerate(tools.values()):
desc = tool.desc.replace("\n", " ")
args = tool.args if hasattr(tool, 'args') else str({tool.input_variable: str})
desc = f"whose description is <desc>{desc}</desc>. It takes arguments {args} in JSON format."
instr.append(f"({idx+1}) {tool.name}, {desc}")

signature_ = (
dspy.Signature({**signature.input_fields}, "\n".join(instr))
.append("trajectory", dspy.InputField(), type_=str)
.append("next_thought", dspy.OutputField(), type_=str)
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())])
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
)

for idx, tool in enumerate(self.tools):
tool = self.tools[tool]
instr.append(
f"({idx+1}) {tool.name}[{tool.input_variable}], which {tool.desc}",
)

instr = "\n".join(instr)
self.react = [
Predict(dspy.Signature(self._generate_signature(i), instr))
for i in range(1, max_iters + 1)
]

def _generate_signature(self, iters):
signature_dict = {}
for key, val in self.input_fields.items():
signature_dict[key] = val

for j in range(1, iters + 1):
IOField = dspy.OutputField if j == iters else dspy.InputField

signature_dict[f"Thought_{j}"] = IOField(
prefix=f"Thought {j}:",
desc="next steps to take based on last observation",
)

tool_list = " or ".join(
[
f"{tool.name}[{tool.input_variable}]"
for tool in self.tools.values()
if tool.name != "Finish"
],
)
signature_dict[f"Action_{j}"] = IOField(
prefix=f"Action {j}:",
desc=f"always either {tool_list} or, when done, Finish[<answer>], where <answer> is the answer to the question itself.",
)

if j < iters:
signature_dict[f"Observation_{j}"] = IOField(
prefix=f"Observation {j}:",
desc="observations based on action",
format=dsp.passages2text,
)

return signature_dict

def act(self, output, hop):
try:
action = output[f"Action_{hop+1}"]
action_name, action_val = action.strip().split("\n")[0].split("[", 1)
action_val = action_val.rsplit("]", 1)[0]

if action_name == "Finish":
return action_val

result = self.tools[action_name](action_val) #result must be a str, list, or tuple
# Handle the case where 'passages' attribute is missing
output[f"Observation_{hop+1}"] = getattr(result, "passages", result)

except Exception:
output[f"Observation_{hop+1}"] = (
"Failed to parse action. Bad formatting or incorrect action name."
)
# raise e

def forward(self, **kwargs):
args = {key: kwargs[key] for key in self.input_fields.keys() if key in kwargs}

for hop in range(self.max_iters):
# with dspy.settings.context(show_guidelines=(i <= 2)):
output = self.react[hop](**args)
output[f'Action_{hop + 1}'] = output[f'Action_{hop + 1}'].split('\n')[0]

if action_val := self.act(output, hop):
break
args.update(output)
fallback_signature = (
dspy.Signature({**signature.input_fields, **signature.output_fields})
.append("trajectory", dspy.InputField(), type_=str)
)

observations = [args[key] for key in args if key.startswith("Observation")]
self.tools = tools
self.react = dspy.Predict(signature_)
self.extract = dspy.ChainOfThought(fallback_signature)

def forward(self, **input_args):
trajectory = {}

def format(trajectory_: dict[str, Any], last_iteration: bool):
adapter = dspy.settings.adapter or dspy.ChatAdapter()
blob = adapter.format_fields(dspy.Signature(f"{', '.join(trajectory_.keys())} -> x"), trajectory_)
warning = f"\n\nWarning: The maximum number of iterations ({self.max_iters}) has been reached."
warning += " You must now produce the finish action."
return blob + (warning if last_iteration else "")

for idx in range(self.max_iters):
pred = self.react(**input_args, trajectory=format(trajectory, last_iteration=(idx == self.max_iters-1)))

trajectory[f"thought_{idx}"] = pred.next_thought
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
except Exception as e:
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}"

if pred.next_tool_name == "finish":
break

# assumes only 1 output field for now - TODO: handling for multiple output fields
return dspy.Prediction(observations=observations, **{list(self.output_fields.keys())[0]: action_val or ""})
extract = self.extract(**input_args, trajectory=format(trajectory, last_iteration=False))
return dspy.Prediction(trajectory=trajectory, **extract)
5 changes: 4 additions & 1 deletion dspy/retrieve/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ def __call__(self, *args, **kwargs):

def forward(
self,
query_or_queries: Union[str, List[str]],
query_or_queries: Union[str, List[str]] = None,
query: Optional[str] = None,
k: Optional[int] = None,
by_prob: bool = True,
with_metadata: bool = False,
**kwargs,
) -> Union[List[str], Prediction, List[Prediction]]:
query_or_queries = query_or_queries or query

# queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
# queries = [query.strip().split('\n')[0].strip() for query in queries]

Expand Down
Loading

0 comments on commit 803dff0

Please sign in to comment.