From b2a1f44f30b5db9ffd1d46858be1333cc0d025f3 Mon Sep 17 00:00:00 2001 From: Isaac Schifferer Date: Mon, 13 Jan 2025 19:31:38 -0500 Subject: [PATCH] Make attn_implementation configurable for huggingface models --- machine/jobs/huggingface/hugging_face_nmt_model_factory.py | 6 +++++- machine/jobs/settings.yaml | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py index 81d66c9..d1c9af9 100644 --- a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py +++ b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py @@ -65,7 +65,11 @@ def init(self) -> None: ) self._model = cast( PreTrainedModel, - AutoModelForSeq2SeqLM.from_pretrained(self._config.huggingface.parent_model_name, config=config), + AutoModelForSeq2SeqLM.from_pretrained( + self._config.huggingface.parent_model_name, + config=config, + attn_implementation=self._config.huggingface.attn_implementation, + ), ) def create_source_tokenizer_trainer(self, corpus: TextCorpus) -> Trainer: diff --git a/machine/jobs/settings.yaml b/machine/jobs/settings.yaml index 00d9517..663b942 100644 --- a/machine/jobs/settings.yaml +++ b/machine/jobs/settings.yaml @@ -27,6 +27,7 @@ default: tokenizer: add_unk_src_tokens: true add_unk_trg_tokens: true + attn_implementation: sdpa thot_mt: word_alignment_model_type: hmm tokenizer: latin