diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index eef03e7d8..5de00e106 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -2,8 +2,8 @@ On the server side, run one of the following commands: vLLM OpenAI API server - python -m vllm.entrypoints.openai.api_server \ - --model --swap-space 16 \ + vllm serve \ + --swap-space 16 \ --disable-log-requests (TGI backend) @@ -17,7 +17,7 @@ --dataset-path \ --request-rate \ # By default is inf --num-prompts # By default is 1000 - + when using tgi backend, add --endpoint /generate_stream to the end of the command above. @@ -44,6 +44,11 @@ except ImportError: from backend_request_func import get_tokenizer +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + @dataclass class BenchmarkMetrics: @@ -72,7 +77,6 @@ def sample_sharegpt_requests( ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") - # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) @@ -191,6 +195,7 @@ async def get_request( if request_rate == float("inf"): # If the request rate is infinity, then we don't need to wait. continue + # Sample the request interval from the exponential distribution. interval = np.random.exponential(1.0 / request_rate) # The next request will be sent after the interval. @@ -214,7 +219,7 @@ def calculate_metrics( # We use the tokenizer to count the number of output tokens for all # serving backends instead of looking at len(outputs[i].itl) since # multiple output tokens may be bundled together - # Note: this may inflate the output token count slightly + # Note : this may inflate the output token count slightly output_len = len( tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids) @@ -511,7 +516,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": - parser = argparse.ArgumentParser( + parser = FlexibleArgumentParser( description="Benchmark the online serving throughput.") parser.add_argument( "--backend", diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 6248d8468..092c3c6cb 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -109,7 +109,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) ```{argparse} :module: vllm.entrypoints.openai.cli_args -:func: make_arg_parser +:func: create_parser_for_docs :prog: -m vllm.entrypoints.openai.api_server ``` diff --git a/setup.py b/setup.py index b2ae6def8..24d8d9a45 100644 --- a/setup.py +++ b/setup.py @@ -450,4 +450,9 @@ def _read_requirements(filename: str) -> List[str]: }, cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, package_data=package_data, + entry_points={ + "console_scripts": [ + "vllm=vllm.scripts:main", + ], + }, ) diff --git a/tests/tgis/test_hub.py b/tests/tgis/test_hub.py new file mode 100644 index 000000000..5ecde4bc6 --- /dev/null +++ b/tests/tgis/test_hub.py @@ -0,0 +1,50 @@ +from pathlib import Path + +import pytest +from huggingface_hub.utils import LocalEntryNotFoundError + +from vllm.tgis_utils.hub import (convert_files, download_weights, weight_files, + weight_hub_files) + + +def test_convert_files(): + model_id = "bigscience/bloom-560m" + local_pt_files = download_weights(model_id, extension=".bin") + local_pt_files = [Path(p) for p in local_pt_files] + local_st_files = [ + p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" + for p in local_pt_files + ] + convert_files(local_pt_files, local_st_files, discard_names=[]) + + found_st_files = weight_files(model_id) + + assert all([str(p) in found_st_files for p in local_st_files]) + + +def test_weight_hub_files(): + filenames = weight_hub_files("bigscience/bloom-560m") + assert filenames == ["model.safetensors"] + + +def test_weight_hub_files_llm(): + filenames = weight_hub_files("bigscience/bloom") + assert filenames == [ + f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73) + ] + + +def test_weight_hub_files_empty(): + filenames = weight_hub_files("bigscience/bloom", ".errors") + assert filenames == [] + + +def test_download_weights(): + files = download_weights("bigscience/bloom-560m") + local_files = weight_files("bigscience/bloom-560m") + assert files == local_files + + +def test_weight_files_error(): + with pytest.raises(LocalEntryNotFoundError): + weight_files("bert-base-uncased") \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index f2b2d22b1..0acb4c0f5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import get_open_port +from vllm.utils import FlexibleArgumentParser, get_open_port # Path to root of repository so that utilities can be imported by ray workers VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) @@ -74,7 +74,9 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None: cli_args = cli_args + ["--port", str(get_open_port())] - parser = make_arg_parser() + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) args = parser.parse_args(cli_args) self.host = str(args.host or 'localhost') self.port = int(args.port) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 12f8ad63c..df1a1adfa 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -8,7 +8,7 @@ import fastapi import uvicorn -from fastapi import Request +from fastapi import APIRouter, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -30,10 +30,14 @@ from vllm.logger import init_logger from vllm.tgis_utils.args import add_tgis_args, postprocess_tgis_args from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds +logger = init_logger(__name__) +engine: AsyncLLMEngine +engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding @@ -67,50 +71,35 @@ async def _force_log(): logger.info("gRPC server stopped") -app = fastapi.FastAPI(lifespan=lifespan) - - -def parse_args(): - parser = make_arg_parser() - parser = add_tgis_args(parser) - parsed_args = parser.parse_args() - parsed_args = postprocess_tgis_args(parsed_args) - return parsed_args - +router = APIRouter() # Add prometheus asgi middleware to route /metrics requests route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics route.path_regex = re.compile('^/metrics(?P.*)$') -app.routes.append(route) - +router.routes.append(route) -@app.exception_handler(RequestValidationError) -async def validation_exception_handler(_, exc): - err = openai_serving_chat.create_error_response(message=str(exc)) - return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) - -@app.get("/health") +@router.get("/health") async def health() -> Response: """Health check.""" await openai_serving_chat.engine.check_health() return Response(status_code=200) -@app.get("/v1/models") +@router.get("/v1/models") async def show_available_models(): - models = await openai_serving_chat.show_available_models() + models = await openai_serving_completion.show_available_models() return JSONResponse(content=models.model_dump()) -@app.get("/version") +@router.get("/version") async def show_version(): ver = {"version": VLLM_VERSION} return JSONResponse(content=ver) -@app.post("/v1/chat/completions") +@router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): generator = await openai_serving_chat.create_chat_completion( @@ -126,7 +115,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return JSONResponse(content=generator.model_dump()) -@app.post("/v1/completions") +@router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): generator = await openai_serving_completion.create_completion( request, raw_request) @@ -140,7 +129,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -@app.post("/v1/embeddings") +@router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): generator = await openai_serving_embedding.create_embedding( request, raw_request) @@ -151,8 +140,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -if __name__ == "__main__": - args = parse_args() +def build_app(args): + app = fastapi.FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path app.add_middleware( CORSMiddleware, @@ -162,6 +153,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): allow_headers=args.allowed_headers, ) + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + err = openai_serving_chat.create_error_response(message=str(exc)) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + if token := envs.VLLM_API_KEY or args.api_key: @app.middleware("http") @@ -187,6 +184,12 @@ async def authentication(request: Request, call_next): raise ValueError(f"Invalid middleware {middleware}. " f"Must be a function or a class.") + return app + + +def run_server(args, llm_engine=None): + app = build_app(args) + logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) @@ -195,6 +198,8 @@ async def authentication(request: Request, call_next): else: served_model_names = [args.model] + global engine, engine_args + engine_args = AsyncEngineArgs.from_cli_args(args) # Enforce pixel values as image input type for vision language models @@ -206,8 +211,9 @@ async def authentication(request: Request, call_next): "Only --image-input-type 'pixel_values' is supported for serving " "vision language models with the vLLM API server.") - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) event_loop: Optional[asyncio.AbstractEventLoop] try: @@ -223,6 +229,11 @@ async def authentication(request: Request, call_next): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) + global openai_serving_chat + global openai_serving_completion + global openai_serving_embedding + global async_llm_engine + openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, args.response_role, @@ -247,3 +258,16 @@ async def authentication(request: Request, call_next): ssl_certfile=args.ssl_certfile, ssl_ca_certs=args.ssl_ca_certs, ssl_cert_reqs=args.ssl_cert_reqs) + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + parser = add_tgis_args(parser) + args = parser.parse_args() + args = postprocess_tgis_args(args) + + run_server(args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 0b0fcdc5a..cc3031c66 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -10,7 +10,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import LoRAModulePath -from vllm.tgis_utils.args import EnvVarArgumentParser +from vllm.utils import FlexibleArgumentParser class LoRAParserAction(argparse.Action): @@ -23,9 +23,7 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, lora_list) -def make_arg_parser(): - parser = EnvVarArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server.") +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--host", type=nullable_str, default=None, @@ -114,3 +112,9 @@ def make_arg_parser(): parser = AsyncEngineArgs.add_cli_args(parser) return parser + + +def create_parser_for_docs() -> FlexibleArgumentParser: + parser_for_docs = FlexibleArgumentParser( + prog="-m vllm.entrypoints.openai.api_server") + return make_arg_parser(parser_for_docs) diff --git a/vllm/scripts.py b/vllm/scripts.py new file mode 100644 index 000000000..c1ba1c910 --- /dev/null +++ b/vllm/scripts.py @@ -0,0 +1,344 @@ +# The CLI entrypoint to vLLM. +import argparse +import os +import signal +import sys +from pathlib import Path +from typing import Optional + +from openai import OpenAI + +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils import FlexibleArgumentParser + + +def registrer_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def serve(args: argparse.Namespace) -> None: + # EngineArgs expects the model name to be passed as --model. + args.model = args.model_tag + + run_server(args) + + +def interactive_cli(args: argparse.Namespace) -> None: + registrer_signal_handlers() + + base_url = args.url + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") + openai_client = OpenAI(api_key=api_key, base_url=base_url) + + if args.model_name: + model_name = args.model_name + else: + available_models = openai_client.models.list() + model_name = available_models.data[0].id + + print(f"Using model: {model_name}") + + if args.command == "complete": + complete(model_name, openai_client) + elif args.command == "chat": + chat(args.system_prompt, model_name, openai_client) + + +def tgis_cli(args: argparse.Namespace) -> None: + registrer_signal_handlers() + + if args.command == "download-weights": + download_weights(args.model_name, args.revision, args.token, + args.extension, args.auto_convert) + elif args.command == "convert-to-safetensors": + convert_to_safetensors(args.model_name, args.revision) + elif args.command == "convert-to-fast-tokenizer": + convert_to_fast_tokenizer(args.model_name, args.revision, + args.output_path) + + +def complete(model_name: str, client: OpenAI) -> None: + print("Please enter prompt to complete:") + while True: + input_prompt = input("> ") + + completion = client.completions.create(model=model_name, + prompt=input_prompt) + output = completion.choices[0].text + print(output) + + +def chat(system_prompt: Optional[str], model_name: str, + client: OpenAI) -> None: + conversation = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + input_message = input("> ") + message = {"role": "user", "content": input_message} + conversation.append(message) + + chat_completion = client.chat.completions.create(model=model_name, + messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) + print(output) + + +def download_weights( + model_name: str, + revision: Optional[str] = None, + token: Optional[str] = None, + extension: str = ".safetensors", + auto_convert: bool = True, +) -> None: + from vllm.tgis_utils import hub + + print(extension) + meta_exts = [".json", ".py", ".model", ".md"] + + extensions = extension.split(",") + + if len(extensions) == 1 and extensions[0] not in meta_exts: + extensions.extend(meta_exts) + + files = hub.download_weights(model_name, + extensions, + revision=revision, + auth_token=token) + + if auto_convert and ".safetensors" in extensions: + if not hub.local_weight_files(hub.get_model_path(model_name, revision), + ".safetensors"): + if ".bin" not in extensions: + print(".safetensors weights not found, \ + downloading pytorch weights to convert...") + hub.download_weights(model_name, + ".bin", + revision=revision, + auth_token=token) + + print(".safetensors weights not found, \ + converting from pytorch weights...") + convert_to_safetensors(model_name, revision) + elif not any(f.endswith(".safetensors") for f in files): + print(".safetensors weights not found on hub, \ + but were found locally. Remove them first to re-convert") + if auto_convert: + convert_to_fast_tokenizer(model_name, revision) + + +def convert_to_safetensors( + model_name: str, + revision: Optional[str] = None, +): + from vllm.tgis_utils import hub + + # Get local pytorch file paths + model_path = hub.get_model_path(model_name, revision) + local_pt_files = hub.local_weight_files(model_path, ".bin") + local_pt_index_files = hub.local_index_files(model_path, ".bin") + if len(local_pt_index_files) > 1: + print( + f"Found more than one .bin.index.json file: {local_pt_index_files}" + ) + return + if not local_pt_files: + print("No pytorch .bin files found to convert") + return + + local_pt_files = [Path(f) for f in local_pt_files] + local_pt_index_file = local_pt_index_files[ + 0] if local_pt_index_files else None + + # Safetensors final filenames + local_st_files = [ + p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" + for p in local_pt_files + ] + + if any(os.path.exists(p) for p in local_st_files): + print("Existing .safetensors weights found, \ + remove them first to reconvert") + return + + try: + import transformers + + config = transformers.AutoConfig.from_pretrained( + model_name, + revision=revision, + ) + architecture = config.architectures[0] + + class_ = getattr(transformers, architecture) + + # Name for this variable depends on transformers version + discard_names = getattr(class_, "_tied_weights_keys", []) + discard_names.extend( + getattr(class_, "_keys_to_ignore_on_load_missing", [])) + + except Exception: + discard_names = [] + + if local_pt_index_file: + local_pt_index_file = Path(local_pt_index_file) + st_prefix = local_pt_index_file.stem.removeprefix( + "pytorch_").removesuffix(".bin.index") + local_st_index_file = (local_pt_index_file.parent / + f"{st_prefix}.safetensors.index.json") + + if os.path.exists(local_st_index_file): + print("Existing .safetensors.index.json file found, \ + remove it first to reconvert") + return + + hub.convert_index_file(local_pt_index_file, local_st_index_file, + local_pt_files, local_st_files) + + # Convert pytorch weights to safetensors + hub.convert_files(local_pt_files, local_st_files, discard_names) + + +def convert_to_fast_tokenizer( + model_name: str, + revision: Optional[str] = None, + output_path: Optional[str] = None, +): + from vllm.tgis_utils import hub + + # Check for existing "tokenizer.json" + model_path = hub.get_model_path(model_name, revision) + + if os.path.exists(os.path.join(model_path, "tokenizer.json")): + print(f"Model {model_name} already has a fast tokenizer") + return + + if output_path is not None: + if not os.path.isdir(output_path): + print(f"Output path {output_path} must exist and be a directory") + return + else: + output_path = model_path + + import transformers + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, + revision=revision) + tokenizer.save_pretrained(output_path) + + print(f"Saved tokenizer to {output_path}") + + +def _add_query_options( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + "--url", + type=str, + default="http://localhost:8000/v1", + help="url of the running OpenAI-Compatible RESTful API server") + parser.add_argument( + "--model-name", + type=str, + default=None, + help=("The model name used in prompt completion, default to " + "the first model in list models API call.")) + parser.add_argument( + "--api-key", + type=str, + default=None, + help=( + "API key for OpenAI services. If provided, this api key " + "will overwrite the api key obtained through environment variables." + )) + return parser + + +def main(): + parser = FlexibleArgumentParser(description="vLLM CLI") + subparsers = parser.add_subparsers(required=True) + + serve_parser = subparsers.add_parser( + "serve", + help="Start the vLLM OpenAI Compatible API server", + usage="vllm serve [options]") + serve_parser.add_argument("model_tag", + type=str, + help="The model tag to serve") + serve_parser = make_arg_parser(serve_parser) + serve_parser.set_defaults(dispatch_function=serve) + + complete_parser = subparsers.add_parser( + "complete", + help=("Generate text completions based on the given prompt " + "via the running API server"), + usage="vllm complete [options]") + _add_query_options(complete_parser) + complete_parser.set_defaults(dispatch_function=interactive_cli, + command="complete") + + chat_parser = subparsers.add_parser( + "chat", + help="Generate chat completions via the running API server", + usage="vllm chat [options]") + _add_query_options(chat_parser) + chat_parser.add_argument( + "--system-prompt", + type=str, + default=None, + help=("The system prompt to be added to the chat template, " + "used for models that support system prompts.")) + chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") + + download_weights_parser = subparsers.add_parser( + "download-weights", + help=("Download the weights of a given model"), + usage="vllm download-weights [options]") + download_weights_parser.add_argument("model_name") + download_weights_parser.add_argument("--revision") + download_weights_parser.add_argument("--token") + download_weights_parser.add_argument("--extension", default=".safetensors") + download_weights_parser.add_argument("--auto_convert", default=True) + download_weights_parser.set_defaults(dispatch_function=tgis_cli, + command="download-weights") + + convert_to_safetensors_parser = subparsers.add_parser( + "convert-to-safetensors", + help=("Convert model weights to safetensors"), + usage="vllm convert-to-safetensors [options]") + convert_to_safetensors_parser.add_argument("model_name") + convert_to_safetensors_parser.add_argument("--revision") + convert_to_safetensors_parser.set_defaults( + dispatch_function=tgis_cli, command="convert-to-safetensors") + + convert_to_fast_tokenizer_parser = subparsers.add_parser( + "convert-to-fast-tokenizer", + help=("Convert to fast tokenizer"), + usage="vllm convert-to-fast-tokenizer [options]") + convert_to_fast_tokenizer_parser.add_argument("model_name") + convert_to_fast_tokenizer_parser.add_argument("--revision") + convert_to_fast_tokenizer_parser.add_argument("--output_path") + convert_to_fast_tokenizer_parser.set_defaults( + dispatch_function=tgis_cli, command="convert-to-fast-tokenizer") + + args = parser.parse_args() + # One of the sub commands should be executed. + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/vllm/tgis_utils/hub.py b/vllm/tgis_utils/hub.py new file mode 100644 index 000000000..4361b189f --- /dev/null +++ b/vllm/tgis_utils/hub.py @@ -0,0 +1,270 @@ +import concurrent +import datetime +import glob +import json +import logging +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from pathlib import Path +from typing import Dict, List, Optional + +import torch +from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache +from huggingface_hub.utils import LocalEntryNotFoundError +from safetensors.torch import (_find_shared_tensors, _is_complete, load_file, + save_file) +from tqdm import tqdm + +TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE") == "true" +logger = logging.getLogger(__name__) + + +def weight_hub_files(model_name, + extension=".safetensors", + revision=None, + auth_token=None): + """Get the safetensors filenames on the hub""" + exts = [extension] if isinstance(extension, str) else extension + api = HfApi() + info = api.model_info(model_name, revision=revision, token=auth_token) + filenames = [ + s.rfilename for s in info.siblings if any( + s.rfilename.endswith(ext) and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename and "args" not in s.rfilename + and "training" not in s.rfilename for ext in exts) + ] + return filenames + + +def weight_files(model_name, extension=".safetensors", revision=None): + """Get the local safetensors filenames""" + filenames = weight_hub_files(model_name, extension) + files = [] + for filename in filenames: + cache_file = try_to_load_from_cache(model_name, + filename=filename, + revision=revision) + if cache_file is None: + raise LocalEntryNotFoundError( + f"File {filename} of model {model_name} not found in " + f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " + f"Please run `vllm \ + download-weights {model_name}` first.") + files.append(cache_file) + + return files + + +def download_weights(model_name, + extension=".safetensors", + revision=None, + auth_token=None): + """Download the safetensors files from the hub""" + filenames = weight_hub_files(model_name, + extension, + revision=revision, + auth_token=auth_token) + + download_function = partial( + hf_hub_download, + repo_id=model_name, + local_files_only=False, + revision=revision, + token=auth_token, + ) + + print(f"Downloading {len(filenames)} files for model {model_name}") + executor = ThreadPoolExecutor(max_workers=5) + futures = [ + executor.submit(download_function, filename=filename) + for filename in filenames + ] + files = [ + future.result() + for future in tqdm(concurrent.futures.as_completed(futures), + total=len(futures)) + ] + + return files + + +def get_model_path(model_name: str, revision: Optional[str] = None): + """Get path to model dir in local huggingface hub (model) cache""" + config_file = "config.json" + err = None + try: + config_path = try_to_load_from_cache( + model_name, + config_file, + cache_dir=os.getenv("TRANSFORMERS_CACHE" + ), # will fall back to HUGGINGFACE_HUB_CACHE + revision=revision, + ) + if config_path is not None: + return config_path.removesuffix(f"/{config_file}") + except ValueError as e: + err = e + + if os.path.isfile(f"{model_name}/{config_file}"): + return model_name # Just treat the model name as an explicit model path + + if err is not None: + raise err + + raise ValueError( + f"Weights not found in local cache for model {model_name}") + + +def local_weight_files(model_path: str, extension=".safetensors"): + """Get the local safetensors filenames""" + ext = "" if extension is None else extension + return glob.glob(f"{model_path}/*{ext}") + + +def local_index_files(model_path: str, extension=".safetensors"): + """Get the local .index.json filename""" + ext = "" if extension is None else extension + return glob.glob(f"{model_path}/*{ext}.index.json") + + +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + # _find_shared_tensors returns a list of sets of names of tensors that + # have the same data, including sets with one element that aren't shared + if len(shared) == 1: + continue + + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + raise RuntimeError(f"Error while trying to find names to remove \ + to save state dict, but found no suitable name to \ + keep for saving amongst: {shared}. None is covering \ + the entire storage.Refusing to save/load the model \ + since you could be storing much more \ + memory than needed. Please refer to\ + https://huggingface.co/docs/safetensors/torch_shared_tensors \ + for more information. \ + Or open an issue.") + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): + """ + Convert a pytorch file to a safetensors file + This will remove duplicate tensors from the file. + + Unfortunately, this might not respect *transformers* convention. + Forcing us to check for potentially different keys during load when looking + for specific tensors (making tensor sharing explicit). + """ + loaded = torch.load(pt_file, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_file) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_file, metadata=metadata) + reloaded = load_file(sf_file) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_index_file(source_file: Path, dest_file: Path, + pt_files: List[Path], sf_files: List[Path]): + weight_file_map = {s.name: d.name for s, d in zip(pt_files, sf_files)} + + logger.info( + "Converting pytorch .bin.index.json files to .safetensors.index.json") + with open(source_file, "r") as f: + index = json.load(f) + + index["weight_map"] = { + k: weight_file_map[v] + for k, v in index["weight_map"].items() + } + + with open(dest_file, "w") as f: + json.dump(index, f, indent=4) + + +def convert_files(pt_files: List[Path], + sf_files: List[Path], + discard_names: List[str] = None): + assert len(pt_files) == len(sf_files) + + # Filter non-inference files + pairs = [ + p for p in zip(pt_files, sf_files) if not any(s in p[0].name for s in [ + "arguments", + "args", + "training", + "optimizer", + "scheduler", + "index", + ]) + ] + + N = len(pairs) + + if N == 0: + logger.warning("No pytorch .bin weight files found to convert") + return + + logger.info("Converting %d pytorch .bin files to .safetensors...", N) + + for i, (pt_file, sf_file) in enumerate(pairs): + file_count = (i + 1) / N + logger.info('Converting: [%d] "$s"', file_count, pt_file.name) + start = datetime.datetime.now() + convert_file(pt_file, sf_file, discard_names) + elapsed = datetime.datetime.now() - start + logger.info('Converted: [%d] "%s" -- Took: %d', file_count, + sf_file.name, elapsed) diff --git a/vllm/utils.py b/vllm/utils.py index ffe921e65..e40aaf278 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,3 +1,4 @@ +import argparse import asyncio import datetime import enum @@ -775,3 +776,27 @@ def wrapper(*args, **kwargs) -> Any: wrapper.has_run = False # type: ignore[attr-defined] return wrapper + + +class FlexibleArgumentParser(argparse.ArgumentParser): + """ArgumentParser that allows both underscore and dash in names.""" + + def parse_args(self, args=None, namespace=None): + if args is None: + args = sys.argv[1:] + + # Convert underscores to dashes and vice versa in argument names + processed_args = [] + for arg in args: + if arg.startswith('--'): + if '=' in arg: + key, value = arg.split('=', 1) + key = '--' + key[len('--'):].replace('_', '-') + processed_args.append(f'{key}={value}') + else: + processed_args.append('--' + + arg[len('--'):].replace('_', '-')) + else: + processed_args.append(arg) + + return super().parse_args(processed_args, namespace)