diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 9548920a8699..8b62a1e2bd8c 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -295,8 +295,11 @@ def __init__( if self.pp_size > 1: assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" assert ( - pp_style == "interleaved" or pp_style == "zbv" - ) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + pp_style in ["interleaved", "zbv"] or num_model_chunks == 1 + ), "num_model_chunks must be 1 when using 1f1b" + assert ( + pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2 + ), "num_model_chunks must be 2 when using zero bubble pipeline" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -309,6 +312,7 @@ def __init__( enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, + use_zbv=(pp_style == "zbv"), ) if pp_style == "interleaved": @@ -329,7 +333,8 @@ def __init__( enable_metadata_cache=enable_metadata_cache, ) elif pp_style == "zbv": - self.schedule = ZeroBubbleVPipeScheduler( + assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV" + self.scheduler = ZeroBubbleVPipeScheduler( schedule=scheduler_nodes, stage_manager=self.stage_manager, num_model_chunks=num_model_chunks, diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 8e2ca5de0556..af5b15ed5d20 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -258,14 +258,30 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = stage_manager.distribute_layers(len(module.layers)) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) + stage_manager.stage_indices = stage_indices + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) + ): + # for zbv, when is_first_stage (last fwd), we append norm + # for interleaved, when is_last_stage (last fwd), we also append norm + held_layers.append(module.norm) + else: + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) return held_layers diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 0f2d6c49c749..384ed649055c 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -7,17 +7,30 @@ import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import assert_loose_close + +NUM_BATCH = 8 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4 +NUM_LAYERS = 8 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 4 +TOP_K = 1 class MlpModel(nn.Module): @@ -730,127 +743,176 @@ def criterion_base(x, *args, **kwargs): assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) -# TODO:4) support Hybrid base 3) +# TODO:3) support booster & Hybrid base 2) def run_with_hybridplugin(test_config): pass -# TODO:5) support MoEHybrid base 3) +# TODO:4) support booster & MoEHybrid base 2) @parameterize( - "test_config", + "config", [ - { - "pp_style": "zbv", - "tp_size": 1, - "ep_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunks": 2, - }, + (0, 1, 4, 1, 1), + (1, 2, 2, 1, 1), + (1, 2, 1, 2, 1), + (1, 2, 1, 1, 2), ], ) -def run_with_moehybridplugin(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") - # test_config["use_lazy_init"] = False - test_config["initial_scale"] = 2**16 - model_list = [ - "transformers_bert", - ] - clear_layout_converter() - - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name in model_list: - # base param - model = model_fn() - data = data_gen_fn() - print(f"data {data}") - criterion = loss_fn - optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5) - - output = model(**data) - loss = criterion(output) - loss.backward() - optimizer.step() - print(f"output {output}") - - # # pp param - # model_pp = deepcopy(model) - # data_pp = deepcopy(data) - # optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) - - # # init pipeline graph - # h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024 - # mem_f = 34 * h + 5 * a * s - # mem_w = -32 * h - # mem_b = -mem_w - mem_f - # graph = PipelineGraph( - # n_stage=test_config["pp_size"], - # n_micro=test_config["num_microbatches"], - # f_cost=1, - # b_cost=1, - # w_cost=1, - # c_cost=1, - # f_mem=mem_f, - # b_mem=mem_b, - # w_mem=mem_w, - # # max_mem=mem_f * (p * 2 + m_offset), - # ) - - # zbv_schedule = graph.get_v_schedule() - - # test_config["scheduler_nodes"] = zbv_schedule - # plugin = MoeHybridParallelPlugin( - # **test_config - # ) - # model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure( - # model = model_pp, - # optimizer = optimizer_pp, - # criterion = criterion, - # dataloader = data_pp, - # ) - - # output_pp = plugin.execute_pipeline( - # data_iter=iter(data), - # model=model, - # criterion=criterion, - # optimizer=optimizer, - # return_loss = True, - # return_outputs = True, - # ) - - -# TODO:6) support booster & Hybrid base 4) - - -# TODO:7) support booster & MoEHybrid base 4) -@parameterize( - "test_config", - [ - { - "pp_style": "zbv", - "tp_size": 1, - "ep_size": 1, - "pp_size": 4, - "num_microbatches": 4, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunks": 2, - }, - ], -) -def run_with_booster_moehybridplugin(test_config): - pass +def run_with_booster_moehybridplugin(config: Tuple[int, ...]): + test_config = config + stage, ep_size, pp_size, tp_size, sp_size = config + num_microbatches = pp_size + dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) + + ######## + # init base model + ######## + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + attn_implementation="flash_attention_2", + ) + + # init model with the same seed + seed_all(10086) + + torch_model = MixtralModel(config).to(dtype).cuda() + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + # init schedule + h, a, s = config.hidden_size, config.num_attention_heads, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + # max_mem=mem_f * (p * 2 + m_offset), + ) + + zbv_schedule = graph.get_v_schedule() + + # init MoeHybridPlugin + plugin = MoeHybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + ep_size=ep_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, + ) + + dp_size = plugin.dp_size + + booster = Booster(plugin=plugin) + + ######## + # init pp model + ######## + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for _ in range(2): + # gen random input + input_embeddings = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + dist.all_reduce( + input_embeddings, group=plugin.pp_group + ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input + dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([{"inputs_embeds": input_embeddings}]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x.last_hidden_state.mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + # stage 0 chunk 0 + parallel_output = None + if ( + booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) + and rank == dist.get_process_group_ranks(plugin.pp_group)[0] + ): + parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + # broadcast along pp axis + dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) + + else: + # for test without pp + parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] + dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + print(f"rank {dist.get_rank()} config {test_config} test passed") + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # run_fwd_bwd_iter_input() - run_fwd_bwd_vschedule_with_optim() - # run_with_moehybridplugin() - # run_with_booster_moehybridplugin() + # run_fwd_bwd_vschedule_with_optim() + run_with_booster_moehybridplugin() @pytest.mark.dist