Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncated assistant message gets a 0 asssitant mask #34494

Closed
2 of 4 tasks
Butanium opened this issue Oct 29, 2024 · 7 comments
Closed
2 of 4 tasks

Truncated assistant message gets a 0 asssitant mask #34494

Butanium opened this issue Oct 29, 2024 · 7 comments

Comments

@Butanium
Copy link

Butanium commented Oct 29, 2024

System Info

  • transformers version: 4.45.2
  • Platform: Linux-5.4.0-163-generic-x86_64-with-glibc2.31
  • Python version: 3.11.10
  • Huggingface_hub version: 0.26.0
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: No
  • GPU type: NVIDIA GeForce GTX TITAN X

Who can help?

@yonigottesman

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I modified gemma template to allow assitant_masks to work:

{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{{ '<start_of_turn>' + role + '\n'}}{% generation %}{{message['content'] | trim}}{% endgeneration %}{{ '<end_of_turn>\n' }}{% else %}{% set role = message['role'] %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}

however, if a model message gets truncated, the mask is all 0:
image

from transformers import AutoTokenizer
better_template = "<copy above>"
chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great, thank you!"},
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."},
]
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
tokens = tokenizer.apply_chat_template(
    chat,
    tokenize=True,
    return_assistant_tokens_mask=True,
    return_dict=True,
    chat_template=better_template,
    max_length=20,
    truncation=True,
)
highlighted_tokens = [
    (
        f"<span style='color: red; border: 1px solid red; padding: 2px;'>{token.replace('<', '&lt;').replace('>', '&gt;')}</span>"
        if mask
        else token.replace("<", "&lt;").replace(">", "&gt;")
    )
    for token, mask in zip(
        tokenizer.convert_ids_to_tokens(tokens["input_ids"]), tokens["assistant_masks"]
    )
]

md = "".join(highlighted_tokens)
from IPython.display import display, HTML

display(HTML(md))

# %%
tokens["assistant_masks"]
# %%

Expected behavior

I'd expect the mask to have 1 on the partial model response

@Butanium Butanium added the bug label Oct 29, 2024
@Butanium
Copy link
Author

Also @yonigottesman, assistant_mask in not converted ot a tensor even if I do

tokens = tokenizer.apply_chat_template(
    chat,
    tokenize=True,
    return_assistant_tokens_mask=True,
    return_dict=True,
    chat_template=better_template,
    return_tensors ="pt"
)

@yonigottesman
Copy link
Contributor

@Butanium you are right there is a bug in my code, I will fix and update.
BTW, you should include the {{ '<end_of_turn>\n' }} insude the generation block, as you want the model to learn to output this string when its done

@Butanium
Copy link
Author

thank you! I edited my template to include the <end_of_turn> but not the \n as those are different tokens

@Butanium
Copy link
Author

Butanium commented Nov 5, 2024

Thanks for the fix @yonigottesman! Should open another issue regarding the return type of assistant_masks? It does not get converted to pytorch tensor by default

tokens = tokenizer.apply_chat_template(
    chat,
    tokenize=True,
    return_assistant_tokens_mask=True,
    return_dict=True,
    chat_template=better_template,
    return_tensors ="pt"
)
In [14]: type(tokens.input_ids)
Out[14]: torch.Tensor

In [15]: type(tokens.assistant_masks)
Out[15]: list

@yonigottesman
Copy link
Contributor

hi can you explain why you would want a tensor in this case?
in my case i use this list to later create a "labels" tensor with -100s.

@Butanium
Copy link
Author

@yonigottesman I'd expect assistant_mask to have the same behavior as attention_mask for consistency (i.e. return a tensor with 0 and 1s). In my case, I convert it to bool tensor to take the prediction for the assistant tokens only:

assistant_masks = th.tensor(batch_tokens["assistant_masks"]).bool().to(device)

BernardZach pushed a commit to BernardZach/transformers that referenced this issue Dec 5, 2024
)

* Fix assistant tokens when truncated

* fix test

* fix test

* step
@Butanium
Copy link
Author

Hi @yonigottesman, sorry to bother you with that. I'm wondering: should I open another issue regarding this conversion to tensor issue ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants