Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UL2 data sampling and pretraining #358

Open
wants to merge 122 commits into
base: main
Choose a base branch
from

Conversation

janEbert
Copy link
Collaborator

This adds pretraining using UL2 for both encoder-decoder, non-causal decoder-only, and causal decoder-only models.
I have not yet run large-scale tests to see if it yields the desired training improvements, but I wanted to give others the option to take a look at the code already.

Since we create them in the T5 data loader, why not use them?
Handles backward-compatibility, so the rest of the code base does not
need to change.
Namely sampling from uniform and normal distributions.
@janEbert janEbert force-pushed the ul2 branch 3 times, most recently from db95ce8 to 4d9ff77 Compare December 13, 2022 17:24
... which also improve error messages.
Instead, the user should choose a larger maximum sequence length, which
an error warns them about.
@janEbert
Copy link
Collaborator Author

janEbert commented Dec 14, 2022

Previously, I truncated sequences so the maximum amount of duplicated extra_id tokens would fit in and still be accepted by the model, losing a bit of data most of the time. I now changed it so the program just errors out and asks the user to put in a longer sequence length for the model.

This is probably a worse/undesired solution, so I kept the other code in for now (but commented).

Note that erroring out is also how the T5Dataset does it.

Instead of concatenating arrays and lists to get a certain dtype.
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
pvals = 1. / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
elif sampling_style is SamplingStyle.NORMAL:
normal_mean = (max_ngrams + 1) / 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normal_mean is not used it seems

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For small sequence lengths or low probability/mean ngram values, we
could get `max_ngrams` < 1 and `max_predictions_per_seq` < 1, causing no
masking to be done.
Now same as in the UL2 paper code snippet.
@janEbert
Copy link
Collaborator Author

janEbert commented Jan 3, 2023

There were several issues still remaining in the UL2 implementation, most notably that I only tested for micro batch sizes of 1, which when increased made the decoder-only models fail. :p
Also most notably in terms of the UL2 sampling, there was an issue regarding the S-denoisers, in which the mean was not correctly positioned, leading to shorter masks than desired.

The implementation also more closely follows the seqio implementation in the UL2 paper now, which omits the single extra_id token for the Prefix-LM task, which we previously added.

@janEbert
Copy link
Collaborator Author

janEbert commented Apr 6, 2023

I can finally report results... Comparing standard T5 training vs training with UL2 or UL2R, results in lm-eval-harness were almost always better with UL2/UL2R. Which should mean this code does improve evaluation results. :)

janEbert added 11 commits April 13, 2023 16:25
DS = DeepSpeed

No idea why this happens, I couldn't explain it after briefly looking
into the DeepSpeed source.
That is, the reproduced objective token.
Was missing `max_seq_length_dec`.
This was already the case for encoder-decoders, but is now also the case
for decoder-only models.
This also fixes problems with decoder-only attention masks.
When using the custom fused softmax kernel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants