Skip to content

Commit

Permalink
feat: Allow custom model configuration
Browse files Browse the repository at this point in the history
eg local models usage with increased context
  • Loading branch information
RezaRahemtola committed Nov 26, 2024
1 parent 788410e commit 3f14654
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions libertai_agents/libertai_agents/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,47 @@
class ModelConfiguration(BaseModel):
vm_url: str
context_length: int


class FullModelConfiguration(ModelConfiguration):
constructor: typing.Type[Model]


MODEL_IDS: list[ModelId] = list(typing.get_args(ModelId))

MODELS_CONFIG: dict[ModelId, ModelConfiguration] = {
"NousResearch/Hermes-3-Llama-3.1-8B": ModelConfiguration(vm_url="https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion",
context_length=4096,
constructor=HermesModel),
"mistralai/Mistral-Nemo-Instruct-2407": ModelConfiguration(vm_url="https://curated.aleph.cloud/vm/2c4ad0bf343fb12924936cbc801732d95ce90f84cd895aa8bee82c0a062815c2/completion",
context_length=4096,
constructor=MistralModel)
MODELS_CONFIG: dict[ModelId, FullModelConfiguration] = {
"NousResearch/Hermes-3-Llama-3.1-8B": FullModelConfiguration(
vm_url="https://curated.aleph.cloud/vm/84df52ac4466d121ef3bb409bb14f315de7be4ce600e8948d71df6485aa5bcc3/completion",
context_length=4096,
constructor=HermesModel),
"mistralai/Mistral-Nemo-Instruct-2407": FullModelConfiguration(
vm_url="https://curated.aleph.cloud/vm/2c4ad0bf343fb12924936cbc801732d95ce90f84cd895aa8bee82c0a062815c2/completion",
context_length=4096,
constructor=MistralModel)
}


def get_model(model_id: ModelId, hf_token: str | None = None) -> Model:
def get_model(model_id: ModelId, hf_token: str | None = None,
custom_configuration: ModelConfiguration | None = None) -> Model:
"""
Get one of the available models
:param model_id: HuggingFace ID of the model, must be one of the supported models
:param hf_token: Optional access token, required to use gated models
:param custom_configuration: Optional model configuration (useful to use a local model)
:return: An instance of the model
"""
model_configuration = MODELS_CONFIG.get(model_id)

if model_configuration is None:
# Fetching our full configuration with the model constructor
full_config = MODELS_CONFIG.get(model_id)

if full_config is None:
raise ValueError(f'model_id must be one of {MODEL_IDS}')

if hf_token is not None:
login(hf_token)

return model_configuration.constructor(model_id=model_id, **model_configuration.dict(exclude={'constructor'}))
# Using our configuration if the user didn't pass a custom model config
configuration = custom_configuration if custom_configuration is not None else full_config

return full_config.constructor(model_id=model_id, **configuration.dict(exclude={'constructor'}))

0 comments on commit 3f14654

Please sign in to comment.