Skip to content

Commit

Permalink
Merge pull request #46 from vllm-project/main
Browse files Browse the repository at this point in the history
  • Loading branch information
openshift-merge-bot[bot] authored Jun 6, 2024
2 parents cf31bc0 + 89c9207 commit 5f62558
Show file tree
Hide file tree
Showing 47 changed files with 778 additions and 434 deletions.
26 changes: 26 additions & 0 deletions .buildkite/nightly-benchmarks/kickoff-pipeline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env bash

set -euo pipefail

# Install system packages
apt update
apt install -y curl jq

# Install minijinja for templating
curl -sSfL https://github.com/mitsuhiko/minijinja/releases/latest/download/minijinja-cli-installer.sh | sh
source $HOME/.cargo/env

# If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq
if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then
PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name')

if [[ $PR_LABELS == *"perf-benchmarks"* ]]; then
echo "This PR has the 'perf-benchmarks' label. Proceeding with the nightly benchmarks."
else
echo "This PR does not have the 'perf-benchmarks' label. Skipping the nightly benchmarks."
exit 0
fi
fi

# Upload sample.yaml
buildkite-agent pipeline upload .buildkite/nightly-benchmarks/sample.yaml
39 changes: 39 additions & 0 deletions .buildkite/nightly-benchmarks/sample.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
steps:
# NOTE(simon): You can create separate blocks for different jobs
- label: "A100: NVIDIA SMI"
agents:
queue: A100
plugins:
- kubernetes:
podSpec:
containers:
# - image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT
# TODO(simon): check latest main branch or use the PR image.
- image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
command:
- bash -c 'nvidia-smi && nvidia-smi topo -m && pwd && ls'
resources:
limits:
nvidia.com/gpu: 8
volumeMounts:
- name: devshm
mountPath: /dev/shm
nodeSelector:
nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB
volumes:
- name: devshm
emptyDir:
medium: Memory
# TODO(simon): bring H100 online
# - label: "H100: NVIDIA SMI"
# agents:
# queue: H100
# plugins:
# - docker#v5.11.0:
# image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
# command:
# - bash -c 'nvidia-smi && nvidia-smi topo -m'
# propagate-environment: true
# ipc: host
# gpus: all

2 changes: 1 addition & 1 deletion .buildkite/run-benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ echo "### Serving Benchmarks" >> benchmark_results.md
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
echo "" >> benchmark_results.md
echo '```' >> benchmark_results.md
tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines
tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines
echo '```' >> benchmark_results.md

# if the agent binary is not found, skip uploading the results, exit 0
Expand Down
12 changes: 7 additions & 5 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py

- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
Expand Down Expand Up @@ -93,14 +93,13 @@ steps:
- label: Models Test
#mirror_hardwares: [amd]
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models --ignore=models/test_llava.py
- pytest -v -s models -m \"not llava\"

- label: Llava Test
mirror_hardwares: [amd]
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models/test_llava.py
- pytest -v -s models -m llava

- label: Prefix Caching Test
mirror_hardwares: [amd]
Expand All @@ -124,7 +123,10 @@ steps:

- label: Speculative decoding tests
#mirror_hardwares: [amd]
command: pytest -v -s spec_decode
commands:
# See https://github.com/vllm-project/vllm/issues/5152
- export VLLM_ATTENTION_BACKEND=XFORMERS
- pytest -v -s spec_decode

- label: LoRA Test %N
#mirror_hardwares: [amd]
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ Easy, fast, and cheap LLM serving for everyone

---

**Ray Summit CPF is Open (June 4th to June 20th)!**

There will be a track for vLLM at the Ray Summit (09/30-10/02, SF) this year!
If you have cool projects related to vLLM or LLM inference, we would love to see your proposals.
This will be a great chance for everyone in the community to get together and learn.
Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite)

**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**

We are thrilled to announce our fourth vLLM Meetup!
Expand Down Expand Up @@ -104,6 +111,7 @@ vLLM is a community project. Our compute resources for development and testing a
- Dropbox
- Lambda Lab
- NVIDIA
- Sequoia Capital
- Replicate
- Roblox
- RunPod
Expand Down
23 changes: 22 additions & 1 deletion benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class BenchmarkMetrics:
mean_tpot_ms: float
median_tpot_ms: float
p99_tpot_ms: float
mean_itl_ms: float
median_itl_ms: float
p99_itl_ms: float


def sample_sharegpt_requests(
Expand Down Expand Up @@ -200,16 +203,24 @@ def calculate_metrics(
actual_output_lens = []
total_input = 0
completed = 0
itls = []
tpots = []
ttfts = []
for i in range(len(outputs)):
if outputs[i].success:
output_len = len(tokenizer(outputs[i].generated_text).input_ids)
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note: this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
actual_output_lens.append(output_len)
total_input += input_requests[i][1]
if output_len > 1:
tpots.append(
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
ttfts.append(outputs[i].ttft)
completed += 1
else:
Expand All @@ -234,6 +245,9 @@ def calculate_metrics(
mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
)

return metrics, actual_output_lens
Expand Down Expand Up @@ -333,6 +347,10 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("=" * 50)

result = {
Expand All @@ -349,6 +367,9 @@ async def benchmark(
"mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms,
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def main(args: argparse.Namespace):

if args.batch_size is None:
batch_sizes = [
1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]
else:
batch_sizes = [args.batch_size]
Expand Down
105 changes: 70 additions & 35 deletions csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,44 @@ using namespace cute;

namespace {

template <typename Arch, typename ElementAB_, typename ElementD_,
typename TileShape, typename WarpShape, typename InstructionShape,
int32_t MainLoopStages>
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};

template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};

template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};

template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
Expand Down Expand Up @@ -101,7 +136,7 @@ struct cutlass_2x_gemm {
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
Expand All @@ -112,7 +147,7 @@ struct cutlass_2x_gemm {
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel;
>::GemmKernel>;
// clang-format on

using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
Expand Down Expand Up @@ -208,16 +243,16 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm75, int8_t, cutlass::half_t, TileShape,
WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 2>>(out, a, b, a_scales,
b_scales);
}
}

Expand All @@ -235,16 +270,16 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm80, int8_t, cutlass::half_t, TileShape,
WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
}
}

Expand All @@ -263,32 +298,32 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<
cutlass_2x_gemm<cutlass::arch::Sm89, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::bfloat16_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_scaled_mm_dq_dispatcher<cutlass_2x_gemm<
cutlass::arch::Sm89, cutlass::float_e4m3_t, cutlass::half_t,
TileShape, WarpShape, InstructionShape, 5>>(out, a, b, a_scales,
b_scales);
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
}
}
}
Loading

0 comments on commit 5f62558

Please sign in to comment.