-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] support zbv in mixtral benchmark; (#6083)
* [feat] support zbv in mixtral benchmark; * [fix] MixtralForCausalLMPolicy get_held_layer support zbv; * [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv; * [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv * [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling; * [feat] Linear1D_COL/ROW support zbv WeightGradStore; * [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy; * [fix] fix test case; moe error in second iter * [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv; * [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd; * [fix] debug zbv llama test; * [fix] rm use_zbv flag in Shardconfig; rm debug info; * [fix] add & fix llama test * [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp); * [fix\ fix fail case test_shard_llama * [fix] fix test_shard_llama * [fix] fix llama modeling policy; * [fix] fix test_shard_llama ci; * [fix] fix test zerobubble * [fix] fix handle name; rm useless comments; * [fix] fix send recv signature; * [fix] fix comment in llama & benchmark * [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore * [fix] fix linear (no tp) ops func name;
- Loading branch information
1 parent
dac0e07
commit aed20fb
Showing
16 changed files
with
940 additions
and
153 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import queue | ||
|
||
|
||
class WeightGradStore: | ||
|
||
cache = [] | ||
weight_grad_queue = [queue.Queue(), queue.Queue()] | ||
|
||
@classmethod | ||
def put(cls, total_input, grad_output, weight, func): | ||
# func(total_input, grad_output, weight.main_grad) | ||
cls.cache.append((total_input, grad_output, weight, func)) | ||
|
||
@classmethod | ||
def flush(cls, chunk=0): | ||
cls.weight_grad_queue[chunk].put(cls.cache) | ||
cls.cache = [] | ||
|
||
@classmethod | ||
def pop(cls, chunk=0): | ||
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}") | ||
if cls.weight_grad_queue[chunk].qsize() > 0: | ||
stored_grads = cls.weight_grad_queue[chunk].get() | ||
for total_input, grad_output, weight, func in stored_grads: | ||
if weight.grad is not None: | ||
func(total_input, grad_output, weight.grad) | ||
# for first bwd; weight.grad is None, assign grad_weight to weight.grad | ||
else: | ||
grad_weight = func(total_input, grad_output) | ||
weight.grad = grad_weight | ||
else: | ||
raise Exception("Pop empty queue.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.