diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index a34a60669031..2cad504f3391 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -140,7 +140,7 @@ jobs:
- name: Install Colossal-AI
run: |
- CUDA_EXT=1 pip install -v -e .
+ BUILD_EXT=1 pip install -v -e .
pip install -r requirements/requirements-test.txt
- name: Store Colossal-AI Cache
diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml
index 03f9c53f1d28..ae1a5275e5da 100644
--- a/.github/workflows/build_on_schedule.yml
+++ b/.github/workflows/build_on_schedule.yml
@@ -12,7 +12,7 @@ jobs:
if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
- image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
+ image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 90
steps:
@@ -55,7 +55,7 @@ jobs:
if: steps.check-avai.outputs.avai == 'true'
run: |
[ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
- CUDA_EXT=1 pip install -v -e .
+ BUILD_EXT=1 pip install -v -e .
cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/
pip install -r requirements/requirements-test.txt
diff --git a/.github/workflows/example_check_on_dispatch.yml b/.github/workflows/example_check_on_dispatch.yml
index 02e30f52a459..bba321fd2d59 100644
--- a/.github/workflows/example_check_on_dispatch.yml
+++ b/.github/workflows/example_check_on_dispatch.yml
@@ -45,9 +45,9 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.manual_check_matrix_preparation.outputs.matrix)}}
container:
- image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
+ image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
- timeout-minutes: 10
+ timeout-minutes: 15
steps:
- name: 📚 Checkout
uses: actions/checkout@v3
diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml
index 6d6952aa169a..fcff8e569ff7 100644
--- a/.github/workflows/example_check_on_pr.yml
+++ b/.github/workflows/example_check_on_pr.yml
@@ -77,7 +77,7 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.detect-changed-example.outputs.matrix)}}
container:
- image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
+ image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 20
concurrency:
diff --git a/.github/workflows/example_check_on_schedule.yml b/.github/workflows/example_check_on_schedule.yml
index 919fa5092a6c..abb9479492e7 100644
--- a/.github/workflows/example_check_on_schedule.yml
+++ b/.github/workflows/example_check_on_schedule.yml
@@ -34,7 +34,7 @@ jobs:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
- image: hpcaitech/pytorch-cuda:2.0.0-11.7.0
+ image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
timeout-minutes: 10
steps:
- name: 📚 Checkout
diff --git a/.github/workflows/run_chatgpt_examples.yml b/.github/workflows/run_chatgpt_examples.yml
index f9e9f400962e..bb0ceb4a8296 100644
--- a/.github/workflows/run_chatgpt_examples.yml
+++ b/.github/workflows/run_chatgpt_examples.yml
@@ -18,7 +18,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
- image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /data/scratch/github_actions/chat:/data/scratch/github_actions/chat --shm-size=10.24gb
timeout-minutes: 30
defaults:
diff --git a/.github/workflows/run_chatgpt_unit_tests.yml b/.github/workflows/run_chatgpt_unit_tests.yml
index ec5c8ffa319f..7986889e006b 100644
--- a/.github/workflows/run_chatgpt_unit_tests.yml
+++ b/.github/workflows/run_chatgpt_unit_tests.yml
@@ -20,7 +20,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
- image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /data/scratch/chatgpt:/data/scratch/chatgpt
timeout-minutes: 30
defaults:
diff --git a/.github/workflows/run_colossalqa_unit_tests.yml b/.github/workflows/run_colossalqa_unit_tests.yml
index 763db277289f..00944b92d9b6 100644
--- a/.github/workflows/run_colossalqa_unit_tests.yml
+++ b/.github/workflows/run_colossalqa_unit_tests.yml
@@ -19,7 +19,7 @@ jobs:
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
- image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
volumes:
- /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa
- /data/scratch/llama-tiny:/data/scratch/llama-tiny
@@ -51,4 +51,4 @@ jobs:
TEST_DATA_PATH_EN: /data/scratch/test_data_colossalqa/companies.txt
TEST_DATA_PATH_ZH: /data/scratch/test_data_colossalqa/companies_zh.txt
TEST_DOCUMENT_LOADER_DATA_PATH: /data/scratch/test_data_colossalqa/tests/*
- SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path
\ No newline at end of file
+ SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path
diff --git a/MANIFEST.in b/MANIFEST.in
index ad26b634ac3e..f0a5611efc7d 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,4 +1,4 @@
include *.txt README.md
recursive-include requirements *.txt
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
-recursive-include op_builder *.py
+recursive-include extensions *.py *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi
diff --git a/README.md b/README.md
index 13757eece7db..3963fe2fb5d6 100644
--- a/README.md
+++ b/README.md
@@ -398,10 +398,10 @@ pip install colossalai
**Note: only Linux is supported for now.**
-However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`.
+However, if you want to build the PyTorch extensions during installation, you can set `BUILD_EXT=1`.
```bash
-CUDA_EXT=1 pip install colossalai
+BUILD_EXT=1 pip install colossalai
```
**Otherwise, CUDA kernels will be built during runtime when you actually need them.**
@@ -429,7 +429,7 @@ By default, we do not compile CUDA/C++ kernels. ColossalAI will build them durin
If you want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
```shell
-CUDA_EXT=1 pip install .
+BUILD_EXT=1 pip install .
```
For Users with CUDA 10.2, you can still build ColossalAI from source. However, you need to manually download the cub library and copy it to the corresponding directory.
@@ -445,7 +445,7 @@ unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
# install
-CUDA_EXT=1 pip install .
+BUILD_EXT=1 pip install .
```
(back to top)
diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py
index c0e257f54a07..e67e16231cc2 100644
--- a/applications/Chat/coati/dataset/sft_dataset.py
+++ b/applications/Chat/coati/dataset/sft_dataset.py
@@ -49,12 +49,13 @@ def _preprocess(
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
- sequences = [s + t for s, t in zip(sources, targets)]
+ sequences = [s + t + tokenizer.eos_token for s, t in zip(sources, targets)]
sequences_token = tokenizer(
- sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
)
+
sources_token = tokenizer(
- sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
+ sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False
)
assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
@@ -65,7 +66,8 @@ def _preprocess(
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
- labels[i][-pad_len:] = IGNORE_INDEX
+ if pad_len>0:
+ labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][: pad_len + source_len] = IGNORE_INDEX
diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py
index d6966689885e..330e4e0e395e 100644
--- a/applications/Chat/coati/trainer/ppo.py
+++ b/applications/Chat/coati/trainer/ppo.py
@@ -10,7 +10,7 @@
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
-from colossalai.utils import get_current_device
+from colossalai.accelerator import get_accelerator
from .base import OnPolicyTrainer
from .callbacks import Callback
@@ -105,7 +105,7 @@ def __init__(
self.critic_optim = critic_optim
self.offload_inference_models = offload_inference_models
- self.device = get_current_device()
+ self.device = get_accelerator().get_current_device()
def _before_fit(
self,
diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py
index 7129edb060ef..95f01678640c 100644
--- a/applications/Chat/coati/trainer/strategies/colossalai.py
+++ b/applications/Chat/coati/trainer/strategies/colossalai.py
@@ -6,7 +6,6 @@
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
-from colossalai.utils import get_current_device
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
from .ddp import DDPStrategy
@@ -158,9 +157,19 @@ def __init__(
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
+ # colossalai has changed api for get_current_device in 0.3.4 version or newer
+ try:
+ from colossalai.accelerator import get_accelerator
+
+ chunk_init_device = get_accelerator().get_current_device()
+ except:
+ from colossalai.utils import get_current_device
+
+ chunk_init_device = get_current_device()
+
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
- chunk_init_device=get_current_device(),
+ chunk_init_device=chunk_init_device,
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
diff --git a/applications/Chat/examples/train_sft.sh b/applications/Chat/examples/train_sft.sh
index 0fb4da3d3ce8..b7d176847d9c 100755
--- a/applications/Chat/examples/train_sft.sh
+++ b/applications/Chat/examples/train_sft.sh
@@ -25,4 +25,4 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--accumulation_steps 8 \
--lr 2e-5 \
--max_datasets_size 512 \
- --max_epochs 1
+ --max_epochs 1
\ No newline at end of file
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
index a2cfb2ef6264..327651f4e645 100644
--- a/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
@@ -1,20 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-import numpy as np
import os
-import random
from dataclasses import dataclass
-from typing import Dict, List, Union, Sequence, Optional, Iterator, Callable
+from typing import Dict, Iterator, List, Optional, Sequence, Union
import torch
-from datasets import dataset_dict, load_from_disk
+import torch.nn.functional as F
from datasets import Dataset as HFDataset
-from torch.distributed import ProcessGroup
-from torch.distributed.distributed_c10d import _get_default_group
-from torch.utils.data import ConcatDataset, Dataset, DataLoader, DistributedSampler
+from datasets import dataset_dict, load_from_disk
+from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
-import torch.nn.functional as F
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
@@ -62,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
+ padding: str = "max_length"
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
@@ -106,10 +103,11 @@ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
- # pad to max
- to_pad = self.max_length - input_ids.size(1)
- input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
- labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
+ if self.padding == "max_length":
+ # pad to max
+ to_pad = self.max_length - input_ids.size(1)
+ input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
+ labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
elif self.tokenizer.padding_side == "left":
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
@@ -171,49 +169,3 @@ def __len__(self) -> int:
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
-
-
-def setup_distributed_dataloader(
- dataset: DatasetType,
- batch_size: int = 1,
- shuffle: bool = False,
- seed: int = 1024,
- drop_last: bool = False,
- pin_memory: bool = False,
- num_workers: int = 0,
- collate_fn: Callable[[Sequence[Dict[str, Union[str, List[int]]]]], Dict[str, torch.Tensor]] = None,
- process_group: Optional[ProcessGroup] = None,
- **kwargs,
-) -> DataLoader:
- """
- Setup dataloader for distributed training.
- """
- _kwargs = kwargs.copy()
- process_group = process_group or _get_default_group()
- sampler = StatefulDistributedSampler(
- dataset=dataset,
- num_replicas=process_group.size(),
- rank=process_group.rank(),
- shuffle=shuffle,
- seed=seed,
- drop_last=drop_last,
- )
-
- # Deterministic dataloader
- def seed_worker(worker_id: int) -> None:
- worker_seed = seed
- np.random.seed(worker_seed)
- torch.manual_seed(worker_seed)
- random.seed(worker_seed)
-
- return DataLoader(
- dataset=dataset,
- batch_size=batch_size,
- sampler=sampler,
- num_workers=num_workers,
- collate_fn=collate_fn,
- pin_memory=pin_memory,
- drop_last=drop_last,
- worker_init_fn=seed_worker,
- **_kwargs,
- )
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
index 1926ec78aba8..6c048c3b18cf 100644
--- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
@@ -1,15 +1,15 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
+import math
from types import MethodType
from typing import Optional, Tuple
import torch
+import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
-from flash_attn.bert_padding import pad_input, unpad_input
-from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
-from flash_attn.ops.rms_norm import rms_norm
+from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
@@ -19,194 +19,334 @@
repeat_kv,
)
+from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
+if get_accelerator().name == "cuda":
+ from flash_attn.bert_padding import pad_input, unpad_input
+ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
+ from flash_attn.ops.rms_norm import rms_norm
-def _prepare_decoder_attention_mask(
- self: LlamaModel,
- attention_mask: torch.BoolTensor,
- input_shape: torch.Size,
- inputs_embeds: torch.Tensor,
- past_key_values_length: int,
-) -> Optional[torch.Tensor]:
- """
- Decoder attetion mask
- """
- if past_key_values_length > 0 and attention_mask is not None:
- attention_mask = torch.cat(
- tensors=(
- torch.full(
- size=(input_shape[0], past_key_values_length),
- fill_value=True,
- dtype=attention_mask.dtype,
- device=attention_mask.device,
+ def _prepare_decoder_attention_mask(
+ self: LlamaModel,
+ attention_mask: torch.BoolTensor,
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ) -> Optional[torch.Tensor]:
+ """
+ Decoder attetion mask
+ """
+ if past_key_values_length > 0 and attention_mask is not None:
+ attention_mask = torch.cat(
+ tensors=(
+ torch.full(
+ size=(input_shape[0], past_key_values_length),
+ fill_value=True,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ ),
+ attention_mask,
),
- attention_mask,
- ),
- dim=-1,
- ) # (bsz, past_key_values_length + q_len)
- if attention_mask is not None and torch.all(attention_mask):
- return None # Faster
- return attention_mask
-
-
-def attention_forward(
- self: LlamaAttention,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- **kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """
- Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
- """
- if output_attentions:
- logger.warning(
- "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
- "return `None` instead."
- )
-
- bsz, q_len, _ = hidden_states.size()
+ dim=-1,
+ ) # (bsz, past_key_values_length + q_len)
+ if attention_mask is not None and torch.all(attention_mask):
+ return None # Faster
+ return attention_mask
- if self.config.pretraining_tp > 1:
- q_slicing, kv_slicing = (
- dim // self.config.pretraining_tp
- for dim in (
- self.num_heads * self.head_dim,
- self.num_key_value_heads * self.head_dim,
+ def attention_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
+ """
+ if output_attentions:
+ logger.warning(
+ "Argument `output_attentions` is not supported for flash-attention patched `LlamaAttention`, "
+ "return `None` instead."
)
- ) # `Tuple[int, int]`
- q_slices, k_slices, v_slices = (
- proj.weight.split(slicing, dim=0)
- for proj, slicing in (
- (self.q_proj, q_slicing),
- (self.k_proj, kv_slicing),
- (self.v_proj, kv_slicing),
+
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ q_slicing, kv_slicing = (
+ dim // self.config.pretraining_tp
+ for dim in (
+ self.num_heads * self.head_dim,
+ self.num_key_value_heads * self.head_dim,
+ )
+ ) # `Tuple[int, int]`
+ q_slices, k_slices, v_slices = (
+ proj.weight.split(slicing, dim=0)
+ for proj, slicing in (
+ (self.q_proj, q_slicing),
+ (self.k_proj, kv_slicing),
+ (self.v_proj, kv_slicing),
+ )
+ ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
+ q, k, v = (
+ torch.cat(
+ [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
+ dim=-1,
+ )
+ for slices in (q_slices, k_slices, v_slices)
)
- ) # Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor], Tuple[torch.Tensor]]
+ # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
+ # (bsz, q_len, num_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim)
+ else:
+ q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
+ # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
+ # (bsz, q_len, num_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim),
+ # (bsz, q_len, num_key_value_heads * head_dim)
+
+ # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
+ # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
+ # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
q, k, v = (
- torch.cat(
- [F.linear(hidden_states, slices[i]) for i in range(self.config.pretraining_tp)],
- dim=-1,
+ states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
+ for states, num_heads in (
+ (q, self.num_heads),
+ (k, self.num_key_value_heads),
+ (v, self.num_key_value_heads),
)
- for slices in (q_slices, k_slices, v_slices)
- )
- # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
- # (bsz, q_len, num_heads * head_dim),
- # (bsz, q_len, num_key_value_heads * head_dim),
- # (bsz, q_len, num_key_value_heads * head_dim)
- else:
- q, k, v = (proj(hidden_states) for proj in (self.q_proj, self.k_proj, self.v_proj))
- # `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` of shape:
- # (bsz, q_len, num_heads * head_dim),
- # (bsz, q_len, num_key_value_heads * head_dim),
- # (bsz, q_len, num_key_value_heads * head_dim)
-
- # (bsz, q_len, num_heads * head_dim) -> (bsz, num_heads, q_len, head_dim);
- # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim);
- # (bsz, q_len, num_key_value_heads * head_dim) -> (bsz, num_key_value_heads, q_len, head_dim)
- q, k, v = (
- states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
- for states, num_heads in (
- (q, self.num_heads),
- (k, self.num_key_value_heads),
- (v, self.num_key_value_heads),
)
- )
- kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
- past_kv_len = 0
- if past_key_value is not None:
- # if `past_key_value` is not None, `kv_len` > `q_len`.
- past_kv_len = past_key_value[0].shape[-2]
- kv_len += past_kv_len
-
- # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
- cos, sin = self.rotary_emb(v, seq_len=kv_len)
- # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
- q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
- if past_key_value is not None:
- # reuse k, v, self_attention
- k = torch.cat([past_key_value[0], k], dim=2)
- v = torch.cat([past_key_value[1], v], dim=2)
-
- past_key_value = (k, v) if use_cache else None
-
- # repeat k/v heads if n_kv_heads < n_heads
- k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
- # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
- v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
- # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
-
- key_padding_mask = attention_mask
- # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
- q, k, v = (states.transpose(1, 2) for states in (q, k, v))
-
- if past_kv_len > 0:
- q = torch.cat(
- tensors=(
- torch.full(
- size=(bsz, past_kv_len, self.num_heads, self.head_dim),
- fill_value=0.0,
- dtype=q.dtype,
- device=q.device,
+ kv_len = k.shape[-2] # initially, `kv_len` == `q_len`
+ past_kv_len = 0
+ if past_key_value is not None:
+ # if `past_key_value` is not None, `kv_len` > `q_len`.
+ past_kv_len = past_key_value[0].shape[-2]
+ kv_len += past_kv_len
+
+ # two `torch.Tensor` objs of shape (1, 1, kv_len, head_dim)
+ cos, sin = self.rotary_emb(v, seq_len=kv_len)
+ # (bsz, num_heads, q_len, head_dim), (bsz, num_key_value_heads, q_len, head_dim)
+ q, k = apply_rotary_pos_emb(q=q, k=k, cos=cos, sin=sin, position_ids=position_ids)
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ k = torch.cat([past_key_value[0], k], dim=2)
+ v = torch.cat([past_key_value[1], v], dim=2)
+
+ past_key_value = (k, v) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ k = repeat_kv(hidden_states=k, n_rep=self.num_key_value_groups)
+ # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
+ v = repeat_kv(hidden_states=v, n_rep=self.num_key_value_groups)
+ # (bsz, num_key_value_heads, q_len, head_dim) -> (bsz, num_heads, q_len, head_dim)
+
+ key_padding_mask = attention_mask
+ # (bsz, num_heads, q_len, head_dim) -> (bsz, q_len, num_heads, head_dim)
+ q, k, v = (states.transpose(1, 2) for states in (q, k, v))
+
+ if past_kv_len > 0:
+ q = torch.cat(
+ tensors=(
+ torch.full(
+ size=(bsz, past_kv_len, self.num_heads, self.head_dim),
+ fill_value=0.0,
+ dtype=q.dtype,
+ device=q.device,
+ ),
+ q,
),
- q,
- ),
- dim=1,
- ) # (bsz, past_kv_len + q_len, num_heads, head_dim)
-
- if key_padding_mask is None:
- # (bsz, past_kv_len + q_len, num_heads, head_dim)
- output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
- output = rearrange(output, pattern="... h d -> ... (h d)") # (bsz, past_kv_len + q_len, num_heads * head_dim)
- else:
- q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
- kv, _, cu_kv_lens, max_kv_len = unpad_input(
- hidden_states=torch.stack(tensors=(k, v), dim=2),
- attention_mask=key_padding_mask,
- )
- output_unpad = flash_attn_varlen_kvpacked_func(
- q=q,
- kv=kv,
- cu_seqlens_q=cu_q_lens,
- cu_seqlens_k=cu_kv_lens,
- max_seqlen_q=max_q_len,
- max_seqlen_k=max_kv_len,
- dropout_p=0.0,
- softmax_scale=None,
- causal=True,
- )
- output = pad_input(
- hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
- indices=indices,
- batch=bsz,
- seqlen=past_kv_len + q_len,
- ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
-
- if past_kv_len > 0:
- # Strip off the zero query outputs.
- output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
- output = self.o_proj(output) # (bsz, q_len, hidden_size)
- return output, None, past_key_value
-
-
-def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
- """
- Formard function for RMS Norm
- """
- return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
-
-
-def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
- for name, module in model.named_modules():
- if isinstance(module, LlamaAttention):
- module.forward = MethodType(attention_forward, module)
- if isinstance(module, LlamaModel):
- module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
- if isinstance(module, LlamaRMSNorm):
- module.forward = MethodType(rms_norm_forward, module)
+ dim=1,
+ ) # (bsz, past_kv_len + q_len, num_heads, head_dim)
+
+ if key_padding_mask is None:
+ # (bsz, past_kv_len + q_len, num_heads, head_dim)
+ output = flash_attn_func(q=q, k=k, v=v, dropout_p=0.0, softmax_scale=None, causal=True) # (bsz, )
+ output = rearrange(
+ output, pattern="... h d -> ... (h d)"
+ ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
+ else:
+ q, indices, cu_q_lens, max_q_len = unpad_input(hidden_states=q, attention_mask=key_padding_mask)
+ kv, _, cu_kv_lens, max_kv_len = unpad_input(
+ hidden_states=torch.stack(tensors=(k, v), dim=2),
+ attention_mask=key_padding_mask,
+ )
+ output_unpad = flash_attn_varlen_kvpacked_func(
+ q=q,
+ kv=kv,
+ cu_seqlens_q=cu_q_lens,
+ cu_seqlens_k=cu_kv_lens,
+ max_seqlen_q=max_q_len,
+ max_seqlen_k=max_kv_len,
+ dropout_p=0.0,
+ softmax_scale=None,
+ causal=True,
+ )
+ output = pad_input(
+ hidden_states=rearrange(output_unpad, pattern="nnz h d -> nnz (h d)"),
+ indices=indices,
+ batch=bsz,
+ seqlen=past_kv_len + q_len,
+ ) # (bsz, past_kv_len + q_len, num_heads * head_dim)
+
+ if past_kv_len > 0:
+ # Strip off the zero query outputs.
+ output = output[:, past_kv_len:, ...] # (bsz, q_len, num_heads * head_dim)
+ output = self.o_proj(output) # (bsz, q_len, hidden_size)
+ return output, None, past_key_value
+
+ def rms_norm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Formard function for RMS Norm
+ """
+ return rms_norm(x=hidden_states, weight=self.weight, epsilon=self.variance_epsilon)
+
+ def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaAttention):
+ module.forward = MethodType(attention_forward, module)
+ if isinstance(module, LlamaModel):
+ module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
+ if isinstance(module, LlamaRMSNorm):
+ module.forward = MethodType(rms_norm_forward, module)
+
+elif get_accelerator().name == "npu":
+ import torch_npu
+
+ class NPULlamaAttention(LlamaAttention):
+ use_flash: bool = True
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.setup()
+
+ def setup(self):
+ self._softmax_scale = 1 / math.sqrt(self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split(
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+ )
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if not self.use_flash:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+ else:
+ attn_output, *_ = torch_npu.npu_fusion_attention(
+ query_states,
+ key_states,
+ value_states,
+ self.num_heads,
+ "BNSD",
+ atten_mask=attention_mask.bool(),
+ scale=self._softmax_scale,
+ padding_mask=None,
+ pre_tockens=65535,
+ next_tockens=0,
+ keep_prob=1.0,
+ inner_precise=0,
+ )
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum(
+ [F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
+ )
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ class NPURMSNorm(LlamaRMSNorm):
+ def forward(self, hidden_states):
+ return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
+
+ def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaAttention):
+ module.__class__ = NPULlamaAttention
+ module.setup()
+ if isinstance(module, LlamaRMSNorm):
+ module.__class__ = NPURMSNorm
diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
index 9f6c9c1cc6f3..21d769f3c49f 100644
--- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
+++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py
@@ -17,7 +17,7 @@
def unwrap(model):
if hasattr(model, "module"):
- return unwrap_model(model.module)
+ return model.unwrap()
else:
return model
diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA-2/inference_example.py
index 123290d45eab..77e18d8b5939 100644
--- a/applications/Colossal-LLaMA-2/inference_example.py
+++ b/applications/Colossal-LLaMA-2/inference_example.py
@@ -1,17 +1,16 @@
import argparse
-import os
import torch
+from colossal_llama2.dataset.conversation import default_conversation
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
from colossalai.logging import get_dist_logger
-from transformers import AutoTokenizer, AutoModelForCausalLM
logger = get_dist_logger()
def load_model(model_path, device="cuda", **kwargs):
- logger.info(
- "Please check whether the tokenizer and model weights are properly stored in the same folder."
- )
+ logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.")
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
model.to(device)
@@ -27,31 +26,50 @@ def load_model(model_path, device="cuda", **kwargs):
def generate(args):
model, tokenizer = load_model(model_path=args.model_path, device=args.device)
- BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
- input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
-
- inputs = tokenizer(args.input_txt, return_tensors='pt').to(args.device)
- output = model.generate(**inputs,
- max_new_tokens=args.max_new_tokens,
- do_sample=args.do_sample,
- temperature=args.temperature,
- top_k=args.top_k,
- top_p=args.top_p,
- num_return_sequences=1)
- response = tokenizer.decode(output.cpu()[0], skip_special_tokens=True)[len(input_txt):]
+ if args.prompt_style == "sft":
+ conversation = default_conversation.copy()
+ conversation.append_message("Human", args.input_txt)
+ input_txt = conversation.get_prompt()
+ else:
+ BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
+ input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
+
+ inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
+ num_input_tokens = inputs["input_ids"].shape[-1]
+ output = model.generate(
+ **inputs,
+ max_new_tokens=args.max_new_tokens,
+ do_sample=args.do_sample,
+ temperature=args.temperature,
+ top_k=args.top_k,
+ top_p=args.top_p,
+ num_return_sequences=1,
+ )
+ response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
- parser.add_argument('--model_path', type=str, default="hpcai-tech/Colossal-LLaMA-2-7b-base", help="HF repo name or local path of the model")
- parser.add_argument('--device', type=str, default="cuda:0", help="Set the device")
- parser.add_argument('--max_new_tokens', type=int, default=512, help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt")
- parser.add_argument('--do_sample', type=bool, default=True, help="Set whether or not to use sampling")
- parser.add_argument('--temperature', type=float, default=0.3, help="Set temperature value")
- parser.add_argument('--top_k', type=int, default=50, help="Set top_k value for top-k-filtering")
- parser.add_argument('--top_p', type=float, default=0.95, help="Set top_p value for generation")
- parser.add_argument('--input_txt', type=str, default="明月松间照,", help="The prompt input to the model")
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="hpcai-tech/Colossal-LLaMA-2-7b-base",
+ help="HF repo name or local path of the model",
+ )
+ parser.add_argument("--device", type=str, default="cuda:0", help="Set the device")
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=512,
+ help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
+ )
+ parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling")
+ parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
+ parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
+ parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
+ parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model")
+ parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt")
args = parser.parse_args()
- generate(args)
\ No newline at end of file
+ generate(args)
diff --git a/applications/Colossal-LLaMA-2/train.example.sh b/applications/Colossal-LLaMA-2/train.example.sh
index 276d9ce99d42..6a1c887bf6cc 100644
--- a/applications/Colossal-LLaMA-2/train.example.sh
+++ b/applications/Colossal-LLaMA-2/train.example.sh
@@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.
--warmup_steps 100 \
--use_grad_checkpoint \
--use_flash_attn \
+ --pad_token "unk"
diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py
index 41b4ef031b46..2e4bab75a085 100644
--- a/applications/Colossal-LLaMA-2/train.py
+++ b/applications/Colossal-LLaMA-2/train.py
@@ -1,45 +1,40 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
-Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
+Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team
"""
-import json
import argparse
+import json
import os
import resource
from contextlib import nullcontext
-from tqdm import tqdm
import torch
import torch.distributed as dist
+from colossal_llama2.dataset.loader import (
+ DataCollatorForSupervisedDataset,
+ StatefulDistributedSampler,
+ load_tokenized_dataset,
+)
+from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
+from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
+from colossal_llama2.utils.froze import freeze_non_embeds_parameters
+from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter
-from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
+from tqdm import tqdm
+from transformers import LlamaForCausalLM, LlamaTokenizer
import colossalai
+from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
-from colossalai.booster.plugin import (
- GeminiPlugin,
- LowLevelZeroPlugin,
- HybridParallelPlugin,
-)
+from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
-from colossal_llama2.dataset.loader import (
- load_tokenized_dataset,
- setup_distributed_dataloader,
- DataCollatorForSupervisedDataset,
- StatefulDistributedSampler,
-)
-
-from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
-from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
-from colossal_llama2.utils.froze import freeze_non_embeds_parameters
-
def get_model_numel(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
@@ -90,6 +85,7 @@ def main() -> None:
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
+ parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
@@ -115,6 +111,12 @@ def main() -> None:
default=False,
help="Use flash-attention",
)
+ parser.add_argument(
+ "--use_neft",
+ action="store_true",
+ default=False,
+ help="Use NEFTune",
+ )
parser.add_argument(
"--freeze_non_embeds_params",
action="store_true",
@@ -123,6 +125,8 @@ def main() -> None:
)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--zero", type=int, default=1)
+ parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
+ parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
args = parser.parse_args()
with open(args.config_file, "w") as f:
@@ -132,6 +136,7 @@ def main() -> None:
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
+ accelerator = get_accelerator()
coordinator = DistCoordinator()
# ==============================
@@ -149,6 +154,7 @@ def main() -> None:
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
+ enable_gradient_accumulation=(args.accumulation_steps > 1),
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -156,6 +162,7 @@ def main() -> None:
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
+ enable_gradient_accumulation=(args.accumulation_steps > 1),
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -189,7 +196,10 @@ def main() -> None:
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
- tokenizer.pad_token = tokenizer.unk_token
+ if args.pad_token == "eos":
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.pad_token == "unk":
+ tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
@@ -200,29 +210,36 @@ def main() -> None:
coordinator.print_on_master(f"Load dataset: {args.dataset}")
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
- dataloader = setup_distributed_dataloader(
+ data_collator = DataCollatorForSupervisedDataset(
+ tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
+ )
+ dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.micro_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
+ distributed_sampler_cls=StatefulDistributedSampler,
)
coordinator.print_on_master(
- f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx = (
- LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
+ LazyInitContext(default_device=get_current_device())
+ if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
+ else nullcontext()
)
with init_ctx:
- model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
+ model = LlamaForCausalLM.from_pretrained(args.pretrained)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)
+ # this is essential, otherwise the grad checkpoint will not work.
+ model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
@@ -244,12 +261,14 @@ def main() -> None:
adamw_mode=True,
)
+ if args.warmup_steps is None:
+ args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
+ coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
+
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
- total_steps=args.num_epochs * len(dataloader),
- warmup_steps=args.warmup_steps
- if args.warmup_steps is not None
- else int(args.num_epochs * len(dataloader) * 0.025),
+ total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
+ warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
@@ -265,11 +284,9 @@ def main() -> None:
torch.set_default_dtype(torch.float)
- if args.load_checkpoint is None:
- coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
- booster.load_model(model, args.pretrained, strict=False)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
+ coordinator.print_on_master(
+ f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ )
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
@@ -296,87 +313,109 @@ def main() -> None:
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
+ f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
+ f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
- num_steps_per_epoch = len(dataloader)
+ if args.use_neft:
+ coordinator.print_on_master("Activate NEFTune.")
+ model, handle = activate_neftune(model)
+
+ num_steps_per_epoch = len(dataloader) // args.accumulation_steps
# If resume training, set the sampler start index to the correct value
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
- with tqdm(
- iterable=enumerate(dataloader, start=start_step),
+ pbar = tqdm(
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
- initial=start_step,
- ) as pbar:
- for step, batch in pbar:
- batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
+ initial=start_step // args.accumulation_steps,
+ )
+ total_loss = torch.tensor(0.0, device=get_current_device())
+ for step, batch in enumerate(dataloader, start=start_step):
+ batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
- batch_output = model(**batch)
+ batch_output = model(**batch)
- loss = batch_output.loss
+ loss = batch_output.loss / args.accumulation_steps
+ total_loss.add_(loss.data)
- booster.backward(loss=loss, optimizer=optimizer)
+ booster.backward(loss=loss, optimizer=optimizer)
+ if (step + 1) % args.accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
- all_reduce_mean(tensor=loss)
- pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
+ all_reduce_mean(tensor=total_loss)
+ pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
if coordinator.is_master():
- global_step = epoch * num_steps_per_epoch + step
- writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=global_step)
+ global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
+ writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
- # Save modeling.
-
- if (args.save_interval > 0 and (step + 1) % args.save_interval == 0) or (step + 1) == len(dataloader):
- coordinator.print_on_master("\nStart saving model checkpoint with running states")
- save_checkpoint(
- save_dir=args.save_dir,
- booster=booster,
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- epoch=epoch,
- step=step + 1,
- batch_size=args.micro_batch_size,
- coordinator=coordinator,
- )
- coordinator.print_on_master(
- f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
- )
-
- # Delete CUDA cache.
- # del batch, batch_labels, batch_output, loss
- torch.cuda.empty_cache()
+ total_loss.fill_(0.0)
+ pbar.update()
+ # Save modeling.
+
+ if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
+ step + 1
+ ) == len(dataloader):
+ coordinator.print_on_master("\nStart saving model checkpoint with running states")
+
+ if args.use_neft:
+ coordinator.print_on_master("Deactivate NEFTune before saving model.")
+ deactivate_neftune(model, handle)
+
+ accelerator.empty_cache()
+ save_checkpoint(
+ save_dir=args.save_dir,
+ booster=booster,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ epoch=epoch,
+ step=step + 1,
+ batch_size=args.micro_batch_size,
+ coordinator=coordinator,
+ )
+ coordinator.print_on_master(
+ f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
+ )
+
+ if args.use_neft:
+ coordinator.print_on_master("Activate NEFTune.")
+ model, handle = activate_neftune(model)
+
+ # Delete cache.
+ # del batch, batch_labels, batch_output, loss
+ accelerator.empty_cache()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
start_step = 0
+ if args.use_neft:
+ coordinator.print_on_master("Deactivate NEFTune.")
+ deactivate_neftune(model, handle)
+
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
- coordinator.print_on_master(
- f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
- )
+ coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+ coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
diff --git a/applications/Colossal-LLaMA-2/train_sft.example.sh b/applications/Colossal-LLaMA-2/train_sft.example.sh
index dcb11515d48f..d87f9ef82f4f 100755
--- a/applications/Colossal-LLaMA-2/train_sft.example.sh
+++ b/applications/Colossal-LLaMA-2/train_sft.example.sh
@@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
TENSORBOARD_DIR="${PARENT_TENSORBOARD_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
-colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_sft.py \
+colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.py \
--pretrained $PRETRAINED_MODEL_PATH \
--dataset ${dataset[@]} \
--plugin "zero2" \
@@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_
--use_grad_checkpoint \
--use_flash_attn \
--use_neft \
+ --pad_token "eos"
diff --git a/applications/Colossal-LLaMA-2/train_sft.py b/applications/Colossal-LLaMA-2/train_sft.py
deleted file mode 100644
index fd9e1cd3e747..000000000000
--- a/applications/Colossal-LLaMA-2/train_sft.py
+++ /dev/null
@@ -1,403 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-"""
-Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
-"""
-
-import argparse
-import json
-import os
-import resource
-from contextlib import nullcontext
-
-import torch
-import torch.distributed as dist
-from colossal_llama2.dataset.loader import (
- DataCollatorForSupervisedDataset,
- StatefulDistributedSampler,
- load_tokenized_dataset,
- setup_distributed_dataloader,
-)
-from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
-from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
-from colossal_llama2.utils.froze import freeze_non_embeds_parameters
-from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
-from torch.utils.tensorboard import SummaryWriter
-from tqdm import tqdm
-from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
-
-import colossalai
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
-from colossalai.cluster import DistCoordinator
-from colossalai.lazy import LazyInitContext
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-from colossalai.utils import get_current_device
-
-
-def get_model_numel(model: torch.nn.Module) -> int:
- return sum(p.numel() for p in model.parameters())
-
-
-def format_numel_str(numel: int) -> str:
- B = 1024**3
- M = 1024**2
- K = 1024
- if numel >= B:
- return f"{numel / B:.2f} B"
- elif numel >= M:
- return f"{numel / M:.2f} M"
- elif numel >= K:
- return f"{numel / K:.2f} K"
- else:
- return f"{numel}"
-
-
-def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
- dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
- tensor.div_(dist.get_world_size())
- return tensor
-
-
-def main() -> None:
- # ==============================
- # Parse Arguments
- # ==============================
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--pretrained",
- type=str,
- default=None,
- help="Address of the pre-trained modeling",
- )
- parser.add_argument("--dataset", nargs="+", default=[])
- parser.add_argument(
- "--plugin",
- type=str,
- default="gemini",
- choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
- help="Choose which plugin to use",
- )
- parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
- parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
- parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
- parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
- parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
- parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
- parser.add_argument("--accumulation_steps", type=int, default=8, help="Number of accumulation steps")
- parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
- parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
- parser.add_argument("--max_length", type=int, default=4096, help="Model max length")
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default="fp16",
- choices=["fp16", "bf16"],
- help="Mixed precision",
- )
- parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
- parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
- parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
- parser.add_argument(
- "--use_grad_checkpoint",
- action="store_true",
- default=False,
- help="Use gradient checkpointing",
- )
- parser.add_argument(
- "--use_flash_attn",
- action="store_true",
- default=False,
- help="Use flash-attention",
- )
- parser.add_argument(
- "--use_neft",
- action="store_true",
- default=False,
- help="Use NEFTune",
- )
- parser.add_argument(
- "--freeze_non_embeds_params",
- action="store_true",
- default=False,
- help="Freeze non embeddings parameters",
- )
- parser.add_argument("--tp", type=int, default=1)
- parser.add_argument("--zero", type=int, default=1)
- args = parser.parse_args()
-
- with open(args.config_file, "w") as f:
- json.dump(args.__dict__, f, indent=4)
-
- # ==============================
- # Initialize Distributed Training
- # ==============================
- colossalai.launch_from_torch({})
- coordinator = DistCoordinator()
-
- # ==============================
- # Initialize Tensorboard
- # ==============================
- if coordinator.is_master():
- os.makedirs(args.tensorboard_dir, exist_ok=True)
- writer = SummaryWriter(args.tensorboard_dir)
-
- # ==============================
- # Initialize Booster
- # ==============================
- if args.plugin == "gemini":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "gemini_auto":
- plugin = GeminiPlugin(
- precision=args.mixed_precision,
- placement_policy="auto",
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "zero2_cpu":
- plugin = LowLevelZeroPlugin(
- stage=2,
- precision=args.mixed_precision,
- initial_scale=2**16,
- cpu_offload=True,
- max_norm=args.grad_clip,
- )
- elif args.plugin == "3d":
- plugin = HybridParallelPlugin(
- tp_size=args.tp,
- pp_size=1,
- zero_stage=args.zero,
- max_norm=args.grad_clip,
- precision=args.mixed_precision,
- )
- else:
- raise ValueError(f"Unknown plugin {args.plugin}")
-
- booster = Booster(plugin=plugin)
-
- # ======================================================
- # Initialize Tokenizer, Dataset, Collator and Dataloader
- # ======================================================
- tokenizer = LlamaTokenizer.from_pretrained(args.pretrained)
- tokenizer.pad_token = tokenizer.eos_token
- tokenizer.add_bos_token = False
- tokenizer.add_eos_token = False
-
- coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
- coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
- coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
-
- coordinator.print_on_master(f"Load dataset: {args.dataset}")
-
- dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
- dataloader = setup_distributed_dataloader(
- dataset=dataset,
- batch_size=args.micro_batch_size,
- shuffle=True,
- drop_last=True,
- collate_fn=data_collator,
- )
- coordinator.print_on_master(
- f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
-
- # ======================================================
- # Initialize Model, Objective, Optimizer and LR Scheduler
- # ======================================================
- init_ctx = (
- LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
- )
- with init_ctx:
- model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
- # Freeze part of parameters.
- if args.freeze_non_embeds_params:
- freeze_non_embeds_parameters(model=model)
-
- if args.use_grad_checkpoint:
- model.gradient_checkpointing_enable()
- coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
- if args.use_flash_attn:
- replace_with_flash_attention(model=model)
- coordinator.print_on_master(msg="Flash-attention enabled successfully")
-
- model_numel = get_model_numel(model)
- coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
-
- optimizer = HybridAdam(
- model_params=filter(lambda p: p.requires_grad, model.parameters())
- if args.freeze_non_embeds_params
- else model.parameters(),
- lr=args.lr,
- betas=(0.9, 0.95),
- weight_decay=args.weight_decay,
- adamw_mode=True,
- )
-
- if args.warmup_steps is None:
- args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
- coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
-
- lr_scheduler = CosineAnnealingWarmupLR(
- optimizer=optimizer,
- total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
- warmup_steps=args.warmup_steps,
- eta_min=0.1 * args.lr,
- )
-
- # Flash attention will be disabled because it does NOT support fp32.
- default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
- torch.set_default_dtype(default_dtype)
- model, optimizer, _, dataloader, lr_scheduler = booster.boost(
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- dataloader=dataloader,
- )
-
- torch.set_default_dtype(torch.float)
-
- if args.load_checkpoint is None:
- coordinator.print_on_master(f"Load pretrained model checkpoint from {args.pretrained}")
- booster.load_model(model, args.pretrained, strict=False)
-
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
- coordinator.print_on_master(
- f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- start_epoch = 0
- start_step = 0
- sampler_start_idx = 0
- if args.load_checkpoint is not None:
- if "modeling" in args.load_checkpoint:
- coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
- booster.load_model(model, args.load_checkpoint)
- else:
- coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
- start_epoch, start_step, sampler_start_idx = load_checkpoint(
- load_dir=args.load_checkpoint,
- booster=booster,
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- )
- coordinator.print_on_master(
- f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
- )
- coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
-
- coordinator.print_on_master(
- f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
- )
- coordinator.print_on_master(
- f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
- )
-
- if args.use_neft:
- coordinator.print_on_master("Activate NEFTune.")
- model, handle = activate_neftune(model)
-
- num_steps_per_epoch = len(dataloader) // args.accumulation_steps
- # If resume training, set the sampler start index to the correct value
- assert isinstance(dataloader.sampler, StatefulDistributedSampler)
- dataloader.sampler.set_start_index(start_index=sampler_start_idx)
-
- for epoch in range(start_epoch, args.num_epochs):
- dataloader.sampler.set_epoch(epoch=epoch)
- pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch)
- total_loss = torch.tensor(0.0).to(torch.cuda.current_device())
- for step, batch in enumerate(dataloader):
- batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
-
- batch_output = model(**batch)
-
- loss = batch_output.loss / args.accumulation_steps
- total_loss += loss.item()
-
- booster.backward(loss=loss, optimizer=optimizer)
-
- if (step + 1) % args.accumulation_steps == 0:
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
-
- all_reduce_mean(tensor=total_loss)
- pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
- if coordinator.is_master():
- global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
- writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
- writer.add_scalar(
- tag="Learning Rate",
- scalar_value=lr_scheduler.get_last_lr()[0],
- global_step=global_step,
- )
- total_loss.fill_(0.0)
- pbar.update()
- # Save modeling.
-
- if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or (
- step + 1
- ) == len(dataloader):
- coordinator.print_on_master("\nStart saving model checkpoint with running states")
-
- if args.use_neft:
- coordinator.print_on_master("Deactivate NEFTune before saving model.")
- deactivate_neftune(model, handle)
-
- save_checkpoint(
- save_dir=args.save_dir,
- booster=booster,
- model=model,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- epoch=epoch,
- step=step + 1,
- batch_size=args.micro_batch_size,
- coordinator=coordinator,
- )
- coordinator.print_on_master(
- f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
- )
-
- if args.use_neft:
- coordinator.print_on_master("Activate NEFTune.")
- model, handle = activate_neftune(model)
-
- # Delete CUDA cache.
- # del batch, batch_labels, batch_output, loss
- torch.cuda.empty_cache()
-
- # the continue epochs are not resumed, so we need to reset the sampler start index and start step
- dataloader.sampler.set_start_index(start_index=0)
- start_step = 0
-
- if args.use_neft:
- coordinator.print_on_master("Deactivate NEFTune.")
- deactivate_neftune(model, handle)
-
- # Final save.
- coordinator.print_on_master("Start saving final model checkpoint")
- booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
- coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
-
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
-
-
-if __name__ == "__main__":
- main()
diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py
index f293c4f699cd..9c70c0d2a1ad 100644
--- a/applications/ColossalEval/colossal_eval/models/chatglm.py
+++ b/applications/ColossalEval/colossal_eval/models/chatglm.py
@@ -3,6 +3,8 @@
import torch
+from colossalai.utils import get_current_device
+
from .huggingface import HuggingFaceModel
IGNORE_INDEX = -100
@@ -126,9 +128,9 @@ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[t
"""
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
- ).to(torch.cuda.current_device())
+ ).to(get_current_device())
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
- torch.cuda.current_device()
+ get_current_device()
)
outputs = self.model(input_ids)[0]
@@ -197,7 +199,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str
truncation=True,
return_tensors="pt",
max_length=self.model_max_length - max_new_tokens,
- ).to(torch.cuda.current_device())
+ ).to(get_current_device())
# Set output_scores=True to get prediction scores.
outputs = self.model.generate(
diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py
index 741c884f0043..fff697e21e34 100644
--- a/applications/ColossalEval/colossal_eval/models/huggingface.py
+++ b/applications/ColossalEval/colossal_eval/models/huggingface.py
@@ -11,6 +11,7 @@
from colossalai.logging import DistributedLogger
from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.utils import get_current_device
from .base import BaseModel
@@ -128,12 +129,12 @@ def _load_model(
self.model = AutoModel.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model)
- self.model.to(torch.cuda.current_device())
+ self.model.to(get_current_device())
if peft_path is not None:
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
else:
- self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
+ self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
self.model.eval()
@@ -155,11 +156,11 @@ def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[t
"""
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
- ).to(torch.cuda.current_device())
+ ).to(get_current_device())
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
- torch.cuda.current_device()
+ get_current_device()
)
- attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device())
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
@@ -464,7 +465,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str
return_tensors="pt",
return_token_type_ids=False,
max_length=self.model_max_length - max_new_tokens,
- ).to(torch.cuda.current_device())
+ ).to(get_current_device())
# Set output_scores=True to get prediction scores.
outputs = self.model.generate(
@@ -598,12 +599,12 @@ def _load_model(
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, sharded_parameters = shard_former.optimize(self.model)
- self.model.to(torch.cuda.current_device())
+ self.model.to(get_current_device())
if peft_path is not None:
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
else:
- self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
+ self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py
index 5b09f9de8da6..a340f3bfd281 100644
--- a/applications/ColossalEval/examples/dataset_evaluation/inference.py
+++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py
@@ -8,6 +8,7 @@
from colossal_eval import dataset, models, utils
import colossalai
+from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig
@@ -82,6 +83,7 @@ def rm_and_merge(
def main(args):
colossalai.launch_from_torch(config={}, seed=42)
+ accelerator = get_accelerator()
world_size = dist.get_world_size()
rank = dist.get_rank()
@@ -235,10 +237,10 @@ def main(args):
),
)
- logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
+ logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB")
del model_
- torch.cuda.empty_cache()
+ accelerator.empty_cache()
dist.barrier()
if rank == 0:
diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md
new file mode 100644
index 000000000000..be50a8f9f251
Binary files /dev/null and b/applications/ColossalMoE/README.md differ
diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
new file mode 100644
index 000000000000..d08dfd5f8120
--- /dev/null
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
@@ -0,0 +1,629 @@
+import copy
+import logging
+import os
+from pathlib import Path
+from shutil import rmtree
+from typing import Dict, Iterator, Optional, OrderedDict, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.checkpoint_io import CheckpointIndexFile
+from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
+from colossalai.checkpoint_io.index_file import CheckpointIndexFile
+from colossalai.checkpoint_io.utils import (
+ StateDictSharder,
+ gather_distributed_param,
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ load_shard_state_dict,
+ load_states_into_optimizer,
+ save_config_file,
+ save_param_groups,
+ save_state_dict_shards,
+ search_tp_partition_dim,
+ sharded_optimizer_loading_epilogue,
+)
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.moe import MOE_MANAGER
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
+
+try:
+ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
+except ImportError:
+ _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
+
+
+class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
+ def __init__(
+ self,
+ dp_group: ProcessGroup,
+ pp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ zero_stage: int,
+ verbose: bool = True,
+ ) -> None:
+ super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
+ moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
+ self.ep_group = moe_info.ep_group
+ self.ep_size = moe_info.ep_size
+ self.ep_rank = moe_info.ep_rank
+ self.real_dp_rank = moe_info.dp_rank
+
+ @staticmethod
+ def _model_sharder(
+ model: nn.Module,
+ prefix: str = "",
+ keep_vars: bool = False,
+ size_per_shard: int = 1024,
+ param_name_pattern: Optional[str] = None,
+ ) -> Iterator[Tuple[OrderedDict, int]]:
+ # An internel method that breaks state_dict of model into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+
+ # Save parameters.
+ for name, param in model.named_parameters():
+ if param is None:
+ continue
+ if param_name_pattern is not None and param_name_pattern not in name:
+ continue
+ # Gather tensor pieces when using tensor parallel.
+ param_ = gather_distributed_param(param, keep_vars=False)
+ block, block_size = state_dict_sharder.append_param(prefix + name, param_)
+ if block is not None:
+ yield block, block_size
+
+ # Save buffers.
+ for name, buf in model.named_buffers():
+ if buf is not None and name not in model._non_persistent_buffers_set:
+ buffer = buf if keep_vars else buf.detach()
+ block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
+ if block is not None:
+ yield block, block_size
+
+ # Save extra states.
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if (
+ getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+ is not torch.nn.Module.get_extra_state
+ ):
+ extra_state = model.get_extra_state()
+ block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def save_sharded_model(
+ self,
+ model: ModelWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
+ """
+ Save sharded model checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
+ - Multiple files that store state tensors of models.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_model.-000XX.bin"
+
+
+ Args:
+ model (nn.Module): Model on local device to be saved.
+ checkpoint (str): Checkpointing path which should be a directory path.
+ gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
+ prefix (str, optional): Perfix of file to save. Defaults to None.
+ size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
+ use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
+ """
+
+ assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+ model = model.unwrap()
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ if self.real_dp_rank != 0:
+ dist.barrier()
+ return
+
+ # ep_rank 0 saves all the parameters and buffers.
+ # other ep_ranks save only experts
+ ep_param_pattern = "experts." if self.ep_rank != 0 else None
+
+ # Then collect the sharded parameters & buffers along tp_group.
+ # Only devices with tp_rank == 0 are responsible for model saving.
+ state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
+ model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
+ )
+ weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = self.tp_rank == 0
+
+ if self.pp_size == 1 and self.ep_size == 1:
+ # When pipeline is not used, save the model shards as in general checkpointIO
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ )
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ dist.barrier()
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
+ weights_name = weights_name.replace(
+ ".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors"
+ )
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ use_pp_format=True,
+ )
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ dist.barrier()
+ return
+
+ dist.barrier()
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.coordinator.is_master():
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for weight, weight_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(weight, weight_filename)
+
+ final_index_file.write_index_file(final_index_file_path)
+ save_config_file(model, checkpoint)
+ rmtree(tmp_index_file_folder)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
+
+ @staticmethod
+ def gather_from_sharded_optimizer_state(
+ state: OrderedDict,
+ param: torch.Tensor,
+ original_shape: torch.Size,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ use_zero: bool,
+ inplace: bool,
+ is_moe_param: bool,
+ device: torch.device = torch.device("cpu"),
+ ) -> OrderedDict:
+ """
+ With given parameter and its optimizer states, gather the complete optimizer state for saving.
+
+ Args:
+ state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
+ param (torch.Tensor): The given parameter. It should be working_param when using Zero.
+ original_shape (torch.Size): The size of parameter before sharding.
+ dp_group (ProcessGroup): The process group of data parallel.
+ tp_group (ProcessGroup): The process group of tensor parallel.
+ use_zero (bool): Whether Zero is used.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+ device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
+
+ Returns:
+ OrderedDict: The complete optimizer state of given parameter.
+ """
+ dp_size = dist.get_world_size(dp_group)
+ tp_size = dist.get_world_size(tp_group)
+ current_shape = param.shape
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ # First gather Zero shards.
+ if use_zero and not is_moe_param:
+ v = v.cuda()
+ gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
+ dist.all_gather(gather_tensor, v, group=dp_group)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+
+ # Then gather TP shards.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
+ if partition_dim is not None:
+ gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
+ dist.all_gather(gather_tensor, v, group=tp_group)
+ v = torch.cat(gather_tensor, dim=partition_dim)
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
+
+ @staticmethod
+ def _optimizer_sharder(
+ optimizer: OptimizerWrapper,
+ use_zero: bool,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ size_per_shard: int = 1024,
+ only_moe_param: bool = False,
+ ):
+ # An internel method that breaks state_dict of optimizer into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+ param_info = optimizer.param_info
+ master_to_working_map = optimizer.get_master_to_working_map()
+
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ param_id = param_info["param2id"][id(working_param)]
+ original_shape = param_info["param2shape"][id(working_param)]
+ state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+ state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=dp_group,
+ tp_group=tp_group,
+ use_zero=use_zero,
+ inplace=False,
+ is_moe_param=is_moe_tensor(working_param),
+ )
+
+ if only_moe_param and not is_moe_tensor(working_param):
+ continue
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
+ """
+ Save sharded optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files that store state tensors of optimizers.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_optim.-000XX.bin"
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+ checkpoint (str): Path to save optimizer state_dict
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ prefix (str): Perfix of file to save
+ size_per_shard (int): Max file size of each file shard that store state tensors
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of states when zero is not used.
+ # In this case only let the device with dp_rank == 0 save the model.
+ if not self.use_zero and self.real_dp_rank != 0:
+ dist.barrier()
+ return
+
+ # Then collect the sharded states along dp_group(if using zero)/tp_group.
+ # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
+ state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
+ optimizer,
+ use_zero=self.use_zero,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ size_per_shard=size_per_shard,
+ only_moe_param=self.ep_rank != 0,
+ )
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
+
+ if self.pp_size == 1 and self.ep_size == 1:
+ # When pipeline is not used, save the optimizer shards as in general checkpointIO
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ )
+
+ if control_saving:
+ # Store param groups.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ save_param_groups({"param_groups": param_groups}, group_file_path)
+ # Store index file.
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ dist.barrier()
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ use_pp_format=True,
+ )
+
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ dist.barrier()
+ return
+
+ dist.barrier()
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.coordinator.is_master():
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for param_id, state_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(param_id, state_filename)
+
+ # Store param groups.
+ final_index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ save_param_groups({"param_groups": param_groups}, group_file_path)
+
+ final_index_file.write_index_file(final_index_file_path)
+ rmtree(tmp_index_file_folder)
+
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
+
+ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
+ """
+ Load sharded optimizer with the given path to index file of checkpoint folder.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ prefix (str): Not used.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ return optimizer.param_info["param2id"][id(working_param)]
+
+ # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
+ # When Zero is used, the mapped parameter objects should be fp32 master parameters.
+ # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
+ id_map = {}
+ master_to_working_map = optimizer.get_master_to_working_map()
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ id_map[param_id] = param
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
+ saved_groups = torch.load(param_group_path)
+
+ updated_groups = []
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ # obtain updated param group
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
+ updated_groups.append(new_pg)
+ # ep param groups
+ if len(optimizer.optim.param_groups) == len(saved_groups) + 1:
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+ # Load saved states to optimizer.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ if param is None:
+ continue
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ if param_id not in weight_map:
+ continue
+ filename = weight_map[param_id]
+
+ # If this param's states has been loaded before, directly return.
+ if filename in loaded_file:
+ continue
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
+ load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
+ loaded_file.add(filename)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ device = param.device
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(
+ state,
+ current_shape=working_param.shape,
+ original_shape=original_shape,
+ device=device,
+ inplace=True,
+ is_moe_param=is_moe_tensor(working_param),
+ )
+ optimizer.optim.state[param] = sharded_state
+
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def shard_from_complete_optimizer_state(
+ self,
+ state: OrderedDict,
+ current_shape: torch.Size,
+ original_shape: torch.Size,
+ device: torch.device,
+ inplace: bool,
+ is_moe_param: bool,
+ ) -> OrderedDict:
+ """
+ With complete optimizer states of a specific parameter loaded from checkpoint,
+ slice out the sharded optimizer states kept by current device.
+
+ Args:
+ state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
+ current_shape (torch.Size): The size of parameter after sharding.
+ original_shape (torch.Size): The size of parameter before sharding.
+ device (torch.device): The destination device of loaded optimizer states.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+
+ Returns:
+ OrderedDict: The sharded optimizer state of the given parameter.
+ """
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ # Shard state along tensor parallel group.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
+ if partition_dim is not None:
+ slice_size = current_shape[partition_dim]
+ v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
+
+ # Shard state along data parallel group when using Zero.
+ if self.use_zero and not is_moe_param:
+ padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ slice_size = v.numel() // self.dp_size
+ v = v.split(slice_size, dim=0)[self.dp_rank]
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
+
+ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ raise NotImplementedError
+
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
+ raise NotImplementedError
+
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
+ raise NotImplementedError
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
new file mode 100644
index 000000000000..a2b78a2bd18c
--- /dev/null
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
@@ -0,0 +1,92 @@
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+from colossalai.lazy import LazyInitContext
+from colossalai.moe import MOE_MANAGER
+from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
+
+
+class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
+ def __init__(self, config):
+ super().__init__(config)
+ self.setup_ep()
+
+ def setup_ep(self):
+ _, moe_info = MOE_MANAGER.get_info(self.num_experts)
+ ep_group = moe_info.ep_group
+ self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
+ self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+ assert self.num_experts % self.ep_size == 0
+ self.ep_group = ep_group
+ self.num_experts_per_ep = self.num_experts // self.ep_size
+ self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
+ held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+ set_tensors_to_none(self.experts, exclude=set(held_experts))
+ for p in self.experts.parameters():
+ set_moe_tensor_info(p, moe_info)
+
+ @staticmethod
+ def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
+ LazyInitContext.materialize(module)
+ module.__class__ = EPMixtralSparseMoeBlock
+ module.setup_ep()
+ return module
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ # router_logits: (batch * sequence_length, n_experts)
+ router_logits = self.gate(hidden_states)
+
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ # we cast back to the input dtype
+ routing_weights = routing_weights.to(hidden_states.dtype)
+
+ selected_experts = selected_experts.t().reshape(-1)
+ selected_experts_idx = selected_experts.argsort()
+ dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
+ input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
+ output_split_sizes = torch.zeros_like(input_split_sizes)
+ dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
+
+ input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+ output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+ output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
+ # compute expert output
+ output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+ if output_states.size(0) > 0:
+ if self.num_experts_per_ep == 1:
+ # no need to split
+ expert = self.experts[self.expert_start_idx]
+ output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
+ output_states = expert.w2(output_states)
+ else:
+ output_states_splits = output_states.split(output_split_sizes.tolist())
+ output_states_list = []
+ for i, split_states in enumerate(output_states_splits):
+ if split_states.size(0) == 0:
+ continue
+ expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+ split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
+ split_states = expert.w2(split_states)
+ output_states_list.append(split_states)
+ output_states = torch.cat(output_states_list)
+ output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+ dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
+ recover_experts_idx = torch.empty_like(selected_experts_idx)
+ recover_experts_idx[selected_experts_idx] = torch.arange(
+ selected_experts_idx.size(0), device=selected_experts_idx.device
+ )
+ dispatch_states = dispatch_states[recover_experts_idx]
+ k_hidden_states = dispatch_states.chunk(self.top_k)
+ output_states = k_hidden_states[0] * routing_weights[:, 0, None]
+ for i in range(1, self.top_k):
+ output_states += k_hidden_states[i] * routing_weights[:, i, None]
+ output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
+ return output_states, router_logits
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
new file mode 100644
index 000000000000..218b05b27fad
--- /dev/null
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
@@ -0,0 +1,557 @@
+from functools import partial
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import CrossEntropyLoss, Module
+from transformers.models.mixtral.modeling_mixtral import (
+ MixtralDecoderLayer,
+ MixtralForCausalLM,
+ MixtralModel,
+ MoeCausalLMOutputWithPast,
+ _prepare_4d_causal_attention_mask,
+ load_balancing_loss_func,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from colossalai.shardformer.shard import ShardConfig
+
+from .mixtral_layer import EPMixtralSparseMoeBlock
+
+__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
+
+
+class MixtralPolicy(Policy):
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ if self.shard_config.enable_tensor_parallelism:
+ # Resize embedding
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
+
+ return self.model
+
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ policy = {}
+
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ raise NotImplementedError(
+ "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
+ )
+
+ if self.shard_config.enable_tensor_parallelism:
+ raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
+
+ # expert parallel
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="block_sparse_moe",
+ target_module=EPMixtralSparseMoeBlock,
+ )
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ ],
+ policy=policy,
+ target_key=MixtralDecoderLayer,
+ )
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="norm",
+ target_module=FusedRMSNorm,
+ ),
+ policy=policy,
+ target_key=MixtralModel,
+ )
+
+ if self.shard_config.enable_flash_attention:
+ raise NotImplementedError("Flash attention has already been replaced in mixtral.")
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "MixtralModel":
+ module = self.model
+ else:
+ module = self.model.model
+
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=model_cls
+ )
+
+ return
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "MixtralModel":
+ module = self.model
+ else:
+ module = self.model.model
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+
+ return held_layers
+
+
+class MixtralModelPolicy(MixtralPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=MixtralModel,
+ new_forward=MixtralPipelineForwards.mixtral_model_forward,
+ policy=policy,
+ )
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama model"""
+ return []
+
+
+class MixtralForCausalLMPolicy(MixtralPolicy):
+ def module_policy(self):
+ policy = super().module_policy()
+
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for casual lm
+ new_item = {
+ MixtralForCausalLM: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs=dict(gather_output=True),
+ )
+ ]
+ )
+ }
+ policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=MixtralForCausalLM,
+ new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
+ policy=policy,
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ llama_model = self.model.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if (
+ id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
+ and self.pipeline_stage_manager.num_stages > 1
+ ):
+ # tie weights
+ return [
+ {
+ 0: llama_model.embed_tokens.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
+ }
+ ]
+ return []
+
+
+class MixtralPipelineForwards:
+ """
+ This class serves as a micro library for forward function substitution of Llama models
+ under pipeline setting.
+ """
+
+ @staticmethod
+ def mixtral_model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ past_router_logits: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
+
+ >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if stage_manager.is_first_stage():
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
+ use_cache = False
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions, for the first stage, hidden_states is the input embeddings,
+ # for the other stages, hidden_states is the output of the previous stage
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_router_logits = () if output_router_logits else None
+ next_decoder_cache = None
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ output_attentions,
+ output_router_logits,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_value,
+ output_attentions,
+ output_router_logits,
+ use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+ if output_router_logits:
+ all_router_logits += (layer_outputs[-1],)
+
+ if stage_manager.is_last_stage():
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ next_cache = next_decoder_cache if use_cache else None
+
+ if output_router_logits and past_router_logits is not None:
+ all_router_logits = past_router_logits + all_router_logits
+ if stage_manager.is_last_stage():
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
+ if v is not None
+ )
+ # always return dict for imediate stage
+ return {
+ "hidden_states": hidden_states,
+ "past_router_logits": all_router_logits,
+ }
+
+ @staticmethod
+ def mixtral_for_causal_lm_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ past_router_logits: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MixtralForCausalLM
+
+ >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = MixtralPipelineForwards.mixtral_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ output_router_logits=output_router_logits,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ past_router_logits=past_router_logits,
+ )
+ past_key_values = None
+
+ if stage_manager.is_last_stage():
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
+ if labels is not None:
+ loss += self.router_aux_loss_coef * aux_loss
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ if output_router_logits:
+ output = (aux_loss,) + output
+ return (loss,) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=None,
+ hidden_states=outputs[0],
+ attentions=None,
+ router_logits=outputs[-1],
+ )
+ else:
+ out = {}
+ hidden_states = outputs.get("hidden_states")
+ out["hidden_states"] = hidden_states
+ if output_router_logits:
+ out["past_router_logits"] = outputs["past_router_logits"]
+ return out
diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py
new file mode 100644
index 000000000000..a2a0a7e78239
--- /dev/null
+++ b/applications/ColossalMoE/colossal_moe/utils.py
@@ -0,0 +1,84 @@
+import json
+import os
+from typing import Any, Dict, Tuple, Union
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.optim.optimizer import Optimizer
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+
+
+def move_to_cuda(batch, device):
+ return {k: v.to(device) for k, v in batch.items()}
+
+
+def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
+ """
+ Load file in JSON format
+ """
+ with open(file=file_path, mode="r", encoding="utf-8") as fp:
+ return json.load(fp)
+
+
+def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
+ """
+ Save as JSON format
+ """
+ with open(file=file_path, mode="w", encoding="utf-8") as fp:
+ json.dump(data, fp=fp, ensure_ascii=False, indent=4)
+
+
+def save_checkpoint(
+ save_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+ epoch: int,
+ step: int,
+ batch_size: int,
+ coordinator: DistCoordinator,
+) -> None:
+ """
+ Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
+ os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
+
+ booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
+ booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
+ running_states = {
+ "epoch": epoch,
+ "step": step,
+ "sample_start_index": step * batch_size,
+ }
+ if coordinator.is_master():
+ save_json(running_states, os.path.join(save_dir, "running_states.json"))
+
+
+def load_checkpoint(
+ load_dir: Union[str, os.PathLike],
+ booster: Booster,
+ model: torch.nn.Module,
+ optimizer: Optimizer,
+ lr_scheduler: _LRScheduler,
+) -> Tuple[int, int, int]:
+ """
+ Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
+ """
+
+ # Update booster params states.
+ booster.load_model(model, os.path.join(load_dir, "modeling"))
+ booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
+ booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
+
+ running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
+ return (
+ running_states["epoch"],
+ running_states["step"],
+ running_states["sample_start_index"],
+ )
diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py
new file mode 100644
index 000000000000..46ff70ff33ab
--- /dev/null
+++ b/applications/ColossalMoE/infer.py
@@ -0,0 +1,111 @@
+import argparse
+
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
+from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
+from transformers import AutoTokenizer
+from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+
+
+def parse_args():
+ # basic settings
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="mistralai/Mixtral-8x7B-v0.1",
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="ep",
+ choices=["ep"],
+ help="Parallel methos.",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ default="bf16",
+ choices=["fp32", "bf16", "fp16"],
+ help="The mixed precision training.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+
+ # kernel
+ parser.add_argument(
+ "--use_kernel",
+ action="store_true",
+ help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
+ )
+ parser.add_argument(
+ "--use_layernorm_kernel",
+ action="store_true",
+ help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+
+ config = MixtralConfig.from_pretrained(args.model_name)
+ ep_size = min(dist.get_world_size(), config.num_local_experts)
+ # Set plugin
+ if args.plugin == "ep":
+ plugin = MoeHybridParallelPlugin(
+ tp_size=1,
+ pp_size=1,
+ ep_size=ep_size,
+ zero_stage=1,
+ precision=args.precision,
+ custom_policy=MixtralForCausalLMPolicy(),
+ checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
+ enable_fused_normalization=args.use_layernorm_kernel,
+ enable_jit_fused=args.use_kernel,
+ )
+ else:
+ raise ValueError(f"Invalid plugin {args.plugin}")
+ coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
+
+ # Build mixtral model
+ model = MixtralForCausalLM.from_pretrained(args.model_name)
+ coordinator.print_on_master(f"Finish load model")
+
+ # Prepare tokenizer and dataloader
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+
+ # Set booster
+ booster = Booster(plugin=plugin)
+ model, _, _, _, _ = booster.boost(model=model)
+ coordinator.print_on_master(f"Finish init booster")
+
+ model.eval()
+
+ if coordinator.rank == 0:
+ text = ["Hello my name is"]
+ else:
+ text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
+ tokenizer.pad_token = tokenizer.unk_token
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
+
+ with torch.no_grad():
+ outputs = model.module.generate(**inputs, max_new_tokens=20)
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ print(f"[{coordinator.rank}] {outputs}")
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh
new file mode 100644
index 000000000000..0487fe9c1562
--- /dev/null
+++ b/applications/ColossalMoE/infer.sh
@@ -0,0 +1,7 @@
+NUM_GPU=2
+MODEL="mistralai/Mixtral-8x7B-v0.1"
+
+# ep
+torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
+ --model_name $MODEL \
+ --plugin "ep" \
diff --git a/applications/ColossalMoE/requirements.txt b/applications/ColossalMoE/requirements.txt
new file mode 100644
index 000000000000..9a5738c412b9
--- /dev/null
+++ b/applications/ColossalMoE/requirements.txt
@@ -0,0 +1,5 @@
+colossalai >= 0.3.3
+torch >= 1.8.1
+transformers == 4.36.0
+sentencepiece
+datasets
diff --git a/applications/ColossalMoE/setup.py b/applications/ColossalMoE/setup.py
new file mode 100644
index 000000000000..275f59e10a06
--- /dev/null
+++ b/applications/ColossalMoE/setup.py
@@ -0,0 +1,43 @@
+from setuptools import find_packages, setup
+
+
+def fetch_requirements(path):
+ with open(path, "r") as fd:
+ return [r.strip() for r in fd.readlines()]
+
+
+def fetch_readme():
+ with open("README.md", encoding="utf-8") as f:
+ return f.read()
+
+
+def fetch_version():
+ with open("version.txt", "r") as f:
+ return f.read().strip()
+
+
+setup(
+ name="colossal_moe",
+ version=fetch_version(),
+ packages=find_packages(
+ exclude=(
+ "tests",
+ "benchmarks",
+ "*.egg-info",
+ )
+ ),
+ description="Colossal-AI MoE",
+ long_description=fetch_readme(),
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ url="https://github.com/hpcaitech",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: System :: Distributed Computing",
+ ],
+)
diff --git a/applications/ColossalMoE/tests/__init__.py b/applications/ColossalMoE/tests/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py
new file mode 100644
index 000000000000..57589ab20d22
--- /dev/null
+++ b/applications/ColossalMoE/tests/test_mixtral_layer.py
@@ -0,0 +1,63 @@
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
+from torch.testing import assert_close
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+import colossalai
+from colossalai.moe import MOE_MANAGER
+from colossalai.testing.utils import spawn
+
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
+
+
+def check_mixtral_moe_layer():
+ torch.cuda.set_device(dist.get_rank())
+ MOE_MANAGER.setup(
+ parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
+ )
+ config = MixtralConfig(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_local_experts=n_experts,
+ num_experts_per_tok=top_k,
+ )
+ torch.manual_seed(0)
+ orig_model = MixtralSparseMoeBlock(config).cuda()
+ x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
+ orig_output, orig_logits = orig_model(x)
+ model = deepcopy(orig_model)
+ model = EPMixtralSparseMoeBlock.from_native_module(model)
+ ep_output, ep_logits = model(x)
+ assert_close(orig_logits, ep_logits)
+ assert_close(orig_output, ep_output)
+ orig_loss = orig_output.mean()
+ orig_loss.backward()
+ ep_loss = ep_output.mean()
+ ep_loss.backward()
+ assert_close(orig_loss, ep_loss)
+ name_to_p = {n: p for n, p in orig_model.named_parameters()}
+ for n, ep_p in model.named_parameters():
+ p = name_to_p[n]
+ if ep_p.grad is not None:
+ assert_close(p.grad, ep_p.grad)
+
+
+def run_dist(rank: int, world_size: int, port: int):
+ colossalai.launch({}, rank, world_size, "localhost", port)
+ check_mixtral_moe_layer()
+
+
+@pytest.mark.parametrize("world_size", [2, 4])
+def test_mixtral_moe_layer(world_size: int):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_mixtral_moe_layer(2)
diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py
new file mode 100644
index 000000000000..822e7410f016
--- /dev/null
+++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py
@@ -0,0 +1,146 @@
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
+from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
+from torch.optim import Adam
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing.utils import spawn
+
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
+
+
+def check_model_equal(model1, model2):
+ assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
+ for p1, p2 in zip(model1.parameters(), model2.parameters()):
+ assert torch.equal(p1.half(), p2.half())
+
+
+def get_optimizer_snapshot(optim):
+ state = {id(k): deepcopy(v) for k, v in optim.state.items()}
+ param_groups = []
+ for group in optim.param_groups:
+ params = [id(p) for p in group["params"]]
+ new_group = {"params": params}
+ for k, v in group.items():
+ if k != "params":
+ new_group[k] = v
+ param_groups.append(new_group)
+ return {
+ "state": state,
+ "param_groups": param_groups,
+ }
+
+
+def check_optimizer_snapshot_equal(snapshot1, snapshot2):
+ # check param_groups
+ assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
+ for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
+ assert set(group1.keys()) == set(group2.keys())
+ for k in group1.keys():
+ assert group1[k] == group2[k]
+ # check state
+ assert set(snapshot1["state"].keys()) == set(
+ snapshot2["state"].keys()
+ ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
+ for pid in snapshot1["state"].keys():
+ state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
+ assert set(state1.keys()) == set(state2.keys())
+ for k in state1.keys():
+ if isinstance(state1[k], torch.Tensor):
+ assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
+ else:
+ assert state1[k] == state2[k]
+
+
+def check_mixtral_moe_layer():
+ torch.cuda.set_device(dist.get_rank())
+ config = MixtralConfig(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_local_experts=n_experts,
+ num_experts_per_tok=top_k,
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ )
+ torch.manual_seed(0)
+ input_ids = torch.randint(0, 100, (2, tokens)).cuda()
+ orig_model = MixtralForCausalLM(config).cuda()
+ model = deepcopy(orig_model)
+ optimizer = Adam(model.parameters(), lr=1e-3)
+ plugin = MoeHybridParallelPlugin(
+ tp_size=1,
+ pp_size=2,
+ ep_size=2,
+ custom_policy=MixtralForCausalLMPolicy(),
+ checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
+ microbatch_size=1,
+ zero_stage=1,
+ )
+ booster = Booster(plugin=plugin)
+ model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
+ # initialize grads
+ data_iter = iter(
+ [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
+ )
+ booster.execute_pipeline(
+ data_iter,
+ model,
+ lambda outputs, inputs: outputs.loss,
+ optimizer,
+ )
+
+ # check save model
+ booster.save_model(model, "mixtral_model", shard=True)
+ dist.barrier()
+ if dist.get_rank() == 0:
+ saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
+ check_model_equal(orig_model, saved_model)
+ saved_model.save_pretrained("mixtral_hf_model")
+ dist.barrier()
+
+ # check load model
+ new_model = MixtralForCausalLM(config).cuda()
+ new_optimizer = Adam(new_model.parameters(), lr=1e-3)
+ new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
+ booster.load_model(new_model, "mixtral_hf_model")
+ check_model_equal(model, new_model)
+
+ # check save optimizer
+ optimizer.step()
+ for group in optimizer.param_groups:
+ group["lr"] = 0.1
+ snapshot = get_optimizer_snapshot(optimizer.unwrap())
+ booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
+ dist.barrier()
+ # reset optimizer state
+ for state in optimizer.unwrap().state.values():
+ for v in state.values():
+ if isinstance(v, torch.Tensor):
+ v.zero_()
+ booster.load_optimizer(optimizer, "mixtral_optim")
+ loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
+ check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
+
+
+def run_dist(rank: int, world_size: int, port: int):
+ colossalai.launch({}, rank, world_size, "localhost", port)
+ check_mixtral_moe_layer()
+
+
+@pytest.mark.parametrize("world_size", [4])
+def test_mixtral_moe_layer(world_size: int):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_mixtral_moe_layer(4)
diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py
new file mode 100644
index 000000000000..c567038ec252
--- /dev/null
+++ b/applications/ColossalMoE/train.py
@@ -0,0 +1,295 @@
+import argparse
+
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
+from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
+from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
+from torch.utils.data import Dataset
+from tqdm import tqdm
+from transformers import AutoTokenizer
+from transformers.models.mixtral import MixtralForCausalLM
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+
+@torch.no_grad()
+def get_global_loss(loss, booster):
+ global_loss = loss.clone().detach()
+ dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
+ global_loss.div_(booster.plugin.dp_size)
+ return global_loss
+
+
+class RandomDataset(Dataset):
+ def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
+ self.num_samples = num_samples
+ self.max_length = max_length
+ self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
+ self.attention_mask = torch.ones_like(self.input_ids)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ return {
+ "input_ids": self.input_ids[idx],
+ "attention_mask": self.attention_mask[idx],
+ "labels": self.input_ids[idx],
+ }
+
+
+def parse_args():
+ # basic settings
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="mistralai/Mixtral-8x7B-v0.1",
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="hybrid",
+ choices=["hybrid"],
+ help="Parallel methods.",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ default="./outputs",
+ help="The path of your saved model after finetuning.",
+ )
+ parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="Batch size (per dp group) for the training dataloader.",
+ )
+ parser.add_argument(
+ "--save_interval",
+ type=int,
+ default=1000,
+ help=" The interval (steps) of saving checkpoints.",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ default="bf16",
+ choices=["fp32", "bf16", "fp16"],
+ help="The mixed precision training.",
+ )
+ parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+
+ # optim
+ parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
+
+ # lr scheduler
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
+ parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
+
+ # zero stage for all plugins
+ parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
+ # hybrid plugin
+ parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
+ parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
+ parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
+ parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
+
+ # kernel
+ parser.add_argument(
+ "--use_kernel",
+ action="store_true",
+ help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
+ )
+ parser.add_argument(
+ "--use_layernorm_kernel",
+ action="store_true",
+ help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
+ )
+
+ # load balance
+ parser.add_argument(
+ "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
+ )
+ parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
+ # communicate overlap
+ parser.add_argument(
+ "--comm_overlap",
+ action="store_true",
+ help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
+ )
+ # hierarchical all-to-all
+ parser.add_argument(
+ "--hierarchical_alltoall",
+ action="store_true",
+ help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+
+ # Set plugin
+ if args.plugin == "hybrid":
+ plugin = MoeHybridParallelPlugin(
+ tp_size=1,
+ pp_size=args.pp_size,
+ ep_size=args.ep_size,
+ microbatch_size=args.microbatch_size,
+ custom_policy=MixtralForCausalLMPolicy(),
+ enable_fused_normalization=args.use_layernorm_kernel,
+ enable_jit_fused=args.use_kernel,
+ precision=args.precision,
+ zero_stage=args.zero_stage,
+ checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
+ )
+
+ else:
+ raise ValueError(f"Invalid plugin {args.plugin}")
+ coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
+
+ # Build Mixtral model
+ model = MixtralForCausalLM.from_pretrained(args.model_name)
+ coordinator.print_on_master(f"Finish init model")
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
+ # Prepare tokenizer and dataloader
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+ dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
+ collate_fn = None
+ dataloader = plugin.prepare_dataloader(
+ dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
+ )
+
+ # Set optimizer
+ optimizer = HybridAdam(
+ model_params=model.parameters(),
+ lr=args.lr,
+ betas=(0.9, 0.95),
+ weight_decay=args.weight_decay,
+ adamw_mode=True,
+ )
+
+ # Set lr scheduler
+ lr_scheduler = CosineAnnealingWarmupLR(
+ optimizer=optimizer,
+ total_steps=args.num_epochs * len(dataloader),
+ warmup_steps=args.warmup_steps
+ if args.warmup_steps is not None
+ else int(args.num_epochs * len(dataloader) * 0.025),
+ eta_min=0.1 * args.lr,
+ )
+
+ # Set booster
+ booster = Booster(plugin=plugin)
+ model, optimizer, _, dataloader, lr_scheduler = booster.boost(
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ dataloader=dataloader,
+ )
+ use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ coordinator.print_on_master(f"Finish init booster")
+
+ # Load ckpt
+ if args.load_checkpoint is not None:
+ load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
+ coordinator.print_on_master(f"Finish load optimizer")
+
+ # Start finetuning
+ coordinator.print_on_master(f"Start finetuning")
+ for epoch in range(args.num_epoch):
+ model.train()
+ train_dataloader_iter = iter(dataloader)
+ total_len = len(train_dataloader_iter)
+ with tqdm(
+ range(total_len),
+ desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
+ disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
+ ) as pbar:
+ for step in pbar:
+ if use_pipeline:
+ # Forward pass
+ outputs = booster.execute_pipeline(
+ train_dataloader_iter,
+ model,
+ lambda x, y: x.loss,
+ optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs["loss"]
+ global_loss = get_global_loss(loss, booster)
+ if coordinator._local_rank == "0":
+ pbar.set_postfix({"Loss": global_loss.item()})
+ else:
+ # Forward pass
+ data = next(train_dataloader_iter)
+ data = move_to_cuda(data, torch.cuda.current_device())
+ outputs = model(**data)
+ loss = outputs["loss"]
+ # Backward
+ booster.backward(loss, optimizer)
+ pbar.set_postfix({"loss": loss.item()})
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Apply load balance
+ # if (
+ # args.load_balance
+ # and args.load_balance_interval > 0
+ # and (step + 1) % args.load_balance_interval == 0
+ # ):
+ # coordinator.print_on_master(f"Apply load balance")
+ # apply_load_balance(model, optimizer)
+ # save ckeckpoint
+ if (step + 1) % args.save_interval == 0:
+ coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
+ save_checkpoint(
+ args.output_path,
+ booster,
+ model,
+ optimizer,
+ lr_scheduler,
+ epoch,
+ step,
+ args.batch_size,
+ coordinator,
+ )
+
+ # save checkpoint at the end of each epochs
+ booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
+ coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
+
+ # Finish training
+ coordinator.print_on_master(f"Finish training")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/applications/ColossalMoE/train.sh b/applications/ColossalMoE/train.sh
new file mode 100644
index 000000000000..bee7f5c8fdf8
--- /dev/null
+++ b/applications/ColossalMoE/train.sh
@@ -0,0 +1,19 @@
+NUM_GPU=8
+MODEL="mistralai/Mixtral-8x7B-v0.1"
+SEQ_LENGTH=2048
+BATCH_SIZE=1
+LR=0.00001
+
+# hybrid
+# torchrun --standalone --nproc_per_node $NUM_GPU \
+colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \
+ train.py \
+ --num_epoch 1 \
+ --model_name $MODEL \
+ --plugin "hybrid" \
+ --batch_size $BATCH_SIZE \
+ --lr $LR \
+ --zero_stage 1 \
+ --pp_size 2 \
+ --dp_size 1 \
+ --ep_size 8 \
diff --git a/applications/ColossalMoE/version.txt b/applications/ColossalMoE/version.txt
new file mode 100644
index 000000000000..3eefcb9dd5b3
--- /dev/null
+++ b/applications/ColossalMoE/version.txt
@@ -0,0 +1 @@
+1.0.0
diff --git a/colossalai/__init__.py b/colossalai/__init__.py
index 7da55590305b..6b7f5d055207 100644
--- a/colossalai/__init__.py
+++ b/colossalai/__init__.py
@@ -1,4 +1,5 @@
from .initialize import launch, launch_from_openmpi, launch_from_slurm, launch_from_torch
+from . import accelerator
try:
# .version will be created by setup.py
diff --git a/colossalai/accelerator/README.md b/colossalai/accelerator/README.md
new file mode 100644
index 000000000000..8c644493b03a
--- /dev/null
+++ b/colossalai/accelerator/README.md
@@ -0,0 +1,20 @@
+# 🚀 Accelerator
+
+## 🔗 Table of Contents
+
+- [🚀 Accelerator](#-accelerator)
+ - [🔗 Table of Contents](#-table-of-contents)
+ - [📚 Introduction](#-introduction)
+ - [📌 Design and Acknowledgement](#-design-and-acknowledgement)
+
+## 📚 Introduction
+
+This module offers a layer of abstraction for ColossalAI. With this module, the user can easily switch between different accelerator backends, such as Nvidia GPUs, Huawei NPUs, etc. This module is an attempt to make users' code portable across different hardware platform with a simple `auto_set_accelerator()` API.
+
+## 📌 Design and Acknowledgement
+
+Our `accelerator` module is heavily inspired by [`deepspeed/accelerator`](https://www.deepspeed.ai/tutorials/accelerator-abstraction-interface/). We found that it is a very well-designed and well-structured module that can be easily integrated into our project. We would like to thank the DeepSpeed team for their great work.
+
+We implemented this accelerator module from scratch. At the same time, we have implemented our own modifications:
+1. we updated the accelerator API names to be aligned with PyTorch's native API names.
+2. we did not include the `op builder` in the `accelerator`. Instead, we have reconstructed our `kernel` module to automatically match the accelerator and its corresponding kernel implementations, so as to make modules less tangled.
diff --git a/colossalai/accelerator/__init__.py b/colossalai/accelerator/__init__.py
new file mode 100644
index 000000000000..1405133affe2
--- /dev/null
+++ b/colossalai/accelerator/__init__.py
@@ -0,0 +1,15 @@
+from .api import auto_set_accelerator, get_accelerator, set_accelerator
+from .base_accelerator import BaseAccelerator
+from .cpu_accelerator import CpuAccelerator
+from .cuda_accelerator import CudaAccelerator
+from .npu_accelerator import NpuAccelerator
+
+__all__ = [
+ "get_accelerator",
+ "set_accelerator",
+ "auto_set_accelerator",
+ "BaseAccelerator",
+ "CudaAccelerator",
+ "NpuAccelerator",
+ "CpuAccelerator",
+]
diff --git a/colossalai/accelerator/api.py b/colossalai/accelerator/api.py
new file mode 100644
index 000000000000..02b3055d7380
--- /dev/null
+++ b/colossalai/accelerator/api.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python
+from collections import OrderedDict
+from typing import Union
+
+from .base_accelerator import BaseAccelerator
+from .cpu_accelerator import CpuAccelerator
+from .cuda_accelerator import CudaAccelerator
+from .npu_accelerator import NpuAccelerator
+
+__all__ = ["set_accelerator", "auto_set_accelerator", "get_accelerator"]
+
+
+_ACCELERATOR = None
+
+
+# we use ordered dictionary here to associate the
+# order with device check priority
+# i.e. auto_set_accelerator will check cuda first
+_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
+
+
+def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
+ """
+ Set the global accelerator for the current process.
+
+ Args:
+ accelerator (Union[str, BaseAccelerator]): the type of accelerator to which the current device belongs.
+ """
+
+ global _ACCELERATOR
+
+ if isinstance(accelerator, str):
+ _ACCELERATOR = _ACCELERATOR_MAPPING[accelerator]()
+ elif isinstance(accelerator, BaseAccelerator):
+ _ACCELERATOR = accelerator
+ else:
+ raise TypeError("accelerator must be either a string or an instance of BaseAccelerator")
+
+
+def auto_set_accelerator() -> None:
+ """
+ Automatically check if any accelerator is available.
+ If an accelerator is availabe, set it as the global accelerator.
+ """
+ global _ACCELERATOR
+
+ for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
+ try:
+ accelerator = accelerator_cls()
+ if accelerator_name == "cpu" or accelerator.is_available():
+ _ACCELERATOR = accelerator
+ break
+ except:
+ pass
+
+ if _ACCELERATOR is None:
+ raise RuntimeError("No accelerator is available.")
+
+
+def get_accelerator() -> BaseAccelerator:
+ """
+ Return the accelerator for the current process. If the accelerator is not initialized, it will be initialized
+ to the default accelerator type.
+
+ Returns: the accelerator for the current process.
+ """
+ global _ACCELERATOR
+
+ if _ACCELERATOR is None:
+ auto_set_accelerator()
+ return _ACCELERATOR
diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py
new file mode 100644
index 000000000000..33c113999018
--- /dev/null
+++ b/colossalai/accelerator/base_accelerator.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python
+
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+
+__all__ = ["BaseAccelerator"]
+
+
+class BaseAccelerator(ABC):
+ support_set_device: bool = True
+
+ def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
+ self._name = name
+ self._communication_backend = communication_backend
+ self._is_synchronous = is_synchronous
+
+ # =======================
+ # immutable attributes
+ # =======================
+
+ @property
+ def name(self) -> str:
+ """
+ Return the name of the accelerator.
+ """
+ return self._name
+
+ @property
+ def communication_backend(self) -> str:
+ """
+ Return the name of the backend communication library.
+ """
+ return self._communication_backend
+
+ @property
+ def is_synchronous(self) -> bool:
+ """
+ Return whether the accelerator is a synchronous device.
+ """
+ return self._is_synchronous
+
+ def __repr__(self) -> str:
+ cls_name = self.__class__.__name__
+ return f"{cls_name}(name={self._name}, communication_backend={self._communication_backend}, is_synchronous={self._is_synchronous})"
+
+ # =======================
+ # device APIs
+ # =======================
+ @abstractmethod
+ def get_version(self) -> str:
+ """
+ Return the version of the accelerator which torch is built against.
+ """
+
+ @abstractmethod
+ def get_current_device(self) -> torch.device:
+ """
+ Return the current device.
+ """
+
+ @abstractmethod
+ def current_device(self) -> int:
+ """
+ Return the current device index.
+ """
+
+ @abstractmethod
+ def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
+ """
+ Bind the current process to a device.
+ """
+
+ @abstractmethod
+ def get_device_name(self, device: Union[torch.device, int]) -> str:
+ """
+ Return the name of the device.
+ """
+
+ @abstractmethod
+ def synchronize(self, device: Union[torch.device, int] = None):
+ """
+ Synchronize the current process.
+ """
+
+ @abstractmethod
+ def is_available(self):
+ """
+ Check if the accelerator is available.
+ """
+
+ @abstractmethod
+ def device_count(self):
+ """
+ Return the number of devices on the machine.
+ """
+
+ def set_to_device(self, models: Any) -> Any:
+ """
+ Send model to device.
+
+ :param models: nn.module or a list of module
+ """
+ if isinstance(models, list) and len(models) > 1:
+ ret = []
+ for model in models:
+ ret.append(model.to(self.get_current_device()))
+ return ret
+ elif isinstance(models, list):
+ return models[0].to(self.get_current_device())
+ else:
+ return models.to(self.get_current_device())
+
+ @abstractmethod
+ def get_device_capability(self, device=None) -> Tuple[int, int]:
+ """
+ Gets the capability of a device.
+ """
+
+ @abstractmethod
+ def get_device_name(self, device=None) -> str:
+ """
+ Gets the name of a device.
+ """
+
+ @abstractmethod
+ def get_device_properties(self, device):
+ """
+ Gets the properties of a device.
+ """
+
+ @abstractmethod
+ def utilization(self, device=None) -> int:
+ """
+ Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc.
+ """
+
+ # =======================
+ # random number generator APIs
+ # =======================
+ @abstractmethod
+ def get_rng_state(self, device="cuda") -> torch.Tensor:
+ """
+ Returns the random number generator state of the specified device as a ByteTensor.
+ """
+
+ @abstractmethod
+ def get_rng_state_all(self) -> List[torch.Tensor]:
+ """
+ Returns a list of ByteTensor representing the random number states of all devices.
+ """
+
+ @abstractmethod
+ def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
+ """
+ Sets the random number generator state of the specified device.
+ """
+
+ @abstractmethod
+ def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
+ """
+ Sets the random number generator state of all devices.
+ """
+
+ @abstractmethod
+ def manual_seed(self, seed: int) -> None:
+ """
+ Sets the seed for generating random numbers for the current device.
+ """
+
+ @abstractmethod
+ def manual_seed_all(self, seed: int) -> None:
+ """
+ Sets the seed for generating random numbers on all devices.
+ """
+
+ @abstractmethod
+ def seed(self) -> None:
+ """
+ Sets the seed for generating random numbers to a random number for the current device.
+ """
+
+ @abstractmethod
+ def seed_all(self) -> None:
+ """
+ Sets the seed for generating random numbers to a random number on all devices.
+ """
+
+ @abstractmethod
+ def initial_seed(self) -> int:
+ """
+ Returns the current random seed of the current device.
+ """
+
+ # =======================
+ # memory management APIs
+ # =======================
+ @abstractmethod
+ def empty_cache(self) -> None:
+ """
+ Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi.
+ """
+
+ @abstractmethod
+ def memory_stats(self, device=None) -> Dict[str, Any]:
+ """
+ Returns a dictionary of CUDA memory allocator statistics for a given device.
+ """
+
+ @abstractmethod
+ def memory_summary(self, device=None, abbreviated=False) -> str:
+ """
+ Returns a human-readable printout of the current memory allocator statistics for a given device.
+ """
+
+ @abstractmethod
+ def memory_snapshot(self):
+ """
+ Returns a snapshot of the CUDA memory allocator state across all devices.
+ """
+
+ @abstractmethod
+ def memory_allocated(self, device=None) -> int:
+ """
+ Returns the current device memory occupied by tensors in bytes for a given device.
+ """
+
+ @abstractmethod
+ def max_memory_allocated(self, device=None) -> int:
+ """
+ Returns the maximum device memory occupied by tensors in bytes for a given device.
+ """
+
+ @abstractmethod
+ def reset_max_memory_allocated(self, device=None) -> None:
+ """
+ Resets the starting point in tracking maximum device memory occupied by tensors for a given device.
+ """
+
+ @abstractmethod
+ def reset_max_memory_cached(self, device=None) -> None:
+ """
+ Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device.
+ """
+
+ @abstractmethod
+ def memory_reserved(self, device=None) -> int:
+ """
+ Returns the current device memory managed by the caching allocator in bytes for a given device.
+ """
+
+ @abstractmethod
+ def max_memory_reserved(self, device=None) -> int:
+ """
+ Returns the maximum device memory managed by the caching allocator in bytes for a given device.
+ """
+
+ @abstractmethod
+ def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
+ """
+ Set memory fraction for a process.
+ """
+
+ @abstractmethod
+ def reset_peak_memory_stats(self, device=None) -> None:
+ """
+ Resets the "peak" stats tracked by the device memory allocator.
+ """
+
+ # =======================
+ # streams and events APIs
+ # =======================
+
+ @abstractmethod
+ def Stream(self, device=None, priority=0, **kwargs):
+ """
+ A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
+ """
+
+ @abstractmethod
+ def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+ """
+ device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
+ """
+
+ @abstractmethod
+ def current_stream(self, device=None):
+ """
+ Returns the currently selected Stream for a given device.
+ """
+
+ @abstractmethod
+ def default_stream(self, device=None):
+ """
+ Returns the default Stream for a given device.
+ """
+
+ @abstractmethod
+ def set_stream(self, stream_):
+ """
+ Sets the current stream.This is a wrapper API to set the stream.
+ """
+
+ @abstractmethod
+ def stream(self, stream_):
+ """
+ Wrapper around the Context-manager StreamContext that selects a given stream.
+ """
+
+ # =======================
+ # amp APIs
+ # =======================
+ @abstractmethod
+ def autocast(
+ self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
+ ) -> Callable:
+ """
+ Return autocast function
+ """
diff --git a/colossalai/accelerator/cpu_accelerator.py b/colossalai/accelerator/cpu_accelerator.py
new file mode 100644
index 000000000000..080aa61e8e3a
--- /dev/null
+++ b/colossalai/accelerator/cpu_accelerator.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python
+
+import resource
+from contextlib import nullcontext
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import psutil
+import torch
+
+from .base_accelerator import BaseAccelerator
+
+__all__ = ["CpuAccelerator"]
+
+
+class CpuAccelerator(BaseAccelerator):
+ support_set_device: bool = False
+ """
+ Accelerator class for cpu.
+ """
+
+ def __init__(self):
+ super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False)
+
+ # =======================
+ # device APIs
+ # =======================
+ def get_version(self) -> str:
+ """
+ Return the version of the accelerator which torch is built against.
+ """
+ return ""
+
+ def get_current_device(self) -> torch.device:
+ """
+ Return the current device.
+ """
+ return torch.device("cpu")
+
+ def current_device(self) -> int:
+ """
+ Return the current device index.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
+ """
+ Bind the current process to a device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def get_device_name(self, device: Union[torch.device, int]) -> str:
+ """
+ Return the name of the device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def synchronize(self, device: Union[torch.device, int] = None):
+ """
+ Synchronize the current process.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def is_available(self):
+ """
+ Check if the accelerator is available.
+ """
+ return True
+
+ def device_count(self):
+ """
+ Return the number of devices on the machine.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def get_device_capability(self, device=None) -> Tuple[int, int]:
+ """
+ Gets the cuda capability of a device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def get_device_name(self, device=None) -> str:
+ """
+ Gets the name of a device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def get_device_properties(self, device):
+ """
+ Gets the properties of a device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def utilization(self, device=None) -> int:
+ """
+ Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ # =======================
+ # random number generator APIs
+ # =======================
+ def get_rng_state(self, device=None) -> torch.Tensor:
+ """
+ Returns the random number generator state of the specified GPU as a ByteTensor.
+ """
+ return torch.get_rng_state(device)
+
+ def get_rng_state_all(self) -> List[torch.Tensor]:
+ """
+ Returns a list of ByteTensor representing the random number states of all devices.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None:
+ """
+ Sets the random number generator state of the specified GPU.
+ """
+ torch.set_rng_state(new_state)
+
+ def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
+ """
+ Sets the random number generator state of all devices.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def manual_seed(self, seed: int) -> None:
+ """
+ Sets the seed for generating random numbers for the current GPU.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def manual_seed_all(self, seed: int) -> None:
+ """
+ Set the random seed for the all processes.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def seed(self) -> None:
+ """
+ Sets the seed for generating random numbers to a random number for the current GPU.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def seed_all(self) -> None:
+ """
+ Sets the seed for generating random numbers to a random number on all GPUs.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def initial_seed(self) -> int:
+ """
+ Returns the current random seed of the current GPU.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ # =======================
+ # memory management APIs
+ # =======================
+
+ def empty_cache(self) -> None:
+ """
+ Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def memory_stats(self, device=None) -> Dict[str, Any]:
+ """
+ Returns a dictionary of CUDA memory allocator statistics for a given device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def memory_summary(self, device=None, abbreviated=False) -> str:
+ """
+ Returns a human-readable printout of the current memory allocator statistics for a given device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def memory_snapshot(self):
+ """
+ Returns a snapshot of the CUDA memory allocator state across all devices.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def memory_allocated(self, device=None) -> int:
+ """
+ Returns the current GPU memory occupied by tensors in bytes for a given device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def max_memory_allocated(self, device=None) -> int:
+ """
+ Returns the maximum GPU memory occupied by tensors in bytes for a given device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def reset_max_memory_allocated(self, device=None) -> None:
+ """
+ Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def reset_max_memory_cached(self, device=None) -> None:
+ """
+ Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def memory_reserved(self, device=None) -> int:
+ """
+ Returns the current GPU memory managed by the caching allocator in bytes for a given device.
+ """
+ return psutil.Process().memory_info().rss
+
+ def max_memory_reserved(self, device=None) -> int:
+ """
+ Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
+ """
+ return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+
+ def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
+ """
+ Set memory fraction for a process.
+ """
+ max_memory = int(psutil.virtual_memory().total * fraction)
+ _, hard = resource.getrlimit(resource.RLIMIT_AS)
+ resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard))
+
+ def reset_peak_memory_stats(self, device=None) -> None:
+ """
+ Resets the "peak" stats tracked by the CUDA memory allocator.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ # =======================
+ # streams and events APIs
+ # =======================
+
+ def Stream(self, device=None, priority=0, **kwargs):
+ """
+ A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+ """
+ CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def current_stream(self, device=None):
+ """
+ Returns the currently selected Stream for a given device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def default_stream(self, device=None):
+ """
+ Returns the default Stream for a given device.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def set_stream(self, stream_):
+ """
+ Sets the current stream.This is a wrapper API to set the stream.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ def stream(self, stream_):
+ """
+ Wrapper around the Context-manager StreamContext that selects a given stream.
+ """
+ raise RuntimeError("this method is not supported for cpu accelerator")
+
+ # =======================
+ # amp APIs
+ # =======================
+ def autocast(
+ self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
+ ) -> Callable:
+ """
+ Return autocast function
+ """
+ return nullcontext
diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py
new file mode 100644
index 000000000000..f1ab487d4f58
--- /dev/null
+++ b/colossalai/accelerator/cuda_accelerator.py
@@ -0,0 +1,282 @@
+#!/usr/bin/env python
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+
+from .base_accelerator import BaseAccelerator
+
+__all__ = ["CudaAccelerator"]
+
+
+class CudaAccelerator(BaseAccelerator):
+ """
+ Accelerator class for Nvidia CUDA devices.
+ """
+
+ def __init__(self):
+ super().__init__(name="cuda", communication_backend="nccl", is_synchronous=False)
+
+ # =======================
+ # device APIs
+ # =======================
+ def get_version(self) -> str:
+ """
+ Return the version of the accelerator which torch is built against.
+ """
+ return torch.version.cuda
+
+ def get_current_device(self) -> torch.device:
+ """
+ Return the current device.
+ """
+ return torch.device(f"cuda:{torch.cuda.current_device()}")
+
+ def current_device(self) -> int:
+ """
+ Return the current device index.
+ """
+ return torch.cuda.current_device()
+
+ def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
+ """
+ Bind the current process to a device.
+ """
+ if device is None:
+ if not dist.is_initialized():
+ raise RuntimeError("Cannot get current device when distributed is not initialized.")
+ device = dist.get_rank() % self.device_count()
+ torch.cuda.set_device(device)
+
+ def get_device_name(self, device: Union[torch.device, int]) -> str:
+ """
+ Return the name of the device.
+ """
+ return torch.cuda.get_device_name(device)
+
+ def synchronize(self, device: Union[torch.device, int] = None):
+ """
+ Synchronize the current process.
+ """
+ torch.cuda.synchronize(device)
+
+ def is_available(self):
+ """
+ Check if the accelerator is available.
+ """
+ return torch.cuda.is_available()
+
+ def device_count(self):
+ """
+ Return the number of devices on the machine.
+ """
+ return torch.cuda.device_count()
+
+ def get_device_capability(self, device=None) -> Tuple[int, int]:
+ """
+ Gets the cuda capability of a device.
+ """
+ return torch.cuda.get_device_capability(device)
+
+ def get_device_name(self, device=None) -> str:
+ """
+ Gets the name of a device.
+ """
+ return torch.cuda.get_device_name(device)
+
+ def get_device_properties(self, device):
+ """
+ Gets the properties of a device.
+ """
+ return torch.cuda.get_device_properties(device)
+
+ def utilization(self, device=None) -> int:
+ """
+ Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
+ """
+ return torch.cuda.utilization(device)
+
+ # =======================
+ # random number generator APIs
+ # =======================
+ def get_rng_state(self, device="cuda") -> torch.Tensor:
+ """
+ Returns the random number generator state of the specified GPU as a ByteTensor.
+ """
+ return torch.cuda.get_rng_state(device)
+
+ def get_rng_state_all(self) -> List[torch.Tensor]:
+ """
+ Returns a list of ByteTensor representing the random number states of all devices.
+ """
+ return torch.cuda.get_rng_state_all()
+
+ def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
+ """
+ Sets the random number generator state of the specified GPU.
+ """
+ torch.cuda.set_rng_state(new_state, device)
+
+ def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
+ """
+ Sets the random number generator state of all devices.
+ """
+ torch.cuda.set_rng_state_all(new_states)
+
+ def manual_seed(self, seed: int) -> None:
+ """
+ Sets the seed for generating random numbers for the current GPU.
+ """
+ torch.cuda.manual_seed(seed)
+
+ def manual_seed_all(self, seed: int) -> None:
+ """
+ Set the random seed for the all processes.
+ """
+ torch.cuda.manual_seed_all(seed)
+
+ def seed(self) -> None:
+ """
+ Sets the seed for generating random numbers to a random number for the current GPU.
+ """
+ torch.cuda.seed()
+
+ def seed_all(self) -> None:
+ """
+ Sets the seed for generating random numbers to a random number on all GPUs.
+ """
+ torch.cuda.seed_all()
+
+ def initial_seed(self) -> int:
+ """
+ Returns the current random seed of the current GPU.
+ """
+ return torch.cuda.initial_seed()
+
+ # =======================
+ # memory management APIs
+ # =======================
+
+ def empty_cache(self) -> None:
+ """
+ Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
+ """
+ torch.cuda.empty_cache()
+
+ def memory_stats(self, device=None) -> Dict[str, Any]:
+ """
+ Returns a dictionary of CUDA memory allocator statistics for a given device.
+ """
+ return torch.cuda.memory_stats(device=device)
+
+ def memory_summary(self, device=None, abbreviated=False) -> str:
+ """
+ Returns a human-readable printout of the current memory allocator statistics for a given device.
+ """
+ return torch.cuda.memory_summary(device=device, abbreviated=abbreviated)
+
+ def memory_snapshot(self):
+ """
+ Returns a snapshot of the CUDA memory allocator state across all devices.
+ """
+ return torch.cuda.memory_snapshot()
+
+ def memory_allocated(self, device=None) -> int:
+ """
+ Returns the current GPU memory occupied by tensors in bytes for a given device.
+ """
+ return torch.cuda.memory_allocated(device=device)
+
+ def max_memory_allocated(self, device=None) -> int:
+ """
+ Returns the maximum GPU memory occupied by tensors in bytes for a given device.
+ """
+ return torch.cuda.max_memory_allocated(device=device)
+
+ def reset_max_memory_allocated(self, device=None) -> None:
+ """
+ Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
+ """
+ torch.cuda.reset_max_memory_allocated(device=device)
+
+ def reset_max_memory_cached(self, device=None) -> None:
+ """
+ Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
+ """
+ torch.cuda.reset_max_memory_cached(device=device)
+
+ def memory_reserved(self, device=None) -> int:
+ """
+ Returns the current GPU memory managed by the caching allocator in bytes for a given device.
+ """
+ return torch.cuda.memory_reserved(device=device)
+
+ def max_memory_reserved(self, device=None) -> int:
+ """
+ Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
+ """
+ return torch.cuda.max_memory_reserved(device=device)
+
+ def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
+ """
+ Set memory fraction for a process.
+ """
+ torch.cuda.set_per_process_memory_fraction(fraction, device=device)
+
+ def reset_peak_memory_stats(self, device=None) -> None:
+ """
+ Resets the "peak" stats tracked by the CUDA memory allocator.
+ """
+ torch.cuda.reset_peak_memory_stats(device=device)
+
+ # =======================
+ # streams and events APIs
+ # =======================
+
+ def Stream(self, device=None, priority=0, **kwargs):
+ """
+ A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
+ """
+ return torch.cuda.Stream(device, priority, **kwargs)
+
+ def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+ """
+ CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
+ """
+ return torch.cuda.Event(enable_timing, blocking, interprocess)
+
+ def current_stream(self, device=None):
+ """
+ Returns the currently selected Stream for a given device.
+ """
+ return torch.cuda.current_stream(device)
+
+ def default_stream(self, device=None):
+ """
+ Returns the default Stream for a given device.
+ """
+ return torch.cuda.default_stream(device)
+
+ def set_stream(self, stream_):
+ """
+ Sets the current stream.This is a wrapper API to set the stream.
+ """
+ torch.cuda.set_stream(stream_)
+
+ def stream(self, stream_):
+ """
+ Wrapper around the Context-manager StreamContext that selects a given stream.
+ """
+ return torch.cuda.stream(stream_)
+
+ # =======================
+ # amp APIs
+ # =======================
+ def autocast(
+ self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
+ ) -> Callable:
+ """
+ Return autocast function
+ """
+ return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py
new file mode 100644
index 000000000000..b28492968eeb
--- /dev/null
+++ b/colossalai/accelerator/npu_accelerator.py
@@ -0,0 +1,288 @@
+#!/usr/bin/env python
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+
+from .base_accelerator import BaseAccelerator
+
+try:
+ import torch_npu # noqa
+except ImportError:
+ pass
+
+
+__all__ = ["NpuAccelerator"]
+
+
+class NpuAccelerator(BaseAccelerator):
+ """
+ Accelerator class for Huawei NPU devices.
+ """
+
+ def __init__(self):
+ super().__init__(name="npu", communication_backend="hccl", is_synchronous=False)
+
+ # =======================
+ # device APIs
+ # =======================
+ def get_version(self) -> str:
+ """
+ Return the version of the accelerator which torch is built against.
+ """
+ return torch.version.cann
+
+ def get_current_device(self) -> torch.device:
+ """
+ Return the current device.
+ """
+ return torch.device(f"npu:{torch.npu.current_device()}")
+
+ def current_device(self) -> int:
+ """
+ Return the current device index.
+ """
+ return torch.npu.current_device()
+
+ def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
+ """
+ Bind the current process to a device.
+ """
+ if device is None:
+ if not dist.is_initialized():
+ raise RuntimeError("Cannot get current device when distributed is not initialized.")
+ device = dist.get_rank() % self.device_count()
+ torch.npu.set_device(device)
+
+ def get_device_name(self, device: Union[torch.device, int]) -> str:
+ """
+ Return the name of the device.
+ """
+ return torch.npu.get_device_name(device)
+
+ def synchronize(self, device: Union[torch.device, int] = None):
+ """
+ Synchronize the current process.
+ """
+ torch.npu.synchronize(device)
+
+ def is_available(self):
+ """
+ Check if the accelerator is available.
+ """
+ return torch.npu.is_available()
+
+ def device_count(self):
+ """
+ Return the number of devices on the machine.
+ """
+ return torch.npu.device_count()
+
+ def get_device_capability(self, device=None) -> Tuple[int, int]:
+ """
+ Gets the npu capability of a device.
+ """
+ return torch.npu.get_device_capability(device)
+
+ def get_device_name(self, device=None) -> str:
+ """
+ Gets the name of a device.
+ """
+ return torch.npu.get_device_name(device)
+
+ def get_device_properties(self, device):
+ """
+ Gets the properties of a device.
+ """
+ return torch.npu.get_device_properties(device)
+
+ def utilization(self, device=None) -> int:
+ """
+ Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
+ """
+ return torch.npu.utilization(device)
+
+ # =======================
+ # random number generator APIs
+ # =======================
+ def get_rng_state(self, device="npu") -> torch.Tensor:
+ """
+ Returns the random number generator state of the specified GPU as a ByteTensor.
+ """
+ return torch.npu.get_rng_state(device)
+
+ def get_rng_state_all(self) -> List[torch.Tensor]:
+ """
+ Returns a list of ByteTensor representing the random number states of all devices.
+ """
+ return torch.npu.get_rng_state_all()
+
+ def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None:
+ """
+ Sets the random number generator state of the specified GPU.
+ """
+ torch.npu.set_rng_state(new_state, device)
+
+ def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
+ """
+ Sets the random number generator state of all devices.
+ """
+ torch.npu.set_rng_state_all(new_states)
+
+ def manual_seed(self, seed: int) -> None:
+ """
+ Sets the seed for generating random numbers for the current GPU.
+ """
+ torch.npu.manual_seed(seed)
+
+ def manual_seed_all(self, seed: int) -> None:
+ """
+ Set the random seed for the all processes.
+ """
+ torch.npu.manual_seed_all(seed)
+
+ def seed(self) -> None:
+ """
+ Sets the seed for generating random numbers to a random number for the current GPU.
+ """
+ torch.npu.seed()
+
+ def seed_all(self) -> None:
+ """
+ Sets the seed for generating random numbers to a random number on all GPUs.
+ """
+ torch.npu.seed_all()
+
+ def initial_seed(self) -> int:
+ """
+ Returns the current random seed of the current GPU.
+ """
+ return torch.npu.initial_seed()
+
+ # =======================
+ # memory management APIs
+ # =======================
+
+ def empty_cache(self) -> None:
+ """
+ Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
+ """
+ torch.npu.empty_cache()
+
+ def memory_stats(self, device=None) -> Dict[str, Any]:
+ """
+ Returns a dictionary of npu memory allocator statistics for a given device.
+ """
+ return torch.npu.memory_stats(device=device)
+
+ def memory_summary(self, device=None, abbreviated=False) -> str:
+ """
+ Returns a human-readable printout of the current memory allocator statistics for a given device.
+ """
+ return torch.npu.memory_summary(device=device, abbreviated=abbreviated)
+
+ def memory_snapshot(self):
+ """
+ Returns a snapshot of the npu memory allocator state across all devices.
+ """
+ return torch.npu.memory_snapshot()
+
+ def memory_allocated(self, device=None) -> int:
+ """
+ Returns the current GPU memory occupied by tensors in bytes for a given device.
+ """
+ return torch.npu.memory_allocated(device=device)
+
+ def max_memory_allocated(self, device=None) -> int:
+ """
+ Returns the maximum GPU memory occupied by tensors in bytes for a given device.
+ """
+ return torch.npu.max_memory_allocated(device=device)
+
+ def reset_max_memory_allocated(self, device=None) -> None:
+ """
+ Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
+ """
+ torch.npu.reset_max_memory_allocated(device=device)
+
+ def reset_max_memory_cached(self, device=None) -> None:
+ """
+ Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
+ """
+ torch.npu.reset_max_memory_cached(device=device)
+
+ def memory_reserved(self, device=None) -> int:
+ """
+ Returns the current GPU memory managed by the caching allocator in bytes for a given device.
+ """
+ return torch.npu.memory_reserved(device=device)
+
+ def max_memory_reserved(self, device=None) -> int:
+ """
+ Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
+ """
+ return torch.npu.max_memory_reserved(device=device)
+
+ def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
+ """
+ Set memory fraction for a process.
+ """
+ torch.npu.set_per_process_memory_fraction(fraction, device=device)
+
+ def reset_peak_memory_stats(self, device=None) -> None:
+ """
+ Resets the "peak" stats tracked by the npu memory allocator.
+ """
+ torch.npu.reset_peak_memory_stats(device=device)
+
+ # =======================
+ # streams and events APIs
+ # =======================
+
+ def Stream(self, device=None, priority=0, **kwargs):
+ """
+ A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details.
+ """
+ return torch.npu.Stream(device, priority, **kwargs)
+
+ def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+ """
+ npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams.
+ """
+ return torch.npu.Event(enable_timing, blocking, interprocess)
+
+ def current_stream(self, device=None):
+ """
+ Returns the currently selected Stream for a given device.
+ """
+ return torch.npu.current_stream(device)
+
+ def default_stream(self, device=None):
+ """
+ Returns the default Stream for a given device.
+ """
+ return torch.npu.default_stream(device)
+
+ def set_stream(self, stream_):
+ """
+ Sets the current stream.This is a wrapper API to set the stream.
+ """
+ torch.npu.set_stream(stream_)
+
+ def stream(self, stream_):
+ """
+ Wrapper around the Context-manager StreamContext that selects a given stream.
+ """
+ return torch.npu.stream(stream_)
+
+ # =======================
+ # amp APIs
+ # =======================
+ def autocast(
+ self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
+ ) -> Callable:
+ """
+ Return autocast function
+ """
+ return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
index 439d13dcfc11..fc4c884d4c5d 100644
--- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
@@ -7,8 +7,8 @@
import torch
from torch import Tensor
+from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
-from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"]
@@ -23,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
- self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
+ self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
self._verbose = verbose
if self._verbose:
diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
index 86ba919ee696..5cd8035d7987 100644
--- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
@@ -5,7 +5,7 @@
import torch
-from colossalai.utils.device import get_current_device
+from colossalai.accelerator import get_accelerator
from .base_grad_scaler import BaseGradScaler
@@ -37,14 +37,20 @@ def __init__(
hysteresis: int = 2,
verbose: bool = False,
):
+ a = get_accelerator()
+ a.device_count()
super().__init__(initial_scale, verbose)
if min_scale:
- self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
+ self._min_scale = torch.tensor(
+ [min_scale], device=get_accelerator().get_current_device(), dtype=torch.float
+ )
else:
self._min_scale = None
if max_scale:
- self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
+ self._max_scale = torch.tensor(
+ [max_scale], device=get_accelerator().get_current_device(), dtype=torch.float
+ )
else:
self._max_scale = None
@@ -117,7 +123,7 @@ def state_dict(self):
return state_dict
def load_state_dict(self, state_dict):
- self._scale = state_dict["scale"].to(get_current_device())
+ self._scale = state_dict["scale"].to(get_accelerator().get_current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]
diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
index 9ce272356797..2e7c8a281916 100644
--- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
+++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py
@@ -5,8 +5,8 @@
import torch.distributed as dist
from torch import Tensor
+from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
-from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin
@@ -40,7 +40,7 @@ def __init__(
max_scale=max_scale,
)
self.optim_state = OptimState.UNSCALED
- self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
+ self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
@property
def loss_scale(self) -> float:
diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py
index 601bf2926d99..fe8439269f48 100644
--- a/colossalai/auto_parallel/offload/amp_optimizer.py
+++ b/colossalai/auto_parallel/offload/amp_optimizer.py
@@ -4,10 +4,10 @@
import torch
from torch.optim import Optimizer
+from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
-from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
from .region import Region
@@ -79,7 +79,9 @@ def __init__(
hysteresis=hysteresis,
max_scale=max_scale,
)
- self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
+ self._found_overflow: torch.Tensor = torch.zeros(
+ 1, dtype=torch.int64, device=get_accelerator().get_current_device()
+ )
self._logger = get_dist_logger()
def _set_grad_ptr(self):
diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py
index a6628e29c2bc..3ad210de9f0a 100644
--- a/colossalai/auto_parallel/offload/solver.py
+++ b/colossalai/auto_parallel/offload/solver.py
@@ -11,7 +11,7 @@
import torch
from torch.fx.node import Node
-from colossalai.utils.device import get_current_device
+from colossalai.accelerator import get_accelerator
from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
@@ -57,7 +57,10 @@ def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error
if memory_budget > 0:
self.memory_budget = memory_budget * self.error_factor
else:
- self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
+ self.memory_budget = (
+ torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
+ * self.error_factor
+ )
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
self.comp_power: float = self._extract_computing_power()
diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py
index 443c4094c0e1..c757a878d97a 100644
--- a/colossalai/booster/mixed_precision/fp16_torch.py
+++ b/colossalai/booster/mixed_precision/fp16_torch.py
@@ -5,8 +5,8 @@
from torch import Tensor
from torch.optim import Optimizer
+from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.utils.device import autocast
from .mixed_precision_base import MixedPrecision
@@ -89,7 +89,7 @@ def __init__(self, module: nn.Module):
super().__init__(module)
def forward(self, *args, **kwargs):
- with autocast():
+ with get_accelerator().autocast():
return self.module(*args, **kwargs)
diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py
index d2dd00453e32..27285f95ce52 100644
--- a/colossalai/booster/plugin/dp_plugin_base.py
+++ b/colossalai/booster/plugin/dp_plugin_base.py
@@ -21,7 +21,16 @@ def __init__(self) -> None:
self.world_size = dist.get_world_size()
def prepare_dataloader(
- self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ self,
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ distributed_sampler_cls=None,
+ **kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@@ -45,7 +54,8 @@ def prepare_dataloader(
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
- sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
+ distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
+ sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index a891db422d67..95b96bbfd9ed 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -15,6 +15,7 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
+from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
@@ -27,8 +28,6 @@
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
-from colossalai.utils import get_current_device
-from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
@@ -366,11 +365,11 @@ def __init__(
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
- if IS_NPU_AVAILABLE:
+ if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
- chunk_init_device=(chunk_init_device or get_current_device()),
+ chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac,
@@ -455,9 +454,18 @@ def control_device(self) -> bool:
def supported_devices(self) -> List[str]:
return ["cuda", "npu"]
-
+
def prepare_dataloader(
- self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ self,
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ distributed_sampler_cls=None,
+ **kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@@ -485,8 +493,12 @@ def prepare_dataloader(
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
- sampler = DistributedSampler(
- dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle
+ distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
+ sampler = distributed_sampler_cls(
+ dataset,
+ num_replicas=zero_world_size * extra_dp_world_size,
+ rank=zero_rank * extra_dp_world_size + extra_dp_rank,
+ shuffle=shuffle,
)
# Deterministic dataloader
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index e1593cf6b26c..da67e6b41fbf 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1,5 +1,6 @@
import ctypes
import random
+import warnings
from contextlib import contextmanager
from functools import partial
from types import MethodType
@@ -18,6 +19,7 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
+from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
@@ -28,7 +30,6 @@
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
-from colossalai.utils.device import get_current_device
from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase
@@ -82,7 +83,7 @@ def __init__(
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
- module = module.to(get_current_device())
+ module = module.to(get_accelerator().get_current_device())
# setting input type cast when using mixed precision
self.convert_fn = None
@@ -345,7 +346,9 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
- total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
+ total_norm_cuda = torch.tensor(
+ [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
+ )
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if self.pp_size > 1:
@@ -385,7 +388,7 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
- [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
+ [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
# compute norm in tp process group
@@ -542,7 +545,9 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
# so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
- total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
+ total_norm_cuda = torch.tensor(
+ [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
+ )
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@@ -586,7 +591,7 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
- [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
+ [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
# compute norm in tp process group
@@ -798,7 +803,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
# so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type)
- total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
+ total_norm_cuda = torch.tensor(
+ [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
+ )
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@@ -838,7 +845,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.tensor(
- [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
+ [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
)
if dp_size > 1:
# compute norm in dp process group
@@ -1128,7 +1135,12 @@ def configure(
tp_process_group=self.tp_group,
)
else:
- assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
+ if self.dp_size == 1:
+ warnings.warn(
+ "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
+ "If you are not intended to use cpu_offload, please consider set zero_stage=0."
+ )
+
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer,
@@ -1193,7 +1205,16 @@ def execute_pipeline(
return outputs
def prepare_dataloader(
- self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ self,
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ distributed_sampler_cls=None,
+ **kwargs,
):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
@@ -1217,7 +1238,8 @@ def prepare_dataloader(
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
- sampler = DistributedSampler(
+ distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
+ sampler = distributed_sampler_cls(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
)
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 89102820cd38..d21496f0b758 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -12,6 +12,7 @@
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
+from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
@@ -24,7 +25,6 @@
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
-from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase
@@ -52,7 +52,7 @@ def __init__(self, module: nn.Module, precision: str) -> None:
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
- module = module.to(get_current_device())
+ module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index e976d0aaf014..45e5a23c1b22 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -22,7 +22,7 @@
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.moe import MoECheckpintIO
+from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
@@ -150,6 +150,7 @@ def __init__(
self,
tp_size: int,
pp_size: int,
+ ep_size: int,
extra_dp_size: int = 1,
precision: str = "fp16",
zero_stage: int = 0,
@@ -181,6 +182,7 @@ def __init__(
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None,
+ checkpoint_io: Optional[MoECheckpintIO] = None,
) -> None:
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
@@ -188,10 +190,26 @@ def __init__(
if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
-
+ assert (
+ dist.get_world_size() % (tp_size * pp_size) == 0
+ ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+ assert (
+ dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
+ ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
+ self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
+ MOE_MANAGER.setup(
+ parallel="EP",
+ mode="fixed",
+ fixed_dp_size=self.real_dp_size,
+ fixed_ep_size=ep_size,
+ fixed_pp_size=pp_size,
+ use_ep_inside=use_ep_inside,
+ )
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ self.ep_size = ep_size
+ self.moe_info = MOE_MANAGER.get_info(0)[1]
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
@@ -200,6 +218,7 @@ def __init__(
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
+ self.checkpoint_io = checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
@@ -323,7 +342,10 @@ def seed_worker(worker_id):
)
def get_checkpoint_io(self) -> MoECheckpintIO:
- self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ if self.checkpoint_io is None:
+ self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ else:
+ self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def configure(
diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py
index 780117598e18..71232421586d 100644
--- a/colossalai/checkpoint_io/checkpoint_io_base.py
+++ b/colossalai/checkpoint_io/checkpoint_io_base.py
@@ -9,7 +9,7 @@
from colossalai.interface import ModelWrapper
-from .utils import has_index_file
+from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
__all__ = ["CheckpointIO"]
@@ -90,7 +90,15 @@ def load_model(
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
- self.load_unsharded_model(model, checkpoint, strict)
+ path = Path(checkpoint, SAFE_WEIGHTS_NAME)
+ if path.is_file():
+ self.load_unsharded_model(model, str(path), strict)
+ else:
+ path = Path(checkpoint, WEIGHTS_NAME)
+ if path.is_file():
+ self.load_unsharded_model(model, str(path), strict)
+ else:
+ self.load_unsharded_model(model, checkpoint, strict)
return origin_model
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index b7900bc0f217..36df30335dd7 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -1,7 +1,7 @@
import copy
-from functools import reduce
import logging
import os
+from functools import reduce
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
@@ -14,6 +14,7 @@
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.utils import get_current_device
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -445,7 +446,11 @@ def save_sharded_optimizer(
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
- save_param_groups(optimizer.param_info, group_file_path)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ save_param_groups({"param_groups": param_groups}, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
@@ -504,7 +509,11 @@ def save_sharded_optimizer(
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
- save_param_groups(optimizer.param_info, group_file_path)
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ save_param_groups({"param_groups": param_groups}, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
@@ -713,12 +722,16 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
tp_group=self.tp_group,
use_zero=self.use_zero,
inplace=False,
- device=torch.device("cuda"),
+ device=get_current_device(),
)
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
- state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
@@ -729,7 +742,11 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
# Only the master rank do the saving.
if self.coordinator.is_master():
- state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
+ param_groups = [
+ {**group, "params": group_info["params"]}
+ for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
+ ]
+ state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
@@ -838,7 +855,7 @@ def gather_from_sharded_optimizer_state(
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
- v = v.cuda()
+ v = v.to(get_current_device())
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index 25076b742c26..aaeaad3828f5 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -6,12 +6,12 @@
from pathlib import Path
from typing import Dict, Union
-import torch
import torch.distributed as dist
+from colossalai.accelerator import get_accelerator
from colossalai.context import Config
from colossalai.logging import get_dist_logger
-from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
+from colossalai.utils import set_seed
def launch(
@@ -47,17 +47,18 @@ def launch(
if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.")
- if IS_NPU_AVAILABLE and backend == "nccl":
- backend = "hccl"
+ cur_accelerator = get_accelerator()
+
+ backend = cur_accelerator.communication_backend
# init default process group
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
- if torch.cuda.is_available() or IS_NPU_AVAILABLE:
- # if local rank is not given, calculate automatically
- set_device(local_rank)
+ # if local rank is not given, calculate automatically
+ if cur_accelerator.support_set_device:
+ cur_accelerator.set_device(local_rank)
set_seed(seed)
diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py
index 8933fc0a3c2f..e69de29bb2d1 100644
--- a/colossalai/kernel/__init__.py
+++ b/colossalai/kernel/__init__.py
@@ -1,7 +0,0 @@
-from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
-
-__all__ = [
- "LayerNorm",
- "FusedScaleMaskSoftmax",
- "MultiHeadAttention",
-]
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu
deleted file mode 100644
index 2b1b366b1c02..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu
+++ /dev/null
@@ -1,63 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#include "column_remap.cuh"
-#include "util.cuh"
-
-const int SHUF_BLOCKSIZE_X = 256;
-const int SHUF_BLOCKSIZE_Y = 16;
-
-__global__ void column_remap_kernel
-(
- const half* __restrict__ x,
- half* __restrict__ x_new,
- const int x_width,
- const int x_height,
- const uint32_t* x_map
-)
-{
- int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
- int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
- if (x_column >= x_width) return;
- //if (x_row >= x_height) return;
-
- int x_stride = x_width;
- int x_idx = x_row * x_stride + x_column;
-
- int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
- int x_idx_end = x_row_end * x_stride + x_column;
-
- int s_column = x_map[x_column];
- int s_idx = x_row * x_stride + s_column;
-
- while (x_idx < x_idx_end)
- {
- x_new[x_idx] = x[s_idx];
- x_idx += x_stride;
- s_idx += x_stride;
- }
-}
-
-// Remap columns in x to correspond to sequential group index before matmul
-//
-// perform x -> seq_x such that seq_x @ seq_w == x @ w
-
-void column_remap_cuda
-(
- const half* x,
- half* x_new,
- const int x_height,
- const int x_width,
- const uint32_t* x_map
-)
-{
- dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
-
- dim3 blocks
- (
- (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
- (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
- 1
- );
-
- column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map);
-}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh
deleted file mode 100644
index 0364e38c4779..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh
+++ /dev/null
@@ -1,19 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _column_remap_cuh
-#define _column_remap_cuh
-
-#include
-#include
-#include
-
-void column_remap_cuda
-(
- const half* x,
- half* x_new,
- const int x_height,
- const int x_width,
- const uint32_t* x_map
-);
-
-#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh
deleted file mode 100644
index c5258813e147..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh
+++ /dev/null
@@ -1,58 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _cuda_compat_cuh
-#define _cuda_compat_cuh
-
-// atomicAdd for half types, to support CC < 7.x
-
-__device__ __forceinline__ void atomicAdd_half(half* address, half val)
-{
- unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
- unsigned int old = *address_as_ui;
- unsigned int assumed;
-
- do
- {
- assumed = old;
- __half_raw hsum;
- hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
- half tmpres = __hadd(hsum, val);
- hsum = __half_raw(tmpres);
- old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
- old = atomicCAS(address_as_ui, assumed, old);
- }
- while (assumed != old);
-}
-
-// atomicAdd for half2 types
-
-__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
-{
- unsigned int* address_as_ui = (unsigned int*)address;
- unsigned int old = *address_as_ui;
- unsigned int assumed;
- do
- {
- assumed = old;
- half2 old_val = *((half2*)&old);
- half2 new_val = __hadd2(old_val, val);
- old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
- }
- while (assumed != old);
-}
-
-//
-
-#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
-#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
-
-__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
-
-#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
-__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
-#endif
-
-#endif
-#endif
-
-#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu
deleted file mode 100644
index 4416027c8387..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu
+++ /dev/null
@@ -1,75 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#define _cuda_buffers_cu
-#include "cuda_buffers.cuh"
-
-CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
-// __constant__ half2 q4_table[16][256];
-// half2 q4_table_host[16][256];
-// bool q4_table_init = false;
-
-CudaBuffers::CudaBuffers
-(
- int _device,
- int _temp_state_size,
- half* _temp_state,
- half* _temp_dq
-) :
- device(_device),
- temp_state_size(_temp_state_size),
- temp_state(_temp_state),
- temp_dq(_temp_dq)
-{
- cudaSetDevice(_device);
-
- cudaStreamCreate(&alt_stream_1);
- cudaStreamCreate(&alt_stream_2);
- cudaStreamCreate(&alt_stream_3);
- cudaEventCreate(&alt_stream_1_done);
- cudaEventCreate(&alt_stream_2_done);
- cudaEventCreate(&alt_stream_3_done);
-}
-
-CudaBuffers::~CudaBuffers()
-{
- cudaStreamDestroy(alt_stream_1);
- cudaStreamDestroy(alt_stream_2);
- cudaStreamDestroy(alt_stream_3);
- cudaEventDestroy(alt_stream_1_done);
- cudaEventDestroy(alt_stream_2_done);
- cudaEventDestroy(alt_stream_3_done);
-}
-
-CudaBuffers* get_buffers(const int device_index)
-{
- return g_buffers[device_index];
-}
-
-void prepare_buffers_cuda
-(
- int _device,
- int _temp_state_size,
- half* _temp_state,
- half* _temp_dq
-)
-{
- CudaBuffers* buffers = new CudaBuffers
- (
- _device,
- _temp_state_size,
- _temp_state,
- _temp_dq
- );
-
- g_buffers[_device] = buffers;
-}
-
-void cleanup_buffers_cuda()
-{
- for (int i = 0; i < CUDA_MAX_DEVICES; i++)
- {
- if (!g_buffers[i]) continue;
- delete g_buffers[i];
- g_buffers[i] = NULL;
- }
-}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh
deleted file mode 100644
index 0bf2057c665c..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh
+++ /dev/null
@@ -1,55 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _cuda_buffers_cuh
-#define _cuda_buffers_cuh
-
-#include
-#include
-#include
-#include
-
-const int CUDA_MAX_DEVICES = 16;
-
-// #ifndef _cuda_buffers_cu
-// extern __constant__ half2 q4_table[16][256];
-// #endif
-
-class CudaBuffers
-{
-public:
- int device;
-
- half* temp_state; // [max_hidden_rows * intermediate_size]
- int temp_state_size;
- half* temp_dq; // size of largest quant tensor * 8
-
- cudaStream_t alt_stream_1;
- cudaStream_t alt_stream_2;
- cudaStream_t alt_stream_3;
- cudaEvent_t alt_stream_1_done;
- cudaEvent_t alt_stream_2_done;
- cudaEvent_t alt_stream_3_done;
-
- CudaBuffers
- (
- int _device,
- int _temp_state_size,
- half* _temp_state,
- half* _temp_dq
- );
- ~CudaBuffers();
-};
-
-CudaBuffers* get_buffers(const int device_index);
-
-void prepare_buffers_cuda
-(
- int _device,
- int _temp_state_size,
- half* _temp_state,
- half* _temp_dq
-);
-
-void cleanup_buffers_cuda();
-
-#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh
deleted file mode 100644
index 5cd2e8553ef6..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh
+++ /dev/null
@@ -1,49 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _hip_compat_cuh
-#define _hip_compat_cuh
-
-// Workaround for a bug in hipamd, backported from upstream.
-__device__ __forceinline__ __half __compat_hrcp(__half x) {
- return __half_raw{
- static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
-}
-
-__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
- return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)),
- static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))};
-}
-
-#define hrcp __compat_hrcp
-#define h2rcp __compat_h2rcp
-
-// Workaround for hipify_python using rocblas instead of hipblas.
-__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
- hipblasOperation_t transA,
- hipblasOperation_t transB,
- int m,
- int n,
- int k,
- const half* alpha,
- const half* AP,
- int lda,
- const half* BP,
- int ldb,
- const half* beta,
- half* CP,
- int ldc) {
- return hipblasHgemm(handle, transA, transB, m, n, k,
- reinterpret_cast(alpha),
- reinterpret_cast(AP), lda,
- reinterpret_cast(BP), ldb,
- reinterpret_cast(beta),
- reinterpret_cast(CP), ldc);
-}
-
-#define rocblas_handle hipblasHandle_t
-#define rocblas_operation_none HIPBLAS_OP_N
-#define rocblas_get_stream hipblasGetStream
-#define rocblas_set_stream hipblasSetStream
-#define rocblas_hgemm __compat_hipblasHgemm
-
-#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp
deleted file mode 100644
index bcc0e43901de..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp
+++ /dev/null
@@ -1,254 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include "util.cuh"
-#include "tuning.h"
-#include "cuda_buffers.cuh"
-#include "q4_matrix.cuh"
-#include "q4_matmul.cuh"
-#include "column_remap.cuh"
-
-// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
-// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
-// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
-
-void check_cuda(cudaError_t ret)
-{
- switch (ret)
- {
- case cudaSuccess:
- break;
-
- case cudaUnspecified:
- printf(" **** Unspecified error\n");
- TORCH_CHECK(false, "CUDA error");
- break;
-
- default:
- printf(" **** CUDA error\n"); \
- printf(" **** %s\n", cudaGetErrorString(ret)); \
- TORCH_CHECK(false, "CUDA error"); \
- break;
- }
-}
-
-// Some decluttering macros
-
-#define STRINGIFY_(__x) #__x
-#define STRINGIFY(__x) STRINGIFY_(__x)
-#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
-#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
-#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
-#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
-#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
-#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
-
-#define TORCH_CHECK_DEVICE_INDEX(__index) \
-do { \
- TORCH_CHECK(__index >= 0, "no device index"); \
- TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
-} while(0)
-
-#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
-do { \
- TORCH_CHECK_DTYPE(__w, kInt); \
- TORCH_CHECK_DTYPE(__w_scales, kHalf); \
- TORCH_CHECK_DTYPE(__w_zeros, kInt); \
- TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
- TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
- TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
- TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
-} while(0)
-
-int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
-{
- int groupsize = w.size(0) * 8 / w_zeros.size(0);
- TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
- return groupsize;
-}
-
-
-// Tuning parameters
-
-ExLlamaTuning tuningParams;
-
-void set_tuning_params
-(
- int matmul_recons_thd,
- bool matmul_fused_remap,
- bool matmul_no_half2
-)
-{
- tuningParams.matmul_recons_thd = matmul_recons_thd;
- tuningParams.matmul_fused_remap = matmul_fused_remap;
- tuningParams.matmul_no_half2 = matmul_no_half2;
-}
-
-
-// Release all unmanaged objects allocated by the extension
-
-void cleanup()
-{
- cleanup_buffers_cuda();
- g_q4_free_matrices();
-}
-
-
-// Prepare buffers for forward pass
-
-void prepare_buffers
-(
- torch::Device device,
- torch::Tensor temp_state,
- torch::Tensor temp_dq
-)
-{
- int device_index = device.index();
- TORCH_CHECK_DEVICE_INDEX(device_index);
- const at::cuda::OptionalCUDAGuard device_guard(device);
-
- prepare_buffers_cuda
- (
- device_index,
- // buffer size used for sanity checks
- temp_state.numel(),
- (half*) temp_state.data_ptr(),
- (half*) temp_dq.data_ptr()
- );
-}
-
-
-// Create Q4Matrix, return handle
-
-uintptr_t make_q4
-(
- torch::Tensor qweight,
- torch::Tensor qzeros,
- torch::Tensor scales,
- torch::Tensor g_idx,
- int device
-)
-{
- TORCH_CHECK_DTYPE(qweight, kInt);
- TORCH_CHECK_DTYPE(qzeros, kInt);
- TORCH_CHECK_DTYPE(scales, kHalf);
- TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
- TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
- TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
- TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
-
- int width = qweight.size(1);
- int height = qweight.size(0) * 8;
- int groups = qzeros.size(0);
-
- Q4Matrix* m = new Q4Matrix
- (
- height,
- width,
- groups,
-
- (uint32_t*) qweight.data_ptr(),
- (uint32_t*) qzeros.data_ptr(),
- (half*) scales.data_ptr(),
- g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
-
- device
- );
-
- g_q4_keep_matrix(m);
- return reinterpret_cast (m);
-}
-
-
-// Matmul half @ quant -> half
-
-void q4_matmul
-(
- torch::Tensor x,
- uintptr_t w,
- torch::Tensor out
-)
-{
- Q4Matrix* wm = reinterpret_cast (w);
-
- TORCH_CHECK_DTYPE(x, kHalf);
- TORCH_CHECK_DTYPE(out, kHalf);
- TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
- TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
-
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
-
- int x_height = x.size(0);
-
- if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
- {
- q4_matmul_cuda
- (
- &tuningParams,
- (half*) x.data_ptr(),
- x_height,
- wm,
- (half*) out.data_ptr()
- );
- }
- else
- {
- q4_matmul_recons_cuda
- (
- &tuningParams,
- (half*) x.data_ptr(),
- x_height,
- wm,
- (half*) out.data_ptr(),
- at::cuda::getCurrentCUDABlasHandle()
- );
- }
-}
-
-
-// Remap columns in half tensor
-
-void column_remap
-(
- torch::Tensor x,
- torch::Tensor x_new,
- torch::Tensor x_map
-)
-{
- TORCH_CHECK_DTYPE(x, kHalf);
- TORCH_CHECK_DTYPE(x_new, kHalf);
- TORCH_CHECK_DTYPE(x_map, kInt);
- TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
-
- int height = x.size(0);
- int width = x.size(1);
-
- TORCH_CHECK_BUFFER_SIZE(x_new, height * width);
-
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
-
- column_remap_cuda
- (
- (half*) x.data_ptr(),
- (half*) x_new.data_ptr(),
- height,
- width,
- (uint32_t*) x_map.data_ptr()
- );
-}
-
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
-{
- m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
- m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
- m.def("cleanup", &cleanup, "cleanup");
- m.def("make_q4", &make_q4, "make_q4");
- m.def("q4_matmul", &q4_matmul, "q4_matmul");
-}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh
deleted file mode 100644
index 2fd5ab0b36cd..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh
+++ /dev/null
@@ -1,294 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _matrix_cuh
-#define _matrix_cuh
-
-#include
-#include
-
-class MatrixView_half
-{
-public:
- const half* data;
- const int height;
- const int width;
-
- __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
- : data(data), height(height), width(width)
- { }
-
- __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
- __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
- __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
- __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
-};
-
-class MatrixView_half_rw
-{
-public:
- half* data;
- const int height;
- const int width;
-
- __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
- : data(data), height(height), width(width)
- { }
-
- __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
- __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
- __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
- __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
- __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
-};
-
-class MatrixView_q4_row
-{
-public:
- const uint32_t* data;
- const int height;
- const int width;
-
- __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
- : data(data), height(height), width(width)
- { }
-
- __device__ __forceinline__ int item(int row, int column) const
- {
- int shift = (column & 0x07) * 4;
- return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
- }
-};
-
-class MatrixView_q4_column
-{
-public:
- const uint32_t* data;
- const int height;
- const int width;
-
- __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
- : data(data), height(height), width(width)
- { }
-
- __device__ __forceinline__ int item(int row, int column) const
- {
- int shift = (row & 0x07) * 4;
- return (data[row / 8 * width + column] >> shift) & 0x0f;
- }
-
- __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
- __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
-};
-
-// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
-
-// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
-
-__device__ __forceinline__ half2 dot_product_8
-(
- const half2 acc,
- MatrixView_half& h_,
- const int h_row,
- const int h_column, // divisible by 8
- MatrixView_q4_column& v_,
- const int v_row, // divisible by 8
- const int v_column,
- const half2 v_scale_2,
- const uint32_t v_zero, // + 1 (!!)
- const int count
-)
-{
- const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
- const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
- half2 result = acc;
-
- for (int i = 0; i < count; i++)
- {
- uint32_t v_read = *v_ptr; v_ptr += v_.width;
-
- half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
- half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
- half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
- half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
- half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
- half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
- half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
- half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
-
- half2 v_01 = __halves2half2(v_0, v_1);
- half2 v_23 = __halves2half2(v_2, v_3);
- half2 v_45 = __halves2half2(v_4, v_5);
- half2 v_67 = __halves2half2(v_6, v_7);
-
-// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
-// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
-// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
-// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
-
- half2 tmp = __hmul2(*h_ptr++, v_01);
- tmp = __hfma2(*h_ptr++, v_23, tmp);
- tmp = __hfma2(*h_ptr++, v_45, tmp);
- tmp = __hfma2(*h_ptr++, v_67, tmp);
- result = __hfma2(v_scale_2, tmp, result);
- }
-
- return result;
-}
-
-__device__ __forceinline__ half dot_product_8_h
-(
- const half acc,
- MatrixView_half& h_,
- const int h_row,
- const int h_column, // divisible by 8
- MatrixView_q4_column& v_,
- const int v_row, // divisible by 8
- const int v_column,
- const half v_scale,
- const uint32_t v_zero, // + 1 (!!)
- const int count
-)
-{
- const half* h_ptr = h_.item_ptr(h_row, h_column);
- const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
- half result = acc;
-
- for (int i = 0; i < count; i++)
- {
- uint32_t v_read = *v_ptr; v_ptr += v_.width;
-
- half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
- half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
- half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
- half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
- half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
- half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
- half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
- half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
-
- half tmp = __hmul(*h_ptr++, v_0);
- tmp = __hfma(*h_ptr++, v_1, tmp);
- tmp = __hfma(*h_ptr++, v_2, tmp);
- tmp = __hfma(*h_ptr++, v_3, tmp);
- tmp = __hfma(*h_ptr++, v_4, tmp);
- tmp = __hfma(*h_ptr++, v_5, tmp);
- tmp = __hfma(*h_ptr++, v_6, tmp);
- tmp = __hfma(*h_ptr++, v_7, tmp);
- result = __hfma(v_scale, tmp, result);
- }
-
- return result;
-}
-
-// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
-
-__device__ __forceinline__ half2 dot_product_8_x_map
-(
- const half2 acc,
- MatrixView_half& h_,
- const int h_row,
- const int h_column, // divisible by 8
- MatrixView_q4_column& v_,
- const int v_row, // divisible by 8
- const int v_column,
- const half2 v_scale_2,
- const uint32_t v_zero, // + 1 (!!)
- const int count,
- const uint32_t* x_map
-)
-{
- const half* h_ptr = h_.item_ptr(h_row, 0);
- const uint32_t* x_map_ptr = x_map + h_column;
- const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
- half2 result = acc;
-
- for (int i = 0; i < count; i++)
- {
- uint32_t v_read = *v_ptr; v_ptr += v_.width;
-
- half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
- half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
- half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
- half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
- half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
- half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
- half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
- half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
-
- half2 v_01 = __halves2half2(v_0, v_1);
- half2 v_23 = __halves2half2(v_2, v_3);
- half2 v_45 = __halves2half2(v_4, v_5);
- half2 v_67 = __halves2half2(v_6, v_7);
-
- half h_0 = h_ptr[*x_map_ptr++];
- half h_1 = h_ptr[*x_map_ptr++];
- half h_2 = h_ptr[*x_map_ptr++];
- half h_3 = h_ptr[*x_map_ptr++];
- half h_4 = h_ptr[*x_map_ptr++];
- half h_5 = h_ptr[*x_map_ptr++];
- half h_6 = h_ptr[*x_map_ptr++];
- half h_7 = h_ptr[*x_map_ptr++];
-
- half2 h_01 = __halves2half2(h_0, h_1);
- half2 h_23 = __halves2half2(h_2, h_3);
- half2 h_45 = __halves2half2(h_4, h_5);
- half2 h_67 = __halves2half2(h_6, h_7);
-
- half2 tmp = __hmul2(h_01, v_01);
- tmp = __hfma2(h_23, v_23, tmp);
- tmp = __hfma2(h_45, v_45, tmp);
- tmp = __hfma2(h_67, v_67, tmp);
- result = __hfma2(v_scale_2, tmp, result);
- }
-
- return result;
-}
-
-__device__ __forceinline__ half dot_product_8_x_map_h
-(
- const half acc,
- MatrixView_half& h_,
- const int h_row,
- const int h_column, // divisible by 8
- MatrixView_q4_column& v_,
- const int v_row, // divisible by 8
- const int v_column,
- const half v_scale,
- const uint32_t v_zero, // + 1 (!!)
- const int count,
- const uint32_t* x_map
-)
-{
- const half* h_ptr = h_.item_ptr(h_row, 0);
- const uint32_t* x_map_ptr = x_map + h_column;
- const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
- half result = acc;
-
- for (int i = 0; i < count; i++)
- {
- uint32_t v_read = *v_ptr; v_ptr += v_.width;
-
- half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
- half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
- half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
- half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
- half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
- half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
- half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
- half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
-
- half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
- result = __hfma(v_scale, tmp, result);
- }
-
- return result;
-}
-
-#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu
deleted file mode 100644
index f47daeb0e877..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu
+++ /dev/null
@@ -1,260 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#include "q4_matmul.cuh"
-#include "column_remap.cuh"
-#include "util.cuh"
-#include "matrix.cuh"
-#include "cu_compat.cuh"
-#include "cuda_buffers.cuh"
-#if defined(USE_ROCM)
-#include "hip_compat.cuh"
-#endif
-
-const int THREADS_X = 32; // Block size and thread count along columns in w and out
-const int THREADS_Y = 1; // Block size and thread count along rows in x and out
-
-typedef void (*fp_q4_matmul_kernel)
-(
- const half*,
- const uint32_t*,
- half*,
- const half*,
- const uint32_t*,
- const int,
- const int,
- const int,
- const int,
- const int,
- const uint32_t*,
- bool
-);
-
-template
-__global__ void q4_matmul_kernel
-(
- const half* __restrict__ x,
- const uint32_t* __restrict__ w,
- half* __restrict__ out,
- const half* __restrict__ w_scales,
- const uint32_t* __restrict__ w_zeros,
- const int height,
- const int dim,
- const int width,
- const int groupsize,
- const int block_size_z,
- const uint32_t* __restrict__ x_map,
- bool no_zero
-)
-{
- // Start of block
-
- int x_column = block_size_z * blockIdx.z;
- int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
-
- int w_column = THREADS_X * blockIdx.x + threadIdx.x;
- int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
-
- int iterations = (x_column_end - x_column) / 8;
-
- // Views
-
- MatrixView_half x_(x, height, dim);
- MatrixView_half w_scales_(w_scales, dim / groupsize, width);
- MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
- MatrixView_q4_column w_(w, dim, width);
- MatrixView_half_rw out_(out, height, width);
-
- // Zero output
-
- if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
- {
- *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
- __syncthreads();
- }
-
- // Loop over part of x row (and w column)
-
- half2 acc = {};
- half acc_h = {};
-
- if constexpr (use_groupsize)
- {
- // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
- // could be slightly faster
-
- for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
- {
- if constexpr (use_half2)
- {
- half2 w_scale = w_scales_.item_half2half2(group, w_column);
- uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
-
- if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
- else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
- }
- else
- {
- half w_scale = w_scales_.item(group, w_column);
- uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
-
- if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
- else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
- }
- }
- }
- else
- {
- // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
-
- for (int k = x_column; k < x_column + iterations * 8; k += 8)
- {
- if constexpr (use_half2)
- {
- int group = k / groupsize;
- half2 w_scale = w_scales_.item_half2half2(group, w_column);
- uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
-
- if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
- else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
- }
- else
- {
- int group = k / groupsize;
- half w_scale = w_scales_.item(group, w_column);
- uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
-
- if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
- else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
- }
- }
- }
-
- // Add to block result
-
- if constexpr (use_half2)
- {
- half result = __hadd(__low2half(acc), __high2half(acc));
- atomicAdd(out_.item_ptr(x_row, w_column), result);
- }
- else
- {
- atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
- }
-}
-
-fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
-{
- //
- if (tuningParams->matmul_no_half2) {
- if (block_size_z % groupsize == 0) {
- if (x_map) return q4_matmul_kernel;
- else return q4_matmul_kernel;
- } else {
- if (x_map) return q4_matmul_kernel;
- else return q4_matmul_kernel;
- }
- } else {
- if (block_size_z % groupsize == 0)
- {
- if (x_map) return q4_matmul_kernel;
- else return q4_matmul_kernel;
- } else {
- if (x_map) return q4_matmul_kernel;
- else return q4_matmul_kernel;
- }
- }
-};
-
-// Compute y = x @ w
-
-void q4_matmul_cuda
-(
- ExLlamaTuning* tuningParams,
- const half* x,
- const int x_height,
- const Q4Matrix* w,
- half* out,
- bool no_zero,
- cudaStream_t alt_stream
-)
-{
- int height = x_height;
- int dim = w->height;
- int width = w->width;
-
- cudaSetDevice(w->device);
-
- uint32_t* x_map = w->cuda_x_map;
- const half* x_mapped = x;
- if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
- {
- CudaBuffers* buffers = get_buffers(w->device);
- column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
- x_mapped = buffers->temp_state;
- x_map = NULL;
- }
-
- int block_size_z;
- if (w->width == 4096) block_size_z = 384; // 7B
- else if (w->width == 11008) block_size_z = 256;
- else if (w->width == 5120) block_size_z = 384; // 13B
- else if (w->width == 13824) block_size_z = 256;
- else if (w->width == 6656) block_size_z = 256; // 33B
- else if (w->width == 17920) block_size_z = 128;
- else block_size_z = 256;
-
- //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
-
- dim3 threads(THREADS_X, THREADS_Y, 1);
-
- dim3 blocks
- (
- (width + threads.x - 1) / threads.x,
- (height + threads.y - 1) / threads.y,
- (dim + block_size_z - 1) / block_size_z
- );
-
- fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
-
- kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
-}
-
-void q4_matmul_recons_cuda
-(
- ExLlamaTuning* tuningParams,
- const half* x,
- const int x_height,
- Q4Matrix* w,
- half* out,
- const cublasHandle_t handle,
- bool no_zero
-)
-{
- int height = x_height;
- int dim = w->height;
- int width = w->width;
-
- cudaSetDevice(w->device);
- CudaBuffers* buffers = get_buffers(w->device);
-
- const half* x_mapped = x;
- if (w->cuda_x_map)
- {
- TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
- column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
- x_mapped = buffers->temp_state;
- }
-
- w->reconstruct(buffers->temp_dq);
-
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
- const float alpha = 1.0f;
- const float beta = no_zero ? 1.0f : 0.0f;
- cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
- x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
-#else
- const half alpha = __float2half(1.0f);
- const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
- cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
-#endif
-}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh
deleted file mode 100644
index 09f3e1a63362..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh
+++ /dev/null
@@ -1,43 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _q4_matmul_cuh
-#define _q4_matmul_cuh
-
-#include
-#include
-#include
-#include
-#include
-
-#include "q4_matrix.cuh"
-#include "tuning.h"
-
-// Workaround for hipify_python using rocblas instead of hipblas.
-#if defined(USE_ROCM)
-#include
-#define rocblas_handle hipblasHandle_t
-#endif
-
-void q4_matmul_cuda
-(
- ExLlamaTuning* tuningParams,
- const half* x,
- const int x_height,
- const Q4Matrix* w,
- half* out,
- bool no_zero = false,
- cudaStream_t alt_stream = NULL
-);
-
-void q4_matmul_recons_cuda
-(
- ExLlamaTuning* tuningParams,
- const half* x,
- const int x_height,
- Q4Matrix* w,
- half* out,
- const cublasHandle_t handle,
- bool no_zero = false
-);
-
-#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu
deleted file mode 100644
index 9c61143f565e..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu
+++ /dev/null
@@ -1,225 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#include "q4_matrix.cuh"
-#include
-#include "util.cuh"
-#include "matrix.cuh"
-
-using namespace std;
-
-const int UNSHUF_BLOCKSIZE_X = 64;
-
-const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
-const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
-
-vector g_q4_matrices;
-
-void g_q4_keep_matrix(Q4Matrix* m)
-{
- g_q4_matrices.push_back(m);
-}
-
-void g_q4_free_matrices()
-{
- for (const auto& m : g_q4_matrices) delete m;
- g_q4_matrices.clear();
-}
-
-Q4Matrix::Q4Matrix
-(
- const int _height,
- const int _width,
- const int _groups,
-
- uint32_t* _qweight,
- uint32_t* _qzeros,
- half* _scales,
- uint32_t* _g_idx,
-
- const int _device
-) :
- height(_height),
- width(_width),
- groups(_groups),
- device(_device)
-{
- cudaSetDevice(device);
-
- cuda_qweight = _qweight;
- cuda_qzeros = _qzeros;
- cuda_scales = _scales;
-
- groupsize = height / groups;
-
- if (_g_idx) make_sequential(_g_idx);
-}
-
-Q4Matrix::~Q4Matrix()
-{
-}
-
-// Make sequential
-
-__global__ void make_sequential_kernel
-(
- const uint32_t* __restrict__ w,
- uint32_t* __restrict__ w_new,
- const uint32_t* __restrict__ x_map,
- const int w_height,
- const int w_width
-)
-{
- const uint64_t* w2 = (uint64_t*) w;
- uint64_t* w_new2 = (uint64_t*) w_new;
- int w2_stride = w_width >> 1;
-
- int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
- if (w2_column >= w2_stride) return;
-
- int w_new2_row = blockIdx.y;
-
- int x_map_idx = w_new2_row << 3;
-
- uint64_t dst = 0;
-
- #pragma unroll
- for (int i = 0; i < 8; i++)
- {
- int source_row = x_map[x_map_idx++];
-
- int w2_row = source_row >> 3;
- int w2_subrow = source_row & 0x07;
- int w2_row_shift = w2_subrow << 2;
- int wnew2_row_shift = i << 2;
-
- uint64_t src = w2[w2_row * w2_stride + w2_column];
- src >>= w2_row_shift;
- src &= 0x0000000f0000000f;
- src <<= wnew2_row_shift;
- dst |= src;
- }
-
- w_new2[w_new2_row * w2_stride + w2_column] = dst;
-}
-
-void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
-{
- uint32_t* cuda_new_qweight = NULL;
- cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
- cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
-
- uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
- uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
- uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
-
- // Group histogram
-
- for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
-
- // Group map
-
- for (int i = 0, acc = 0; i < groups; i++)
- {
- short tmp = cpu_g_idx_map[i];
- cpu_g_idx_map[i] = acc;
- acc += tmp;
- }
-
- // X map (inverse)
-
- for (int row = 0; row < height; row++)
- {
- uint32_t target_group = cpu_g_idx[row];
- uint32_t target_row = cpu_g_idx_map[target_group];
- cpu_g_idx_map[target_group]++;
- cpu_x_map_inv[row] = target_row;
- }
-
- // X map
-
- for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
-
- // Move to CUDA
-
- cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
-
- // Rearrange rows in w
-
- dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
- dim3 blocks
- (
- (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
- height / 8,
- 1
- );
-
- make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
-
- // Replace qweights
-
- cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
-
- // Cleanup
-
- cudaDeviceSynchronize();
- cudaFree(cuda_new_qweight);
- free(cpu_g_idx_map);
- free(cpu_x_map);
- free(cpu_x_map_inv);
-}
-
-__global__ void reconstruct_kernel
-(
- const uint32_t* __restrict__ w,
- half* __restrict__ out, // (y)
- const half* __restrict__ w_scales,
- const uint32_t* __restrict__ w_zeros,
- const int height,
- const int width,
- const int groupsize
-)
-{
- // Start of block
-
- int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
- int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
- if (column >= width) return;
-
- // Views
-
- MatrixView_q4_column w_(w, height, width);
- MatrixView_half_rw out_(out, height, width);
- MatrixView_half w_scales_(w_scales, height / groupsize, width);
- MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
-
- // Groupsize version
-
- int group = row / groupsize;
-
- half w_scale = w_scales_.item(group, column);
- uint32_t w_zero = w_zeros_.item(group, column) + 1;
-
- uint32_t w_read = w_.item_uint32_t(row, column);
- half* out_ptr = out_.item_ptr(row, column);
-
- #pragma unroll
- for (int s = 0; s < 32; s += 4)
- {
- half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
- *out_ptr = w_item; out_ptr += out_.width;
- }
-}
-
-void Q4Matrix::reconstruct(half* out)
-{
- dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
-
- dim3 blocks
- (
- (width + threads.x - 1) / threads.x,
- (height / 8 + threads.y - 1) / threads.y,
- 1
- );
-
- reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
-}
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh
deleted file mode 100644
index 50cb72a41518..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh
+++ /dev/null
@@ -1,53 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _q4_matrix_cuh
-#define _q4_matrix_cuh
-
-#include
-#include
-#include
-
-class Q4Matrix
-{
-public:
-
- int device;
-
- int height;
- int width;
- int groups;
- int groupsize;
-
- uint32_t* cuda_qweight = NULL;
- uint32_t* cuda_qzeros = NULL;
- half* cuda_scales = NULL;
- uint32_t* cuda_x_map = NULL;
-
- Q4Matrix
- (
- const int _height,
- const int _width,
- const int _groups,
-
- uint32_t* _qweight,
- uint32_t* _qzeros,
- half* _scales,
- uint32_t* _g_idx,
-
- const int _device
- );
-
- ~Q4Matrix();
-
- void reconstruct(half* out);
-
-private:
-
- void make_sequential(const uint32_t* cpu_g_idx);
-
-};
-
-void g_q4_keep_matrix(Q4Matrix* m);
-void g_q4_free_matrices();
-
-#endif
\ No newline at end of file
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h
deleted file mode 100644
index e413b8a96c11..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h
+++ /dev/null
@@ -1,12 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _tuning_h
-#define _tuning_h
-
-struct ExLlamaTuning {
- int matmul_recons_thd;
- bool matmul_fused_remap;
- bool matmul_no_half2;
-};
-
-#endif
diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh
deleted file mode 100644
index 7b397573214b..000000000000
--- a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh
+++ /dev/null
@@ -1,33 +0,0 @@
-// Adapted from turboderp exllama: https://github.com/turboderp/exllama
-
-#ifndef _util_cuh
-#define _util_cuh
-
-#include
-#include
-#include
-#include
-
-#if defined(USE_ROCM)
-#define cudaUnspecified hipErrorUnknown
-#else
-#define cudaUnspecified cudaErrorApiFailureBase
-#endif
-
-// React to failure on return code != cudaSuccess
-
-#define _cuda_check(fn) \
-do { \
- {_cuda_err = fn;} \
- if (_cuda_err != cudaSuccess) goto _cuda_fail; \
-} while(false)
-
-// React to failure on return code == 0
-
-#define _alloc_check(fn) \
-do { \
- if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
- else _cuda_err = cudaSuccess; \
-} while(false)
-
-#endif
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu
deleted file mode 100644
index 58d26235a9cc..000000000000
--- a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu
+++ /dev/null
@@ -1,191 +0,0 @@
-#include "block_reduce.h"
-#include "cuda_util.h"
-#include "kernels.h"
-#include "ls_cub.cuh"
-
-ls::cub::CachingDeviceAllocator g_allocator(true);
-
-template
-__global__ void ls_cross_entropy_fw_kernel(
- const T *__restrict__ inputs, const int *__restrict__ targets,
- float *__restrict__ outputs, float *__restrict__ nll_loss_outputs,
- const int padding_idx, const float epsilon, const int vocab_size) {
- /* step1: compute each thread's max_logit and sum_exp_logit, store in
- * max_input, sum_exp_logit */
- const int block_start = blockIdx.x * vocab_size;
- const int left_idx = block_start + threadIdx.x;
- const int right_idx = (blockIdx.x + 1) * vocab_size;
- float max_input[1] = {REDUCE_FLOAT_INF_NEG};
- float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
- int target_tid = targets[blockIdx.x];
-
- if (target_tid == padding_idx) {
- if (threadIdx.x == 0) {
- nll_loss_outputs[blockIdx.x] = 0.f;
- outputs[blockIdx.x] = 0.f;
- }
- return;
- }
-
- for (int i = left_idx; i < right_idx; i += blockDim.x) {
- max_input[0] = fmaxf(max_input[0], static_cast(inputs[i]));
- }
- blockReduce(max_input);
- __shared__ float s_max_input;
- if (threadIdx.x == 0) {
- s_max_input = max_input[0];
- }
- __syncthreads();
-
- for (int i = left_idx; i < right_idx; i += blockDim.x) {
- float logit = static_cast(inputs[i]) - s_max_input;
- sum_logits[0] += logit;
- sum_logits[1] += expf(logit);
- }
-
- blockReduce(sum_logits);
- __shared__ float s_sum_logit;
- __shared__ float s_sum_exp;
- if (threadIdx.x == 0) {
- s_sum_logit = sum_logits[0];
- s_sum_exp = sum_logits[1];
- }
- __syncthreads();
-
- float eps_i = epsilon / (vocab_size - 1);
- if (threadIdx.x == 0) {
- // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max)
- float nll_loss = logf(s_sum_exp) -
- static_cast(inputs[block_start + target_tid]) +
- s_max_input;
- nll_loss_outputs[blockIdx.x] = nll_loss;
- float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit;
- outputs[blockIdx.x] =
- (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss;
- }
-}
-
-template
-__global__ void ls_cross_entropy_bw_kernel(
- const float *__restrict__ grad_outputs, const T *__restrict__ inputs,
- const int *__restrict__ targets, T *__restrict__ grad_inputs,
- const int padding_idx, const float epsilon, const int vocab_size) {
- /* step1: compute each thread's max_logit and sum_exp_logit, store in
- * max_input, sum_exp_logit */
- const int block_start = blockIdx.x * vocab_size;
- const int left_idx = block_start + threadIdx.x;
- const int right_idx = (blockIdx.x + 1) * vocab_size;
- float max_input[1] = {REDUCE_FLOAT_INF_NEG};
- float sum_logits[1] = {0.f};
- const float grad_out = static_cast(grad_outputs[0]);
- int target_tid = targets[blockIdx.x];
-
- if (target_tid == padding_idx) {
- for (int i = left_idx; i < right_idx; i += blockDim.x) {
- grad_inputs[i] = 0.f;
- }
- return;
- }
-
- for (int i = left_idx; i < right_idx; i += blockDim.x) {
- max_input[0] = fmaxf(max_input[0], static_cast(inputs[i]));
- }
- blockReduce(max_input);
- __shared__ float s_max_input;
- if (threadIdx.x == 0) {
- s_max_input = max_input[0];
- }
- __syncthreads();
-
- for (int i = left_idx; i < right_idx; i += blockDim.x) {
- float logit = static_cast(inputs[i]) - s_max_input;
- sum_logits[0] += expf(logit);
- }
-
- blockReduce(sum_logits);
- __shared__ float s_sum_exp;
- if (threadIdx.x == 0) {
- s_sum_exp = sum_logits[0];
- }
- __syncthreads();
-
- float eps_i = epsilon / (vocab_size - 1);
- float nll_weight = 1.0 - epsilon - eps_i;
-
- for (int i = left_idx; i < right_idx; i += blockDim.x) {
- float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp;
- float grad = 0;
- grad += (vocab_size * prob - 1) * eps_i;
- grad += prob * nll_weight;
- if ((i - block_start) == target_tid) {
- grad -= nll_weight;
- }
- grad_inputs[i] = grad_out * grad;
- }
-}
-
-template
-void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr,
- float *outputs_ptr, float *nll_loss_ptr,
- float *loss_buffer, const int padding_idx,
- const float epsilon, const int batch_size,
- const int seq_len, const int vocab_size,
- cudaStream_t stream) {
- int grid_dim = batch_size * seq_len;
- float *nll_loss_buffer = loss_buffer + grid_dim;
- ls_cross_entropy_fw_kernel<<>>(
- inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx,
- epsilon, vocab_size);
-
- int num_items = grid_dim;
- void *d_temp_storage = NULL;
- size_t temp_storage_bytes = 0;
- CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
- loss_buffer, outputs_ptr,
- num_items, stream));
- CHECK_GPU_ERROR(
- g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes));
- CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
- loss_buffer, outputs_ptr,
- num_items, stream));
- CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
- nll_loss_buffer, nll_loss_ptr,
- num_items, stream));
- CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage));
-}
-
-template void launch_cross_entropy_fw(
- const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
- float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
- const float epsilon, const int batch_size, const int seq_len,
- const int vocab_size, cudaStream_t stream);
-
-template void launch_cross_entropy_fw<__half>(
- const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr,
- float *nll_loss_ptr, float *loss_buffer, const int padding_idx,
- const float epsilon, const int batch_size, const int seq_len,
- const int vocab_size, cudaStream_t stream);
-
-template
-void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr,
- const int *targets_ptr, T *grad_inputs_ptr,
- const int padding_idx, const float epsilon,
- const int batch_size, const int seq_len,
- const int vocab_size, cudaStream_t stream) {
- int grid_dim = batch_size * seq_len;
- ls_cross_entropy_bw_kernel<<>>(
- grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx,
- epsilon, vocab_size);
-}
-
-template void launch_cross_entropy_bw(
- const float *grad_outputs_ptr, const float *inputs_ptr,
- const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx,
- const float epsilon, const int batch_size, const int seq_len,
- const int vocab_size, cudaStream_t stream);
-
-template void launch_cross_entropy_bw<__half>(
- const float *grad_outputs_ptr, const __half *inputs_ptr,
- const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx,
- const float epsilon, const int batch_size, const int seq_len,
- const int vocab_size, cudaStream_t stream);
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu
deleted file mode 100644
index 09f34763f9b2..000000000000
--- a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu
+++ /dev/null
@@ -1,88 +0,0 @@
-/* Copyright 2021 The LightSeq Team
- Copyright Microsoft DeepSpeed
- This file is adapted from Microsoft DeepSpeed
- Licensed under the MIT License.
-*/
-#include "cublas_wrappers.h"
-
-int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
- cublasOperation_t transb, int m, int n, int k,
- const float *alpha, const float *beta, const float *A,
- const float *B, float *C, cublasGemmAlgo_t algo) {
- cublasStatus_t status =
- cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha,
- (const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k,
- (const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n,
- (const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo);
-
- if (status != CUBLAS_STATUS_SUCCESS) {
- fprintf(stderr,
- "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
- m, n, k, (int)status);
- return EXIT_FAILURE;
- }
- return 0;
-}
-
-int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa,
- cublasOperation_t transb, int m, int n, int k,
- const float *alpha, const float *beta, const __half *A,
- const __half *B, __half *C, cublasGemmAlgo_t algo) {
- cublasStatus_t status = cublasGemmEx(
- handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A,
- CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F,
- (transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C,
- CUDA_R_16F, m, CUDA_R_32F, algo);
-
- if (status != CUBLAS_STATUS_SUCCESS) {
- fprintf(stderr,
- "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
- m, n, k, (int)status);
- return EXIT_FAILURE;
- }
- return 0;
-}
-
-int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
- const float *alpha, const float *beta,
- const float *A, const float *B, float *C,
- cublasOperation_t op_A, cublasOperation_t op_B,
- int stride_A, int stride_B, int stride_C,
- int batch, cublasGemmAlgo_t algo) {
- cublasStatus_t status = cublasGemmStridedBatchedEx(
- handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F,
- (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F,
- (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C,
- batch, CUDA_R_32F, algo);
-
- if (status != CUBLAS_STATUS_SUCCESS) {
- fprintf(stderr,
- "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, "
- "error: %d) \n",
- batch, m, n, k, (int)status);
- return EXIT_FAILURE;
- }
- return 0;
-}
-
-int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k,
- const float *alpha, const float *beta,
- const __half *A, const __half *B, __half *C,
- cublasOperation_t op_A, cublasOperation_t op_B,
- int stride_A, int stride_B, int stride_C,
- int batch, cublasGemmAlgo_t algo) {
- cublasStatus_t status = cublasGemmStridedBatchedEx(
- handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F,
- (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F,
- (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C,
- batch, CUDA_R_32F, algo);
-
- if (status != CUBLAS_STATUS_SUCCESS) {
- fprintf(stderr,
- "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
- m, n, k, (int)status);
- return EXIT_FAILURE;
- }
-
- return 0;
-}
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
deleted file mode 100644
index e5ac17308640..000000000000
--- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
+++ /dev/null
@@ -1,169 +0,0 @@
-#include
-#include
-#include
-#include "cuda_util.h"
-
-/* GPU function guard */
-std::string _cudaGetErrorString(cudaError_t error) {
- return cudaGetErrorString(error);
-}
-
-std::string _cudaGetErrorString(cublasStatus_t error) {
- switch (error) {
- case CUBLAS_STATUS_SUCCESS:
- return "CUBLAS_STATUS_SUCCESS";
-
- case CUBLAS_STATUS_NOT_INITIALIZED:
- return "CUBLAS_STATUS_NOT_INITIALIZED";
-
- case CUBLAS_STATUS_ALLOC_FAILED:
- return "CUBLAS_STATUS_ALLOC_FAILED";
-
- case CUBLAS_STATUS_INVALID_VALUE:
- return "CUBLAS_STATUS_INVALID_VALUE";
-
- case CUBLAS_STATUS_ARCH_MISMATCH:
- return "CUBLAS_STATUS_ARCH_MISMATCH";
-
- case CUBLAS_STATUS_MAPPING_ERROR:
- return "CUBLAS_STATUS_MAPPING_ERROR";
-
- case CUBLAS_STATUS_EXECUTION_FAILED:
- return "CUBLAS_STATUS_EXECUTION_FAILED";
-
- case CUBLAS_STATUS_INTERNAL_ERROR:
- return "CUBLAS_STATUS_INTERNAL_ERROR";
-
- case CUBLAS_STATUS_NOT_SUPPORTED:
- return "CUBLAS_STATUS_NOT_SUPPORTED";
-
- case CUBLAS_STATUS_LICENSE_ERROR:
- return "CUBLAS_STATUS_LICENSE_ERROR";
- }
- return "CUBLAS_UNKNOW";
-}
-
-template
-void check_gpu_error(T result, char const *const func, const char *const file,
- int const line) {
- if (result) {
- throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" +
- std::to_string(line) +
- "): " + (_cudaGetErrorString(result)) + "\n");
- }
-}
-
-template void check_gpu_error(cudaError_t result,
- char const *const func,
- const char *const file,
- int const line);
-template void check_gpu_error(cublasStatus_t result,
- char const *const func,
- const char *const file,
- int const line);
-
-template
-void print_vec(const T *outv, std::string outn, int num_output_ele) {
- std::cout << outn << ": ";
- std::vector hout(num_output_ele, (T)0);
- cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T),
- cudaMemcpyDeviceToHost);
- for (int i = 0; i < num_output_ele; i++) {
- std::cout << hout[i] << ", ";
- }
- std::cout << std::endl;
-}
-
-template <>
-void print_vec<__half>(const __half *outv, std::string outn,
- int num_output_ele) {
- std::cout << outn << ": ";
- std::vector<__half> hout(num_output_ele, (__half)0.f);
- cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half),
- cudaMemcpyDeviceToHost);
- for (int i = 0; i < num_output_ele; i++) {
- std::cout << __half2float(hout[i]) << ", ";
- }
- std::cout << std::endl;
-}
-
-template void print_vec(const float *outv, std::string outn,
- int num_output_ele);
-
-template void print_vec(const int *outv, std::string outn,
- int num_output_ele);
-
-template void print_vec<__half>(const __half *outv, std::string outn,
- int num_output_ele);
-
-template
-T *cuda_malloc(size_t ele_num) {
- size_t byte_size = ele_num * sizeof(T);
- T *pdata = nullptr;
- CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size));
- return pdata;
-}
-
-template float *cuda_malloc(size_t ele_num);
-
-template __half *cuda_malloc<__half>(size_t ele_num);
-
-template uint8_t *cuda_malloc(size_t ele_num);
-
-void cuda_free(void *pdata) {
- if (pdata != nullptr) {
- cudaFree(pdata);
- }
-}
-
-template
-struct _isnan {
- __device__ bool operator()(T a) const { return isnan(a); }
-};
-
-template <>
-struct _isnan<__half> {
- __device__ bool operator()(const __half a) const { return __hisnan(a); }
-};
-
-template
-struct _isinf {
- __device__ bool operator()(T a) const { return isinf(a); }
-};
-
-template <>
-struct _isinf<__half> {
- __device__ bool operator()(const __half a) const { return __hisinf(a); }
-};
-
-template
-void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
- std::string file, int line, cudaStream_t stream) {
- // check_nan_inf = 0 for checking nan
- // check_nan_inf = 1 for checking inf
- bool res = false;
- std::string msg = file + "(" + std::to_string(line) + "): ";
- if (check_nan_inf) {
- msg += "nan.";
- res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
- data_ptr + dsize, _isnan(), false,
- thrust::logical_or());
- } else {
- msg += "inf.";
- res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr,
- data_ptr + dsize, _isinf(), false,
- thrust::logical_or());
- }
- if (res) {
- throw std::runtime_error(msg);
- }
- std::cout << msg << " [check pass]." << std::endl;
-}
-
-template void check_nan_inf(const float *data_ptr, int dsize,
- bool check_nan_inf, std::string file,
- int line, cudaStream_t stream);
-
-template void check_nan_inf<__half>(const __half *data_ptr, int dsize,
- bool check_nan_inf, std::string file,
- int line, cudaStream_t stream);
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
deleted file mode 100644
index ce0b017f12e1..000000000000
--- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
+++ /dev/null
@@ -1,1002 +0,0 @@
-#include
-#include
-
-#include "kernels.h"
-
-#include
-
-
-namespace cg = cooperative_groups;
-
-curandStatePhilox4_32_10_t *curandstate;
-
-/**
- * @brief element-wise activation function on device, like Relu, Gelu
- *
- * @tparam enum class ActivationType, kRelu, kGelu
- * @tparam input type
- * @param any shape of float and __half2
- * @return same shape and type with input
- */
-template
-__forceinline__ __device__ T activation_kernel(T x);
-
-template <>
-__device__ float activation_kernel(float x) {
- float cdf =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
- return x * cdf;
-}
-
-template <>
-__device__ __half2
-activation_kernel(__half2 val) {
- __half2 val_pow3 = __hmul2(val, __hmul2(val, val));
- float2 tmp_pow = __half22float2(val_pow3);
- float2 tmp = __half22float2(val);
-
- tmp.x =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
- tmp.y =
- 0.5f *
- (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
- return __hmul2(val, __float22half2_rn(tmp));
-}
-
-template <>
-__device__ float activation_kernel(float x) {
- return fmaxf(x, 0);
-}
-
-template <>
-__device__ __half2
-activation_kernel(__half2 x) {
- return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)),
- fmaxf(0.f, __half2float(x.y)));
-}
-
-/**
- * @brief element-wise activation backward function on device
- *
- * @tparam enum class ActivationType
- * @tparam input type
- * @param any shape of float and __half2
- * @return same shape of input
- */
-template
-__forceinline__ __device__ T activation_bwd_kernel(T grad, T x);
-
-template <>
-__device__ float activation_bwd_kernel(float grad,
- float x) {
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
-
- float x2mul = x * x * mul_param;
- float tan_h = tanhf(sqrt_param * (x + x * x2mul));
- float dg1 = 0.5f * (1.0f + tan_h);
- float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
- float dg3 = dg2 * 3 * x2mul;
- return grad * (dg1 + dg2 + dg3);
-}
-
-template <>
-__device__ __half activation_bwd_kernel(
- __half grad, __half x_half) {
- float x = __half2float(x_half);
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
-
- float x2mul = x * x * mul_param;
- float tan_h = tanhf(sqrt_param * (x + x * x2mul));
- float dg1 = 0.5f * (1.0f + tan_h);
- float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
- float dg3 = dg2 * 3 * x2mul;
- return grad * __float2half(dg1 + dg2 + dg3);
-}
-
-template <>
-__device__ float activation_bwd_kernel(float grad,
- float x) {
- return x > 0.f ? grad : 0.f;
-}
-
-template <>
-__device__ __half
-activation_bwd_kernel(__half grad, __half x) {
- const __half half_zero = __float2half(0.f);
- return x > half_zero ? grad : half_zero;
-}
-
-template <>
-__device__ __half2 activation_bwd_kernel(
- __half2 grad2, __half2 x_half2) {
- const __half half_zero = __float2half(0.f);
- return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero,
- x_half2.y > half_zero ? grad2.y : half_zero);
-}
-
-/**
- * @brief init curand states in global memory
- *
- * @thread grid_dim * block*dim to suuport any size of states
- * @param state persistant curand states
- * @param seed seed to init states
- * @return void
- */
-__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state,
- int seed) {
- /* Each thread gets same seed, a different sequence
- number, no offset */
- int id = threadIdx.x + blockIdx.x * blockDim.x;
- curand_init(seed, id, 0, &state[id]);
-}
-
-void launch_curand_init(int total_count, int dim, cudaStream_t stream) {
- cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t));
- int grid_dim = total_count >> 9;
- curand_init_kernel<<>>(
- curandstate, std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
-}
-
-/**
- * @brief element-wise dropout, store dropped position in mask, it's not
- * in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out any size of float and __half
- * @param in same with out
- * @param mask uint8 type, same size with out
- * @param seed seed to curand
- * @return void
- */
-__global__ void ls_dropout_kernel(const int total_count, const float ratio,
- float *__restrict__ out,
- const float *__restrict__ in,
- uint8_t *__restrict__ mask, const int seed) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
-
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
-
- float4 input4 = data4[i];
- float4 res4;
- res4.x = input4.x * scale * m[0];
- res4.y = input4.y * scale * m[1];
- res4.z = input4.z * scale * m[2];
- res4.w = input4.w * scale * m[3];
- out4[i] = res4;
-}
-
-__global__ void ls_dropout_kernel(const int total_count, const float ratio,
- __half *__restrict__ out,
- const __half *__restrict__ in,
- uint8_t *__restrict__ mask, const int seed) {
- const float scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = (uint8_t)(rand.x > ratio);
- m[5] = (uint8_t)(rand.y > ratio);
- m[6] = (uint8_t)(rand.z > ratio);
- m[7] = (uint8_t)(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = *m8;
-
- float4 val_float4 = vals_float4[i];
- float4 out_float4;
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
- __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
- __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
- __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
- out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
- out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
- out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
- out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
- outs_float4[i] = out_float4;
-}
-
-/**
- * @brief element-wise dropout backward with dropout mask, it's
- * not in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param in any size of float and __half
- * @param mask uint8 type, same size with in
- * @return void
- */
-__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
- float *out, const float *in,
- const uint8_t *__restrict__ mask) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *in4 = reinterpret_cast(in);
- const uint32_t *mask4 = reinterpret_cast(mask);
-
- uint32_t *m4 = reinterpret_cast(m);
- m4[0] = mask4[i];
-
- float4 input4 = in4[i];
- float4 res4;
- res4.x = input4.x * scale * static_cast(m[0]);
- res4.y = input4.y * scale * static_cast(m[1]);
- res4.z = input4.z * scale * static_cast(m[2]);
- res4.w = input4.w * scale * static_cast(m[3]);
- out4[i] = res4;
-}
-
-__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
- __half *out, const __half *in,
- const uint8_t *__restrict__ mask) {
- const __half scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *vals_float4 = reinterpret_cast(in);
- const uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- uint64_t *m8 = reinterpret_cast(m);
- m8[0] = mask8[i];
-
- float4 val_float4 = vals_float4[i];
- float4 out_float4;
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- __half2 scale_mask_1 =
- __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
- __half2 scale_mask_2 =
- __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
- __half2 scale_mask_3 =
- __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
- __half2 scale_mask_4 =
- __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
- out_half2[0] = __hmul2(val_half2[0], scale_mask_1);
- out_half2[1] = __hmul2(val_half2[1], scale_mask_2);
- out_half2[2] = __hmul2(val_half2[2], scale_mask_3);
- out_half2[3] = __hmul2(val_half2[3], scale_mask_4);
- out4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout(float *out, const float *vals, uint8_t *mask,
- int total_count, float ratio, cudaStream_t stream,
- bool backward) {
- int grid_dim = total_count >> 12;
- if (!backward) {
- ls_dropout_kernel<<>>(
- total_count, ratio, out, vals, mask,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
- } else {
- ls_dropout_bwd_kernel<<>>(total_count, ratio,
- out, vals, mask);
- }
-}
-
-template <>
-void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask,
- int total_count, float ratio,
- cudaStream_t stream, bool backward) {
- int grid_dim = total_count >> 13;
- if (!backward) {
- ls_dropout_kernel<<>>(
- total_count, ratio, out, vals, mask,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count());
- } else {
- ls_dropout_bwd_kernel<<>>(total_count, ratio,
- out, vals, mask);
- }
-}
-
-/**
- * @brief fused bias, dropout, and residual at the end of Attention and FFN,
- * store dropped position in mask, it's not in-place
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out [batch_size, seq_len, hidden_size], float and __half
- * @param in [batch_size, seq_len, hidden_size], float and __half
- * @param mask [batch_size, seq_len, hidden_size], uint8 type
- * @param bias [hidden_size], ffn bias
- * @param residual [batch_size, seq_len, hidden_size], float and __half
- * @param seed seed to curand
- * @param hidden_size hidden size
- * @return void
- */
-__global__ void ls_dropout_res_bias_kernel(
- const int total_count, const float ratio, float *__restrict__ out,
- const float *__restrict__ in, uint8_t *__restrict__ mask,
- const float *__restrict__ bias, const float *__restrict__ residual,
- const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- const float4 *residual4 = reinterpret_cast(residual);
- const float4 *bias4 = reinterpret_cast(bias);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = static_cast(rand.x > ratio);
- m[1] = static_cast(rand.y > ratio);
- m[2] = static_cast(rand.z > ratio);
- m[3] = static_cast(rand.w > ratio);
-
- int bias_i = i % (hidden_size >> 2);
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
- const float4 input4 = data4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- const float4 res4 = residual4[i];
- float4 output4;
-
- output4.x = (input4.x + b4.x) * scale * m[0] + res4.x;
- output4.y = (input4.y + b4.y) * scale * m[1] + res4.y;
- output4.z = (input4.z + b4.z) * scale * m[2] + res4.z;
- output4.w = (input4.w + b4.w) * scale * m[3] + res4.w;
-
- out4[i] = output4;
-}
-
-__global__ void ls_dropout_res_bias_kernel(
- const int total_count, const float ratio, __half *__restrict__ out,
- const __half *__restrict__ in, uint8_t *__restrict__ mask,
- const __half *__restrict__ bias, const __half *__restrict__ residual,
- const int seed, const int hidden_size) {
- const __half scale = 1. / (1. - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- const float4 *residual4 = reinterpret_cast(residual);
- const float4 *bias4 = reinterpret_cast(bias);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = static_cast(rand.x > ratio);
- m[1] = static_cast(rand.y > ratio);
- m[2] = static_cast(rand.z > ratio);
- m[3] = static_cast(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = static_cast(rand.x > ratio);
- m[5] = static_cast(rand.y > ratio);
- m[6] = static_cast(rand.z > ratio);
- m[7] = static_cast(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = m8[0];
-
- int bias_i = i % (hidden_size >> 3);
- float4 val_float4 = vals_float4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- const float4 res4 = residual4[i];
- float4 out_float4;
-
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- const __half2 *b_half2 = reinterpret_cast(&b4);
- const __half2 *res_half2 = reinterpret_cast(&res4);
- __half2 scale_mask_1 =
- __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1]));
- __half2 scale_mask_2 =
- __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3]));
- __half2 scale_mask_3 =
- __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5]));
- __half2 scale_mask_4 =
- __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7]));
- out_half2[0] =
- __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]);
- out_half2[1] =
- __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]);
- out_half2[2] =
- __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]);
- out_half2[3] =
- __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]);
- outs_float4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout_res_bias(float *out, const float *vals,
- uint8_t *mask, const float *bias,
- const float *residual, int total_count,
- int dim, float ratio,
- cudaStream_t stream) {
- int grid_dim = total_count >> 12;
- ls_dropout_res_bias_kernel<<>>(
- total_count, ratio, out, vals, mask, bias, residual,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals,
- uint8_t *mask, const __half *bias,
- const __half *residual, int total_count,
- int dim, float ratio,
- cudaStream_t stream) {
- int grid_dim = total_count >> 13;
- ls_dropout_res_bias_kernel<<>>(
- total_count, ratio, out, vals, mask, bias, residual,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-/**
- * @brief fused bias and dropout backward at the end of Attention and FFN
- *
- * @thread
- * gridDim.x = hidden_size / 8
- * blockDim.x = 8
- * blockDim.y = 1024 / 8 = 128
- *
- * @param row_size batch_size * seq_len
- * @param ratio dropout ratio
- * @param in_grad [batch_size, seq_len, hidden_size], input grad
- * @param bias_grad [hidden_size], bias grad
- * @param out_grad [batch_size, seq_len, hidden_size], output grad
- * @param mask [batch_size, seq_len, hidden_size], dropout mask
- * @param hidden_size
- * @return void
- */
-__global__ void ls_dropout_bias_bwd_kernel(
- const int row_size, const float ratio, float *__restrict__ in_grad,
- float *__restrict__ bias_grad, const float *__restrict__ out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- // every block generate 8 bias result
- __shared__ float tile[8][129];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
- int stride = hidden_size * 128;
- float local_sum = 0;
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- for (int r = threadIdx.y; r < row_size; r += 128) {
- float val = out_grad[idx];
- val *= scale * static_cast(mask[idx]);
- local_sum += val;
- in_grad[idx] = val;
- idx += stride;
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
-
- float sum = 0;
- int tid = threadIdx.y * blockDim.x + threadIdx.x;
- int x = tid >> 7;
- int y = tid & (127);
- if (y < 32) {
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- sum += tile[x][y + i * 32];
- }
- }
- __syncthreads();
-
- for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (y == 0) tile[0][x] = sum;
- __syncthreads();
-
- if (threadIdx.x < 8) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
- bias_grad[pos] = tile[0][threadIdx.x];
- }
-}
-
-__global__ void ls_dropout_bias_bwd_kernel(
- const int row_size, const float ratio, __half *__restrict__ in_grad,
- __half *__restrict__ bias_grad, const __half *__restrict__ out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
- __shared__ __half2 tile[8][129];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
- const __half2 *out_grad2 = reinterpret_cast(out_grad);
- __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8);
- int stride = hidden_size * 128;
- __half2 local_sum = __float2half2_rn(0.f);
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- for (int r = threadIdx.y; r < row_size; r += 128) {
- __half2 val = out_grad2[idx];
- __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
- val *= scale * m2;
- local_sum += val;
- in_grad2[idx] = val;
- idx += stride;
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
-
- __half2 sum = __float2half2_rn(0.f);
- int tid = threadIdx.y * blockDim.x + threadIdx.x;
- int x = tid >> 7;
- int y = tid & (127);
- if (y < 32) {
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- sum += tile[x][y + i * 32];
- }
- }
- __syncthreads();
-
- for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (y == 0) tile[0][x] = sum;
- __syncthreads();
-
- if (threadIdx.x < 8) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, 8);
- bias_grad2[pos] = tile[0][threadIdx.x];
- }
-}
-
-template
-void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad,
- const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream) {
- dim3 grid_dim((dim - 1) / 8 + 1);
- dim3 block_dim(8, 128);
- ls_dropout_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
-}
-
-template <>
-void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad,
- const __half *out_grad, const uint8_t *mask,
- int row_size, int dim, float ratio,
- cudaStream_t stream) {
- dim >>= 1;
- dim3 grid_dim((dim - 1) / 8 + 1);
- dim3 block_dim(8, 128);
- ls_dropout_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, out_grad, mask, dim);
-}
-
-template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad,
- const float *out_grad,
- const uint8_t *mask, int row_size,
- int dim, float ratio,
- cudaStream_t stream);
-
-/**
- * @brief fused bias, activation, and dropout at the end of first ffn
- *
- * @thread
- * gridDim.x = hidden_size / 8
- * blockDim.x = 8
- * blockDim.y = 1024 / 8 = 128
- *
- * @tparam act_type activation function, like kRelu, kGelu
- * @param total_count total elements
- * @param ratio drop ratio
- * @param out [batch_size, seq_len, hidden_size], float and __half
- * @param in [batch_size, seq_len, hidden_size], float and __half
- * @param mask [batch_size, seq_len, hidden_size], uint8 type
- * @param bias [hidden_size], ffn bias
- * @param seed seed to curand
- * @param hidden_size
- * @return void
- */
-template
-__global__ void ls_dropout_act_bias_kernel(
- const int total_count, const float ratio, float *__restrict__ out,
- const float *__restrict__ in, uint8_t *__restrict__ mask,
- const float *__restrict__ bias, const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 4 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
- uint8_t m[4];
-
- float4 *out4 = reinterpret_cast(out);
- const float4 *data4 = reinterpret_cast(in);
- const float4 *bias4 = reinterpret_cast(bias);
- uint32_t *mask4 = reinterpret_cast(mask);
- float4 rand = curand_uniform4(&state);
-
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
-
- int bias_i = i % (hidden_size >> 2);
- uint32_t *m4 = reinterpret_cast(m);
- mask4[i] = m4[0];
- const float4 input4 = data4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- float4 output4;
-
- output4.x =
- activation_kernel(input4.x + b4.x) * scale * m[0];
- output4.y =
- activation_kernel(input4.y + b4.y) * scale * m[1];
- output4.z =
- activation_kernel(input4.z + b4.z) * scale * m[2];
- output4.w =
- activation_kernel(input4.w + b4.w) * scale * m[3];
-
- out4[i] = output4;
-}
-
-template
-__global__ void ls_dropout_act_bias_kernel(
- const int total_count, const float ratio, __half *__restrict__ out,
- const __half *__restrict__ in, uint8_t *__restrict__ mask,
- const __half *__restrict__ bias, const int seed, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
-
- int i = blockIdx.x * blockDim.x + threadIdx.x;
-
- if (i * 8 >= total_count) return;
-
- curandStatePhilox4_32_10_t state;
- curand_init(seed, i, 0, &state);
-
- const float4 *vals_float4 = reinterpret_cast(in);
- float4 *outs_float4 = reinterpret_cast(out);
- const float4 *bias4 = reinterpret_cast(bias);
- uint64_t *mask8 = reinterpret_cast(mask);
-
- uint8_t m[8];
- float4 rand = curand_uniform4(&state);
- m[0] = (uint8_t)(rand.x > ratio);
- m[1] = (uint8_t)(rand.y > ratio);
- m[2] = (uint8_t)(rand.z > ratio);
- m[3] = (uint8_t)(rand.w > ratio);
- rand = curand_uniform4(&state);
- m[4] = (uint8_t)(rand.x > ratio);
- m[5] = (uint8_t)(rand.y > ratio);
- m[6] = (uint8_t)(rand.z > ratio);
- m[7] = (uint8_t)(rand.w > ratio);
- uint64_t *m8 = reinterpret_cast(m);
- mask8[i] = *m8;
-
- int bias_i = i % (hidden_size >> 3);
- float4 val_float4 = vals_float4[i];
- const float4 b4 = __ldg(&bias4[bias_i]);
- float4 out_float4;
-
- __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4);
- __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4);
- const __half2 *b_half2 = reinterpret_cast(&b4);
-
- __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]);
- __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]);
- __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]);
- __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]);
- out_half2[0] = __hmul2(
- activation_kernel(__hadd2(val_half2[0], b_half2[0])),
- scale_mask_1);
- out_half2[1] = __hmul2(
- activation_kernel(__hadd2(val_half2[1], b_half2[1])),
- scale_mask_2);
- out_half2[2] = __hmul2(
- activation_kernel(__hadd2(val_half2[2], b_half2[2])),
- scale_mask_3);
- out_half2[3] = __hmul2(
- activation_kernel(__hadd2(val_half2[3], b_half2[3])),
- scale_mask_4);
- outs_float4[i] = out_float4;
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- float *out, const float *vals, uint8_t *mask, const float *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 10;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- __half *out, const __half *vals, uint8_t *mask, const __half *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 11;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- float *out, const float *vals, uint8_t *mask, const float *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 10;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-template <>
-void launch_ls_dropout_act_bias(
- __half *out, const __half *vals, uint8_t *mask, const __half *bias,
- int total_count, int dim, float ratio, cudaStream_t stream) {
- int grid_dim = total_count >> 11;
- ls_dropout_act_bias_kernel
- <<>>(
- total_count, ratio, out, vals, mask, bias,
- std::chrono::duration_cast(
- std::chrono::system_clock::now().time_since_epoch())
- .count(),
- dim);
-}
-
-/**
- * @brief fused bias, activation, and dropout backward
- *
- * @thread
- * gridDim.x = total_count / 1024
- * blockDim.x = 1024
- *
- * @tparam act_type kRelu
- * @param row_size batch_size * seq_len
- * @param ratio dropout ratio
- * @param in_grad [batch_size, seq_len, hidden_size], input grad
- * @param bias_grad [hidden_size], bias grad
- * @param out_grad [batch_size, seq_len, hidden_size], output grad
- * @param mask [batch_size, seq_len, hidden_size], dropout mask
- * @param hidden_size
- * @return void
- */
-template
-__global__ void ls_dropout_act_bias_bwd_kernel(
- const int row_size, const float ratio, T *in_grad,
- T *__restrict__ bias_grad, const T *__restrict__ input,
- const T *__restrict__ bias, const T *out_grad,
- const uint8_t *__restrict__ mask, const int hidden_size) {
- const float scale = 1.f / (1.f - ratio);
- __shared__ float tile[WARP_SIZE][WARP_SIZE + 1];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-
- int stride = hidden_size * WARP_SIZE;
- float local_sum = 0;
-
- int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
- if (col_idx < hidden_size) {
- for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
- float val = out_grad[idx];
- float in = input[idx];
- float b = bias[idx % hidden_size];
- val = activation_bwd_kernel(
- val * scale * static_cast(mask[idx]), in + b);
- local_sum += val;
- in_grad[idx] = val;
- idx += stride;
- }
- }
-
- tile[threadIdx.x][threadIdx.y] = local_sum;
- __syncthreads();
- float sum = tile[threadIdx.y][threadIdx.x];
- __syncthreads();
-
- for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
- __syncthreads();
-
- if (threadIdx.y == 0) {
- int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
- bias_grad[pos] = tile[0][threadIdx.x];
- }
-}
-
-// @brief fused bias, activation, and dropout backward
-// It is deprecated for precision reason. Keep it for future optimization.
-//
-// template
-// __global__ void ls_dropout_act_bias_bwd_kernel(
-// const int row_size, const float ratio, __half * in_grad,
-// __half *__restrict__ bias_grad, const __half *__restrict__ input, const
-// __half *__restrict__ bias, const __half * out_grad, const uint8_t
-// *__restrict__ mask, const int hidden_size) {
-// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio));
-// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1];
-
-// cg::thread_block b = cg::this_thread_block();
-// cg::thread_block_tile g = cg::tiled_partition(b);
-
-// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad);
-// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad);
-// const __half2 *out_grad2 = reinterpret_cast(out_grad);
-// const __half2 *input2 = reinterpret_cast(input);
-// const __half2 *bias2 = reinterpret_cast(bias);
-
-// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-
-// int stride = hidden_size * WARP_SIZE;
-// __half2 local_sum = __float2half2_rn(0.f);
-
-// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size);
-// if (col_idx < hidden_size) {
-// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) {
-// __half2 val = out_grad2[idx];
-// __half2 in2 = input2[idx];
-// __half2 b2 = bias2[idx % hidden_size ];
-// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]);
-// val = activation_bwd_kernel(val * scale
-// *
-// m2,
-// in2+b2);
-// local_sum += val;
-// in_grad2[idx] = val;
-// idx += stride;
-// }
-// }
-
-// tile[threadIdx.x][threadIdx.y] = local_sum;
-// __syncthreads();
-// __half2 sum = tile[threadIdx.y][threadIdx.x];
-// __syncthreads();
-
-// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
-// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
-// __syncthreads();
-
-// if (threadIdx.y == 0) {
-// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
-// bias_grad2[pos] = tile[0][threadIdx.x];
-// }
-// }
-
-template
-void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input,
- const T *bias, const T *out_grad,
- const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream) {
- dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
- dim3 block_dim(WARP_SIZE, WARP_SIZE);
- ls_dropout_act_bias_bwd_kernel<<>>(
- row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim);
-}
-
-// template <>
-// void launch_ls_dropout_act_bias_bwd(
-// __half *in_grad, __half *bias_grad,const __half *input, const __half
-// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int
-// dim, float ratio, cudaStream_t stream) {
-// dim >>= 1;
-// dim3 grid_dim((dim - 1) / WARP_SIZE + 1);
-// dim3 block_dim(WARP_SIZE, WARP_SIZE);
-// ls_dropout_act_bias_bwd_kernel
-// <<>>(row_size, ratio, in_grad,
-// bias_grad,
-// input, bias,out_grad, mask, dim);
-// }
-
-template void launch_ls_dropout_act_bias_bwd(
- float *in_grad, float *bias_grad, const float *input, const float *bias,
- const float *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- __half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
- const __half *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- float *in_grad, float *bias_grad, const float *input, const float *bias,
- const float *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
-
-template void launch_ls_dropout_act_bias_bwd(
- __half *in_grad, __half *bias_grad, const __half *input, const __half *bias,
- const __half *out_grad, const uint8_t *mask, int row_size, int dim,
- float ratio, cudaStream_t stream);
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu
deleted file mode 100644
index 625b02cd25d9..000000000000
--- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu
+++ /dev/null
@@ -1,232 +0,0 @@
-#include
-
-#include "kernels.h"
-
-namespace cg = cooperative_groups;
-
-/**
-@brief: fuse_transpose_bias
-Calculate the sum of elements in each column of the matrix.
-
-@thread
-gridDim.x = ceil(cols / WARP_SIZE)
-blockDim.x = WARP_SIZE
-blockDim.y = WARP_SIZE
-
-@param
-inp: [rows, cols]
-out: [cols]
-rows: the number of rows in the matrix
-cols: the number of cols in the matrix
-*/
-template
-__global__ void column_sum_reduce(const T *__restrict__ inp,
- T *__restrict__ out, int rows, int cols) {
- __shared__ float tile[WARP_SIZE][WARP_SIZE];
-
- cg::thread_block b = cg::this_thread_block();
- cg::thread_block_tile g = cg::tiled_partition(b);
-
- int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE);
- int y_stride = cols * WARP_SIZE;
- float localSum = 0;
-
- // Loop across matrix row
- // TODO: optimize to log complexity
- if (idx < cols) {
- int offset = flat_2dim(threadIdx.y, idx, cols);
- for (int r = threadIdx.y; r < rows; r += WARP_SIZE) {
- localSum += (float)inp[offset];
- offset += y_stride;
- }
- }
-
- // The sum of a row in tile is equal to the sum of a col in original matrix
- tile[threadIdx.x][threadIdx.y] = localSum;
-
- __syncthreads();
-
- // Sum the shared buffer.
- // The change of threadIdx.x is continuous
- float sum = tile[threadIdx.y][threadIdx.x];
-
- __syncthreads();
-
- // Calculate the sum of a row in tile
- for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
-
- if (threadIdx.x == 0) {
- int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE);
- if (pos < cols) out[pos] = sum;
- }
-}
-
-// [r, c] -> [c]
-template <>
-void launch_fuse_transpose_bias_kernel(const float *inp, float *out,
- int rows, int cols,
- cudaStream_t stream) {
- dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
- dim3 block_dim(WARP_SIZE, WARP_SIZE);
-
- column_sum_reduce
- <<>>(inp, out, rows, cols);
-}
-
-template <>
-void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out,
- int rows, int cols,
- cudaStream_t stream) {
- dim3 grid_dim((cols - 1) / WARP_SIZE + 1);
- dim3 block_dim(WARP_SIZE, WARP_SIZE);
-
- column_sum_reduce<__half>
- <<>>(inp, out, rows, cols);
-}
-
-/**
-@brief: fused_add2
-Add two matrix inp1 and inp2 to out.
-
-@thread
-gridDim.x = batch_size * seq_len
-blockDim.x = min(hidden_dim, MAX_THREADS)
-
-@param
-inp1: [batch_size, seq_len, hidden_dim]
-inp2: [batch_size, seq_len, hidden_dim]
-out: [batch_size, seq_len, hidden_dim]
-batch_size: the size of the current batch
-seq_len: the sequence length of the current batch
-hidden_dim: dim of the hidden tensor
-*/
-template
-__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2,
- int hidden_dim);
-
-template <>
-__global__ void fused_add2_kernel(float *out, const float *inp1,
- const float *inp2, int hidden_dim) {
- int row_id = blockIdx.x;
- int offset = flat_2dim(row_id, 0, hidden_dim);
-
- const float4 *inp1_4 = reinterpret_cast(inp1);
- const float4 *inp2_4 = reinterpret_cast(inp2);
- float4 *out_4 = reinterpret_cast(out);
- float4 vinp1;
- float4 vinp2;
- float4 val;
-
- for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
- vinp1 = inp1_4[offset + i];
- vinp2 = inp2_4[offset + i];
- val.x = vinp1.x + vinp2.x;
- val.y = vinp1.y + vinp2.y;
- val.z = vinp1.z + vinp2.z;
- val.w = vinp1.w + vinp2.w;
- out_4[offset + i] = val;
- }
-}
-
-template <>
-__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1,
- const __half *inp2, int hidden_dim) {
- int row_id = blockIdx.x;
- int offset = flat_2dim(row_id, 0, hidden_dim);
-
- const float4 *inp1_4 = reinterpret_cast(inp1);
- const float4 *inp2_4 = reinterpret_cast(inp2);
- float4 *out_4 = reinterpret_cast