From ef2adb04287a0eb2421536204dc4928f18072fc1 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Sun, 18 Jun 2023 16:02:54 +0800 Subject: [PATCH] add chatbot (#2) --- .pre-commit-config.yaml | 2 +- llmdeploy/__init__.py | 1 + llmdeploy/serve/__init__.py | 1 + llmdeploy/serve/client.py | 38 ++ llmdeploy/serve/fastertransformer/__init__.py | 3 + llmdeploy/serve/fastertransformer/chatbot.py | 440 ++++++++++++++++++ .../fastertransformer/service_docker_up.sh | 0 llmdeploy/serve/fastertransformer/utils.py | 165 +++++++ llmdeploy/version.py | 30 ++ requirements.txt | 5 + setup.py | 38 ++ 11 files changed, 722 insertions(+), 1 deletion(-) create mode 100644 llmdeploy/__init__.py create mode 100644 llmdeploy/serve/__init__.py create mode 100644 llmdeploy/serve/client.py create mode 100644 llmdeploy/serve/fastertransformer/__init__.py create mode 100644 llmdeploy/serve/fastertransformer/chatbot.py create mode 100644 llmdeploy/serve/fastertransformer/service_docker_up.sh create mode 100644 llmdeploy/serve/fastertransformer/utils.py create mode 100644 llmdeploy/version.py create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 48ab2fb453..b171aed7d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,4 +50,4 @@ repos: rev: v0.2.0 hooks: - id: check-copyright - args: [] + args: ["llmdeploy"] diff --git a/llmdeploy/__init__.py b/llmdeploy/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/llmdeploy/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/llmdeploy/serve/__init__.py b/llmdeploy/serve/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/llmdeploy/serve/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/llmdeploy/serve/client.py b/llmdeploy/serve/client.py new file mode 100644 index 0000000000..017e461336 --- /dev/null +++ b/llmdeploy/serve/client.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import fire + +from llmdeploy.serve.fastertransformer.chatbot import Chatbot + + +def input_prompt(): + print('\ndouble enter to end input >>> ', end='') + sentinel = '' # ends when this string is seen + return '\n'.join(iter(input, sentinel)) + + +def main(triton_server_addr: str, model_name: str, session_id: int): + log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO') + chatbot = Chatbot(triton_server_addr, + model_name, + log_level=log_level, + display=True) + nth_round = 1 + while True: + prompt = input_prompt() + if prompt == 'exit': + exit(0) + elif prompt == 'end': + chatbot.end(session_id) + else: + request_id = f'{session_id}-{nth_round}' + for status, res, tokens in chatbot.stream_infer( + session_id, prompt, request_id=request_id): + continue + print(f'session {session_id}, {status}, {tokens}, {res}') + nth_round += 1 + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/llmdeploy/serve/fastertransformer/__init__.py b/llmdeploy/serve/fastertransformer/__init__.py new file mode 100644 index 0000000000..096504a9a2 --- /dev/null +++ b/llmdeploy/serve/fastertransformer/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from llmdeploy.serve.fastertransformer.chatbot import \ + Chatbot # noqa: F401,F403 diff --git a/llmdeploy/serve/fastertransformer/chatbot.py b/llmdeploy/serve/fastertransformer/chatbot.py new file mode 100644 index 0000000000..457b146624 --- /dev/null +++ b/llmdeploy/serve/fastertransformer/chatbot.py @@ -0,0 +1,440 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import logging +import queue +import random +import threading +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import List, Union + +import google.protobuf.json_format +import mmengine +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.grpc.service_pb2 import ModelInferResponse + +from llmdeploy.serve.fastertransformer.utils import (Postprocessor, + Preprocessor, + prepare_tensor) + + +@dataclass +class Session: + session_id: Union[int, str] + request_id: str = '' + prev: str = '' # history of the session in text format + round_prev: str = '' # previous generated text in the current round + sequence_length: int = 0 # the total generated token number in the session + response: str = '' + status: int = None # status of the session + + +class StatusCode(Enum): + TRITON_STREAM_END = 0 # end of streaming + TRITON_STREAM_ING = 1 # response is in streaming + TRITON_SERVER_ERR = -1 # triton server's error + TRITON_SESSION_CLOSED = -2 # session has been closed + TRITON_SESSION_OUT_OF_LIMIT = -3 # request length out of limit + TRITON_SESSION_INVALID_ARG = -4 # invalid argument + + +def stream_callback(que, result, error): + if error: + print(error) + que.put(dict(errcode=StatusCode.TRITON_SERVER_ERR, errmsg=f'{error}')) + else: + que.put(result.get_response(as_json=True)) + + +def get_logger(log_file=None, log_level=logging.INFO): + from .utils import get_logger + logger = get_logger('service.ft', log_file=log_file, log_level=log_level) + return logger + + +class Chatbot: + """Chatbot for LLaMA series models with fastertransformer as inference + engine. + + Args: + tritonserver_addr (str): communicating address ':' of + triton inference server + model_name (str): name of the to-be-deployed mode + session_len (int): the maximum context length of the model + top_p (float): If set to float < 1, only the smallest set of most + probable tokens with probabilities that add up to top_p or higher + are kept for generation. + top_k (int): The number of the highest probability vocabulary tokens to + keep for top-k-filtering + temperature (float): to modulate the next token probability + repetition_penalty (float): The parameter for repetition penalty. + 1.0 means no penalty + stop_words (list): List of token ids that stops the generation + bad_words (list): List of token ids that are not allowed to be + generated. + log_level (int): the level of the log + display (bool): display the generated text on consolo or not + """ + + def __init__(self, + tritonserver_addr: str, + model_name: str, + session_len: int = 2048, + top_p: float = 1.0, + top_k: int = 40, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + stop_words: List = None, + bad_words: List = None, + log_level: int = logging.INFO, + display: bool = False): + self._session = None + self.tritonserver_addr = tritonserver_addr + self.model_name = model_name + if stop_words is not None: + stop_words = np.array(stop_words, dtype=np.int32) + if bad_words is not None: + bad_words = np.array(bad_words, dtype=np.int32) + + self.cfg = mmengine.Config( + dict(session_len=session_len, + top_p=top_p, + top_k=top_k, + temperature=temperature, + repetition_penalty=repetition_penalty, + stop_words=stop_words, + bad_words=bad_words)) + self.preprocess = Preprocessor(tritonserver_addr) + self.postprocess = Postprocessor(tritonserver_addr) + self.log_level = log_level + self.display = display + + def stream_infer(self, + session_id: int, + prompt: str, + request_id: str = '', + request_output_len: int = None, + sequence_start: bool = False, + sequence_end: bool = False, + *args, + **kwargs): + """Start a new round conversion of a session. + + Args: + session_id (int): the identical id of a session + prompt (str): user's prompt in this round conversation + request_id (str): the identical id of this round conversation + request_output_len (int): the expected generated token numbers + sequence_start (bool): start flag of a session + sequence_end (bool): end flag of a session + Returns: + iterator: The generated content by chatbot + """ + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + logger = get_logger(log_level=self.log_level) + logger.info(f'session {session_id}, request_id {request_id}, ' + f'request_output_len {request_output_len}') + + if self._session is None: + sequence_start = True + self._session = Session(session_id=session_id) + elif self._session.status == 0: + logger.error(f'session {session_id} has been ended. Please set ' + f'`sequence_start` be True if you want to restart it') + yield StatusCode.TRITON_SESSION_CLOSED, '', 0 + return + + self._session.status = 1 + self._session.request_id = request_id + self._session.response = '' + + prompt = self._get_prompt(prompt, sequence_start) + for status, res, tokens in self._stream_infer(self._session, prompt, + request_output_len, + sequence_start, + sequence_end): + yield status, res, tokens + self._session.prev = self._session.prev + self._session.round_prev + + def end(self, session_id: int, *args, **kwargs): + """end a session. Triton inference server will release the session's + occupied resource when it is ended. + + Args: + session_id (int): the identical id of a session + + Returns: + int: 0: success, -1: session not found + """ + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + + logger = get_logger(log_level=self.log_level) + logger.info(f'end session: {session_id}') + + if self._session is None: + logger.error( + f"session {session_id} doesn't exist. It cannot be ended") + return StatusCode.TRITON_SESSION_INVALID_ARG + if self._session.session_id != session_id: + logger.error(f'you cannot end session {session_id}, because this ' + f'session is {self._session.session_id}') + return StatusCode.TRITON_SESSION_INVALID_ARG + if self._session.status == 0: + logger.warning(f'session {session_id} has already been ended') + return StatusCode.TRITON_SESSION_CLOSED + + self._session.status = 0 + for status, _, _ in self._stream_infer(self._session, + prompt='', + request_output_len=0, + sequence_start=False, + sequence_end=True): + if status != StatusCode.TRITON_STREAM_END: + return status + return StatusCode.TRITON_STREAM_END + + def cancel(self, session_id: int, *args, **kwargs): + """Cancel the session during generating tokens. + + Args: + session_id (int): the identical id of a session + + Returns: + int: 0: success, -1: session not found + """ + assert isinstance(session_id, int), \ + f'INT session id is required, but got {type(session_id)}' + logger = get_logger(log_level=self.log_level) + logger.info(f'cancel session: {session_id}') + + if self._session is None: + logger.error( + f"session {session_id} doesn't exist. It cannot be cancelled") + return StatusCode.TRITON_SESSION_INVALID_ARG + if self._session.session_id != session_id: + logger.error( + f'you cannot cancel session {session_id}, because this ' + f'session is {self._session.session_id}') + return StatusCode.TRITON_SESSION_INVALID_ARG + if self._session.status == 0: + logger.error(f'session {session_id} has already been ended. ' + f'It cannot be cancelled') + return StatusCode.TRITON_SESSION_CLOSED + + prev_session = self._session + for status, res, _ in self._stream_infer(self._session, + prompt='', + request_output_len=0, + sequence_start=False, + sequence_end=False, + cancel=True): + if status.value < 0: + break + if status == StatusCode.TRITON_STREAM_END: + logger.info(f'cancel session {session_id} successfully') + if prev_session.prev: + logger.warn(f'TODO: start to recover session {session_id}') + else: + logger.info(f'cancel session {session_id} failed: {res}') + return status + + def _get_prompt(self, prompt: str, sequence_start: bool): + if self.model_name == 'vicuna': + if sequence_start: + return f'USER: {prompt} ASSISTANT:' + else: + return f'USER: {prompt} ASSISTANT:' + + def _stream_infer(self, + session: Session, + prompt: str, + request_output_len: int = 512, + sequence_start: bool = True, + sequence_end: bool = False, + cancel: bool = False): + logger = get_logger(log_level=self.log_level) + logger.info(f'session {session.session_id}, ' + f'request id {session.request_id}, ' + f'request_output_len {request_output_len}, ' + f'start {sequence_start}, ' + f'end {sequence_end}, cancel {cancel}') + + assert request_output_len is None or \ + isinstance(request_output_len, int), \ + f'request_output_len is supposed to be None or int, ' \ + f'but got {type(request_output_len)}' + + input_ids, input_lengths = self.preprocess(prompt) + input_tokens = input_lengths.squeeze() + + if request_output_len is None: + request_output_len = max( + 128, + self.cfg.session_len - session.sequence_length - input_tokens) + + if input_tokens + request_output_len + \ + session.sequence_length > self.cfg.session_len: + errmsg = f'session {session.session_id}, ' \ + f'out of max sequence length {self.cfg.session_len}, ' \ + f'#input tokens {input_tokens}, ' \ + f'history tokens {session.sequence_length}, ' \ + f'request length {request_output_len}' + yield StatusCode.TRITON_SESSION_OUT_OF_LIMIT, errmsg, 0 + logger.info(f'session {session.session_id}, ' + f'input tokens: {input_tokens}, ' + f'request tokens: {request_output_len}, ' + f'history tokens: {session.sequence_length}') + + preseq_length = session.sequence_length + session.round_prev = '' + + que = queue.Queue() + producer = threading.Thread(target=self._stream_producer, + args=(self.tritonserver_addr, session, que, + self.cfg, input_ids, input_lengths, + request_output_len, sequence_start, + sequence_end, preseq_length, cancel)) + producer.start() + for state, res, tokens in self.stream_consumer(self.postprocess, que, + session, preseq_length, + cancel, logger, + self.display): + if state.value < 0: + yield state, res, 0 + else: + yield state, res, tokens - input_tokens + producer.join() + self._session = que.get() + curseq_length = self._session.sequence_length + logger.info(f'session {session.session_id}, pre seq_len ' + f'{preseq_length}, cur seq_len {curseq_length}, ' + f'diff {curseq_length - preseq_length}') + + @staticmethod + def _stream_producer(tritonserver_addr, session, que, cfg, input_ids, + input_lengths, request_output_len, sequence_start, + sequence_end, preseq_length, cancel): + request_output_len = np.full(input_lengths.shape, + request_output_len).astype(np.uint32) + + callback = partial(stream_callback, que) + with grpcclient.InferenceServerClient(tritonserver_addr) as client: + inputs = [ + prepare_tensor('input_ids', input_ids), + prepare_tensor('input_lengths', input_lengths), + prepare_tensor('request_output_len', request_output_len), + prepare_tensor('runtime_top_k', + cfg.top_k * np.ones((1, 1), dtype=np.uint32)), + prepare_tensor('runtime_top_p', + cfg.top_p * np.ones((1, 1), dtype=np.float32)), + prepare_tensor( + 'temperature', + cfg.temperature * np.ones((1, 1), dtype=np.float32)), + prepare_tensor( + 'repetition_penalty', + cfg.repetition_penalty * np.ones( + (1, 1), dtype=np.float32)), + prepare_tensor('step', + preseq_length * np.ones((1, 1), dtype=np.int32)) + ] + if cfg.stop_words is not None: + inputs += [prepare_tensor('stop_words_list', cfg.stop_words)] + if cfg.bad_words is not None: + inputs += [prepare_tensor('bad_words_list', cfg.bad_words)] + + inputs += [ + prepare_tensor( + 'session_len', + cfg.session_len * + np.ones([input_ids.shape[0], 1], dtype=np.uint32)), + prepare_tensor('START', (1 if sequence_start else 0) * np.ones( + (1, 1), dtype=np.int32)), + prepare_tensor('END', (1 if sequence_end else 0) * np.ones( + (1, 1), dtype=np.int32)), + prepare_tensor( + 'CORRID', + session.session_id * np.ones((1, 1), dtype=np.uint64)), + prepare_tensor('STOP', (1 if cancel else 0) * np.ones( + (1, 1), dtype=np.int32)) + ] + if sequence_start: + random_seed = random.getrandbits(64) + inputs += [ + prepare_tensor( + 'random_seed', + random_seed * np.ones((1, 1), dtype=np.uint64)) + ] + client.start_stream(callback) + client.async_stream_infer('fastertransformer', + inputs, + sequence_id=session.session_id, + request_id=session.request_id, + sequence_start=sequence_start, + sequence_end=sequence_end) + que.put(None) + + @staticmethod + def stream_consumer(postprocess, res_queue, session, preseq_length, cancel, + logger, display): + + def process_response(res): + if session.ai_says is None: + return res, True + index = res.find(session.ai_says) + if index == -1: + return res, False + res = res[index + len(session.ai_says):].replace(session.eoa, '') + return res, True + + while True: + result = res_queue.get() + if result is None: + yield StatusCode.TRITON_STREAM_END, session.response, \ + session.sequence_length - preseq_length + break + if 'errcode' in result: + logger.error(f'got error from fastertransformer, code ' + f"{result['errcode']}, {result['errmsg']}, " + f'token {session.sequence_length}') + session.sequence_length = preseq_length + yield result['errcode'], result['errmsg'], 0 + break + if cancel: + continue + try: + message = ModelInferResponse() + google.protobuf.json_format.Parse(json.dumps(result), message) + result = grpcclient.InferResult(message) + sequence_length = result.as_numpy('sequence_length') + output_ids = result.as_numpy('output_ids') + + session.sequence_length = sequence_length.squeeze() + sequence_length = sequence_length - preseq_length + + output_ids = output_ids.reshape((1, 1, output_ids.shape[-1])) + sequence_length = sequence_length.reshape( + (1, sequence_length.shape[-1])) + output_str = postprocess(output_ids[:, :, preseq_length:], + sequence_length) + text = output_str[0].decode() + if display: + new_text = text[len(session.round_prev):] + print(new_text, end='', flush=True) + session.round_prev = text + yield (StatusCode.TRITON_STREAM_ING, session.response, + sequence_length.squeeze()) + except Exception as e: + logger.error(f'catch exception: {e}') + + # put session back to queue so that `_stream_infer` can update it in + # `self.sessions` + while not res_queue.empty(): + res_queue.get() + res_queue.put(session) + if display: + print('\n') diff --git a/llmdeploy/serve/fastertransformer/service_docker_up.sh b/llmdeploy/serve/fastertransformer/service_docker_up.sh new file mode 100644 index 0000000000..e69de29bb2 diff --git a/llmdeploy/serve/fastertransformer/utils.py b/llmdeploy/serve/fastertransformer/utils.py new file mode 100644 index 0000000000..7cfa57566f --- /dev/null +++ b/llmdeploy/serve/fastertransformer/utils.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional, Union + +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.utils import np_to_triton_dtype + +logger_initialized = {} + + +def get_logger(name: str, + log_file: Optional[str] = None, + log_level: int = logging.INFO, + file_mode: str = 'w'): + """Initialize and get a logger by name. + + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. If `log_file` is specified, a FileHandler will also be added. + Args: + name (str): Logger name. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the logger. + log_level (int): The logger level. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + Returns: + logging.Logger: The expected logger. + """ + # use logger in mmengine if exists. + try: + from mmengine.logging import MMLogger + if MMLogger.check_instance_created(name): + logger = MMLogger.get_instance(name) + else: + logger = MMLogger.get_instance(name, + logger_name=name, + log_file=log_file, + log_level=log_level, + file_mode=file_mode) + return logger + + except Exception: + pass + + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + # handle hierarchical names + # e.g., logger "a" is initialized, then logger "a.b" will skip the + # initialization since it is a child of "a". + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + # handle duplicate logs to the console + for handler in logger.root.handlers: + if type(handler) is logging.StreamHandler: + handler.setLevel(logging.ERROR) + + stream_handler = logging.StreamHandler() + handlers = [stream_handler] + + if log_file is not None: + # Here, the default behaviour of the official logger is 'a'. Thus, we + # provide an interface to change the file mode to the default + # behaviour. + file_handler = logging.FileHandler(log_file, file_mode) + handlers.append(file_handler) + + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + for handler in handlers: + handler.setFormatter(formatter) + handler.setLevel(log_level) + logger.addHandler(handler) + + logger.setLevel(log_level) + logger_initialized[name] = True + + return logger + + +def prepare_tensor(name, input_tensor): + t = grpcclient.InferInput(name, list(input_tensor.shape), + np_to_triton_dtype(input_tensor.dtype)) + t.set_data_from_numpy(input_tensor) + return t + + +class Preprocessor: + + def __init__(self, tritonserver_addr: str): + self.tritonserver_addr = tritonserver_addr + self.model_name = 'preprocessing' + + def __call__(self, *args, **kwargs): + return self.infer(*args, **kwargs) + + def infer(self, prompts: Union[str, List[str]]) -> tuple: + """Tokenize the input prompts. + + Args: + prompts(str | List[str]): user's prompt, or a batch prompts + + Returns: + Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token + ids, ids' length and requested output length + """ + if isinstance(prompts, str): + input0 = [[prompts]] + elif isinstance(prompts, List): + input0 = [[prompt] for prompt in prompts] + else: + assert 0, f'str or List[str] prompts are expected but got ' \ + f'{type(prompts)}' + + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.uint32) + inputs = [ + prepare_tensor('QUERY', input0_data), + prepare_tensor('REQUEST_OUTPUT_LEN', output0_len) + ] + + with grpcclient.InferenceServerClient(self.tritonserver_addr) as \ + client: + result = client.infer(self.model_name, inputs) + output0 = result.as_numpy('INPUT_ID') + output1 = result.as_numpy('REQUEST_INPUT_LEN') + return output0, output1 + + +class Postprocessor: + + def __init__(self, tritonserver_addr: str): + self.tritonserver_addr = tritonserver_addr + + def __call__(self, *args, **kwargs): + return self.infer(*args, **kwargs) + + def infer(self, output_ids: np.ndarray, seqlen: np.ndarray): + """De-tokenize tokens for text. + + Args: + output_ids(np.ndarray): tokens' id + seqlen(np.ndarray): sequence length + + Returns: + str: decoded tokens + """ + inputs = [ + prepare_tensor('TOKENS_BATCH', output_ids), + prepare_tensor('sequence_length', seqlen) + ] + inputs[0].set_data_from_numpy(output_ids) + inputs[1].set_data_from_numpy(seqlen) + model_name = 'postprocessing' + with grpcclient.InferenceServerClient(self.tritonserver_addr) \ + as client: + result = client.infer(model_name, inputs) + output0 = result.as_numpy('OUTPUT') + return output0 diff --git a/llmdeploy/version.py b/llmdeploy/version.py new file mode 100644 index 0000000000..079f9e66d8 --- /dev/null +++ b/llmdeploy/version.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +__version__ = '0.0.1' +short_version = __version__ + + +def parse_version_info(version_str: str) -> Tuple: + """Parse version from a string. + + Args: + version_str (str): A string represents a version info. + + Returns: + tuple: A sequence of integer and string represents version. + """ + _version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + _version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + _version_info.append(int(patch_version[0])) + _version_info.append(f'rc{patch_version[1]}') + return tuple(_version_info) + + +version_info = parse_version_info(__version__) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..19d9ee7593 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +fire +mmengine +numpy +setuptools +tritonclient==2.33.0 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..c18383ea02 --- /dev/null +++ b/setup.py @@ -0,0 +1,38 @@ +import os + +from setuptools import find_packages, setup + +pwd = os.path.dirname(__file__) +version_file = 'llmdeploy/version.py' + + +def readme(): + with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f: + content = f.read() + return content + + +def get_version(): + with open(os.path.join(pwd, version_file), 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +if __name__ == '__main__': + setup(name='llmdeploy', + version=get_version(), + description='triton inference service of llama', + long_description=readme(), + long_description_content_type='text/markdown', + author='OpenMMLab', + author_email='openmmlab@gmail.com', + packages=find_packages( + exclude=('llmdeploy/serve/fastertransformer/triton_models', )), + classifiers=[ + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + ])