-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
distributed TP model forward output's requires_grad is False #115
Comments
I can confirm that the intermediate |
|
Made some progress, it seems the activation doesn't require_grad anymore after going through lm-head. |
Ok it's caused by the last lm-head being sharded and all gather of the activations along the column dimension doesn't prop gradients. |
A simple workaround for now would be to not shard the lm-head; another solution would be to run the all-gather through a custom torch.autograd.Function with backward pass implemented. I'll test out the first solution for now. |
Hi, thanks for the nice work!
I've been trying to optimize the performance of the TP wrapper here, and the first thing that came to mind was balancing out the compute on each rank using distributed / multiprocessing (as opposed to threading).
I've been wrapping my model like the following
But it seems the final model output from the forward pass doesn't require grad anymore, which makes it impossible to call
loss.backward()
.Code below to reproduce
Finally, pasting some of my system configs here
The text was updated successfully, but these errors were encountered: