Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 Conversion from Original Tensorflow Produce rubbish Text #7791

Closed
2 of 4 tasks
agemagician opened this issue Oct 14, 2020 · 16 comments · Fixed by #8528
Closed
2 of 4 tasks

T5 Conversion from Original Tensorflow Produce rubbish Text #7791

agemagician opened this issue Oct 14, 2020 · 16 comments · Fixed by #8528
Assignees

Comments

@agemagician
Copy link
Contributor

Environment info

  • transformers version: 3.0.2
  • Platform: Linux-4.19.112+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.6.0+cu101 (False)
  • Tensorflow version (GPU?): 2.3.0 (False)
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help

Text Generation: @TevenLeScao
T5: @patrickvonplaten

Information

Model I am using (Bert, XLNet ...):
T5

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

https://colab.research.google.com/drive/112Jt7VFwHHT-QmMxFPJ764GNJBn0d5eX?usp=sharing

Expected behavior

We have started a big project for source code tasks (generation, summarisation, documentation, etc.) using language models. Using T5 text to text library, the model can predict the input correctly, However, after we converted the Tensorflow checkpoint to huggingface the output text is rubbish.
I am not sure if we are doing something wrong during conversion or there is a problem in loading and converting the weights from the original Tensorflow checkpoint to Pytorch.

The above Colab re-produce the issue.
Important Note: We are using a copy of "adapt_t5_for_covid_19_3b" branch which should fix the conversion problem with only one small modification, setting is_tied to false.

Your help is highly appreciated.

@patrickvonplaten patrickvonplaten self-assigned this Oct 15, 2020
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 15, 2020

Hey @agemagician - did you train your model using the "newer" T5 model (see here #6285) for reference or is it the "original" T5 model?

@agemagician
Copy link
Contributor Author

agemagician commented Oct 15, 2020

No, this is the original T5 model.

I just doubled checked the training script as well as the operative_config :
https://storage.googleapis.com/t5_convert_tranformers/model/operative_config.gin

@patrickvonplaten
Copy link
Contributor

Ok! From a first check of your google colab it looks like the model was correctly converted to PT (the "Weights not copied to PyTorch model: message is empty meaning that all PT weights are initialiazed).

Do you think you could check if it might be the tokenizer that does not work correctly? Could you maybe run an integration test for some input_ids to check if original t5 implementation yields same output as the PT version?

@agemagician
Copy link
Contributor Author

agemagician commented Oct 15, 2020

I have loaded the original T5 tokenizer then encoded the data and performed generation using Pytorch to make sure the input is the same for both original T5 script and Pytorch script, and the results is still rubbish.

I have checked the original T5 tokenizer and Pytorch tokenizer and they produce the same encoding/decoding. The only difference is that Pytorch tokenizer doesn't append Eos.

I have added a new section on the Colab "Part IIII: Check tokenizers" which perform these tests.

@agemagician
Copy link
Contributor Author

agemagician commented Oct 15, 2020

Since the input is the same to both original T5 script and Pytorch script, I think the issue should be in one of the following:

  1. The conversion process.
  2. The generation process.
  3. The loading process.

@patrickvonplaten
Copy link
Contributor

Thanks, I hope to be able to take a look at this soon!

@agemagician
Copy link
Contributor Author

agemagician commented Nov 3, 2020

@patrickvonplaten Any update for fixing this issue ?

We started to release our models for the following tasks:

  1. api generation
  2. code comment generation
  3. commit generation
  4. function documentation generation
  5. program synthesis
  6. source code summarization
  7. Code generation

for the following languages:

  1. go
  2. java
  3. javascript
  4. php
  5. python
  6. ruby
  7. c#
  8. SQL
  9. LISP

https://github.com/agemagician/CodeTrans

However, we are using T5 original library for now, as huggingface transformers is still producing rubbish text after conversion.

It will be really useful if we can integrate and use huggingface transformers for this project too.

@patrickvonplaten
Copy link
Contributor

Will take a look today!

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 11, 2020

@agemagician - I looked into it. It's quite nightmarish to debug in mesh tensorflow ... :-/ I couldn't find the bug sadly and it's getting very time-consuming. I'll gonna spend some time now to integrate mt5 and T5v1.1, so I'll still be working with the mesh tensorflow library. I hope to be able to come back to this problem! A couple of things I found out:

  1. The input_ids passed to the Encoder for
"Code: function isStandardBrowserEnv ( ) { if ( typeof navigator !== 'undefined' && ( navigator . product === 'ReactNative' || navigator . product === 'NativeScript' || navigator . product === 'NS' ) ) { return false ; } return ( typeof window !== 'undefined' && typeof document !== 'undefined' ) ; }
Documentation: Returns true if the browser is a native element ."

is actually not the same for Hugging Face T5 and Mesh TF T5. => I suspect the tokenizers to behave differently here or mesh tf to do something under the hood with the input text

  1. Sadly even if I pass the exact same input_ids to the encoder of both models, the encoder outputs are still different => this means that there is a different in the architecture. I suspect that mesh TensorFlow handles the relative_attention_bias different for the EncoderDecoderSelfAttention. In the mesh tensorflow's gin it's set no None, but in our code its definitely used. Did not manage to check it here in more detail.

=> Overall the problem is that mesh_tensorflow is constantly adding new features that are configurable with the gin config, but some of these new features are not implemented in HF and are therefore not used. So what is probably happening is that a mesh tensorflow trained model has the exact same weights as the HF implementation but has a slightly different architecture that cannot be configured with the HF T5 model...it's very hard for us to make sure that mesh tensorflow is kept constantly compatible with HF and we probably won't have the time to make sure it is. The only real solution is to use a HF pre-trained and train it within our environment or make sure that before mesh tensorflow training that the model is compatible with HF (checking the output of the pretrained models).

In case you want to take a deeper look here are my simplified scripts I used for debugging:

for mesh tf model:

import t5
from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary

t5_model = t5.models.MtfModel(
    model_dir="./checkpoint",
    batch_size=16,
    sequence_length={"inputs": 128, "targets": 32},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=None,
    iterations_per_loop=100,
    tpu=None
)

vocab_model_path = 'gs://t5_convert_tranformers/spm/code_spm_unigram_40M.model'
vocab = SentencePieceVocabulary(vocab_model_path, extra_ids=100)

t5_model.predict(
    input_file="input.txt",
    output_file="output.txt",
    vocabulary=vocab,
    temperature=0
)

and HF:

from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch

input_text = "javascript documentation generation: function isStandardBrowserEnv ( ) { if ( typeof navigator !== 'undefined' && ( navigator . product === 'ReactNative' || navigator . product === 'NativeScript' || navigator . product === 'NS' ) ) { return false ; } return ( typeof window !== 'undefined' && typeof document !== 'undefined' ) ; }"

model = T5ForConditionalGeneration.from_pretrained("./pytorch_model").to("cuda")
tok = T5Tokenizer.from_pretrained("./pytorch_model")

#input_ids = tok(input_text, return_tensors="pt").input_ids.to("cuda")
input_ids = torch.tensor([[69,  8316,  3952, 12059,   171,    69,    34, 11451,  7798,
        6614,     5,     6,    12,    29,     5,   644, 16747,   494,
          20,  3910,    36,   129,     5, 16747,     4,  1668,   232,
          20, 23435,  6462,    36,   194, 16747,     4,  1668,   232,
          20,  6462,  2769,    36,   194, 16747,     4,  1668,   232,
          20,  4759,    36,     6,     6,    12,    30,   181,     9,
          16,    30,     5,   644,  1066,   494,    20,  3910,    36,
         129,   644,   722,   494,    20,  3910,    36,     6,     9,
          16,     1]], dtype=torch.long, device="cuda")


output = model.generate(input_ids, num_beams=4)

print(tok.batch_decode(output))

Then my folders had the following files (same as in your notebook).

ls checkpoint
checkpoint  code_spm_unigram_40M.model  graph.pbtxt  model.ckpt-16000.data-00000-of-00002  model.ckpt-16000.data-00001-of-00002  model.ckpt-16000.index  model.ckpt-16000.meta  operative_config.gin

and

ls pytorch_model
config.json  pytorch_model.bin  special_tokens_map.json  spiece.model  tokenizer_config.json

with all the pytorch models converted from the mesh tf spm and mesh tf checkpoint (as you've done in the colab).

And then one has to put a lot of mtf.print(x, [x], "output: ", summarize=-1) statements in the mesh tensorflow code - here e.g.: https://github.com/tensorflow/mesh/blob/165d3dc7b4186ee5b6d31c9b17b3df4f7571cf42/mesh_tensorflow/transformer/transformer_layers.py#L729, but that's very painful ;-)

Also, see here for debugging advice: tensorflow/mesh#235

Maybe by some miracle I find the problem over the next two weeks while further looking into mesh tensorflow.

Sorry, to be not too much of help here.

@agemagician
Copy link
Contributor Author

agemagician commented Nov 11, 2020

Hi @patrickvonplaten ,

Thanks a lot for looking into this issue.
We highly appreciate your effort and sorry if it wasted your time.

I have also tested our protein model "prot_t5_xl_bfd" for protein sequence generation and it has the same issue. Also our next 11B model for protein sequences "prot_t5_xxl_bfd" will have the same issue.
This means the current results that we have from all our T5 models are not correct.

Do you know if this issue exist in only the decoder or both the encoder and the decoder ?
because currently we are only using the encoder on "prot_t5_xl_bfd" for feature extraction.

I have also checked MT5 and T5v1.1 and they seem to have the same issue as our current models, so if you will work on T5v1.1, you will highly likely find the issue and the solution for path ProtTrans models and ProtCode models.

Thanks again for your time, and I will leave this issue open, until you finish T5v1.1 implementation.

@patrickvonplaten
Copy link
Contributor

It's both encoder and decoder. Even the same encoder input yielded a different encoder output

@agemagician
Copy link
Contributor Author

agemagician commented Nov 12, 2020

This is really bad for the ProtTrans project.
Thanks a lot Patrick for your clear reply.
I will try to debug it from my side, and I will update you if I found the issue.

@patrickvonplaten
Copy link
Contributor

I got T5v1.1 working now I think: #8488. But this code will certainly not work with your example since the Feed-Forward layer has different weights...

Let me take a look again at this issue in a bit. Could you maybe provide me with a code example where I just need to download 1) of your pretrained checkpoints
2) run a code snippet of the following format:

#!/usr/bin/env python3
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # or any {'0', '1', '2'}

import t5  # noqa: E402
from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary  # noqa: E402
from transformers import T5Tokenizer  # noqa: E402
from transformers.convert_t5_v1_1_original_tf_checkpoint_to_pytorch import (  # noqa: E402
    convert_tf_checkpoint_to_pytorch,
)
from transformers.modeling_t5v2 import T5Config, T5v2ForConditionalGeneration  # noqa: E402


path_to_tf_checkpoint = "/home/patrick/hugging_face/t5v1.1/t5_mesh_checkpoints"


tok = T5Tokenizer.from_pretrained("t5-small")
tok.save_pretrained(path_to_tf_checkpoint)
config = T5Config.from_pretrained("t5-small")
config.d_ff = 1024
config.num_decoder_layers = 8
config.num_layers = 8
config.num_heads = 6

config.save_pretrained(path_to_tf_checkpoint)

convert_tf_checkpoint_to_pytorch(path_to_tf_checkpoint, path_to_tf_checkpoint + "/config.json", path_to_tf_checkpoint)

t5_model = t5.models.MtfModel(
    model_dir=path_to_tf_checkpoint,
    batch_size=1,
    tpu=None,
    sequence_length={"inputs": 4, "targets": 4},
)

vocab_model_path = path_to_tf_checkpoint + "/sentencepiece.model"
vocab = SentencePieceVocabulary(vocab_model_path, extra_ids=100)

score = t5_model.score(
    inputs=["Hello there"],
    targets=["Hi I am"],
    vocabulary=vocab,
)

model = T5v2ForConditionalGeneration.from_pretrained(path_to_tf_checkpoint, return_dict=True)

input_ids = tok("Hello there", return_tensors="pt").input_ids
labels = tok("Hi I am", return_tensors="pt").input_ids

# input_ids and labels are ok!
loss = model(input_ids, labels=labels).loss

assert -(labels.shape[-1] * loss.item()) - score[0][0] < 1e-4

If all the code would be in one file -> this would really help me save time in debugging. Otherwise, maybe we can have a quick call early next week (Monday maybe?) to discuss how to best tackle the error. I got a bit lost in all the colab notebook. I'm sure it's not that hard to fix actually.

@agemagician
Copy link
Contributor Author

Great @patrickvonplaten , "du bist der Beste" :

I have created a Colab that runs your code and download one of the CodeTrans models:
https://colab.research.google.com/drive/149F64wSOjm5O-HdLWpdWJE4dAMUA-Waa?usp=sharing

Important notes:

  1. This model is using the original T5 model not v1.1. ie (word embedding is tied, uses dropout, uses RELU)
  2. It is the base model.

Let me know if anything else is required.

@patrickvonplaten patrickvonplaten linked a pull request Nov 13, 2020 that will close this issue
5 tasks
@patrickvonplaten
Copy link
Contributor

should be fixed now. Everything is explained in the PR.

@agemagician
Copy link
Contributor Author

Woohoo, thanks a lot @patrickvonplaten, you are the best 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants