Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update on naive_attn module #21

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
42 changes: 42 additions & 0 deletions Dockerfile.rocm.perf.benchmark
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# CONTEXT {'gpu_vendor': 'AMD', 'guest_os': 'UBUNTU'}
ARG BASE_DOCKER=rocm/pytorch:latest
FROM $BASE_DOCKER
USER root

# env
ENV WORKSPACE_DIR=/workspace
ENV MAX_JOBS=64
# vllm uses different env variables depending on branches. So just turn on all cases
ENV VLLM_USE_TRITON_FLASH_ATTN=1
ENV VLLM_USE_FLASH_ATTN_TRITON=1
ENV HIP_FORCE_DEV_KERNARG=0
ENV OPTIMIZE_EPILOGUE=1

# tunableOps
ENV PYTORCH_TUNABLEOP_ENABLED=0

WORKDIR $WORKSPACE_DIR

# torch
RUN pip install --upgrade pip
RUN pip install pandas

# TODO: remove when triton in BASE_DOCKER is updated to ver 2.3.0
ARG TRITON_COMMIT="bbe6246"
RUN pip uninstall triton -y && git clone https://github.com/triton-lang/triton.git && cd triton && git checkout ${TRITON_COMMIT} && cd python && python setup.py install

# vllm
RUN git clone https://github.com/ROCm/vllm.git -b perf_benchmark_navi &&\
cd vllm &&\
pip install -U -r requirements-rocm.txt &&\
python3 setup.py install

# gradlib
RUN cd vllm/gradlib \
&& pip install .

RUN cd /opt/rocm/share/amd_smi &&\
pip install .

# record configuration for posterity
RUN pip list
14 changes: 14 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ def main(args: argparse.Namespace):

# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
# exceptions for navi, let's leave this for huge models
if args.model == "TheBloke/Llama-2-70B-Chat-GPTQ" or args.model == "MaziyarPanahi/Meta-Llama-3-70B-Instruct-GPTQ":
if int(torch.cuda.get_device_properties(0).total_memory/1e9) < 40:
effective_max_model_len = 2000
elif int(torch.cuda.get_device_properties(0).total_memory/1e9) < 50: # w7900 48 GB
effective_max_model_len = 3000
else:
effective_max_model_len = None # derived from the model config
else:
effective_max_model_len = None

print("effective_max_model_len", effective_max_model_len)

llm = LLM(model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
Expand All @@ -30,6 +43,7 @@ def main(args: argparse.Namespace):
worker_use_ray=args.worker_use_ray,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
max_model_len=effective_max_model_len,
block_size=args.block_size)

sampling_params = SamplingParams(
Expand Down
277 changes: 277 additions & 0 deletions benchmarks/benchmark_latency_vllm_DLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
"""TODO: put verbose info of libs device"""
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import time
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from tqdm import tqdm

from vllm import LLM, SamplingParams
import csv

def main(args: argparse.Namespace):
print(args)

# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
worker_use_ray=args.worker_use_ray,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size)

sampling_params_prefill = SamplingParams(
n=args.n,
#temperature=0.0 if args.use_beam_search else 1.0,
temperature=0.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=1,
)
print(sampling_params_prefill)

sampling_params_decoding = SamplingParams(
n=args.n,
#temperature=0.0 if args.use_beam_search else 1.0,
temperature=0.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=args.output_len,
)
print(sampling_params_decoding)
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()

def run_to_completion(profile_dir: Optional[str] = None, sampling_params=sampling_params_decoding):
if profile_dir:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)

p.export_chrome_trace("prof_pt.json")
print(p.key_averages())
else:
start_time = time.perf_counter()
llm_out = llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
latency = end_time - start_time
if args.out_token_num:
for request_output in llm_out:
out_tkn_ids = [output.token_ids for output in request_output.outputs]
print("INFO: number of tokens", len(out_tkn_ids[0]))
return latency

print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None, sampling_params=sampling_params_decoding)

if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = Path(
"."
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir, sampling_params=sampling_params_decoding)
return

# Benchmark.
latencies_prefill = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies_prefill.append(run_to_completion(profile_dir=None, sampling_params=sampling_params_prefill))
latencies_prefill = np.array(latencies_prefill)
percentages = [10, 25, 50, 75, 90]
percentiles = np.percentile(latencies_prefill, percentages)
latencies_prefill = np.mean(latencies_prefill) * 1000
print(f'Avg prefill latency: {latencies_prefill} ms')


latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None, sampling_params=sampling_params_decoding))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90]
latencies = np.mean(latencies) * 1000
print(f'Avg prefill+decoding latency: {latencies} ms')

for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds')
sep = " , "


throughput = 1000/np.mean(latencies)*args.batch_size*args.output_len
latencies_decoding = (latencies - latencies_prefill) / (args.output_len - 1)

with open(args.csv, mode='a') as csv_latency:
csv_latency.write("model, prefill/decoding, latency, throughtput, tp, batch, input, output \n")
prefill_csv = args.model +sep+ \
"PREFILL" +sep+ \
str(latencies_prefill) +sep+\
str(0) +sep+\
str(args.tensor_parallel_size) +sep+\
str(args.batch_size) +sep+\
str(args.input_len) +sep+\
str(args.output_len) +"\n"

decoding_csv = args.model +sep+ \
"DECODING" +sep+ \
str(latencies_decoding) +sep+\
str(throughput) +sep+\
str(args.tensor_parallel_size) +sep+\
str(args.batch_size) +sep+\
str(args.input_len) +sep+\
str(args.output_len) +"\n"

print(prefill_csv)
print(decoding_csv)
csv_latency.write(prefill_csv)
csv_latency.write(decoding_csv)
if args.verbose:
csv_latency.write("torch: " + torch.__version__)
csv_latency.write("rocm/cuda: " + torch._C._cuda_getCompiledVersion())
csv_latency.write("device: " + torch.cuda.get_device_name(0))

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters-warmup',
type=int,
default=5,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
type=int,
default=5,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
'--profile',
action='store_true',
help='profile the generation process of a single batch')
parser.add_argument(
'--profile-result-dir',
type=str,
default=None,
help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
help="If specified, use nsight to profile ray workers",
)
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU '
'unless on ROCm where the default is torchrun')
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
"--csv",
type=str,
help="Csv file out"
)
parser.add_argument(
"--verbose",
action='store_true',
help="print out the verbose info"
)
parser.add_argument(
"--out-token-num",
action='store_true',
help="print out the number of out tokens"
)
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion csrc/attention/attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
}

// Finalize the reduction across lanes.
Expand Down
4 changes: 2 additions & 2 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake >= 3.21
ninja # For faster builds.
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy
numpy==1.26.4
requests
py-cpuinfo
transformers >= 4.39.1 # Required for StarCoder2 & Llava.
Expand All @@ -11,4 +11,4 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34 # Requires torch >= 2.1.0
#outlines == 0.0.34 # Requires torch >= 2.1.0
4 changes: 4 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torchrun --standalone --nnodes 1 --nproc-per-node 8 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 8 -tp 8 --input-len 128 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
torchrun --standalone --nnodes 1 --nproc-per-node 8 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 8 -tp 8 --input-len 2048 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
torchrun --standalone --nnodes 1 --nproc-per-node 8 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 32 -tp 8 --input-len 128 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
torchrun --standalone --nnodes 1 --nproc-per-node 8 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 32 -tp 8 --input-len 2048 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
6 changes: 6 additions & 0 deletions run_batch1to32sweep.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torchrun --standalone --nnodes 1 --nproc-per-node 4 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 1 -tp 4 --input-len 2048 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
torchrun --standalone --nnodes 1 --nproc-per-node 4 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 2 -tp 4 --input-len 2048 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
torchrun --standalone --nnodes 1 --nproc-per-node 4 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 4 -tp 4 --input-len 2048 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
torchrun --standalone --nnodes 1 --nproc-per-node 4 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 8 -tp 4 --input-len 2048 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
torchrun --standalone --nnodes 1 --nproc-per-node 4 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 16 -tp 4 --input-len 2048 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
torchrun --standalone --nnodes 1 --nproc-per-node 4 /workspace/vllm/benchmarks/benchmark_latency_vllm_DLM.py --model NousResearch/Meta-Llama-3-70B --batch-size 32 -tp 4 --input-len 2048 --output-len 128 --num-iters-warmup 5 --num-iters 5 --trust-remote-code --dtype half --csv result.csv
Loading