diff --git a/requirements-dev.txt b/requirements-dev.txt index cf2bb9bef..15e3970a1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -27,7 +27,7 @@ einops # required for MPT httpx peft requests -ray +#ray sentence-transformers # required for embedding # Benchmarking diff --git a/tests/entrypoints/Makefile b/tests/entrypoints/Makefile new file mode 100644 index 000000000..d3f51da47 --- /dev/null +++ b/tests/entrypoints/Makefile @@ -0,0 +1,15 @@ +gen-client: + # Compile protos + pip install grpcio-tools==1.60.0 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir + mkdir pb || true + python -m grpc_tools.protoc -I../../proto --python_out=pb \ + --grpc_python_out=pb --mypy_out=pb ../../proto/generation.proto + find pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; + touch pb/__init__.py + +install: gen-client + pip install pip --upgrade + pip install -e . --no-cache-dir + +test: + pytest -sv test_server_tokenize_truncate.py \ No newline at end of file diff --git a/tests/entrypoints/test_server_tokenize_truncate.py b/tests/entrypoints/test_server_tokenize_truncate.py new file mode 100644 index 000000000..4350d01ec --- /dev/null +++ b/tests/entrypoints/test_server_tokenize_truncate.py @@ -0,0 +1,191 @@ +# imports for guided decoding tests +import pytest +# using Ray for overall ease of process management, parallel requests, +# and debugging. +import ray +import grpc +# to install pb, run Makefile to compile grpc protobuf +from .pb import generation_pb2_grpc as gpb2, generation_pb2 as pb2 +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ..utils import ServerRunner + +# Config. vars for gRPC +SERVER = 'localhost' +PORT = 8033 + +# The tokenizer was tested using the following model: +MODEL_NAME = "facebook/opt-125m" + +@pytest.fixture(scope="module") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16" + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + +# Fixture to create a gRPC stub for the GenerationService +@pytest.fixture(scope="module") +def grpc_stub(): + channel = grpc.insecure_channel(f"{SERVER}:{PORT}") + stub = gpb2.GenerationServiceStub(channel) + yield stub + channel.close() + +# Test cases +@pytest.mark.parametrize("test_case", [ + { + "name": "Tokenize with offsets", + "request": { + "text": "The very long story is written", + "return_offsets": True, + }, + "response": { + "tokenCount": 7, + "offsets": [ + {"start": 0, "end": 0}, + {"start": 0, "end": 3}, + {"start": 3, "end": 8}, + {"start": 8, "end": 13}, + {"start": 13, "end": 19}, + {"start": 19, "end": 22}, + {"start": 22, "end": 30}, + ], + }, + }, + { + "name": "Tokenize with tokens and offsets", + "request": { + "text": "The very long story is written", + "return_tokens": True, + "return_offsets": True, + }, + "response": { + "tokenCount": 7, + "tokens": [ + "", + "The", + "Ġvery", + "Ġlong", + "Ġstory", + "Ġis", + "Ġwritten" + ], + "offsets": [ + {"start": 0, "end": 0}, + {"start": 0, "end": 3}, + {"start": 3, "end": 8}, + {"start": 8, "end": 13}, + {"start": 13, "end": 19}, + {"start": 19, "end": 22}, + {"start": 22, "end": 30}, + ], + }, + }, + { + "name": "Tokenize with tokens and truncation", + "request": { + "text": "The very long story is written by a very long story", + "return_tokens": True, + "truncate_input_tokens": 10, + }, + "response": { + "tokenCount": 10, + "tokens": [ + "Ġvery", + "Ġlong", + "Ġstory", + "Ġis", + "Ġwritten", + "Ġby", + "Ġa", + "Ġvery", + "Ġlong", + "Ġstory", + ], + }, + }, + { + "name": "Tokenize, trunc and offset for a request with no text message", + "request": { + "text": "", + "return_offsets": True, + "return_tokens": True, + "truncate_input_tokens": 10, + }, + "response": { + "tokenCount": 1, + "tokens": [ + "" + ], + }, + }, + { + "name": "A request without text ('') and parameters", + "request": { + "text" : "" + }, + "response": { + "tokenCount": 1 + }, + }, + { + "name": "A request without text (None) and parameters", + "request": { + "text" : None + }, + "response": { + "tokenCount": 1 + }, + }, +]) +def test_tokenization(server, grpc_stub, test_case): + """Test tokenization with the given test case.""" + text = test_case['request']['text'] + truncate_input_tokens = test_case['request'].get('truncate_input_tokens', + None) + + # Construct the request + batch = pb2.BatchedTokenizeRequest( + model_id="unused", + requests=[pb2.TokenizeRequest(text=text)], + return_tokens=test_case['request'].get('return_tokens', False), + return_offsets=test_case['request'].get('return_offsets', False), + truncate_input_tokens=truncate_input_tokens + ) + + try: + responses = grpc_stub.Tokenize(batch).responses + except grpc.RpcError as e: + # Print debug message in case of connection failure + print(f"Failed to connect to the gRPC server: {e}") + pytest.fail(f"gRPC call failed with error: {e}") + + # Verify the response + expected_response = test_case['response'] + for resp in responses: + assert resp.token_count == expected_response['tokenCount'],\ + "Token count mismatch" + if 'tokens' in expected_response: + assert resp.tokens == expected_response['tokens'],\ + "Tokens mismatch" + if 'offsets' in expected_response: + expected_offsets = expected_response['offsets'] + assert len(resp.offsets) == len(expected_offsets),\ + "Offset length mismatch" + for resp_offset, exp_offset in zip(resp.offsets, expected_offsets): + assert resp_offset.start == exp_offset.get('start', None),\ + "Start offset mismatch" + assert resp_offset.end == exp_offset.get('end', None),\ + "End offset mismatch" + print("Test case passed: ", test_case["name"]) + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/vllm/entrypoints/grpc/grpc_server.py b/vllm/entrypoints/grpc/grpc_server.py index 931b613af..5f1e23f72 100644 --- a/vllm/entrypoints/grpc/grpc_server.py +++ b/vllm/entrypoints/grpc/grpc_server.py @@ -559,33 +559,75 @@ async def _validate_prompt_and_tokenize( return input_ids, max_is_token_limit @log_rpc_handler_errors - async def Tokenize(self, request: BatchedTokenizeRequest, - context: ServicerContext) -> BatchedTokenizeResponse: + async def Tokenize( + self, request: BatchedTokenizeRequest, context: ServicerContext + ) -> BatchedTokenizeResponse: + """ + Handles tokenization requests by tokenizing input texts and + returning tokenized results. If request.truncate_input_tokens is + provided, the tokenization will contain the truncated results. + + Args: + request (BatchedTokenizeRequest): The tokenization request + containing texts to be tokenized. + context (ServicerContext): The context for the RPC call. + + Returns: + BatchedTokenizeResponse: The response containing the + tokenized results. + """ + # Log the incoming tokenization request for metrics service_metrics.observe_tokenization_request(request) - #TODO implement these - if request.return_offsets: - await context.abort(StatusCode.INVALID_ARGUMENT, - "return_offsets not yet supported") - if request.truncate_input_tokens: - await context.abort(StatusCode.INVALID_ARGUMENT, - "truncate_input_tokens not yet supported") + # Initialize an empty list to store individual tokenization responses responses: List[TokenizeResponse] = [] - #TODO maybe parallelize, also move convert_ids_to_tokens - # into the other threads + # TODO: maybe parallelize, also move convert_ids_to_tokens into the + # other threads for req in request.requests: - token_ids = await self.tokenizer_group.encode_async(req.text) - responses.append( - TokenizeResponse( - token_count=len(token_ids), - tokens=None if not request.return_tokens else - self.tokenizer.convert_ids_to_tokens(token_ids))) - + batch_encoding = self.tokenizer.encode_plus( + text=req.text, + return_offsets_mapping=request.return_offsets + ) # Tokenize the input text and get offset_mapping + + # Tokenize the input text async + token_ids = batch_encoding.input_ids + token_count = len(token_ids) + + # Truncate the token count if truncate_input_tokens + if 1 <= request.truncate_input_tokens < token_count: + token_count = request.truncate_input_tokens + + # Initialize Tokens from ids + tokens = self.tokenizer.convert_ids_to_tokens(token_ids) + offsets = None # Initialize offsets to None + + # Offset calc. steps + if request.return_offsets: + offsets = [ + {'start': start, 'end': end} + for start, end in batch_encoding.offset_mapping + if start is not None and end is not None + ] + # Truncate offset list if request.truncate_input_tokens + offsets=offsets[-token_count:] + + # Return a token list (Truncated if request.truncate_input_tokens) + tokens = tokens[-token_count:] if request.return_tokens else None + + # Append the response for the current request + responses.append(TokenizeResponse(token_count=token_count, + tokens=tokens, + offsets=offsets)) + + # Create a batched response containing all individual responses response = BatchedTokenizeResponse(responses=responses) + + # Log the current tokenization response for metrics service_metrics.observe_tokenization_response(response) - return response + # Return the batched tokenization response + return response @log_rpc_handler_errors async def ModelInfo(self, request: ModelInfoRequest, context: ServicerContext) -> ModelInfoResponse: @@ -596,7 +638,6 @@ async def ModelInfo(self, request: ModelInfoRequest, max_new_tokens=self.max_max_new_tokens, ) - async def start_grpc_server(engine: AsyncLLMEngine, args: argparse.Namespace) -> aio.Server: