Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and joerunde committed Mar 11, 2024
1 parent 1810585 commit ab322bb
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 28 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate

COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY --from=build /workspace/vllm/thirdparty_files /workspace/vllm/thirdparty_files
COPY vllm vllm

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
Expand Down
44 changes: 32 additions & 12 deletions benchmarks/kernels/benchmark_mixtral_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,39 @@

def main():
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
run_grid(bs, method=method)


def run_grid(bs, method):
d_model = 4096
num_total_experts = 8
top_k = 2
tp_size = 2
model_intermediate_size = 14336
num_layers = 32
num_calls = 100
best_configs = {}

for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
best_configs.update(
run_grid(bs=bs,
method=method,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
num_layers=num_layers))

device_name = torch.cuda.get_device_name().replace(" ", "_")
filename = f"E={num_total_experts},N={model_intermediate_size//tp_size},device_name={device_name}.json"
print(f"writing combined configs to file {filename}")
with open(filename, 'w') as fd:
json.dump(best_configs, fd, indent=4)


def run_grid(bs: int, method, d_model: int, num_total_experts: int, top_k: int,
tp_size: int, model_intermediate_size: int,
num_layers: int) -> float:
num_calls = 100
num_warmup_trials = 1
num_trials = 1

Expand Down Expand Up @@ -64,7 +81,7 @@ def run_grid(bs, method):
print(f'{tp_size=} {bs=}')
print(f'{config}')
# warmup
print(f'warming up')
print('warming up')
try:
for _ in range(num_warmup_trials):
run_timing(
Expand All @@ -82,7 +99,7 @@ def run_grid(bs, method):
continue

# trial
print(f'benchmarking')
print('benchmarking')
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
Expand All @@ -109,11 +126,14 @@ def run_grid(bs, method):

print("best_time_us", best_time_us)
print("best_config", best_config)
bs_best_config = {str(bs): best_config}

filename = "/tmp/config.jsonl"
print(f"writing config to file {filename}")
with open(filename, "a") as f:
f.write(json.dumps({str(bs): best_config}) + "\n")
f.write(json.dumps(bs_best_config) + "\n")

return bs_best_config


def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
Expand Down
36 changes: 20 additions & 16 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from grpc.aio import ServicerContext
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.logger import init_logger
from vllm.config import ModelConfig
from vllm.entrypoints.grpc.pb import generation_pb2_grpc
Expand All @@ -21,6 +20,8 @@
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.sampling_params import LogitsProcessor
from vllm.tgis_utils.logits_processors import MinTokensLogitsProcessor, TypicalLogitsWarperWrapper
from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.sequence import Logprob
from vllm import AsyncLLMEngine, SamplingParams, RequestOutput, CompletionOutput

logger = init_logger(__name__)
Expand Down Expand Up @@ -398,7 +399,7 @@ def _convert_reason(output: CompletionOutput, max_is_token_limit: bool,
def _convert_tokens(
self,
token_ids: list[int],
logprobs_list: Optional[list[Dict[int, float]]],
logprobs_list: Optional[list[Dict[int, Logprob]]],
include_logprobs: bool,
top_n_tokens: int,
token_infos: MutableSequence[TokenInfo], # OUT
Expand All @@ -414,20 +415,23 @@ def _convert_tokens(
token_info = TokenInfo(text=text)
if logprobs_list is not None:
logprobs = logprobs_list[i]
if include_logprobs:
token_info.logprob = logprobs[token_ids[i]]
if top_n_tokens:
items = sorted(logprobs.items(),
key=lambda item: item[1],
reverse=True)[:top_n_tokens]
#TODO later use get_lora_tokenizer here
tt_texts = self.tokenizer.convert_ids_to_tokens(
[tid for tid, _ in items])
token_info.top_tokens.extend(
TokenInfo.TopToken(
text=tt_text,
logprob=logprob,
) for tt_text, (_, logprob) in zip(tt_texts, items))
# Logprobs entry will be None for first prompt token
if logprobs is not None:
if include_logprobs:
token_info.logprob = logprobs[token_ids[i]].logprob
if top_n_tokens:
items = sorted(logprobs.items(),
key=lambda item: item[1].logprob,
reverse=True)[:top_n_tokens]
#TODO later use get_lora_tokenizer here
tt_texts = self.tokenizer.convert_ids_to_tokens(
[tid for tid, _ in items])
token_info.top_tokens.extend(
TokenInfo.TopToken(
text=tt_text,
logprob=logprob.logprob,
)
for tt_text, (_, logprob) in zip(tt_texts, items))
token_infos.append(token_info)

async def _validate_prompt_and_tokenize(
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ async def validation_exception_handler(_, exc):
@app.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

0 comments on commit ab322bb

Please sign in to comment.