diff --git a/examples/inference-deployments/mpt/mpt_handler.py b/examples/inference-deployments/mpt/mpt_handler.py index cf52bf70b..b7c23bbf7 100644 --- a/examples/inference-deployments/mpt/mpt_handler.py +++ b/examples/inference-deployments/mpt/mpt_handler.py @@ -13,7 +13,8 @@ class MPTModelHandler(): DEFAULT_GENERATE_KWARGS = { - 'max_length': 256, + 'max_length': 256, # Counts input + output tokens (deprecated) + 'max_new_tokens': 256, # Only counts output tokens 'use_cache': True, 'do_sample': True, 'top_p': 0.95, @@ -76,7 +77,7 @@ def predict(self, model_requests: List[Dict]): model_requests: List of dictionaries that contain forward pass inputs as well as other parameters, such as generate kwargs. - ex. [{'input': 'hello world!', 'parameters': {'max_length': 10}] + ex. [{'input': 'hello world!', 'parameters': {'max_new_tokens': 10}] """ generate_inputs = [] generate_kwargs = {}