From 58d53c87881e8932bcd231d31017a349ba41f88a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Mon, 15 Jan 2024 00:03:28 +0100 Subject: [PATCH] feat: slice & batch over transformer windows to avoid GPU OOM errors --- changelog.md | 1 + .../embeddings/transformer/transformer.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/changelog.md b/changelog.md index b8b992cb5..57f7e45d1 100644 --- a/changelog.md +++ b/changelog.md @@ -11,6 +11,7 @@ - Support doc -> list converters with parquet files writer - Fixed some OOM errors when writing many outputs to parquet files - Both edsnlp & spacy factories are now listed when a factory lookup fails +- Fixed some GPU OOM errors with the `eds.transformer` pipe when processing really long documents ## v0.10.3 diff --git a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py index b39f3a703..2068ac6a5 100644 --- a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py +++ b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py @@ -325,11 +325,25 @@ def forward(self, batch): "mask": batch["mask"].clone(), } - trf_result = self.transformer.base_model( + max_windows = self.max_tokens_per_device // batch["input_ids"].size(1) + kwargs = dict( input_ids=batch["input_ids"].as_tensor(), attention_mask=batch["input_ids"].mask, ) - wordpiece_embeddings = trf_result.last_hidden_state + wordpiece_embeddings = [ + self.transformer.base_model( + **{ + k: None if v is None else v[offset : offset + max_windows] + for k, v in kwargs.items() + } + ).last_hidden_state + for offset in range(0, batch["input_ids"].size(0), max_windows) + ] + wordpiece_embeddings = ( + torch.cat(wordpiece_embeddings, dim=0) + if len(wordpiece_embeddings) > 1 + else wordpiece_embeddings[0] + ) mask = batch["mask"].clone() word_embeddings = torch.zeros(