diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index 9028178a..51ab5384 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -16,12 +16,10 @@ OpenAIModelInfo, RerankInput, ) -from infinity_emb.inference import ( - DeviceTypeHint, -) +from infinity_emb.inference import Device, DeviceTypeHint from infinity_emb.inference.caching_layer import INFINITY_CACHE_VECTORS from infinity_emb.log_handler import UVICORN_LOG_LEVELS, logger -from infinity_emb.transformer.utils import InferenceEngineTypeHint +from infinity_emb.transformer.utils import InferenceEngine, InferenceEngineTypeHint def create_server( @@ -244,10 +242,10 @@ def _start_uvicorn( batch_size=batch_size, revision=revision, trust_remote_code=trust_remote_code, - engine=engine, # type: ignore + engine=InferenceEngine[engine.value], # type: ignore model_warmup=model_warmup, vector_disk_cache_path=vector_disk_cache_path, - device=device, # type: ignore + device=Device[device.value], # type: ignore lengths_via_tokenize=lengths_via_tokenize, )