-
-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tokenization and truncation funcionality to Tokenize() function. #39
Changes from all commits
040a69a
570f7aa
9b93eef
1f3e41e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
gen-client: | ||
# Compile protos | ||
pip install grpcio-tools==1.60.0 mypy-protobuf==3.5.0 'types-protobuf>=3.20.4' --no-cache-dir | ||
mkdir pb || true | ||
python -m grpc_tools.protoc -I../../proto --python_out=pb \ | ||
--grpc_python_out=pb --mypy_out=pb ../../proto/generation.proto | ||
find pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; | ||
touch pb/__init__.py | ||
|
||
install: gen-client | ||
pip install pip --upgrade | ||
pip install -e . --no-cache-dir | ||
|
||
test: | ||
pytest -sv test_server_tokenize_truncate.py |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# imports for guided decoding tests | ||
import pytest | ||
# using Ray for overall ease of process management, parallel requests, | ||
# and debugging. | ||
import ray | ||
import grpc | ||
# to install pb, run Makefile to compile grpc protobuf | ||
from .pb import generation_pb2_grpc as gpb2, generation_pb2 as pb2 | ||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||
|
||
from ..utils import ServerRunner | ||
|
||
# Config. vars for gRPC | ||
SERVER = 'localhost' | ||
PORT = 8033 | ||
|
||
# The tokenizer was tested using the following model: | ||
MODEL_NAME = "facebook/opt-125m" | ||
|
||
@pytest.fixture(scope="module") | ||
def server(): | ||
ray.init() | ||
server_runner = ServerRunner.remote([ | ||
"--model", | ||
MODEL_NAME, | ||
# use half precision for speed and memory savings in CI environment | ||
"--dtype", | ||
"bfloat16" | ||
]) | ||
ray.get(server_runner.ready.remote()) | ||
yield server_runner | ||
ray.shutdown() | ||
|
||
# Fixture to create a gRPC stub for the GenerationService | ||
@pytest.fixture(scope="module") | ||
def grpc_stub(): | ||
channel = grpc.insecure_channel(f"{SERVER}:{PORT}") | ||
stub = gpb2.GenerationServiceStub(channel) | ||
yield stub | ||
channel.close() | ||
|
||
# Test cases | ||
@pytest.mark.parametrize("test_case", [ | ||
{ | ||
"name": "Tokenize with offsets", | ||
"request": { | ||
"text": "The very long story is written", | ||
"return_offsets": True, | ||
}, | ||
"response": { | ||
"tokenCount": 7, | ||
"offsets": [ | ||
{"start": 0, "end": 0}, | ||
{"start": 0, "end": 3}, | ||
{"start": 3, "end": 8}, | ||
{"start": 8, "end": 13}, | ||
{"start": 13, "end": 19}, | ||
{"start": 19, "end": 22}, | ||
{"start": 22, "end": 30}, | ||
], | ||
}, | ||
}, | ||
{ | ||
"name": "Tokenize with tokens and offsets", | ||
"request": { | ||
"text": "The very long story is written", | ||
"return_tokens": True, | ||
"return_offsets": True, | ||
}, | ||
"response": { | ||
"tokenCount": 7, | ||
"tokens": [ | ||
"</s>", | ||
"The", | ||
"Ġvery", | ||
"Ġlong", | ||
"Ġstory", | ||
"Ġis", | ||
"Ġwritten" | ||
], | ||
"offsets": [ | ||
{"start": 0, "end": 0}, | ||
{"start": 0, "end": 3}, | ||
{"start": 3, "end": 8}, | ||
{"start": 8, "end": 13}, | ||
{"start": 13, "end": 19}, | ||
{"start": 19, "end": 22}, | ||
{"start": 22, "end": 30}, | ||
], | ||
}, | ||
}, | ||
{ | ||
"name": "Tokenize with tokens and truncation", | ||
"request": { | ||
"text": "The very long story is written by a very long story", | ||
"return_tokens": True, | ||
"truncate_input_tokens": 10, | ||
}, | ||
"response": { | ||
"tokenCount": 10, | ||
"tokens": [ | ||
"Ġvery", | ||
"Ġlong", | ||
"Ġstory", | ||
"Ġis", | ||
"Ġwritten", | ||
"Ġby", | ||
"Ġa", | ||
"Ġvery", | ||
"Ġlong", | ||
"Ġstory", | ||
], | ||
}, | ||
}, | ||
{ | ||
"name": "Tokenize, trunc and offset for a request with no text message", | ||
"request": { | ||
"text": "", | ||
"return_offsets": True, | ||
"return_tokens": True, | ||
"truncate_input_tokens": 10, | ||
}, | ||
"response": { | ||
"tokenCount": 1, | ||
"tokens": [ | ||
"</s>" | ||
], | ||
}, | ||
}, | ||
{ | ||
"name": "A request without text ('') and parameters", | ||
"request": { | ||
"text" : "" | ||
}, | ||
"response": { | ||
"tokenCount": 1 | ||
}, | ||
}, | ||
{ | ||
"name": "A request without text (None) and parameters", | ||
"request": { | ||
"text" : None | ||
}, | ||
"response": { | ||
"tokenCount": 1 | ||
}, | ||
}, | ||
]) | ||
def test_tokenization(server, grpc_stub, test_case): | ||
"""Test tokenization with the given test case.""" | ||
text = test_case['request']['text'] | ||
truncate_input_tokens = test_case['request'].get('truncate_input_tokens', | ||
None) | ||
|
||
# Construct the request | ||
batch = pb2.BatchedTokenizeRequest( | ||
model_id="unused", | ||
requests=[pb2.TokenizeRequest(text=text)], | ||
return_tokens=test_case['request'].get('return_tokens', False), | ||
return_offsets=test_case['request'].get('return_offsets', False), | ||
truncate_input_tokens=truncate_input_tokens | ||
) | ||
|
||
try: | ||
responses = grpc_stub.Tokenize(batch).responses | ||
except grpc.RpcError as e: | ||
# Print debug message in case of connection failure | ||
print(f"Failed to connect to the gRPC server: {e}") | ||
pytest.fail(f"gRPC call failed with error: {e}") | ||
|
||
# Verify the response | ||
expected_response = test_case['response'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better replace with |
||
for resp in responses: | ||
assert resp.token_count == expected_response['tokenCount'],\ | ||
"Token count mismatch" | ||
if 'tokens' in expected_response: | ||
assert resp.tokens == expected_response['tokens'],\ | ||
"Tokens mismatch" | ||
if 'offsets' in expected_response: | ||
expected_offsets = expected_response['offsets'] | ||
assert len(resp.offsets) == len(expected_offsets),\ | ||
"Offset length mismatch" | ||
for resp_offset, exp_offset in zip(resp.offsets, expected_offsets): | ||
assert resp_offset.start == exp_offset.get('start', None),\ | ||
"Start offset mismatch" | ||
assert resp_offset.end == exp_offset.get('end', None),\ | ||
"End offset mismatch" | ||
print("Test case passed: ", test_case["name"]) | ||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -559,33 +559,75 @@ async def _validate_prompt_and_tokenize( | |
return input_ids, max_is_token_limit | ||
|
||
@log_rpc_handler_errors | ||
async def Tokenize(self, request: BatchedTokenizeRequest, | ||
context: ServicerContext) -> BatchedTokenizeResponse: | ||
async def Tokenize( | ||
self, request: BatchedTokenizeRequest, context: ServicerContext | ||
) -> BatchedTokenizeResponse: | ||
""" | ||
Handles tokenization requests by tokenizing input texts and | ||
returning tokenized results. If request.truncate_input_tokens is | ||
provided, the tokenization will contain the truncated results. | ||
|
||
Args: | ||
request (BatchedTokenizeRequest): The tokenization request | ||
containing texts to be tokenized. | ||
context (ServicerContext): The context for the RPC call. | ||
|
||
Returns: | ||
BatchedTokenizeResponse: The response containing the | ||
tokenized results. | ||
""" | ||
# Log the incoming tokenization request for metrics | ||
service_metrics.observe_tokenization_request(request) | ||
#TODO implement these | ||
if request.return_offsets: | ||
await context.abort(StatusCode.INVALID_ARGUMENT, | ||
"return_offsets not yet supported") | ||
if request.truncate_input_tokens: | ||
await context.abort(StatusCode.INVALID_ARGUMENT, | ||
"truncate_input_tokens not yet supported") | ||
|
||
# Initialize an empty list to store individual tokenization responses | ||
responses: List[TokenizeResponse] = [] | ||
|
||
#TODO maybe parallelize, also move convert_ids_to_tokens | ||
# into the other threads | ||
# TODO: maybe parallelize, also move convert_ids_to_tokens into the | ||
# other threads | ||
for req in request.requests: | ||
token_ids = await self.tokenizer_group.encode_async(req.text) | ||
responses.append( | ||
TokenizeResponse( | ||
token_count=len(token_ids), | ||
tokens=None if not request.return_tokens else | ||
self.tokenizer.convert_ids_to_tokens(token_ids))) | ||
|
||
batch_encoding = self.tokenizer.encode_plus( | ||
text=req.text, | ||
return_offsets_mapping=request.return_offsets | ||
) # Tokenize the input text and get offset_mapping | ||
|
||
# Tokenize the input text async | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment doesn't match what the code does. |
||
token_ids = batch_encoding.input_ids | ||
token_count = len(token_ids) | ||
|
||
# Truncate the token count if truncate_input_tokens | ||
if 1 <= request.truncate_input_tokens < token_count: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be easier to read |
||
token_count = request.truncate_input_tokens | ||
|
||
# Initialize Tokens from ids | ||
tokens = self.tokenizer.convert_ids_to_tokens(token_ids) | ||
offsets = None # Initialize offsets to None | ||
|
||
# Offset calc. steps | ||
if request.return_offsets: | ||
offsets = [ | ||
{'start': start, 'end': end} | ||
for start, end in batch_encoding.offset_mapping | ||
if start is not None and end is not None | ||
] | ||
# Truncate offset list if request.truncate_input_tokens | ||
offsets=offsets[-token_count:] | ||
|
||
# Return a token list (Truncated if request.truncate_input_tokens) | ||
tokens = tokens[-token_count:] if request.return_tokens else None | ||
|
||
# Append the response for the current request | ||
responses.append(TokenizeResponse(token_count=token_count, | ||
tokens=tokens, | ||
offsets=offsets)) | ||
|
||
# Create a batched response containing all individual responses | ||
response = BatchedTokenizeResponse(responses=responses) | ||
|
||
# Log the current tokenization response for metrics | ||
service_metrics.observe_tokenization_response(response) | ||
return response | ||
|
||
# Return the batched tokenization response | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can remove most of these comments because it's fairly obvious what the code does. |
||
return response | ||
@log_rpc_handler_errors | ||
async def ModelInfo(self, request: ModelInfoRequest, | ||
context: ServicerContext) -> ModelInfoResponse: | ||
|
@@ -596,7 +638,6 @@ async def ModelInfo(self, request: ModelInfoRequest, | |
max_new_tokens=self.max_max_new_tokens, | ||
) | ||
|
||
|
||
async def start_grpc_server(engine: AsyncLLMEngine, | ||
args: argparse.Namespace) -> aio.Server: | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps you can add a new variable
request = test_case['request']
to eliminate some repetition.