Skip to content

Commit

Permalink
[feat] support zbv in mixtral benchmark; (#6083)
Browse files Browse the repository at this point in the history
* [feat] support zbv in mixtral benchmark;

* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;

* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;

* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv

* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

* [feat] Linear1D_COL/ROW support zbv WeightGradStore;

* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;

* [fix] fix test case; moe error in second iter

* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;

* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

* [fix] debug zbv llama test;

* [fix] rm use_zbv flag in Shardconfig; rm debug info;

* [fix] add & fix  llama test

* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);

* [fix\ fix fail case test_shard_llama

* [fix] fix test_shard_llama

* [fix] fix llama modeling policy;

* [fix] fix test_shard_llama ci;

* [fix] fix test zerobubble

* [fix] fix handle name; rm useless comments;

* [fix] fix send recv signature;

* [fix] fix comment in llama & benchmark

* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore

* [fix] fix linear (no tp) ops func name;
  • Loading branch information
duanjunwen authored Oct 31, 2024
1 parent dac0e07 commit aed20fb
Show file tree
Hide file tree
Showing 16 changed files with 940 additions and 153 deletions.
162 changes: 105 additions & 57 deletions colossalai/pipeline/schedule/zero_bubble_pp.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def distribute_layers(

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks

# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
Expand Down
32 changes: 32 additions & 0 deletions colossalai/pipeline/weight_grad_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import queue


class WeightGradStore:

cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]

@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))

@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []

@classmethod
def pop(cls, chunk=0):
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")
3 changes: 2 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
Expand All @@ -11,6 +11,7 @@
__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"LinearWithGradAccum",
"Linear1D_Col",
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",
Expand Down
176 changes: 165 additions & 11 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import functools

import torch
import torch.distributed as dist
import torch.nn.functional as F

from colossalai.pipeline.weight_grad_store import WeightGradStore

from .utils import is_share_sp_tp

try:
Expand Down Expand Up @@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
Expand All @@ -143,6 +148,13 @@ def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv

def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)

def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
Expand All @@ -164,24 +176,160 @@ def backward(ctx, grad_output):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None


class LinearWithGradAccum(torch.autograd.Function):
"""
Linear layer baseline (no tensor parallel version).
"""

@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)

return output

@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv

def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)

def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)

total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

return grad_input, grad_weight, grad_bias, None, None, None, None

Expand Down Expand Up @@ -1043,12 +1191,18 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
)


def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
return LinearWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
)


def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)


def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
Expand Down
Loading

0 comments on commit aed20fb

Please sign in to comment.