-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T5 Conversion from Original Tensorflow Produce rubbish Text #7791
Comments
Hey @agemagician - did you train your model using the "newer" T5 model (see here #6285) for reference or is it the "original" T5 model? |
No, this is the original T5 model. I just doubled checked the training script as well as the operative_config : |
Ok! From a first check of your google colab it looks like the model was correctly converted to PT (the Do you think you could check if it might be the tokenizer that does not work correctly? Could you maybe run an integration test for some |
I have loaded the original T5 tokenizer then encoded the data and performed generation using Pytorch to make sure the input is the same for both original T5 script and Pytorch script, and the results is still rubbish. I have checked the original T5 tokenizer and Pytorch tokenizer and they produce the same encoding/decoding. The only difference is that Pytorch tokenizer doesn't append Eos. I have added a new section on the Colab "Part IIII: Check tokenizers" which perform these tests. |
Since the input is the same to both original T5 script and Pytorch script, I think the issue should be in one of the following:
|
Thanks, I hope to be able to take a look at this soon! |
@patrickvonplaten Any update for fixing this issue ? We started to release our models for the following tasks:
for the following languages:
https://github.com/agemagician/CodeTrans However, we are using T5 original library for now, as huggingface transformers is still producing rubbish text after conversion. It will be really useful if we can integrate and use huggingface transformers for this project too. |
Will take a look today! |
@agemagician - I looked into it. It's quite nightmarish to debug in mesh tensorflow ... :-/ I couldn't find the bug sadly and it's getting very time-consuming. I'll gonna spend some time now to integrate mt5 and T5v1.1, so I'll still be working with the mesh tensorflow library. I hope to be able to come back to this problem! A couple of things I found out:
is actually not the same for Hugging Face T5 and Mesh TF T5. => I suspect the tokenizers to behave differently here or mesh tf to do something under the hood with the input text
=> Overall the problem is that In case you want to take a deeper look here are my simplified scripts I used for debugging: for mesh tf model: import t5
from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
t5_model = t5.models.MtfModel(
model_dir="./checkpoint",
batch_size=16,
sequence_length={"inputs": 128, "targets": 32},
learning_rate_schedule=0.003,
save_checkpoints_steps=5000,
keep_checkpoint_max=None,
iterations_per_loop=100,
tpu=None
)
vocab_model_path = 'gs://t5_convert_tranformers/spm/code_spm_unigram_40M.model'
vocab = SentencePieceVocabulary(vocab_model_path, extra_ids=100)
t5_model.predict(
input_file="input.txt",
output_file="output.txt",
vocabulary=vocab,
temperature=0
) and HF: from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
input_text = "javascript documentation generation: function isStandardBrowserEnv ( ) { if ( typeof navigator !== 'undefined' && ( navigator . product === 'ReactNative' || navigator . product === 'NativeScript' || navigator . product === 'NS' ) ) { return false ; } return ( typeof window !== 'undefined' && typeof document !== 'undefined' ) ; }"
model = T5ForConditionalGeneration.from_pretrained("./pytorch_model").to("cuda")
tok = T5Tokenizer.from_pretrained("./pytorch_model")
#input_ids = tok(input_text, return_tensors="pt").input_ids.to("cuda")
input_ids = torch.tensor([[69, 8316, 3952, 12059, 171, 69, 34, 11451, 7798,
6614, 5, 6, 12, 29, 5, 644, 16747, 494,
20, 3910, 36, 129, 5, 16747, 4, 1668, 232,
20, 23435, 6462, 36, 194, 16747, 4, 1668, 232,
20, 6462, 2769, 36, 194, 16747, 4, 1668, 232,
20, 4759, 36, 6, 6, 12, 30, 181, 9,
16, 30, 5, 644, 1066, 494, 20, 3910, 36,
129, 644, 722, 494, 20, 3910, 36, 6, 9,
16, 1]], dtype=torch.long, device="cuda")
output = model.generate(input_ids, num_beams=4)
print(tok.batch_decode(output)) Then my folders had the following files (same as in your notebook). ls checkpoint
checkpoint code_spm_unigram_40M.model graph.pbtxt model.ckpt-16000.data-00000-of-00002 model.ckpt-16000.data-00001-of-00002 model.ckpt-16000.index model.ckpt-16000.meta operative_config.gin and ls pytorch_model
config.json pytorch_model.bin special_tokens_map.json spiece.model tokenizer_config.json with all the pytorch models converted from the mesh tf spm and mesh tf checkpoint (as you've done in the colab). And then one has to put a lot of Also, see here for debugging advice: tensorflow/mesh#235 Maybe by some miracle I find the problem over the next two weeks while further looking into mesh tensorflow. Sorry, to be not too much of help here. |
Hi @patrickvonplaten , Thanks a lot for looking into this issue. I have also tested our protein model "prot_t5_xl_bfd" for protein sequence generation and it has the same issue. Also our next 11B model for protein sequences "prot_t5_xxl_bfd" will have the same issue. Do you know if this issue exist in only the decoder or both the encoder and the decoder ? I have also checked MT5 and T5v1.1 and they seem to have the same issue as our current models, so if you will work on T5v1.1, you will highly likely find the issue and the solution for path ProtTrans models and ProtCode models. Thanks again for your time, and I will leave this issue open, until you finish T5v1.1 implementation. |
It's both encoder and decoder. Even the same encoder input yielded a different encoder output |
This is really bad for the ProtTrans project. |
I got T5v1.1 working now I think: #8488. But this code will certainly not work with your example since the Feed-Forward layer has different weights... Let me take a look again at this issue in a bit. Could you maybe provide me with a code example where I just need to download 1) of your pretrained checkpoints #!/usr/bin/env python3
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # or any {'0', '1', '2'}
import t5 # noqa: E402
from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary # noqa: E402
from transformers import T5Tokenizer # noqa: E402
from transformers.convert_t5_v1_1_original_tf_checkpoint_to_pytorch import ( # noqa: E402
convert_tf_checkpoint_to_pytorch,
)
from transformers.modeling_t5v2 import T5Config, T5v2ForConditionalGeneration # noqa: E402
path_to_tf_checkpoint = "/home/patrick/hugging_face/t5v1.1/t5_mesh_checkpoints"
tok = T5Tokenizer.from_pretrained("t5-small")
tok.save_pretrained(path_to_tf_checkpoint)
config = T5Config.from_pretrained("t5-small")
config.d_ff = 1024
config.num_decoder_layers = 8
config.num_layers = 8
config.num_heads = 6
config.save_pretrained(path_to_tf_checkpoint)
convert_tf_checkpoint_to_pytorch(path_to_tf_checkpoint, path_to_tf_checkpoint + "/config.json", path_to_tf_checkpoint)
t5_model = t5.models.MtfModel(
model_dir=path_to_tf_checkpoint,
batch_size=1,
tpu=None,
sequence_length={"inputs": 4, "targets": 4},
)
vocab_model_path = path_to_tf_checkpoint + "/sentencepiece.model"
vocab = SentencePieceVocabulary(vocab_model_path, extra_ids=100)
score = t5_model.score(
inputs=["Hello there"],
targets=["Hi I am"],
vocabulary=vocab,
)
model = T5v2ForConditionalGeneration.from_pretrained(path_to_tf_checkpoint, return_dict=True)
input_ids = tok("Hello there", return_tensors="pt").input_ids
labels = tok("Hi I am", return_tensors="pt").input_ids
# input_ids and labels are ok!
loss = model(input_ids, labels=labels).loss
assert -(labels.shape[-1] * loss.item()) - score[0][0] < 1e-4 If all the code would be in one file -> this would really help me save time in debugging. Otherwise, maybe we can have a quick call early next week (Monday maybe?) to discuss how to best tackle the error. I got a bit lost in all the colab notebook. I'm sure it's not that hard to fix actually. |
Great @patrickvonplaten , "du bist der Beste" : I have created a Colab that runs your code and download one of the CodeTrans models: Important notes:
Let me know if anything else is required. |
should be fixed now. Everything is explained in the PR. |
Woohoo, thanks a lot @patrickvonplaten, you are the best 😄 |
Environment info
transformers
version: 3.0.2Who can help
Text Generation: @TevenLeScao
T5: @patrickvonplaten
Information
Model I am using (Bert, XLNet ...):
T5
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
https://colab.research.google.com/drive/112Jt7VFwHHT-QmMxFPJ764GNJBn0d5eX?usp=sharing
Expected behavior
We have started a big project for source code tasks (generation, summarisation, documentation, etc.) using language models. Using T5 text to text library, the model can predict the input correctly, However, after we converted the Tensorflow checkpoint to huggingface the output text is rubbish.
I am not sure if we are doing something wrong during conversion or there is a problem in loading and converting the weights from the original Tensorflow checkpoint to Pytorch.
The above Colab re-produce the issue.
Important Note: We are using a copy of "adapt_t5_for_covid_19_3b" branch which should fix the conversion problem with only one small modification, setting is_tied to false.
Your help is highly appreciated.
The text was updated successfully, but these errors were encountered: