Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tokenization and truncation funcionality to Tokenize() function. #39

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ einops # required for MPT
httpx
peft
requests
ray
#ray
sentence-transformers # required for embedding

# Benchmarking
Expand Down
15 changes: 15 additions & 0 deletions tests/entrypoints/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
gen-client:
# Compile protos
pip install grpcio-tools==1.60.0 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir
mkdir pb || true
python -m grpc_tools.protoc -I../../proto --python_out=pb \
--grpc_python_out=pb --mypy_out=pb ../../proto/generation.proto
find pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch pb/__init__.py

install: gen-client
pip install pip --upgrade
pip install -e . --no-cache-dir

test:
pytest -sv test_server_tokenize_truncate.py
191 changes: 191 additions & 0 deletions tests/entrypoints/test_server_tokenize_truncate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# imports for guided decoding tests
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import grpc
# to install pb, run Makefile to compile grpc protobuf
from .pb import generation_pb2_grpc as gpb2, generation_pb2 as pb2
from vllm.transformers_utils.tokenizer import get_tokenizer

from ..utils import ServerRunner

# Config. vars for gRPC
SERVER = 'localhost'
PORT = 8033

# The tokenizer was tested using the following model:
MODEL_NAME = "facebook/opt-125m"

@pytest.fixture(scope="module")
def server():
ray.init()
server_runner = ServerRunner.remote([
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16"
])
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()

# Fixture to create a gRPC stub for the GenerationService
@pytest.fixture(scope="module")
def grpc_stub():
channel = grpc.insecure_channel(f"{SERVER}:{PORT}")
stub = gpb2.GenerationServiceStub(channel)
yield stub
channel.close()

# Test cases
@pytest.mark.parametrize("test_case", [
{
"name": "Tokenize with offsets",
"request": {
"text": "The very long story is written",
"return_offsets": True,
},
"response": {
"tokenCount": 7,
"offsets": [
{"start": 0, "end": 0},
{"start": 0, "end": 3},
{"start": 3, "end": 8},
{"start": 8, "end": 13},
{"start": 13, "end": 19},
{"start": 19, "end": 22},
{"start": 22, "end": 30},
],
},
},
{
"name": "Tokenize with tokens and offsets",
"request": {
"text": "The very long story is written",
"return_tokens": True,
"return_offsets": True,
},
"response": {
"tokenCount": 7,
"tokens": [
"</s>",
"The",
"Ġvery",
"Ġlong",
"Ġstory",
"Ġis",
"Ġwritten"
],
"offsets": [
{"start": 0, "end": 0},
{"start": 0, "end": 3},
{"start": 3, "end": 8},
{"start": 8, "end": 13},
{"start": 13, "end": 19},
{"start": 19, "end": 22},
{"start": 22, "end": 30},
],
},
},
{
"name": "Tokenize with tokens and truncation",
"request": {
"text": "The very long story is written by a very long story",
"return_tokens": True,
"truncate_input_tokens": 10,
},
"response": {
"tokenCount": 10,
"tokens": [
"Ġvery",
"Ġlong",
"Ġstory",
"Ġis",
"Ġwritten",
"Ġby",
"Ġa",
"Ġvery",
"Ġlong",
"Ġstory",
],
},
},
{
"name": "Tokenize, trunc and offset for a request with no text message",
"request": {
"text": "",
"return_offsets": True,
"return_tokens": True,
"truncate_input_tokens": 10,
},
"response": {
"tokenCount": 1,
"tokens": [
"</s>"
],
},
},
{
"name": "A request without text ('') and parameters",
"request": {
"text" : ""
},
"response": {
"tokenCount": 1
},
},
{
"name": "A request without text (None) and parameters",
"request": {
"text" : None
},
"response": {
"tokenCount": 1
},
},
])
def test_tokenization(server, grpc_stub, test_case):
"""Test tokenization with the given test case."""
text = test_case['request']['text']
truncate_input_tokens = test_case['request'].get('truncate_input_tokens',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you can add a new variable request = test_case['request'] to eliminate some repetition.

None)

# Construct the request
batch = pb2.BatchedTokenizeRequest(
model_id="unused",
requests=[pb2.TokenizeRequest(text=text)],
return_tokens=test_case['request'].get('return_tokens', False),
return_offsets=test_case['request'].get('return_offsets', False),
truncate_input_tokens=truncate_input_tokens
)

try:
responses = grpc_stub.Tokenize(batch).responses
except grpc.RpcError as e:
# Print debug message in case of connection failure
print(f"Failed to connect to the gRPC server: {e}")
pytest.fail(f"gRPC call failed with error: {e}")

# Verify the response
expected_response = test_case['response']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better replace with test_case['response'][0] and eliminate the for loop below.

for resp in responses:
assert resp.token_count == expected_response['tokenCount'],\
"Token count mismatch"
if 'tokens' in expected_response:
assert resp.tokens == expected_response['tokens'],\
"Tokens mismatch"
if 'offsets' in expected_response:
expected_offsets = expected_response['offsets']
assert len(resp.offsets) == len(expected_offsets),\
"Offset length mismatch"
for resp_offset, exp_offset in zip(resp.offsets, expected_offsets):
assert resp_offset.start == exp_offset.get('start', None),\
"Start offset mismatch"
assert resp_offset.end == exp_offset.get('end', None),\
"End offset mismatch"
print("Test case passed: ", test_case["name"])

if __name__ == "__main__":
pytest.main([__file__])
81 changes: 61 additions & 20 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,33 +559,75 @@ async def _validate_prompt_and_tokenize(
return input_ids, max_is_token_limit

@log_rpc_handler_errors
async def Tokenize(self, request: BatchedTokenizeRequest,
context: ServicerContext) -> BatchedTokenizeResponse:
async def Tokenize(
self, request: BatchedTokenizeRequest, context: ServicerContext
) -> BatchedTokenizeResponse:
"""
Handles tokenization requests by tokenizing input texts and
returning tokenized results. If request.truncate_input_tokens is
provided, the tokenization will contain the truncated results.

Args:
request (BatchedTokenizeRequest): The tokenization request
containing texts to be tokenized.
context (ServicerContext): The context for the RPC call.

Returns:
BatchedTokenizeResponse: The response containing the
tokenized results.
"""
# Log the incoming tokenization request for metrics
service_metrics.observe_tokenization_request(request)
#TODO implement these
if request.return_offsets:
await context.abort(StatusCode.INVALID_ARGUMENT,
"return_offsets not yet supported")
if request.truncate_input_tokens:
await context.abort(StatusCode.INVALID_ARGUMENT,
"truncate_input_tokens not yet supported")

# Initialize an empty list to store individual tokenization responses
responses: List[TokenizeResponse] = []

#TODO maybe parallelize, also move convert_ids_to_tokens
# into the other threads
# TODO: maybe parallelize, also move convert_ids_to_tokens into the
# other threads
for req in request.requests:
token_ids = await self.tokenizer_group.encode_async(req.text)
responses.append(
TokenizeResponse(
token_count=len(token_ids),
tokens=None if not request.return_tokens else
self.tokenizer.convert_ids_to_tokens(token_ids)))

batch_encoding = self.tokenizer.encode_plus(
text=req.text,
return_offsets_mapping=request.return_offsets
) # Tokenize the input text and get offset_mapping

# Tokenize the input text async
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment doesn't match what the code does.

token_ids = batch_encoding.input_ids
token_count = len(token_ids)

# Truncate the token count if truncate_input_tokens
if 1 <= request.truncate_input_tokens < token_count:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be easier to read 0 < request...tokens < token_count.

token_count = request.truncate_input_tokens

# Initialize Tokens from ids
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
offsets = None # Initialize offsets to None

# Offset calc. steps
if request.return_offsets:
offsets = [
{'start': start, 'end': end}
for start, end in batch_encoding.offset_mapping
if start is not None and end is not None
]
# Truncate offset list if request.truncate_input_tokens
offsets=offsets[-token_count:]

# Return a token list (Truncated if request.truncate_input_tokens)
tokens = tokens[-token_count:] if request.return_tokens else None

# Append the response for the current request
responses.append(TokenizeResponse(token_count=token_count,
tokens=tokens,
offsets=offsets))

# Create a batched response containing all individual responses
response = BatchedTokenizeResponse(responses=responses)

# Log the current tokenization response for metrics
service_metrics.observe_tokenization_response(response)
return response

# Return the batched tokenization response
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove most of these comments because it's fairly obvious what the code does.

return response
@log_rpc_handler_errors
async def ModelInfo(self, request: ModelInfoRequest,
context: ServicerContext) -> ModelInfoResponse:
Expand All @@ -596,7 +638,6 @@ async def ModelInfo(self, request: ModelInfoRequest,
max_new_tokens=self.max_max_new_tokens,
)


async def start_grpc_server(engine: AsyncLLMEngine,
args: argparse.Namespace) -> aio.Server:

Expand Down