Skip to content

Commit

Permalink
Merge branch 'add-custom-model' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 19, 2024
2 parents 2a5472d + a7d176c commit a1b0d6b
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions examples/custom_models/local_mt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class LocalMTClient(LightevalModel):
where src and tgt are ISO language codes (2 or 3 letter codes supported).
Example:
```lighteval custom facebook/seamless-m4t-v2-large examples/custom_models/local_mt_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10
```lighteval custom facebook/seamless-m4t-v2-large examples/custom_models/local_mt_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 --save-details
```
Note:
Expand All @@ -101,6 +101,10 @@ def __init__(self, config, env_config) -> None:
self._tokenizer = AutoProcessor.from_pretrained(config.model)
self._model = SeamlessM4Tv2ForTextToText.from_pretrained(config.model)
self.model_type = "seamless-4mt"
self.batch_size = 1
logger.info(
"Using batch size of 1 for seamless-4mt model because it the target language needs to be set for the entire batch."
)
elif "madlad400" in config.model:
self._tokenizer = AutoTokenizer.from_pretrained(config.model)
self._model = AutoModelForSeq2SeqLM.from_pretrained(config.model)
Expand Down Expand Up @@ -176,20 +180,17 @@ def get_langs(task_name: str) -> tuple[str, str]:
):
batch = current_requests[batch_idx : batch_idx + batch_size]

# Batch tokenize all inputs together instead of concatenating pre-tokenized inputs
# Batch tokenize all inputs together instead of concatenating pre-tokenized inputs because of the padding
batch_texts = [r.context for r in batch]
src_lang = get_langs(batch[0].task_name)[0] # All source languages should be the same in a batch

# This is the tokenization step that really counts, as it actually gets used
tokenizer_kwargs = {"text": batch_texts, "return_tensors": "pt", "padding": True}
if self.model_type == "seamless-4mt":
src_lang = get_langs(batch[0].task_name)[0]
tokenizer_kwargs["src_lang"] = src_lang

input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).to(self.device).values()

tgt_langs = [get_langs(r.task_name)[1] for r in batch]
assert set(tgt_langs) == {tgt_langs[0]}, "All target languages must be the same"

generation_sizes = [r.generation_size for r in batch]
assert set(generation_sizes) == {generation_sizes[0]}, "All generation sizes must be the same"

Expand All @@ -200,7 +201,8 @@ def get_langs(task_name: str) -> tuple[str, str]:
"max_new_tokens": generation_sizes[0],
}
if self.model_type == "seamless-4mt":
generate_kwargs["tgt_lang"] = tgt_langs[0]
tgt_lang = get_langs(batch[0].task_name)[1]
generate_kwargs["tgt_lang"] = tgt_lang

output_ids = self._model.generate(**generate_kwargs)
translations = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True)
Expand Down

0 comments on commit a1b0d6b

Please sign in to comment.