diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index a34a60669031..2cad504f3391 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v -e . pip install -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 03f9c53f1d28..ae1a5275e5da 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -12,7 +12,7 @@ jobs: if: github.repository == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny timeout-minutes: 90 steps: @@ -55,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - CUDA_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v -e . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install -r requirements/requirements-test.txt diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml index 02e30f52a459..bba321fd2d59 100644 --- a/.github/workflows/example_check_on_dispatch.yml +++ b/.github/workflows/example_check_on_dispatch.yml @@ -45,9 +45,9 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: 📚 Checkout uses: actions/checkout@v3 diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 6d6952aa169a..fcff8e569ff7 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -77,7 +77,7 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/examples-data:/data/ timeout-minutes: 20 concurrency: diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml index 919fa5092a6c..abb9479492e7 100644 --- a/.github/workflows/example_check_on_schedule.yml +++ b/.github/workflows/example_check_on_schedule.yml @@ -34,7 +34,7 @@ jobs: fail-fast: false matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}} container: - image: hpcaitech/pytorch-cuda:2.0.0-11.7.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 timeout-minutes: 10 steps: - name: 📚 Checkout diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml index f9e9f400962e..bb0ceb4a8296 100644 --- a/.github/workflows/run_chatgpt_examples.yml +++ b/.github/workflows/run_chatgpt_examples.yml @@ -18,7 +18,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb timeout-minutes: 30 defaults: diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml index ec5c8ffa319f..7986889e006b 100644 --- a/.github/workflows/run_chatgpt_unit_tests.yml +++ b/.github/workflows/run_chatgpt_unit_tests.yml @@ -20,7 +20,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt timeout-minutes: 30 defaults: diff --git a/.github/workflows/run_colossalqa_unit_tests.yml b/.github/workflows/run_colossalqa_unit_tests.yml index 763db277289f..00944b92d9b6 100644 --- a/.github/workflows/run_colossalqa_unit_tests.yml +++ b/.github/workflows/run_colossalqa_unit_tests.yml @@ -19,7 +19,7 @@ jobs: github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' runs-on: [self-hosted, gpu] container: - image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 + image: hpcaitech/pytorch-cuda:2.1.0-12.1.0 volumes: - /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa - /data/scratch/llama-tiny:/data/scratch/llama-tiny @@ -51,4 +51,4 @@ jobs: TEST_DATA_PATH_EN: /data/scratch/test_data_colossalqa/companies.txt TEST_DATA_PATH_ZH: /data/scratch/test_data_colossalqa/companies_zh.txt TEST_DOCUMENT_LOADER_DATA_PATH: /data/scratch/test_data_colossalqa/tests/* - SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path \ No newline at end of file + SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path diff --git a/MANIFEST.in b/MANIFEST.in index ad26b634ac3e..f0a5611efc7d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include *.txt README.md recursive-include requirements *.txt recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi -recursive-include op_builder *.py +recursive-include extensions *.py *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi diff --git a/README.md b/README.md index 13757eece7db..3963fe2fb5d6 100644 --- a/README.md +++ b/README.md @@ -398,10 +398,10 @@ pip install colossalai **Note: only Linux is supported for now.** -However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`. +However, if you want to build the PyTorch extensions during installation, you can set `BUILD_EXT=1`. ```bash -CUDA_EXT=1 pip install colossalai +BUILD_EXT=1 pip install colossalai ``` **Otherwise, CUDA kernels will be built during runtime when you actually need them.** @@ -429,7 +429,7 @@ By default, we do not compile CUDA/C++ kernels. ColossalAI will build them durin If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer): ```shell -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` For Users with CUDA 10.2, you can still build ColossalAI from source. However, you need to manually download the cub library and copy it to the corresponding directory. @@ -445,7 +445,7 @@ unzip 1.8.0.zip cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ # install -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ```

(back to top)

diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index c0e257f54a07..e67e16231cc2 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -49,12 +49,13 @@ def _preprocess( max_length: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Preprocess the data by tokenizing.""" - sequences = [s + t for s, t in zip(sources, targets)] + sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)] sequences_token = tokenizer( - sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False ) + sources_token = tokenizer( - sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False ) assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently" @@ -65,7 +66,8 @@ def _preprocess( if tokenizer.padding_side == "right": # |prompt|completion|eos|pad| labels[i][:source_len] = IGNORE_INDEX - labels[i][-pad_len:] = IGNORE_INDEX + if pad_len>0: + labels[i][-pad_len:] = IGNORE_INDEX elif tokenizer.padding_side == "left": # |pad|prompt|completion|eos| labels[i][: pad_len + source_len] = IGNORE_INDEX diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index d6966689885e..330e4e0e395e 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -10,7 +10,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .base import OnPolicyTrainer from .callbacks import Callback @@ -105,7 +105,7 @@ def __init__( self.critic_optim = critic_optim self.offload_inference_models = offload_inference_models - self.device = get_current_device() + self.device = get_accelerator().get_current_device() def _before_fit( self, diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 7129edb060ef..95f01678640c 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -6,7 +6,6 @@ import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.utils import get_current_device from colossalai.zero.gemini.gemini_ddp import GeminiDDP from .ddp import DDPStrategy @@ -158,9 +157,19 @@ def __init__( warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") + # colossalai has changed api for get_current_device in 0.3.4 version or newer + try: + from colossalai.accelerator import get_accelerator + + chunk_init_device = get_accelerator().get_current_device() + except: + from colossalai.utils import get_current_device + + chunk_init_device = get_current_device() + # NOTE: dist should be initialized before calling get_current_device() plugin_initializer = lambda: GeminiPlugin( - chunk_init_device=get_current_device(), + chunk_init_device=chunk_init_device, placement_policy=placement_policy, shard_param_frac=shard_param_frac, offload_optim_frac=offload_optim_frac, diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh index 0fb4da3d3ce8..b7d176847d9c 100755 --- a/applications/Chat/examples/train_sft.sh +++ b/applications/Chat/examples/train_sft.sh @@ -25,4 +25,4 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \ --accumulation_steps 8 \ --lr 2e-5 \ --max_datasets_size 512 \ - --max_epochs 1 + --max_epochs 1 \ No newline at end of file diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py index a2cfb2ef6264..327651f4e645 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py @@ -1,20 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import numpy as np import os -import random from dataclasses import dataclass -from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable +from typing import Dict, Iterator, List, Optional, Sequence, Union import torch -from datasets import dataset_dict, load_from_disk +import torch.nn.functional as F from datasets import Dataset as HFDataset -from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import _get_default_group -from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler +from datasets import dataset_dict, load_from_disk +from torch.utils.data import ConcatDataset, Dataset, DistributedSampler from transformers.tokenization_utils import PreTrainedTokenizer -import torch.nn.functional as F DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset] PathType = Union[str, os.PathLike] @@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object): tokenizer: PreTrainedTokenizer max_length: int = 4096 ignore_index: int = -100 + padding: str = "max_length" def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: """ @@ -106,10 +103,11 @@ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch batch_first=True, padding_value=self.ignore_index, ) # (bsz, max_len) - # pad to max - to_pad = self.max_length - input_ids.size(1) - input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) - labels = F.pad(labels, (0, to_pad), value=self.ignore_index) + if self.padding == "max_length": + # pad to max + to_pad = self.max_length - input_ids.size(1) + input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id) + labels = F.pad(labels, (0, to_pad), value=self.ignore_index) elif self.tokenizer.padding_side == "left": reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids] reversed_input_ids = torch.nn.utils.rnn.pad_sequence( @@ -171,49 +169,3 @@ def __len__(self) -> int: def set_start_index(self, start_index: int) -> None: self.start_index = start_index - - -def setup_distributed_dataloader( - dataset: DatasetType, - batch_size: int = 1, - shuffle: bool = False, - seed: int = 1024, - drop_last: bool = False, - pin_memory: bool = False, - num_workers: int = 0, - collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None, - process_group: Optional[ProcessGroup] = None, - **kwargs, -) -> DataLoader: - """ - Setup dataloader for distributed training. - """ - _kwargs = kwargs.copy() - process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler( - dataset=dataset, - num_replicas=process_group.size(), - rank=process_group.rank(), - shuffle=shuffle, - seed=seed, - drop_last=drop_last, - ) - - # Deterministic dataloader - def seed_worker(worker_id: int) -> None: - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=sampler, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - worker_init_fn=seed_worker, - **_kwargs, - ) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 1926ec78aba8..6c048c3b18cf 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import math from types import MethodType from typing import Optional, Tuple import torch +import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func -from flash_attn.ops.rms_norm import rms_norm +from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, @@ -19,194 +19,334 @@ repeat_kv, ) +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger logger = get_dist_logger() +if get_accelerator().name == "cuda": + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func + from flash_attn.ops.rms_norm import rms_norm -def _prepare_decoder_attention_mask( - self: LlamaModel, - attention_mask: torch.BoolTensor, - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, -) -> Optional[torch.Tensor]: - """ - Decoder attetion mask - """ - if past_key_values_length > 0 and attention_mask is not None: - attention_mask = torch.cat( - tensors=( - torch.full( - size=(input_shape[0], past_key_values_length), - fill_value=True, - dtype=attention_mask.dtype, - device=attention_mask.device, + def _prepare_decoder_attention_mask( + self: LlamaModel, + attention_mask: torch.BoolTensor, + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ) -> Optional[torch.Tensor]: + """ + Decoder attetion mask + """ + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + tensors=( + torch.full( + size=(input_shape[0], past_key_values_length), + fill_value=True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, ), - attention_mask, - ), - dim=-1, - ) # (bsz, past_key_values_length + q_len) - if attention_mask is not None and torch.all(attention_mask): - return None # Faster - return attention_mask - - -def attention_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. - """ - if output_attentions: - logger.warning( - "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " - "return `None` instead." - ) - - bsz, q_len, _ = hidden_states.size() + dim=-1, + ) # (bsz, past_key_values_length + q_len) + if attention_mask is not None and torch.all(attention_mask): + return None # Faster + return attention_mask - if self.config.pretraining_tp > 1: - q_slicing, kv_slicing = ( - dim // self.config.pretraining_tp - for dim in ( - self.num_heads * self.head_dim, - self.num_key_value_heads * self.head_dim, + def attention_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. + """ + if output_attentions: + logger.warning( + "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, " + "return `None` instead." ) - ) # `Tuple[int, int]` - q_slices, k_slices, v_slices = ( - proj.weight.split(slicing, dim=0) - for proj, slicing in ( - (self.q_proj, q_slicing), - (self.k_proj, kv_slicing), - (self.v_proj, kv_slicing), + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + q_slicing, kv_slicing = ( + dim // self.config.pretraining_tp + for dim in ( + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ) + ) # `Tuple[int, int]` + q_slices, k_slices, v_slices = ( + proj.weight.split(slicing, dim=0) + for proj, slicing in ( + (self.q_proj, q_slicing), + (self.k_proj, kv_slicing), + (self.v_proj, kv_slicing), + ) + ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + q, k, v = ( + torch.cat( + [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], + dim=-1, + ) + for slices in (q_slices, k_slices, v_slices) ) - ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]] + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + else: + q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) + # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: + # (bsz, q_len, num_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim), + # (bsz, q_len, num_key_value_heads * head_dim) + + # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); + # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) q, k, v = ( - torch.cat( - [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)], - dim=-1, + states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) + for states, num_heads in ( + (q, self.num_heads), + (k, self.num_key_value_heads), + (v, self.num_key_value_heads), ) - for slices in (q_slices, k_slices, v_slices) - ) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) - else: - q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj)) - # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape: - # (bsz, q_len, num_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim), - # (bsz, q_len, num_key_value_heads * head_dim) - - # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim); - # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim) - q, k, v = ( - states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2) - for states, num_heads in ( - (q, self.num_heads), - (k, self.num_key_value_heads), - (v, self.num_key_value_heads), ) - ) - kv_len = k.shape[-2] # initially, `kv_len` == `q_len` - past_kv_len = 0 - if past_key_value is not None: - # if `past_key_value` is not None, `kv_len` > `q_len`. - past_kv_len = past_key_value[0].shape[-2] - kv_len += past_kv_len - - # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) - cos, sin = self.rotary_emb(v, seq_len=kv_len) - # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) - q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) - if past_key_value is not None: - # reuse k, v, self_attention - k = torch.cat([past_key_value[0], k], dim=2) - v = torch.cat([past_key_value[1], v], dim=2) - - past_key_value = (k, v) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) - # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) - - key_padding_mask = attention_mask - # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) - q, k, v = (states.transpose(1, 2) for states in (q, k, v)) - - if past_kv_len > 0: - q = torch.cat( - tensors=( - torch.full( - size=(bsz, past_kv_len, self.num_heads, self.head_dim), - fill_value=0.0, - dtype=q.dtype, - device=q.device, + kv_len = k.shape[-2] # initially, `kv_len` == `q_len` + past_kv_len = 0 + if past_key_value is not None: + # if `past_key_value` is not None, `kv_len` > `q_len`. + past_kv_len = past_key_value[0].shape[-2] + kv_len += past_kv_len + + # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim) + cos, sin = self.rotary_emb(v, seq_len=kv_len) + # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim) + q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + k = torch.cat([past_key_value[0], k], dim=2) + v = torch.cat([past_key_value[1], v], dim=2) + + past_key_value = (k, v) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups) + # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim) + + key_padding_mask = attention_mask + # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim) + q, k, v = (states.transpose(1, 2) for states in (q, k, v)) + + if past_kv_len > 0: + q = torch.cat( + tensors=( + torch.full( + size=(bsz, past_kv_len, self.num_heads, self.head_dim), + fill_value=0.0, + dtype=q.dtype, + device=q.device, + ), + q, ), - q, - ), - dim=1, - ) # (bsz, past_kv_len + q_len, num_heads, head_dim) - - if key_padding_mask is None: - # (bsz, past_kv_len + q_len, num_heads, head_dim) - output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) - output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim) - else: - q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) - kv, _, cu_kv_lens, max_kv_len = unpad_input( - hidden_states=torch.stack(tensors=(k, v), dim=2), - attention_mask=key_padding_mask, - ) - output_unpad = flash_attn_varlen_kvpacked_func( - q=q, - kv=kv, - cu_seqlens_q=cu_q_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_q_len, - max_seqlen_k=max_kv_len, - dropout_p=0.0, - softmax_scale=None, - causal=True, - ) - output = pad_input( - hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), - indices=indices, - batch=bsz, - seqlen=past_kv_len + q_len, - ) # (bsz, past_kv_len + q_len, num_heads * head_dim) - - if past_kv_len > 0: - # Strip off the zero query outputs. - output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) - output = self.o_proj(output) # (bsz, q_len, hidden_size) - return output, None, past_key_value - - -def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Formard function for RMS Norm - """ - return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) - - -def replace_with_flash_attention(model: LlamaForCausalLM) -> None: - for name, module in model.named_modules(): - if isinstance(module, LlamaAttention): - module.forward = MethodType(attention_forward, module) - if isinstance(module, LlamaModel): - module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) - if isinstance(module, LlamaRMSNorm): - module.forward = MethodType(rms_norm_forward, module) + dim=1, + ) # (bsz, past_kv_len + q_len, num_heads, head_dim) + + if key_padding_mask is None: + # (bsz, past_kv_len + q_len, num_heads, head_dim) + output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, ) + output = rearrange( + output, pattern="... h d -> ... (h d)" + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + else: + q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask) + kv, _, cu_kv_lens, max_kv_len = unpad_input( + hidden_states=torch.stack(tensors=(k, v), dim=2), + attention_mask=key_padding_mask, + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_q=cu_q_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_q_len, + max_seqlen_k=max_kv_len, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + output = pad_input( + hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"), + indices=indices, + batch=bsz, + seqlen=past_kv_len + q_len, + ) # (bsz, past_kv_len + q_len, num_heads * head_dim) + + if past_kv_len > 0: + # Strip off the zero query outputs. + output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim) + output = self.o_proj(output) # (bsz, q_len, hidden_size) + return output, None, past_key_value + + def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Formard function for RMS Norm + """ + return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon) + + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.forward = MethodType(attention_forward, module) + if isinstance(module, LlamaModel): + module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module) + if isinstance(module, LlamaRMSNorm): + module.forward = MethodType(rms_norm_forward, module) + +elif get_accelerator().name == "npu": + import torch_npu + + class NPULlamaAttention(LlamaAttention): + use_flash: bool = True + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.setup() + + def setup(self): + self._softmax_scale = 1 / math.sqrt(self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if not self.use_flash: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + else: + attn_output, *_ = torch_npu.npu_fusion_attention( + query_states, + key_states, + value_states, + self.num_heads, + "BNSD", + atten_mask=attention_mask.bool(), + scale=self._softmax_scale, + padding_mask=None, + pre_tockens=65535, + next_tockens=0, + keep_prob=1.0, + inner_precise=0, + ) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum( + [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class NPURMSNorm(LlamaRMSNorm): + def forward(self, hidden_states): + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] + + def replace_with_flash_attention(model: LlamaForCausalLM) -> None: + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + module.__class__ = NPULlamaAttention + module.setup() + if isinstance(module, LlamaRMSNorm): + module.__class__ = NPURMSNorm diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py index 9f6c9c1cc6f3..21d769f3c49f 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py @@ -17,7 +17,7 @@ def unwrap(model): if hasattr(model, "module"): - return unwrap_model(model.module) + return model.unwrap() else: return model diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA-2/inference_example.py index 123290d45eab..77e18d8b5939 100644 --- a/applications/Colossal-LLaMA-2/inference_example.py +++ b/applications/Colossal-LLaMA-2/inference_example.py @@ -1,17 +1,16 @@ import argparse -import os import torch +from colossal_llama2.dataset.conversation import default_conversation +from transformers import AutoModelForCausalLM, AutoTokenizer + from colossalai.logging import get_dist_logger -from transformers import AutoTokenizer, AutoModelForCausalLM logger = get_dist_logger() def load_model(model_path, device="cuda", **kwargs): - logger.info( - "Please check whether the tokenizer and model weights are properly stored in the same folder." - ) + logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.") model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) model.to(device) @@ -27,31 +26,50 @@ def load_model(model_path, device="cuda", **kwargs): def generate(args): model, tokenizer = load_model(model_path=args.model_path, device=args.device) - BASE_INFERENCE_SUFFIX = "\n\n->\n\n" - input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}" - - inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device) - output = model.generate(**inputs, - max_new_tokens=args.max_new_tokens, - do_sample=args.do_sample, - temperature=args.temperature, - top_k=args.top_k, - top_p=args.top_p, - num_return_sequences=1) - response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):] + if args.prompt_style == "sft": + conversation = default_conversation.copy() + conversation.append_message("Human", args.input_txt) + input_txt = conversation.get_prompt() + else: + BASE_INFERENCE_SUFFIX = "\n\n->\n\n" + input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}" + + inputs = tokenizer(input_txt, return_tensors="pt").to(args.device) + num_input_tokens = inputs["input_ids"].shape[-1] + output = model.generate( + **inputs, + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + num_return_sequences=1, + ) + response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True) logger.info(f"Question: {input_txt} \n\n Answer: \n{response}") return response if __name__ == "__main__": parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.") - parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model") - parser.add_argument('--device', type=str, default="cuda:0", help="Set the device") - parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt") - parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling") - parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value") - parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering") - parser.add_argument('--top_p', type=float, default=0.95, help="Set top_p value for generation") - parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model") + parser.add_argument( + "--model_path", + type=str, + default="hpcai-tech/Colossal-LLaMA-2-7b-base", + help="HF repo name or local path of the model", + ) + parser.add_argument("--device", type=str, default="cuda:0", help="Set the device") + parser.add_argument( + "--max_new_tokens", + type=int, + default=512, + help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt", + ) + parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling") + parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value") + parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering") + parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation") + parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model") + parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt") args = parser.parse_args() - generate(args) \ No newline at end of file + generate(args) diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh index 276d9ce99d42..6a1c887bf6cc 100644 --- a/applications/Colossal-LLaMA-2/train.example.sh +++ b/applications/Colossal-LLaMA-2/train.example.sh @@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train. --warmup_steps 100 \ --use_grad_checkpoint \ --use_flash_attn \ + --pad_token "unk" diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 41b4ef031b46..2e4bab75a085 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -1,45 +1,40 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Continual Pre-training of LLaMA-2 developed by Colossal-AI Team +Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team """ -import json import argparse +import json import os import resource from contextlib import nullcontext -from tqdm import tqdm import torch import torch.distributed as dist +from colossal_llama2.dataset.loader import ( + DataCollatorForSupervisedDataset, + StatefulDistributedSampler, + load_tokenized_dataset, +) +from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint +from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention +from colossal_llama2.utils.froze import freeze_non_embeds_parameters +from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from torch.utils.tensorboard import SummaryWriter -from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig +from tqdm import tqdm +from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster -from colossalai.booster.plugin import ( - GeminiPlugin, - LowLevelZeroPlugin, - HybridParallelPlugin, -) +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from colossal_llama2.dataset.loader import ( - load_tokenized_dataset, - setup_distributed_dataloader, - DataCollatorForSupervisedDataset, - StatefulDistributedSampler, -) - -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.froze import freeze_non_embeds_parameters - def get_model_numel(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters()) @@ -90,6 +85,7 @@ def main() -> None: parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") parser.add_argument("--config_file", type=str, default="config_file", help="Config file") parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") parser.add_argument("--max_length", type=int, default=4096, help="Model max length") @@ -115,6 +111,12 @@ def main() -> None: default=False, help="Use flash-attention", ) + parser.add_argument( + "--use_neft", + action="store_true", + default=False, + help="Use NEFTune", + ) parser.add_argument( "--freeze_non_embeds_params", action="store_true", @@ -123,6 +125,8 @@ def main() -> None: ) parser.add_argument("--tp", type=int, default=1) parser.add_argument("--zero", type=int, default=1) + parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") + parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") args = parser.parse_args() with open(args.config_file, "w") as f: @@ -132,6 +136,7 @@ def main() -> None: # Initialize Distributed Training # ============================== colossalai.launch_from_torch({}) + accelerator = get_accelerator() coordinator = DistCoordinator() # ============================== @@ -149,6 +154,7 @@ def main() -> None: precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip, + enable_gradient_accumulation=(args.accumulation_steps > 1), ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -156,6 +162,7 @@ def main() -> None: placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip, + enable_gradient_accumulation=(args.accumulation_steps > 1), ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin( @@ -189,7 +196,10 @@ def main() -> None: # Initialize Tokenizer, Dataset, Collator and Dataloader # ====================================================== tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) - tokenizer.pad_token = tokenizer.unk_token + if args.pad_token == "eos": + tokenizer.pad_token = tokenizer.eos_token + elif args.pad_token == "unk": + tokenizer.pad_token = tokenizer.unk_token tokenizer.add_bos_token = False tokenizer.add_eos_token = False @@ -200,29 +210,36 @@ def main() -> None: coordinator.print_on_master(f"Load dataset: {args.dataset}") dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) - dataloader = setup_distributed_dataloader( + data_collator = DataCollatorForSupervisedDataset( + tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode + ) + dataloader = plugin.prepare_dataloader( dataset=dataset, batch_size=args.micro_batch_size, shuffle=True, drop_last=True, collate_fn=data_collator, + distributed_sampler_cls=StatefulDistributedSampler, ) coordinator.print_on_master( - f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" ) # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() ) with init_ctx: - model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) + model = LlamaForCausalLM.from_pretrained(args.pretrained) # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) + # this is essential, otherwise the grad checkpoint will not work. + model.train() if args.use_grad_checkpoint: model.gradient_checkpointing_enable() @@ -244,12 +261,14 @@ def main() -> None: adamw_mode=True, ) + if args.warmup_steps is None: + args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + lr_scheduler = CosineAnnealingWarmupLR( optimizer=optimizer, - total_steps=args.num_epochs * len(dataloader), - warmup_steps=args.warmup_steps - if args.warmup_steps is not None - else int(args.num_epochs * len(dataloader) * 0.025), + total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps), + warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr, ) @@ -265,11 +284,9 @@ def main() -> None: torch.set_default_dtype(torch.float) - if args.load_checkpoint is None: - coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}") - booster.load_model(model, args.pretrained, strict=False) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") + coordinator.print_on_master( + f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" ) @@ -296,87 +313,109 @@ def main() -> None: coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") coordinator.print_on_master( - f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" + f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" ) coordinator.print_on_master( - f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" + f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB" ) coordinator.print_on_master( f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" ) - num_steps_per_epoch = len(dataloader) + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + num_steps_per_epoch = len(dataloader) // args.accumulation_steps # If resume training, set the sampler start index to the correct value assert isinstance(dataloader.sampler, StatefulDistributedSampler) dataloader.sampler.set_start_index(start_index=sampler_start_idx) for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch=epoch) - with tqdm( - iterable=enumerate(dataloader, start=start_step), + pbar = tqdm( desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, - initial=start_step, - ) as pbar: - for step, batch in pbar: - batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + initial=start_step // args.accumulation_steps, + ) + total_loss = torch.tensor(0.0, device=get_current_device()) + for step, batch in enumerate(dataloader, start=start_step): + batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} - batch_output = model(**batch) + batch_output = model(**batch) - loss = batch_output.loss + loss = batch_output.loss / args.accumulation_steps + total_loss.add_(loss.data) - booster.backward(loss=loss, optimizer=optimizer) + booster.backward(loss=loss, optimizer=optimizer) + if (step + 1) % args.accumulation_steps == 0: optimizer.step() lr_scheduler.step() optimizer.zero_grad() - all_reduce_mean(tensor=loss) - pbar.set_postfix({"Loss": f"{loss.item():.4f}"}) + all_reduce_mean(tensor=total_loss) + pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) if coordinator.is_master(): - global_step = epoch * num_steps_per_epoch + step - writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step) + global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) writer.add_scalar( tag="Learning Rate", scalar_value=lr_scheduler.get_last_lr()[0], global_step=global_step, ) - # Save modeling. - - if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader): - coordinator.print_on_master("\nStart saving model checkpoint with running states") - save_checkpoint( - save_dir=args.save_dir, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - epoch=epoch, - step=step + 1, - batch_size=args.micro_batch_size, - coordinator=coordinator, - ) - coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" - ) - - # Delete CUDA cache. - # del batch, batch_labels, batch_output, loss - torch.cuda.empty_cache() + total_loss.fill_(0.0) + pbar.update() + # Save modeling. + + if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( + step + 1 + ) == len(dataloader): + coordinator.print_on_master("\nStart saving model checkpoint with running states") + + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune before saving model.") + deactivate_neftune(model, handle) + + accelerator.empty_cache() + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.micro_batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) + + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) + + # Delete cache. + # del batch, batch_labels, batch_output, loss + accelerator.empty_cache() # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(start_index=0) start_step = 0 + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune.") + deactivate_neftune(model, handle) + # Final save. coordinator.print_on_master("Start saving final model checkpoint") booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master( - f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}" - ) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/applications/Colossal-LLaMA-2/train_sft.example.sh b/applications/Colossal-LLaMA-2/train_sft.example.sh index dcb11515d48f..d87f9ef82f4f 100755 --- a/applications/Colossal-LLaMA-2/train_sft.example.sh +++ b/applications/Colossal-LLaMA-2/train_sft.example.sh @@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}" TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}" CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json" -colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \ +colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \ --pretrained $PRETRAINED_MODEL_PATH \ --dataset ${dataset[@]} \ --plugin "zero2" \ @@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_ --use_grad_checkpoint \ --use_flash_attn \ --use_neft \ + --pad_token "eos" diff --git a/applications/Colossal-LLaMA-2/train_sft.py b/applications/Colossal-LLaMA-2/train_sft.py deleted file mode 100644 index fd9e1cd3e747..000000000000 --- a/applications/Colossal-LLaMA-2/train_sft.py +++ /dev/null @@ -1,403 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team -""" - -import argparse -import json -import os -import resource -from contextlib import nullcontext - -import torch -import torch.distributed as dist -from colossal_llama2.dataset.loader import ( - DataCollatorForSupervisedDataset, - StatefulDistributedSampler, - load_tokenized_dataset, - setup_distributed_dataloader, -) -from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint -from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention -from colossal_llama2.utils.froze import freeze_non_embeds_parameters -from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device - - -def get_model_numel(model: torch.nn.Module) -> int: - return sum(p.numel() for p in model.parameters()) - - -def format_numel_str(numel: int) -> str: - B = 1024**3 - M = 1024**2 - K = 1024 - if numel >= B: - return f"{numel / B:.2f} B" - elif numel >= M: - return f"{numel / M:.2f} M" - elif numel >= K: - return f"{numel / K:.2f} K" - else: - return f"{numel}" - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) - tensor.div_(dist.get_world_size()) - return tensor - - -def main() -> None: - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument( - "--pretrained", - type=str, - default=None, - help="Address of the pre-trained modeling", - ) - parser.add_argument("--dataset", nargs="+", default=[]) - parser.add_argument( - "--plugin", - type=str, - default="gemini", - choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"], - help="Choose which plugin to use", - ) - parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") - parser.add_argument("--save_interval", type=int, default=1000, help="Save interval") - parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") - parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory") - parser.add_argument("--config_file", type=str, default="config_file", help="Config file") - parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") - parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps") - parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process") - parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=4096, help="Model max length") - parser.add_argument( - "--mixed_precision", - type=str, - default="fp16", - choices=["fp16", "bf16"], - help="Mixed precision", - ) - parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") - parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") - parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") - parser.add_argument( - "--use_grad_checkpoint", - action="store_true", - default=False, - help="Use gradient checkpointing", - ) - parser.add_argument( - "--use_flash_attn", - action="store_true", - default=False, - help="Use flash-attention", - ) - parser.add_argument( - "--use_neft", - action="store_true", - default=False, - help="Use NEFTune", - ) - parser.add_argument( - "--freeze_non_embeds_params", - action="store_true", - default=False, - help="Freeze non embeddings parameters", - ) - parser.add_argument("--tp", type=int, default=1) - parser.add_argument("--zero", type=int, default=1) - args = parser.parse_args() - - with open(args.config_file, "w") as f: - json.dump(args.__dict__, f, indent=4) - - # ============================== - # Initialize Distributed Training - # ============================== - colossalai.launch_from_torch({}) - coordinator = DistCoordinator() - - # ============================== - # Initialize Tensorboard - # ============================== - if coordinator.is_master(): - os.makedirs(args.tensorboard_dir, exist_ok=True) - writer = SummaryWriter(args.tensorboard_dir) - - # ============================== - # Initialize Booster - # ============================== - if args.plugin == "gemini": - plugin = GeminiPlugin( - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "gemini_auto": - plugin = GeminiPlugin( - precision=args.mixed_precision, - placement_policy="auto", - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin( - stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip, - ) - elif args.plugin == "zero2_cpu": - plugin = LowLevelZeroPlugin( - stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - cpu_offload=True, - max_norm=args.grad_clip, - ) - elif args.plugin == "3d": - plugin = HybridParallelPlugin( - tp_size=args.tp, - pp_size=1, - zero_stage=args.zero, - max_norm=args.grad_clip, - precision=args.mixed_precision, - ) - else: - raise ValueError(f"Unknown plugin {args.plugin}") - - booster = Booster(plugin=plugin) - - # ====================================================== - # Initialize Tokenizer, Dataset, Collator and Dataloader - # ====================================================== - tokenizer = LlamaTokenizer.from_pretrained(args.pretrained) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.add_bos_token = False - tokenizer.add_eos_token = False - - coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") - coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}") - coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}") - - coordinator.print_on_master(f"Load dataset: {args.dataset}") - - dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") - data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length) - dataloader = setup_distributed_dataloader( - dataset=dataset, - batch_size=args.micro_batch_size, - shuffle=True, - drop_last=True, - collate_fn=data_collator, - ) - coordinator.print_on_master( - f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" - ) - - # ====================================================== - # Initialize Model, Objective, Optimizer and LR Scheduler - # ====================================================== - init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() - ) - with init_ctx: - model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) - # Freeze part of parameters. - if args.freeze_non_embeds_params: - freeze_non_embeds_parameters(model=model) - - if args.use_grad_checkpoint: - model.gradient_checkpointing_enable() - coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") - if args.use_flash_attn: - replace_with_flash_attention(model=model) - coordinator.print_on_master(msg="Flash-attention enabled successfully") - - model_numel = get_model_numel(model) - coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") - - optimizer = HybridAdam( - model_params=filter(lambda p: p.requires_grad, model.parameters()) - if args.freeze_non_embeds_params - else model.parameters(), - lr=args.lr, - betas=(0.9, 0.95), - weight_decay=args.weight_decay, - adamw_mode=True, - ) - - if args.warmup_steps is None: - args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps)) - coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") - - lr_scheduler = CosineAnnealingWarmupLR( - optimizer=optimizer, - total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps), - warmup_steps=args.warmup_steps, - eta_min=0.1 * args.lr, - ) - - # Flash attention will be disabled because it does NOT support fp32. - default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 - torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost( - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - dataloader=dataloader, - ) - - torch.set_default_dtype(torch.float) - - if args.load_checkpoint is None: - coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}") - booster.load_model(model, args.pretrained, strict=False) - - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") - coordinator.print_on_master( - f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" - ) - - start_epoch = 0 - start_step = 0 - sampler_start_idx = 0 - if args.load_checkpoint is not None: - if "modeling" in args.load_checkpoint: - coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}") - booster.load_model(model, args.load_checkpoint) - else: - coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}") - start_epoch, start_step, sampler_start_idx = load_checkpoint( - load_dir=args.load_checkpoint, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - ) - coordinator.print_on_master( - f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}" - ) - coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}") - - coordinator.print_on_master( - f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" - ) - coordinator.print_on_master( - f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB" - ) - coordinator.print_on_master( - f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" - ) - - if args.use_neft: - coordinator.print_on_master("Activate NEFTune.") - model, handle = activate_neftune(model) - - num_steps_per_epoch = len(dataloader) // args.accumulation_steps - # If resume training, set the sampler start index to the correct value - assert isinstance(dataloader.sampler, StatefulDistributedSampler) - dataloader.sampler.set_start_index(start_index=sampler_start_idx) - - for epoch in range(start_epoch, args.num_epochs): - dataloader.sampler.set_epoch(epoch=epoch) - pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch) - total_loss = torch.tensor(0.0).to(torch.cuda.current_device()) - for step, batch in enumerate(dataloader): - batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} - - batch_output = model(**batch) - - loss = batch_output.loss / args.accumulation_steps - total_loss += loss.item() - - booster.backward(loss=loss, optimizer=optimizer) - - if (step + 1) % args.accumulation_steps == 0: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - all_reduce_mean(tensor=total_loss) - pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"}) - if coordinator.is_master(): - global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps - writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) - writer.add_scalar( - tag="Learning Rate", - scalar_value=lr_scheduler.get_last_lr()[0], - global_step=global_step, - ) - total_loss.fill_(0.0) - pbar.update() - # Save modeling. - - if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( - step + 1 - ) == len(dataloader): - coordinator.print_on_master("\nStart saving model checkpoint with running states") - - if args.use_neft: - coordinator.print_on_master("Deactivate NEFTune before saving model.") - deactivate_neftune(model, handle) - - save_checkpoint( - save_dir=args.save_dir, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - epoch=epoch, - step=step + 1, - batch_size=args.micro_batch_size, - coordinator=coordinator, - ) - coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" - ) - - if args.use_neft: - coordinator.print_on_master("Activate NEFTune.") - model, handle = activate_neftune(model) - - # Delete CUDA cache. - # del batch, batch_labels, batch_output, loss - torch.cuda.empty_cache() - - # the continue epochs are not resumed, so we need to reset the sampler start index and start step - dataloader.sampler.set_start_index(start_index=0) - start_step = 0 - - if args.use_neft: - coordinator.print_on_master("Deactivate NEFTune.") - deactivate_neftune(model, handle) - - # Final save. - coordinator.print_on_master("Start saving final model checkpoint") - booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) - coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}") - - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") - - -if __name__ == "__main__": - main() diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py index f293c4f699cd..9c70c0d2a1ad 100644 --- a/applications/ColossalEval/colossal_eval/models/chatglm.py +++ b/applications/ColossalEval/colossal_eval/models/chatglm.py @@ -3,6 +3,8 @@ import torch +from colossalai.utils import get_current_device + from .huggingface import HuggingFaceModel IGNORE_INDEX = -100 @@ -126,9 +128,9 @@ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[t """ input_ids = torch.nn.utils.rnn.pad_sequence( input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id - ).to(torch.cuda.current_device()) + ).to(get_current_device()) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( - torch.cuda.current_device() + get_current_device() ) outputs = self.model(input_ids)[0] @@ -197,7 +199,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str truncation=True, return_tensors="pt", max_length=self.model_max_length - max_new_tokens, - ).to(torch.cuda.current_device()) + ).to(get_current_device()) # Set output_scores=True to get prediction scores. outputs = self.model.generate( diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 741c884f0043..fff697e21e34 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -11,6 +11,7 @@ from colossalai.logging import DistributedLogger from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.utils import get_current_device from .base import BaseModel @@ -128,12 +129,12 @@ def _load_model( self.model = AutoModel.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) self.model, sharded_parameters = shard_former.optimize(self.model) - self.model.to(torch.cuda.current_device()) + self.model.to(get_current_device()) if peft_path is not None: raise NotImplementedError("ShardFormer for PEFT models is not implemented.") else: - self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device()) if peft_path is not None: self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) self.model.eval() @@ -155,11 +156,11 @@ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[t """ input_ids = torch.nn.utils.rnn.pad_sequence( input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id - ).to(torch.cuda.current_device()) + ).to(get_current_device()) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( - torch.cuda.current_device() + get_current_device() ) - attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device()) + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device()) outputs = self.model(input_ids, attention_mask=attention_mask)[0] @@ -464,7 +465,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str return_tensors="pt", return_token_type_ids=False, max_length=self.model_max_length - max_new_tokens, - ).to(torch.cuda.current_device()) + ).to(get_current_device()) # Set output_scores=True to get prediction scores. outputs = self.model.generate( @@ -598,12 +599,12 @@ def _load_model( self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) self.model, sharded_parameters = shard_former.optimize(self.model) - self.model.to(torch.cuda.current_device()) + self.model.to(get_current_device()) if peft_path is not None: raise NotImplementedError("ShardFormer for PEFT models is not implemented.") else: - self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device()) + self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device()) if peft_path is not None: self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 5b09f9de8da6..a340f3bfd281 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -8,6 +8,7 @@ from colossal_eval import dataset, models, utils import colossalai +from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig @@ -82,6 +83,7 @@ def rm_and_merge( def main(args): colossalai.launch_from_torch(config={}, seed=42) + accelerator = get_accelerator() world_size = dist.get_world_size() rank = dist.get_rank() @@ -235,10 +237,10 @@ def main(args): ), ) - logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") + logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB") del model_ - torch.cuda.empty_cache() + accelerator.empty_cache() dist.barrier() if rank == 0: diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md new file mode 100644 index 000000000000..be50a8f9f251 Binary files /dev/null and b/applications/ColossalMoE/README.md differ diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py new file mode 100644 index 000000000000..d08dfd5f8120 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py @@ -0,0 +1,629 @@ +import copy +import logging +import os +from pathlib import Path +from shutil import rmtree +from typing import Dict, Iterator, Optional, OrderedDict, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.checkpoint_io import CheckpointIndexFile +from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO +from colossalai.checkpoint_io.index_file import CheckpointIndexFile +from colossalai.checkpoint_io.utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + get_optimizer_base_filenames, + load_shard_state_dict, + load_states_into_optimizer, + save_config_file, + save_param_groups, + save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, +) +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.moe import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True, + ) -> None: + super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose) + moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size] + self.ep_group = moe_info.ep_group + self.ep_size = moe_info.ep_size + self.ep_rank = moe_info.ep_rank + self.real_dp_rank = moe_info.dp_rank + + @staticmethod + def _model_sharder( + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + param_name_pattern: Optional[str] = None, + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + if param_name_pattern is not None and param_name_pattern not in name: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + if self.real_dp_rank != 0: + dist.barrier() + return + + # ep_rank 0 saves all the parameters and buffers. + # other ep_ranks save only experts + ep_param_pattern = "experts." if self.ep_rank != 0 else None + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder( + model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern + ) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.tp_rank == 0 + + if self.pp_size == 1 and self.ep_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + dist.barrier() + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin") + weights_name = weights_name.replace( + ".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors" + ) + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + dist.barrier() + return + + dist.barrier() + + # The global master rank integrates the index files and clean the folder. + if self.coordinator.is_master(): + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) + rmtree(tmp_index_file_folder) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + @staticmethod + def gather_from_sharded_optimizer_state( + state: OrderedDict, + param: torch.Tensor, + original_shape: torch.Size, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + use_zero: bool, + inplace: bool, + is_moe_param: bool, + device: torch.device = torch.device("cpu"), + ) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # First gather Zero shards. + if use_zero and not is_moe_param: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().to(device) + + return state_ + + @staticmethod + def _optimizer_sharder( + optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + size_per_shard: int = 1024, + only_moe_param: bool = False, + ): + # An internel method that breaks state_dict of optimizer into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + master_to_working_map = optimizer.get_master_to_working_map() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info["param2id"][id(working_param)] + original_shape = param_info["param2shape"][id(working_param)] + state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False, + is_moe_param=is_moe_tensor(working_param), + ) + + if only_moe_param and not is_moe_tensor(working_param): + continue + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.real_dp_rank != 0: + dist.barrier() + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + size_per_shard=size_per_shard, + only_moe_param=self.ep_rank != 0, + ) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.real_dp_rank == 0 and self.tp_rank == 0 + + if self.pp_size == 1 and self.ep_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + dist.barrier() + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) + + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + dist.barrier() + return + + dist.barrier() + + # The global master rank integrates the index files and clean the folder. + if self.coordinator.is_master(): + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + # ep param groups + if len(optimizer.optim.param_groups) == len(saved_groups) + 1: + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True, + is_moe_param=is_moe_tensor(working_param), + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose and self.coordinator.is_master(): + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def shard_from_complete_optimizer_state( + self, + state: OrderedDict, + current_shape: torch.Size, + original_shape: torch.Size, + device: torch.device, + inplace: bool, + is_moe_param: bool, + ) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero and not is_moe_param: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False): + raise NotImplementedError diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py new file mode 100644 index 000000000000..a2b78a2bd18c --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py @@ -0,0 +1,92 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +from colossalai.lazy import LazyInitContext +from colossalai.moe import MOE_MANAGER +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven +from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_info + + +class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): + def __init__(self, config): + super().__init__(config) + self.setup_ep() + + def setup_ep(self): + _, moe_info = MOE_MANAGER.get_info(self.num_experts) + ep_group = moe_info.ep_group + self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 + self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 + assert self.num_experts % self.ep_size == 0 + self.ep_group = ep_group + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) + for p in self.experts.parameters(): + set_moe_tensor_info(p, moe_info) + + @staticmethod + def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": + LazyInitContext.materialize(module) + module.__class__ = EPMixtralSparseMoeBlock + module.setup_ep() + return module + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + selected_experts = selected_experts.t().reshape(-1) + selected_experts_idx = selected_experts.argsort() + dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] + input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + output_split_sizes = torch.zeros_like(input_split_sizes) + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + # compute expert output + output_states = MoeInGradScaler.apply(output_states, self.ep_size) + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + # no need to split + expert = self.experts[self.expert_start_idx] + output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) + output_states = expert.w2(output_states) + else: + output_states_splits = output_states.split(output_split_sizes.tolist()) + output_states_list = [] + for i, split_states in enumerate(output_states_splits): + if split_states.size(0) == 0: + continue + expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) + split_states = expert.w2(split_states) + output_states_list.append(split_states) + output_states = torch.cat(output_states_list) + output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + recover_experts_idx = torch.empty_like(selected_experts_idx) + recover_experts_idx[selected_experts_idx] = torch.arange( + selected_experts_idx.size(0), device=selected_experts_idx.device + ) + dispatch_states = dispatch_states[recover_experts_idx] + k_hidden_states = dispatch_states.chunk(self.top_k) + output_states = k_hidden_states[0] * routing_weights[:, 0, None] + for i in range(1, self.top_k): + output_states += k_hidden_states[i] * routing_weights[:, i, None] + output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) + return output_states, router_logits diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py new file mode 100644 index 000000000000..218b05b27fad --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -0,0 +1,557 @@ +from functools import partial +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.models.mixtral.modeling_mixtral import ( + MixtralDecoderLayer, + MixtralForCausalLM, + MixtralModel, + MoeCausalLMOutputWithPast, + _prepare_4d_causal_attention_mask, + load_balancing_loss_func, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.shard import ShardConfig + +from .mixtral_layer import EPMixtralSparseMoeBlock + +__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] + + +class MixtralPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") + + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=MixtralModel, + ) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in mixtral.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class MixtralModelPolicy(MixtralPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralModel, + new_forward=MixtralPipelineForwards.mixtral_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class MixtralForCausalLMPolicy(MixtralPolicy): + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + MixtralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralForCausalLM, + new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class MixtralPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def mixtral_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + output_router_logits, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + output_router_logits, + use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + if output_router_logits and past_router_logits is not None: + all_router_logits = past_router_logits + all_router_logits + if stage_manager.is_last_stage(): + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + + @staticmethod + def mixtral_for_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MixtralPipelineForwards.mixtral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + past_router_logits=past_router_logits, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + if output_router_logits: + out["past_router_logits"] = outputs["past_router_logits"] + return out diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py new file mode 100644 index 000000000000..a2a0a7e78239 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/utils.py @@ -0,0 +1,84 @@ +import json +import os +from typing import Any, Dict, Tuple, Union + +import torch +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def save_checkpoint( + save_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, +) -> None: + """ + Save model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") + os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load_checkpoint( + load_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, +) -> Tuple[int, int, int]: + """ + Load model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + # Update booster params states. + booster.load_model(model, os.path.join(load_dir, "modeling")) + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + + running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) + return ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py new file mode 100644 index 000000000000..46ff70ff33ab --- /dev/null +++ b/applications/ColossalMoE/infer.py @@ -0,0 +1,111 @@ +import argparse + +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from transformers import AutoTokenizer +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="mistralai/Mixtral-8x7B-v0.1", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--plugin", + type=str, + default="ep", + choices=["ep"], + help="Parallel methos.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + config = MixtralConfig.from_pretrained(args.model_name) + ep_size = min(dist.get_world_size(), config.num_local_experts) + # Set plugin + if args.plugin == "ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + ep_size=ep_size, + zero_stage=1, + precision=args.precision, + custom_policy=MixtralForCausalLMPolicy(), + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + enable_fused_normalization=args.use_layernorm_kernel, + enable_jit_fused=args.use_kernel, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build mixtral model + model = MixtralForCausalLM.from_pretrained(args.model_name) + coordinator.print_on_master(f"Finish load model") + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + # Set booster + booster = Booster(plugin=plugin) + model, _, _, _, _ = booster.boost(model=model) + coordinator.print_on_master(f"Finish init booster") + + model.eval() + + if coordinator.rank == 0: + text = ["Hello my name is"] + else: + text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"] + tokenizer.pad_token = tokenizer.unk_token + inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) + + with torch.no_grad(): + outputs = model.module.generate(**inputs, max_new_tokens=20) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + print(f"[{coordinator.rank}] {outputs}") + + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh new file mode 100644 index 000000000000..0487fe9c1562 --- /dev/null +++ b/applications/ColossalMoE/infer.sh @@ -0,0 +1,7 @@ +NUM_GPU=2 +MODEL="mistralai/Mixtral-8x7B-v0.1" + +# ep +torchrun --standalone --nproc_per_node $NUM_GPU infer.py \ + --model_name $MODEL \ + --plugin "ep" \ diff --git a/applications/ColossalMoE/requirements.txt b/applications/ColossalMoE/requirements.txt new file mode 100644 index 000000000000..9a5738c412b9 --- /dev/null +++ b/applications/ColossalMoE/requirements.txt @@ -0,0 +1,5 @@ +colossalai >= 0.3.3 +torch >= 1.8.1 +transformers == 4.36.0 +sentencepiece +datasets diff --git a/applications/ColossalMoE/setup.py b/applications/ColossalMoE/setup.py new file mode 100644 index 000000000000..275f59e10a06 --- /dev/null +++ b/applications/ColossalMoE/setup.py @@ -0,0 +1,43 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +def fetch_version(): + with open("version.txt", "r") as f: + return f.read().strip() + + +setup( + name="colossal_moe", + version=fetch_version(), + packages=find_packages( + exclude=( + "tests", + "benchmarks", + "*.egg-info", + ) + ), + description="Colossal-AI MoE", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], +) diff --git a/applications/ColossalMoE/tests/__init__.py b/applications/ColossalMoE/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py new file mode 100644 index 000000000000..57589ab20d22 --- /dev/null +++ b/applications/ColossalMoE/tests/test_mixtral_layer.py @@ -0,0 +1,63 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock +from torch.testing import assert_close +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +import colossalai +from colossalai.moe import MOE_MANAGER +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_mixtral_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + MOE_MANAGER.setup( + parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1 + ) + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + ) + torch.manual_seed(0) + orig_model = MixtralSparseMoeBlock(config).cuda() + x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() + orig_output, orig_logits = orig_model(x) + model = deepcopy(orig_model) + model = EPMixtralSparseMoeBlock.from_native_module(model) + ep_output, ep_logits = model(x) + assert_close(orig_logits, ep_logits) + assert_close(orig_output, ep_output) + orig_loss = orig_output.mean() + orig_loss.backward() + ep_loss = ep_output.mean() + ep_loss.backward() + assert_close(orig_loss, ep_loss) + name_to_p = {n: p for n, p in orig_model.named_parameters()} + for n, ep_p in model.named_parameters(): + p = name_to_p[n] + if ep_p.grad is not None: + assert_close(p.grad, ep_p.grad) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch({}, rank, world_size, "localhost", port) + check_mixtral_moe_layer() + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mixtral_moe_layer(2) diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py new file mode 100644 index 000000000000..822e7410f016 --- /dev/null +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -0,0 +1,146 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from torch.optim import Adam +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_model_equal(model1, model2): + assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert torch.equal(p1.half(), p2.half()) + + +def get_optimizer_snapshot(optim): + state = {id(k): deepcopy(v) for k, v in optim.state.items()} + param_groups = [] + for group in optim.param_groups: + params = [id(p) for p in group["params"]] + new_group = {"params": params} + for k, v in group.items(): + if k != "params": + new_group[k] = v + param_groups.append(new_group) + return { + "state": state, + "param_groups": param_groups, + } + + +def check_optimizer_snapshot_equal(snapshot1, snapshot2): + # check param_groups + assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"]) + for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]): + assert set(group1.keys()) == set(group2.keys()) + for k in group1.keys(): + assert group1[k] == group2[k] + # check state + assert set(snapshot1["state"].keys()) == set( + snapshot2["state"].keys() + ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}" + for pid in snapshot1["state"].keys(): + state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid] + assert set(state1.keys()) == set(state2.keys()) + for k in state1.keys(): + if isinstance(state1[k], torch.Tensor): + assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}" + else: + assert state1[k] == state2[k] + + +def check_mixtral_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + num_attention_heads=2, + num_key_value_heads=2, + ) + torch.manual_seed(0) + input_ids = torch.randint(0, 100, (2, tokens)).cuda() + orig_model = MixtralForCausalLM(config).cuda() + model = deepcopy(orig_model) + optimizer = Adam(model.parameters(), lr=1e-3) + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=2, + ep_size=2, + custom_policy=MixtralForCausalLMPolicy(), + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + microbatch_size=1, + zero_stage=1, + ) + booster = Booster(plugin=plugin) + model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer) + # initialize grads + data_iter = iter( + [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}] + ) + booster.execute_pipeline( + data_iter, + model, + lambda outputs, inputs: outputs.loss, + optimizer, + ) + + # check save model + booster.save_model(model, "mixtral_model", shard=True) + dist.barrier() + if dist.get_rank() == 0: + saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda() + check_model_equal(orig_model, saved_model) + saved_model.save_pretrained("mixtral_hf_model") + dist.barrier() + + # check load model + new_model = MixtralForCausalLM(config).cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) + booster.load_model(new_model, "mixtral_hf_model") + check_model_equal(model, new_model) + + # check save optimizer + optimizer.step() + for group in optimizer.param_groups: + group["lr"] = 0.1 + snapshot = get_optimizer_snapshot(optimizer.unwrap()) + booster.save_optimizer(optimizer, "mixtral_optim", shard=True) + dist.barrier() + # reset optimizer state + for state in optimizer.unwrap().state.values(): + for v in state.values(): + if isinstance(v, torch.Tensor): + v.zero_() + booster.load_optimizer(optimizer, "mixtral_optim") + loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap()) + check_optimizer_snapshot_equal(snapshot, loaded_snapshot) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch({}, rank, world_size, "localhost", port) + check_mixtral_moe_layer() + + +@pytest.mark.parametrize("world_size", [4]) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mixtral_moe_layer(4) diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py new file mode 100644 index 000000000000..c567038ec252 --- /dev/null +++ b/applications/ColossalMoE/train.py @@ -0,0 +1,295 @@ +import argparse + +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer +from transformers.models.mixtral import MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +@torch.no_grad() +def get_global_loss(loss, booster): + global_loss = loss.clone().detach() + dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) + global_loss.div_(booster.plugin.dp_size) + return global_loss + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="mistralai/Mixtral-8x7B-v0.1", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + choices=["hybrid"], + help="Parallel methods.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./outputs", + help="The path of your saved model after finetuning.", + ) + parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--save_interval", + type=int, + default=1000, + help=" The interval (steps) of saving checkpoints.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + # optim + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + + # lr scheduler + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + + # zero stage for all plugins + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin") + parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin") + parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + + # load balance + parser.add_argument( + "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable." + ) + parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.") + # communicate overlap + parser.add_argument( + "--comm_overlap", + action="store_true", + help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", + ) + # hierarchical all-to-all + parser.add_argument( + "--hierarchical_alltoall", + action="store_true", + help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Set plugin + if args.plugin == "hybrid": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=args.pp_size, + ep_size=args.ep_size, + microbatch_size=args.microbatch_size, + custom_policy=MixtralForCausalLMPolicy(), + enable_fused_normalization=args.use_layernorm_kernel, + enable_jit_fused=args.use_kernel, + precision=args.precision, + zero_stage=args.zero_stage, + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + ) + + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build Mixtral model + model = MixtralForCausalLM.from_pretrained(args.model_name) + coordinator.print_on_master(f"Finish init model") + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + dataset = RandomDataset(num_samples=100, tokenizer=tokenizer) + collate_fn = None + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn + ) + + # Set optimizer + optimizer = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # Set lr scheduler + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * len(dataloader), + warmup_steps=args.warmup_steps + if args.warmup_steps is not None + else int(args.num_epochs * len(dataloader) * 0.025), + eta_min=0.1 * args.lr, + ) + + # Set booster + booster = Booster(plugin=plugin) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + coordinator.print_on_master(f"Finish init booster") + + # Load ckpt + if args.load_checkpoint is not None: + load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler) + coordinator.print_on_master(f"Finish load optimizer") + + # Start finetuning + coordinator.print_on_master(f"Start finetuning") + for epoch in range(args.num_epoch): + model.train() + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) + with tqdm( + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage, + ) as pbar: + for step in pbar: + if use_pipeline: + # Forward pass + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + global_loss = get_global_loss(loss, booster) + if coordinator._local_rank == "0": + pbar.set_postfix({"Loss": global_loss.item()}) + else: + # Forward pass + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Apply load balance + # if ( + # args.load_balance + # and args.load_balance_interval > 0 + # and (step + 1) % args.load_balance_interval == 0 + # ): + # coordinator.print_on_master(f"Apply load balance") + # apply_load_balance(model, optimizer) + # save ckeckpoint + if (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + save_checkpoint( + args.output_path, + booster, + model, + optimizer, + lr_scheduler, + epoch, + step, + args.batch_size, + coordinator, + ) + + # save checkpoint at the end of each epochs + booster.save_model(model, args.output_path, shard=True, size_per_shard=5120) + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + + # Finish training + coordinator.print_on_master(f"Finish training") + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalMoE/train.sh b/applications/ColossalMoE/train.sh new file mode 100644 index 000000000000..bee7f5c8fdf8 --- /dev/null +++ b/applications/ColossalMoE/train.sh @@ -0,0 +1,19 @@ +NUM_GPU=8 +MODEL="mistralai/Mixtral-8x7B-v0.1" +SEQ_LENGTH=2048 +BATCH_SIZE=1 +LR=0.00001 + +# hybrid +# torchrun --standalone --nproc_per_node $NUM_GPU \ +colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \ + train.py \ + --num_epoch 1 \ + --model_name $MODEL \ + --plugin "hybrid" \ + --batch_size $BATCH_SIZE \ + --lr $LR \ + --zero_stage 1 \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 8 \ diff --git a/applications/ColossalMoE/version.txt b/applications/ColossalMoE/version.txt new file mode 100644 index 000000000000..3eefcb9dd5b3 --- /dev/null +++ b/applications/ColossalMoE/version.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/colossalai/__init__.py b/colossalai/__init__.py index 7da55590305b..6b7f5d055207 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -1,4 +1,5 @@ from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch +from . import accelerator try: # .version will be created by setup.py diff --git a/colossalai/accelerator/README.md b/colossalai/accelerator/README.md new file mode 100644 index 000000000000..8c644493b03a --- /dev/null +++ b/colossalai/accelerator/README.md @@ -0,0 +1,20 @@ +# 🚀 Accelerator + +## 🔗 Table of Contents + +- [🚀 Accelerator](#-accelerator) + - [🔗 Table of Contents](#-table-of-contents) + - [📚 Introduction](#-introduction) + - [📌 Design and Acknowledgement](#-design-and-acknowledgement) + +## 📚 Introduction + +This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API. + +## 📌 Design and Acknowledgement + +Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work. + +We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications: +1. we updated the accelerator API names to be aligned with PyTorch's native API names. +2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled. diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py new file mode 100644 index 000000000000..1405133affe2 --- /dev/null +++ b/colossalai/accelerator/__init__.py @@ -0,0 +1,15 @@ +from .api import auto_set_accelerator, get_accelerator, set_accelerator +from .base_accelerator import BaseAccelerator +from .cpu_accelerator import CpuAccelerator +from .cuda_accelerator import CudaAccelerator +from .npu_accelerator import NpuAccelerator + +__all__ = [ + "get_accelerator", + "set_accelerator", + "auto_set_accelerator", + "BaseAccelerator", + "CudaAccelerator", + "NpuAccelerator", + "CpuAccelerator", +] diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py new file mode 100644 index 000000000000..02b3055d7380 --- /dev/null +++ b/colossalai/accelerator/api.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +from collections import OrderedDict +from typing import Union + +from .base_accelerator import BaseAccelerator +from .cpu_accelerator import CpuAccelerator +from .cuda_accelerator import CudaAccelerator +from .npu_accelerator import NpuAccelerator + +__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"] + + +_ACCELERATOR = None + + +# we use ordered dictionary here to associate the +# order with device check priority +# i.e. auto_set_accelerator will check cuda first +_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator) + + +def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: + """ + Set the global accelerator for the current process. + + Args: + accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs. + """ + + global _ACCELERATOR + + if isinstance(accelerator, str): + _ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]() + elif isinstance(accelerator, BaseAccelerator): + _ACCELERATOR = accelerator + else: + raise TypeError("accelerator must be either a string or an instance of BaseAccelerator") + + +def auto_set_accelerator() -> None: + """ + Automatically check if any accelerator is available. + If an accelerator is availabe, set it as the global accelerator. + """ + global _ACCELERATOR + + for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items(): + try: + accelerator = accelerator_cls() + if accelerator_name == "cpu" or accelerator.is_available(): + _ACCELERATOR = accelerator + break + except: + pass + + if _ACCELERATOR is None: + raise RuntimeError("No accelerator is available.") + + +def get_accelerator() -> BaseAccelerator: + """ + Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized + to the default accelerator type. + + Returns: the accelerator for the current process. + """ + global _ACCELERATOR + + if _ACCELERATOR is None: + auto_set_accelerator() + return _ACCELERATOR diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py new file mode 100644 index 000000000000..33c113999018 --- /dev/null +++ b/colossalai/accelerator/base_accelerator.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +__all__ = ["BaseAccelerator"] + + +class BaseAccelerator(ABC): + support_set_device: bool = True + + def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None: + self._name = name + self._communication_backend = communication_backend + self._is_synchronous = is_synchronous + + # ======================= + # immutable attributes + # ======================= + + @property + def name(self) -> str: + """ + Return the name of the accelerator. + """ + return self._name + + @property + def communication_backend(self) -> str: + """ + Return the name of the backend communication library. + """ + return self._communication_backend + + @property + def is_synchronous(self) -> bool: + """ + Return whether the accelerator is a synchronous device. + """ + return self._is_synchronous + + def __repr__(self) -> str: + cls_name = self.__class__.__name__ + return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})" + + # ======================= + # device APIs + # ======================= + @abstractmethod + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + + @abstractmethod + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + + @abstractmethod + def current_device(self) -> int: + """ + Return the current device index. + """ + + @abstractmethod + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + + @abstractmethod + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + + @abstractmethod + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + + @abstractmethod + def is_available(self): + """ + Check if the accelerator is available. + """ + + @abstractmethod + def device_count(self): + """ + Return the number of devices on the machine. + """ + + def set_to_device(self, models: Any) -> Any: + """ + Send model to device. + + :param models: nn.module or a list of module + """ + if isinstance(models, list) and len(models) > 1: + ret = [] + for model in models: + ret.append(model.to(self.get_current_device())) + return ret + elif isinstance(models, list): + return models[0].to(self.get_current_device()) + else: + return models.to(self.get_current_device()) + + @abstractmethod + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the capability of a device. + """ + + @abstractmethod + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + + @abstractmethod + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + + @abstractmethod + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc. + """ + + # ======================= + # random number generator APIs + # ======================= + @abstractmethod + def get_rng_state(self, device="cuda") -> torch.Tensor: + """ + Returns the random number generator state of the specified device as a ByteTensor. + """ + + @abstractmethod + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + + @abstractmethod + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None: + """ + Sets the random number generator state of the specified device. + """ + + @abstractmethod + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + + @abstractmethod + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current device. + """ + + @abstractmethod + def manual_seed_all(self, seed: int) -> None: + """ + Sets the seed for generating random numbers on all devices. + """ + + @abstractmethod + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current device. + """ + + @abstractmethod + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all devices. + """ + + @abstractmethod + def initial_seed(self) -> int: + """ + Returns the current random seed of the current device. + """ + + # ======================= + # memory management APIs + # ======================= + @abstractmethod + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi. + """ + + @abstractmethod + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + + @abstractmethod + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + + @abstractmethod + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + + @abstractmethod + def memory_allocated(self, device=None) -> int: + """ + Returns the current device memory occupied by tensors in bytes for a given device. + """ + + @abstractmethod + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum device memory occupied by tensors in bytes for a given device. + """ + + @abstractmethod + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum device memory occupied by tensors for a given device. + """ + + @abstractmethod + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device. + """ + + @abstractmethod + def memory_reserved(self, device=None) -> int: + """ + Returns the current device memory managed by the caching allocator in bytes for a given device. + """ + + @abstractmethod + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum device memory managed by the caching allocator in bytes for a given device. + """ + + @abstractmethod + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + + @abstractmethod + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the device memory allocator. + """ + + # ======================= + # streams and events APIs + # ======================= + + @abstractmethod + def Stream(self, device=None, priority=0, **kwargs): + """ + A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + + @abstractmethod + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + + @abstractmethod + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + + @abstractmethod + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + + @abstractmethod + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + + @abstractmethod + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + + # ======================= + # amp APIs + # ======================= + @abstractmethod + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ diff --git a/colossalai/accelerator/cpu_accelerator.py b/colossalai/accelerator/cpu_accelerator.py new file mode 100644 index 000000000000..080aa61e8e3a --- /dev/null +++ b/colossalai/accelerator/cpu_accelerator.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python + +import resource +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import psutil +import torch + +from .base_accelerator import BaseAccelerator + +__all__ = ["CpuAccelerator"] + + +class CpuAccelerator(BaseAccelerator): + support_set_device: bool = False + """ + Accelerator class for cpu. + """ + + def __init__(self): + super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return "" + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device("cpu") + + def current_device(self) -> int: + """ + Return the current device index. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def is_available(self): + """ + Check if the accelerator is available. + """ + return True + + def device_count(self): + """ + Return the number of devices on the machine. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the cuda capability of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device=None) -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.set_rng_state(new_state) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return psutil.Process().memory_info().rss + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + max_memory = int(psutil.virtual_memory().total * fraction) + _, hard = resource.getrlimit(resource.RLIMIT_AS) + resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard)) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + raise RuntimeError("this method is not supported for cpu accelerator") + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return nullcontext diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py new file mode 100644 index 000000000000..f1ab487d4f58 --- /dev/null +++ b/colossalai/accelerator/cuda_accelerator.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +from .base_accelerator import BaseAccelerator + +__all__ = ["CudaAccelerator"] + + +class CudaAccelerator(BaseAccelerator): + """ + Accelerator class for Nvidia CUDA devices. + """ + + def __init__(self): + super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return torch.version.cuda + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device(f"cuda:{torch.cuda.current_device()}") + + def current_device(self) -> int: + """ + Return the current device index. + """ + return torch.cuda.current_device() + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + if device is None: + if not dist.is_initialized(): + raise RuntimeError("Cannot get current device when distributed is not initialized.") + device = dist.get_rank() % self.device_count() + torch.cuda.set_device(device) + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + return torch.cuda.get_device_name(device) + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + torch.cuda.synchronize(device) + + def is_available(self): + """ + Check if the accelerator is available. + """ + return torch.cuda.is_available() + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.cuda.device_count() + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the cuda capability of a device. + """ + return torch.cuda.get_device_capability(device) + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + return torch.cuda.get_device_name(device) + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + return torch.cuda.get_device_properties(device) + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + return torch.cuda.utilization(device) + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device="cuda") -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.cuda.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + return torch.cuda.get_rng_state_all() + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.cuda.set_rng_state(new_state, device) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + torch.cuda.set_rng_state_all(new_states) + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + torch.cuda.manual_seed(seed) + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + torch.cuda.manual_seed_all(seed) + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + torch.cuda.seed() + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + torch.cuda.seed_all() + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + return torch.cuda.initial_seed() + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + torch.cuda.empty_cache() + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of CUDA memory allocator statistics for a given device. + """ + return torch.cuda.memory_stats(device=device) + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + return torch.cuda.memory_summary(device=device, abbreviated=abbreviated) + + def memory_snapshot(self): + """ + Returns a snapshot of the CUDA memory allocator state across all devices. + """ + return torch.cuda.memory_snapshot() + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.memory_allocated(device=device) + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.max_memory_allocated(device=device) + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + torch.cuda.reset_max_memory_allocated(device=device) + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + torch.cuda.reset_max_memory_cached(device=device) + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.cuda.memory_reserved(device=device) + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.cuda.max_memory_reserved(device=device) + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + torch.cuda.set_per_process_memory_fraction(fraction, device=device) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + torch.cuda.reset_peak_memory_stats(device=device) + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details. + """ + return torch.cuda.Stream(device, priority, **kwargs) + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams. + """ + return torch.cuda.Event(enable_timing, blocking, interprocess) + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.cuda.current_stream(device) + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + return torch.cuda.default_stream(device) + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + torch.cuda.set_stream(stream_) + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.cuda.stream(stream_) + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py new file mode 100644 index 000000000000..b28492968eeb --- /dev/null +++ b/colossalai/accelerator/npu_accelerator.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +from .base_accelerator import BaseAccelerator + +try: + import torch_npu # noqa +except ImportError: + pass + + +__all__ = ["NpuAccelerator"] + + +class NpuAccelerator(BaseAccelerator): + """ + Accelerator class for Huawei NPU devices. + """ + + def __init__(self): + super().__init__(name="npu", communication_backend="hccl", is_synchronous=False) + + # ======================= + # device APIs + # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return torch.version.cann + + def get_current_device(self) -> torch.device: + """ + Return the current device. + """ + return torch.device(f"npu:{torch.npu.current_device()}") + + def current_device(self) -> int: + """ + Return the current device index. + """ + return torch.npu.current_device() + + def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None: + """ + Bind the current process to a device. + """ + if device is None: + if not dist.is_initialized(): + raise RuntimeError("Cannot get current device when distributed is not initialized.") + device = dist.get_rank() % self.device_count() + torch.npu.set_device(device) + + def get_device_name(self, device: Union[torch.device, int]) -> str: + """ + Return the name of the device. + """ + return torch.npu.get_device_name(device) + + def synchronize(self, device: Union[torch.device, int] = None): + """ + Synchronize the current process. + """ + torch.npu.synchronize(device) + + def is_available(self): + """ + Check if the accelerator is available. + """ + return torch.npu.is_available() + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.npu.device_count() + + def get_device_capability(self, device=None) -> Tuple[int, int]: + """ + Gets the npu capability of a device. + """ + return torch.npu.get_device_capability(device) + + def get_device_name(self, device=None) -> str: + """ + Gets the name of a device. + """ + return torch.npu.get_device_name(device) + + def get_device_properties(self, device): + """ + Gets the properties of a device. + """ + return torch.npu.get_device_properties(device) + + def utilization(self, device=None) -> int: + """ + Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi + """ + return torch.npu.utilization(device) + + # ======================= + # random number generator APIs + # ======================= + def get_rng_state(self, device="npu") -> torch.Tensor: + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + return torch.npu.get_rng_state(device) + + def get_rng_state_all(self) -> List[torch.Tensor]: + """ + Returns a list of ByteTensor representing the random number states of all devices. + """ + return torch.npu.get_rng_state_all() + + def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None: + """ + Sets the random number generator state of the specified GPU. + """ + torch.npu.set_rng_state(new_state, device) + + def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None: + """ + Sets the random number generator state of all devices. + """ + torch.npu.set_rng_state_all(new_states) + + def manual_seed(self, seed: int) -> None: + """ + Sets the seed for generating random numbers for the current GPU. + """ + torch.npu.manual_seed(seed) + + def manual_seed_all(self, seed: int) -> None: + """ + Set the random seed for the all processes. + """ + torch.npu.manual_seed_all(seed) + + def seed(self) -> None: + """ + Sets the seed for generating random numbers to a random number for the current GPU. + """ + torch.npu.seed() + + def seed_all(self) -> None: + """ + Sets the seed for generating random numbers to a random number on all GPUs. + """ + torch.npu.seed_all() + + def initial_seed(self) -> int: + """ + Returns the current random seed of the current GPU. + """ + return torch.npu.initial_seed() + + # ======================= + # memory management APIs + # ======================= + + def empty_cache(self) -> None: + """ + Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi. + """ + torch.npu.empty_cache() + + def memory_stats(self, device=None) -> Dict[str, Any]: + """ + Returns a dictionary of npu memory allocator statistics for a given device. + """ + return torch.npu.memory_stats(device=device) + + def memory_summary(self, device=None, abbreviated=False) -> str: + """ + Returns a human-readable printout of the current memory allocator statistics for a given device. + """ + return torch.npu.memory_summary(device=device, abbreviated=abbreviated) + + def memory_snapshot(self): + """ + Returns a snapshot of the npu memory allocator state across all devices. + """ + return torch.npu.memory_snapshot() + + def memory_allocated(self, device=None) -> int: + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.npu.memory_allocated(device=device) + + def max_memory_allocated(self, device=None) -> int: + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.npu.max_memory_allocated(device=device) + + def reset_max_memory_allocated(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device. + """ + torch.npu.reset_max_memory_allocated(device=device) + + def reset_max_memory_cached(self, device=None) -> None: + """ + Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device. + """ + torch.npu.reset_max_memory_cached(device=device) + + def memory_reserved(self, device=None) -> int: + """ + Returns the current GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.npu.memory_reserved(device=device) + + def max_memory_reserved(self, device=None) -> int: + """ + Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. + """ + return torch.npu.max_memory_reserved(device=device) + + def set_per_process_memory_fraction(self, fraction: float, device=None) -> None: + """ + Set memory fraction for a process. + """ + torch.npu.set_per_process_memory_fraction(fraction, device=device) + + def reset_peak_memory_stats(self, device=None) -> None: + """ + Resets the "peak" stats tracked by the npu memory allocator. + """ + torch.npu.reset_peak_memory_stats(device=device) + + # ======================= + # streams and events APIs + # ======================= + + def Stream(self, device=None, priority=0, **kwargs): + """ + A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details. + """ + return torch.npu.Stream(device, priority, **kwargs) + + def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + """ + npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams. + """ + return torch.npu.Event(enable_timing, blocking, interprocess) + + def current_stream(self, device=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.npu.current_stream(device) + + def default_stream(self, device=None): + """ + Returns the default Stream for a given device. + """ + return torch.npu.default_stream(device) + + def set_stream(self, stream_): + """ + Sets the current stream.This is a wrapper API to set the stream. + """ + torch.npu.set_stream(stream_) + + def stream(self, stream_): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.npu.stream(stream_) + + # ======================= + # amp APIs + # ======================= + def autocast( + self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True + ) -> Callable: + """ + Return autocast function + """ + return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 439d13dcfc11..fc4c884d4c5d 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -7,8 +7,8 @@ import torch from torch import Tensor +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger -from colossalai.utils.device import get_current_device __all__ = ["BaseGradScaler"] @@ -23,7 +23,7 @@ class BaseGradScaler(ABC): def __init__(self, initial_scale: float, verbose: bool): assert initial_scale > 0 - self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) + self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float) self._verbose = verbose if self._verbose: diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 86ba919ee696..5cd8035d7987 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -5,7 +5,7 @@ import torch -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .base_grad_scaler import BaseGradScaler @@ -37,14 +37,20 @@ def __init__( hysteresis: int = 2, verbose: bool = False, ): + a = get_accelerator() + a.device_count() super().__init__(initial_scale, verbose) if min_scale: - self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) + self._min_scale = torch.tensor( + [min_scale], device=get_accelerator().get_current_device(), dtype=torch.float + ) else: self._min_scale = None if max_scale: - self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) + self._max_scale = torch.tensor( + [max_scale], device=get_accelerator().get_current_device(), dtype=torch.float + ) else: self._max_scale = None @@ -117,7 +123,7 @@ def state_dict(self): return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict["scale"].to(get_current_device()) + self._scale = state_dict["scale"].to(get_accelerator().get_current_device()) self._growth_factor = state_dict["growth_factor"] self._backoff_factor = state_dict["backoff_factor"] self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py index 9ce272356797..2e7c8a281916 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -5,8 +5,8 @@ import torch.distributed as dist from torch import Tensor +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.utils import get_current_device from .base import MixedPrecisionMixin @@ -40,7 +40,7 @@ def __init__( max_scale=max_scale, ) self.optim_state = OptimState.UNSCALED - self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) + self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) @property def loss_scale(self) -> float: diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index 601bf2926d99..fe8439269f48 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -4,10 +4,10 @@ import torch from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule from .region import Region @@ -79,7 +79,9 @@ def __init__( hysteresis=hysteresis, max_scale=max_scale, ) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + self._found_overflow: torch.Tensor = torch.zeros( + 1, dtype=torch.int64, device=get_accelerator().get_current_device() + ) self._logger = get_dist_logger() def _set_grad_ptr(self): diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index a6628e29c2bc..3ad210de9f0a 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -11,7 +11,7 @@ import torch from torch.fx.node import Node -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .region import Region from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator @@ -57,7 +57,10 @@ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: - self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor + self.memory_budget = ( + torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory + * self.error_factor + ) self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 443c4094c0e1..c757a878d97a 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -5,8 +5,8 @@ from torch import Tensor from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.utils.device import autocast from .mixed_precision_base import MixedPrecision @@ -89,7 +89,7 @@ def __init__(self, module: nn.Module): super().__init__(module) def forward(self, *args, **kwargs): - with autocast(): + with get_accelerator().autocast(): return self.module(*args, **kwargs) diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py index d2dd00453e32..27285f95ce52 100644 --- a/colossalai/booster/plugin/dp_plugin_base.py +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -21,7 +21,16 @@ def __init__(self) -> None: self.world_size = dist.get_world_size() def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -45,7 +54,8 @@ def prepare_dataloader( :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) # Deterministic dataloader def seed_worker(worker_id): diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index a891db422d67..95b96bbfd9ed 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -15,6 +15,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( get_model_base_filenames, @@ -27,8 +28,6 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.utils import get_current_device -from colossalai.utils.device import IS_NPU_AVAILABLE from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -366,11 +365,11 @@ def __init__( ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" - if IS_NPU_AVAILABLE: + if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" self.gemini_config = dict( chunk_config_dict=chunk_config_dict, - chunk_init_device=(chunk_init_device or get_current_device()), + chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), placement_policy=placement_policy, enable_gradient_accumulation=enable_gradient_accumulation, shard_param_frac=shard_param_frac, @@ -455,9 +454,18 @@ def control_device(self) -> bool: def supported_devices(self) -> List[str]: return ["cuda", "npu"] - + def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -485,8 +493,12 @@ def prepare_dataloader( extra_dp_world_size = self.pg_mesh.size(DP_AXIS) zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) - sampler = DistributedSampler( - dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( + dataset, + num_replicas=zero_world_size * extra_dp_world_size, + rank=zero_rank * extra_dp_world_size + extra_dp_rank, + shuffle=shuffle, ) # Deterministic dataloader diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index e1593cf6b26c..da67e6b41fbf 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,5 +1,6 @@ import ctypes import random +import warnings from contextlib import contextmanager from functools import partial from types import MethodType @@ -18,6 +19,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh @@ -28,7 +30,6 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor -from colossalai.utils.device import get_current_device from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -82,7 +83,7 @@ def __init__( self.mixed_precision = torch.bfloat16 if self.mixed_precision is not None: module = module.to(self.mixed_precision) - module = module.to(get_current_device()) + module = module.to(get_accelerator().get_current_device()) # setting input type cast when using mixed precision self.convert_fn = None @@ -345,7 +346,9 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) if self.pp_size > 1: @@ -385,7 +388,7 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if self.tp_size > 1: # compute norm in tp process group @@ -542,7 +545,9 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ # so we need to calculate the norm of 'tp' and 'pp' gradients. total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -586,7 +591,7 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if self.tp_size > 1: # compute norm in tp process group @@ -798,7 +803,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo # so we only need to calculate the norm 'tp' of 'pp' gradients. total_norm = super()._compute_grad_norm(gradients, norm_type) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) if tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -838,7 +845,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32 + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 ) if dp_size > 1: # compute norm in dp process group @@ -1128,7 +1135,12 @@ def configure( tp_process_group=self.tp_group, ) else: - assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + if self.dp_size == 1: + warnings.warn( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you are not intended to use cpu_offload, please consider set zero_stage=0." + ) + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer( optimizer, @@ -1193,7 +1205,16 @@ def execute_pipeline( return outputs def prepare_dataloader( - self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by @@ -1217,7 +1238,8 @@ def prepare_dataloader( :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler( + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle ) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 89102820cd38..d21496f0b758 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -12,6 +12,7 @@ from torch.utils._pytree import tree_map from torch.utils.data import DataLoader +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io.utils import ( get_optimizer_base_filenames, @@ -24,7 +25,6 @@ sharded_optimizer_loading_epilogue, ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.utils import get_current_device from colossalai.zero import LowLevelZeroOptimizer from .dp_plugin_base import DPPluginBase @@ -52,7 +52,7 @@ def __init__(self, module: nn.Module, precision: str) -> None: self.dtype = torch.bfloat16 if self.dtype is not None: module = module.to(self.dtype) - module = module.to(get_current_device()) + module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None if self.dtype is not None: diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index e976d0aaf014..45e5a23c1b22 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -22,7 +22,7 @@ ) from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MoECheckpintIO +from colossalai.moe import MOE_MANAGER, MoECheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig @@ -150,6 +150,7 @@ def __init__( self, tp_size: int, pp_size: int, + ep_size: int, extra_dp_size: int = 1, precision: str = "fp16", zero_stage: int = 0, @@ -181,6 +182,7 @@ def __init__( overlap_communication: bool = True, use_ep_inside: bool = True, custom_policy: Policy = None, + checkpoint_io: Optional[MoECheckpintIO] = None, ) -> None: assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -188,10 +190,26 @@ def __init__( if enable_sequence_parallelism: assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" - + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert ( + dist.get_world_size() % (tp_size * pp_size * ep_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}" + self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size) + MOE_MANAGER.setup( + parallel="EP", + mode="fixed", + fixed_dp_size=self.real_dp_size, + fixed_ep_size=ep_size, + fixed_pp_size=pp_size, + use_ep_inside=use_ep_inside, + ) self.tp_size = tp_size self.pp_size = pp_size self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.ep_size = ep_size + self.moe_info = MOE_MANAGER.get_info(0)[1] self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload @@ -200,6 +218,7 @@ def __init__( self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism + self.checkpoint_io = checkpoint_io # we change pg mesh to (pp, dp, tp) for better moe performance self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) @@ -323,7 +342,10 @@ def seed_worker(worker_id): ) def get_checkpoint_io(self) -> MoECheckpintIO: - self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + if self.checkpoint_io is None: + self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + else: + self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io def configure( diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 780117598e18..71232421586d 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -9,7 +9,7 @@ from colossalai.interface import ModelWrapper -from .utils import has_index_file +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file __all__ = ["CheckpointIO"] @@ -90,7 +90,15 @@ def load_model( if index_file_exists: self.load_sharded_model(model, index_file_path, strict) else: - self.load_unsharded_model(model, checkpoint, strict) + path = Path(checkpoint, SAFE_WEIGHTS_NAME) + if path.is_file(): + self.load_unsharded_model(model, str(path), strict) + else: + path = Path(checkpoint, WEIGHTS_NAME) + if path.is_file(): + self.load_unsharded_model(model, str(path), strict) + else: + self.load_unsharded_model(model, checkpoint, strict) return origin_model diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b7900bc0f217..36df30335dd7 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,7 +1,7 @@ import copy -from functools import reduce import logging import os +from functools import reduce from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple @@ -14,6 +14,7 @@ from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -445,7 +446,11 @@ def save_sharded_optimizer( # Store param groups. index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) # Store index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) @@ -504,7 +509,11 @@ def save_sharded_optimizer( # Store param groups. final_index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(optimizer.param_info, group_file_path) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) @@ -713,12 +722,16 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, tp_group=self.tp_group, use_zero=self.use_zero, inplace=False, - device=torch.device("cuda"), + device=get_current_device(), ) if self.pp_size == 1: # When pipeline is not used, let master rank directly save the collected state_dict. - state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states} + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + state_dict = {"param_groups": param_groups, "state": local_states} if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) else: @@ -729,7 +742,11 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, # Only the master rank do the saving. if self.coordinator.is_master(): - state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()} + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + state_dict = {"param_groups": param_groups, "state": dict()} for _states in states_list: state_dict["state"].update(_states) save_state_dict(state_dict, checkpoint, use_safetensors=False) @@ -838,7 +855,7 @@ def gather_from_sharded_optimizer_state( if isinstance(v, torch.Tensor) and k != "step": # First gather Zero shards. if use_zero: - v = v.cuda() + v = v.to(get_current_device()) gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] dist.all_gather(gather_tensor, v, group=dp_group) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 25076b742c26..aaeaad3828f5 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -6,12 +6,12 @@ from pathlib import Path from typing import Dict, Union -import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.context import Config from colossalai.logging import get_dist_logger -from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed +from colossalai.utils import set_seed def launch( @@ -47,17 +47,18 @@ def launch( if rank == 0: warnings.warn("`config` is deprecated and will be removed soon.") - if IS_NPU_AVAILABLE and backend == "nccl": - backend = "hccl" + cur_accelerator = get_accelerator() + + backend = cur_accelerator.communication_backend # init default process group init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device - if torch.cuda.is_available() or IS_NPU_AVAILABLE: - # if local rank is not given, calculate automatically - set_device(local_rank) + # if local rank is not given, calculate automatically + if cur_accelerator.support_set_device: + cur_accelerator.set_device(local_rank) set_seed(seed) diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 8933fc0a3c2f..e69de29bb2d1 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,7 +0,0 @@ -from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention - -__all__ = [ - "LayerNorm", - "FusedScaleMaskSoftmax", - "MultiHeadAttention", -] diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu deleted file mode 100644 index 2b1b366b1c02..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu +++ /dev/null @@ -1,63 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "column_remap.cuh" -#include "util.cuh" - -const int SHUF_BLOCKSIZE_X = 256; -const int SHUF_BLOCKSIZE_Y = 16; - -__global__ void column_remap_kernel -( - const half* __restrict__ x, - half* __restrict__ x_new, - const int x_width, - const int x_height, - const uint32_t* x_map -) -{ - int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; - if (x_column >= x_width) return; - //if (x_row >= x_height) return; - - int x_stride = x_width; - int x_idx = x_row * x_stride + x_column; - - int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); - int x_idx_end = x_row_end * x_stride + x_column; - - int s_column = x_map[x_column]; - int s_idx = x_row * x_stride + s_column; - - while (x_idx < x_idx_end) - { - x_new[x_idx] = x[s_idx]; - x_idx += x_stride; - s_idx += x_stride; - } -} - -// Remap columns in x to correspond to sequential group index before matmul -// -// perform x -> seq_x such that seq_x @ seq_w == x @ w - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -) -{ - dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); - - dim3 blocks - ( - (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, - (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, - 1 - ); - - column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh deleted file mode 100644 index 0364e38c4779..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh +++ /dev/null @@ -1,19 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _column_remap_cuh -#define _column_remap_cuh - -#include -#include -#include - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh deleted file mode 100644 index c5258813e147..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh +++ /dev/null @@ -1,58 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_compat_cuh -#define _cuda_compat_cuh - -// atomicAdd for half types, to support CC < 7.x - -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; - - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); -} - -// atomicAdd for half2 types - -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); -} - -// - -#if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) - -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } - -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif - -#endif -#endif - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu deleted file mode 100644 index 4416027c8387..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu +++ /dev/null @@ -1,75 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#define _cuda_buffers_cu -#include "cuda_buffers.cuh" - -CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; -// __constant__ half2 q4_table[16][256]; -// half2 q4_table_host[16][256]; -// bool q4_table_init = false; - -CudaBuffers::CudaBuffers -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) : - device(_device), - temp_state_size(_temp_state_size), - temp_state(_temp_state), - temp_dq(_temp_dq) -{ - cudaSetDevice(_device); - - cudaStreamCreate(&alt_stream_1); - cudaStreamCreate(&alt_stream_2); - cudaStreamCreate(&alt_stream_3); - cudaEventCreate(&alt_stream_1_done); - cudaEventCreate(&alt_stream_2_done); - cudaEventCreate(&alt_stream_3_done); -} - -CudaBuffers::~CudaBuffers() -{ - cudaStreamDestroy(alt_stream_1); - cudaStreamDestroy(alt_stream_2); - cudaStreamDestroy(alt_stream_3); - cudaEventDestroy(alt_stream_1_done); - cudaEventDestroy(alt_stream_2_done); - cudaEventDestroy(alt_stream_3_done); -} - -CudaBuffers* get_buffers(const int device_index) -{ - return g_buffers[device_index]; -} - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) -{ - CudaBuffers* buffers = new CudaBuffers - ( - _device, - _temp_state_size, - _temp_state, - _temp_dq - ); - - g_buffers[_device] = buffers; -} - -void cleanup_buffers_cuda() -{ - for (int i = 0; i < CUDA_MAX_DEVICES; i++) - { - if (!g_buffers[i]) continue; - delete g_buffers[i]; - g_buffers[i] = NULL; - } -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh deleted file mode 100644 index 0bf2057c665c..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh +++ /dev/null @@ -1,55 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_buffers_cuh -#define _cuda_buffers_cuh - -#include -#include -#include -#include - -const int CUDA_MAX_DEVICES = 16; - -// #ifndef _cuda_buffers_cu -// extern __constant__ half2 q4_table[16][256]; -// #endif - -class CudaBuffers -{ -public: - int device; - - half* temp_state; // [max_hidden_rows * intermediate_size] - int temp_state_size; - half* temp_dq; // size of largest quant tensor * 8 - - cudaStream_t alt_stream_1; - cudaStream_t alt_stream_2; - cudaStream_t alt_stream_3; - cudaEvent_t alt_stream_1_done; - cudaEvent_t alt_stream_2_done; - cudaEvent_t alt_stream_3_done; - - CudaBuffers - ( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq - ); - ~CudaBuffers(); -}; - -CudaBuffers* get_buffers(const int device_index); - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -); - -void cleanup_buffers_cuda(); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh deleted file mode 100644 index 5cd2e8553ef6..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh +++ /dev/null @@ -1,49 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _hip_compat_cuh -#define _hip_compat_cuh - -// Workaround for a bug in hipamd, backported from upstream. -__device__ __forceinline__ __half __compat_hrcp(__half x) { - return __half_raw{ - static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; -} - -__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), - static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; -} - -#define hrcp __compat_hrcp -#define h2rcp __compat_h2rcp - -// Workaround for hipify_python using rocblas instead of hipblas. -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); -} - -#define rocblas_handle hipblasHandle_t -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_get_stream hipblasGetStream -#define rocblas_set_stream hipblasSetStream -#define rocblas_hgemm __compat_hipblasHgemm - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp deleted file mode 100644 index bcc0e43901de..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp +++ /dev/null @@ -1,254 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include -#include -#include -#include -#include -#include -#include -#include "util.cuh" -#include "tuning.h" -#include "cuda_buffers.cuh" -#include "q4_matrix.cuh" -#include "q4_matmul.cuh" -#include "column_remap.cuh" - -// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a -// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of -// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. - -void check_cuda(cudaError_t ret) -{ - switch (ret) - { - case cudaSuccess: - break; - - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; - - default: - printf(" **** CUDA error\n"); \ - printf(" **** %s\n", cudaGetErrorString(ret)); \ - TORCH_CHECK(false, "CUDA error"); \ - break; - } -} - -// Some decluttering macros - -#define STRINGIFY_(__x) #__x -#define STRINGIFY(__x) STRINGIFY_(__x) -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") - -#define TORCH_CHECK_DEVICE_INDEX(__index) \ -do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ - TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ -} while(0) - -#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ -do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ -} while(0) - -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) -{ - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; -} - - -// Tuning parameters - -ExLlamaTuning tuningParams; - -void set_tuning_params -( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2 -) -{ - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; -} - - -// Release all unmanaged objects allocated by the extension - -void cleanup() -{ - cleanup_buffers_cuda(); - g_q4_free_matrices(); -} - - -// Prepare buffers for forward pass - -void prepare_buffers -( - torch::Device device, - torch::Tensor temp_state, - torch::Tensor temp_dq -) -{ - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); - - prepare_buffers_cuda - ( - device_index, - // buffer size used for sanity checks - temp_state.numel(), - (half*) temp_state.data_ptr(), - (half*) temp_dq.data_ptr() - ); -} - - -// Create Q4Matrix, return handle - -uintptr_t make_q4 -( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device -) -{ - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); - - Q4Matrix* m = new Q4Matrix - ( - height, - width, - groups, - - (uint32_t*) qweight.data_ptr(), - (uint32_t*) qzeros.data_ptr(), - (half*) scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), - - device - ); - - g_q4_keep_matrix(m); - return reinterpret_cast (m); -} - - -// Matmul half @ quant -> half - -void q4_matmul -( - torch::Tensor x, - uintptr_t w, - torch::Tensor out -) -{ - Q4Matrix* wm = reinterpret_cast (w); - - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - int x_height = x.size(0); - - if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) - { - q4_matmul_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr() - ); - } - else - { - q4_matmul_recons_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - at::cuda::getCurrentCUDABlasHandle() - ); - } -} - - -// Remap columns in half tensor - -void column_remap -( - torch::Tensor x, - torch::Tensor x_new, - torch::Tensor x_map -) -{ - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - - int height = x.size(0); - int width = x.size(1); - - TORCH_CHECK_BUFFER_SIZE(x_new, height * width); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - column_remap_cuda - ( - (half*) x.data_ptr(), - (half*) x_new.data_ptr(), - height, - width, - (uint32_t*) x_map.data_ptr() - ); -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); - m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); - m.def("cleanup", &cleanup, "cleanup"); - m.def("make_q4", &make_q4, "make_q4"); - m.def("q4_matmul", &q4_matmul, "q4_matmul"); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh deleted file mode 100644 index 2fd5ab0b36cd..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh +++ /dev/null @@ -1,294 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _matrix_cuh -#define _matrix_cuh - -#include -#include - -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } -}; - -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } -}; - -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } -}; - -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } -}; - -// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale - -__device__ __forceinline__ half2 dot_product_8 -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - -// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) -// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; -// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; -// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; - - half2 tmp = __hmul2(*h_ptr++, v_01); - tmp = __hfma2(*h_ptr++, v_23, tmp); - tmp = __hfma2(*h_ptr++, v_45, tmp); - tmp = __hfma2(*h_ptr++, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half* h_ptr = h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(*h_ptr++, v_0); - tmp = __hfma(*h_ptr++, v_1, tmp); - tmp = __hfma(*h_ptr++, v_2, tmp); - tmp = __hfma(*h_ptr++, v_3, tmp); - tmp = __hfma(*h_ptr++, v_4, tmp); - tmp = __hfma(*h_ptr++, v_5, tmp); - tmp = __hfma(*h_ptr++, v_6, tmp); - tmp = __hfma(*h_ptr++, v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map - -__device__ __forceinline__ half2 dot_product_8_x_map -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - - half h_0 = h_ptr[*x_map_ptr++]; - half h_1 = h_ptr[*x_map_ptr++]; - half h_2 = h_ptr[*x_map_ptr++]; - half h_3 = h_ptr[*x_map_ptr++]; - half h_4 = h_ptr[*x_map_ptr++]; - half h_5 = h_ptr[*x_map_ptr++]; - half h_6 = h_ptr[*x_map_ptr++]; - half h_7 = h_ptr[*x_map_ptr++]; - - half2 h_01 = __halves2half2(h_0, h_1); - half2 h_23 = __halves2half2(h_2, h_3); - half2 h_45 = __halves2half2(h_4, h_5); - half2 h_67 = __halves2half2(h_6, h_7); - - half2 tmp = __hmul2(h_01, v_01); - tmp = __hfma2(h_23, v_23, tmp); - tmp = __hfma2(h_45, v_45, tmp); - tmp = __hfma2(h_67, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_x_map_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); - tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu deleted file mode 100644 index f47daeb0e877..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu +++ /dev/null @@ -1,260 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matmul.cuh" -#include "column_remap.cuh" -#include "util.cuh" -#include "matrix.cuh" -#include "cu_compat.cuh" -#include "cuda_buffers.cuh" -#if defined(USE_ROCM) -#include "hip_compat.cuh" -#endif - -const int THREADS_X = 32; // Block size and thread count along columns in w and out -const int THREADS_Y = 1; // Block size and thread count along rows in x and out - -typedef void (*fp_q4_matmul_kernel) -( - const half*, - const uint32_t*, - half*, - const half*, - const uint32_t*, - const int, - const int, - const int, - const int, - const int, - const uint32_t*, - bool -); - -template -__global__ void q4_matmul_kernel -( - const half* __restrict__ x, - const uint32_t* __restrict__ w, - half* __restrict__ out, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int dim, - const int width, - const int groupsize, - const int block_size_z, - const uint32_t* __restrict__ x_map, - bool no_zero -) -{ - // Start of block - - int x_column = block_size_z * blockIdx.z; - int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - int x_row = THREADS_Y * blockIdx.y + threadIdx.y; - - int iterations = (x_column_end - x_column) / 8; - - // Views - - MatrixView_half x_(x, height, dim); - MatrixView_half w_scales_(w_scales, dim / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); - MatrixView_q4_column w_(w, dim, width); - MatrixView_half_rw out_(out, height, width); - - // Zero output - - if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) - { - *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; - __syncthreads(); - } - - // Loop over part of x row (and w column) - - half2 acc = {}; - half acc_h = {}; - - if constexpr (use_groupsize) - { - // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this - // could be slightly faster - - for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) - { - if constexpr (use_half2) - { - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - else - { - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - } - } - else - { - // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache - - for (int k = x_column; k < x_column + iterations * 8; k += 8) - { - if constexpr (use_half2) - { - int group = k / groupsize; - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - else - { - int group = k / groupsize; - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - } - } - - // Add to block result - - if constexpr (use_half2) - { - half result = __hadd(__low2half(acc), __high2half(acc)); - atomicAdd(out_.item_ptr(x_row, w_column), result); - } - else - { - atomicAdd(out_.item_ptr(x_row, w_column), acc_h); - } -} - -fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) -{ - // - if (tuningParams->matmul_no_half2) { - if (block_size_z % groupsize == 0) { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } else { - if (block_size_z % groupsize == 0) - { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } -}; - -// Compute y = x @ w - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero, - cudaStream_t alt_stream -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - - uint32_t* x_map = w->cuda_x_map; - const half* x_mapped = x; - if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) - { - CudaBuffers* buffers = get_buffers(w->device); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - x_map = NULL; - } - - int block_size_z; - if (w->width == 4096) block_size_z = 384; // 7B - else if (w->width == 11008) block_size_z = 256; - else if (w->width == 5120) block_size_z = 384; // 13B - else if (w->width == 13824) block_size_z = 256; - else if (w->width == 6656) block_size_z = 256; // 33B - else if (w->width == 17920) block_size_z = 128; - else block_size_z = 256; - - //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); - - dim3 threads(THREADS_X, THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height + threads.y - 1) / threads.y, - (dim + block_size_z - 1) / block_size_z - ); - - fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - - kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); -} - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - CudaBuffers* buffers = get_buffers(w->device); - - const half* x_mapped = x; - if (w->cuda_x_map) - { - TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - } - - w->reconstruct(buffers->temp_dq); - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 - const float alpha = 1.0f; - const float beta = no_zero ? 1.0f : 0.0f; - cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, - x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); -#else - const half alpha = __float2half(1.0f); - const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); - cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); -#endif -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh deleted file mode 100644 index 09f3e1a63362..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh +++ /dev/null @@ -1,43 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matmul_cuh -#define _q4_matmul_cuh - -#include -#include -#include -#include -#include - -#include "q4_matrix.cuh" -#include "tuning.h" - -// Workaround for hipify_python using rocblas instead of hipblas. -#if defined(USE_ROCM) -#include -#define rocblas_handle hipblasHandle_t -#endif - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero = false, - cudaStream_t alt_stream = NULL -); - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero = false -); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu deleted file mode 100644 index 9c61143f565e..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu +++ /dev/null @@ -1,225 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matrix.cuh" -#include -#include "util.cuh" -#include "matrix.cuh" - -using namespace std; - -const int UNSHUF_BLOCKSIZE_X = 64; - -const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column -const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows - -vector g_q4_matrices; - -void g_q4_keep_matrix(Q4Matrix* m) -{ - g_q4_matrices.push_back(m); -} - -void g_q4_free_matrices() -{ - for (const auto& m : g_q4_matrices) delete m; - g_q4_matrices.clear(); -} - -Q4Matrix::Q4Matrix -( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device -) : - height(_height), - width(_width), - groups(_groups), - device(_device) -{ - cudaSetDevice(device); - - cuda_qweight = _qweight; - cuda_qzeros = _qzeros; - cuda_scales = _scales; - - groupsize = height / groups; - - if (_g_idx) make_sequential(_g_idx); -} - -Q4Matrix::~Q4Matrix() -{ -} - -// Make sequential - -__global__ void make_sequential_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const uint32_t* __restrict__ x_map, - const int w_height, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - - int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - - int w_new2_row = blockIdx.y; - - int x_map_idx = w_new2_row << 3; - - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = x_map[x_map_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - - w_new2[w_new2_row * w2_stride + w2_column] = dst; -} - -void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) -{ - uint32_t* cuda_new_qweight = NULL; - cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); - cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch - - uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); - - // Group histogram - - for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; - - // Group map - - for (int i = 0, acc = 0; i < groups; i++) - { - short tmp = cpu_g_idx_map[i]; - cpu_g_idx_map[i] = acc; - acc += tmp; - } - - // X map (inverse) - - for (int row = 0; row < height; row++) - { - uint32_t target_group = cpu_g_idx[row]; - uint32_t target_row = cpu_g_idx_map[target_group]; - cpu_g_idx_map[target_group]++; - cpu_x_map_inv[row] = target_row; - } - - // X map - - for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; - - // Move to CUDA - - cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); - - // Rearrange rows in w - - dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); - dim3 blocks - ( - (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), - height / 8, - 1 - ); - - make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); - - // Replace qweights - - cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - - // Cleanup - - cudaDeviceSynchronize(); - cudaFree(cuda_new_qweight); - free(cpu_g_idx_map); - free(cpu_x_map); - free(cpu_x_map_inv); -} - -__global__ void reconstruct_kernel -( - const uint32_t* __restrict__ w, - half* __restrict__ out, // (y) - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int width, - const int groupsize -) -{ - // Start of block - - int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; - int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; - if (column >= width) return; - - // Views - - MatrixView_q4_column w_(w, height, width); - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, height / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); - - // Groupsize version - - int group = row / groupsize; - - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - - uint32_t w_read = w_.item_uint32_t(row, column); - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += 4) - { - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } -} - -void Q4Matrix::reconstruct(half* out) -{ - dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height / 8 + threads.y - 1) / threads.y, - 1 - ); - - reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh deleted file mode 100644 index 50cb72a41518..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh +++ /dev/null @@ -1,53 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matrix_cuh -#define _q4_matrix_cuh - -#include -#include -#include - -class Q4Matrix -{ -public: - - int device; - - int height; - int width; - int groups; - int groupsize; - - uint32_t* cuda_qweight = NULL; - uint32_t* cuda_qzeros = NULL; - half* cuda_scales = NULL; - uint32_t* cuda_x_map = NULL; - - Q4Matrix - ( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device - ); - - ~Q4Matrix(); - - void reconstruct(half* out); - -private: - - void make_sequential(const uint32_t* cpu_g_idx); - -}; - -void g_q4_keep_matrix(Q4Matrix* m); -void g_q4_free_matrices(); - -#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h deleted file mode 100644 index e413b8a96c11..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h +++ /dev/null @@ -1,12 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _tuning_h -#define _tuning_h - -struct ExLlamaTuning { - int matmul_recons_thd; - bool matmul_fused_remap; - bool matmul_no_half2; -}; - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh deleted file mode 100644 index 7b397573214b..000000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh +++ /dev/null @@ -1,33 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _util_cuh -#define _util_cuh - -#include -#include -#include -#include - -#if defined(USE_ROCM) -#define cudaUnspecified hipErrorUnknown -#else -#define cudaUnspecified cudaErrorApiFailureBase -#endif - -// React to failure on return code != cudaSuccess - -#define _cuda_check(fn) \ -do { \ - {_cuda_err = fn;} \ - if (_cuda_err != cudaSuccess) goto _cuda_fail; \ -} while(false) - -// React to failure on return code == 0 - -#define _alloc_check(fn) \ -do { \ - if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ - else _cuda_err = cudaSuccess; \ -} while(false) - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu deleted file mode 100644 index 58d26235a9cc..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu +++ /dev/null @@ -1,191 +0,0 @@ -#include "block_reduce.h" -#include "cuda_util.h" -#include "kernels.h" -#include "ls_cub.cuh" - -ls::cub::CachingDeviceAllocator g_allocator(true); - -template -__global__ void ls_cross_entropy_fw_kernel( - const T *__restrict__ inputs, const int *__restrict__ targets, - float *__restrict__ outputs, float *__restrict__ nll_loss_outputs, - const int padding_idx, const float epsilon, const int vocab_size) { - /* step1: compute each thread's max_logit and sum_exp_logit, store in - * max_input, sum_exp_logit */ - const int block_start = blockIdx.x * vocab_size; - const int left_idx = block_start + threadIdx.x; - const int right_idx = (blockIdx.x + 1) * vocab_size; - float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[2] = {0.f, 0.f}; // logit and logit exp - int target_tid = targets[blockIdx.x]; - - if (target_tid == padding_idx) { - if (threadIdx.x == 0) { - nll_loss_outputs[blockIdx.x] = 0.f; - outputs[blockIdx.x] = 0.f; - } - return; - } - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); - } - blockReduce(max_input); - __shared__ float s_max_input; - if (threadIdx.x == 0) { - s_max_input = max_input[0]; - } - __syncthreads(); - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float logit = static_cast(inputs[i]) - s_max_input; - sum_logits[0] += logit; - sum_logits[1] += expf(logit); - } - - blockReduce(sum_logits); - __shared__ float s_sum_logit; - __shared__ float s_sum_exp; - if (threadIdx.x == 0) { - s_sum_logit = sum_logits[0]; - s_sum_exp = sum_logits[1]; - } - __syncthreads(); - - float eps_i = epsilon / (vocab_size - 1); - if (threadIdx.x == 0) { - // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max) - float nll_loss = logf(s_sum_exp) - - static_cast(inputs[block_start + target_tid]) + - s_max_input; - nll_loss_outputs[blockIdx.x] = nll_loss; - float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit; - outputs[blockIdx.x] = - (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss; - } -} - -template -__global__ void ls_cross_entropy_bw_kernel( - const float *__restrict__ grad_outputs, const T *__restrict__ inputs, - const int *__restrict__ targets, T *__restrict__ grad_inputs, - const int padding_idx, const float epsilon, const int vocab_size) { - /* step1: compute each thread's max_logit and sum_exp_logit, store in - * max_input, sum_exp_logit */ - const int block_start = blockIdx.x * vocab_size; - const int left_idx = block_start + threadIdx.x; - const int right_idx = (blockIdx.x + 1) * vocab_size; - float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[1] = {0.f}; - const float grad_out = static_cast(grad_outputs[0]); - int target_tid = targets[blockIdx.x]; - - if (target_tid == padding_idx) { - for (int i = left_idx; i < right_idx; i += blockDim.x) { - grad_inputs[i] = 0.f; - } - return; - } - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); - } - blockReduce(max_input); - __shared__ float s_max_input; - if (threadIdx.x == 0) { - s_max_input = max_input[0]; - } - __syncthreads(); - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float logit = static_cast(inputs[i]) - s_max_input; - sum_logits[0] += expf(logit); - } - - blockReduce(sum_logits); - __shared__ float s_sum_exp; - if (threadIdx.x == 0) { - s_sum_exp = sum_logits[0]; - } - __syncthreads(); - - float eps_i = epsilon / (vocab_size - 1); - float nll_weight = 1.0 - epsilon - eps_i; - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp; - float grad = 0; - grad += (vocab_size * prob - 1) * eps_i; - grad += prob * nll_weight; - if ((i - block_start) == target_tid) { - grad -= nll_weight; - } - grad_inputs[i] = grad_out * grad; - } -} - -template -void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, - float *outputs_ptr, float *nll_loss_ptr, - float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, - const int seq_len, const int vocab_size, - cudaStream_t stream) { - int grid_dim = batch_size * seq_len; - float *nll_loss_buffer = loss_buffer + grid_dim; - ls_cross_entropy_fw_kernel<<>>( - inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx, - epsilon, vocab_size); - - int num_items = grid_dim; - void *d_temp_storage = NULL; - size_t temp_storage_bytes = 0; - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - loss_buffer, outputs_ptr, - num_items, stream)); - CHECK_GPU_ERROR( - g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - loss_buffer, outputs_ptr, - num_items, stream)); - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - nll_loss_buffer, nll_loss_ptr, - num_items, stream)); - CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage)); -} - -template void launch_cross_entropy_fw( - const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr, float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template void launch_cross_entropy_fw<__half>( - const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr, float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template -void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr, - const int padding_idx, const float epsilon, - const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream) { - int grid_dim = batch_size * seq_len; - ls_cross_entropy_bw_kernel<<>>( - grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx, - epsilon, vocab_size); -} - -template void launch_cross_entropy_bw( - const float *grad_outputs_ptr, const float *inputs_ptr, - const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template void launch_cross_entropy_bw<__half>( - const float *grad_outputs_ptr, const __half *inputs_ptr, - const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu deleted file mode 100644 index 09f34763f9b2..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#include "cublas_wrappers.h" - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const float *A, - const float *B, float *C, cublasGemmAlgo_t algo) { - cublasStatus_t status = - cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha, - (const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k, - (const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n, - (const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const __half *A, - const __half *B, __half *C, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmEx( - handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A, - CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F, - (transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C, - CUDA_R_16F, m, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const float *A, const float *B, float *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmStridedBatchedEx( - handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F, - (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F, - (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C, - batch, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, " - "error: %d) \n", - batch, m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const __half *A, const __half *B, __half *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmStridedBatchedEx( - handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F, - (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F, - (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C, - batch, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - - return 0; -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu deleted file mode 100644 index e5ac17308640..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "cuda_util.h" - -/* GPU function guard */ -std::string _cudaGetErrorString(cudaError_t error) { - return cudaGetErrorString(error); -} - -std::string _cudaGetErrorString(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return "CUBLAS_UNKNOW"; -} - -template -void check_gpu_error(T result, char const *const func, const char *const file, - int const line) { - if (result) { - throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" + - std::to_string(line) + - "): " + (_cudaGetErrorString(result)) + "\n"); - } -} - -template void check_gpu_error(cudaError_t result, - char const *const func, - const char *const file, - int const line); -template void check_gpu_error(cublasStatus_t result, - char const *const func, - const char *const file, - int const line); - -template -void print_vec(const T *outv, std::string outn, int num_output_ele) { - std::cout << outn << ": "; - std::vector hout(num_output_ele, (T)0); - cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T), - cudaMemcpyDeviceToHost); - for (int i = 0; i < num_output_ele; i++) { - std::cout << hout[i] << ", "; - } - std::cout << std::endl; -} - -template <> -void print_vec<__half>(const __half *outv, std::string outn, - int num_output_ele) { - std::cout << outn << ": "; - std::vector<__half> hout(num_output_ele, (__half)0.f); - cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half), - cudaMemcpyDeviceToHost); - for (int i = 0; i < num_output_ele; i++) { - std::cout << __half2float(hout[i]) << ", "; - } - std::cout << std::endl; -} - -template void print_vec(const float *outv, std::string outn, - int num_output_ele); - -template void print_vec(const int *outv, std::string outn, - int num_output_ele); - -template void print_vec<__half>(const __half *outv, std::string outn, - int num_output_ele); - -template -T *cuda_malloc(size_t ele_num) { - size_t byte_size = ele_num * sizeof(T); - T *pdata = nullptr; - CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size)); - return pdata; -} - -template float *cuda_malloc(size_t ele_num); - -template __half *cuda_malloc<__half>(size_t ele_num); - -template uint8_t *cuda_malloc(size_t ele_num); - -void cuda_free(void *pdata) { - if (pdata != nullptr) { - cudaFree(pdata); - } -} - -template -struct _isnan { - __device__ bool operator()(T a) const { return isnan(a); } -}; - -template <> -struct _isnan<__half> { - __device__ bool operator()(const __half a) const { return __hisnan(a); } -}; - -template -struct _isinf { - __device__ bool operator()(T a) const { return isinf(a); } -}; - -template <> -struct _isinf<__half> { - __device__ bool operator()(const __half a) const { return __hisinf(a); } -}; - -template -void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, - std::string file, int line, cudaStream_t stream) { - // check_nan_inf = 0 for checking nan - // check_nan_inf = 1 for checking inf - bool res = false; - std::string msg = file + "(" + std::to_string(line) + "): "; - if (check_nan_inf) { - msg += "nan."; - res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, - data_ptr + dsize, _isnan(), false, - thrust::logical_or()); - } else { - msg += "inf."; - res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, - data_ptr + dsize, _isinf(), false, - thrust::logical_or()); - } - if (res) { - throw std::runtime_error(msg); - } - std::cout << msg << " [check pass]." << std::endl; -} - -template void check_nan_inf(const float *data_ptr, int dsize, - bool check_nan_inf, std::string file, - int line, cudaStream_t stream); - -template void check_nan_inf<__half>(const __half *data_ptr, int dsize, - bool check_nan_inf, std::string file, - int line, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu deleted file mode 100644 index ce0b017f12e1..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ /dev/null @@ -1,1002 +0,0 @@ -#include -#include - -#include "kernels.h" - -#include - - -namespace cg = cooperative_groups; - -curandStatePhilox4_32_10_t *curandstate; - -/** - * @brief element-wise activation function on device, like Relu, Gelu - * - * @tparam enum class ActivationType, kRelu, kGelu - * @tparam input type - * @param any shape of float and __half2 - * @return same shape and type with input - */ -template -__forceinline__ __device__ T activation_kernel(T x); - -template <> -__device__ float activation_kernel(float x) { - float cdf = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__device__ __half2 -activation_kernel(__half2 val) { - __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - -template <> -__device__ float activation_kernel(float x) { - return fmaxf(x, 0); -} - -template <> -__device__ __half2 -activation_kernel(__half2 x) { - return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), - fmaxf(0.f, __half2float(x.y))); -} - -/** - * @brief element-wise activation backward function on device - * - * @tparam enum class ActivationType - * @tparam input type - * @param any shape of float and __half2 - * @return same shape of input - */ -template -__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * (dg1 + dg2 + dg3); -} - -template <> -__device__ __half activation_bwd_kernel( - __half grad, __half x_half) { - float x = __half2float(x_half); - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * __float2half(dg1 + dg2 + dg3); -} - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - return x > 0.f ? grad : 0.f; -} - -template <> -__device__ __half -activation_bwd_kernel(__half grad, __half x) { - const __half half_zero = __float2half(0.f); - return x > half_zero ? grad : half_zero; -} - -template <> -__device__ __half2 activation_bwd_kernel( - __half2 grad2, __half2 x_half2) { - const __half half_zero = __float2half(0.f); - return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, - x_half2.y > half_zero ? grad2.y : half_zero); -} - -/** - * @brief init curand states in global memory - * - * @thread grid_dim * block*dim to suuport any size of states - * @param state persistant curand states - * @param seed seed to init states - * @return void - */ -__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, - int seed) { - /* Each thread gets same seed, a different sequence - number, no offset */ - int id = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, id, 0, &state[id]); -} - -void launch_curand_init(int total_count, int dim, cudaStream_t stream) { - cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); - int grid_dim = total_count >> 9; - curand_init_kernel<<>>( - curandstate, std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); -} - -/** - * @brief element-wise dropout, store dropped position in mask, it's not - * in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out any size of float and __half - * @param in same with out - * @param mask uint8 type, same size with out - * @param seed seed to curand - * @return void - */ -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - float *__restrict__ out, - const float *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - - float4 input4 = data4[i]; - float4 res4; - res4.x = input4.x * scale * m[0]; - res4.y = input4.y * scale * m[1]; - res4.z = input4.z * scale * m[2]; - res4.w = input4.w * scale * m[3]; - out4[i] = res4; -} - -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - __half *__restrict__ out, - const __half *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - outs_float4[i] = out_float4; -} - -/** - * @brief element-wise dropout backward with dropout mask, it's - * not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param in any size of float and __half - * @param mask uint8 type, same size with in - * @return void - */ -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - float *out, const float *in, - const uint8_t *__restrict__ mask) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *in4 = reinterpret_cast(in); - const uint32_t *mask4 = reinterpret_cast(mask); - - uint32_t *m4 = reinterpret_cast(m); - m4[0] = mask4[i]; - - float4 input4 = in4[i]; - float4 res4; - res4.x = input4.x * scale * static_cast(m[0]); - res4.y = input4.y * scale * static_cast(m[1]); - res4.z = input4.z * scale * static_cast(m[2]); - res4.w = input4.w * scale * static_cast(m[3]); - out4[i] = res4; -} - -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - __half *out, const __half *in, - const uint8_t *__restrict__ mask) { - const __half scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - float4 *out4 = reinterpret_cast(out); - const float4 *vals_float4 = reinterpret_cast(in); - const uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - uint64_t *m8 = reinterpret_cast(m); - m8[0] = mask8[i]; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - out4[i] = out_float4; -} - -template <> -void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, - int total_count, float ratio, cudaStream_t stream, - bool backward) { - int grid_dim = total_count >> 12; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -template <> -void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, - int total_count, float ratio, - cudaStream_t stream, bool backward) { - int grid_dim = total_count >> 13; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -/** - * @brief fused bias, dropout, and residual at the end of Attention and FFN, - * store dropped position in mask, it's not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param residual [batch_size, seq_len, hidden_size], float and __half - * @param seed seed to curand - * @param hidden_size hidden size - * @return void - */ -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const float *__restrict__ residual, - const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 output4; - - output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; - output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; - output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; - output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; - - out4[i] = output4; -} - -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const __half *__restrict__ residual, - const int seed, const int hidden_size) { - const __half scale = 1. / (1. - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = static_cast(rand.x > ratio); - m[5] = static_cast(rand.y > ratio); - m[6] = static_cast(rand.z > ratio); - m[7] = static_cast(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = m8[0]; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - const __half2 *res_half2 = reinterpret_cast(&res4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = - __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); - out_half2[1] = - __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); - out_half2[2] = - __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); - out_half2[3] = - __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_res_bias(float *out, const float *vals, - uint8_t *mask, const float *bias, - const float *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 12; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, - uint8_t *mask, const __half *bias, - const __half *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 13; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias and dropout backward at the end of Attention and FFN - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, float *__restrict__ in_grad, - float *__restrict__ bias_grad, const float *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - // every block generate 8 bias result - __shared__ float tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - float val = out_grad[idx]; - val *= scale * static_cast(mask[idx]); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - float sum = 0; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, __half *__restrict__ in_grad, - __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); - __shared__ __half2 tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); - const __half2 *out_grad2 = reinterpret_cast(out_grad); - __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - __half2 local_sum = __float2half2_rn(0.f); - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - __half2 val = out_grad2[idx]; - __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); - val *= scale * m2; - local_sum += val; - in_grad2[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - __half2 sum = __float2half2_rn(0.f); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad2[pos] = tile[0][threadIdx.x]; - } -} - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template <> -void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, - const __half *out_grad, const uint8_t *mask, - int row_size, int dim, float ratio, - cudaStream_t stream) { - dim >>= 1; - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, - const float *out_grad, - const uint8_t *mask, int row_size, - int dim, float ratio, - cudaStream_t stream); - -/** - * @brief fused bias, activation, and dropout at the end of first ffn - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @tparam act_type activation function, like kRelu, kGelu - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param seed seed to curand - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 output4; - - output4.x = - activation_kernel(input4.x + b4.x) * scale * m[0]; - output4.y = - activation_kernel(input4.y + b4.y) * scale * m[1]; - output4.z = - activation_kernel(input4.z + b4.z) * scale * m[2]; - output4.w = - activation_kernel(input4.w + b4.w) * scale * m[3]; - - out4[i] = output4; -} - -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2( - activation_kernel(__hadd2(val_half2[0], b_half2[0])), - scale_mask_1); - out_half2[1] = __hmul2( - activation_kernel(__hadd2(val_half2[1], b_half2[1])), - scale_mask_2); - out_half2[2] = __hmul2( - activation_kernel(__hadd2(val_half2[2], b_half2[2])), - scale_mask_3); - out_half2[3] = __hmul2( - activation_kernel(__hadd2(val_half2[3], b_half2[3])), - scale_mask_4); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias, activation, and dropout backward - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @tparam act_type kRelu - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_bwd_kernel( - const int row_size, const float ratio, T *in_grad, - T *__restrict__ bias_grad, const T *__restrict__ input, - const T *__restrict__ bias, const T *out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - - int stride = hidden_size * WARP_SIZE; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - if (col_idx < hidden_size) { - for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { - float val = out_grad[idx]; - float in = input[idx]; - float b = bias[idx % hidden_size]; - val = activation_bwd_kernel( - val * scale * static_cast(mask[idx]), in + b); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - float sum = tile[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; - __syncthreads(); - - if (threadIdx.y == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -// @brief fused bias, activation, and dropout backward -// It is deprecated for precision reason. Keep it for future optimization. -// -// template -// __global__ void ls_dropout_act_bias_bwd_kernel( -// const int row_size, const float ratio, __half * in_grad, -// __half *__restrict__ bias_grad, const __half *__restrict__ input, const -// __half *__restrict__ bias, const __half * out_grad, const uint8_t -// *__restrict__ mask, const int hidden_size) { -// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); -// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; - -// cg::thread_block b = cg::this_thread_block(); -// cg::thread_block_tile g = cg::tiled_partition(b); - -// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); -// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); -// const __half2 *out_grad2 = reinterpret_cast(out_grad); -// const __half2 *input2 = reinterpret_cast(input); -// const __half2 *bias2 = reinterpret_cast(bias); - -// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - -// int stride = hidden_size * WARP_SIZE; -// __half2 local_sum = __float2half2_rn(0.f); - -// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); -// if (col_idx < hidden_size) { -// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { -// __half2 val = out_grad2[idx]; -// __half2 in2 = input2[idx]; -// __half2 b2 = bias2[idx % hidden_size ]; -// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); -// val = activation_bwd_kernel(val * scale -// * -// m2, -// in2+b2); -// local_sum += val; -// in_grad2[idx] = val; -// idx += stride; -// } -// } - -// tile[threadIdx.x][threadIdx.y] = local_sum; -// __syncthreads(); -// __half2 sum = tile[threadIdx.y][threadIdx.x]; -// __syncthreads(); - -// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - -// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; -// __syncthreads(); - -// if (threadIdx.y == 0) { -// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); -// bias_grad2[pos] = tile[0][threadIdx.x]; -// } -// } - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - ls_dropout_act_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); -} - -// template <> -// void launch_ls_dropout_act_bias_bwd( -// __half *in_grad, __half *bias_grad,const __half *input, const __half -// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int -// dim, float ratio, cudaStream_t stream) { -// dim >>= 1; -// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); -// dim3 block_dim(WARP_SIZE, WARP_SIZE); -// ls_dropout_act_bias_bwd_kernel -// <<>>(row_size, ratio, in_grad, -// bias_grad, -// input, bias,out_grad, mask, dim); -// } - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu deleted file mode 100644 index 625b02cd25d9..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ /dev/null @@ -1,232 +0,0 @@ -#include - -#include "kernels.h" - -namespace cg = cooperative_groups; - -/** -@brief: fuse_transpose_bias -Calculate the sum of elements in each column of the matrix. - -@thread -gridDim.x = ceil(cols / WARP_SIZE) -blockDim.x = WARP_SIZE -blockDim.y = WARP_SIZE - -@param -inp: [rows, cols] -out: [cols] -rows: the number of rows in the matrix -cols: the number of cols in the matrix -*/ -template -__global__ void column_sum_reduce(const T *__restrict__ inp, - T *__restrict__ out, int rows, int cols) { - __shared__ float tile[WARP_SIZE][WARP_SIZE]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - int y_stride = cols * WARP_SIZE; - float localSum = 0; - - // Loop across matrix row - // TODO: optimize to log complexity - if (idx < cols) { - int offset = flat_2dim(threadIdx.y, idx, cols); - for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - // The sum of a row in tile is equal to the sum of a col in original matrix - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - // The change of threadIdx.x is continuous - float sum = tile[threadIdx.y][threadIdx.x]; - - __syncthreads(); - - // Calculate the sum of a row in tile - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); - if (pos < cols) out[pos] = sum; - } -} - -// [r, c] -> [c] -template <> -void launch_fuse_transpose_bias_kernel(const float *inp, float *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce - <<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce<__half> - <<>>(inp, out, rows, cols); -} - -/** -@brief: fused_add2 -Add two matrix inp1 and inp2 to out. - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -inp1: [batch_size, seq_len, hidden_dim] -inp2: [batch_size, seq_len, hidden_dim] -out: [batch_size, seq_len, hidden_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -*/ -template -__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, - int hidden_dim); - -template <> -__global__ void fused_add2_kernel(float *out, const float *inp1, - const float *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - val.x = vinp1.x + vinp2.x; - val.y = vinp1.y + vinp2.y; - val.z = vinp1.z + vinp2.z; - val.w = vinp1.w + vinp2.w; - out_4[offset + i] = val; - } -} - -template <> -__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, - const __half *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); - __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); - __half2 *h2_val = reinterpret_cast<__half2 *>(&val); - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); - h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); - h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); - h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); - out_4[offset + i] = val; - } -} - -//[b, s, h] -> [b, s, h] -template <> -void launch_fused_add2(float *out, const float *inp1, const float *inp2, - int batch_size, int seq_len, int hidden_dim, - cudaStream_t &stream) { - hidden_dim >>= 2; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template <> -void launch_fused_add2<__half>(__half *out, const __half *inp1, - const __half *inp2, int batch_size, int seq_len, - int hidden_dim, cudaStream_t &stream) { - hidden_dim >>= 3; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template -__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, - int sz0, int sz2, int sz1_1, int sz1_2) { - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); - if (idx >= nele) { - return; - } - float4 *dst_ptr = (float4 *)output + idx; - int idx2 = idx % sz2; - idx = idx / sz2; - int idx1 = idx % (sz1_1 + sz1_2); - int idx0 = idx / (sz1_1 + sz1_2); - float4 *src_ptr = nullptr; - int sz1 = 0; - if (idx1 < sz1_1) { - sz1 = sz1_1; - src_ptr = (float4 *)inp1; - } else { - idx1 -= sz1_1; - sz1 = sz1_2; - src_ptr = (float4 *)inp2; - } - src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); - dst_ptr[0] = src_ptr[0]; -} - -template <> -void launch_concat3_dim1(const float *inp1, const float *inp2, - float *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 2; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} - -template <> -void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, - __half *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 3; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h deleted file mode 100644 index f7d75f38cc2b..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include -#include - -#include -#include - -#include "cuda_util.h" - -class Context { - public: - Context() : _stream(nullptr) { - CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); - } - - virtual ~Context() {} - - static Context &Instance() { - static Context _ctx; - return _ctx; - } - - void set_stream(cudaStream_t stream) { - _stream = stream; - CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream)); - } - - cudaStream_t get_stream() { return _stream; } - - cublasHandle_t get_cublashandle() { return _cublasHandle; } - - private: - cudaStream_t _stream; - cublasHandle_t _cublasHandle; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h deleted file mode 100644 index f4e9befc6588..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "cuda_util.h" - -template -class CrossEntropyLayer { - public: - CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); - - virtual ~CrossEntropyLayer(); - - void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr); - - void Backward(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr); - - void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); - - private: - void allocate_mem_buffer() { - // allocate local gpu memory - _loss_buffer = cuda_malloc(_max_batch_tokens * 2); - } - - void free_mem_buffer() { - // free local gpu memory - cuda_free(_loss_buffer); - } - - const int _padding_idx; - const float _epsilon; - const int _max_batch_tokens; - - size_t _batch_size; - size_t _seq_len; - size_t _vocab_size; - - float *_loss_buffer; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h deleted file mode 100644 index 90255152b2c8..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const float *A, - const float *B, float *C, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const __half *A, - const __half *B, __half *C, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const float *A, const float *B, float *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); - -int cublas_strided_batched_gemm( - cublasHandle_t handle, int m, int n, int k, const float *alpha, - const float *beta, const __half *A, const __half *B, __half *C, - cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, - int stride_C, int batch, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h deleted file mode 100644 index 1595257be0f5..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -template -void check_gpu_error(T result, char const *const func, const char *const file, - int const line); - -#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__) - -template -void print_vec(const T *outv, std::string outn, int num_output_ele); - -template -T *cuda_malloc(size_t ele_num); - -void cuda_free(void *pdata); - -template -void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, - std::string file, int line, cudaStream_t stream); - -#define CHECK_NAN_INF(ptr, size, stream) \ - check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ - check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h deleted file mode 100644 index 025fbf3f8f15..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ /dev/null @@ -1,96 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -template -class Dropout { - public: - struct Config { - float ratio; - bool training; - - Config(float r) : ratio(r), training(true) {} - float RATIO() const { return training ? ratio : 0.0; } - }; - - Dropout(const Config &config, size_t max_ele_num) - : _config(config), _mask(nullptr) { - _mask = cuda_malloc(max_ele_num); - } - - virtual ~Dropout() { cuda_free(_mask); } - - // after attention softmax - void dropout(T *output, const T *input, int count, cudaStream_t stream, - bool bwd = false) { - launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, - bwd); - } - - void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { - launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), - stream, true); - } - - // transformer layer's postprocessing dropout, after attn or ffn module, - // before residual add. - void bias_dropout_residual(T *output, const T *input, const T *residual, - const T *bias, int rows, int cols, - cudaStream_t stream) { - launch_ls_dropout_res_bias(output, input, _mask, bias, residual, - rows * cols, cols, _config.RATIO(), stream); - } - - void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, - int rows, int cols, cudaStream_t stream) { - launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, - _config.RATIO(), stream); - } - - // dropout inside ffn. - void bias_act_dropout(T *output, const T *input, const T *bias, int rows, - int cols, std::string activation_fn, - cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, - const T *bias, int rows, int cols, - std::string activation_fn, cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - private: - uint8_t *_mask; - Config _config; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h deleted file mode 100644 index 8186da1eed5f..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once - -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#include -#include -#include - -#include - -#include "cublas_wrappers.h" -#include "kernels.h" - -template -class FeedForward { - public: - struct Config { - int outputSize; - int inputSize; - std::array gemm_algos; - Config(int outputs, int inputs) - : outputSize(outputs), - inputSize(inputs), - gemm_algos(std::array({99, 99, 99})) {} - }; - - FeedForward(Config config) : config_(config) {} - - ~FeedForward() {} - - void Forward(int bsz, const T *input_ptr, const T *weights, T *out, - cublasHandle_t &_cublasHandle) { - float alpha = T(1.); - float beta = T(0.); - - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize, - bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, - out, cublasGemmAlgo_t(config_.gemm_algos[0])); - } - void Backward(int bsz, const T *out_grad, const T *input_ptr, - const T *weights, T *weights_grad, T *bias_grad, - cublasHandle_t &_cublasHandle, cudaStream_t &stream, - T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr, - bool compute_bias = true) { - float alpha = (T)1.0, beta = (T)0.0; - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, - config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, - weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1])); - - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize, - bsz, config_.outputSize, &alpha, &beta, weights, out_grad, - inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2])); - if (compute_bias) { - launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, - config_.outputSize, stream); - } - } - - void reset_size(int outputSize, int inputSize) { - config_.outputSize = outputSize; - config_.inputSize = inputSize; - } - - private: - Config config_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h deleted file mode 100644 index 735e1363cc46..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ /dev/null @@ -1,275 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include - -#define MAX_THREADS 1024 -#define WARP_SIZE 32 - -enum class ActivationType { kRelu, kGelu }; - -void launch_curand_init(int total_count, int dim, cudaStream_t stream); - -template -void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int batch_size, - int hidden_dim, cudaStream_t stream); - -template -void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, const T *vars, const T *means, int batch, - int hidden_dim, cudaStream_t stream[2]); - -template -void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads, - int from_len, int to_len, bool mask_future, - cudaStream_t stream); - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream); - -// [b, s, h] -> [b, nh, s, ad] -template -void launch_transform_0213(T *output, const T *vals, int batch_size, - int seq_length, int hidden_dim, int nhead, - cudaStream_t stream); - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template -void launch_bias_add_transform_20314(T *output, const T *input, const T *bias, - int dim_0, int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream); - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template -void launch_transform4d_0213(T *output, const T *vals, int batch_size, - int seq_len, int hidden_dim, int nhead, - int trans_count, cudaStream_t stream); - -template -void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count, - float ratio, cudaStream_t stream, bool backward = false); - -template -void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask, - const T *bias, const T *residual, - int total_count, int dim, float ratio, - cudaStream_t stream); - -template -void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask, - const T *bias, int total_count, int dim, - float ratio, cudaStream_t stream); - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template -void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols, - cudaStream_t stream); - -void launch_param_update(const float *input, __half *output, int size, - cudaStream_t stream); - -template -void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0, - int sz2, int sz1_1, int sz1_2, cudaStream_t stream); - -template -void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size, - int seq_len, int hidden_size, cudaStream_t &stream); - -template -void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, - float *outputs_ptr, float *nll_loss_ptr, - float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, - const int seq_len, const int vocab_size, - cudaStream_t stream); - -template -void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr, - const int padding_idx, const float epsilon, - const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template -void launch_lookup_scale_pos_dropout( - T *output, const int *input, const T *embeddings, const T *pos_embeddings, - uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, - int padding_idx, float dropout_ratio, int step, cudaStream_t &stream); - -template -void launch_d_lookup_scale_pos_dropout( - T *grad_embeddings, const T *grad_output, const int *input, - const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, - int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream); - -/* Convert 2-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) { - return id1 * dim2 + id2; -} - -/* Convert 3-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, - int dim2, int dim3) { - return id1 * dim2 * dim3 + id2 * dim3 + id3; -} - -/* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, - int id4, int dim2, int dim3, - int dim4) { - // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; - int res = id4; - - int ld = dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert 5-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3, - int id4, int id5, int dim2, - int dim3, int dim4, - int dim5) { - // return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) + - // id4*dim5 + dim5; - int res = id5; - - int ld = dim5; - res += id4 * ld; - - ld *= dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert 6-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, - int id4, int id5, int id6, - int dim2, int dim3, int dim4, - int dim5, int dim6) { - // return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) + - // id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6; - int res = id6; - - int ld = dim6; - res += id5 * ld; - - ld *= dim5; - res += id4 * ld; - - ld *= dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_6dim( - int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, - int *id1, int *id2, int *id3, int *id4, int *id5) { - *id5 = src % dim5; - src /= dim5; - - *id4 = src % dim4; - src /= dim4; - - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, - int dim2, int dim3, - int dim4, int *id0, - int *id1, int *id2, - int *id3, int *id4) { - *id4 = src % dim4; - src /= dim4; - - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 4-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, - int dim2, int dim3, - int *id0, int *id1, - int *id2, int *id3) { - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, - int dim2, int *id0, - int *id1, int *id2) { - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 2-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1, - int *id0, int *id1) { - *id1 = src % dim1; - *id0 = src / dim1; -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh deleted file mode 100644 index 4f65e7b54ba1..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh +++ /dev/null @@ -1,12 +0,0 @@ -// copied from https://github.com/dmlc/dgl/pull/2758 -#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ -#define DGL_ARRAY_CUDA_DGL_CUB_CUH_ - -#define CUB_NS_PREFIX namespace ls { -#define CUB_NS_POSTFIX } -#include "cub/cub.cuh" -#include "cub/util_allocator.cuh" -#undef CUB_NS_POSTFIX -#undef CUB_NS_PREFIX - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h deleted file mode 100644 index a7767e187ffc..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Normalize_Layer { - public: - struct Config { - uint32_t hidden_dim; - bool use_mean; - Config(uint32_t hidden_dim, bool use_mean = false) - : hidden_dim(hidden_dim), use_mean(use_mean) {} - }; - - Normalize_Layer(Config config, size_t max_rows) - : config_(config), vars_(nullptr), means_(nullptr) { - vars_ = cuda_malloc(max_rows); - if (config_.use_mean) { - means_ = cuda_malloc(max_rows); - } - } - - ~Normalize_Layer() { - cuda_free(vars_); - cuda_free(means_); - } - - void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, - int batch_size, cudaStream_t stream) { - launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, - config_.hidden_dim, stream); - } - - /* - residual_grad, inp_or_out, betta should be treated carefully. - inp_or_out = input if use_mean else output - residual_grad, betta can be nullptr. - residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln - betta are only used to compute xhat, - (use_mean == false) ^ (betta == nullptr) should be true - */ - void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, int batch_size, cudaStream_t stream[2]) { - launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, - inp_or_out, gamma, betta, vars_, means_, batch_size, - config_.hidden_dim, stream); - } - - inline bool use_mean() const { return config_.use_mean; } - - private: - Config config_; - T *vars_; - T *means_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h deleted file mode 100644 index b917abaf0336..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Softmax { - public: - struct Config { - size_t nhead; - Config(size_t nhead) : nhead(nhead) {} - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, - int to_len, cudaStream_t &stream, bool mask_future = true) { - launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, - to_len, mask_future, stream); - } - - void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, - int to_len, cudaStream_t stream) { - launch_attn_softmax_bw(out_grad, soft_out, - batch_size * config_.nhead * from_len, to_len, - stream); - } - - void reset_size(size_t nhead) { config_.nhead = nhead; } - - private: - Config config_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h deleted file mode 100644 index d386650e8235..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#pragma once - -#include -#include -#include - -#include - -#include "cublas_wrappers.h" - -template -class StridedBatchGemm { - public: - struct Config { - int m; - int n; - int k; - float alpha; - float beta; - cublasOperation_t op_A; - cublasOperation_t op_B; - std::array gemm_algos; - - Config(float param_alpha, float param_beta, cublasOperation_t opA, - cublasOperation_t opB) - : alpha(param_alpha), - beta(param_beta), - op_A(opA), - op_B(opB), - gemm_algos(std::array({99, 99, 99})) {} - void SetConfig(int mm, int nn, int kk) { - m = mm; - n = nn; - k = kk; - } - }; - - StridedBatchGemm(const Config &config) : _config(config) {} - - virtual ~StridedBatchGemm() {} - - void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b, - cublasHandle_t handle) { - int stride_a = _config.m * _config.k; - int stride_b = _config.n * _config.k; - int stride_c = _config.m * _config.n; - - cublas_strided_batched_gemm( - handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, - _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, - stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0])); - } - - void Backward(int bsz, const T *d_output, const T *_buffer_a, - const T *_buffer_b, cublasHandle_t handle, - T *inpGradA = nullptr, T *inpGradB = nullptr) { - int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); - int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); - - int stride_a = mb * _config.n; - int stride_b = _config.n * kb; - int stride_c = _config.m * _config.k; - - // B need to transpose. - cublasOperation_t op_b = - (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - // Calculate d_A. - cublas_strided_batched_gemm( - handle, mb, kb, _config.n, &_config.alpha, &_config.beta, - (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), - (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, - CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, - cublasGemmAlgo_t(_config.gemm_algos[1])); - - // A need to transpose. - cublasOperation_t op_a = - (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - stride_a = _config.m * _config.k; - stride_b = _config.m * _config.n; - stride_c = _config.n * _config.k; - - // Calculate d_B. - cublas_strided_batched_gemm( - handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, - _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, - stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2])); - } - - inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } - - private: - Config _config; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu deleted file mode 100644 index e2f1869b165e..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ /dev/null @@ -1,1172 +0,0 @@ -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float LN_EPSILON = 1e-8f; -#define TILE_DIM 32 - -template -__forceinline__ __device__ T add_eps(T x) { - return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); -} - -/** -@brief: ker_layer_norm -Standard layer normalization. -It will not only output the layer norm result, - but also outputs variance. - may also output means, depends on whether - the means argument is nullptr - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -ln_res: [batch_size* seq_len, hidden_size], ln result. -vars: [batch_size* seq_len], variance per token -means: [batch_size* seq_len], means per token, can be nullput -inp: [batch_size * seq_len, hidden_size], ln input. -scale: [hidden_size], ln scale -bias: [hidden_size], ln bias -*/ -template -__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val = inp_f4[idx]; - l_sum += val.x + val.y + val.z + val.w; - l_square_sum += - val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 4.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 vscale = __ldg((const float4 *)scale + idx); - float4 vbias = __ldg((const float4 *)bias + idx); - float4 val = inp_f4[idx]; - val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; - val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; - val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; - val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; - output_f4[idx] = val; - } -} - -template <> -__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, - __half *means, const __half *inp, - const __half *scale, const __half *bias, - int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 val_f2 = __half22float2(val_h2[i]); - l_sum += val_f2.x + val_f2.y; - l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; - } - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 8.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - // load scale, bias, input - float4 scale_f4 = __ldg((const float4 *)scale + idx); - __half2 *scale_h2 = (__half2 *)(&scale_f4); - float4 bias_f4 = __ldg((const float4 *)bias + idx); - __half2 *bias_h2 = (__half2 *)(&bias_f4); - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); - -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 scale_f2 = __half22float2(scale_h2[i]); - float2 bias_f2 = __half22float2(bias_h2[i]); - float2 val_f2 = __half22float2(val_h2[i]); - val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; - val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; - val_h2[i] = __float22half2_rn(val_f2); - } - output_f4[idx] = val_f4; - } -} - -// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x -// * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 2; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// } -// } - -// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// float4 val_f4_2 = inp_f4[idx+2]; -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + -// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * -// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x -// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + -// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + -// val_f2_3.y * val_f2_3.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 4; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); -// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); -// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); -// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); -// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); -// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); -// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// float4 val_f4_2 = inp_f4[idx+2]; -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); -// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); -// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * -// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var -// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * -// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) -// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = -// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); -// val_h2_2[i] = __float22half2_rn(val_f2_2); -// val_h2_3[i] = __float22half2_rn(val_f2_3); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// output_f4[idx+2] = val_f4_2; -// output_f4[idx+3] = val_f4_3; -// } -// } - -template <> -void launch_layer_norm(float *ln_res, float *vars, float *means, - const float *inp, const float *scale, - const float *bias, int batch_size, int hidden_dim, - cudaStream_t stream) { - if (hidden_dim % 4 != 0) { - throw std::runtime_error("violate hidden_dim % 4 = 0"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, - const __half *inp, const __half *scale, - const __half *bias, int batch_size, - int hidden_dim, cudaStream_t stream) { - if (hidden_dim % 8 != 0) { - throw std::runtime_error("violate hidden_dim % 8 = 0"); - } - hidden_dim >>= 3; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<__half><<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); - // if (hidden_dim % 8 != 0) { - // throw std::runtime_error("violate hidden_dim % 8 = 0"); - // } - // hidden_dim >>= 3; - - // if (hidden_dim * 8 < 8192) { - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm<__half><<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { - // hidden_dim >>= 1; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x2<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { - // hidden_dim >>= 2; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x4<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else { - // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - // } -} - -/** -@brief: ker_ln_bw_dgamma_dbetta -Layer norm backword kernel, compute the gradient of gamma and betta. -dbetta = sum(dout, dim=0) -dgamma = sum(xhat * dout, dim=0) -xhat = (input - mean) * rsqrt(var) or - (output - betta) / gamma - - -@thread -gridDim.x = hidden_size / 32 -blockDim.x = 32 -blockDim.y = 32 - -@param -gamma_grad: [hidden_size], gradient of gamma -betta_grad: [hidden_size], gradient of betta -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat, maybe nullptr -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat, maybe nullptr -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -(gamma && betta) ^ (vars && means) should be true -*/ -template -__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, - const T *out_grad, const T *inp_or_out, - const T *gamma, const T *betta, - const T *vars, const T *means, int rows, - int width) { - __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - // Loop across inp height - float dbetta = 0; - float dgamma = 0; - float dout, val; - if (idx < width) { - if (means == nullptr) { - float vbetta = (float)betta[idx]; - float vgamma = (float)gamma[idx]; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is output - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - vbetta) / add_eps(vgamma) * dout); - offset += y_stride; - } - } else { - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is input - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - (float)means[r]) * - rsqrtf((float)vars[r] + LN_EPSILON) * dout); - offset += y_stride; - } - } - } - - // Sum the shared buffer. - betta_buffer[threadIdx.x][threadIdx.y] = dbetta; - gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; - __syncthreads(); - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (threadIdx.x == 0 && idx < width) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/** -@brief: ker_ln_bw_dinp -Layer norm backword kernel, compute the gradient of input. -dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) - * rsqrt(var) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dxhat = dout * gamma - - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, - usually appear in pre-layer-norm for transformer layer, maybe nullptr -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat and dxhat -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat and dinp -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -*/ -template -__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, - const T *gamma, const T *betta, const T *vars, - const T *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - float4 dxhat, xhat; - float var_rsqrt; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - dxhat = ((const float4 *)out_grad)[offset]; - float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; - dxhat.x *= vgamma.x; - dxhat.y *= vgamma.y; - dxhat.z *= vgamma.z; - dxhat.w *= vgamma.w; - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - xhat = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); - xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); - xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); - xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; - xhat.x = (xhat.x - fmean) * var_rsqrt; - xhat.y = (xhat.y - fmean) * var_rsqrt; - xhat.z = (xhat.z - fmean) * var_rsqrt; - xhat.w = (xhat.w - fmean) * var_rsqrt; - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - float reduce_val[2] = {0.f, 0.f}; - if (threadIdx.x < hidden_dim) { - reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; - reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + - dxhat.w * xhat.w; - } - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - dxhat.x += dresidual.x; - dxhat.y += dresidual.y; - dxhat.z += dresidual.z; - dxhat.w += dresidual.w; - } - ((float4 *)inp_grad)[offset] = dxhat; -} - -template <> -__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - - float2 dxhat[4], xhat[4]; - float var_rsqrt; - float4 vtmp; - __half2 *tmp_h2; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vbetta = __half22float2(betta_h2[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; -} - -__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float var_rsqrt; - float4 vtmp, vtmp_1; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 2; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; -} - -__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float2 dxhat_2[4], xhat_2[4]; - float2 dxhat_3[4], xhat_3[4]; - float var_rsqrt; - float4 vtmp, vtmp_1, vtmp_2, vtmp_3; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - __half2 *tmp_h2_2; - __half2 *tmp_h2_3; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - vtmp_2 = ((const float4 *)out_grad)[offset + 2]; - vtmp_3 = ((const float4 *)out_grad)[offset + 3]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); - tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; - float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; - float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); - __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); - __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vdout_2 = __half22float2(tmp_h2_2[i]); - float2 vdout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - dxhat_2[i].x = vdout_2.x * vgamma_2.x; - dxhat_2[i].y = vdout_2.y * vgamma_2.y; - dxhat_3[i].x = vdout_3.x * vgamma_3.x; - dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + - dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + - dxhat_3[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; - vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; - float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; - float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); - __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); - __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vout_2 = __half22float2(tmp_h2_2[i]); - float2 vout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - float2 vbetta_2 = __half22float2(betta_h2_2[i]); - float2 vbetta_3 = __half22float2(betta_h2_3[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); - xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); - xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - float2 vinp_2 = __half22float2(tmp_h2_2[i]); - float2 vinp_3 = __half22float2(tmp_h2_3[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; - xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; - xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); - __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); - __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_2[2 * i])); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_3[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; - ((float4 *)inp_grad)[offset + 2] = vtmp_2; - ((float4 *)inp_grad)[offset + 3] = vtmp_3; -} - -/** -Layer norm backword, - compute the gradient of gamma, betta and input. -dbetta = sum(dout, dim=0) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dgamma = sum(xhat * dout, dim=0) -dxhat = dout * gamma -dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) - * rsqrt(var) - -residual_grad, means, betta can be nullptr. -residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln -means and betta are only used to compute xhat, - (means == nullptr) ^ (betta == nullptr) should be true -*/ -template <> -void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, - const float *out_grad, const float *residual_grad, - const float *inp_or_out, const float *gamma, - const float *betta, const float *vars, - const float *means, int batch, int hidden_dim, - cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 4 != 0 || hidden_dim > 4096) { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); -} - -template <> -void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, - __half *inp_grad, const __half *out_grad, - const __half *residual_grad, const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, int batch, - int hidden_dim, cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<__half><<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 8 != 0) { - throw std::runtime_error("hidden_dim % 8 != 0"); - } - hidden_dim >>= 3; - - if (hidden_dim * 8 <= 8192) { - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { - hidden_dim >>= 1; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - } -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu deleted file mode 100644 index 3862a699d3c3..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ /dev/null @@ -1,365 +0,0 @@ -#include -#include - -#include -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float EPSILON = 1e-8f; - -/** -@brief: softmax_kernel -Softmax forward kernel for - enc-self-attn, dec-self-attn, encdec-attn - -@thread -gridDim.x = dynamic -gridDim.y = batch_size -gridDim.z = nhead -blockDim.x = from_len - -@param -inp: [batch_size, nhead, from_len, to_len], softmax input. -attn_mask: [batch_size, to_len], padding tokens are -inf, - non padding tokens are 0. - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template -__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // block reduce max - blockReduce(l_max); - // write shared - __shared__ float s_max[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_max[i] = l_max[i]; - } - } - __syncthreads(); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - s_max[i]); - l_sum[i] += val[i][j]; - } - } - // block reduce sum - blockReduce(l_sum); - // write shared - __shared__ float s_sum[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - } - } - __syncthreads(); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * s_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -template -__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // warp reduce max - warpReduce(l_max); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - l_max[i]); - l_sum[i] += val[i][j]; - } - } - // warp reduce sum - warpReduce(l_sum); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * l_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -/* - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template <> -void launch_attn_softmax(float *inp, const float *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 16; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 32; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 64; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -template <> -void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<__half, 32, 1><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<__half, 32, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 8; - ker_attn_softmax<__half, 64, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 16; - ker_attn_softmax<__half, 128, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 32; - ker_attn_softmax<__half, 256, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -/** -@brief: ker_attn_softmax_bw -Softmax backward in self attention. - -@thread -gridDim.x = batch_size * nhead * seq_len / warps_per_block -blockDim.x = WARP_SIZE -blockDim.y = warps_per_block - -@param -grad: [batch_size, nhead, seq_len, seq_len], output grad. -output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. -*/ -template -__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - inp += offset; - - T grad_reg[ITERATIONS]; - T inp_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - inp_reg[i] = inp[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)inp_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); - } -} - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream) { - const int warps_per_block = 4; - // rows = batch_size * nhead * from_len - dim3 grid_dim(rows / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (softmax_len <= 32) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 64) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 128) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 256) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 384) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 512) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 768) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 1024) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 2048) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else - throw std::runtime_error( - std::string( - "Special sequence length found in softmax backward, seq_len: ") + - std::to_string(softmax_len)); -} - -template void launch_attn_softmax_bw<__half>(__half *out_grad, - const __half *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); -template void launch_attn_softmax_bw(float *out_grad, - const float *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu deleted file mode 100644 index 04de3c092ee0..000000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ /dev/null @@ -1,314 +0,0 @@ -#include -#include -#include - -#include "kernels.h" - -using namespace cub; - -/** -@brief: transform_0213 -Split the attention heads and reshape input -during backward progress of encoder self-attention - -@thread -gridDim.x = batch_size -gridDim.y = seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -input: [batch_size, seq_len, hidden_dim] -output: [batch_size, nhead, seq_len, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -*/ - -template -__global__ void transform_0213(T *output, const T *input, int hidden_dim, - int head_dim); - -template <> -__global__ void transform_0213(float *output, const float *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -template <> -__global__ void transform_0213<__half>(__half *output, const __half *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -// [b, s, h] -> [b, nh, s, ad] -template <> -void launch_transform_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213 - <<>>(output, input, hidden_dim, head_dim); -} - -template <> -void launch_transform_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213<__half> - <<>>(output, input, hidden_dim, head_dim); -} - -/** -@brief: bias_add_transform_20314 -Add bias to input, transform from -[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] - -@thread -gridDim.x = dim_0 -gridDim.y = dim_1 -gridDim.z = dim_2 -blockDim.x = min(dim_3 * dim_4, MAX_THREADS) - -@param -input: [dim_0, dim_1, dim_2, dim_3, dim_4] -bias: [dim_2, dim_3, dim_4] -output: [dim_2, dim_0, dim_3, dim_1, dim_4] -*/ -template -__global__ void bias_add_transform_20314(T *output, const T *input, - const T *bias, int dim_3, int dim_4); - -template <> -__global__ void bias_add_transform_20314(float *output, - const float *input, - const float *bias, int dim_3, - int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - vres4.x = vqkv4.x + vbias4.x; - vres4.y = vqkv4.y + vbias4.y; - vres4.z = vqkv4.z + vbias4.z; - vres4.w = vqkv4.w + vbias4.w; - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -template <> -__global__ void bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_3, - int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); - __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); - __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); - h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); - h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); - h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template <> -void launch_bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 2; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314 - <<>>(output, input, bias, dim_3, dim_4); -} - -template <> -void launch_bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 3; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314<__half> - <<>>(output, input, bias, dim_3, dim_4); -} - -/** -@brief: transform4d_0213 -Reshape the input matrix to merge the heads - -@thread -gridDim.x = (num_all + max_block_thread - 1) / max_block_thread -blockDim.x = max_block_thread - -@param -input: [trans_count, batch_size, nhead, seq_len, head_dim] -output: [batch_size, seq_len, trans_count, nhead, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -trans_count: 1 or 3, the count of matrice need to be transformed -*/ -template -__global__ void transform4d_0213(T *output, const T *input, int batch_size, - int seq_len, int trans_count, int nhead, - int head_dim, int num_all) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset >= num_all) { - return; - } - int trans_id, batch_id, head_id, token_id, dim_id; - decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, - &batch_id, &head_id, &token_id, &dim_id); - // [b, s, tc, nh, ad] - int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, - seq_len, trans_count, nhead, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - res4[trg_offset] = input4[offset]; -} - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template <> -void launch_transform4d_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} - -template <> -void launch_transform4d_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, - int hidden_dim, int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<__half><<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp deleted file mode 100644 index d08f3dbc74d8..000000000000 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ /dev/null @@ -1,406 +0,0 @@ -#include "multihead_attention_1d.h" - -#include -#include -#include - -#if TORCH_VERSION_MAJOR > 1 || \ - (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) -#include -#else -#include -#endif -#include - -#include "context.h" -#include "kernels.h" - -template -MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_size, - int num_heads, - float attn_prob_dropout_ratio, - float hidden_output_dropout_ratio, - bool pre_or_postLayerNorm) - : _layer_id(layer_id), - _max_batch_tokens(max_batch_tokens), - _max_seq_len(max_seq_len), - _hidden_size(hidden_size), - _heads(num_heads), - _training(true), - _pre_or_postLayerNorm(pre_or_postLayerNorm), - _qkv_linear( - typename FeedForward::Config(3 * hidden_size, hidden_size)), - _attn_out_linear( - typename FeedForward::Config(hidden_size, hidden_size)), - _attn_ln(typename Normalize_Layer::Config(hidden_size, false), - _max_batch_tokens), - _softmax(typename Softmax::Config(num_heads)), - _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), - _max_batch_tokens * _heads * _max_seq_len), - _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), - _max_batch_tokens * _hidden_size), - _attn_scores(typename StridedBatchGemm::Config( - (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, - CUBLAS_OP_N)), - _attn_context(typename StridedBatchGemm::Config( - T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { - assert(_hidden_size % _heads == 0); -} - -template -MultiHeadAttention::~MultiHeadAttention() { - free_mem_buffer(); -} - -template -void MultiHeadAttention::attn_layer_fw(const T *input_ptr, - const T *input_mask_ptr, - T *output_ptr, T *buffer) { - T *q_tf_ptr = _qkv_ptr; - T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; - T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; - - if (_pre_or_postLayerNorm) { - _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); - } - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; - _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, - _cublasHandle); - - launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, - _batch_size, _seq_len, 3, _heads / pg_size, - _hidden_size / _heads, _stream); - - // attention scores, q*k - _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle); - - // Softmax + Mask - _softmax.reset_size(_heads / pg_size); - _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, - _seq_len, _stream, true); - - // attn prob dropout. - _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - // attention context, score * v - _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, - _cublasHandle); - - // [b, nh, s, ad] -> [b, s, nh, ad] - launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, - _hidden_size / pg_size, _heads / pg_size, 1, - _stream); - - _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, - output_ptr, _cublasHandle); - - // allreduce - if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { - } else { - auto data_type = torch::kFloat; - if (typeid(T) != typeid(float)) { - data_type = torch::kHalf; - } - auto output_tensor = torch::from_blob( - output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); - std::vector allreduce_tensors = {output_tensor}; - auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); - work->wait(); - } - - _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, - _attn_ob_ptr, _batch_tokens, _hidden_size, - _stream); - if (!_pre_or_postLayerNorm) { - // in-place ln since ln-input will not be used in post-ln mode - _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); - } -} - -template -void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, - T *out_ptr) { - _stream = Context::Instance().get_stream(); - _cublasHandle = Context::Instance().get_cublashandle(); - T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim - - attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer); -} - -template -void MultiHeadAttention::attn_layer_bw(const T *input_ptr, - const T *input_mask_ptr, - const T *output_ptr, - const T *grad_output_ptr, - T *grad_input_ptr, T *buffer) { - cudaStream_t streams[2] = {_stream, _stream}; - - const T *q_tf_ptr = _qkv_ptr; - const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; - const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; - // batch_dim = batch_size * seq_len * hidden_size - // buffer size: batch_dim * 3 + max(batch_dim * 3, - // batch_size * head_num * seq_len * seq_len) - T *grad_residual_ptr = buffer; - buffer += _batch_dim; - - T *grad_input_buf_ptr = buffer; // batch_dim - T *grad_qkv_5d_ptr = buffer; // batch_dim * 3 - buffer += 3 * _batch_dim / pg_size; - - T *grad_qkv_4d_ptr = buffer; // batch_dim * 3 - T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len - // buffer += max(3 * _batch_dim, - // batch_size * head_num * seq_len * seq_len); - - if (_pre_or_postLayerNorm) { - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_output_ptr, _batch_tokens, - _hidden_size, _stream); - } else { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, - grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr, - _attn_nb_ptr, _batch_tokens, streams); - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_residual_ptr, _batch_tokens, - _hidden_size, _stream); - } - - // bw of output project - _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, - _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - false); - launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - _stream); - - // bw of score * v - _attn_context.Backward( - _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, - grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); - - _attn_prob_dropout.d_dropout(grad_softmax_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - _softmax.reset_size(_heads / pg_size); - _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, - _seq_len, _stream); - - // bw of q * k - _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size, - grad_qkv_5d_ptr); - - // [3, b, nh, s, ad] -> [b, s, 3, h] - launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - 3, _stream); - - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; - _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, - _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - true); - - // allreduce - if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { - } else { - auto data_type = torch::kFloat; - if (typeid(T) != typeid(float)) { - data_type = torch::kHalf; - } - auto grad_input_tensor = - torch::from_blob(grad_input_buf_ptr, - {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); - std::vector allreduce_tensors = {grad_input_tensor}; - auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); - work->wait(); - } - - if (_pre_or_postLayerNorm) { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, - grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr, - _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); - } else { - // FIXME later - launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, - _batch_size, _seq_len, _hidden_size, _stream); - } -} - -template -void MultiHeadAttention::Backward(const T *grad_output_ptr, - const T *input_ptr, const T *output_ptr, - const T *input_mask_ptr, - T *grad_input_ptr) { - _stream = Context::Instance().get_stream(); - _cublasHandle = Context::Instance().get_cublashandle(); - T *buffer = _shared_mem_ptr; - - /* - buffer size needed by attn bw: - 4 * _batch_dim + max(3 * _batch_dim, - _batch_size * _head_num * _seq_len * _seq_len); - */ - attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, - grad_input_ptr, buffer); -} - -template -void MultiHeadAttention::SetTrainingMode(bool training) { - // Dropout will be skipped when not in training model. - _attn_prob_dropout.SetTrainingMode(training); - _attn_dropout.SetTrainingMode(training); -} - -template -T *MultiHeadAttention::_shared_mem_ptr = nullptr; - -template class MultiHeadAttention; -template class MultiHeadAttention<__half>; - -// x is torch::Tensor -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -static std::unordered_map> s_multihead_attention; - -template -int create_multihead_attention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_dim, int num_heads, - float attn_prob_dropout_ratio, - float hidden_dropout_ratio, - bool pre_or_postLayerNorm, - c10::intrusive_ptr pg_) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - Context::Instance().set_stream(stream); - auto layer = std::make_shared>( - layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, - attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm); - - layer->SetPG(pg_); - - s_multihead_attention[layer_id] = layer; - - std::string dtype = (std::is_same::value) ? "half" : "float"; - - return 0; -} - -template -std::vector multihead_attention_fw( - int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask, - const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias, - const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias, - const torch::Tensor &norm_weight, const torch::Tensor &norm_bias, - bool training_mode, bool prelayernorm) { - CHECK_INPUT(input); - CHECK_INPUT(input_mask); - - const T *input_ptr = (const T *)input.data_ptr(); - const T *input_mask_ptr = (const T *)input_mask.data_ptr(); - - auto output = torch::empty_like(input); - T *out_ptr = (T *)output.data_ptr(); - - std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); - layer->set_cur_batch_shape(input.size(0), input.size(1)); - layer->SetTrainingMode(training_mode); - - layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr(); - layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr(); - layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr(); - layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr(); - layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr(); - layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr(); - - layer->Forward(input_ptr, input_mask_ptr, out_ptr); - - return {output}; -} - -template -std::vector multihead_attention_bw( - int layer_id, const torch::Tensor &grad_dec_output, - const torch::Tensor &output, const torch::Tensor &input, - const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight, - const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, - const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight, - const torch::Tensor &norm_bias) { - auto g_output = grad_dec_output.contiguous(); - CHECK_INPUT(g_output); - CHECK_INPUT(output); - CHECK_INPUT(input); - CHECK_INPUT(input_mask); - - auto grad_input = torch::empty_like(input); - auto grad_in_proj_weight = torch::empty_like(in_proj_weight); - auto grad_in_proj_bias = torch::empty_like(in_proj_bias); - auto grad_out_proj_weight = torch::empty_like(out_proj_weight); - auto grad_out_proj_bias = torch::empty_like(out_proj_bias); - auto grad_norm_weight = torch::empty_like(norm_weight); - auto grad_norm_bias = torch::empty_like(norm_bias); - - // inputs. - const T *grad_dec_output_ptr = (const T *)g_output.data_ptr(); - const T *input_ptr = (const T *)input.data_ptr(); - const T *output_ptr = (const T *)output.data_ptr(); - const T *input_mask_ptr = (const T *)input_mask.data_ptr(); - - // outputs. - T *grad_input_ptr = (T *)grad_input.data_ptr(); - - std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); - layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); - - layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); - layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr(); - layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr(); - layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr(); - layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); - layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); - - layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, - grad_input_ptr); - - return {grad_input, grad_in_proj_weight, grad_in_proj_bias, - grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight, - grad_norm_bias}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multihead_attention_fw_fp32", &multihead_attention_fw, - "Multi-head Attention forward with fp32 (CUDA)"); - m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>, - "Multi-head Attention forward with fp16 (CUDA)"); - m.def("multihead_attention_bw_fp32", &multihead_attention_bw, - "Multi-head Attention backward with fp32 (CUDA)"); - m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>, - "Multi-head Attention backward with fp16 (CUDA)"); - m.def("create_multihead_attention_fp32", &create_multihead_attention, - "Create Multi-head Attention with fp32 (CUDA)"); - m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>, - "Create Multi-head Attention with fp16 (CUDA)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h deleted file mode 100644 index 6505eb31fb9f..000000000000 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ /dev/null @@ -1,167 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#if TORCH_VERSION_MAJOR > 1 || \ - (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) -#include -#else -#include -#endif - -#include -#include - -#include "cuda_util.h" -#include "dropout.h" -#include "feed_forward.h" -#include "normalize_layer.h" -#include "softmax.h" -#include "strided_batch_gemm.h" - -template -class MultiHeadAttention { - public: - MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, - int hidden_size, int num_heads, float attn_dropout_ratio, - float hidden_output_dropout_ratio, - bool pre_or_postLayerNorm); - - virtual ~MultiHeadAttention(); - - void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); - - void Backward(const T *grad_output_ptr, const T *input_ptr, - const T *output_ptr, const T *input_mask_ptr, - T *grad_input_ptr); - - void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, - T *buffer); - - void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, - const T *output_ptr, const T *grad_output_ptr, - T *grad_input_attn_layer_bwptr, T *buffer); - - void set_cur_batch_shape(int batch_size, int seq_len) { - _batch_size = batch_size; - _seq_len = seq_len; - _batch_tokens = batch_size * seq_len; - _batch_heads = batch_size * _heads / pg_size; - _batch_dim = _batch_tokens * _hidden_size; - _attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads); - _attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len); - } - - void SetTrainingMode(bool training); - inline bool IsTrainingMode() const { return _training; } - - void SetPG(c10::intrusive_ptr pg_) { - pg = pg_; - pg_size = 1; - if (pg != c10::detail::UniqueVoidPtr()) { - pg_size = pg->getSize(); - } - allocate_mem_buffer(); - } - - // weights ptr - const T *_attn_qkvw_ptr; - const T *_attn_qkvb_ptr; - const T *_attn_ow_ptr; - const T *_attn_ob_ptr; - const T *_attn_nw_ptr; - const T *_attn_nb_ptr; - - // grads ptr - T *_grad_attn_qkvw_ptr; - T *_grad_attn_qkvb_ptr; - T *_grad_attn_ow_ptr; - T *_grad_attn_ob_ptr; - T *_grad_attn_nw_ptr; - T *_grad_attn_nb_ptr; - - private: - void allocate_mem_buffer() { - // allocate local gpu memory - if (_pre_or_postLayerNorm) { - _gemmQKV_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); - } else { - _gemmQKV_inp_ptr = nullptr; - } - - _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); - _soft_out_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _ctx_bufB_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); - - // buffer size needed by attn bw - size_t smem_size = - 4 * _max_batch_tokens * _hidden_size / pg_size + - std::max(3 * _max_batch_tokens * _hidden_size / pg_size, - _max_batch_tokens * _heads / pg_size * _max_seq_len); - - if (!_shared_mem_ptr) { - cuda_free(_shared_mem_ptr); - _shared_mem_ptr = cuda_malloc(smem_size); - } - } - - void free_mem_buffer() { - // free local gpu memory - cuda_free(_gemmQKV_inp_ptr); - cuda_free(_qkv_ptr); - cuda_free(_soft_out_ptr); - cuda_free(_ctx_bufB_ptr); - cuda_free(_attn_o_inp_ptr); - - // free shared gpu memory between layers - cuda_free(_shared_mem_ptr); - _shared_mem_ptr = nullptr; - } - - // const parameter between batch - const size_t _layer_id; - const size_t _hidden_size; - const size_t _heads; - const size_t _max_batch_tokens; - const size_t _max_seq_len; - const bool _pre_or_postLayerNorm; - // dynamic parameter between batch - size_t _batch_size; - size_t _seq_len; - size_t _batch_tokens; - size_t _batch_heads; - size_t _batch_dim; - bool _training; - - cublasHandle_t _cublasHandle; - cudaStream_t _stream; - - // layers - FeedForward _qkv_linear; - FeedForward _attn_out_linear; - Normalize_Layer _attn_ln; - Softmax _softmax; - Dropout _attn_prob_dropout; - Dropout _attn_dropout; - StridedBatchGemm _attn_scores; - StridedBatchGemm _attn_context; - - // local GPU memory - T *_gemmQKV_inp_ptr; - T *_qkv_ptr; - T *_soft_out_ptr; - T *_ctx_bufB_ptr; - T *_attn_o_inp_ptr; - // shared GPU memory between layer - static T *_shared_mem_ptr; - - c10::intrusive_ptr pg; - int pg_size; -}; diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp deleted file mode 100644 index 8444272940b4..000000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include - -#include "linear.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, - "Linear SiLU (INT8)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu deleted file mode 100644 index a30d02a4cf42..000000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu +++ /dev/null @@ -1,162 +0,0 @@ -// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu - -#include "linear.h" -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = float; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - -#if CUDA_ARCH >= 800 - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - EpilogueOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -#elif CUDA_ARCH >= 750 - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - - using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, - ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, - DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, - DefaultGemmCfg::InstructionShape, - EpilogueOp>; -#elif CUDA_ARCH >= 700 - #define USE_TORCH_SILU - using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< - cutlass::arch::OpClassSimt, cutlass::arch::Sm70, - ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassSimt, cutlass::arch::Sm70, - DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, - DefaultGemmCfg::InstructionShape, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; -#else - #error "Unsupported cuda arch" -#endif - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - auto device = input.device(); - // use the broadcasted bias as the output - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha, beta}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } -#ifdef USE_TORCH_SILU -#undef USE_TORCH_SILU - out = torch::silu(out); -#endif - return out; -} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h deleted file mode 100644 index b62a27f3f8f3..000000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h +++ /dev/null @@ -1,12 +0,0 @@ -#include -#include - -#include -#include - -torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -); diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py deleted file mode 100644 index cad36e598d14..000000000000 --- a/colossalai/kernel/cuda_native/mha/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .mha import ColoAttention - -__all__ = ["ColoAttention"] diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py deleted file mode 100644 index 9ee83915b1b4..000000000000 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ /dev/null @@ -1,80 +0,0 @@ -import warnings -from typing import Optional - -import torch - - -def is_ampere_or_better_gpu(): - if torch.cuda.is_available(): - device = torch.device("cuda") - properties = torch.cuda.get_device_properties(device) - if properties.major >= 8: # Ampere GPUs or newer - return True - return False - - -# "Check Ampere GPUs or newer" -HAS_FLASH_ATTN = False -if is_ampere_or_better_gpu(): - HAS_FLASH_ATTN = True -else: - warnings.warn("FlashAttention only supports Ampere GPUs or newer.") - HAS_FLASH_ATTN = False -try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - - HAS_FLASH_ATTN = True -except ImportError: - warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") - HAS_FLASH_ATTN = False - -if HAS_FLASH_ATTN: - pass - - from .utils import SeqLenInfo - - def flash_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - """ - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - if padded: - if seq_len_info_kv == None: - seq_len_info_kv = seq_len_info_q - - attn_out = flash_attn_varlen_func( - q, - k, - v, - seq_len_info_q.cu_seqlens, - seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, - seq_len_info_kv.max_seqlen, - dropout_p, - scale, - causal, - ) - else: - attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) - return attn_out diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py deleted file mode 100644 index 649e74d61bab..000000000000 --- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py +++ /dev/null @@ -1,70 +0,0 @@ -import warnings - -HAS_MEM_EFF_ATTN = False -try: - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) - - HAS_MEM_EFF_ATTN = True -except ImportError: - warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") - HAS_MEM_EFF_ATTN = False - -if HAS_MEM_EFF_ATTN: - """ - A general attention module using the flash attention kernels from xformers: - https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha - """ - from typing import Optional - - import torch - - from .utils import SeqLenInfo - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - def mem_eff_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - attn_bias = None - if padded: # bert style - if not causal: - attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - if padded: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) - - # shape: (b*s, n, d) - if padded: - out = out.squeeze(0) - - return out diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py deleted file mode 100644 index 1c778439d33f..000000000000 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ /dev/null @@ -1,113 +0,0 @@ -import math -from typing import Optional - -import torch -from einops import rearrange - -from ..scaled_softmax import AttnMaskType -from .flash_attn_2 import HAS_FLASH_ATTN -from .mem_eff_attn import HAS_MEM_EFF_ATTN -from .utils import Repad, SeqLenInfo, Unpad - -if HAS_FLASH_ATTN: - from .flash_attn_2 import flash_attention -if HAS_MEM_EFF_ATTN: - from .mem_eff_attn import mem_eff_attention - - -class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): - super().__init__() - assert ( - embed_dim % num_heads == 0 - ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - if scale is not None: - self.scale = scale - else: - self.scale = 1 / math.sqrt(embed_dim // num_heads) - self.dropout = dropout - - if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN: - raise Exception("flash attention can not support!") - - @staticmethod - def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return Unpad.apply(tensor, indices) - - @staticmethod - def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: - return Repad.apply(tensor, indices, batch_size, seq_len) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None, - ): - attn = None - if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: - attn = flash_attention - else: - attn = mem_eff_attention - - padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 - causal = attn_mask_type is not None and attn_mask_type.value > 1 - - batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] - # unpad - seq_len_info_q = None - seq_len_info_kv = None - if padded: - # bert style, unpad process - assert ( - attn_mask is not None - ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, ( - "attention mask is supposed to have shape (batch_size, seq_len), " - + f"but got {attn_mask.dim()} dimensions." - ) - - # bert style - if tgt_len == src_len: - seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query, key, value = self.unpad( - torch.stack([query, key, value], dim=2), seq_len_info_q.indices - ).unbind(dim=1) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - seq_len_info_kv = seq_len_info_q - else: - seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) - seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) - if batch_size > 1: - query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( - dim=1 - ) - else: - query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - - out = attn( - query, - key, - value, - seq_len_info_q, - seq_len_info_kv, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded, - ) - - # repad - if padded: - if batch_size > 1: - out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - - out = rearrange(out, "b s h d -> b s (h d)") - return out diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py deleted file mode 100644 index 5f01e3ef327d..000000000000 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -from dataclasses import dataclass -from typing import Iterable, Tuple - -import torch -import torch.nn.functional as F -from einops import rearrange - -from colossalai.utils.device import get_current_device - - -class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, "b s ... -> (b s) ...") - ctx.shape = out.shape - # [ntokens, ...] - return out[indices] - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [ntokens, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output - grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) - # [b, s, ...] - return grad, None - - -class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - return out - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [b*s, ...] - grad = grad_output[indices] - # [ntokens, ...] - return grad, None, None, None - - -@dataclass -class SeqLenInfo: - seqlens: Iterable[int] = None - indices: torch.Tensor = None - max_seqlen: int = None - cu_seqlens: torch.Tensor = None - - @staticmethod - def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): - if attn_mask is not None: - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() - else: - batch_size, tgt_len = size[0], size[1] - indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) - seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) - max_seqlen = max(seqlens) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) - return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py deleted file mode 100644 index 87afc1862847..000000000000 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ /dev/null @@ -1,338 +0,0 @@ -import math -from dataclasses import dataclass - -import torch -from torch import nn -from torch.autograd import Function - - -def check_config(config): - if config.hidden_size % config.nhead != 0: - raise Exception("hidden_size % nhead != 0") - - factor = 8 if config.fp16 else 4 - upbound = factor * 1024 * 4 - if config.hidden_size > upbound: - # as required by ln backward kernel currently - raise Exception(f"hidden_size > {upbound}") - - head_dim = config.hidden_size // config.nhead - if head_dim % factor != 0: - # as required by reshape kernel - raise Exception(f"head_dim({head_dim}) % {factor} != 0") - - -def calc_offset(sizes): - offsets = [0] - tmp = 0 - for x in sizes: - tmp += x - offsets.append(tmp) - return offsets - - -colossal_multihead_attention = None - - -@dataclass -class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 precision - - -class MultiHeadAttention1DFunc(Function): - @staticmethod - def forward( - ctx, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - config, - ): - cuda_module = colossal_multihead_attention - forward_func = ( - cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32 - ) - if config.fp16: - input = input.to(torch.half) - input_mask = input_mask.to(torch.half) - - (output,) = forward_func( - config.layer_id, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - config.training, - config.norm_first, - ) - - if config.is_grad_enabled and config.training: - ctx.save_for_backward( - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) - ctx.config = config - return output - - @staticmethod - def backward(ctx, grad_output): - assert ctx.config.training - - cuda_module = colossal_multihead_attention - backward_func = ( - cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32 - ) - - ( - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) = ctx.saved_tensors - - grad_input = None - grad_in_proj_weight = None - grad_in_proj_bias = None - grad_out_proj_weight = None - grad_out_proj_bias = None - grad_norm_weight = None - grad_norm_bias = None - - if ctx.config.fp16: - grad_output = grad_output.to(torch.half) - output = output.to(torch.half) - input = input.to(torch.half) - input_mask = input_mask.to(torch.half) - ( - grad_input, - grad_in_proj_weight, - grad_in_proj_bias, - grad_out_proj_weight, - grad_out_proj_bias, - grad_norm_weight, - grad_norm_bias, - ) = backward_func( - ctx.config.layer_id, - grad_output, - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) - - return ( - grad_input, - None, - grad_in_proj_weight, - grad_in_proj_bias, - grad_out_proj_weight, - grad_out_proj_bias, - grad_norm_weight, - grad_norm_bias, - None, - ) - - -class MultiHeadAttention(nn.Module): - """Initialize the MultiHeadAttention. - - Static variable: - - layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated, - e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23. - - Arguments: - hidden_size: Total dimension of hidden_size. - nhead: Number of parallel attention heads. - batch_size: Batch Size for one forward - max_seq_len: Max length of input sequence - dropout: Dropout probability - norm_first: perform LayerNorms before attention - """ - - layer_id = 0 - - def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): - super(MultiHeadAttention, self).__init__() - - self.config = Config( - batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16 - ) - check_config(self.config) - self.pg = pg - self.pg_size = 1 - if self.pg: - self.pg_size = pg.size() - self.config.layer_id = MultiHeadAttention.layer_id - MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1 - - # Load cuda modules if needed - global colossal_multihead_attention - if colossal_multihead_attention is None: - from colossalai.kernel.op_builder import MultiHeadAttnBuilder - - multihead_attention = MultiHeadAttnBuilder().load() - colossal_multihead_attention = multihead_attention - - # create the layer in cuda kernels. - cuda_module = colossal_multihead_attention - create_layer_func = ( - cuda_module.create_multihead_attention_fp16 - if self.config.fp16 - else cuda_module.create_multihead_attention_fp32 - ) - - create_layer_func( - self.config.layer_id, - self.config.max_batch_tokens, - self.config.max_seq_len, - self.config.hidden_size, - self.config.nhead, - self.config.attn_prob_dropout_ratio, - self.config.hidden_dropout_ratio, - self.config.norm_first, - self.pg, - ) - - hs = self.config.hidden_size - - self.precision = torch.float32 - if self.config.fp16: - self.precision = torch.half - - self.hs_per_rank = int(hs / self.pg_size) - - self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs)) - self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank)) - self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank)) - self.out_proj_bias = nn.Parameter(torch.Tensor(hs)) - self.norm_weight = nn.Parameter(torch.Tensor(hs)) - self.norm_bias = nn.Parameter(torch.Tensor(hs)) - - self.reset_parameters() - torch.cuda.empty_cache() - - def calc_bound(self, w): - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) - bound = 1.0 / math.sqrt(fan_in) - return bound - - def reset_parameters(self): - hs = self.config.hidden_size - - nn.init.zeros_(self.out_proj_bias) - - nn.init.ones_(self.norm_weight) - nn.init.zeros_(self.norm_bias) - - if self.pg_size > 1: - rank_in_pg = torch.distributed.get_rank(self.pg) - attn_qkvw_global = torch.empty(hs * 3, hs) - attn_qkvb_global = torch.empty(hs * 3) - nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0)) - bound = self.calc_bound(attn_qkvw_global) - nn.init.uniform_(attn_qkvb_global, -bound, bound) - - attn_qkvw_global = attn_qkvw_global.cuda() - attn_qkvb_global = attn_qkvb_global.cuda() - torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg) - torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg) - attn_qkvw_global = attn_qkvw_global.cpu() - attn_qkvb_global = attn_qkvb_global.cpu() - - with torch.no_grad(): - self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[ - :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), : - ] - ) - self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[ - :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size) - ] - ) - - attn_ow_global = torch.empty(hs, hs) - nn.init.xavier_uniform_(attn_ow_global, 1.0) - attn_ow_global = attn_ow_global.cuda() - torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) - attn_ow_global = attn_ow_global.cpu() - with torch.no_grad(): - self.out_proj_weight.copy_( - attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)] - ) - - else: - attn_qkvw = self.in_proj_weight.view(-1, hs) - nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0)) - bound = self.calc_bound(attn_qkvw) - nn.init.uniform_(self.in_proj_bias, -bound, bound) - - nn.init.xavier_uniform_(self.out_proj_weight, 1.0) - - def state_dict(self, destination=None, prefix="", keep_vars=False): - destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars) - return destination - - def forward(self, hidden_states, encoder_padding_mask): - self.config.training = self.training - self.config.is_grad_enabled = torch.is_grad_enabled() - hidden_states = hidden_states.contiguous() - encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous() - - bs, sl, dim = hidden_states.size() - if bs * sl > self.config.max_batch_tokens: - raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") - if sl > self.config.max_seq_len: - raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") - if len(encoder_padding_mask.size()) == 1: - assert bs == 1 and sl == encoder_padding_mask.size(0) - else: - assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - - output = MultiHeadAttention1DFunc.apply( - hidden_states, - encoder_padding_mask, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.norm_weight, - self.norm_bias, - self.config, - ) - - return output.to(self.precision) diff --git a/colossalai/kernel/extensions b/colossalai/kernel/extensions new file mode 120000 index 000000000000..e8eb45a54893 --- /dev/null +++ b/colossalai/kernel/extensions @@ -0,0 +1 @@ +../../extensions \ No newline at end of file diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index 8bebad894ca4..d392649a62f2 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -1,7 +1,7 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear -from colossalai.utils import get_current_device from .bias_dropout_add import bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl @@ -46,11 +46,13 @@ def warmup_jit_fusion( ): """Compile JIT functions before the main training steps""" - embed = Embedding(vocab_size, hidden_size).to(get_current_device()) - linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) - linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) + embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device()) + linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device()) + linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device()) - x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) + x = torch.randint( + vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device() + ) x = embed(x) y, y_bias = linear_1(x) z, z_bias = linear_2(y) @@ -58,8 +60,8 @@ def warmup_jit_fusion( # prop and recomputation for bias_grad, input_grad in zip([True, True], [False, True]): for _ in range(10): - bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) - input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) + bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device()) + input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device()) bias.requires_grad, input_.requires_grad = bias_grad, input_grad bias_gelu_impl(input_, bias) @@ -69,9 +71,9 @@ def warmup_jit_fusion( # prop and recomputation for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): for _ in range(10): - input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) - residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) - bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) + input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device()) + residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device()) + bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device()) input_.requires_grad = input_grad bias.requires_grad = bias_grad residual.requires_grad = residual_grad diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py new file mode 100644 index 000000000000..148c3e3fc08a --- /dev/null +++ b/colossalai/kernel/kernel_loader.py @@ -0,0 +1,109 @@ +import warnings +from typing import List + +from .extensions import ( + CpuAdamArmExtension, + CpuAdamX86Extension, + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionXformersCudaExtension, + FusedOptimizerCudaExtension, + LayerNormCudaExtension, + MoeCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, +) +from .extensions.base_extension import _Extension + +__all__ = [ + "KernelLoader", + "CPUAdamLoader", + "LayerNormLoader", + "MoeLoader", + "FusedOptimizerLoader", + "ScaledMaskedSoftmaxLoader", + "ScaledUpperTriangleMaskedSoftmaxLoader", +] + + +class KernelLoader: + """ + An abstract class which offers encapsulation to the kernel loading process. + + Usage: + kernel_loader = KernelLoader() + kernel = kernel_loader.load() + """ + + REGISTRY: List[_Extension] = [] + + @classmethod + def register_extension(cls, extension: _Extension): + """ + This classmethod is an extension point which allows users to register their customized + kernel implementations to the loader. + + Args: + extension (_Extension): the extension to be registered. + """ + cls.REGISTRY.append(extension) + + def load(self, ext_name: str = None): + """ + Load the kernel according to the current machine. + + Args: + ext_name (str): the name of the extension to be loaded. If not specified, the loader + will try to look for an kernel available on the current machine. + """ + exts = [ext_cls() for ext_cls in self.__class__.REGISTRY] + + # look for exts which can be built/loaded on the current machine + + if ext_name: + usable_exts = list(filter(lambda ext: ext.name == ext_name, exts)) + else: + usable_exts = [] + for ext in exts: + if ext.is_hardware_available(): + # make sure the machine is compatible during kernel loading + ext.assert_hardware_compatible() + usable_exts.append(ext) + + assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." + + if len(usable_exts) > 1: + # if more than one usable kernel is found, we will try to load the kernel with the highest priority + usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True) + warnings.warn( + f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}" + ) + return usable_exts[0].load() + + +class CPUAdamLoader(KernelLoader): + REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension] + + +class LayerNormLoader(KernelLoader): + REGISTRY = [LayerNormCudaExtension] + + +class MoeLoader(KernelLoader): + REGISTRY = [MoeCudaExtension] + + +class FusedOptimizerLoader(KernelLoader): + REGISTRY = [FusedOptimizerCudaExtension] + + +class ScaledMaskedSoftmaxLoader(KernelLoader): + REGISTRY = [ScaledMaskedSoftmaxCudaExtension] + + +class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): + REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension] + + +class FlashAttentionLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] diff --git a/colossalai/kernel/op_builder b/colossalai/kernel/op_builder deleted file mode 120000 index db4f9c335065..000000000000 --- a/colossalai/kernel/op_builder +++ /dev/null @@ -1 +0,0 @@ -../../op_builder \ No newline at end of file diff --git a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py index 97ec57fbd007..d2dceb50b240 100644 --- a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py @@ -7,7 +7,7 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler -from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.kernel.kernel_loader import FusedOptimizerLoader from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes @@ -28,7 +28,7 @@ def load_fused_optim(): global fused_optim if fused_optim is None: - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py index 0a8d09be21ea..08f867eee96c 100644 --- a/colossalai/legacy/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -1,18 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from colossalai.utils.device import autocast - import torch.nn as nn from torch import Tensor from torch.nn.modules.loss import _Loss from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.legacy.utils import clip_grad_norm_fp32 from ._grad_scaler import GradScaler +autocast = get_accelerator().autocast + class TorchAMPOptimizer(OptimizerWrapper): """A wrapper class which integrate Pytorch AMP with an optimizer diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index 19c3919b6e29..cf0bd4ba2437 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -8,9 +8,9 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks @@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): if isinstance(recv_shapes, torch.Size): recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) - buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + buffer_recv = torch.empty( + recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype + ) return buffer_recv, recv_split buffer_recv = [] for recv_shape in recv_shapes: recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) - tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) + tensor_recv = torch.empty( + recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype + ) buffer_recv.append(tensor_recv) return buffer_recv, recv_split diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py index a61dae56cd42..792a15abdfae 100644 --- a/colossalai/legacy/communication/ring.py +++ b/colossalai/legacy/communication/ring.py @@ -3,9 +3,9 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device, synchronize def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: @@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> current_rank = gpc.get_global_rank() tensor_recv_prev = torch.empty( - buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype + buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype ) # send to next rank @@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> req.wait() # To protect against race condition when using batch_isend_irecv(). - synchronize() + get_accelerator().synchronize() return tensor_recv_prev diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py index 6d77f3753fe8..0b7c0eb74651 100644 --- a/colossalai/legacy/communication/utils.py +++ b/colossalai/legacy/communication/utils.py @@ -3,9 +3,9 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device TensorShape = Union[torch.Size, List[int], Tuple[int]] @@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: if next_rank is None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()} if isinstance(obj, torch.Tensor): send_obj_nums = torch.tensor(1, **tensor_kwargs) dist.send(send_obj_nums, next_rank) @@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()} recv_obj_nums = torch.empty((), **tensor_kwargs) dist.recv(recv_obj_nums, prev_rank) if recv_obj_nums.item() == 1: diff --git a/colossalai/legacy/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py index 4a3ccfda1bb5..9b2913442225 100644 --- a/colossalai/legacy/engine/schedule/_base_schedule.py +++ b/colossalai/legacy/engine/schedule/_base_schedule.py @@ -6,8 +6,8 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device class BaseSchedule(ABC): @@ -29,12 +29,12 @@ def __init__(self, data_process_func: Callable = None): def _move_tensor(element): if torch.is_tensor(element): if not element.is_cuda: - return element.to(get_current_device()).detach() + return element.to(get_accelerator().get_current_device()).detach() return element def _move_to_device(self, data): if isinstance(data, torch.Tensor): - data = data.to(get_current_device()) + data = data.to(get_accelerator().get_current_device()) elif isinstance(data, (list, tuple)): data_to_return = [] for element in data: diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 5fd5602e790c..4a23853c137a 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -7,12 +7,12 @@ import torch.cuda import colossalai.legacy.communication as comm +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp.naive_amp import NaiveAMPModel from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.logging import get_dist_logger -from colossalai.utils.device import get_current_device from ._base_schedule import BaseSchedule @@ -352,7 +352,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None # Used for tensor meta information communication @@ -584,7 +584,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo if not forward_only: output_obj_grads = [[] for _ in range(len(model))] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index 4cd7e47c37f1..6e7760218c16 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -6,10 +6,10 @@ import torch.cuda import colossalai.legacy.communication.p2p_v2 as comm +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine import Engine -from colossalai.utils.device import get_current_device from ._pipeline_schedule import PipelineSchedule @@ -99,7 +99,7 @@ def forward_backward_step( output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) else: accum_loss = None diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py index 4035bd6b54ef..d99a7d3f0c65 100644 --- a/colossalai/legacy/initialize.py +++ b/colossalai/legacy/initialize.py @@ -15,6 +15,7 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +from colossalai.accelerator import get_accelerator from colossalai.context import Config, ConfigException from colossalai.interface import OptimizerWrapper from colossalai.legacy.amp import AMP_TYPE, convert_to_amp @@ -34,7 +35,6 @@ from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device def get_default_parser(): @@ -309,9 +309,9 @@ def initialize( else: if isinstance(model, nn.Module): # first sync model across dp ranks - model.to(get_current_device()) + model.to(get_accelerator().get_current_device()) elif isinstance(model, Callable): - model = model().to(get_current_device()) + model = model().to(get_accelerator().get_current_device()) # optimizer maybe a optimizer_cls if isinstance(optimizer, Callable): diff --git a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py index e1db0fe98a02..aa661664f4e8 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py @@ -3,8 +3,8 @@ from torch import dtype, nn +from colossalai.accelerator import get_accelerator from colossalai.nn import init -from colossalai.utils import get_current_device from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D @@ -83,7 +83,7 @@ def __init__( embed = ( nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) .to(dtype) - .to(get_current_device()) + .to(get_accelerator().get_current_device()) ) weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) elif num_embeddings <= vocab_parallel_limit: diff --git a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py index f8e317e723f1..58842f481a10 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/normalization.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py @@ -1,6 +1,6 @@ from torch import nn -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from ..parallel_1d import LayerNorm1D from ..parallel_2d import LayerNorm2D @@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule): def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: - norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device()) else: norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) super().__init__(norm) diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index b6ec5347f2e2..b38e1c4338b2 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -10,7 +10,7 @@ from torch import Tensor from torch.nn.parameter import Parameter -from colossalai.kernel import LayerNorm +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context.parallel_context import global_context as gpc @@ -22,7 +22,7 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule @@ -221,7 +221,7 @@ def __init__( # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -357,7 +357,7 @@ def __init__( # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -499,7 +499,7 @@ def __init__( # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) if bias: @@ -638,7 +638,7 @@ def __init__( # Parameters. # Initialize weight. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) if self.stream_chunk_num > 1: @@ -802,7 +802,9 @@ def __init__( self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -912,7 +914,11 @@ def __init__( self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim), + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.reset_parameters(weight_initializer) diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py index f1eff7128e7a..f67ee2e60be1 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py @@ -5,10 +5,10 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device def matmul_2d( @@ -250,7 +250,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -399,7 +399,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -556,7 +556,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index f81c5334ad77..4987afa18672 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc @@ -18,7 +19,6 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -82,7 +82,7 @@ def __init__( self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) ) @@ -259,7 +259,7 @@ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=N self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) # create parameters - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -438,18 +438,24 @@ def __init__( self.weight = Parameter( torch.empty( (self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.pos_embed = Parameter( torch.zeros( - (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + (1, self.num_patches + 1, self.embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) @@ -619,7 +625,9 @@ def __init__( self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -758,7 +766,7 @@ def __init__( self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) @@ -895,11 +903,18 @@ def __init__( self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.input_size_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -1052,7 +1067,7 @@ def __init__( self.output_size_per_partition = divide(num_classes, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py index 50900c135cab..43328bd033c8 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py @@ -5,10 +5,10 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.utils import get_current_device def get_parallel_group(parallel_mode: ParallelMode): @@ -205,7 +205,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) - C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -362,7 +362,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[0]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -527,7 +527,7 @@ def forward( B_shape = B.shape B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[-1], B.shape[-1]) - C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) + C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device()) # use circular buffer to store the communication tensor # 2 is enough for all cases @@ -661,7 +661,9 @@ def forward( if row_rank == 0: bias_temp = bias.clone() else: - bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) + bias_temp = torch.zeros( + output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device() + ) src_rank = ( col_rank + dep_rank * tesseract_dim**2 @@ -984,7 +986,7 @@ def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: Par @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: grad_shape = (ctx.batch_size,) + output_grad.shape[1:] - grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) + grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device()) dist.all_gather( list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) ) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index b451a4031c25..d9410f1cbcbc 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc @@ -19,7 +20,6 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple @@ -84,7 +84,7 @@ def __init__( self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter( torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) ) @@ -272,7 +272,7 @@ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=N self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * # create parameters - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -451,18 +451,24 @@ def __init__( self.weight = Parameter( torch.empty( (self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) - self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.pos_embed = Parameter( torch.zeros( - (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + (1, self.num_patches + 1, self.embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) @@ -632,7 +638,9 @@ def __init__( self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -772,7 +780,7 @@ def __init__( self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) @@ -910,11 +918,18 @@ def __init__( self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.input_size_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -1068,7 +1083,7 @@ def __init__( self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index 16e515f87da3..bb01ec85130a 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_reduce, broadcast from colossalai.legacy.constants import ( INPUT_GROUP_3D, @@ -27,7 +28,6 @@ partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import ( @@ -69,11 +69,13 @@ def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=N self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.weight = Parameter( - torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) ) if bias: self.bias = Parameter( - torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -202,13 +204,15 @@ def __init__( torch.empty( self.in_features_per_partition, self.out_features_per_partition, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -380,11 +384,18 @@ def __init__( self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, + self.in_features_per_partition, + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.has_weight = True if bias: - self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -523,14 +534,16 @@ def __init__( torch.empty( self.out_features_per_partition, self.in_features_per_partition, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) self.has_weight = True if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + torch.zeros( + self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype + ) ) else: self.bias = None @@ -705,16 +718,24 @@ def __init__( self.weight = nn.Parameter( torch.empty( - (embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype + (embed_size_per_partition, in_chans, *self.patch_size), + device=get_accelerator().get_current_device(), + dtype=dtype, ) ) - self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter( + torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype) + ) self.cls_token = nn.Parameter( - torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype) ) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, self.num_patches + 1, embed_size_per_partition), + device=get_accelerator().get_current_device(), + dtype=dtype, + ) ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -880,7 +901,9 @@ def __init__( self.embed_kwargs = kwargs self.weight = nn.Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + torch.empty( + (num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer) @@ -1019,7 +1042,7 @@ def __init__( self.weight = Parameter( torch.empty( (self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=dtype, ) ) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py index 24d5499e3a5f..4e9bf364d8eb 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py @@ -5,11 +5,11 @@ from torch import distributed as dist from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import ring_forward from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range -from colossalai.utils import get_current_device class RingQK(torch.autograd.Function): @@ -30,7 +30,7 @@ def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): sub_seq_length, sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), dtype=sub_q.dtype, - device=get_current_device(), + device=get_accelerator().get_current_device(), ) # compute local QK^T @@ -71,7 +71,7 @@ def backward(ctx, grad_output): grad_q = torch.zeros_like( sub_q, dtype=sub_q.dtype, - device=get_current_device(), + device=get_accelerator().get_current_device(), ) # compute with local sub_k @@ -105,7 +105,7 @@ def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attent batch_size * num_attention_heads, sub_seq_length, attention_head_size, - device=get_current_device(), + device=get_accelerator().get_current_device(), dtype=attention_score.dtype, ) @@ -142,7 +142,9 @@ def backward(ctx, grad_output): grad_v /= local_world_size # calculate gradient for attention score - grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) + grad_attention_score = torch.zeros_like( + attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device() + ) # compute with local sub_k grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py index 063b0cd8e2b2..445b7e4cda2a 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/layers.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py @@ -8,13 +8,12 @@ import torch.nn.functional as F from torch.nn import Parameter -from colossalai.kernel import FusedScaleMaskSoftmax -from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType from colossalai.legacy.context import seed from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK from colossalai.legacy.registry import LAYERS +from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax @LAYERS.register_module diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 590ad5ff6085..3a1c2e57b4be 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -7,10 +7,10 @@ from torch import nn as nn from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import seed from colossalai.legacy.registry import LAYERS from colossalai.nn import init as init -from colossalai.utils.device import get_current_device from ..utils import to_2tuple @@ -173,12 +173,18 @@ def __init__( self.flatten = flatten self.weight = nn.Parameter( - torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) + torch.empty( + (embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype + ) + ) + self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype)) + self.cls_token = nn.Parameter( + torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype) ) - self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) - self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) + torch.zeros( + (1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -242,11 +248,15 @@ def __init__( self.has_weight = False else: self.weight = nn.Parameter( - torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) + torch.empty( + self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype + ) ) self.has_weight = True if bias: - self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) + self.bias = nn.Parameter( + torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -287,7 +297,7 @@ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): self.normalized_shape = (normalized_shape,) self.variance_epsilon = eps - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) if bias: @@ -333,7 +343,7 @@ def __init__( self.in_features = in_features self.out_features = out_features self.skip_bias_add = skip_bias_add - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) if bias: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py index 44f39a6db262..474fd4a2cb9c 100644 --- a/colossalai/legacy/nn/loss/loss_2d.py +++ b/colossalai/legacy/nn/loss/loss_2d.py @@ -4,12 +4,12 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -118,7 +118,7 @@ def backward(ctx, output_grad): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py index c57bf26e9139..b423ab3d8699 100644 --- a/colossalai/legacy/nn/loss/loss_2p5d.py +++ b/colossalai/legacy/nn/loss/loss_2p5d.py @@ -4,12 +4,12 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -112,7 +112,7 @@ def backward(ctx, output_grad): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py index 988317cae3eb..de6a674d61db 100644 --- a/colossalai/legacy/nn/loss/loss_3d.py +++ b/colossalai/legacy/nn/loss/loss_3d.py @@ -4,12 +4,12 @@ from torch.nn.functional import cross_entropy from torch.nn.modules.loss import _Loss +from colossalai.accelerator import get_accelerator from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.registry import LOSSES -from colossalai.utils import get_current_device @LOSSES.register_module @@ -80,7 +80,7 @@ def forward(ctx, logits, targets, output_parallel_mode): target_mask = (targets < vocab_start) | (targets > vocab_end) masked_target = targets.clone() - vocab_start masked_target[target_mask] = 0 - arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device()) predicted_logits = logits[arange_1d, masked_target] predicted_logits = predicted_logits.clone().contiguous().view_as(targets) predicted_logits[target_mask] = 0.0 @@ -110,7 +110,7 @@ def backward(ctx, output_grad): grad_2d = input_grad.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device()) grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() input_grad.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py index 35a7f0a156ab..0e6731db5a77 100644 --- a/colossalai/legacy/trainer/hooks/_metric_hook.py +++ b/colossalai/legacy/trainer/hooks/_metric_hook.py @@ -7,12 +7,12 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_reduce from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import HOOKS from colossalai.legacy.utils import is_no_pp_or_last_stage -from colossalai.utils import get_current_device from ._base_hook import BaseHook from ._commons_ import _format_number @@ -82,8 +82,8 @@ class LossMetric(Metric): def __init__(self, epoch_only): super().__init__(epoch_only=epoch_only) - self.last_step_loss = torch.zeros(1, device=get_current_device()) - self.accum_loss = torch.zeros(1, device=get_current_device()) + self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) self.count = 0 def reset(self) -> None: @@ -164,10 +164,10 @@ class AccuracyMetric(Metric): def __init__(self, epoch_only: bool, accuracy_func: Callable): super().__init__(epoch_only=epoch_only) self.acc = accuracy_func - self.last_step_sum = torch.zeros(1, device=get_current_device()) - self.last_step_correct = torch.zeros(1, device=get_current_device()) - self.accumulated_sum = torch.zeros(1, device=get_current_device()) - self.accumulated_correct = torch.zeros(1, device=get_current_device()) + self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device()) def reset(self) -> None: self.last_step_sum.zero_() @@ -320,10 +320,10 @@ def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int super().__init__(epoch_only=epoch_only) self.ignored_steps = ignored_steps self.cur_steps = 0 - self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) - self.accumulated_used_time = torch.zeros(1, device=get_current_device()) - self.last_step_num_samples = torch.zeros(1, device=get_current_device()) - self.last_step_used_time = torch.zeros(1, device=get_current_device()) + self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device()) + self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device()) + self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device()) self._tflop_per_step = tflop_per_step self._use_local = use_local diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py index 9a8051ae937f..d1382cb1e36d 100644 --- a/colossalai/legacy/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -6,8 +6,8 @@ import torch from torch.utils.checkpoint import check_backward_validity, detach_variable +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states -from colossalai.utils.device import autocast, get_current_device def copy_to_device(obj, device): @@ -33,7 +33,7 @@ def forward(ctx, run_function, activation_offload=False, *args): check_backward_validity(args) ctx.run_function = run_function ctx.activation_offload = activation_offload - ctx.device = get_current_device() + ctx.device = get_accelerator().get_current_device() # preserve rng states ctx.fwd_cpu_rng_state = torch.get_rng_state() @@ -110,7 +110,7 @@ def backward(ctx, *args): inputs[idx] = tensors[i] detached_inputs = detach_variable(tuple(inputs)) if ctx.had_autocast_in_fwd: - with torch.enable_grad(), autocast(): + with torch.enable_grad(), get_accelerator().autocast()(): outputs = ctx.run_function(*detached_inputs) else: with torch.enable_grad(): @@ -226,7 +226,7 @@ def inner_unpack(packed): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: - with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks( + with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks( inner_pack, inner_unpack ): _unused = function(*args) @@ -245,7 +245,7 @@ def inner_unpack(packed): # get device if we need to offload the activation if activation_offload: - device = get_current_device() + device = get_accelerator().get_current_device() # run function with pack and unpack as saved_tensors_hooks with torch.autograd.graph.saved_tensors_hooks(pack, unpack): diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py index 671bcc3d6ad7..76ec08e96a6d 100644 --- a/colossalai/legacy/utils/common.py +++ b/colossalai/legacy/utils/common.py @@ -96,9 +96,9 @@ def _calc_l2_norm(grads): global fused_optim if fused_optim is None: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() norm = 0.0 if len(grads) > 0: diff --git a/colossalai/legacy/utils/memory.py b/colossalai/legacy/utils/memory.py index 2f99a7d2f72e..cfb22d3153d9 100644 --- a/colossalai/legacy/utils/memory.py +++ b/colossalai/legacy/utils/memory.py @@ -6,9 +6,9 @@ import torch.distributed as dist from packaging import version +from colossalai.accelerator import get_accelerator from colossalai.legacy.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device _GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CPU_MEM_CAPACITY = -1 @@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int: # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node if device.type == "cuda": - return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION + return ( + torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory + * _GLOBAL_CUDA_MEM_FRACTION + ) def colo_device_memory_used(device: torch.device) -> int: @@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None: return global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio - torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) + torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device()) def colo_set_cpu_memory_capacity(size: int) -> None: diff --git a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py index ad54b989f412..a9e3ffe1a2ec 100644 --- a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py @@ -8,7 +8,7 @@ from torch.autograd.profiler import profile from torch.distributed import ReduceOp -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time @@ -177,7 +177,7 @@ def close_profiler(self, group=None): assert current_comm_event is not None, "dist op has not been found" - buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device()) torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) current_comm_event.self_cuda_time = buffer.item() diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index e336717f4164..b0360880e7ad 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,7 +3,7 @@ from time import time from typing import List -from colossalai.utils.device import get_current_device +from colossalai.accelerator import get_accelerator from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy @@ -69,7 +69,7 @@ def adjust_layout(self) -> None: # move COMPUTE tensors to CUDA self._cpu_gpu_move_volume += cuda_demand for t in move_to_cuda_tensor_list: - colo_model_data_tensor_move_inline(t, get_current_device()) + colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device()) @property def cpu_gpu_move_volume(self): diff --git a/colossalai/legacy/zero/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py index 3aca80cfe56a..6fde91d4a3a3 100644 --- a/colossalai/legacy/zero/gemini/tensor_placement_policy.py +++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py @@ -5,8 +5,8 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity -from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector from .stateful_tensor import StatefulTensor @@ -38,7 +38,7 @@ def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) - class CUDATensorPlacementPolicy(TensorPlacementPolicy): def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" - super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) + super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: return 0, 0 @@ -78,7 +78,7 @@ def evict_tensors( int: the volume of memory that is evicted """ start = time() - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. diff --git a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py index b9d3071a877e..e5a35dea1b94 100644 --- a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -4,8 +4,8 @@ import torch.distributed as dist from torch._utils import _flatten_dense_tensors as flatten +from colossalai.accelerator import get_accelerator from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device from .tensor_shard_strategy import TensorShardStrategy @@ -30,9 +30,11 @@ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist. rank = dist.get_rank(process_group) for i in range(world_size): if i == rank: - buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) + buffer_list.append( + flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device()) + ) else: - buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) + buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device())) dist.all_gather(buffer_list, buffer_list[rank], group=process_group) # Move to target device before splitting buffer # Ensure we utilize maximum PCIE bandwidth diff --git a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py index ebaef774bd06..fb6ef534be56 100644 --- a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py @@ -3,11 +3,11 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils.commons import get_shard from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device class TensorShardStrategy(BaseShardStrategy): @@ -34,9 +34,9 @@ def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGr if t.is_sharded: return if t.payload.device.type == "cuda": - assert t.payload.device == get_current_device(), ( + assert t.payload.device == get_accelerator().get_current_device(), ( f"shard tensor on cuda device index {t.payload.device.index}," - f" but current cuda device is {get_current_device()}" + f" but current cuda device is {get_accelerator().get_current_device()}" ) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.payload_reset(sharded_payload) @@ -50,7 +50,9 @@ def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessG world_size = dist.get_world_size(process_group) rank = dist.get_rank(process_group) - buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) + buffer = torch.empty( + payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device() + ) buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) buffer_list[rank].copy_(t.payload) diff --git a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py index 85f2ac2159f4..bb7744a80851 100644 --- a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils.memory import colo_device_memory_capacity @@ -22,7 +23,7 @@ from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.logging import get_dist_logger -from colossalai.utils import disposable, get_current_device +from colossalai.utils import disposable from colossalai.zero.gemini.memory_tracer import MemStatsCollector from ._utils import ( @@ -212,8 +213,12 @@ def dump_memory_stats(self, filename: Optional[str] = "dump_mem_stats.log") -> N self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) if gpc.get_global_rank() == 0: with open(filename, "w+") as f: - f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") - f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") + f.write( + f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n" + ) + f.write( + f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n" + ) f.write("CUDA model data (GB)\n") f.write("\n") f.write("CUDA non model data (GB)\n") @@ -266,7 +271,8 @@ def _update_memstats(self): # model data is fixed in cuda during training. # cuda margin space can be used to store OS. self._cuda_margin_space = ( - colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + colo_device_memory_capacity(get_accelerator().get_current_device()) + - self._memstats_collector._memstats.max_overall_cuda ) @torch.no_grad() diff --git a/colossalai/legacy/zero/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py index 892e9f31ded4..332f44d5397b 100644 --- a/colossalai/legacy/zero/sharded_model/zero_hook.py +++ b/colossalai/legacy/zero/sharded_model/zero_hook.py @@ -3,13 +3,13 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.registry import OPHOOKS from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.legacy.zero.gemini.stateful_tensor import TensorState from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from colossalai.zero.gemini.memory_tracer import MemStatsCollector @@ -33,7 +33,7 @@ def __init__( self.process_group = process_group # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU - self.computing_device = get_current_device() + self.computing_device = get_accelerator().get_current_device() self._memstarts_collector = memstarts_collector self._stateful_tensor_mgr = stateful_tensor_mgr diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 721da69d0741..6dd0a5fc3c52 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,6 +1,7 @@ from .checkpoint import MoECheckpintIO from .experts import MLPExperts -from .layers import SparseMLP +from .layers import SparseMLP, apply_load_balance +from .manager import MOE_MANAGER from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator @@ -14,4 +15,6 @@ "UniformNoiseGenerator", "SparseMLP", "MoECheckpintIO", + "MOE_MANAGER", + "apply_load_balance", ] diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index c71e6c1f40c7..01c837ee36ad 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch import torch.distributed as dist @@ -11,9 +11,9 @@ def load_moe(): global MOE_KERNEL - from colossalai.kernel.op_builder import MOEBuilder + from colossalai.kernel.kernel_loader import MoeLoader - MOE_KERNEL = MOEBuilder().load() + MOE_KERNEL = MoeLoader().load() class AllGather(torch.autograd.Function): @@ -145,14 +145,8 @@ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: class HierarchicalAllToAll(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - inputs: Tensor, - groups: Tuple[ProcessGroup, ProcessGroup], - src_rank: int - ) -> Tensor: + def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor: """ Returns: outputs: Tensor @@ -276,8 +270,9 @@ def backward(ctx, tokens_grad): if tokens_grad.dtype != torch.float32: tokens_grad = tokens_grad.to(torch.float32) - d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, - mask, dest_idx) + d_expert, d_logits = MOE_KERNEL.combine_backward( + ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx + ) if d_expert.dtype != ctx.dtype: d_expert = d_expert.to(ctx.dtype) @@ -334,3 +329,68 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: if ctx.ep_size != 1: grad = grad / ctx.ep_size return grad, None + + +def _all_to_all( + inputs: torch.Tensor, + input_split_sizes: Optional[List[int]] = None, + output_split_sizes: Optional[List[int]] = None, + group=None, + async_op: bool = False, +): + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + outputs_shape = list(inputs.shape) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) + inputs = inputs.contiguous() + outputs = outputs.contiguous() + handle = dist.all_to_all_single( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op + ) + return outputs, handle + + +class AllToAllUneven(torch.autograd.Function): + @staticmethod + def forward( + ctx, + inputs, + input_split_sizes=None, + output_split_sizes=None, + group=None, + overlap: bool = False, + ): + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + ctx.input_split_sizes = input_split_sizes + ctx.output_split_sizes = output_split_sizes + ctx.group = group + return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) + + @staticmethod + def backward(ctx: Any, *grad_outputs): + return ( + _all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0], + None, + None, + None, + None, + ) + + +def all_to_all_uneven( + inputs: torch.Tensor, + input_split_sizes: Optional[List[int]] = None, + output_split_sizes: Optional[List[int]] = None, + group=None, + overlap: bool = False, +): + return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index a8c50eab66e3..b37ffabea41f 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -224,6 +224,7 @@ def save_sharded_model( size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. """ + torch.cuda.empty_cache() if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -265,6 +266,7 @@ def save_sharded_model( f"index located at {save_index_file}." ) dist.barrier() + torch.cuda.empty_cache() # ======================================================== # Abstract methods for optimizer loading/saving implementation @@ -332,10 +334,12 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" def _get_param_id_from_optimizer_param( - param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None ): if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param return optimizer.param_info["param2id"][id(working_param)] @@ -347,7 +351,7 @@ def _get_param_id_from_optimizer_param( master_to_working_map = optimizer.get_master_to_working_map() for pg in optimizer.optim.param_groups: for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) id_map[param_id] = param # Read checkpoint index file. @@ -371,14 +375,10 @@ def _get_param_id_from_optimizer_param( new_pg = copy.deepcopy(saved_pg) new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. updated_groups.append(new_pg) - # ep extra group - if MOE_MANAGER.parallel == "EP": + # ep param group + if len(optimizer.optim.param_groups) > len(saved_groups): new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = optimizer.optim.param_groups[-1][ - "params" - ] # Only keep the parameters kept by current pipeline stage. - for param in new_pg["params"]: - param.data = param.data.to(torch.float32) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] updated_groups.append(new_pg) optimizer.optim.__dict__.update({"param_groups": updated_groups}) @@ -389,7 +389,7 @@ def _get_param_id_from_optimizer_param( for param in pg["params"]: if param is None: continue - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) if param_id not in weight_map: continue filename = weight_map[param_id] @@ -400,27 +400,34 @@ def _get_param_id_from_optimizer_param( file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + + # Then shard the loaded optimizer states if using tp/zero. + for pid, state in list(state_dict.items()): + if pid in id_map: + param = id_map[pid] + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + elif ( + hasattr(optimizer, "moe_master_to_working_map") + and id(param) in optimizer.moe_master_to_working_map + ): + working_param = optimizer.moe_master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.pre_load_optim( + state, + working_param, + current_shape=working_param.shape, + original_shape=original_shape, + device="cpu", + inplace=True, + ) + state_dict[pid] = sharded_state + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - device = param.device - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.pre_load_optim( - state, - param, - current_shape=working_param.shape, - original_shape=original_shape, - device=device, - inplace=True, - ) - optimizer.optim.state[param] = sharded_state - sharded_optimizer_loading_epilogue(optimizer.optim) if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") @@ -576,6 +583,8 @@ def _optimizer_sharder( if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param @@ -618,6 +627,7 @@ def save_sharded_optimizer( prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file shard that store state tensors """ + torch.cuda.empty_cache() assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -723,6 +733,7 @@ def save_sharded_optimizer( f"You can find where each parameters has been saved in the " f"index located at {final_index_file_path}." ) + torch.cuda.empty_cache() def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 477b76547c7e..8e6ea3884df4 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -67,7 +67,11 @@ def __init__( self.ep_size = 1 if gated: - self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) + self.wi_gate = nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size + ) + ) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) else: self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index b768fb94a585..2ac5b186d116 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -51,6 +51,8 @@ def __init__( hidden_size: int, intermediate_size: int, router_top_k: int = 1, + router_loss: bool = True, + router_norm: bool = False, router_capacity_factor_train: float = 1.25, router_capacity_factor_eval: float = 2.0, router_min_capacity: int = 4, @@ -65,15 +67,19 @@ def __init__( enable_kernel: bool = False, enable_comm_overlap: bool = False, enable_hierarchical_comm: bool = False, + return_gate_logits: bool = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts self.gated = mlp_gated + self.return_gate_logits = return_gate_logits self.enable_kernel = enable_kernel self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() + self.router_loss = router_loss + self.router_norm = router_norm # moe router noisy_func = get_noise_generator(router_noisy_policy, num_experts) @@ -150,9 +156,8 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: tokens = inputs.reshape(-1, self.hidden_size) # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) + gate_logits = F.linear(tokens, self.gate_weight) + gate_output = gate_logits.to(torch.float) # update expert load if self.enable_load_balance == True: @@ -165,7 +170,12 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: # the result from the router used_capacity, *route_result_list = self.router( - inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) + inputs=gate_output, + use_kernel=self.enable_kernel, + ep_group=self.ep_group, + use_loss=self.router_loss, + use_norm=self.router_norm, + ) # dispatch_data: (num_experts, capacity, hidden_size) if self.enable_kernel: @@ -177,22 +187,15 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n" - "Please use Experts build function.") + raise NotImplementedError( + "This kind of communication has not been implemented yet.\n" "Please use Experts build function." + ) if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) @@ -204,7 +207,11 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) - return ans + + if self.return_gate_logits: + return ans, gate_logits + else: + return ans def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_in = expert_in.unsqueeze(0) @@ -212,10 +219,7 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: return expert_out def _ep_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ Expert Parallel @@ -228,10 +232,14 @@ def _ep_process( """ if not overlap or dist.get_world_size(self.ep_group) == 1: if self.ep_hierarchical_group is not None: - expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_input = HierarchicalAllToAll.apply( + dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank + ) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) - expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_output = HierarchicalAllToAll.apply( + expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank + ) return expert_output else: expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] @@ -249,7 +257,7 @@ class Capsule: NUM_CHUNK = 4 NUM_STAGES = 4 - assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" + assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) @@ -262,13 +270,15 @@ class Capsule: for i in range(NUM_CHUNK + NUM_STAGES - 1): if expert_out is not None: expert_out.handle.wait() - output[:, :, offset:offset + chunk_size, :] = expert_out.data + output[:, :, offset : offset + chunk_size, :] = expert_out.data offset += chunk_size expert_out = None # all2all last output if _expert_out is not None: - expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) + expert_out = Capsule( + *AllToAll.apply(_expert_out.data, self.ep_group, True), + ) _expert_out = None # all2all next input @@ -288,10 +298,7 @@ class Capsule: return output def _tp_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ without overlap: @@ -326,8 +333,9 @@ class Capsule: NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert ( + dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index c5bb508621b2..e40674c9bb44 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -8,9 +8,9 @@ import torch.nn.functional as F from torch.distributed import ProcessGroup +from colossalai.accelerator import get_accelerator from colossalai.moe._operation import moe_cumsum from colossalai.moe.manager import MOE_MANAGER -from colossalai.utils import get_current_device class MoeRouter(nn.Module, ABC): @@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True, - use_kernel: bool = False): + def __init__( + self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + use_kernel: bool = False, + ): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -43,9 +45,13 @@ def __init__(self, self._z_loss = None self.use_kernel = use_kernel - def get_capacity(self, logits_shape): + def get_capacity(self, num_tokens, num_experts, ep_group=None): + if ep_group is not None: + num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device()) + dist.all_reduce(num_tokens_tensor, group=ep_group) + num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group) capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts) capacity += capacity % 2 capacity = max(capacity, self.min_capacity) assert capacity > 0 @@ -68,8 +74,9 @@ def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, if router_probs.dim() == expert_indices.dim() == 2: router_probs = router_probs.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0) - assert router_probs.dim() == expert_indices.dim() == 3, \ - "router_probs must be 3D tensor and expert_indices must be 4D tensor" + assert ( + router_probs.dim() == expert_indices.dim() == 3 + ), "router_probs must be 3D tensor and expert_indices must be 4D tensor" # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. expert_mask = F.one_hot(expert_indices, num_experts) @@ -122,28 +129,39 @@ class Top1Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device()) + low=torch.tensor(0.0, device=get_accelerator().get_current_device()), + high=torch.tensor(1.0, device=get_accelerator().get_current_device()), ).rsample - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_loss: bool = False, + use_norm: bool = False, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -161,7 +179,8 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) num_experts = probs.size(-1) - capacity = self.get_capacity(inputs.shape) + num_tokens = inputs.size(0) + capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) @@ -200,7 +219,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() - return used_capacity, combine_weights, sec_mask + return used_capacity, combine_weights, sec_mask, probs class Top2Router(MoeRouter): @@ -216,20 +235,31 @@ class Top2Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation. """ - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) + + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_norm: bool = False, + use_loss: bool = True, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -246,8 +276,13 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) + if use_norm: + routing_weights, _ = torch.topk(probs, 2, dim=-1) + probs = probs / routing_weights.sum(dim=-1, keepdim=True) + num_experts = probs.size(-1) - capacity = self.get_capacity(inputs.shape) + num_tokens = inputs.size(0) + capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(probs, dim=-1) mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) @@ -255,21 +290,22 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - cmask = (mask1 + mask2) # loss: [s, e] - cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 + cmask = mask1 + mask2 # loss: [s, e] + cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss - expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) - self.set_aux_loss(probs, expert_indices, num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() + if use_loss: + expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) + self.set_aux_loss(probs, expert_indices, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) @@ -336,15 +372,18 @@ class TopKRouter(MoeRouter): oversubscribed / reach capacity. """ - def __init__(self, - num_selected_experts: int, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Optional[Callable] = None, - drop_tks: bool = True): - super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, - drop_tks) + def __init__( + self, + num_selected_experts: int, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True, + ): + super().__init__( + num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks + ) def forward( self, @@ -410,7 +449,7 @@ def forward( # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) + combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) return combine_array, dispatch_mask diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 5a17a6e0d769..c642f1a4450f 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -7,13 +7,12 @@ import torch.nn as nn import torch.nn.functional as F +from colossalai.accelerator import get_accelerator from colossalai.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor -from colossalai.utils import get_current_device class ForceFP32Parameter(torch.nn.Parameter): - def half(self, memory_format=None): return self.data.clone() @@ -30,8 +29,8 @@ class NormalNoiseGenerator: def __init__(self, num_experts: int): self.normal = torch.distributions.normal.Normal( - loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + loc=torch.tensor(0.0, device=get_accelerator().get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()), ).rsample def __call__(self, inputs: torch.Tensor): @@ -52,8 +51,8 @@ class UniformNoiseGenerator: def __init__(self, eps: float = 1e-2): self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, device=get_current_device()), + low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()), + high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()), ).rsample def __call__(self, inputs: torch.Tensor): @@ -84,6 +83,8 @@ def get_activation(act: str) -> Callable: return torch.nn.GELU() elif act == "swiglu": return SwiGLU + elif act == "silu": + return torch.nn.SiLU() else: raise NotImplementedError("Unsupported activation function") @@ -142,7 +143,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] epsize_param_dict = dict() for param in model.parameters(): if not is_moe_tensor(param): - ep_size = 1 # set ep_size to 1 for dp parameters + ep_size = 1 # set ep_size to 1 for dp parameters else: ep_size = get_ep_size(param) if ep_size not in epsize_param_dict: @@ -193,18 +194,13 @@ def create_ep_hierarchical_group( assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." nproc_per_node = int(nproc_per_node) else: - assert dist.get_world_size() % nproc_per_node == 0, \ - "nproc_per_node should be a divisor of world_size." + assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size." num_node = dist.get_world_size() // nproc_per_node intra_src_rank = None ep_intra_node_group = None for i in range(num_node): - ep_intra_ranks = [ - i * nproc_per_node + j - for j in range(nproc_per_node) - if j in ep_group_ranks - ] + ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks] group = dist.new_group(ep_intra_ranks) if rank in ep_intra_ranks: assert ep_intra_node_group is None @@ -212,10 +208,7 @@ def create_ep_hierarchical_group( intra_src_rank = ep_intra_ranks[0] ep_inter_node_group = None - ep_inter_ranks = [ - ep_group_ranks[0] + i * nproc_per_node - for i in range(num_node) - ] + ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)] if len(ep_inter_ranks) > 1: group = dist.new_group(ep_inter_ranks) if rank in ep_inter_ranks: diff --git a/colossalai/nn/layer/colo_attention.py b/colossalai/nn/layer/colo_attention.py new file mode 100644 index 000000000000..0b7011e8e2d8 --- /dev/null +++ b/colossalai/nn/layer/colo_attention.py @@ -0,0 +1,209 @@ +import enum +import math +import warnings +from dataclasses import dataclass +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange + +from colossalai.accelerator import get_accelerator +from colossalai.kernel.kernel_loader import FlashAttentionLoader + + +@dataclass +class SeqLenInfo: + seqlens: Iterable[int] = None + indices: torch.Tensor = None + max_seqlen: int = None + cu_seqlens: torch.Tensor = None + + @staticmethod + def materialize( + attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() + ): + if attn_mask is not None: + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() + else: + batch_size, tgt_len = size[0], size[1] + indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) + seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) + max_seqlen = max(seqlens) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) + return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class Unpad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, "b s ... -> (b s) ...") + ctx.shape = out.shape + # [ntokens, ...] + return out[indices] + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [ntokens, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) + # [b, s, ...] + return grad, None + + +class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + return out + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [b*s, ...] + grad = grad_output[indices] + # [ntokens, ...] + return grad, None, None, None + + +class ColoAttention(torch.nn.Module): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): + super().__init__() + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) + self.dropout = dropout + + self.attn = FlashAttentionLoader().load() + + @staticmethod + def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return Unpad.apply(tensor, indices) + + @staticmethod + def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: + return Repad.apply(tensor, indices, batch_size, seq_len) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + origin_attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None, + ): + """ + ColoAttention + + Args: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + origin_attn_mask: (nheads, q_seqlen, kv_seqlen) + bias: will not be used + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + # if flash attention is not applicable, switch to memory effcient attention + if self.attn.__name__ == "flash_attention" and ( + query.dtype not in [torch.float16, torch.bfloat16] or bias != None + ): + warnings.warn( + f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation." + ) + self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda") + + padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 + causal = attn_mask_type is not None and attn_mask_type.value > 1 + + batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] + # unpad + seq_len_info_q = None + seq_len_info_kv = None + if padded: + # bert style, unpad process + assert ( + attn_mask is not None + ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, ( + "attention mask is supposed to have shape (batch_size, seq_len), " + + f"but got {attn_mask.dim()} dimensions." + ) + + # bert style + if tgt_len == src_len: + seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query, key, value = self.unpad( + torch.stack([query, key, value], dim=2), seq_len_info_q.indices + ).unbind(dim=1) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + seq_len_info_kv = seq_len_info_q + else: + seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device) + seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) + if batch_size > 1: + query = rearrange(query, "b s ... -> c (b s) ...", c=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( + dim=1 + ) + else: + query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) + + out = self.attn( + query, + key, + value, + seq_len_info_q=seq_len_info_q, + seq_len_info_kv=seq_len_info_kv, + origin_attn_mask=origin_attn_mask, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded, + ) + + # repad + if padded: + if batch_size > 1: + out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) + out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) + + if len(out.shape) == 4: + out = rearrange(out, "b s h d -> b s (h d)") + return out diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/nn/layer/layernorm.py similarity index 95% rename from colossalai/kernel/cuda_native/layer_norm.py rename to colossalai/nn/layer/layernorm.py index c7d2a3a45022..1db48faee213 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/nn/layer/layernorm.py @@ -9,7 +9,7 @@ from torch.nn import init from torch.nn.parameter import Parameter -from colossalai.kernel.op_builder.layernorm import LayerNormBuilder +from colossalai.kernel.kernel_loader import LayerNormLoader try: from colossalai._C import layer_norm @@ -29,7 +29,7 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): global layer_norm if layer_norm is None: - layer_norm = LayerNormBuilder().load() + layer_norm = LayerNormLoader().load() output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.layernorm_op = layer_norm ctx.save_for_backward(input_, weight_, bias_, mean, invvar) diff --git a/colossalai/nn/layer/scaled_softmax.py b/colossalai/nn/layer/scaled_softmax.py new file mode 100644 index 000000000000..a8d72ddd90c9 --- /dev/null +++ b/colossalai/nn/layer/scaled_softmax.py @@ -0,0 +1,184 @@ +# This code from NVIDIA Megatron: +# with minor changes. + +import enum + +import torch +import torch.nn as nn + +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + global scaled_upper_triang_masked_softmax + if scaled_upper_triang_masked_softmax: + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + scale_t = torch.tensor([scale]) + + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + Fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: Flag to indicate if input in fp16 data format. + input_in_bf16: Flag to indicate if input in bf16 data format. + attn_mask_type: Attention mask type (pad or causal) + scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion + mask_func: Mask function to be applied. + softmax_in_fp32: If True, softmax in performed at fp32 precision. + scale: Scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type.value > 1: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type.value > 1: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + def get_batch_per_block(self, sq, sk, b, np): + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py index 9d1d8f01dd2d..e55e82280a5f 100644 --- a/colossalai/nn/lr_scheduler/delayed.py +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -6,6 +6,8 @@ else: from torch.optim.lr_scheduler import _LRScheduler +from colossalai.logging import get_dist_logger + class _enable_get_lr_call: def __init__(self, o): @@ -19,7 +21,39 @@ def __exit__(self, type, value, traceback): self.o._get_lr_called_within_step = False -class DelayerScheduler(_LRScheduler): +class TwoStageScheduler(_LRScheduler): + def __init__(self, optimizer, after_scheduler: _LRScheduler, last_epoch=-1): + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer, last_epoch) + + def state_dict(self): + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] + else: + raise NotImplementedError() + return state_dict + + def load_state_dict(self, state_dict): + if "after_scheduler_dict" not in state_dict: + logger = get_dist_logger() + logger.warning( + "after_scheduler_dict is not found, skip loading after_scheduler. This may cause unexpected behavior." + ) + else: + self.after_scheduler.load_state_dict(state_dict["after_scheduler_dict"]) + state_dict = { + key: value + for key, value in state_dict.items() + if key not in ("after_scheduler_type", "after_scheduler_dict") + } + super().load_state_dict(state_dict) + + +class DelayerScheduler(TwoStageScheduler): """Starts with a flat lr schedule until it reaches N epochs then applies the specific scheduler (For example: ReduceLROnPlateau) @@ -35,19 +69,7 @@ def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): if delay_epochs < 0: raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") self.delay_epochs = delay_epochs - self.after_scheduler = after_scheduler - self.finished = False - super().__init__(optimizer, last_epoch) - - def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} - if isinstance(state_dict["after_scheduler"], _LRScheduler): - state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ - state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() - del state_dict["after_scheduler"] - else: - raise NotImplementedError() - return state_dict + super().__init__(optimizer, after_scheduler, last_epoch) def get_lr(self): if self.last_epoch >= self.delay_epochs: @@ -71,7 +93,7 @@ def step(self, epoch=None): return super(DelayerScheduler, self).step(epoch) -class WarmupScheduler(_LRScheduler): +class WarmupScheduler(TwoStageScheduler): """Starts with a linear warmup lr schedule until it reaches N epochs then applies the specific scheduler (For example: ReduceLROnPlateau). @@ -85,19 +107,7 @@ class WarmupScheduler(_LRScheduler): def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): self.warmup_epochs = int(warmup_epochs) - self.after_scheduler = after_scheduler - self.finished = False - super().__init__(optimizer, last_epoch) - - def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} - if isinstance(state_dict["after_scheduler"], _LRScheduler): - state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ - state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() - del state_dict["after_scheduler"] - else: - raise NotImplementedError() - return state_dict + super().__init__(optimizer, after_scheduler, last_epoch) def get_lr(self): if self.last_epoch >= self.warmup_epochs: @@ -120,7 +130,7 @@ def step(self, epoch=None): return super().step(epoch) -class WarmupDelayerScheduler(_LRScheduler): +class WarmupDelayerScheduler(TwoStageScheduler): """Starts with a linear warmup lr schedule until it reaches N epochs and a flat lr schedule until it reaches M epochs then applies the specific scheduler (For example: ReduceLROnPlateau). @@ -140,19 +150,7 @@ def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}") self.warmup_epochs = warmup_epochs self.delay_epochs = delay_epochs - self.after_scheduler = after_scheduler - self.finished = False - super().__init__(optimizer, last_epoch) - - def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} - if isinstance(state_dict["after_scheduler"], _LRScheduler): - state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ - state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() - del state_dict["after_scheduler"] - else: - raise NotImplementedError() - return state_dict + super().__init__(optimizer, after_scheduler, last_epoch) def get_lr(self): if self.last_epoch >= self.warmup_epochs + self.delay_epochs: diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 7d53a1dd6834..5be629fb2045 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,10 +1,9 @@ import math -import platform from typing import Optional import torch -from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder +from colossalai.kernel.kernel_loader import CPUAdamLoader from .nvme_optimizer import NVMeOptimizer @@ -78,7 +77,7 @@ def __init__( default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() + cpu_adam = CPUAdamLoader().load() # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index fcdd3257d700..aeb5cc91bb9e 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -70,9 +70,9 @@ def __init__( self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index 3e1d5a7ba539..da8d1608a072 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -77,9 +77,9 @@ def __init__( ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 95a6354208a8..3fae9bbca765 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -72,9 +72,9 @@ def __init__( self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() # Skip buffer self._dummy_overflow_buf = torch.tensor( diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index d34fd601ab25..c9c1f81bfc9a 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -2,7 +2,7 @@ import torch -from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.kernel.kernel_loader import FusedOptimizerLoader from colossalai.utils import multi_tensor_applier from .cpu_adam import CPUAdam @@ -85,7 +85,7 @@ def __init__( nvme_offload_dir, ) if torch.cuda.is_available(): - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 72480526bd5c..20f316c2ae48 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -7,10 +7,10 @@ from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from .base import PipelineSchedule @@ -86,7 +86,7 @@ def load_micro_batch(self) -> Any: """ micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def _prepare_inputs_for_interval_stage(self): """ diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 53fc43040831..a4ace5e1baad 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -6,10 +6,11 @@ from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device +from colossalai.utils import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -100,7 +101,7 @@ def load_micro_batch(self, model_chunk_id: int) -> Any: assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) self.microbatch_offset[model_chunk_id] += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index d69f28e74be9..bf2f01b10e9b 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -6,10 +6,11 @@ from torch.nn import Module from torch.utils._pytree import tree_map +from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.device import get_current_device +from colossalai.utils import get_current_device from ._utils import ( detach, @@ -110,7 +111,7 @@ def load_micro_batch(self) -> Any: assert self.microbatch_offset <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) self.microbatch_offset += self.microbatch_size - return tree_map(partial(to_device, device=get_current_device()), micro_batch) + return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def recv_forward(self, prev_rank: int = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. @@ -317,7 +318,7 @@ def run_forward_only( accum_loss = None if return_loss and self.stage_manager.is_last_stage(): - accum_loss = torch.scalar_tensor(0, device=get_current_device()) + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None for _ in range(self.num_microbatches): diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 4b6343adcd3b..0d2cc1b3370d 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -6,7 +6,8 @@ from torch import nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup, get_world_size -from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed + +from colossalai.accelerator import get_accelerator class SeqParallelUtils: @@ -109,10 +110,10 @@ def __init__(self, seed: int): # 1. get the current rng state # 2. set the seed and store the rng state # 3. recover the original rng state - device_original_rng_state = get_rng_state() - manual_seed(seed) - self.device_rng_state = get_rng_state() - set_rng_state(device_original_rng_state) + device_original_rng_state = get_accelerator().get_rng_state() + get_accelerator().manual_seed(seed) + self.device_rng_state = get_accelerator().get_rng_state() + get_accelerator().set_rng_state(device_original_rng_state) # to the same for cpu rng state cpu_original_rng_state = torch.get_rng_state() @@ -121,10 +122,10 @@ def __init__(self, seed: int): torch.set_rng_state(cpu_original_rng_state) def _set_device_rng_state(self, rng_state): - set_rng_state(rng_state) + get_accelerator().set_rng_state(rng_state) def _get_device_rng_state(self): - current_state = get_rng_state() + current_state = get_accelerator().get_rng_state() return current_state def _set_cpu_rng_state(self, rng_state): @@ -209,7 +210,7 @@ def is_randomizer_index_synchronized(process_group: ProcessGroup = None): index = Randomizer.index() if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] @@ -231,7 +232,7 @@ def synchronize_index(process_group: ProcessGroup = None): if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index 00b2037fbdc8..d5c10541a28f 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -62,7 +62,7 @@ def forward( def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.nn.layer.colo_attention import ColoAttention def forward( self: Blip2Attention, diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index c8a311df7c6d..d13bd34926a5 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -14,7 +14,7 @@ def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8f456353742c..055e3096d794 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -719,7 +719,7 @@ def gpt2_for_sequence_classification_forward( def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def split_heads(tensor, num_heads, attn_head_size): """ diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index ad51bf2c709b..22b0f7a90656 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -530,7 +530,7 @@ def gptj_for_question_answering_forward( def get_gptj_flash_attention_forward(): from transformers.models.gptj.modeling_gptj import GPTJAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def split_heads(tensor, num_attention_heads, attn_head_size, rotary): """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1b53ce4afebb..e10a7ed7da0c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F -import torch.distributed as dist from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -15,14 +14,17 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig + from ..layer import cross_entropy_1d try: from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask + LATEST_VERSION = True except ImportError: LATEST_VERSION = False + class LlamaPipelineForwards: """ This class serves as a micro library for forward function substitution of Llama models @@ -203,7 +205,7 @@ def llama_for_causal_lm_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None + shard_config: ShardConfig = None, ): r""" Args: @@ -279,12 +281,13 @@ def llama_for_causal_lm_forward( if shard_config.enable_tensor_parallelism: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -417,7 +420,7 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(shard_config: ShardConfig): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention llama_version = 2 try: @@ -480,7 +483,12 @@ def forward( attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( - query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + query_states, + key_states, + value_states, + attn_mask=flash_attention_mask, + attn_mask_type=attn_mask_type, + origin_attn_mask=attention_mask, ) attn_output = self.o_proj(attn_output) @@ -492,7 +500,7 @@ def forward( def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import LlamaForCausalLM - + def forward( self: LlamaForCausalLM, input_ids: torch.LongTensor = None, @@ -573,12 +581,13 @@ def forward( if shard_config.enable_tensor_parallelism: new_vocab_size = logits.shape[-1] shift_logits = shift_logits.view(-1, new_vocab_size) - loss = cross_entropy_1d(shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group) + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size) loss = loss_fct(shift_logits, shift_labels) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -590,4 +599,5 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + return forward diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 1ddb26c25d5c..0da1a35a0278 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -6,7 +6,7 @@ def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def forward( self: MistralAttention, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 71f2ca3353bc..7f6cbbbcf4f3 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -514,7 +514,7 @@ def opt_for_question_answering_forward( def get_opt_flash_attention_forward(): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def forward( self: OPTAttention, diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index f67aa84e4e72..dcb1785207eb 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -593,10 +593,6 @@ def t5_encoder_model_forward( def get_t5_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from transformers.models.t5.modeling_t5 import T5Attention def forward( @@ -632,11 +628,11 @@ def forward( def shape(states): """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) def unshape(states): """reshape""" - return states.view(batch_size, -1, self.inner_dim) + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -653,8 +649,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=1) - elif past_key_value.shape[1] != key_value_states.shape[1]: + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning # cross-attn @@ -701,10 +697,15 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: position_bias_masked = position_bias - position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention( - query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 - ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout, + scale=1.0, + ) attn_output = unshape(attn_output) attn_output = self.o(attn_output) diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 5a50e7379cdc..ab141a74aef8 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -336,7 +336,7 @@ def pp_forward( def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.kernel.cuda_native import ColoAttention + from colossalai.nn.layer.colo_attention import ColoAttention def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 9827d4801f8d..cb8b45ae7d01 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -26,7 +26,7 @@ def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index f2eeb9d69c81..5c148880f980 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -49,7 +49,7 @@ def module_policy(self): if not self.model.config.new_decoder_architecture and self.model.config.multi_query: warnings.warn( - "Falcon dosen't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag." + "Falcon doesn't support tensor parallelism when (not new_decoder_architecture and multi_query) is True, will ignore the tensor parallelism flag." ) self.shard_config.enable_tensor_parallelism = False diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 1faa24f71e0b..42bf0825b045 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -46,7 +46,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c16aa6deab3b..c0b8b3375836 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -35,7 +35,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( - "Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + "Mistral doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) if self.shard_config.enable_tensor_parallelism: @@ -136,7 +136,7 @@ def __init__(self) -> None: def module_policy(self): if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") return super().module_policy() @@ -160,7 +160,7 @@ def module_policy(self): } if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") policy.update(new_item) @@ -186,7 +186,7 @@ def module_policy(self): } if self.pipeline_stage_manager: - warnings.warn("Mistral dosen't support pipeline parallelism now.") + warnings.warn("Mistral doesn't support pipeline parallelism now.") policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index e2f3a829cc6f..a542808ba794 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -59,7 +59,7 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[OPTDecoder] = ModulePolicyDescription( diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 4d906e3f4c04..e183b0632f88 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -66,7 +66,7 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[T5Stack] = ModulePolicyDescription( @@ -263,7 +263,7 @@ def distribute_t5_layers( if num_decoder_layers == 0: return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages - # the number of stages distributed between encoder and decoder is optmized in this way: + # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 def objective(num_encoder_stages): diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 6ef0e3b34b2b..584d4e2652c0 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -33,7 +33,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False - warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: policy[ViTEmbeddings] = ModulePolicyDescription( diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 6dae99e8cedb..b5b5db79d9de 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -69,13 +69,13 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( - "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + "Whisper doesn't support sequence parallelism now, will ignore the sequence parallelism flag." ) # TODO using the jit fused add_and_dropout affect the accuracy if self.shard_config.enable_jit_fused: self.shard_config.enable_jit_fused = False - warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") + warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.") if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription( @@ -302,7 +302,7 @@ def distribute_whisper_layers( if num_decoder_layers == 0: return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages - # the number of stages distributed between encoder and decoder is optmized in this way: + # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 def objective(num_encoder_stages): diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 5301c87b9836..acb9fc4ae8fc 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -7,11 +7,12 @@ from .colo_tensor import _convert_output -WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point} +WHITE_LIST_FUNCS = {torch.Tensor.__getitem__} +NO_HOOK_FUNCS = {torch.Tensor.is_floating_point} def is_no_hook_op(func) -> bool: - return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS + return (func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS) or func in NO_HOOK_FUNCS def filter_colo_parameters(*args, **kwargs): diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index ba6c77056222..5ac3c2b3a57e 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -26,3 +26,5 @@ def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1 self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) + self.ep_rank = self.pg.coordinate(self.ep_axis) + self.dp_rank = self.pg.coordinate(self.dp_axis) diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 1fe99cd89a4e..40de43c43b05 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -92,7 +92,10 @@ def pre_op(params: List[torch.Tensor], *args: Any) -> list: @staticmethod def post_op(params: List[torch.Tensor], arg: Any) -> Any: ColoParamOpHookManager._trigger_post_forward(params) - return PostFwdPreBwd.apply(params, arg) + # incase the output is a tuple, we have to flatten it + grad_args, other_args, grad_flags, spec = _flatten_grad_args(arg) + new_grad_args = PostFwdPreBwd.apply(params, *grad_args) + return _merge_args(new_grad_args, other_args, grad_flags, spec) @staticmethod def has_hook() -> bool: @@ -113,7 +116,7 @@ def backward(ctx, *grads): class PostFwdPreBwd(torch.autograd.Function): @staticmethod - def forward(ctx, params, args): + def forward(ctx, params, *args): ctx.params = params return args @@ -142,7 +145,6 @@ def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]: grad_args.append(arg) else: other_args.append(arg) - assert len(grad_args) > 0 return grad_args, other_args, grad_flags, spec diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 7cd24b0adc60..5f6864ff0059 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -9,7 +9,8 @@ import torch import torch.multiprocessing as mp from packaging import version -from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count + +from colossalai.accelerator import get_accelerator def parameterize(argument: str, values: List[Any]) -> Callable: @@ -199,7 +200,7 @@ def test_something(): def _wrap_func(f): def _execute_by_gpu_num(*args, **kwargs): - num_avail_gpu = device_count() + num_avail_gpu = get_accelerator().device_count() if num_avail_gpu >= min_gpus: f(*args, **kwargs) @@ -263,11 +264,11 @@ def test_something(): def _wrap_func(f): def _clear_cache(*args, **kwargs): - empty_cache() - reset_peak_memory_stats() - reset_max_memory_allocated() - reset_max_memory_cached() - synchronize() + get_accelerator().empty_cache() + get_accelerator().reset_peak_memory_stats() + get_accelerator().reset_max_memory_allocated() + get_accelerator().reset_max_memory_cached() + get_accelerator().synchronize() gc.collect() f(*args, **kwargs) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 0246a35e2a1b..cdba467091be 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -4,20 +4,16 @@ disposable, ensure_path_exists, free_storage, + get_current_device, is_ddp_ignored, set_seed, ) -from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize from .multi_tensor_apply import multi_tensor_applier from .tensor_detector import TensorDetector from .timer import MultiTimer, Timer __all__ = [ "conditional_context", - "get_current_device", - "synchronize", - "empty_cache", - "set_to_cuda", "Timer", "MultiTimer", "multi_tensor_applier", @@ -27,7 +23,6 @@ "_cast_float", "free_storage", "set_seed", + "get_current_device", "is_ddp_ignored", - "set_device", - "IS_NPU_AVAILABLE", ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index c43caaff4806..4a1889eb57ff 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -10,6 +10,15 @@ import numpy as np import torch +from colossalai.accelerator import get_accelerator + + +def get_current_device(): + """ + A wrapper function for accelerator's API for backward compatibility. + """ + return get_accelerator().get_current_device() + def ensure_path_exists(filename: str): # ensure the path exists diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py deleted file mode 100644 index c70dbdaa5ee1..000000000000 --- a/colossalai/utils/device.py +++ /dev/null @@ -1,223 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Any, Dict, List, Optional, Tuple, Callable - -import torch -import torch.distributed as dist - -IS_NPU_AVAILABLE: bool = False -try: - import torch_npu # noqa - - IS_NPU_AVAILABLE = torch.npu.is_available() -except ImportError: - pass - - -def set_to_cuda(models): - """Send model to gpu. - - :param models: nn.module or a list of module - """ - if isinstance(models, list) and len(models) > 1: - ret = [] - for model in models: - ret.append(model.to(get_current_device())) - return ret - elif isinstance(models, list): - return models[0].to(get_current_device()) - else: - return models.to(get_current_device()) - - -def get_current_device() -> torch.device: - """ - Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. - """ - if torch.cuda.is_available(): - return torch.device(f"cuda:{torch.cuda.current_device()}") - elif IS_NPU_AVAILABLE: - return torch.device(f"npu:{torch.npu.current_device()}") - else: - return torch.device("cpu") - - -def _dispatch_device_func(fn_name: str, *args, **kwargs): - if torch.cuda.is_available(): - return getattr(torch.cuda, fn_name)(*args, **kwargs) - elif IS_NPU_AVAILABLE: - return getattr(torch.npu, fn_name)(*args, **kwargs) - else: - raise RuntimeError("No device available") - - -# device semantics - - -def can_device_access_peer(device, peer_device) -> bool: - return _dispatch_device_func("can_device_access_peer", device, peer_device) - - -def current_device() -> int: - return _dispatch_device_func("current_device") - - -def current_stream(device=None): - return _dispatch_device_func("current_stream", device) - - -def default_stream(device=None): - return _dispatch_device_func("default_stream", device) - - -def device_count() -> int: - return _dispatch_device_func("device_count") - - -def get_device_capability(device=None) -> Tuple[int, int]: - return _dispatch_device_func("get_device_capability", device) - - -def get_device_name(device=None) -> str: - return _dispatch_device_func("get_device_name", device) - - -def get_device_properties(device): - return _dispatch_device_func("get_device_properties", device) - - -def set_device(index: Optional[int] = None) -> None: - if index is None: - index = dist.get_rank() % device_count() - _dispatch_device_func("set_device", index) - - -def set_stream(stream_): - return _dispatch_device_func("set_stream", stream_) - - -def stream(stream_): - return _dispatch_device_func("stream", stream_) - - -def synchronize(): - return _dispatch_device_func("synchronize") - - -def utilization(device=None) -> int: - return _dispatch_device_func("utilization", device) - - -# random number generator - - -def get_rng_state(device="cuda") -> torch.Tensor: - return _dispatch_device_func("get_rng_state", device) - - -def get_rng_state_all() -> List[torch.Tensor]: - return _dispatch_device_func("get_rng_state_all") - - -def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: - return _dispatch_device_func("set_rng_state", new_state, device) - - -def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: - return _dispatch_device_func("set_rng_state_all", new_states) - - -def manual_seed(seed: int) -> None: - return _dispatch_device_func("manual_seed", seed) - - -def manual_seed_all(seed: int) -> None: - return _dispatch_device_func("manual_seed_all", seed) - - -def seed() -> None: - return _dispatch_device_func("seed") - - -def seed_all() -> None: - return _dispatch_device_func("seed_all") - - -def initial_seed() -> int: - return _dispatch_device_func("initial_seed") - - -# streams and events - - -def Stream(device=None, priority=0, **kwargs): - return _dispatch_device_func("Stream", device, priority, **kwargs) - - -def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): - return _dispatch_device_func("Event", enable_timing, blocking, interprocess) - - -# memory management - - -def empty_cache() -> None: - return _dispatch_device_func("empty_cache") - - -def memory_stats(device=None) -> Dict[str, Any]: - return _dispatch_device_func("memory_stats", device) - - -def memory_summary(device=None, abbreviated=False) -> str: - return _dispatch_device_func("memory_summary", device, abbreviated) - - -def memory_snapshot(): - return _dispatch_device_func("memory_snapshot") - - -def memory_allocated(device=None) -> int: - return _dispatch_device_func("memory_allocated", device) - - -def max_memory_allocated(device=None) -> int: - return _dispatch_device_func("max_memory_allocated", device) - - -def reset_max_memory_allocated(device=None) -> None: - return _dispatch_device_func("reset_max_memory_allocated", device) - - -def reset_max_memory_cached(device=None) -> None: - return _dispatch_device_func("reset_max_memory_cached", device) - - -def memory_reserved(device=None) -> int: - return _dispatch_device_func("memory_reserved", device) - - -def max_memory_reserved(device=None) -> int: - return _dispatch_device_func("max_memory_reserved", device) - - -def set_per_process_memory_fraction(fraction: float, device=None) -> None: - return _dispatch_device_func("set_per_process_memory_fraction", fraction, device) - - -def reset_peak_memory_stats(device=None) -> None: - return _dispatch_device_func("reset_peak_memory_stats", device) - - -# amp - - -def autocast() -> Callable: - if torch.cuda.is_available(): - return torch.cuda.amp.autocast() - elif IS_NPU_AVAILABLE: - return torch.npu.amp.autocast() - else: - raise RuntimeError("No device available") diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 2f7ccc24c0fc..2feded7751ea 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -3,7 +3,7 @@ import time from typing import Tuple -from .device import synchronize +from colossalai.accelerator import get_accelerator class Timer: @@ -21,13 +21,13 @@ def has_history(self): @property def current_time(self) -> float: - synchronize() + get_accelerator().synchronize() return time.time() def start(self): """Firstly synchronize cuda, reset the clock and then start the timer.""" self._elapsed = 0 - synchronize() + get_accelerator().synchronize() self._start_time = time.time() self._started = True @@ -44,7 +44,7 @@ def stop(self, keep_in_history: bool = False): Returns: int: Start-stop interval. """ - synchronize() + get_accelerator().synchronize() end_time = time.time() elapsed = end_time - self._start_time if keep_in_history: diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index defc6c4cb150..cad2622f2851 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -6,8 +6,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import get_current_device -from colossalai.utils.device import IS_NPU_AVAILABLE +from colossalai.accelerator import get_accelerator class TensorState(Enum): @@ -107,7 +106,7 @@ def __init__( self.valid_end = self.shard_size self.dtype = dtype - device = init_device or get_current_device() + device = init_device or get_accelerator().get_current_device() # chunk_temp is a global chunk, which only exists during building the chunks. self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero @@ -125,7 +124,7 @@ def __init__( # configure the init device of the shard # no-offload default: fp16, fp32 -> CUDA # offload default: fp16, fp32 -> CPU - self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() + self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device() self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() self.shard_mem = self.chunk_mem // self.pg_size @@ -191,11 +190,10 @@ def memory_usage(self) -> Dict[str, int]: def device_type(self) -> str: if self.chunk_temp is not None: return self.chunk_temp.device.type + elif self.is_gathered or self.cuda_shard is not None: + return get_accelerator().name else: - if self.is_gathered or self.cuda_shard is not None: - return "npu" if IS_NPU_AVAILABLE else "cuda" - else: - return "cpu" + return "cpu" @property def payload(self) -> torch.Tensor: @@ -297,7 +295,7 @@ def close_chunk(self): self.valid_end = self.utilized_size - self.shard_begin if self.chunk_temp.device.type == "cpu": - self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) + self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device()) self.__update_tensors_ptr() else: self.cuda_global_chunk = self.chunk_temp @@ -334,12 +332,12 @@ def shard_move(self, device: torch.device, force_copy: bool = False): return if device.type == "cuda" or device.type == "npu": - assert device == get_current_device(), "can't move chunk to another device" + assert device == get_accelerator().get_current_device(), "can't move chunk to another device" if self.cuda_shard: return - self.cuda_shard = self.cpu_shard.to(get_current_device()) + self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device()) if not self.pin_memory: self.cpu_shard = None @@ -394,7 +392,9 @@ def reduce(self): if self.extra_dp_group is not None: dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) else: - self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) + self.cuda_shard = torch.empty( + self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() + ) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) @@ -533,7 +533,7 @@ def __paired_shard_move(self): # only be called when optimizer state is in CPU memory # the grad and param should be in the same device assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_current_device()) + temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device()) # avoid to transform FP32 in CPU self.cuda_shard = temp.to(self.dtype) @@ -631,7 +631,7 @@ def init_grad_chunk(self) -> "Chunk": grad_chunk.valid_end = self.valid_end if grad_chunk.chunk_temp.device.type == "cpu": - grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device()) + grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device()) else: grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp grad_chunk.chunk_temp = None diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 5f4f37c267aa..5bc662a6189c 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -5,7 +5,8 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import free_storage, get_current_device +from colossalai.accelerator import get_accelerator +from colossalai.utils import free_storage from .chunk import Chunk, ChunkFullError, TensorState @@ -20,7 +21,7 @@ class ChunkManager: """ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: - self.device = init_device or get_current_device() + self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration for k, v in self.kwargs_config.items(): @@ -107,7 +108,7 @@ def access_chunk(self, chunk: Chunk) -> None: return self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": - chunk.shard_move(get_current_device()) + chunk.shard_move(get_accelerator().get_current_device()) self.__add_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) @@ -276,7 +277,10 @@ def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) else: accumulated_grad = ( - chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) + chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device()) + .clone() + .detach() + .mul_(chunk.pg_size) ) accumulated_grad_gathered = False diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5217b8036bcd..bc6c9d088094 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group +from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor @@ -27,7 +28,7 @@ is_distributed_tensor, ) from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored +from colossalai.utils import _cast_float, free_storage, is_ddp_ignored from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -725,11 +726,13 @@ def load_parameter(chunk_slice, data): chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) del temp_chunk - if self.reuse_fp16_chunk: - for chunk_32 in chunk_list: - chunk_16 = chunk_32.paired_chunk - assert chunk_16 is not None - chunk_16.payload.copy_(chunk_32.payload) + + # sync running weights and master weights + if self.master_weights: + for loaded_chunk in chunk_list: + paired_chunk = loaded_chunk.paired_chunk + assert paired_chunk is not None + paired_chunk.payload.copy_(loaded_chunk.payload) for name, buf in persistent_buffers.items(): if buf is not None: @@ -766,7 +769,7 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi # move ignored parameters to CUDA if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) + p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision) continue # create a fp16 parameter @@ -815,7 +818,7 @@ def _cast_buffers(self): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() - buffer.data = buffer.to(get_current_device()) + buffer.data = buffer.to(get_accelerator().get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index ad94593395bb..18367af59d80 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -11,6 +11,7 @@ from torch.nn import Parameter from torch.optim import Optimizer +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import OptimizerWrapper @@ -26,7 +27,7 @@ is_customized_distributed_tensor, is_distributed_tensor, ) -from colossalai.utils import disposable, get_current_device, is_ddp_ignored +from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP @@ -233,7 +234,7 @@ def _calc_global_norm(self) -> float: grad_chunk.l2_norm = None # clear l2 norm - comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) + comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) for group, part_norm in group_to_norm.items(): comm_buffer.fill_(part_norm) dist.all_reduce(comm_buffer, group=group) @@ -314,10 +315,10 @@ def _maybe_move_fp32_params(self): continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: - self.chunk_manager.move_chunk(chunk32, get_current_device()) + self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device()) # stores grad now - self.chunk_manager.move_chunk(chunk16, get_current_device()) - self.module.set_chunk_grad_device(chunk16, get_current_device()) + self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device()) + self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device()) fp32_params_used_cuda_margin_mem += chunk32.payload_mem for group in self.param_groups: @@ -328,7 +329,7 @@ def _maybe_move_fp32_params(self): state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): - state[k] = v.to(get_current_device()) + state[k] = v.to(get_accelerator().get_current_device()) def _register_states_(self): for group in self.optim.param_groups: @@ -551,7 +552,7 @@ def pack_optimizer_states_to_tensor( self, param_id: int, state_names: list, - device: torch.device = get_current_device(), + device: torch.device = get_accelerator().get_current_device(), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -620,7 +621,10 @@ def get_param_groups_for_saving(self) -> list: Return the param_groups in Pytorch format when saving to checkpoint. """ - param_groups = copy.deepcopy(self.param_groups_backup) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(self.optim.param_groups, self.param_groups_backup) + ] # To be compatible with pytorch checkpointing, # store extra hyperparameters used by pytorch Adam optimizer. diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index b5e40a817e58..e302805dfbb7 100644 --- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,6 +1,6 @@ from typing import Optional -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from colossalai.zero.gemini.chunk import ChunkManager from .memory_stats import MemStats @@ -33,4 +33,4 @@ def record_model_data_volume(self) -> None: def cuda_margin_mem(self) -> float: from colossalai.legacy.utils.memory import colo_device_memory_capacity - return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda + return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py index 513a6326d5f1..82c8e9dab098 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -5,7 +5,7 @@ import torch -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator class MemoryMonitor: @@ -77,7 +77,7 @@ def __init__(self, power: int = 10): super().__init__() self.keep_measuring = False - current_device = get_current_device() + current_device = get_accelerator().get_current_device() def _set_cuda_device(): torch.cuda.set_device(current_device) @@ -116,7 +116,7 @@ def _measure_usage(self): while self.keep_measuring: max_usage = max( max_usage, - colo_device_memory_used(get_current_device()), + colo_device_memory_used(get_accelerator().get_current_device()), ) sleep(self.interval) return max_usage diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index c410ad3793c9..388999549bd8 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -6,8 +6,8 @@ import torch -from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.accelerator import get_accelerator +from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager @@ -85,7 +85,7 @@ def setup_grads_device( # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: - device = get_current_device() + device = get_accelerator().get_current_device() else: device = torch.device("cpu") # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here @@ -140,7 +140,7 @@ def evict_tensors( int: the volume of memory that is evicted """ start = time() - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) used_cuda_model_data = self.chunk_manager.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. @@ -194,7 +194,7 @@ def setup_grads_device( # init offload optim settings # keep gathered chunks are in CUDA if chunk.keep_gathered: - grads_device_map[p] = get_current_device() + grads_device_map[p] = get_accelerator().get_current_device() else: grads_device_map[p] = torch.device("cpu") diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 5305953fe1ee..b563ea5b2de6 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -6,7 +6,7 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator from .chunk import Chunk @@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype): if chunk.cuda_shard is not None: shard_temp = chunk.cuda_shard else: - shard_temp = chunk.cpu_shard.to(get_current_device()) + shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device()) shard_temp = shard_temp.to(dtype) - total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device()) + total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device()) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 891cae65a47c..a2433d1b261c 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -12,7 +12,7 @@ from torch.distributed import ProcessGroup from torch.optim import Optimizer -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import ( BF16MixedPrecisionMixin, FP16MixedPrecisionMixin, @@ -22,9 +22,6 @@ from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor -# from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device - from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -144,7 +141,7 @@ def __init__( # because they have different parallel strategy # so we need to store them separately in param_groups # instead of working_groups - moe_params = list() + self.working_moe_params = list() # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -156,7 +153,7 @@ def __init__( if self.moe_extra_dp_pg is None: # skip moe param if is_moe_tensor(param): - moe_params.append(param) + self.working_moe_params.append(param) continue group_params.append(param) @@ -171,19 +168,29 @@ def __init__( # managed by this data parallel rank param_group["params"] = master_param_current_rank - # if there are moe params, store in additional group in optim - if len(moe_params) > 0: + # if there are moe params, store in addtional group in optim + if len(self.working_moe_params) > 0: + self._sync_master_param = False param_group = dict() + # create fp32 master param for key, value in self.optim.param_groups[0].items(): if key != "params": param_group[key] = value - param_group["params"] = moe_params + self.master_moe_params = [] + for param in self.working_moe_params: + self.master_moe_params.append(param.clone().to(torch.float32).detach()) + # create mapping from master to working for optimizer io + self.moe_master_to_working_map = {} + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param + # add to optim + param_group["params"] = self.master_moe_params self.optim.param_groups.append(param_group) # initialize communication stream for # communication-computation overlapping if self._overlap_communication: - self._comm_stream = device_utils.Stream() + self._comm_stream = get_accelerator().Stream() # reduction hook is only used if overlapping communication # or stage 2 is used @@ -217,7 +224,7 @@ def num_param_groups(self): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required" + assert get_accelerator().name in ["cuda", "npu"], "device is required" for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: @@ -228,7 +235,7 @@ def _sanity_checks(self): def _create_master_param_current_rank(self, param_list): # split each param evenly by world size params_current_rank = [] - device = "cpu" if self._cpu_offload else get_current_device() + device = "cpu" if self._cpu_offload else get_accelerator().get_current_device() for param in param_list: padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size @@ -340,11 +347,11 @@ def _run_reduction(self): if len(moe_grad_list) > 0: moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing - stream.wait_stream(device_utils.current_stream()) + stream.wait_stream(get_accelerator().current_stream()) else: - stream = device_utils.current_stream() + stream = get_accelerator().current_stream() - with device_utils.stream(stream): + with get_accelerator().stream(stream): group_id = self._bucket_store.current_group_id if self.moe_extra_dp_pg is None: @@ -486,7 +493,7 @@ def backward(self, loss, retain_graph=False): # clear reduced grads if self._overlap_communication: - device_utils.synchronize() + get_accelerator().synchronize() self.zero_grad() @@ -505,7 +512,7 @@ def backward_by_grad(self, tensor, grad): # clear reduced grads if self._overlap_communication: - device_utils.synchronize() + get_accelerator().synchronize() self.zero_grad() @@ -596,24 +603,40 @@ def step(self, closure=None): # update the params in the optimizer self.optim.param_groups[group_id]["params"] = real_master_params[group_id] + # update param for moe ep + # move grad to master param and compute norm + if len(self.working_moe_params) > 0: + moe_grads = [] + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + if master_moe_param.grad is not None: + raise RuntimeError("Moe param should not have grad here") + grad = working_moe_param.grad + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(master_moe_param.dtype).to(master_moe_param.device) + master_moe_param.grad = grad + working_moe_param.grad = None + moe_grads.append(grad) + grad_partition_groups.append(grad) + norm_group = self._compute_grad_norm(gradients=moe_grads) + norm_groups.append(norm_group) + self.optim.param_groups[-1]["params"] = self.master_moe_params + del moe_grads + # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) - # TODO: we should store master param for ep - if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]["params"]: - param.data = param.data.to(torch.float32) - param.grad = param.grad.to(torch.float32) - # update the parameters self.optim.step() - # release the moe gradm - if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]["params"]: - param.grad = None - param.data = param.data.to(self._dtype) + # release moe grad + if len(self.working_moe_params) > 0: + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.grad = None + working_moe_param.data = ( + master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach() + ) # release the grad grad_partition_groups = [] @@ -621,7 +644,7 @@ def step(self, closure=None): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - device = get_current_device() + device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): @@ -661,7 +684,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo norm_type = float(norm_type) if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float + ) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -673,7 +698,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo # Sum across all model parallel GPUs. total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float ) torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg @@ -765,7 +790,7 @@ def state_dict(self) -> Dict: Dict: the pytorch form state_dict """ zero_state = dict() - device = get_current_device() + device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): @@ -827,7 +852,7 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i ret_block = dict() ret_block_size = 0 - device = get_current_device() + device = get_accelerator().get_current_device() local_states = self.optim.state_dict()["state"] for param_idx, states in local_states.items(): current_block_size = 0 @@ -886,9 +911,14 @@ def update_master_params(self, model: nn.Module) -> None: master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) else: master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + if hasattr(self, "master_moe_params"): + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.copy_(working_moe_param) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + if hasattr(self, "moe_master_to_working_map"): + return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} return self._param_store.master_to_working_param diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 7a0e3b1a0276..e87eafb6eec7 100644 --- a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -45,7 +45,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device ``` ## Define Plugin Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously. @@ -149,7 +148,7 @@ model, optimizer, _criterion, _, lr_scheduler = booster.boost( ## Training GPT-2 using hybrid parallelism -In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. +In the previous tutorial, We've explained how to inject various parallelism features into the model and its training components using the Booster and `HybridParallelPlugin`. Now we can start model training. Define a training function. When pipeline parallelism is used, you need to call `booster.execute_pipeline` to schedule the stages of model training. ```python def train_epoch( @@ -204,4 +203,4 @@ Training the gpt-2 model for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/docs/source/en/get_started/installation.md b/docs/source/en/get_started/installation.md index 18607a34cf65..f9c8fe4758c8 100644 --- a/docs/source/en/get_started/installation.md +++ b/docs/source/en/get_started/installation.md @@ -23,7 +23,7 @@ pip install colossalai If you want to build PyTorch extensions during installation, you can use the command below. Otherwise, the PyTorch extensions will be built during runtime. ```shell -CUDA_EXT=1 pip install colossalai +BUILD_EXT=1 pip install colossalai ``` @@ -39,7 +39,7 @@ cd ColossalAI pip install -r requirements/requirements.txt # install colossalai -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `CUDA_EXT`: @@ -61,7 +61,7 @@ unzip 1.8.0.zip cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ # install -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md index 11740698057f..ae941b489b90 100644 --- a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md +++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md @@ -43,7 +43,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device ``` ### 定义plugin 定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1. @@ -201,4 +200,4 @@ def train_epoch( for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) ``` - \ No newline at end of file + diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md index fb6fd90ec4c2..481efe98ac12 100644 --- a/docs/source/zh-Hans/features/1D_tensor_parallel.md +++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md @@ -19,10 +19,8 @@ 当第二个线性层 $Z=YB$ 跟随上述列并行层的时候, 我们把 $B$ 划分为 $$ \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] -``` -这就是所谓的行并行方式. $$ - +这就是所谓的行并行方式. 为了计算 $$ Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right] diff --git a/docs/source/zh-Hans/get_started/installation.md b/docs/source/zh-Hans/get_started/installation.md index e75e42530fc1..9e4f34707c13 100755 --- a/docs/source/zh-Hans/get_started/installation.md +++ b/docs/source/zh-Hans/get_started/installation.md @@ -20,10 +20,10 @@ pip install colossalai **注:现在只支持Linux。** -如果你想同时安装PyTorch扩展的话,可以添加`CUDA_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。 +如果你想同时安装PyTorch扩展的话,可以添加`BUILD_EXT=1`。如果不添加的话,PyTorch扩展会在运行时自动安装。 ```shell -CUDA_EXT=1 pip install colossalai +BUILD_EXT=1 pip install colossalai ``` ## 从源安装 @@ -38,10 +38,10 @@ cd ColossalAI pip install -r requirements/requirements.txt # install colossalai -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` -如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`CUDA_EXT=1`: +如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`BUILD_EXT=1`: ```shell pip install . @@ -60,7 +60,7 @@ unzip 1.8.0.zip cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/ # install -CUDA_EXT=1 pip install . +BUILD_EXT=1 pip install . ``` diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index 5396de6935cb..40b11d649ae0 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -16,10 +16,10 @@ from utils.logger import Logger import colossalai +from colossalai.accelerator import get_accelerator from colossalai.context import ParallelMode from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ProcessGroup, ShardSpec -from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext @@ -53,7 +53,7 @@ def main(): set_global_variables(launch_time, args.tensorboard_path) world_size = torch.distributed.get_world_size() - get_current_device() + get_accelerator().get_current_device() # build model, optimizer and criterion if args.distplan.startswith("CAI"): @@ -67,7 +67,10 @@ def main(): # build GPT model with ColoInitContext( - device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg + device=get_accelerator().get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg, ): config, model, numel = get_model(args, logger) @@ -78,7 +81,7 @@ def main(): elif args.distplan == "CAI_Gemini": gemini_config = dict( strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), + device=get_accelerator().get_current_device(), placement_policy=args.placement, pin_memory=True, hidden_dim=model.config.hidden_size, diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 1a7f8da7f7d0..cc2b2ebc7b88 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -20,11 +20,11 @@ from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device disable_existing_loggers() logger = get_dist_logger() @@ -386,7 +386,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -401,7 +401,7 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - pipeline.to(get_current_device()) + pipeline.to(get_accelerator().get_current_device()) for example in tqdm( sample_dataloader, @@ -578,8 +578,8 @@ def collate_fn(examples): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - vae.to(get_current_device(), dtype=weight_dtype) - text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_accelerator().get_current_device(), dtype=weight_dtype) + text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -613,7 +613,7 @@ def collate_fn(examples): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): - batch[key] = value.to(get_current_device(), non_blocking=True) + batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index ea6dde8bb578..227488abe204 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -21,13 +21,13 @@ from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device disable_existing_loggers() logger = get_dist_logger() @@ -385,7 +385,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 + torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -400,7 +400,7 @@ def main(args): sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) - pipeline.to(get_current_device()) + pipeline.to(get_accelerator().get_current_device()) for example in tqdm( sample_dataloader, @@ -598,8 +598,8 @@ def collate_fn(examples): # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. - vae.to(get_current_device(), dtype=weight_dtype) - text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_accelerator().get_current_device(), dtype=weight_dtype) + text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) @@ -633,7 +633,7 @@ def collate_fn(examples): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): - batch[key] = value.to(get_current_device(), non_blocking=True) + batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index 13df516d4189..5871bbf8748b 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -13,12 +13,12 @@ from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index b770bc9cfb95..0780173241aa 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -33,9 +33,10 @@ def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224 def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + from colossalai.accelerator import get_accelerator + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) print(f"Limiting GPU memory usage to {size_in_GB} GB") diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 9a26098b3847..26cac977a931 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -6,10 +6,9 @@ import transformers import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.inference import InferenceEngine from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -from colossalai.utils.device import get_current_device GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 @@ -52,7 +51,7 @@ def data_gen(batch_size: int = 4, seq_len: int = 512): - input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device()) + input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device()) attention_mask = torch.ones_like(input_ids) data = dict(input_ids=input_ids, attention_mask=attention_mask) return data @@ -97,9 +96,9 @@ def print_details_info(outputs, model_config, args, whole_end2end): msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n" if torch.cuda.is_available(): - msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n" - msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n" - msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n" + msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n" + msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n" + msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n" print(msg) diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py index 8f85a936352b..b5228c64efa5 100644 --- a/examples/inference/run_llama_inference.py +++ b/examples/inference/run_llama_inference.py @@ -5,9 +5,9 @@ from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.inference import InferenceEngine from colossalai.testing import spawn -from colossalai.utils.device import get_current_device INPUT_TEXTS = [ "What is the longest river in the world?", @@ -57,7 +57,7 @@ def run_inference(args): ) inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True) - inputs = {k: v.to(get_current_device()) for k, v in inputs.items()} + inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()} outputs = engine.generate(inputs) if rank == 0: diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index aad12c9c2c59..0b1e77ffff06 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -18,11 +18,11 @@ ) import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -59,7 +59,7 @@ def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True) - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] @@ -89,8 +89,10 @@ def evaluate_subset(dataloader: DataLoader): object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) - accum_loss.add_(object_list[1].to(get_current_device())) + metric.add_batch( + predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels + ) + accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index e811e1acbf7e..b35112498978 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -7,13 +7,13 @@ from torch.utils._pytree import tree_map import colossalai +from colossalai.accelerator import get_accelerator from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam from colossalai.testing import spawn -from colossalai.utils import get_current_device def parse_args(): @@ -41,7 +41,7 @@ def train_gpt(args): 64, 8, ), - device=get_current_device(), + device=get_accelerator().get_current_device(), ) criterion = GPTLMLoss() diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 88b76c654b1d..78d090ba29da 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -12,12 +12,12 @@ from packaging import version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device CAI_VERSION = colossalai.__version__ @@ -141,7 +141,11 @@ def main(): criterion = GPTLMLoss() torch.manual_seed(123) if args.distplan.startswith("CAI"): - ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext() + ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if args.distplan == "CAI_Gemini" + else nullcontext() + ) # build GPT model with ctx: model = model_builder(args.model_type)(checkpoint=True) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 62804eff8ea5..eb56ee530a0a 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -13,11 +13,11 @@ from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -54,7 +54,7 @@ def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] @@ -83,8 +83,10 @@ def evaluate_subset(dataloader: DataLoader): object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) - metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels) - accum_loss.add_(object_list[1].to(get_current_device())) + metric.add_batch( + predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels + ) + accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py index b2e3f71a5387..ec3df50c4e67 100644 --- a/examples/language/gpt/titans/model/embed.py +++ b/examples/language/gpt/titans/model/embed.py @@ -5,6 +5,7 @@ from torch.nn import functional as F from torch.nn.parameter import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.base_layer import ParallelLayer @@ -12,7 +13,6 @@ from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row from colossalai.legacy.nn.layer.utils import divide from colossalai.legacy.registry import LAYERS, LOSSES -from colossalai.utils import get_current_device class VocabParallelEmbedding(torch.nn.Module): @@ -96,7 +96,9 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if position_ids is None: - position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = torch.arange( + 0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device() + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_embeddings = self.position_embeddings(position_ids) @@ -194,7 +196,7 @@ def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None): self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) init.uniform_(self.weight, -1, 1) @@ -439,7 +441,9 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None): if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if position_ids is None: - position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device()) + position_ids = torch.arange( + 0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device() + ) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_embeddings = self.position_embeddings(position_ids) @@ -532,7 +536,7 @@ def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx self._weight = None # Allocate weights and initialize. - factory_kwargs = {"device": get_current_device(), "dtype": dtype} + factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) init.uniform_(self.weight, -1, 1) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index c931e5ba79d9..b8f70ce9c9d8 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -13,13 +13,12 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Constants @@ -177,7 +176,7 @@ def empty_init(): # Initialize Model and Optimizer # ============================== init_ctx = ( - LazyInitContext(default_device=get_current_device()) + LazyInitContext(default_device=get_accelerator().get_current_device()) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) else nullcontext() ) @@ -208,7 +207,9 @@ def empty_init(): torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) @@ -234,7 +235,7 @@ def empty_init(): performance_evaluator.on_step_end(**batch) performance_evaluator.on_fit_end() - coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py index a438833e1680..6b9e8ef28eb7 100644 --- a/examples/language/llama2/data_utils.py +++ b/examples/language/llama2/data_utils.py @@ -8,7 +8,7 @@ from torch.distributed.distributed_c10d import _get_default_group from torch.utils.data import DataLoader, Dataset, DistributedSampler -from colossalai.utils import get_current_device +from colossalai.accelerator import get_accelerator class StatefulDistributedSampler(DistributedSampler): @@ -108,7 +108,9 @@ class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index f7708b1a38ab..66b5400765f7 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -21,13 +21,13 @@ from transformers.models.llama.tokenization_llama import LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def get_model_numel(model: nn.Module) -> int: @@ -191,7 +191,9 @@ def main(): config = LlamaConfig.from_pretrained(args.model_path) # use lazy init when using GeminiPlugin init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, GeminiPlugin) + else nullcontext() ) with init_ctx: diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index 6b1c92711d48..c2169a730a88 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -5,9 +5,8 @@ import torch.distributed as dist from torch import Tensor -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.cluster import DistCoordinator -from colossalai.utils.device import get_current_device def divide(x: float, y: float) -> float: @@ -22,7 +21,7 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - tensor = torch.tensor([x], device=get_current_device()) + tensor = torch.tensor([x], device=get_accelerator().get_current_device()) dist.all_reduce(tensor) tensor = tensor / world_size return tensor.item() @@ -86,13 +85,13 @@ def on_step_start(self, step: int) -> None: self.disable = self.ignore_steps > 0 and step < self.ignore_steps if self.disable: return - device_utils.synchronize() + get_accelerator().synchronize() self.timer.start() def on_step_end(self, input_ids: Tensor, **kwargs) -> None: if self.disable: return - device_utils.synchronize() + get_accelerator().synchronize() self.timer.end() batch_size, seq_len = input_ids.shape diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index 8d5b7c8db05d..4cdf93e1914b 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -20,13 +20,13 @@ from transformers.models.llama.tokenization_llama import LlamaTokenizer import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device MODEL_CONFIGS = { "7b": LlamaConfig(max_position_embeddings=4096), @@ -227,7 +227,9 @@ def main(): config = MODEL_CONFIGS[args.config] # use lazy init when using GeminiPlugin init_ctx = ( - LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + LazyInitContext(default_device=get_accelerator().get_current_device()) + if isinstance(plugin, GeminiPlugin) + else nullcontext() ) with init_ctx: diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 65562b386cf9..03b660ecf446 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -14,6 +14,7 @@ from utils import PerformanceEvaluator, get_model_numel import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -21,7 +22,6 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -64,13 +64,15 @@ def __init__( ) self.input_ids.append(encode["input_ids"]) self.attention_mask.append(encode["attention_mask"]) - self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) - self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device()) repeat_times = num_samples // self.input_ids.shape[0] + 1 self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples] self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples] else: - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7644317903..eee3b505a22a 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -35,7 +35,7 @@ replace_return_docstrings, ) -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index f354bbea990e..17e7aa46ce85 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -43,7 +43,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False raise NotImplementedError( - "openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index b084361661ac..1ae661f548b8 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -15,6 +15,7 @@ from transformers.models.llama import LlamaConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -22,7 +23,6 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -61,7 +61,9 @@ class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 7af02e24e6cf..4fac7b5072ed 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -14,12 +14,12 @@ from torch.utils.data import DataLoader, Dataset import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import HybridAdam -from colossalai.utils import get_current_device # constants @@ -159,7 +159,11 @@ def __len__(self): logger.info(f"plugin: {plugin}") booster = Booster(plugin=plugin, **booster_kwargs) - ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() + ctx = ( + LazyInitContext(default_device=get_accelerator().get_current_device()) + if args.plugin == "gemini" + else nullcontext() + ) with ctx: model = PaLM(num_tokens=50304, dim=4096, depth=64) diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index 4407a51c3153..a4733126f3ee 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -13,12 +13,12 @@ from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 700e4d2e0cd9..ec6c852b5965 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -13,13 +13,13 @@ from tqdm import tqdm import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin.dp_plugin_base import DPPluginBase from colossalai.cluster import DistCoordinator from colossalai.nn.lr_scheduler import LinearWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -73,8 +73,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl @torch.no_grad() def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: model.eval() - correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) - total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device()) for images, labels in test_dataloader: images = images.cuda() labels = labels.cuda() diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 990822c9feba..e97c9017fe56 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -12,11 +12,11 @@ from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device # ============================== # Prepare Hyperparameters @@ -45,7 +45,7 @@ def evaluate( model.eval() def evaluate_subset(dataloader: DataLoader): - accum_loss = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) outputs = model(**batch) diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 9bd23ffc8aba..3f0d048795e6 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -51,13 +51,13 @@ from transformers.utils.versions import require_version import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.tensor import ProcessGroup from colossalai.legacy.utils import get_dataloader from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device from colossalai.zero import GeminiOptimizer require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") @@ -249,9 +249,9 @@ def parse_args(): def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction - cuda_capacity = colo_device_memory_capacity(get_current_device()) + cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) print("Using {} GB of GPU memory".format(size_in_GB)) @@ -265,7 +265,9 @@ def __init__(self, length, batch_size, seq_len, vocab_size): self.vocab_size = vocab_size def generate(self): - input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device()) + input_ids = torch.randint( + 0, self.vocab_size, (self.batch_size, self.seq_len), device=get_accelerator().get_current_device() + ) attention_mask = torch.ones_like(input_ids) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids} @@ -390,7 +392,7 @@ def main(): if args.init_in_cpu: init_dev = torch.device("cpu") else: - init_dev = get_current_device() + init_dev = get_accelerator().get_current_device() cai_version = colossalai.__version__ logger.info(f"using Colossal-AI version {cai_version}") @@ -439,7 +441,9 @@ def main(): except ImportError: # this works for unreleased main branch, and this may be released on 0.2.9 from colossalai.zero import GeminiDDP - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + model = GeminiDDP( + model, device=get_accelerator().get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True + ) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index 7b0e93d958ca..64260374a0d5 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -3,13 +3,13 @@ import torch import torch.nn as nn -from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.logging import get_dist_logger +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding from .layers.init_method import init_normal, output_init_normal diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index 75afeee60ad4..ff81ace39736 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -3,9 +3,9 @@ import torch.nn.functional as F from loss_func.cross_entropy import vocab_cross_entropy -from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .linear import Linear from .pooler import Pooler diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index e9ceb8d70cb8..f25fc818981a 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -8,12 +8,12 @@ from model.bert import BertForPretrain, build_pipeline_bert import colossalai -from colossalai.kernel import LayerNorm from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import is_using_pp from colossalai.logging import get_dist_logger +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from colossalai.nn.optimizer import FusedAdam from colossalai.utils import MultiTimer diff --git a/extensions/README.md b/extensions/README.md new file mode 100644 index 000000000000..6f5feb55c2af --- /dev/null +++ b/extensions/README.md @@ -0,0 +1,140 @@ +# 🔌 Extensions + +## 📌 Table of Contents + +- [🔌 Extensions](#-extensions) + - [📌 Table of Contents](#-table-of-contents) + - [📚 Introduction](#-introduction) + - [🪅 Design](#-design) + - [🛠 API Usage](#-api-usage) + - [🏗 Write a customized extension](#-write-a-customized-extension) + - [✏️ Acknowledgement](#️-acknowledgement) + +## 📚 Introduction + +This module is a designed to offer extensions to the existing ColossalAI framework. It is designed to be a collection of high-performance kernels to speed up the training and inference process. Different from writing an individual kernel, the `extensions` module offers a layer of abstraction to collate kernels written in different compiler backends and for different hardware backends in an organized way. Please see the design and usage in the sections below. + +## 🪅 Design + +The `extensions` module is a sub-module of the `colossalai.kernel` module. This module is put at the project root directory so that it can be imported for AOT (ahead-of-time) build. At the same time, it is symbolically linked at the `colossalai.kernel.extensions` path for runtime build. + +As we want to support multi-backend kernels, we have to consider multiple compiler options such as `torch.jit`, `CUDA`, `triton` and multiple hardware backends such as `CPU`, `GPU` and `NPU`. To make it easy for the users, we have abstract away the kernels into extensions and expose a single loader to the user for each kind of kernel. + +For example, if the user wants to use the CPU Adam kernel, he can just call `load()` on the kernel loader. The kernel loader will automatically select the correct extension based on the current hardware and compiler backend. The user does not need to worry about the details of the kernel implementation. For example, if the user is using ARM CPU, then Arm kernel will be built and loaded. If it is a X86 CPU, then it is the X86 kernel that will be loaded. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel compatible with the current hardware +kernel = CPUAdamLoader().load() +``` + +![](https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/extensions.png?raw=true) + +## 🛠 API Usage + +To make the `colossalai.kernel` easy to use, we expose some simple APIs and you can use them based on your scenario. + +- Case 1: Simply load a kernel + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel compatible with the current hardware +kernel = CPUAdamLoader().load() +``` + +- Case 2: Load a specific kernel + +This case applies if you are familiar with the extensions available. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader + +# load the kernel by giving the kernel name +kernel = CPUAdamLoader().load(ext_name="cpu_adam_arm") +``` + +- Case 3: Register your own extension + +This case applies if you know how to write an extension. If you do not know how, you can refer to the section below. + +```python +from colossalai.kernel.kernel_loader import CPUAdamLoader +from colossalai.kernel.base_extension import _Extension + +# create your own extension class +class MyExtension(_Extension): + + def __init__(self): + self._name = "my_extension" + self._support_aot = True + self._support_jit = True + self.priority = 10 + + # implementation here + ... + +# register your extension +# you can use the priority value to make sure your kernel will be loaded by default +CPUAdamLoader.register_extension(MyExtension) + +# load the kernel +kernel = CPUAdamLoader().load() +``` + +## 🏗 Write a customized extension + +It is easy to write a customized extension. If you have experience writing CUDA/triton kernels, you should get familiar with the process quickly. + +You just need to inherit the `_Extension` base class or other backend-specific classes such as `_CudaExtension` and implement the abstract methods. Then, you need to register your extension to the kernel loader based on the Case 3 above. The kernel loader will automatically select the correct extension based on the priority score, current hardware, compiler backend. + +```python +from colossalai.kernel.base_extension import _Extension + + +class MyExtension(_Extension): + + def __init__(self): + self._name = "my_extension" + self._support_aot = True + self._support_jit = True + self.priority = 10 + + def is_hardware_available(self) -> bool: + """ + Return if the required hardware can be found. + """ + ... + + def assert_hardware_compatible(self) -> None: + """ + Check if the hardware required by the kernel is compatible. + """ + ... + + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + """ + If this kernel can be built AOT, it should return an extension object + to Python setuptools for compilation. + """ + ... + + def build_jit(self) -> Callable: + """ + Build extension kernel just in time. + """ + ... + + def load(self): + """ + The API called by the user to get the kernel. + """ + ... + +``` + +## ✏️ Acknowledgement + +This module is written from scratch but we learnt a lot by looking into [DeepSpeed' +s op_builder](https://github.com/microsoft/DeepSpeed/tree/master/op_builder). We wish to acknowledge their great work and contributions to the open-source community. diff --git a/extensions/__init__.py b/extensions/__init__.py new file mode 100644 index 000000000000..9343cadda194 --- /dev/null +++ b/extensions/__init__.py @@ -0,0 +1,36 @@ +from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension +from .flash_attention import ( + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionXformersCudaExtension, +) +from .layernorm import LayerNormCudaExtension +from .moe import MoeCudaExtension +from .optimizer import FusedOptimizerCudaExtension +from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension + +ALL_EXTENSIONS = [ + CpuAdamArmExtension, + CpuAdamX86Extension, + LayerNormCudaExtension, + MoeCudaExtension, + FusedOptimizerCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, + FlashAttentionDaoCudaExtension, + FlashAttentionXformersCudaExtension, + FlashAttentionNpuExtension, +] + +__all__ = [ + "CpuAdamArmExtension", + "CpuAdamX86Extension", + "LayerNormCudaExtension", + "MoeCudaExtension", + "FusedOptimizerCudaExtension", + "ScaledMaskedSoftmaxCudaExtension", + "ScaledUpperTriangleMaskedSoftmaxCudaExtension", + "FlashAttentionDaoCudaExtension", + "FlashAttentionXformersCudaExtension", + "FlashAttentionNpuExtension", +] diff --git a/extensions/base_extension.py b/extensions/base_extension.py new file mode 100644 index 000000000000..c815a7f2ac4a --- /dev/null +++ b/extensions/base_extension.py @@ -0,0 +1,82 @@ +import hashlib +import os +from abc import ABC, abstractmethod +from typing import Callable, Union + +__all__ = ["_Extension"] + + +class _Extension(ABC): + def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1): + self._name = name + self._support_aot = support_aot + self._support_jit = support_jit + self.priority = priority + + @property + def name(self): + return self._name + + @property + def support_aot(self): + return self._support_aot + + @property + def support_jit(self): + return self._support_jit + + @staticmethod + def get_jit_extension_folder_path(): + """ + Kernels which are compiled during runtime will be stored in the same cache folder for reuse. + The folder is in the path ~/.cache/colossalai/torch_extensions/. + The name of the follows a common format: + torch._- + + The suffix is the hash value of the path of the `colossalai` file. + """ + import torch + + import colossalai + from colossalai.accelerator import get_accelerator + + # get torch version + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] + + # get device version + device_name = get_accelerator().name + device_version = get_accelerator().get_version() + + # use colossalai's file path as hash + hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest() + + # concat + home_directory = os.path.expanduser("~") + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}" + cache_directory = os.path.join(home_directory, extension_directory) + return cache_directory + + @abstractmethod + def is_hardware_available(self) -> bool: + """ + Check if the hardware required by the kernel is available. + """ + + @abstractmethod + def assert_hardware_compatible(self) -> None: + """ + Check if the hardware required by the kernel is compatible. + """ + + @abstractmethod + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + pass + + @abstractmethod + def build_jit(self) -> Callable: + pass + + @abstractmethod + def load(self) -> Callable: + pass diff --git a/extensions/cpp_extension.py b/extensions/cpp_extension.py new file mode 100644 index 000000000000..3adb65fb8f4e --- /dev/null +++ b/extensions/cpp_extension.py @@ -0,0 +1,134 @@ +import importlib +import os +import time +from abc import abstractmethod +from pathlib import Path +from typing import List + +from .base_extension import _Extension + +__all__ = ["_CppExtension"] + + +class _CppExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=True, support_jit=True, priority=priority) + + # we store the op as an attribute to avoid repeated building and loading + self.cached_op = None + + # build-related variables + self.prebuilt_module_path = "colossalai._C" + self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}" + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path("csrc"), path) + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + + # get the current file path + # iteratively check the parent directory + # if the parent directory is "extensions", then the current file path is the root directory + # otherwise, the current file path is inside the root directory + current_file_path = Path(__file__) + while True: + if current_file_path.name == "extensions": + break + else: + current_file_path = current_file_path.parent + extension_module_path = current_file_path + code_abs_path = extension_module_path.joinpath(code_path) + return str(code_abs_path) + + # functions must be overrided over + def strip_empty_entries(self, args): + """ + Drop any empty strings from the list of compile and link flags + """ + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def build_aot(self) -> "CppExtension": + from torch.utils.cpp_extension import CppExtension + + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) + + def build_jit(self) -> None: + from torch.utils.cpp_extension import load + + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + + def load(self): + try: + op_kernel = self.import_op() + except (ImportError, ModuleNotFoundError): + # if import error occurs, it means that the kernel is not pre-built + # so we build it jit + op_kernel = self.build_jit() + + return op_kernel diff --git a/extensions/cpu_adam/__init__.py b/extensions/cpu_adam/__init__.py new file mode 100644 index 000000000000..cfd26a6a20f8 --- /dev/null +++ b/extensions/cpu_adam/__init__.py @@ -0,0 +1,5 @@ +from .cpu_adam_arm import CpuAdamArmExtension +from .cpu_adam_x86 import CpuAdamX86Extension + +__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension'] + diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/cpu_adam/cpu_adam_arm.py new file mode 100644 index 000000000000..35bff3b55928 --- /dev/null +++ b/extensions/cpu_adam/cpu_adam_arm.py @@ -0,0 +1,41 @@ +import platform + +from ..cpp_extension import _CppExtension + + +class CpuAdamArmExtension(_CppExtension): + def __init__(self): + super().__init__(name="cpu_adam_arm") + + def is_hardware_available(self) -> bool: + # only arm allowed + return platform.machine() == "aarch64" + + def assert_hardware_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "aarch64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}" + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("arm/cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/op_builder/cpu_adam.py b/extensions/cpu_adam/cpu_adam_x86.py similarity index 60% rename from op_builder/cpu_adam.py rename to extensions/cpu_adam/cpu_adam_x86.py index 7988aae4be12..a38194167b01 100644 --- a/op_builder/cpu_adam.py +++ b/extensions/cpu_adam/cpu_adam_x86.py @@ -1,19 +1,27 @@ -from .builder import Builder -from .utils import append_nvcc_threads +import platform +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads -class CPUAdamBuilder(Builder): - NAME = "cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" +class CpuAdamX86Extension(_CudaExtension): def __init__(self): - super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + super().__init__(name="cpu_adam_x86") + + def is_hardware_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_hardware_available() + + def assert_hardware_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "x86_64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" + super().assert_hardware_compatible() # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path("cpu_adam.cpp"), + self.csrc_abs_path("cuda/cpu_adam.cpp"), ] return ret diff --git a/colossalai/kernel/cuda_native/__init__.py b/extensions/csrc/__init__.py similarity index 86% rename from colossalai/kernel/cuda_native/__init__.py rename to extensions/csrc/__init__.py index f8a974b5fb26..0eac28d23e24 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/extensions/csrc/__init__.py @@ -1,5 +1,4 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .mha.mha import ColoAttention from .multihead_attention import MultiHeadAttention from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax @@ -8,6 +7,5 @@ "MultiHeadAttention", "FusedScaleMaskSoftmax", "ScaledUpperTriangMaskedSoftmax", - "ColoAttention", "AttnMaskType", ] diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp b/extensions/csrc/arm/cpu_adam_arm.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp rename to extensions/csrc/arm/cpu_adam_arm.cpp diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h b/extensions/csrc/arm/cpu_adam_arm.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h rename to extensions/csrc/arm/cpu_adam_arm.h diff --git a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp b/extensions/csrc/cuda/colossal_C_frontend.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp rename to extensions/csrc/cuda/colossal_C_frontend.cpp diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/extensions/csrc/cuda/compat.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/compat.h rename to extensions/csrc/cuda/compat.h diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/extensions/csrc/cuda/cpu_adam.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam.cpp rename to extensions/csrc/cuda/cpu_adam.cpp diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/extensions/csrc/cuda/cpu_adam.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam.h rename to extensions/csrc/cuda/cpu_adam.h diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h rename to extensions/csrc/cuda/include/block_reduce.h diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/extensions/csrc/cuda/layer_norm_cuda.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp rename to extensions/csrc/cuda/layer_norm_cuda.cpp diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu rename to extensions/csrc/cuda/layer_norm_cuda_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/extensions/csrc/cuda/moe_cuda.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/moe_cuda.cpp rename to extensions/csrc/cuda/moe_cuda.cpp diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/extensions/csrc/cuda/moe_cuda_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu rename to extensions/csrc/cuda/moe_cuda_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/extensions/csrc/cuda/multi_tensor_adam.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu rename to extensions/csrc/cuda/multi_tensor_adam.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh b/extensions/csrc/cuda/multi_tensor_apply.cuh similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh rename to extensions/csrc/cuda/multi_tensor_apply.cuh diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu rename to extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/extensions/csrc/cuda/multi_tensor_lamb.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu rename to extensions/csrc/cuda/multi_tensor_lamb.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu rename to extensions/csrc/cuda/multi_tensor_scale_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu rename to extensions/csrc/cuda/multi_tensor_sgd_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/extensions/csrc/cuda/scaled_masked_softmax.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp rename to extensions/csrc/cuda/scaled_masked_softmax.cpp diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/extensions/csrc/cuda/scaled_masked_softmax.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h rename to extensions/csrc/cuda/scaled_masked_softmax.h diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_masked_softmax_cuda.cu diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/extensions/csrc/cuda/type_shim.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/type_shim.h rename to extensions/csrc/cuda/type_shim.h diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/extensions/csrc/scaled_softmax.py similarity index 94% rename from colossalai/kernel/cuda_native/scaled_softmax.py rename to extensions/csrc/scaled_softmax.py index 26a5bce16d5c..7c220d60dd19 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/extensions/csrc/scaled_softmax.py @@ -6,8 +6,7 @@ import torch import torch.nn as nn -from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder -from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader try: from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax @@ -35,7 +34,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): def forward(ctx, inputs, scale): global scaled_upper_triang_masked_softmax if scaled_upper_triang_masked_softmax: - scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load() + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() scale_t = torch.tensor([scale]) softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) @@ -67,7 +66,7 @@ def forward(ctx, inputs, mask, scale): # build and load kernel if not pre-built global scaled_masked_softmax if scaled_masked_softmax is None: - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py new file mode 100644 index 000000000000..842cd9713a99 --- /dev/null +++ b/extensions/cuda_extension.py @@ -0,0 +1,109 @@ +import os +import time +from abc import abstractmethod +from pathlib import Path +from typing import List + +from .base_extension import _Extension +from .cpp_extension import _CppExtension +from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list + +__all__ = ["_CudaExtension"] + +# Some constants for installation checks +MIN_PYTORCH_VERSION_MAJOR = 1 +MIN_PYTORCH_VERSION_MINOR = 10 + + +class _CudaExtension(_CppExtension): + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME + + if not CUDA_HOME: + raise AssertionError( + "[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions" + ) + check_system_pytorch_cuda_match(CUDA_HOME) + check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def build_jit(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME, load + + set_cuda_arch_list(CUDA_HOME) + + # get build dir + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + def build_aot(self) -> "CUDAExtension": + from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension + + set_cuda_arch_list(CUDA_HOME) + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/extensions/flash_attention/__init__.py b/extensions/flash_attention/__init__.py new file mode 100644 index 000000000000..18abb6191035 --- /dev/null +++ b/extensions/flash_attention/__init__.py @@ -0,0 +1,20 @@ +from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension +from .flash_attention_npu import FlashAttentionNpuExtension +from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension + +try: + import flash_attention # noqa + + HAS_FLASH_ATTN = True +except: + HAS_FLASH_ATTN = False + +try: + import xformers # noqa + + HAS_MEM_EFF_ATTN = True +except: + HAS_MEM_EFF_ATTN = False + + +__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/flash_attention/flash_attention_dao_cuda.py new file mode 100644 index 000000000000..1b7f8ac4736a --- /dev/null +++ b/extensions/flash_attention/flash_attention_dao_cuda.py @@ -0,0 +1,93 @@ +from ..base_extension import _Extension + + +class FlashAttentionDaoCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + + def load(self): + try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + except ImportError: + raise ModuleNotFoundError( + ( + "We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + ) + + from typing import Optional + + import torch + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: "SeqLenInfo", + seq_len_info_kv: "SeqLenInfo", + origin_attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + """ + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + # check if the input is in allowed dtypes + if padded: + if seq_len_info_kv == None: + seq_len_info_kv = seq_len_info_q + + attn_out = flash_attn_varlen_func( + q, + k, + v, + seq_len_info_q.cu_seqlens, + seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, + seq_len_info_kv.max_seqlen, + dropout_p, + scale, + causal, + ) + else: + attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) + return attn_out + + return flash_attention diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/flash_attention/flash_attention_npu.py new file mode 100644 index 000000000000..58d0f9306e3d --- /dev/null +++ b/extensions/flash_attention/flash_attention_npu.py @@ -0,0 +1,73 @@ +from ..base_extension import _Extension + + +class FlashAttentionNpuExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) + + def is_hardware_available(self) -> bool: + try: + import torch_npu # noqa + + return True + except: + return False + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu." + ) + + def load(self): + import torch + from einops import rearrange + + def npu_sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q=None, + seq_len_info_kv=None, + origin_attn_mask: torch.Tensor = None, + dropout_p: float = 0.0, + scale: float = 1.0, + causal=None, + padded=None, + ): + """ + The scaled dot product attention. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=origin_attn_mask, + dropout_p=dropout_p, + is_causal=origin_attn_mask is None, + scale=scale, + ) + output = rearrange(output, "b h s d -> b s (h d)") + return output + + return npu_sdpa_attention diff --git a/extensions/flash_attention/flash_attention_xformers_cuda.py b/extensions/flash_attention/flash_attention_xformers_cuda.py new file mode 100644 index 000000000000..27cd823de14b --- /dev/null +++ b/extensions/flash_attention/flash_attention_xformers_cuda.py @@ -0,0 +1,94 @@ +from ..base_extension import _Extension + + +class FlashAttentionXformersCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False) + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + + def load(self): + try: + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) + except ImportError: + raise ModuleNotFoundError( + ( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + ) + from typing import Optional + + import torch + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + def mem_eff_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: "SeqLenInfo", + seq_len_info_kv: "SeqLenInfo", + origin_attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + attn_bias = None + if padded: # bert style + if not causal: + attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + elif causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position embedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert causal, "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + if padded: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) + + # shape: (b*s, n, d) + if padded: + out = out.squeeze(0) + + return out + + return mem_eff_attention diff --git a/extensions/layernorm/__init__.py b/extensions/layernorm/__init__.py new file mode 100644 index 000000000000..9d1bd2d019ee --- /dev/null +++ b/extensions/layernorm/__init__.py @@ -0,0 +1,3 @@ +from .layernorm_cuda import LayerNormCudaExtension + +__all__ = ["LayerNormCudaExtension"] \ No newline at end of file diff --git a/extensions/layernorm/layernorm_cuda.py b/extensions/layernorm/layernorm_cuda.py new file mode 100644 index 000000000000..db5f2fce1368 --- /dev/null +++ b/extensions/layernorm/layernorm_cuda.py @@ -0,0 +1,24 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="layernorm_cuda") + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-maxrregcount=50"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros + return append_nvcc_threads(ret) diff --git a/extensions/moe/__init__.py b/extensions/moe/__init__.py new file mode 100644 index 000000000000..962084d4bdde --- /dev/null +++ b/extensions/moe/__init__.py @@ -0,0 +1,3 @@ +from .moe_cuda import MoeCudaExtension + +__all__ = ['MoeCudaExtension'] \ No newline at end of file diff --git a/op_builder/moe.py b/extensions/moe/moe_cuda.py similarity index 56% rename from op_builder/moe.py rename to extensions/moe/moe_cuda.py index 6f8028b1720c..52883e97fc3a 100644 --- a/op_builder/moe.py +++ b/extensions/moe/moe_cuda.py @@ -1,20 +1,17 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag -class MOEBuilder(Builder): - NAME = "moe" - PREBUILT_IMPORT_PATH = "colossalai._C.moe" - +class MoeCudaExtension(_CudaExtension): def __init__(self): - super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) + super().__init__(name="moe_cuda") def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["moe_cuda.cpp", "moe_cuda_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]] return ret def cxx_flags(self): diff --git a/extensions/optimizer/__init__.py b/extensions/optimizer/__init__.py new file mode 100644 index 000000000000..9c8e87cae5de --- /dev/null +++ b/extensions/optimizer/__init__.py @@ -0,0 +1,3 @@ +from .fused_optimizer_cuda import FusedOptimizerCudaExtension + +__all__ = ['FusedOptimizerCudaExtension'] \ No newline at end of file diff --git a/extensions/optimizer/fused_optimizer_cuda.py b/extensions/optimizer/fused_optimizer_cuda.py new file mode 100644 index 000000000000..e065cf34a17d --- /dev/null +++ b/extensions/optimizer/fused_optimizer_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import get_cuda_cc_flag + + +class FusedOptimizerCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="fused_optim_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/colossal_C_frontend.cpp", + "cuda/multi_tensor_sgd_kernel.cu", + "cuda/multi_tensor_scale_kernel.cu", + "cuda/multi_tensor_adam.cu", + "cuda/multi_tensor_l2norm_kernel.cu", + "cuda/multi_tensor_lamb.cu", + ] + ] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/extensions/softmax/__init__.py b/extensions/softmax/__init__.py new file mode 100644 index 000000000000..9bc50c6cd91c --- /dev/null +++ b/extensions/softmax/__init__.py @@ -0,0 +1,4 @@ +from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension +from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension + +__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension'] \ No newline at end of file diff --git a/op_builder/scaled_masked_softmax.py b/extensions/softmax/scaled_masked_softmax_cuda.py similarity index 50% rename from op_builder/scaled_masked_softmax.py rename to extensions/softmax/scaled_masked_softmax_cuda.py index d9239a80eef6..5b4208dba895 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/extensions/softmax/scaled_masked_softmax_cuda.py @@ -1,23 +1,20 @@ -from .builder import Builder -from .utils import append_nvcc_threads +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads -class ScaledMaskedSoftmaxBuilder(Builder): - NAME = "scaled_masked_softmax" - PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" - +class ScaledMaskedSoftmaxCudaExtension(_CudaExtension): def __init__(self): - super().__init__( - name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH - ) + super().__init__(name="scaled_masked_softmax_cuda") - # necessary 4 functions def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["scaled_masked_softmax.cpp", "scaled_masked_softmax_cuda.cu"]] + ret = [ + self.csrc_abs_path(fname) + for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"] + ] return ret def include_dirs(self): - return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + return [self.get_cuda_home_include()] def cxx_flags(self): return ["-O3"] + self.version_dependent_macros diff --git a/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py new file mode 100644 index 000000000000..d4f27a9218ff --- /dev/null +++ b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="scaled_upper_triangle_masked_softmax_cuda") + + def include_dirs(self): + return [self.get_cuda_home_include()] + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/scaled_upper_triang_masked_softmax.cpp", + "cuda/scaled_upper_triang_masked_softmax_cuda.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/extensions/triton_extension.py b/extensions/triton_extension.py new file mode 100644 index 000000000000..9f0792f8ce68 --- /dev/null +++ b/extensions/triton_extension.py @@ -0,0 +1,21 @@ +from .base_extension import _Extension + +__all__ = ["_TritonExtension"] + + +class _TritonExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=False, support_jit=True, priority=priority) + + def is_hardware_compatible(self) -> bool: + # cuda extension can only be built if cuda is available + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def load(self): + return self.build_jit() diff --git a/op_builder/utils.py b/extensions/utils.py similarity index 100% rename from op_builder/utils.py rename to extensions/utils.py diff --git a/op_builder/README.md b/op_builder/README.md deleted file mode 100644 index 9c33a4a328d7..000000000000 --- a/op_builder/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Build PyTorch Extensions - -## Overview - -Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users. - -1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1` -2. Build the extension during runtime - -The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program. - -These two methods have different advantages and disadvantages. -Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration. -Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load. - -## PyTorch Extensions in Colossal-AI - -The project [DeepSpeed](https://github.com/microsoft/DeepSpeed) has proposed a [solution](https://github.com/microsoft/DeepSpeed/tree/master/op_builder) to support kernel-build during either installation or runtime. -We have adapted from DeepSpeed's solution to build extensions. The extension build requires two main functions from PyTorch: - -1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`. -2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime - -Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong). - -Based on the DeepSpeed's work, we have make several modifications and improvements: - -1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C` -2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete) -3. Once a kernel is loaded, we will cache it in the builder to avoid repeated kernel loading. - -When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered. diff --git a/op_builder/__init__.py b/op_builder/__init__.py deleted file mode 100644 index 21e216437c47..000000000000 --- a/op_builder/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from .arm_cpu_adam import ArmCPUAdamBuilder -from .cpu_adam import CPUAdamBuilder -from .fused_optim import FusedOptimBuilder -from .layernorm import LayerNormBuilder -from .moe import MOEBuilder -from .multi_head_attn import MultiHeadAttnBuilder -from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder -from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder - -ALL_OPS = { - "cpu_adam": CPUAdamBuilder, - "fused_optim": FusedOptimBuilder, - "moe": MOEBuilder, - "multi_head_attn": MultiHeadAttnBuilder, - "scaled_masked_softmax": ScaledMaskedSoftmaxBuilder, - "scaled_upper_triangle_masked_softmax": ScaledUpperTrainglemaskedSoftmaxBuilder, - "layernorm": LayerNormBuilder, -} - -__all__ = [ - "ALL_OPS", - "CPUAdamBuilder", - "FusedOptimBuilder", - "MultiHeadAttnBuilder", - "ScaledMaskedSoftmaxBuilder", - "ScaledUpperTrainglemaskedSoftmaxBuilder", - "MOEBuilder", - "MultiTensorSGDBuilder", - "MultiTensorAdamBuilder", - "MultiTensorLambBuilder", - "MultiTensorScaleBuilder", - "MultiTensorL2NormBuilder", - "ArmCPUAdamBuilder", -] diff --git a/op_builder/arm_cpu_adam.py b/op_builder/arm_cpu_adam.py deleted file mode 100644 index 18dd519fae46..000000000000 --- a/op_builder/arm_cpu_adam.py +++ /dev/null @@ -1,34 +0,0 @@ -from .builder import Builder - - -class ArmCPUAdamBuilder(Builder): - NAME = "arm_cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" - ext_type = "cpu" - - def __init__(self): - super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # necessary 4 functions - def sources_files(self): - ret = [ - self.csrc_abs_path("cpu_adam_arm.cpp"), - ] - return ret - - def include_dirs(self): - return [self.csrc_abs_path("includes")] - - def cxx_flags(self): - extra_cxx_flags = [ - "-std=c++14", - "-std=c++17", - "-g", - "-Wno-reorder", - "-fopenmp", - ] - return ["-O3"] + self.version_dependent_macros + extra_cxx_flags - - def nvcc_flags(self): - return [] diff --git a/op_builder/builder.py b/op_builder/builder.py deleted file mode 100644 index d804cb1602e4..000000000000 --- a/op_builder/builder.py +++ /dev/null @@ -1,236 +0,0 @@ -# This code has been adapted from the DeepSpeed library. -# Copyright (c) Microsoft Corporation. - -# Licensed under the MIT License. -import importlib -import os -import time -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Optional, Union - -from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 - - -class Builder(ABC): - """ - Builder is the base class to build extensions for PyTorch. - - Args: - name (str): the name of the kernel to be built - prebuilt_import_path (str): the path where the extension is installed during pip install - """ - - ext_type: str = "cuda" - - def __init__(self, name: str, prebuilt_import_path: str): - self.name = name - self.prebuilt_import_path = prebuilt_import_path - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # we store the op as an attribute to avoid repeated building and loading - self.cached_op_module = None - - assert prebuilt_import_path.startswith( - "colossalai._C" - ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" - - def relative_to_abs_path(self, code_path: str) -> str: - """ - This function takes in a path relative to the colossalai root directory and return the absolute path. - """ - op_builder_module_path = Path(__file__).parent - - # if we install from source - # the current file path will be op_builder/builder.py - # if we install via pip install colossalai - # the current file path will be colossalai/kernel/op_builder/builder.py - # this is because that the op_builder inside colossalai is a symlink - # this symlink will be replaced with actual files if we install via pypi - # thus we cannot tell the colossalai root directory by checking whether the op_builder - # is a symlink, we can only tell whether it is inside or outside colossalai - if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): - root_path = op_builder_module_path.parent.parent - else: - root_path = op_builder_module_path.parent.joinpath("colossalai") - - code_abs_path = root_path.joinpath(code_path) - return str(code_abs_path) - - def get_cuda_home_include(self): - """ - return include path inside the cuda home. - """ - from torch.utils.cpp_extension import CUDA_HOME - - if CUDA_HOME is None: - raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") - cuda_include = os.path.join(CUDA_HOME, "include") - return cuda_include - - def csrc_abs_path(self, path): - return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) - - # functions must be overrided begin - @abstractmethod - def sources_files(self) -> List[str]: - """ - This function should return a list of source files for extensions. - """ - raise NotImplementedError - - @abstractmethod - def include_dirs(self) -> List[str]: - """ - This function should return a list of include files for extensions. - """ - - @abstractmethod - def cxx_flags(self) -> List[str]: - """ - This function should return a list of cxx compilation flags for extensions. - """ - - @abstractmethod - def nvcc_flags(self) -> List[str]: - """ - This function should return a list of nvcc compilation flags for extensions. - """ - - # functions must be overrided over - def strip_empty_entries(self, args): - """ - Drop any empty strings from the list of compile and link flags - """ - return [x for x in args if len(x) > 0] - - def import_op(self): - """ - This function will import the op module by its string name. - """ - return importlib.import_module(self.prebuilt_import_path) - - def check_runtime_build_environment(self): - """ - Check whether the system environment is ready for extension compilation. - """ - try: - from torch.utils.cpp_extension import CUDA_HOME - - TORCH_AVAILABLE = True - except ImportError: - TORCH_AVAILABLE = False - CUDA_HOME = None - - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" - ) - - if CUDA_HOME is None: - raise RuntimeError( - "CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - # make sure CUDA is available for compilation during - cuda_available = check_cuda_availability() - if not cuda_available: - raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.") - - # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not - check_system_pytorch_cuda_match(CUDA_HOME) - - def load(self, verbose: Optional[bool] = None): - """ - load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. - If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the - kernel is built during pip install, it can be accessed through `colossalai._C`. - - Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. - - Args: - verbose (bool, optional): show detailed info. Defaults to True. - """ - if verbose is None: - verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" - # if the kernel has be compiled and cached, we directly use it - if self.cached_op_module is not None: - return self.cached_op_module - - try: - # if the kernel has been pre-built during installation - # we just directly import it - op_module = self.import_op() - if verbose: - print_rank_0( - f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." - ) - except ImportError: - # check environment - if self.ext_type == "cuda": - self.check_runtime_build_environment() - - # time the kernel compilation - start_build = time.time() - - # construct the build directory - import torch - from torch.utils.cpp_extension import load - - torch_version_major = torch.__version__.split(".")[0] - torch_version_minor = torch.__version__.split(".")[1] - torch_cuda_version = torch.version.cuda - home_directory = os.path.expanduser("~") - extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" - build_directory = os.path.join(home_directory, extension_directory) - Path(build_directory).mkdir(parents=True, exist_ok=True) - - if verbose: - print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") - - # load the kernel - op_module = load( - name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - build_directory=build_directory, - verbose=verbose, - ) - - build_duration = time.time() - start_build - - # log jit compilation time - if verbose: - print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds") - - # cache the built/loaded kernel - self.cached_op_module = op_module - - return op_module - - def builder(self) -> Union["CUDAExtension", "CppExtension"]: - """ - get a CUDAExtension instance used for setup.py - """ - from torch.utils.cpp_extension import CppExtension, CUDAExtension - - if self.ext_type == "cpp": - return CppExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args=self.strip_empty_entries(self.cxx_flags()), - ) - - return CUDAExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args={ - "cxx": self.strip_empty_entries(self.cxx_flags()), - "nvcc": self.strip_empty_entries(self.nvcc_flags()), - }, - ) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py deleted file mode 100644 index 3baa0880d801..000000000000 --- a/op_builder/fused_optim.py +++ /dev/null @@ -1,37 +0,0 @@ -from .builder import Builder -from .utils import get_cuda_cc_flag - - -class FusedOptimBuilder(Builder): - NAME = "fused_optim" - PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim" - - def __init__(self): - super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "colossal_C_frontend.cpp", - "multi_tensor_sgd_kernel.cu", - "multi_tensor_scale_kernel.cu", - "multi_tensor_adam.cu", - "multi_tensor_l2norm_kernel.cu", - "multi_tensor_lamb.cu", - ] - ] - return ret - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def cxx_flags(self): - version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - return ["-O3"] + version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = ["-lineinfo"] - extra_cuda_flags.extend(get_cuda_cc_flag()) - return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/op_builder/gptq.py b/op_builder/gptq.py deleted file mode 100644 index a17801f8783c..000000000000 --- a/op_builder/gptq.py +++ /dev/null @@ -1,56 +0,0 @@ -import re - -import torch - -from .builder import Builder -from .utils import append_nvcc_threads - - -class GPTQBuilder(Builder): - NAME = "cu_gptq" - PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" - - def __init__(self): - super().__init__(name=GPTQBuilder.NAME, prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "gptq/linear_gptq.cpp", - "gptq/column_remap.cu", - "gptq/cuda_buffers.cu", - "gptq/q4_matmul.cu", - "gptq/q4_matrix.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-v", - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - "-lcublas", - ] - - for arch in torch.cuda.get_arch_list(): - res = re.search(r"sm_(\d+)", arch) - if res: - arch_cap = res[1] - if int(arch_cap) >= 80: - extra_cuda_flags.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) - - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py deleted file mode 100644 index 2684c6ddb7f7..000000000000 --- a/op_builder/layernorm.py +++ /dev/null @@ -1,27 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class LayerNormBuilder(Builder): - NAME = "layernorm" - PREBUILT_IMPORT_PATH = "colossalai._C.layernorm" - - def __init__(self): - super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) - - def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"]] - return ret - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = ["-maxrregcount=50"] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros - return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py deleted file mode 100644 index cb8fc489ced1..000000000000 --- a/op_builder/multi_head_attn.py +++ /dev/null @@ -1,46 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class MultiHeadAttnBuilder(Builder): - NAME = "multihead_attention" - PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" - - def __init__(self): - super().__init__(name=MultiHeadAttnBuilder.NAME, prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "multihead_attention_1d.cpp", - "kernels/cublas_wrappers.cu", - "kernels/transform_kernels.cu", - "kernels/dropout_kernels.cu", - "kernels/normalize_kernels.cu", - "kernels/softmax_kernels.cu", - "kernels/general_kernels.cu", - "kernels/cuda_util.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py deleted file mode 100644 index 1445230acbc1..000000000000 --- a/op_builder/scaled_upper_triangle_masked_softmax.py +++ /dev/null @@ -1,37 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): - NAME = "scaled_upper_triangle_masked_softmax" - PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" - - def __init__(self): - super().__init__( - name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, - prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH, - ) - - def include_dirs(self): - return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py deleted file mode 100644 index d562a4c4f626..000000000000 --- a/op_builder/smoothquant.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch - -from .builder import Builder -from .utils import append_nvcc_threads - - -class SmoothquantBuilder(Builder): - NAME = "cu_smoothquant" - PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" - - def __init__(self): - super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "smoothquant/binding.cpp", - "smoothquant/linear.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - compute_capability = torch.cuda.get_device_capability() - cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 - - extra_cuda_flags = [ - "-v", - f"-DCUDA_ARCH={cuda_arch}", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) - - def builder(self): - try: - super().builder() - except: - warnings.warn("build smoothquant lib not successful") diff --git a/setup.py b/setup.py index cda1ba7ee7a6..1244bfff0327 100644 --- a/setup.py +++ b/setup.py @@ -5,55 +5,23 @@ from setuptools import find_packages, setup -from op_builder.utils import ( - check_cuda_availability, - check_pytorch_version, - check_system_pytorch_cuda_match, - get_cuda_bare_metal_version, - get_pytorch_version, - set_cuda_arch_list, -) - try: - from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + import torch # noqa + from torch.utils.cpp_extension import BuildExtension TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False - CUDA_HOME = None -# Some constants for installation checks -MIN_PYTORCH_VERSION_MAJOR = 1 -MIN_PYTORCH_VERSION_MINOR = 10 THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -BUILD_CUDA_EXT = int(os.environ.get("CUDA_EXT", "0")) == 1 +BUILD_EXT = int(os.environ.get("BUILD_EXT", "0")) == 1 IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1 -# a variable to store the op builder -ext_modules = [] - # we do not support windows currently if sys.platform == "win32": raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") -# check for CUDA extension dependencies -def environment_check_for_cuda_extension_build(): - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" - ) - - if not CUDA_HOME: - raise RuntimeError( - "[extension] CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - check_system_pytorch_cuda_match(CUDA_HOME) - check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) - check_cuda_availability() - - def fetch_requirements(path) -> List[str]: """ This function reads the requirements file. @@ -98,46 +66,35 @@ def get_version() -> str: # write version into version.py with open(version_py_path, "w") as f: f.write(f"__version__ = '{version}'\n") - - # look for pytorch and cuda version - if BUILD_CUDA_EXT: - torch_major, torch_minor, _ = get_pytorch_version() - torch_version = f"{torch_major}.{torch_minor}" - cuda_version = ".".join(get_cuda_bare_metal_version(CUDA_HOME)) - else: - torch_version = None - cuda_version = None - - # write the version into the python file - if torch_version: - f.write(f'torch = "{torch_version}"\n') - else: - f.write("torch = None\n") - - if cuda_version: - f.write(f'cuda = "{cuda_version}"\n') - else: - f.write("cuda = None\n") - return version -if BUILD_CUDA_EXT: - environment_check_for_cuda_extension_build() - set_cuda_arch_list(CUDA_HOME) +if BUILD_EXT: + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + ) - from op_builder import ALL_OPS + from extensions import ALL_EXTENSIONS op_names = [] + ext_modules = [] - # load all builders - for name, builder_cls in ALL_OPS.items(): - op_names.append(name) - ext_modules.append(builder_cls().builder()) + for ext_cls in ALL_EXTENSIONS: + ext = ext_cls() + if ext.support_aot and ext.is_hardware_available(): + ext.assert_hardware_compatible() + op_names.append(ext.name) + ext_modules.append(ext.build_aot()) # show log - op_name_list = ", ".join(op_names) - print(f"[extension] loaded builders for {op_name_list}") + if len(ext_modules) == 0: + raise RuntimeError("[extension] Could not find any kernel compatible with the current environment.") + else: + op_name_list = ", ".join(op_names) + print(f"[extension] Building extensions{op_name_list}") +else: + ext_modules = [] # always put not nightly branch as the if branch # otherwise github will treat colossalai-nightly as the project name diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 2c8b260e6498..373ba28b8545 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -5,13 +5,13 @@ from torch.utils._pytree import tree_map import colossalai +from colossalai.accelerator import get_accelerator from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed @@ -31,7 +31,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): 64, 8, ), - device=get_current_device(), + device=get_accelerator().get_current_device(), ) criterion = LMLoss() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index aba746f1992d..d577173266da 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -10,12 +10,12 @@ except: NO_CODEGEN = True +from colossalai.accelerator import get_accelerator from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn -from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper @@ -72,7 +72,11 @@ def check_auto_parallel_with_gemini(rank, world_size, port): print("=" * msg_length) gemini_config = dict( - strict_ddp_mode=False, device=get_current_device(), placement_policy="cpu", pin_memory=True, search_range_m=128 + strict_ddp_mode=False, + device=get_accelerator().get_current_device(), + placement_policy="cpu", + pin_memory=True, + search_range_m=128, ) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index 6f2fc104fc07..d629e769d715 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -8,13 +8,14 @@ from torch.utils.data import Dataset import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device, set_seed +from colossalai.utils import set_seed from tests.kit.model_zoo import model_zoo @@ -23,7 +24,9 @@ def __init__(self, num_samples: int = 100, max_length: int = 512, vocab_size: in self.num_samples = num_samples self.max_length = max_length set_seed(42) - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.input_ids = torch.randint( + 0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device() + ) self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): @@ -115,6 +118,20 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): @parameterize( "test_args", [ + { + "batch_size": 8, + "num_steps": 4, + "tp": 2, + "pp": 2, + "pp_style": "1f1b", + "num_model_chunks": 1, + "num_microbatches": 4, + "zero": 1, + "precision": "fp16", + "initial_scale": 1, + "max_length": 512, + "gradient_accumulation_step": 2, + }, { "batch_size": 8, "num_steps": 4, diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 286f431d5c8c..861fa0131397 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -5,7 +5,7 @@ from torch.optim import Adam import colossalai -import colossalai.utils.device as device_utils +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin @@ -23,7 +23,7 @@ @clear_cache_before_run() def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: - device = device_utils.get_current_device() + device = get_accelerator().get_current_device() try: plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) @@ -75,7 +75,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) - device_utils.empty_cache() + get_accelerator().empty_cache() if err is None: passed_models.append(name) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 49fd85ffba0a..61cac1d8369b 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -7,7 +7,6 @@ from utils import shared_tempdir import colossalai -from colossalai.testing import skip_if_not_enough_gpus from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin from colossalai.lazy import LazyInitContext @@ -17,6 +16,7 @@ clear_cache_before_run, parameterize, rerun_if_address_is_in_use, + skip_if_not_enough_gpus, spawn, ) from tests.kit.model_zoo import model_zoo @@ -52,7 +52,12 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b bert_model.config.save_pretrained(save_directory=pretrained_path) extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size) + plugin = GeminiPlugin( + **placement_config, + tp_size=tp_size, + enable_all_optimization=enable_all_optimization, + extra_dp_size=extra_dp_size, + ) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -78,14 +83,21 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha criterion = lambda x: x.mean() enable_all_optimization = True if tp_size > 1 else False extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization) + plugin = GeminiPlugin( + **placement_config, + precision="fp16", + initial_scale=(2**14), + tp_size=tp_size, + extra_dp_size=extra_dp_size, + enable_all_optimization=enable_all_optimization, + ) booster = Booster(plugin=plugin) model = model_fn() new_model = model_fn() optimizer = HybridAdam(model.parameters(), lr=0.001) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_optimizer = HybridAdam(new_model.parameters(), lr=0.01) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) data = data_gen_fn() @@ -97,6 +109,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha booster.backward(loss, optimizer) optimizer.step() + for group in optimizer.param_groups: + group["lr"] = 0.1 with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" @@ -115,6 +129,8 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha check_state_dict_equal( optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False ) + for group in new_optimizer.param_groups: + assert group["lr"] == 0.1 # Check the new model/optimizer can successfully run. data = data_gen_fn() @@ -161,8 +177,13 @@ def run_dist(rank, world_size, port): def test_gemini_ckpIO(): spawn(run_dist, 4) + @pytest.mark.largedist @skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use() def test_gemini_ckpIO_3d(): - spawn(run_dist, 8) \ No newline at end of file + spawn(run_dist, 8) + + +if __name__ == "__main__": + test_gemini_ckpIO() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index a42b550cd6fc..b5cb31715aed 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -83,7 +83,8 @@ def _preprocess_data(data): optimizer.backward(loss) optimizer.step() - + for group in optimizer.param_groups: + group["lr"] = 0.1 with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py index 7d2c81972e5a..079022e930cf 100644 --- a/tests/test_legacy/test_comm/test_comm.py +++ b/tests/test_legacy/test_comm/test_comm.py @@ -2,12 +2,12 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.initialize import launch from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) @@ -16,7 +16,7 @@ def check_all_gather(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) @@ -27,7 +27,7 @@ def check_all_gather(): def check_reduce_scatter(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) @@ -38,7 +38,7 @@ def check_reduce_scatter(): def check_all_reduce(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) - tensor = tensor.to(get_current_device()) + tensor = tensor.to(get_accelerator().get_current_device()) print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py index 8a9a73d65f38..f09df9253a38 100644 --- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -2,6 +2,7 @@ import torch.distributed as dist from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.global_variables import tensor_parallel_env as env @@ -16,13 +17,12 @@ VocabParallelEmbedding1D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear_col(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -68,7 +68,7 @@ def check_linear_col(): print_rank_0("linear_col forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) dist.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] grad = grad.clone() @@ -91,7 +91,7 @@ def check_linear_col(): def check_linear_row(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -137,7 +137,7 @@ def check_linear_row(): print_rank_0("linear_row forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) dist.broadcast(grad_master, src=0) grad = grad_master.clone() out.backward(grad) @@ -159,7 +159,7 @@ def check_linear_row(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -201,7 +201,7 @@ def check_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -243,7 +243,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -309,7 +309,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -369,7 +369,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -420,7 +420,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -508,7 +508,7 @@ def check_vocab_parallel_loss(): @torch.no_grad() def check_linear_row_stream_inference(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py index 0bbc72eca809..78bd407b9193 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -1,5 +1,6 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( @@ -16,13 +17,12 @@ VocabParallelEmbedding2D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = HIDDEN_SIZE @@ -74,7 +74,7 @@ def check_linear(): print_rank_0("linear forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -103,7 +103,7 @@ def check_linear(): def check_layernorm(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE EPS = 1e-12 @@ -139,7 +139,7 @@ def check_layernorm(): print_rank_0("layer norm forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -154,7 +154,7 @@ def check_layernorm(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -201,7 +201,7 @@ def check_embed(): def check_patch_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -274,7 +274,7 @@ def check_patch_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -321,7 +321,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = NUM_CLASSES @@ -371,7 +371,7 @@ def check_classifier_no_given_weight(): print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] # grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -399,7 +399,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -467,7 +467,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -519,7 +519,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -573,7 +573,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -608,7 +608,7 @@ def check_loss(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) @@ -645,7 +645,7 @@ def check_vocab_parallel_loss(): # def check_attention(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 @@ -683,7 +683,7 @@ def check_vocab_parallel_loss(): # print_rank_0('self attention backward: pass') # def check_mlp(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE @@ -716,7 +716,7 @@ def check_vocab_parallel_loss(): # print_rank_0('mlp backward: pass') # def check_transformerlayer(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py index 9c126cefeba8..4506cfee686d 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -3,11 +3,11 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal @@ -27,7 +27,7 @@ def check_AB(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[j] @@ -35,7 +35,7 @@ def check_AB(): A.requires_grad = True B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=0)[i] B = torch.chunk(B, DEPTH, dim=-1)[j] @@ -72,7 +72,7 @@ def check_AB(): print_rank_0("AB forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -105,7 +105,7 @@ def check_ABT(): tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float - device = get_current_device() + device = get_accelerator().get_current_device() j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) @@ -184,7 +184,7 @@ def check_ATB(): ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index 283e7f68374f..914607614a00 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -1,6 +1,7 @@ import torch from torch.nn import Parameter +from colossalai.accelerator import get_accelerator from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import ( @@ -17,13 +18,12 @@ VocabParallelEmbedding2p5D, ) from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import * def check_linear(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -76,7 +76,7 @@ def check_linear(): print_rank_0("linear forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -104,7 +104,7 @@ def check_linear(): def check_layernorm(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE EPS = 1e-12 @@ -141,7 +141,7 @@ def check_layernorm(): print_rank_0("layer norm forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -156,7 +156,7 @@ def check_layernorm(): def check_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -204,7 +204,7 @@ def check_embed(): def check_patch_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -278,7 +278,7 @@ def check_patch_embed(): def check_vocab_parallel_embed(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -326,7 +326,7 @@ def check_vocab_parallel_embed(): def check_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = NUM_CLASSES @@ -377,7 +377,7 @@ def check_classifier_no_given_weight(): print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] # grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -405,7 +405,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -524,7 +524,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -578,7 +578,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -613,7 +613,7 @@ def check_loss(): def check_vocab_parallel_loss(): - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -650,7 +650,7 @@ def check_vocab_parallel_loss(): # def check_attention(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 @@ -689,7 +689,7 @@ def check_vocab_parallel_loss(): # print_rank_0('self attention backward: pass') # def check_mlp(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE @@ -725,7 +725,7 @@ def check_vocab_parallel_loss(): # print_rank_0('mlp backward: pass') # def check_transformerlayer(): -# device = get_current_device() +# device = get_accelerator().get_current_device() # dtype = torch.float32 # INPUT_SIZE = HIDDEN_SIZE # NUM_ATTENTION_HEADS = 2 diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py index 992bd6107f08..91a15c81dfe5 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -1,10 +1,10 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D from colossalai.legacy.utils import print_rank_0 -from colossalai.utils import get_current_device from .common import * @@ -25,7 +25,7 @@ def check_AB(): k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device()) + A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(A_master, src=0) A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] @@ -33,7 +33,7 @@ def check_AB(): A.requires_grad = True B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device()) + B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] @@ -70,7 +70,7 @@ def check_AB(): print_rank_0("AB forward: pass") grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i] grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j] @@ -103,7 +103,7 @@ def check_ABT(): tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float - device = get_current_device() + device = get_accelerator().get_current_device() i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) @@ -184,7 +184,7 @@ def check_ATB(): ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py index a4a4ae9a5ba4..f9f19a17b9d1 100644 --- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -5,6 +5,7 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.core import global_context from colossalai.legacy.nn import ( @@ -23,7 +24,6 @@ from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.utils import print_rank_0 from colossalai.logging import get_dist_logger -from colossalai.utils import get_current_device from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal @@ -31,7 +31,7 @@ def check_linear(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE OUTPUT_SIZE = 2 * HIDDEN_SIZE @@ -84,7 +84,7 @@ def check_linear(): logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=-1)[j] @@ -119,7 +119,7 @@ def check_linear(): def check_layernorm(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -206,7 +206,7 @@ def check_layernorm(): def check_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -258,7 +258,7 @@ def check_classifier_no_given_weight(): logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, device=get_current_device()) + grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -306,7 +306,7 @@ def check_classifier_no_given_weight(): def check_vocab_parallel_classifier_no_given_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() INPUT_SIZE = HIDDEN_SIZE input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -413,7 +413,7 @@ def check_vocab_parallel_classifier_no_given_weight(): def check_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() dtype = torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -463,7 +463,7 @@ def check_classifier_given_embed_weight(): logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device()) torch.distributed.broadcast(grad_master, src=0) grad = torch.chunk(grad_master, DEPTH, dim=0)[i] grad = torch.chunk(grad, DEPTH, dim=0)[j] @@ -497,7 +497,7 @@ def check_classifier_given_embed_weight(): def check_vocab_parallel_classifier_given_embed_weight(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -580,7 +580,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): def check_patch_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -678,7 +678,7 @@ def check_patch_embed(): def check_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -746,7 +746,7 @@ def check_embed(): def check_vocab_parallel_embed(): rank = torch.distributed.get_rank() - device = get_current_device() + device = get_accelerator().get_current_device() logger = get_dist_logger() torch.float32 @@ -823,7 +823,7 @@ def check_vocab_parallel_embed(): def check_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -876,7 +876,7 @@ def check_loss(): def check_vocab_parallel_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() - device = get_current_device() + device = get_accelerator().get_current_device() torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py index aa4d5d6ceeb3..f4ad0d6d1671 100644 --- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -1,9 +1,9 @@ import torch +from colossalai.accelerator import get_accelerator from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn import TransformerSelfAttentionRing -from colossalai.utils import get_current_device def check_selfattention(): @@ -13,10 +13,10 @@ def check_selfattention(): HIDDEN_SIZE = 16 layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) - layer = layer.to(get_current_device()) + layer = layer.to(get_accelerator().get_current_device()) - hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) + hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_accelerator().get_current_device()) attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( - get_current_device() + get_accelerator().get_current_device() ) layer(hidden_states, attention_mask) diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py index a5a2d38577dc..cab111358c9c 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist +from colossalai.accelerator import get_accelerator from colossalai.legacy.communication import ( recv_backward, recv_forward, @@ -18,7 +19,6 @@ from colossalai.legacy.initialize import launch from colossalai.logging import get_dist_logger from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device BATCH_SIZE = 4 SEQ_LENGTH = 2 @@ -73,7 +73,7 @@ def check_forward_backward(output_tensor, output_grad, rank, logger): def check_comm(size, rank, prev_rank, next_rank, logger): dtype = torch.float32 - device = get_current_device() + device = get_accelerator().get_current_device() tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) grad_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) tensor = torch.randn(tensor_shape, dtype=dtype, device=device) diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 9df7cf75aae5..4993df4f3713 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -1,15 +1,15 @@ import pytest import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.testing import spawn -from colossalai.utils.device import get_current_device def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): - frac1 = colo_device_memory_capacity(get_current_device()) + frac1 = colo_device_memory_capacity(get_accelerator().get_current_device()) colo_set_process_memory_fraction(0.5) - frac2 = colo_device_memory_capacity(get_current_device()) + frac2 = colo_device_memory_capacity(get_accelerator().get_current_device()) assert frac2 * 2 == frac1 diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py index b5f2be705890..9975cc04ff30 100644 --- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -4,12 +4,12 @@ from torch.nn.utils import clip_grad_norm_ import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.tensor import ColoTensorSpec, ProcessGroup, distspec from colossalai.legacy.utils.common import clip_grad_norm from colossalai.logging import disable_existing_loggers from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): @@ -36,7 +36,7 @@ def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: @parameterize("norm_type", [2.0, 3.0, float("inf")]) def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): print(f"{world_size}, {dtype}, {device}, {norm_type}") - cuda_device = get_current_device() + cuda_device = get_accelerator().get_current_device() devices = [cuda_device] * 4 if device == "cpu": devices = [torch.device("cpu")] * 4 diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 721a4796abfd..17b790e3e87a 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,13 +1,22 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size + + +def delete_moe_info(model): + for _, param in model.named_parameters(): + if hasattr(param, "moe_info"): + delattr(param, "moe_info") class MoeModel(nn.Module): @@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert not torch.allclose(a, b), \ - (f"expected tensors on rank {i} and {i + 1} not to be equal " - f"but they are, {a} vs {b}") + assert not torch.allclose(a, b), ( + f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}" + ) + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + local_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): + assert local_name in ep_name, print(f"{local_name} != {ep_name}") + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) + + +def loose_close(a, b, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + a = a.detach().to(dtype) + b = b.detach().to(dtype).to(a.device) + + assert_close(a, b, rtol=rtol, atol=atol) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 3fac624729db..a349bc5a910a 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -4,10 +4,10 @@ import torch.nn as nn import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler BATCH_SIZE = 4 @@ -38,7 +38,7 @@ def run_test(rank, world_size, port): layer_list.append(moe_layer) model = nn.ModuleList(layer_list) - model = model.to(get_current_device()) + model = model.to(get_accelerator().get_current_device()) dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) @@ -52,7 +52,7 @@ def run_test(rank, world_size, port): rank = dist.get_rank() torch.cuda.manual_seed(78 + rank) - data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + data = torch.randn(BATCH_SIZE, DIM, device=get_accelerator().get_current_device()) grad = torch.randn_like(data) MOE_MANAGER.reset_loss() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 255ec7444a2c..62d61a3d4b2c 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -3,10 +3,10 @@ import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device BATCH_SIZE = 4 NUM_EXPERTS = 4 @@ -28,7 +28,9 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data - tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) + tokens = torch.randn( + BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True + ) layer = SparseMLP( hidden_size=hidden_size, @@ -37,7 +39,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f router_top_k=topk, router_capacity_factor_train=1.0, ) - layer = layer.to(get_current_device()) + layer = layer.to(get_accelerator().get_current_device()) if data_type == torch.float16: layer = layer.half() @@ -45,7 +47,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.enable_kernel = False old_out = layer(tokens) ech = old_out.shape - grad = torch.randn(ech, device=get_current_device()) + grad = torch.randn(ech, device=get_accelerator().get_current_device()) old_out.backward(grad) # get gradient # save all results diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index bd1103df30d3..d6dad2d7fb41 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -9,11 +9,10 @@ from transformers.models.llama import LlamaConfig import colossalai +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device sys.path.append( os.path.join( @@ -28,7 +27,7 @@ def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): - input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) + input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device()) attention_mask = torch.ones_like(input_ids) return { "input_ids": input_ids, @@ -95,6 +94,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=1, zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) @@ -103,6 +103,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=dist.get_world_size(), zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) @@ -111,6 +112,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=2, zero_stage=2, extra_dp_size=2, custom_policy=OpenMoeForCausalLMPolicy(), @@ -120,6 +122,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=2, + ep_size=2, zero_stage=1, microbatch_size=1, custom_policy=OpenMoeForCausalLMPolicy(), @@ -130,27 +133,6 @@ def get_model(parallel): def _test_moe_checkpoint(rank, parallel): - if parallel == None: - MOE_MANAGER.setup( - parallel=None, - ) - elif parallel == "ep": - MOE_MANAGER.setup( - parallel="EP", - ) - elif parallel == "ep_zero": - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=2, - ) - elif parallel == "hybrid": - MOE_MANAGER.setup( - parallel="EP", - mode="fixed", - fixed_dp_size=1, - fixed_ep_size=2, - fixed_pp_size=2, - ) model1, booster1, optim1 = get_model(parallel) model2, booster2, optim2 = get_model(parallel) model3, booster3, optim3 = get_model(parallel) @@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel): _test_moe_checkpoint(rank, parallel) +@pytest.mark.skip(reason="This is tested in ColossalMOE") @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"]) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index f87d4c792155..74feeeb59722 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -7,12 +7,12 @@ import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler @@ -23,8 +23,9 @@ def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_ tp_model (MoeModule) local_model (MoeModule) """ - for (tp_name, tp_param), (local_name, local_param) in \ - zip(tp_model.named_parameters(), local_model.named_parameters()): + for (tp_name, tp_param), (local_name, local_param) in zip( + tp_model.named_parameters(), local_model.named_parameters() + ): assert tp_name == local_name if not is_moe_tensor(tp_param): if assert_grad_flag: @@ -54,8 +55,7 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: tp_model (MoeModule) ep_model (MoeModule) """ - for (tp_name, tp_param), (ep_name, ep_param) in \ - zip(tp_model.named_parameters(), ep_model.named_parameters()): + for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): assert tp_name == ep_name if not is_moe_tensor(tp_param): if assert_grad_flag: @@ -97,8 +97,9 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ local_model (MoeModule) ep_model (MoeModule) """ - for (local_name, local_param), (ep_name, ep_param) in \ - zip(local_model.named_parameters(), ep_model.named_parameters()): + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): assert local_name == ep_name if "experts" not in local_name: if assert_grad_flag: @@ -141,14 +142,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2, - enable_hierarchical_comm=enable_hierarchical_comm + enable_hierarchical_comm=enable_hierarchical_comm, ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP") tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) - ep_model = ep_model.to(get_current_device()) - tp_model = tp_model.to(get_current_device()) - local_model = local_model.to(get_current_device()) + ep_model = ep_model.to(get_accelerator().get_current_device()) + tp_model = tp_model.to(get_accelerator().get_current_device()) + local_model = local_model.to(get_accelerator().get_current_device()) # sync ep param sync_moe_model_param(ep_model) @@ -163,11 +164,11 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size tp_grad_handler = MoeGradientHandler(tp_model) rank = dist.get_rank() - input_data = torch.randn(batch_size, dim, device=get_current_device()) + input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device()) micro_batch_size = batch_size // world_size index = rank * micro_batch_size # NOTE: ep & tp takes in sharded data for each process - shard_data = input_data.detach()[index:index + micro_batch_size] + shard_data = input_data.detach()[index : index + micro_batch_size] out_local = local_model(input_data) MOE_MANAGER.reset_loss() @@ -176,13 +177,15 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size out_ep = ep_model(shard_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_tp, out_ep, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" + assert torch.allclose( + out_tp, out_ep, atol=1e-6 + ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" try: - out_local_slice = out_local[index:index + micro_batch_size] - assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \ - f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" - except AssertionError as e: + out_local_slice = out_local[index : index + micro_batch_size] + assert torch.allclose( + out_ep, out_local_slice, atol=1e-6 + ), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" + except AssertionError: """ e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2 @@ -193,8 +196,7 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. """ warnings.warn( - "EP & TP may result in different behavior from local model. " - "Please check the comments for details." + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) out_local.mean().backward() @@ -208,10 +210,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) try: sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - except AssertionError as e: + except AssertionError: warnings.warn( - "EP & TP may result in different behavior from local model. " - "Please check the comments for details." + "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) @@ -219,14 +220,17 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size @pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("dim", [64]) -@pytest.mark.parametrize("config", [ - {"enable_hierarchical_comm": False}, - {"enable_hierarchical_comm": True}, -]) +@pytest.mark.parametrize( + "config", + [ + {"enable_hierarchical_comm": False}, + {"enable_hierarchical_comm": True}, + ], +) @rerun_if_address_is_in_use() def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_ep_tp(num_experts=8, batch_size=32, dim=32) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 95c0e715dc34..2f08a335de5a 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,11 +3,11 @@ import torch.nn as nn import colossalai +from colossalai.accelerator import get_accelerator from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device HIDDEN_SIZE = 4 INTERMEDIATE_SIZE = 8 @@ -46,7 +46,7 @@ def run_moe_init(expert_parallel): assert dist.get_rank(parallel_info_dict[1].dp_group) == rank model = nn.ModuleList([exp0, exp1, exp2]) - model = model.to(get_current_device()) + model = model.to(get_accelerator().get_current_device()) sync_moe_model_param(model) # MOE experts layout success when ep_size = 1 diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py index 7ba7fa6f6b7d..9f6167692d61 100644 --- a/tests/test_moe/test_moe_router.py +++ b/tests/test_moe/test_moe_router.py @@ -4,15 +4,21 @@ from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter -@pytest.mark.parametrize(["router", "num_groups"], [ - (Top1Router(), 1), - (Top2Router(), 1), - # (TopKRouter(num_selected_experts=3), 4), -]) -@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [ - (4, 5, 8), - (3, 4, 4), -]) +@pytest.mark.parametrize( + ["router", "num_groups"], + [ + (Top1Router(), 1), + (Top2Router(), 1), + # (TopKRouter(num_selected_experts=3), 4), + ], +) +@pytest.mark.parametrize( + ["batch_size", "seq_len", "num_experts"], + [ + (4, 5, 8), + (3, 4, 4), + ], +) def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): x = torch.randn((batch_size * seq_len, num_experts)).cuda() if num_groups > 1: @@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex router.train() if isinstance(router, TopKRouter): - _, combine_array, dispatch_mask = router(x, expert_capacity=2) + combine_array, dispatch_mask = router(x, expert_capacity=2) else: - _, combine_array, dispatch_mask = router(x) + combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) router.eval() if isinstance(router, TopKRouter): - _, combine_array, dispatch_mask = router(x, expert_capacity=2) + combine_array, dispatch_mask = router(x, expert_capacity=2) else: - _, combine_array, dispatch_mask = router(x) + combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index f0795a4c738f..1bff2106675e 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -4,102 +4,75 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def run_zero_test(local_rank, world_size, stage=1): +def run_zero_test(local_rank, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel() - optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - booster = Booster(plugin=plugin) - zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) - - torch_model = MoeModel() - for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - torch_param.data.copy_(zero_param.data) - torch_model = torch_model.cuda() - grad_handler = MoeGradientHandler(torch_model) - - # assert zero model - for (torch_name, torch_param), (zero_name, zero_param) in zip( - torch_model.named_parameters(), zero_model.module.named_parameters() - ): - assert zero_name == torch_name - assert torch.allclose(zero_param.data, torch_param.data) - - data = torch.randn(16, 4).cuda() + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP") + moe_model = MoeModel().bfloat16() + moe_optimizer = torch.optim.Adam(moe_model.parameters()) + moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + moe_booster = Booster(plugin=moe_plugin) + moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + zero_model = MoeModel().bfloat16() + delete_moe_info(zero_model) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + zero_booster = Booster(plugin=zero_plugin) + zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) + sync_local_from_ep(zero_model, moe_model) + + data = torch.randn(16, 4).bfloat16().cuda() label = torch.randint(0, 4, (16,)).cuda() - torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) - zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) - assert torch.allclose(torch_out, zero_out) - grad_handler.handle_gradient() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) + assert torch.allclose(zero_out, moe_out) - for (zero_name, zero_param), (torch_name, torch_param) in zip( - zero_model.module.named_parameters(), torch_model.named_parameters() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.module.named_parameters(), zero_model.module.named_parameters() ): - assert zero_name == torch_name - zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) - if hasattr(zero_param, "moe_info"): - assert len(zero_grad_list) == 0 - assert torch.allclose(zero_param.grad, torch_param.grad) + assert moe_name == zero_name + moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param)) + zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) + if hasattr(moe_param, "moe_info"): + assert len(moe_grad_list) == 0 + if stage == 1: + zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape) + else: + zero_grad = zero_grad_list[0].view(moe_param.grad.shape) + assert torch.allclose( + moe_param.grad, zero_grad, atol=1e-5 + ), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}" else: - assert len(zero_grad_list) > 0 - torch_grad_list = split_ddp_grad(torch_param.grad, world_size) - if stage == 2: - torch_grad_list = torch_grad_list[local_rank : local_rank + 1] - assert len(zero_grad_list) == len(torch_grad_list) - for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): - assert torch.allclose(zero_grad, torch_grad) + assert len(moe_grad_list) > 0 + assert len(moe_grad_list) == len(zero_grad_list) + for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list): + assert torch.allclose(moe_grad, zero_grad) -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, stage): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(parallel="EP") seed_all(42 + rank) - run_zero_test(rank, world_size, stage=1) - run_zero_test(rank, world_size, stage=2) + run_zero_test(rank, stage=stage) @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("stage", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) +def test_moe_zero_model(world_size, stage): + spawn(run_dist, world_size, stage=stage) if __name__ == "__main__": - test_moe_zero_model(world_size=2) + test_moe_zero_model(world_size=2, stage=1) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 0d2e2fb1b2d8..4f6067aaa10a 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -4,89 +4,80 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def run_zero_optim_test(local_rank, world_size, stage=1): +def run_zero_test(local_rank, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel() - zero_optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - booster = Booster(plugin=plugin) - zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) - - torch_model = MoeModel() - for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - torch_param.data.copy_(zero_param.data) - torch_optimizer = torch.optim.Adam(torch_model.parameters()) - torch_model = torch_model.cuda() - grad_handler = MoeGradientHandler(torch_model) - - for _ in range(2): - data = torch.randn(16, 4).cuda() / (local_rank + 1) - label = torch.randint(0, 4, (16,)).cuda() - run_fwd_bwd(torch_model, data, label, criterion, None) - run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) - grad_handler.handle_gradient() - - torch_optimizer.step() + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP") + moe_model = MoeModel().bfloat16() + moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0) + moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + moe_booster = Booster(plugin=moe_plugin) + moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + zero_model = MoeModel().bfloat16() + delete_moe_info(zero_model) + sync_local_from_ep(zero_model, moe_model) + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0) + zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + zero_booster = Booster(plugin=zero_plugin) + zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) + + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.named_parameters(), zero_model.named_parameters() + ): + if ".experts." in moe_name: + continue + assert moe_name == zero_name + assert torch.allclose( + moe_param.data, zero_param.data + ), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}" + + for _ in range(1): + data = torch.randn(2, 4).bfloat16().cuda() + label = torch.randint(0, 4, (2,)).cuda() + + moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + assert torch.allclose(zero_out, moe_out) + moe_optimizer.step() zero_optimizer.step() - for (torch_name, torch_param), (zero_name, zero_param) in zip( - torch_model.named_parameters(), zero_model.named_parameters() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.named_parameters(), zero_model.named_parameters() ): - assert torch.allclose( - torch_param.data, zero_param.data - ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + assert moe_name == zero_name + if is_moe_tensor(moe_param): + param_size = moe_param.shape[0] + zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size] + loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype) - torch_optimizer.zero_grad() + moe_optimizer.zero_grad() zero_optimizer.zero_grad() -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, stage): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(parallel="EP") - run_zero_optim_test(rank, world_size, stage=1) - run_zero_optim_test(rank, world_size, stage=2) + seed_all(42 + rank) + run_zero_test(rank, stage=stage) @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("stage", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_zero_optim(world_size): - spawn(run_dist, world_size) +def test_moe_zero_optim(world_size, stage): + spawn(run_dist, world_size, stage=stage) if __name__ == "__main__": - test_moe_zero_optim(world_size=2) + test_moe_zero_optim(world_size=2, stage=1) diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 6bbe3e4e8172..6d932156a270 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -8,7 +8,8 @@ import torch from torch import Tensor -from colossalai.utils import get_current_device, multi_tensor_applier +from colossalai.accelerator import get_accelerator +from colossalai.utils import multi_tensor_applier _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), @@ -64,9 +65,9 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class FusedAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.fused_adam = fused_optim.multi_tensor_adam self.dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -90,9 +91,9 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class CPUAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import CPUAdamBuilder + from colossalai.kernel.kernel_loader import CPUAdamLoader - cpu_optim = CPUAdamBuilder().load() + cpu_optim = CPUAdamLoader().load() self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) @@ -155,7 +156,9 @@ def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): rtol, atol = 1e-3, 1e-3 if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: rtol, atol = 4e-3, 4e-3 - check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) + check_adam_kernel( + FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_accelerator().get_current_device(), 3, rtol, atol + ) @pytest.mark.parametrize("adamw", [False, True]) diff --git a/tests/test_optimizer/test_lr_scheduler.py b/tests/test_optimizer/test_lr_scheduler.py new file mode 100644 index 000000000000..e0b084140595 --- /dev/null +++ b/tests/test_optimizer/test_lr_scheduler.py @@ -0,0 +1,20 @@ +import torch.nn as nn +from torch.optim import Adam + +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR + + +def test_lr_scheduler_save_load(): + model = nn.Linear(10, 10) + optimizer = Adam(model.parameters(), lr=1e-3) + scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2) + new_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=5, warmup_steps=2) + for _ in range(5): + scheduler.step() + state_dict = scheduler.state_dict() + new_scheduler.load_state_dict(state_dict) + assert state_dict == new_scheduler.state_dict() + + +if __name__ == "__main__": + test_lr_scheduler_save_load() diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index caf6e6bbbd42..6f5e734b7472 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -3,11 +3,11 @@ import torch.distributed as dist import colossalai +from colossalai.accelerator import get_accelerator from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device WORLD_SIZE = 2 @@ -19,7 +19,7 @@ def check_p2p_communication(): rank = dist.get_rank() - tensor = torch.ones(1, device=get_current_device()) + tensor = torch.ones(1, device=get_accelerator().get_current_device()) data = [ "tensor", tensor, diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index a5c465ba0b07..3ec1700045e3 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -4,13 +4,11 @@ import torch from einops import rearrange -from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN -from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN +from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.kernel.cuda_native import ColoAttention - from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention DTYPE = [torch.float16, torch.bfloat16, torch.float32] diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 5977c706fdd1..e4dc569b825b 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -4,15 +4,15 @@ from torch.distributed.distributed_c10d import _get_default_group import colossalai +from colossalai.accelerator import get_accelerator from colossalai.tensor import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.chunk import Chunk def dist_sum(x): - temp = torch.tensor([x], device=get_current_device()) + temp = torch.tensor([x], device=get_accelerator().get_current_device()) dist.all_reduce(temp) return temp.item() @@ -66,7 +66,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert my_chunk.cpu_shard.size(0) == 1024 // world_size assert my_chunk.device_type == "cpu" assert my_chunk.can_move - my_chunk.shard_move(get_current_device()) + my_chunk.shard_move(get_accelerator().get_current_device()) else: assert my_chunk.cuda_global_chunk.size(0) == 1024 assert my_chunk.device_type == "cuda" diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 21afff753ae6..3a9742e01566 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -5,11 +5,11 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd @@ -47,7 +47,7 @@ def exam_gpt_fwd_bwd( use_grad_checkpoint: bool = False, master_weights: bool = True, ): - init_device = get_current_device() + init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) ) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index 35323e516071..36a803492b6d 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -6,10 +6,10 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd @@ -53,7 +53,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): def exam_gemini_grad_acc( placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool ): - init_device = get_current_device() + init_device = get_accelerator().get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) ) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 152bf289502a..7f3c7176e99e 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -7,11 +7,11 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd @@ -47,7 +47,9 @@ def multi_chunk_init(model: torch.nn.Module, placement_config: dict): def single_chunk_init(model: torch.nn.Module, placement_config: dict): - model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config) + model = GeminiDDP( + model, chunk_init_device=get_accelerator().get_current_device(), pin_memory=True, **placement_config + ) return model @@ -63,7 +65,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() + init_dev = get_accelerator().get_current_device() model = model_builder().to(init_dev) for torch_p, p in zip(torch_model.parameters(), model.parameters()): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 405d7d789b01..71bb27b4aca1 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -5,11 +5,11 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd @@ -150,7 +150,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. model = GeminiDDP( model, - chunk_init_device=get_current_device(), + chunk_init_device=get_accelerator().get_current_device(), search_range_m=1, pin_memory=True, mixed_precision=mixed_precision, diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index e99f6d59ba8e..cf3658bf9920 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -2,8 +2,8 @@ import torch import colossalai +from colossalai.accelerator import get_accelerator from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.kit.model_zoo import model_zoo @@ -34,7 +34,7 @@ def exam_chunk_manager(): sharded_ddp_model = model_builder() chunk_manager = init_chunk_manager( sharded_ddp_model, - get_current_device(), + get_accelerator().get_current_device(), hidden_dim=128, search_range_m=1, min_chunk_size_m=0, diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 351ae5f67ff7..11f738615d16 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -7,9 +7,10 @@ from torch.testing import assert_close import colossalai +from colossalai.accelerator import get_accelerator from colossalai.testing import spawn from colossalai.testing.random import seed_all -from colossalai.utils import conditional_context, get_current_device +from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer @@ -28,7 +29,7 @@ def forward(self, x): def exam_zero_1_2_grad_acc(): local_rank = torch.distributed.get_rank() seed_all(2009) - device = get_current_device() + device = get_accelerator().get_current_device() # create model zero1_model = MlpModel().to(device) zero2_model = copy.deepcopy(zero1_model) @@ -71,7 +72,7 @@ def fwd_bwd_func(number, cur_data, check_flag): def exam_zero_1_grad_acc(sync): local_rank = torch.distributed.get_rank() seed_all(2008) - device = get_current_device() + device = get_accelerator().get_current_device() # create models zero_model = MlpModel() diff --git a/version.txt b/version.txt index 42045acae20f..c2c0004f0e2a 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.4 +0.3.5