Skip to content

Commit

Permalink
Merge branch 'add-custom-model' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 18, 2024
2 parents dd12702 + b7106e4 commit 2a5472d
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/source/package_reference/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
## Endpoints-based Models
### InferenceEndpointModel
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig
[[autodoc]] models.endpoints.endpoint_model.ServerlessEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel

### TGI ModelClient
Expand Down
17 changes: 14 additions & 3 deletions examples/custom_models/local_mt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Optional

import pycountry
import torch
from tqdm import tqdm
from transformers import (
AutoModelForSeq2SeqLM,
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(self, config, env_config) -> None:
self.model = config.model
self.model_definition_file_path = config.model_definition_file_path
self.batch_size = 32
self.device = "cuda" if torch.cuda.is_available() else "cpu"

self.model_info = ModelInfo(
model_name=config.model,
Expand All @@ -106,6 +108,9 @@ def __init__(self, config, env_config) -> None:
else:
raise ValueError(f"Unsupported model: {config.model}")

self._model.to(self.device)
self._model.eval()

def _convert_to_iso3(self, lang_code: str) -> str:
"""Convert 2-letter ISO code to 3-letter ISO code."""
try:
Expand Down Expand Up @@ -166,7 +171,9 @@ def get_langs(task_name: str) -> tuple[str, str]:
current_requests = dataset.sorted_data[split_start:split_end]

# Process in batches
for batch_idx in range(0, len(current_requests), batch_size):
for batch_idx in tqdm(
range(0, len(current_requests), batch_size), desc="Batches", position=1, disable=False
):
batch = current_requests[batch_idx : batch_idx + batch_size]

# Batch tokenize all inputs together instead of concatenating pre-tokenized inputs
Expand All @@ -178,15 +185,19 @@ def get_langs(task_name: str) -> tuple[str, str]:
if self.model_type == "seamless-4mt":
tokenizer_kwargs["src_lang"] = src_lang

input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).values()
input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).to(self.device).values()

tgt_langs = [get_langs(r.task_name)[1] for r in batch]
assert set(tgt_langs) == {tgt_langs[0]}, "All target languages must be the same"

generation_sizes = [r.generation_size for r in batch]
assert set(generation_sizes) == {generation_sizes[0]}, "All generation sizes must be the same"

# Use unpacked values directly
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": generation_sizes[0],
}
if self.model_type == "seamless-4mt":
generate_kwargs["tgt_lang"] = tgt_langs[0]
Expand All @@ -212,7 +223,7 @@ def tokenizer(self):
return self._tokenizer

def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False)
return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False).to(self.device)

@property
def add_special_tokens(self) -> bool:
Expand Down
19 changes: 12 additions & 7 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def inference_endpoint(
str, Argument(help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml)")
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
free_endpoint: Annotated[
bool,
Option(
help="Use serverless free endpoints instead of spinning up your own inference endpoint.",
rich_help_panel=HELP_PANEL_NAME_4,
),
] = False,
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4)
Expand Down Expand Up @@ -200,9 +207,7 @@ def inference_endpoint(
"""

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModelConfig,
)
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
Expand All @@ -220,10 +225,10 @@ def inference_endpoint(
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote

# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])

model_config = InferenceEndpointModelConfig.from_path(model_config_path)
if free_endpoint:
model_config = ServerlessEndpointModelConfig.from_path(model_config_path)
else:
model_config = InferenceEndpointModelConfig.from_path(model_config_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down
29 changes: 21 additions & 8 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,18 @@


@dataclass
class InferenceModelConfig:
model: str
class ServerlessEndpointModelConfig:
model_name: str
add_special_tokens: bool = True

@classmethod
def from_path(cls, path: str) -> "ServerlessEndpointModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
return cls(**config["base_params"])


@dataclass
class InferenceEndpointModelConfig:
Expand Down Expand Up @@ -150,7 +158,7 @@ class InferenceEndpointModel(LightevalModel):
"""

def __init__( # noqa: C901
self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig
self, config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig
) -> None:
self.reuse_existing = getattr(config, "reuse_existing", False)
self._max_length = None
Expand Down Expand Up @@ -280,10 +288,10 @@ def __init__( # noqa: C901
else: # Free inference client
self.endpoint = None
self.endpoint_name = None
self.name = config.model
self.name = config.model_name
self.revision = "default"
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
self.client = InferenceClient(model=config.model, token=env_config.token)
self.async_client = AsyncInferenceClient(model=config.model_name, token=env_config.token)
self.client = InferenceClient(model=config.model_name, token=env_config.token)

self.use_async = True # set to False for debug - async use is faster

Expand All @@ -293,7 +301,7 @@ def __init__( # noqa: C901
self.model_info = ModelInfo(
model_name=self.name,
model_sha=self.revision,
model_dtype=config.model_dtype or "default",
model_dtype=getattr(config, "model_dtype", "default"),
model_size=-1,
)

Expand Down Expand Up @@ -545,7 +553,12 @@ def loglikelihood(
cont_toks = torch.tensor(cur_request.tokenized_continuation)
len_choice = len(cont_toks)

logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None]
if self.endpoint: # inference endpoint
logits = [
t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None
] # to check
else: # serverless endpoint
logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None]

greedy_tokens = torch.tensor(logits).argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all().squeeze(0)
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModel,
InferenceEndpointModelConfig,
InferenceModelConfig,
ServerlessEndpointModelConfig,
)
from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig
from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig
Expand Down Expand Up @@ -87,7 +87,7 @@ def load_model( # noqa: C901
if isinstance(config, TGIModelConfig):
return load_model_with_tgi(config)

if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig):
if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, ServerlessEndpointModelConfig):
return load_model_with_inference_endpoints(config, env_config=env_config)

if isinstance(config, BaseModelConfig):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/endpoints/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TestInferenceEndpointModelConfig:
},
),
(
"examples/model_configs/endpoint_model_lite.yaml",
"examples/model_configs/serverless_model.yaml",
{
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
# Defaults:
Expand Down

0 comments on commit 2a5472d

Please sign in to comment.