Skip to content

Commit

Permalink
update doc and launch params (#733)
Browse files Browse the repository at this point in the history
Co-authored-by: shihaobai <[email protected]>
  • Loading branch information
shihaobai and shihaobai authored Feb 17, 2025
1 parent 250d7ad commit 9452b39
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 17 deletions.
16 changes: 4 additions & 12 deletions docs/EN/source/getting_started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The easiest way to install Lightllm is by using the official image. You can dire
$
$ # Run the image
$ docker run -it --gpus all -p 8080:8080 \
$ --shm-size 1g -v your_local_path:/data/ \
$ --shm-size 32g -v your_local_path:/data/ \
$ ghcr.io/modeltc/lightllm:main /bin/bash
You can also manually build and run the image from the source:
Expand All @@ -39,7 +39,7 @@ You can also manually build and run the image from the source:
$
$ # Run the image
$ docker run -it --gpus all -p 8080:8080 \
$ --shm-size 1g -v your_local_path:/data/ \
$ --shm-size 32g -v your_local_path:/data/ \
$ <image_name> /bin/bash
Alternatively, you can use a script to automatically build and run the image:
Expand Down Expand Up @@ -81,16 +81,8 @@ NOTE: If you are using torch with cuda 11.x instead, run `pip install nvidia-ncc
.. note::

The Lightllm code has been tested on various GPUs, including V100, A100, A800, 4090, and H800.
If you are using A100, A800, or similar GPUs, it is recommended to install triton==3.0.0:
If you are using A100, A800, or similar GPUs, it is recommended to install triton==3.1.0:

.. code-block:: console
$ pip install triton==3.0.0 --no-deps
If you are using H800, V100, or similar GPUs, it is recommended to install triton-nightly:

.. code-block:: console
$ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps
For more details, refer to: `issue <https://github.com/triton-lang/triton/issues/3619>`_ and `fix PR <https://github.com/triton-lang/triton/pull/3638>`_
$ pip install triton==3.1.0 --no-deps
10 changes: 10 additions & 0 deletions docs/EN/source/getting_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ After downloading the Llama-2-7b-chat model, use the following command in the te
.. note::
The ``--model_dir`` parameter in the above command should be changed to the actual path of your model on your machine.

For the DeepSeek-R1 model on H200, it can be launched with the following command:

.. code-block:: console
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 8 --graph_max_batch_size 100
.. note::
LOADWORKER specifies the thread for model loading, which can enhance the speed of model loading. The --graph_max_batch_size parameter specifies the number of cudagraphs to be captured, which will capture graphs for batch sizes ranging from 1 to 100.


3. (Optional) Test the Model Service
--------------------------------------

Expand Down
11 changes: 11 additions & 0 deletions lightllm/common/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ def _parse_network_config(self, network_config):
activation_scheme = network_config.get("activation_scheme", "dynamic")
self.static_activation = activation_scheme == "static"
self.hf_quantization_config = hf_quantization_config
self.hf_quantization_method = hf_quantization_config["quant_method"]
self._mapping_quant_method()

def _mapping_quant_method(self):
if self.hf_quantization_method == "fp8":
block_size = self.hf_quantization_config.get("weight_block_size", None)
if block_size == [128, 128]:
self.quant_type = "vllm-fp8w8a8-b128"
else:
# TODO: more quant method
pass

def _parse_custom_cfg(self, custom_cfg_path):
self.quant_cfg = collections.defaultdict(dict)
Expand Down
12 changes: 9 additions & 3 deletions lightllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import torch.distributed as dist
from torch.distributed import ReduceOp
from lightllm.utils.log_utils import init_logger
from functools import partial
from lightllm.utils.device_utils import has_nvlink
from lightllm.utils.envs_utils import get_env_start_args

original_all_reduce = torch.distributed.all_reduce
original_all_gather_into_tensor = torch.distributed.all_gather_into_tensor
Expand Down Expand Up @@ -67,10 +68,13 @@ def lightllm_capture_graph(self):
yield

def set_custom_reduce(self):
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in ["ON", "TRUE", "1"]
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "True").upper() in ["ON", "TRUE", "1"]
world_size = dist.get_world_size()
ranks = list(range(world_size))

if not has_nvlink() or world_size not in [2, 4, 6, 8]:
ENABLE_VLLM_REDUCE = False

# 创建新的 NCCL 组以防止原始 all_reduce 与 cudagraph 卡住
if self.device_group is None:
self.device_group = dist.new_group(ranks, backend="nccl")
Expand All @@ -93,11 +97,13 @@ def _all_reduce_closure(input_, op=ReduceOp.SUM, group=self.device_group, async_

def set_custom_gather(self):
ENABLE_CUSTOM_GATHER = os.getenv("ENABLE_CUSTOM_GATHER", "False").upper() in ["ON", "TRUE", "1"]
args = get_env_start_args()
world_size = dist.get_world_size()
ranks = list(range(world_size))
if self.device_group is None:
self.device_group = dist.new_group(ranks, backend="nccl")
if ENABLE_CUSTOM_GATHER and HAS_LIGHTLLM_KERNEL:

if ENABLE_CUSTOM_GATHER and HAS_LIGHTLLM_KERNEL or args.disable_custom_allreduce:
cpu_group = dist.new_group(ranks, backend="gloo")
self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device())
logger.info("Enable Custom ALLGather.")
Expand Down
3 changes: 2 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models."
)
parser.add_argument("--disable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.")
parser.add_argument(
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
)
Expand Down Expand Up @@ -225,7 +226,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--graph_max_len_in_batch",
type=int,
default=8192,
default=0,
help="""Maximum sequence length that can be captured by the cuda graph for decodign stage.
The default value is 8192. It will turn into eagar mode if encounters a larger value. """,
)
Expand Down
3 changes: 3 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def normal_or_p_d_start(args):
if not args.enable_chunked_prefill:
args.chunked_prefill_size = 0

if args.graph_max_len_in_batch == 0:
args.graph_max_len_in_batch = args.max_req_total_len

# 这些模式不能同时设置。
assert [
args.enable_chunked_prefill,
Expand Down
17 changes: 17 additions & 0 deletions lightllm/utils/device_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from functools import lru_cache
import subprocess


def set_current_device_id(device_id: int):
Expand Down Expand Up @@ -103,3 +104,19 @@ def init_p2p(device_index):
@lru_cache(maxsize=None)
def kv_trans_use_p2p():
return os.getenv("KV_TRANS_USE_P2P", "False").upper() in ["1", "TRUE", "ON"]


def has_nvlink():
try:
# Call nvidia-smi to get the topology matrix
result = subprocess.check_output(["nvidia-smi", "topo", "--matrix"])
result = result.decode("utf-8")

# Check if the output contains 'NVLink'
if "NVLink" in result:
return True
else:
return False
except subprocess.CalledProcessError:
# If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink
return False
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,5 @@ ujson==5.10.0
frozendict==2.4.6
atomics==1.0.3
easydict==1.13
gunicorn==23.0.0
gunicorn==23.0.0
vllm==0.7.2

0 comments on commit 9452b39

Please sign in to comment.