Skip to content

Commit

Permalink
Remove retry_strategy in LM and handle no-docstring functions in ReAct (
Browse files Browse the repository at this point in the history
#1725)

* Remove retry_strategy in LM and handle no-docstring functions in ReAct

* exponential_backoff_retry tests
  • Loading branch information
okhat authored Oct 30, 2024
1 parent 94f1fd7 commit 064f6d2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
4 changes: 1 addition & 3 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
cache: bool = True,
launch_kwargs: Optional[Dict[str, Any]] = None,
callbacks: Optional[List[BaseCallback]] = None,
num_retries: int = 8,
num_retries: int = 3,
**kwargs,
):
"""
Expand Down Expand Up @@ -186,7 +186,6 @@ def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-s
kwargs = ujson.loads(request)
return litellm.completion(
num_retries=num_retries,
retry_strategy="exponential_backoff_retry",
cache=cache,
**kwargs,
)
Expand Down Expand Up @@ -223,7 +222,6 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
retry_strategy="exponential_backoff_retry",
**kwargs,
)

Expand Down
8 changes: 4 additions & 4 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic
annotations_func = func if inspect.isfunction(func) else func.__call__
self.func = func
self.name = name or getattr(func, '__name__', type(func).__name__)
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "No description")
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "")
self.args = {
k: v.schema() if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel)
else get_annotation_name(v)
Expand Down Expand Up @@ -50,10 +50,10 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
tools["finish"] = Tool(func=lambda **kwargs: "Completed.", name="finish", desc=finish_desc, args=finish_args)

for idx, tool in enumerate(tools.values()):
desc = tool.desc.replace("\n", " ")
args = tool.args if hasattr(tool, 'args') else str({tool.input_variable: str})
desc = f"whose description is <desc>{desc}</desc>. It takes arguments {args} in JSON format."
instr.append(f"({idx+1}) {tool.name}, {desc}")
desc = (f", whose description is <desc>{tool.desc}</desc>." if tool.desc else ".").replace('\n', " ")
desc += f" It takes arguments {args} in JSON format."
instr.append(f"({idx+1}) {tool.name}{desc}")

signature_ = (
dspy.Signature({**signature.input_fields}, "\n".join(instr))
Expand Down
4 changes: 2 additions & 2 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_lm_chat_respects_max_retries():

assert litellm_completion_api.call_count == 1
assert litellm_completion_api.call_args[1]["max_retries"] == 17
assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"
# assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"


def test_lm_completions_respects_max_retries():
Expand All @@ -22,4 +22,4 @@ def test_lm_completions_respects_max_retries():

assert litellm_completion_api.call_count == 1
assert litellm_completion_api.call_args[1]["max_retries"] == 17
assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"
# assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"

0 comments on commit 064f6d2

Please sign in to comment.