Skip to content

Commit

Permalink
Make sure generation can happen on the GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Dec 18, 2024
1 parent bd08781 commit b7106e4
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions examples/custom_models/local_mt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Optional

import pycountry
import torch
from tqdm import tqdm
from transformers import (
AutoModelForSeq2SeqLM,
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(self, config, env_config) -> None:
self.model = config.model
self.model_definition_file_path = config.model_definition_file_path
self.batch_size = 32
self.device = "cuda" if torch.cuda.is_available() else "cpu"

self.model_info = ModelInfo(
model_name=config.model,
Expand All @@ -106,6 +108,9 @@ def __init__(self, config, env_config) -> None:
else:
raise ValueError(f"Unsupported model: {config.model}")

self._model.to(self.device)
self._model.eval()

def _convert_to_iso3(self, lang_code: str) -> str:
"""Convert 2-letter ISO code to 3-letter ISO code."""
try:
Expand Down Expand Up @@ -166,7 +171,9 @@ def get_langs(task_name: str) -> tuple[str, str]:
current_requests = dataset.sorted_data[split_start:split_end]

# Process in batches
for batch_idx in range(0, len(current_requests), batch_size):
for batch_idx in tqdm(
range(0, len(current_requests), batch_size), desc="Batches", position=1, disable=False
):
batch = current_requests[batch_idx : batch_idx + batch_size]

# Batch tokenize all inputs together instead of concatenating pre-tokenized inputs
Expand All @@ -178,15 +185,19 @@ def get_langs(task_name: str) -> tuple[str, str]:
if self.model_type == "seamless-4mt":
tokenizer_kwargs["src_lang"] = src_lang

input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).values()
input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).to(self.device).values()

tgt_langs = [get_langs(r.task_name)[1] for r in batch]
assert set(tgt_langs) == {tgt_langs[0]}, "All target languages must be the same"

generation_sizes = [r.generation_size for r in batch]
assert set(generation_sizes) == {generation_sizes[0]}, "All generation sizes must be the same"

# Use unpacked values directly
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": generation_sizes[0],
}
if self.model_type == "seamless-4mt":
generate_kwargs["tgt_lang"] = tgt_langs[0]
Expand All @@ -212,7 +223,7 @@ def tokenizer(self):
return self._tokenizer

def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False)
return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False).to(self.device)

@property
def add_special_tokens(self) -> bool:
Expand Down

0 comments on commit b7106e4

Please sign in to comment.