Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] ZeroBubble support MoeHybridplugin; #6077

Merged
merged 6 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,11 @@ def __init__(
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style == "interleaved" or pp_style == "zbv"
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
Expand All @@ -309,6 +312,7 @@ def __init__(
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
use_zbv=(pp_style == "zbv"),
)

if pp_style == "interleaved":
Expand All @@ -329,7 +333,8 @@ def __init__(
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "zbv":
self.schedule = ZeroBubbleVPipeScheduler(
assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV"
self.scheduler = ZeroBubbleVPipeScheduler(
schedule=scheduler_nodes,
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
Expand Down
30 changes: 23 additions & 7 deletions colossalai/shardformer/policies/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,30 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)

if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
stage_manager.stage_indices = stage_indices
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices:
held_layers.extend(module.layers[start_idx:end_idx])
if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
):
# for zbv, when is_first_stage (last fwd), we append norm
# for interleaved, when is_last_stage (last fwd), we also append norm
held_layers.append(module.norm)
else:
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers


Expand Down
280 changes: 171 additions & 109 deletions tests/test_pipeline/test_schedule/test_zerobubble_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,30 @@
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel

import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close

NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
NUM_LAYERS = 8
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1


class MlpModel(nn.Module):
Expand Down Expand Up @@ -730,127 +743,176 @@ def criterion_base(x, *args, **kwargs):
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)


# TODO:4) support Hybrid base 3)
# TODO:3) support booster & Hybrid base 2)
def run_with_hybridplugin(test_config):
pass


# TODO:5) support MoEHybrid base 3)
# TODO:4) support booster & MoEHybrid base 2)
@parameterize(
"test_config",
"config",
[
{
"pp_style": "zbv",
"tp_size": 1,
"ep_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunks": 2,
},
(0, 1, 4, 1, 1),
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_with_moehybridplugin(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
# test_config["use_lazy_init"] = False
test_config["initial_scale"] = 2**16
model_list = [
"transformers_bert",
]
clear_layout_converter()

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
# base param
model = model_fn()
data = data_gen_fn()
print(f"data {data}")
criterion = loss_fn
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)

output = model(**data)
loss = criterion(output)
loss.backward()
optimizer.step()
print(f"output {output}")

# # pp param
# model_pp = deepcopy(model)
# data_pp = deepcopy(data)
# optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))

# # init pipeline graph
# h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
# mem_f = 34 * h + 5 * a * s
# mem_w = -32 * h
# mem_b = -mem_w - mem_f
# graph = PipelineGraph(
# n_stage=test_config["pp_size"],
# n_micro=test_config["num_microbatches"],
# f_cost=1,
# b_cost=1,
# w_cost=1,
# c_cost=1,
# f_mem=mem_f,
# b_mem=mem_b,
# w_mem=mem_w,
# # max_mem=mem_f * (p * 2 + m_offset),
# )

# zbv_schedule = graph.get_v_schedule()

# test_config["scheduler_nodes"] = zbv_schedule
# plugin = MoeHybridParallelPlugin(
# **test_config
# )
# model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure(
# model = model_pp,
# optimizer = optimizer_pp,
# criterion = criterion,
# dataloader = data_pp,
# )

# output_pp = plugin.execute_pipeline(
# data_iter=iter(data),
# model=model,
# criterion=criterion,
# optimizer=optimizer,
# return_loss = True,
# return_outputs = True,
# )


# TODO:6) support booster & Hybrid base 4)


# TODO:7) support booster & MoEHybrid base 4)
@parameterize(
"test_config",
[
{
"pp_style": "zbv",
"tp_size": 1,
"ep_size": 1,
"pp_size": 4,
"num_microbatches": 4,
"zero_stage": 1,
"precision": "bf16",
"num_model_chunks": 2,
},
],
)
def run_with_booster_moehybridplugin(test_config):
pass
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
test_config = config
stage, ep_size, pp_size, tp_size, sp_size = config
num_microbatches = pp_size
dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())

########
# init base model
########
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
attn_implementation="flash_attention_2",
)

# init model with the same seed
seed_all(10086)

torch_model = MixtralModel(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
# init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
# max_mem=mem_f * (p * 2 + m_offset),
)

zbv_schedule = graph.get_v_schedule()

# init MoeHybridPlugin
plugin = MoeHybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)

dp_size = plugin.dp_size

booster = Booster(plugin=plugin)

########
# init pp model
########

parallel_model = deepcopy(torch_model)
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
# create different input along dp axis
seed_all(1453 + rank)

torch_model.train()
parallel_model.train()
for _ in range(2):
# gen random input
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check

dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input

# run the model with hybrid parallel
if booster.plugin.stage_manager is not None:
# for test with pp
data_iter = iter([{"inputs_embeds": input_embeddings}])
sharded_output = booster.execute_pipeline(
data_iter,
parallel_model,
lambda x, y: x.last_hidden_state.mean(),
parallel_optimizer,
return_loss=True,
return_outputs=True,
)
# stage 0 chunk 0
parallel_output = None
if (
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
):
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)

else:
# for test without pp
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)

# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()


def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# run_fwd_bwd_iter_input()
run_fwd_bwd_vschedule_with_optim()
# run_with_moehybridplugin()
# run_with_booster_moehybridplugin()
# run_fwd_bwd_vschedule_with_optim()
run_with_booster_moehybridplugin()


@pytest.mark.dist
Expand Down
Loading