Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward pass #6

Open
Coluding opened this issue Nov 23, 2023 · 3 comments
Open

Backward pass #6

Coluding opened this issue Nov 23, 2023 · 3 comments

Comments

@Coluding
Copy link

Coluding commented Nov 23, 2023

Hi!

First of all, thanks for your great implementation. I think it is very awesome, I like it a lot.
I was wondering if you have also implemented a backward pass for the model somewhere, since you have only shown the forward pass in this repo (please correct me if I am wrong).
The reason why I am asking is because I want to train a reversible dilated Encoder model from scratch and your code seems very well suited for the attention mechanism.

Thanks in advance and kind regards!

@fkodom
Copy link
Owner

fkodom commented Nov 24, 2023

@Coluding Not sure I understand. Could you elaborate a bit?

The forward pass is implemented here, and the backward pass can be done automatically with PyTorch. Are you thinking there is a more efficient way to perform the backward pass? In that case, it could make sense to implement it manually here, too. Or maybe some other reason that I'm overlooking?

@Coluding
Copy link
Author

Coluding commented Nov 24, 2023

Hi @fkodom !

I was just wondering if you have tested the backward pass and what implications for memory and sequence length it has.

Best regards!

@fkodom
Copy link
Owner

fkodom commented Nov 29, 2023

@Coluding Yes, the backward pass works and scales roughly the same as forward (linear with sequence length). Can test that with a slightly modified benchmark.py script:
Screen Shot 2023-11-29 at 9 35 07 AM
Screen Shot 2023-11-29 at 9 32 11 AM

INFO:root:Benchmark dilated attention...
INFO:root:Sequence length 4096: (5.918e-02 ± 2.241e-04) s
INFO:root:Sequence length 8192: (5.902e-02 ± 3.026e-05) s
INFO:root:Sequence length 16384: (5.900e-02 ± 3.374e-05) s
INFO:root:Sequence length 32768: (5.903e-02 ± 3.294e-05) s
INFO:root:Sequence length 65536: (5.905e-02 ± 3.805e-05) s
INFO:root:Sequence length 131072: (5.898e-02 ± 3.360e-05) s
INFO:root:Sequence length 262144: (5.897e-02 ± 2.324e-05) s
INFO:root:Sequence length 524288: (5.896e-02 ± 2.756e-05) s

^^ In that script, I dynamically choose the batch size, so that the total number of tokens is constant for all sequence lengths. So, it's roughly constant runtime when the forward/backward pass scales linearly with sequence length.

I haven't explicitly checked for memory profiling, but AFAIK it should scale the same as forward as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants