Skip to content

Commit

Permalink
Merge pull request #50 from huggingface/tgi-support
Browse files Browse the repository at this point in the history
TGI support
  • Loading branch information
IlyasMoutawwakil authored Sep 8, 2023
2 parents 119e9f8 + 0578183 commit b805d49
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 44 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ cython_debug/
.vscode/
*.ipynb
runs/
sweeps/
sweeps/
data/
version.txt
34 changes: 34 additions & 0 deletions examples/text_generation_inference_llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
defaults:
- backend: text-generation-inference # default backend
- benchmark: inference # default benchmark
- experiment # inheriting experiment schema
- _self_ # for hydra 1.1 compatibility
- override hydra/job_logging: colorlog # colorful logging
- override hydra/hydra_logging: colorlog # colorful logging

hydra:
run:
dir: runs/${experiment_name}
sweep:
dir: sweeps/${experiment_name}
job:
chdir: true
env_set:
CUDA_VISIBLE_DEVICES: 0,1

experiment_name: text_generation_inference
model: NousResearch/Llama-2-7b-hf
device: cuda

backend:
no_weights: true
initial_isolation_check: false
continous_isolation_check: false
torch_dtype: float16

benchmark:
input_shapes:
batch_size: 32
sequence_length: 128

new_tokens: 1000
21 changes: 16 additions & 5 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ def seed(self) -> None:
# torch.backends.cudnn.deterministic = True # same as above
# torch.backends.cudnn.benchmark = False # might reduce performance

def prepare_input(self, input: Dict[str, Any]) -> Dict[str, Any]:
if self.is_diffusion_pipeline():
# diffusion pipelines expect a list of strings as input
return input
else:
# models expect tensors on the target device as input
for key, value in input.items():
input[key] = value.to(self.device)

return input

# compiling in openvino requires input shapes
def prepare_for_inference(self, input_shapes: Dict[str, int]) -> Dict[str, Any]:
pass
Expand All @@ -162,10 +173,10 @@ def prepare_for_inference(self, input_shapes: Dict[str, int]) -> Dict[str, Any]:
def prepare_for_profiling(self, input_names: List[str]) -> Dict[str, Any]:
pass

def forward(self, input: Dict[str, Any], **kwargs) -> "ModelOutput":
def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
return self.pretrained_model(**input, **kwargs)

def generate(self, input: Dict[str, Any], **kwargs) -> "ModelOutput":
def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
return self.pretrained_model.generate(**input, **kwargs)

def train(
Expand Down Expand Up @@ -193,14 +204,14 @@ def model_shapes(self) -> Dict[str, int]:

def delete_pretrained_model(self) -> None:
if hasattr(self, "pretrained_model"):
LOGGER.info("\t+ Deleting pretrained model")
del self.pretrained_model

gc.collect()

def delete_model_cache(self) -> None:
LOGGER.info("\t+ Deleting model cache")
model_cache_path = f"models/{self.model}".replace("/", "--")
model_cache_path = os.path.join(os.path.expanduser("~/.cache/huggingface/hub"), model_cache_path)
model_cache_folder = f"models/{self.model}".replace("/", "--")
model_cache_path = os.path.join(os.path.expanduser("~/.cache/huggingface/hub"), model_cache_folder)
shutil.rmtree(model_cache_path, ignore_errors=True)

def clean(self) -> None:
Expand Down
17 changes: 8 additions & 9 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
LOGGER = getLogger("pytorch")



class PyTorchBackend(Backend[PyTorchConfig]):
NAME: str = "pytorch"

Expand Down Expand Up @@ -223,21 +222,21 @@ def prepare_for_profiling(self, input_names: List[str]) -> None:
LOGGER.info("\t+ Wrapping model with FXProfilingWrapper")
self.pretrained_model = FXProfilingWrapper(self.pretrained_model)

def forward(self, input: Dict[str, Any], **kwargs) -> "ModelOutput":
def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
if self.is_diffusion_pipeline():
return super().forward(input, **kwargs)
return super().forward(input, kwargs)
else:
# TODO: autocast as whole can be managed by one config/kwargs
# TODO: autocast as whole can be managed by one config/kwargs ?
with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.config.amp_autocast):
return super().forward(input, **kwargs)
return super().forward(input, kwargs)

def generate(self, input: Dict[str, torch.Tensor], **kwargs) -> "ModelOutput":
def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
if self.is_diffusion_pipeline():
return super().generate(input, **kwargs)
return super().generate(input, kwargs)
else:
# TODO: autocast as whole can be managed by one config/kwargs
# TODO: autocast as whole can be managed by one config/kwargs ?
with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.config.amp_autocast):
return super().generate(input, **kwargs)
return super().generate(input, kwargs)

@record_if_available
def train(
Expand Down
209 changes: 209 additions & 0 deletions optimum_benchmark/backends/text_generation_inference/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import os
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from typing import TYPE_CHECKING, Any, Dict, List

import torch
from accelerate import init_empty_weights
from huggingface_hub import InferenceClient
from transformers import GenerationConfig

if TYPE_CHECKING:
from huggingface_hub.inference._text_generation import TextGenerationResponse

import docker
import docker.errors
import docker.types

from ..base import Backend
from ..pytorch.utils import randomize_weights
from .config import TGIConfig

# bachend logger
LOGGER = getLogger("text-generation-inference")


class TGIBackend(Backend[TGIConfig]):
NAME: str = "text-generation-inference"

def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any]):
super().__init__(model, task, device, hub_kwargs)
self.validate_task()

automodel = self.automodel_class.__name__
LOGGER.info(f"\t+ Infered AutoModel class {automodel} for task {self.task} and model_type {self.model_type}")

def validate_task(self) -> None:
if self.task not in ["text-generation", "text2text-generation"]:
raise NotImplementedError(f"TGI does not support task {self.task}")

def configure(self, config: TGIConfig) -> None:
super().configure(config)
self.config = config

if self.config.no_weights:
# creates dummy model
self.load_model_from_config()
self.save_model_snapshot()
else:
self.load_model_from_pretrained()
self.delete_pretrained_model()

LOGGER.info("\t+ Modifying generation config")
self.modify_generation_config()

LOGGER.info("\t+ Starting Docker client")
self.docker_client = docker.from_env()

try:
LOGGER.info("\t+ Checking if TGI image exists")
self.docker_client.images.get(f"{self.config.image}:{self.config.version}")
except docker.errors.ImageNotFound:
LOGGER.info("\t+ TGI image not found, pulling it")
self.docker_client.images.pull(f"{self.config.image}:{self.config.version}")

LOGGER.info("\t+ Building TGI command")
self.command = [
"--model-id",
self.model,
"--revision",
self.hub_kwargs["revision"],
]

if self.config.quantization is not None:
self.command.extend(["--quantize", self.config.quantization])
if self.config.torch_dtype is not None:
self.command.extend(["--torch-dtype", self.config.torch_dtype])
if self.hub_kwargs.get("trust_remote_code", False):
self.command.append("--trust-remote-code")
if self.config.disable_custom_kernels:
self.command.append("--disable-custom-kernels")

if self.device.type == "cuda":
device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", self.device.index or 0)
LOGGER.info(f"\t+ Starting TGI container on CUDA device(s): {device_ids}")
device_requests = [docker.types.DeviceRequest(device_ids=[str(device_ids)], capabilities=[["gpu"]])]
else:
LOGGER.info("\t+ Starting TGI container on CPU device")
device_requests = None

self.tgi_container = self.docker_client.containers.run(
image=f"{self.config.image}:{self.config.version}",
command=self.command,
shm_size=self.config.shm_size,
volumes={self.config.volume: {"bind": "/data", "mode": "rw"}},
ports={"80/tcp": (self.config.address, self.config.port)},
device_requests=device_requests,
detach=True,
)

LOGGER.info("\t+ Waiting for TGI server to be ready")
for line in self.tgi_container.logs(stream=True):
tgi_log = line.decode("utf-8").strip()
if not tgi_log:
continue
elif "Connected" in tgi_log:
LOGGER.info("\t+ TGI server is ready")
break
else:
LOGGER.info(f"\t {tgi_log}")

LOGGER.info("\t+ Creating InferenceClient")
self.client = InferenceClient(model=f"http://{self.config.address}:{self.config.port}")

def load_model_from_config(self) -> None:
LOGGER.info("\t+ Initializing empty weights model on device: meta")
with init_empty_weights():
self.pretrained_model = self.automodel_class.from_config(
config=self.pretrained_config,
torch_dtype=getattr(torch, self.config.torch_dtype),
trust_remote_code=self.hub_kwargs.get("trust_remote_code", False),
)
# could add model dispatching to accelerate saving and support bigger models
LOGGER.info(f"\t+ Materializing model on device: {self.device}")
self.pretrained_model.to_empty(device=self.device)
LOGGER.info("\t+ Randomizing model weights")
randomize_weights(self.pretrained_model)
LOGGER.info("\t+ Tying weights")
self.pretrained_model.tie_weights()

@property
def model_snapshot_path(self) -> str:
model_cache_folder = f"models/{self.model}".replace("/", "--")
model_cache_path = f"{self.config.volume}/{model_cache_folder}"
snapshot_ref = open(f"{model_cache_path}/refs/{self.hub_kwargs.get('revision', 'main')}", "r").read().strip()
return f"{model_cache_path}/snapshots/{snapshot_ref}"

def save_model_snapshot(self) -> None:
LOGGER.info("\t+ Saving pretrained model snapshot")
self.pretrained_model.save_pretrained(self.model_snapshot_path, safe_serialization=True)

def load_model_from_pretrained(self) -> None:
LOGGER.info("\t+ Downloading pretrained model")
with init_empty_weights():
self.pretrained_model = self.automodel_class.from_pretrained(self.model, **self.hub_kwargs)

def modify_generation_config(self) -> None:
# this should, theorically, make the generated output's sequence length fully controlled by max_new_tokens
# instead of stopping at the first eos_token_id/pad_token_id
generation_config = GenerationConfig.from_pretrained(self.model, **self.hub_kwargs)
generation_config.eos_token_id = -100
generation_config.pad_token_id = -101
generation_config.save_pretrained(self.model_snapshot_path)

def prepare_input(self, input: Dict[str, Any]) -> Dict[str, Any]:
return {"prompt": self.pretrained_processor.batch_decode(input["input_ids"].tolist())}

def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> List["TextGenerationResponse"]:
output = []
with ThreadPoolExecutor(max_workers=len(input["prompt"])) as executor:
futures = [
executor.submit(
self.client.text_generation,
decoder_input_details=True,
prompt=input["prompt"][i],
max_new_tokens=1,
details=True,
)
for i in range(len(input["prompt"]))
]
for future in futures:
output.append(future.result())

return output

def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> List["TextGenerationResponse"]:
output = []
with ThreadPoolExecutor(max_workers=len(input["prompt"])) as executor:
futures = [
executor.submit(
self.client.text_generation,
max_new_tokens=kwargs["max_new_tokens"],
do_sample=kwargs["do_sample"],
prompt=input["prompt"][i],
details=True,
)
for i in range(len(input["prompt"]))
]
for i in range(len(input["prompt"])):
output.append(futures[i].result())
if len(output[-1].details["tokens"]) < kwargs["max_new_tokens"]:
LOGGER.warning(
f"\t+ Generated {len(output[-1].details['tokens'])} tokens instead of {kwargs['max_new_tokens']}"
" tokens. Benchmark results might be inaccurate."
)

return output

def clean(self) -> None:
super().clean()

if hasattr(self, "tgi_container"):
LOGGER.info("\t+ Stoping TGI container")
self.tgi_container.stop()
LOGGER.info("\t+ Waiting for TGI container to stop")
self.tgi_container.wait()

if hasattr(self, "docker_client"):
LOGGER.info("\t+ Closing docker client")
self.docker_client.close()
34 changes: 34 additions & 0 deletions optimum_benchmark/backends/text_generation_inference/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from dataclasses import dataclass
from typing import Optional

from ..config import BackendConfig


@dataclass
class TGIConfig(BackendConfig):
name: str = "tgi"
version: str = "1.0.3"
_target_: str = "optimum_benchmark.backends.text_generation_inference.backend.TGIBackend"

# server options
image: str = "ghcr.io/huggingface/text-generation-inference"
volume: str = f"{os.path.expanduser('~')}/.cache/huggingface/hub"
shm_size: str = "1g"
address: str = "127.0.0.1"
port: int = 1111

# torch options
no_weights: bool = False # True, False
torch_dtype: Optional[str] = None # None, float32, float16, bfloat16
# optimization options
disable_custom_kernels: bool = False # True, False
# quantization options
quantization: Optional[str] = None # None, bitsandbytes-nf4, bitsandbytes-fp4

def __post_init__(self):
super().__post_init__()

if self.torch_dtype is not None:
if self.torch_dtype not in ["float32", "float16", "bfloat16"]:
raise ValueError(f"Invalid value for torh_dtype: {self.torch_dtype}")
Loading

0 comments on commit b805d49

Please sign in to comment.