Skip to content

Commit

Permalink
add args
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed May 31, 2024
1 parent d83420a commit ccf8b47
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 53 deletions.
10 changes: 5 additions & 5 deletions libs/infinity_emb/infinity_emb/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ class EngineArgs:
batch_size: int = MANAGER.batch_size[0]
revision: Optional[str] = MANAGER.revision[0]
trust_remote_code: bool = MANAGER.trust_remote_code[0]
engine: InferenceEngine = MANAGER.engine[0]
engine: InferenceEngine = InferenceEngine[MANAGER.engine[0]]
model_warmup: bool = MANAGER.model_warmup[0]
vector_disk_cache_path: str = ""
device: Device = MANAGER.device[0]
device: Device = Device[MANAGER.device[0]]
compile: bool = MANAGER.compile[0]
bettertransformer: bool = MANAGER.bettertransformer[0]
dtype: Dtype = MANAGER.dtype[0]
pooling_method: PoolingMethod = MANAGER.pooling_method[0]
dtype: Dtype = Dtype[MANAGER.dtype[0]]
pooling_method: PoolingMethod = PoolingMethod[MANAGER.pooling_method[0]]
lengths_via_tokenize: bool = MANAGER.lengths_via_tokenize[0]
embedding_dtype: EmbeddingDtype = MANAGER.embedding_dtype[0]
embedding_dtype: EmbeddingDtype = EmbeddingDtype[MANAGER.embedding_dtype[0]]
served_model_name: str = MANAGER.served_model_name[0]

def __post_init__(self):
Expand Down
56 changes: 21 additions & 35 deletions libs/infinity_emb/infinity_emb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@

import os
from functools import cached_property
from typing import TypeVar

from infinity_emb.primitives import (
Device,
Dtype,
EmbeddingDtype,
EnumType,
InferenceEngine,
PoolingMethod,
)

EnumTypeLike = TypeVar("EnumTypeLike", bound=EnumType)


class __Infinity_EnvManager:
def __init__(self):
Expand Down Expand Up @@ -171,50 +175,32 @@ def redirect_slash(self):
def log_level(self):
return self._optional_infinity_var("log_level", default="info")

def _typed_multiple(self, name: str, cls: type["EnumTypeLike"]) -> list["str"]:
result = self._optional_infinity_var_multiple(
name, default=[cls.default_value()]
)
assert all(cls(v) for v in result)
return result

@cached_property
def dtype(self) -> list[Dtype]:
return [
Dtype(v)
for v in self._optional_infinity_var_multiple(
"dtype", default=[Dtype.default_value()]
)
]
def dtype(self) -> list[str]:
return self._typed_multiple("dtype", cls=Dtype)

@cached_property
def engine(self) -> list[InferenceEngine]:
return [
InferenceEngine(v)
for v in self._optional_infinity_var_multiple(
"engine", default=[InferenceEngine.default_value()]
)
]
def engine(self) -> list[str]:
return self._typed_multiple("engine", InferenceEngine)

@cached_property
def pooling_method(self) -> list[PoolingMethod]:
return [
PoolingMethod(v)
for v in self._optional_infinity_var_multiple(
"pooling_method", default=[PoolingMethod.default_value()]
)
]
def pooling_method(self) -> list[str]:
return self._typed_multiple("pooling_method", PoolingMethod)

@cached_property
def device(self) -> list[Device]:
return [
Device(v)
for v in self._optional_infinity_var_multiple(
"device", default=[Device.default_value()]
)
]
def device(self) -> list[str]:
return self._typed_multiple("device", Device)

@cached_property
def embedding_dtype(self) -> list[EmbeddingDtype]:
return [
EmbeddingDtype(v)
for v in self._optional_infinity_var_multiple(
"embedding_dtype", default=[EmbeddingDtype.default_value()]
)
]
def embedding_dtype(self) -> list[str]:
return self._typed_multiple("embedding_dtype", EmbeddingDtype)


MANAGER = __Infinity_EnvManager()
26 changes: 13 additions & 13 deletions libs/infinity_emb/infinity_emb/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def names_enum(cls) -> enum.Enum:
cls.__name__ + "__names", {k: k for k in cls.__members__.keys()}
)

@abstractmethod
def default_value(self) -> str:
...
@staticmethod
def default_value() -> str:
raise NotImplementedError


class InferenceEngine(EnumType):
Expand All @@ -58,8 +58,8 @@ class InferenceEngine(EnumType):
optimum = "optimum"
debugengine = "debugengine"

@classmethod
def default_value(self):
@staticmethod
def default_value():
return InferenceEngine.torch.value


Expand All @@ -70,8 +70,8 @@ class Device(EnumType):
tensorrt = "tensorrt"
auto = "auto"

@classmethod
def default_value(self):
@staticmethod
def default_value():
return Device.auto.value

def resolve(self) -> Optional[str]:
Expand All @@ -87,8 +87,8 @@ class Dtype(EnumType):
fp8: str = "fp8"
auto: str = "auto"

@classmethod
def default_value(self):
@staticmethod
def default_value():
return Dtype.auto.value


Expand All @@ -97,8 +97,8 @@ class EmbeddingDtype(EnumType):
# int8: str = "int8"
# binary: str = "binary"

@classmethod
def default_value(self):
@staticmethod
def default_value():
return EmbeddingDtype.float32.value


Expand All @@ -107,8 +107,8 @@ class PoolingMethod(EnumType):
cls: str = "cls"
auto: str = "auto"

@classmethod
def default_value(self):
@staticmethod
def default_value():
return PoolingMethod.auto.value


Expand Down

0 comments on commit ccf8b47

Please sign in to comment.