-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix DDP
with nf4
#1684
Fix DDP
with nf4
#1684
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1684
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 71caddb with merge base c8eb8d3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Can confirm that this PR solves my issues reported in #1665 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM just need to fix lint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM just need to fix lint
@weifengpy @drisspg
Fix #1665
TLDR: Implement
aten.cat.default
so thatNF4Tensor
can be used when usingDDP
.Overview
DDP
syncs params and buffers during__init__
. This dispatches to a call toaten.cat.default
with (potentially) a list of tensors with mixed dtypes ifnf4
tensors fall in the same bucket as regular tensors.Implementing
aten.cat.default
fixes this issue by unpacking thenf4
to their original tensors. Other operations post the sync are already implemented such that the synced modules can be properly reconstructed.Tests
Tests are located in
tests/dtypes/ddp
and can be run by executing therun_ddp_nf4_test.sh
script.This script does the following:
LoraLinear
model (ddp_nf4.py
) with world size 1 to generate a reference checkpoint.ddp_nf4.py
with world size 2 to generate test checkpoints.Example output:
ddp_nf4.py
can be parametrized: