Skip to content

Commit

Permalink
fix - no args function call with anthropic api
Browse files Browse the repository at this point in the history
  • Loading branch information
ananis25 committed Jan 28, 2025
1 parent 4227ea4 commit 49d20d8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/magentic/chat_model/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ def strict(self) -> bool | None:
return cast(ConfigDict, self._model.model_config).get("openai_strict")

def parse_args(self, chunks: Iterable[str]) -> FunctionCall[T]:
args_json = "".join(chunks)
# Anthropic message stream returns empty string for function call with no arguments
args_json = "".join(chunks) or "{}"
model = self._model.model_validate_json(args_json)
supplied_params = [
param
Expand Down
27 changes: 26 additions & 1 deletion tests/test_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,10 @@ def plus_with_annotated(
return a + b


def return_constant() -> int:
return 100


class IntModel(BaseModel):
value: int

Expand Down Expand Up @@ -807,6 +811,13 @@ def plus_with_config_openai_strict(a: int, b: int) -> int:
"strict": True,
},
),
(
return_constant,
{
"name": "return_constant",
"parameters": {"type": "object", "properties": {}},
},
),
],
)
def test_function_call_function_schema(function, json_schema):
Expand Down Expand Up @@ -855,12 +866,26 @@ def test_function_call_function_schema_with_default_value():
'{"a": {"value": 1}, "b": {"value": 2}}',
FunctionCall(plus_with_basemodel, IntModel(value=1), IntModel(value=2)),
),
(
return_constant,
"{}",
FunctionCall(return_constant),
),
]

function_call_function_schema_args_empty_string_test_case = [
(
return_constant,
"",
FunctionCall(return_constant),
),
]


@pytest.mark.parametrize(
("function", "args_str", "expected_args"),
function_call_function_schema_args_test_cases,
function_call_function_schema_args_test_cases
+ function_call_function_schema_args_empty_string_test_case,
)
def test_function_call_function_schema_parse_args(function, args_str, expected_args):
parsed_args = FunctionCallFunctionSchema(function).parse_args(args_str)
Expand Down

0 comments on commit 49d20d8

Please sign in to comment.