Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tool parsing] Improve / correct mistral tool parsing #10333

Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Nov 14, 2024

This PR is heavily inspired / copied from what @gcalmettes nicely summarized here: #9059 (comment) and in following messages. Thanks a ton for the nice investigation and great ideas of how to improve Mistral function calling.

Based on @gcalmettes's idea here #9059 (comment) both tekken models (mistral-nemo) and spm models (mistral-8b) output the [TOOL_CALLS] token so that it can be consumed by the tool parser and hence allow for more robust function calling parsing, e.g.:

vllm serve mistralai/Ministral-8B-Instruct-2410 --tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral --enable-auto-tool-choice

and then ping the model e.g. via:

import requests
import json

url = 'http://<your-node>:8000/v1/chat/completions'
headers = {
    'Content-Type': 'application/json',
    'Authorization': 'Bearer token'
}

model = "mistralai/Ministral-8B-Instruct-2410"

tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The city to find the weather for, e.g. 'San Francisco'"
                    },
                    "state": {
                        "type": "string",
                        "description": "The state abbreviation, e.g. 'CA' for California"
                    },
                    "unit": {
                        "type": "string",
                        "description": "The unit for temperature",
                        "enum": ["celsius", "fahrenheit"]
                    }
                },
                "required": ["city", "state", "unit"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "rewrite",
            "description": "Rewrite a given text for improved clarity",
            "parameters": {
                "type": "object",
                "properties": {
                    "text": {
                        "type": "string",
                        "description": "The input text to rewrite"
                    }
                }
            }
        }
    }
]

messages = [
    {"role": "system", "content": "You are an assistant."},
    {
        "role": "user",
        "content": "Could you please rewrite the below article?\n\nMy English needs improvving, maybe I make erors."
    },
    {
        "role": "assistant",
        "content": "",
        "tool_calls": [
            {
                "id": "bbc5b7ede",
                "type": "function",
                "function": {
                    "name": "rewrite",
                    "arguments": '{"text": "My English needs improvving, maybe I make erors."}'
                }
            }
        ]
    },
    {
        "role": "tool",
        "content": '{"action":"rewrite","outcome":"My English needs improving, maybe I make errors."}',
        "tool_call_id": "bbc5b7ede",
        "name": "rewrite"
    },
    {
        "role": "assistant",
        "content": "---\n\nMy English needs improving, maybe I make errors."
    },
    {
        "role": "user",
        "content": "Can you tell me what the temperature will be in Dallas, in Fahrenheit?"
    },
]

data = {
    "model": model,
    "messages": messages,
    "tools": tools
}

response = requests.post(url, headers=headers, data=json.dumps(data))
print(response.json())

This PR should then also finally close: #9059.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the frontend label Nov 14, 2024
@@ -58,17 +61,62 @@
},
"required": ["city", "state", "unit"]
}
},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make test much more difficult, complex to show the community to what extent function calling can be used with Mistral models


model_output = outputs[0].outputs[0].text.strip()
assert model_output.startswith(tool_parser.bot_token), model_output
parsed_message = tool_parser.extract_tool_calls(model_output, None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaner to let the parser take care of correctly extracting the dict

break
request.messages[i][
"tool_calls"] = validated_tool_calls
maybe_serialize_tool_calls(request)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving this out of serving_chat.py just to clean the method a bit. This is a very general method and the error correction here is very mistral specific, so probably better placed in tokenizers.mistral.py

Copy link
Contributor

@gcalmettes gcalmettes Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

I had originally thought about putting it directly in the Mistral Tokenizer but did not in the end because the same problem would occur for any other futur models having a tokenizer not relying on jinja chat templates (none right now, so this was highly hypothetical).
Factoring the logic in the function like you did is a good solution that would still work with other non-chat-template models 👍


request.messages[i]["tool_calls"] = validated_tool_calls


def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As proposed by @gcalmettes here: #9059 (comment)

We don't parse away the [TOOL_CALLS] token for neither tekken nor spm so that function calls can be correctly parsed.

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this PR! I think it's a lot cleaner now.

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 14, 2024
break
request.messages[i][
"tool_calls"] = validated_tool_calls
maybe_serialize_tool_calls(request)
Copy link
Contributor

@gcalmettes gcalmettes Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

I had originally thought about putting it directly in the Mistral Tokenizer but did not in the end because the same problem would occur for any other futur models having a tokenizer not relying on jinja chat templates (none right now, so this was highly hypothetical).
Factoring the logic in the function like you did is a good solution that would still work with other non-chat-template models 👍

@@ -222,7 +260,8 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self.is_tekken:
tokens = [
t for t in tokens
if t not in self.tokenizer._all_special_tokens
if (t is SpecialTokens.tool_calls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that after further testing on my end, I found a edge case where not skipping the [TOOL_CALLS] token here can potentially mess up the intended output:

  • when requiring structured output by specifying response_format=json_object or response_format=json_schema, the [TOOL_CALL] token is still emitted in some cases even though we are not providing any tools to the model, and therefore the generated output is no more compliant with json. I have tested and observed this with all the vllm supported structured output backends (lm-format-enforcer / outlines). Note that this only happens if there is no mention that we expect JSON responses from the model in the system prompt.

If we can find a way to not filter out the SpecialTokens.tool_calls token only when function calling is required (based on the presence of tools in the request for example), that would be best. However I haven't found a clean way yet to pass this information to the convert_tokens_to_string method without having to change the signature of the method ...

I have an easy reproducible example of this problem that I can share to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the note! Would be great if you could share an easy repro

Copy link
Contributor

@gcalmettes gcalmettes Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten please find below a scenario were it will break (and further below the small change in prompt that would make the code work, because of added guidance to the model). Note that the code requires lm-format_enforcer version 0.10.9 so it is compatible with the MistralTokenizer.

However, after further investigation, I know now how to fix it (I'm preparing a PR, I'll tag you for your review) ! In fact the problem was present before but "masked" by the fact that the [TOOL_CALL] was skipped in the convert_tokens_to_string method, so your PR made possible to expose the problem 😉 . (the root cause is that all the structured output librairies filter out the special tokens to build their tree of possible tokens, e.g.: this check in lm-format-enforcer but the current vllm MistralTokenizer does not correctly populate the methods that the librairies use for that. The fix is easy, and I have tested it with success.)

"""
vllm server started with the following arguments:
    --guided-decoding-backend=lm-format-enforcer 
    --enable-auto-tool-choice 
    --tool-call-parser=mistral 
    --tokenizer-mode=mistral
"""

from openai import OpenAI
from pydantic import BaseModel

client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="none",
)

class CalendarEvent(BaseModel):
    name: str
    date: str
    participants: list[str]

completion = client.beta.chat.completions.parse(
    model="mistralai/Pixtral-12B-2409",
    messages=[
        {"role": "system", "content": "Extract the event information."},
        {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
    ],
    response_format=CalendarEvent,
)

# the response will break as `[TOOL_CALLS]` is present at the beginning of the response
event = completion.choices[0].message.parsed
print(event.__dict__)

Guiding the model to output JSON by changing the system prompt as below is enough so that the model actually does not produce a tool_call token :

{"role": "system", "content": "Extract the event information. Respond as JSON."},

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) November 15, 2024 00:25
@DarkLight1337 DarkLight1337 merged commit 11cd1ae into vllm-project:main Nov 15, 2024
62 checks passed
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
rickyyx pushed a commit to rickyyx/vllm that referenced this pull request Nov 20, 2024
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
@K-Mistele
Copy link
Contributor

Out of curiosity was this PR included in 0.6.4?

@gcalmettes
Copy link
Contributor

gcalmettes commented Dec 3, 2024

@K-Mistele yes.
But you might want to use the 0.6.4.post1 if you are also using guided decoding (0.6.4.post1 also includes #10363)

sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: "--tokenizer-mode", "mistral" not compatible with openai API tool use tests
5 participants