Skip to content

Commit

Permalink
Improve code and add support internlm2
Browse files Browse the repository at this point in the history
  • Loading branch information
xusenlin committed Jan 19, 2024
1 parent 389aed2 commit 5b45f34
Show file tree
Hide file tree
Showing 27 changed files with 591 additions and 686 deletions.
8 changes: 4 additions & 4 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ STREAM_INTERVERL=2
PROMPT_NAME=

# device related
DEVICE=cuda
DEVICE_MAP=
DEVICE=

# "auto", "cuda:0", "cuda:1", ...
DEVICE_MAP=auto
GPUS=
NUM_GPUs=1
DTYPE=half

# patch related
PATCH_TYPE=

# api related
API_PREFIX=/v1
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

## 📢 新闻

+ 【2024.01.19】 添加 [InternLM2](https://github.com/InternLM/InternLM) 模型支持,[启动方式](https://github.com/xusenlinzy/api-for-open-llm/blob/master/docs/SCRIPT.md#internlm2)


+ 【2023.12.21】 添加 [TGI](https://github.com/huggingface/text-generation-inference) 生成接口转发和 [TEI](https://github.com/huggingface/text-embeddings-inference) embedding 接口转发


Expand Down Expand Up @@ -113,6 +116,7 @@
| [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B) | Qwen | 7B | en, zh | [Qwen/Qwen-7B-Chat](https://huggingface.co/baichuan-inc/Qwen/Qwen-7B-Chat) |
| [baichuan-13b-chat](https://github.com/baichuan-inc/Baichuan-13B) | Baichuan | 13B | en, zh | [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) |
| [InternLM](https://github.com/InternLM/InternLM) | InternLM | 7B | en, zh | [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) |
| [InternLM2](https://github.com/InternLM/InternLM) | InternLM2 | 20B | en, zh | [internlm/internlm2-chat-20b](https://huggingface.co/internlm/internlm2-chat-20b) |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | GLM | 6/130B | en, zh | [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) |
| [baichaun-7b](https://github.com/baichuan-inc/baichuan-7B) | Baichuan | 7B | en, zh | [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) |
| [Guanaco](https://github.com/artidoro/qlora/tree/main) | LLaMA | 7/33/65B | en | [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged) |
Expand Down
113 changes: 113 additions & 0 deletions api/adapter/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

from typing import (
TYPE_CHECKING,
Optional,
Tuple,
Any,
)

from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)

from .patcher import (
patch_config,
patch_tokenizer,
patch_model,
)

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer


def _load_model_and_tokenizer(
model_name_or_path: str,
use_fast_tokenizer: Optional[bool] = False,
dtype: Optional[str] = None,
device_map: Optional[Any] = None,
load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
rope_scaling: Optional[str] = None,
flash_attn: Optional[bool] = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.
Support inference.
"""
config_kwargs = {"trust_remote_code": True}

tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
use_fast=use_fast_tokenizer,
trust_remote_code=True,
)
patch_tokenizer(tokenizer)

config = AutoConfig.from_pretrained(model_name_or_path, **config_kwargs)
patch_config(
config,
config_kwargs,
dtype,
rope_scaling=rope_scaling,
flash_attn=flash_attn,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
)

if device_map:
config_kwargs["device_map"] = device_map

model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=config,
low_cpu_mem_usage=True,
**config_kwargs
)

patch_model(model)
model.eval()

return model, tokenizer


def load_model_and_tokenizer(
model_name: str,
model_name_or_path: str,
use_fast_tokenizer: Optional[bool] = False,
dtype: Optional[str] = None,
device_map: Optional[Any] = None,
load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
rope_scaling: Optional[str] = None,
flash_attn: Optional[bool] = False,
**kwargs,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
try:
model, tokenizer = _load_model_and_tokenizer(
model_name_or_path,
use_fast_tokenizer,
dtype,
device_map,
load_in_8bit,
load_in_4bit,
rope_scaling,
flash_attn,
)
except:
from .model import load_model_and_tokenizer_old

model, tokenizer = load_model_and_tokenizer_old(
model_name,
model_name_or_path,
dtype=dtype,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
device_map=device_map,
**kwargs,
)

return model, tokenizer
4 changes: 3 additions & 1 deletion api/adapter/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""" this file is overdated and will be used """

import os
import sys
from typing import List, Optional, Any, Dict, Tuple
Expand Down Expand Up @@ -277,7 +279,7 @@ def get_model_adapter(model_name: str) -> BaseModelAdapter:
raise ValueError(f"No valid model adapter for {model_name}")


def load_model(
def load_model_and_tokenizer_old(
model_name: str,
model_name_or_path: Optional[str] = None,
adapter_model: Optional[str] = None,
Expand Down
182 changes: 182 additions & 0 deletions api/adapter/patcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
""" from https://github.com/hiyouga/LLaMA-Factory/blob/main/src/llmtuner/model/patcher.py """
from __future__ import annotations

import importlib.metadata
import importlib.util
import os
from types import MethodType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
)

import torch
from loguru import logger
from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
BitsAndBytesConfig,
)
from transformers.utils import (
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available
)
from transformers.utils.versions import require_version

if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer


_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available()
except:
_is_bf16_available = False


def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None


def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
return "0.0.0"


def is_flash_attn2_available():
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")


def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
else:
return torch.float32


def _configure_rope(config: "PretrainedConfig", rope_scaling: str = None) -> None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return

scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": rope_scaling, "factor": scaling_factor})
logger.info(f"Using {rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}.")


def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
return

config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster and inference.")


def _configure_quantization(
config_kwargs: Dict[str, Any],
load_in_8bits: bool = False,
load_in_4bits: bool = False,
) -> None:

if load_in_8bits:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
logger.info("Quantizing model to 8 bit.")

elif load_in_4bits:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=config_kwargs.get("torch_dtype", torch.float16),
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
logger.info("Quantizing model to 4 bit.")

if load_in_8bits or load_in_4bits:
config_kwargs["device_map"] = {"": get_current_device()}
else:
config_kwargs["device_map"] = get_current_device()


def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)

if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info(f"Add eos token: {tokenizer.eos_token}")

if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Add pad token: {tokenizer.pad_token}")


def patch_config(
config: "PretrainedConfig",
config_kwargs: Dict[str, Any],
compute_dtype: Optional[str] = None,
**kwargs,
):
if compute_dtype is None: # priority: bf16 > fp16 > fp32
compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
else:
_DTYPE_MAP = {
"half": torch.float16,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
compute_dtype = _DTYPE_MAP.get(compute_dtype, torch.float16)

config_kwargs["torch_dtype"] = compute_dtype

if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, compute_dtype == dtype)

rope_scaling = kwargs.get("rope_scaling", None)
if rope_scaling is not None:
_configure_rope(config, rope_scaling)

if kwargs.get("flash_attn", False):
_configure_flashattn(config_kwargs)

_configure_quantization(
config_kwargs,
kwargs.get("load_in_8bit", False),
kwargs.get("load_in_4bit", False),
)


def patch_model(model: "PreTrainedModel") -> None:
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)


def get_current_device() -> torch.device:
r"""
Gets the current available device.
"""
if is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else:
device = "cpu"

return torch.device(device)
Loading

0 comments on commit 5b45f34

Please sign in to comment.