Skip to content

Commit

Permalink
feat: slice & batch over transformer windows to avoid GPU OOM errors
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Jan 15, 2024
1 parent b2d61e3 commit 9e3c61b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- Support doc -> list converters with parquet files writer
- Fixed some OOM errors when writing many outputs to parquet files
- Both edsnlp & spacy factories are now listed when a factory lookup fails
- Fixed some GPU OOM errors with the `eds.transformer` pipe when processing really long documents

## v0.10.3

Expand Down
18 changes: 16 additions & 2 deletions edsnlp/pipes/trainable/embeddings/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,25 @@ def forward(self, batch):
"mask": batch["mask"].clone(),
}

trf_result = self.transformer.base_model(
max_windows = self.max_tokens_per_device // batch["input_ids"].size(1)
kwargs = dict(
input_ids=batch["input_ids"].as_tensor(),
attention_mask=batch["input_ids"].mask,
)
wordpiece_embeddings = trf_result.last_hidden_state
wordpiece_embeddings = [
self.transformer.base_model(
**{
k: None if v is None else v[offset : offset + max_windows]
for k, v in kwargs.items()
}
).last_hidden_state
for offset in range(0, batch["input_ids"].size(0), max_windows)
]
wordpiece_embeddings = (
torch.cat(wordpiece_embeddings, dim=0)
if len(wordpiece_embeddings) > 1
else wordpiece_embeddings[0]
)

mask = batch["mask"].clone()
word_embeddings = torch.zeros(
Expand Down

0 comments on commit 9e3c61b

Please sign in to comment.