-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b9f5843
commit 7b2c54a
Showing
6 changed files
with
129 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
|
||
def _pad( | ||
texts: list[list[str]], max_sequence_length: int, padding_token: str = "" | ||
) -> list[list[str]]: | ||
if max_sequence_length < max(map(len, texts)): | ||
raise ValueError( | ||
f"max_sequence_length ({max_sequence_length}) must be greater than the length of the longest text ({max(map(len, texts))})" | ||
) | ||
return [ | ||
text + [padding_token] * (max_sequence_length - len(text)) for text in texts | ||
] | ||
|
||
|
||
class AbstractTokenizer: | ||
def __init__( | ||
self, max_sequence_length: int | float | None, padding_token: str = "" | ||
) -> None: | ||
if isinstance(max_sequence_length, float): | ||
assert ( | ||
max_sequence_length == np.inf | ||
), "max_sequence_length must be np.inf if float" | ||
|
||
self.max_sequence_length = max_sequence_length | ||
self.padding_token = padding_token | ||
|
||
def _tokenize(self, texts: str | list[str]) -> list[str] | list[list[str]]: | ||
raise NotImplementedError | ||
|
||
def tokenize(self, texts: str | list[str]) -> np.ndarray: | ||
unpadded_tokens = self._tokenize(texts) | ||
|
||
if isinstance(unpadded_tokens[0], list): | ||
tokens = unpadded_tokens | ||
else: | ||
tokens = [unpadded_tokens] | ||
|
||
if self.max_sequence_length is None or np.isinf(self.max_sequence_length): | ||
max_length = max(map(len, tokens)) | ||
tokens = _pad(tokens, max_length, padding_token=self.padding_token) | ||
else: | ||
tokens = _pad( | ||
tokens, self.max_sequence_length, padding_token=self.padding_token | ||
) | ||
|
||
return np.array(tokens) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from __future__ import annotations | ||
|
||
from poli.core.util.tokenizers.abstract_tokenizer import AbstractTokenizer | ||
|
||
|
||
class ListTokenizer(AbstractTokenizer): | ||
def _tokenize(self, texts: str | list[str]) -> list[str] | list[list[str]]: | ||
if isinstance(texts, str): | ||
return list(texts) | ||
elif isinstance(texts, list): | ||
return [list(t) for t in texts] | ||
else: | ||
raise ValueError(f"Expected str or list, got {type(texts)}") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from poli.core.util.proteins.defaults import AMINO_ACIDS | ||
from poli.core.util.tokenizers.list_tokenizer import ListTokenizer | ||
|
||
|
||
def test_list_tokenizer_on_single_sequence(): | ||
one_amino_acid = AMINO_ACIDS[0] | ||
|
||
text = one_amino_acid * 10 | ||
|
||
tokenizer = ListTokenizer(max_sequence_length=10) | ||
|
||
tokens = tokenizer.tokenize(text) | ||
assert (tokens == np.array(list(one_amino_acid * 10))).all() | ||
|
||
|
||
def test_list_tokenizer_on_multiple_sequences_of_same_length(): | ||
one_amino_acid = AMINO_ACIDS[0] | ||
|
||
texts = [one_amino_acid * 10] * 5 | ||
|
||
tokenizer = ListTokenizer(max_sequence_length=10) | ||
|
||
tokens = tokenizer.tokenize(texts) | ||
assert (tokens == np.array([list(one_amino_acid * 10)] * 5)).all() | ||
|
||
|
||
@pytest.mark.parametrize("max_sequence_length", [np.inf, 15, None]) | ||
def test_list_tokenizer_on_multiple_sequences_of_varying_length(max_sequence_length): | ||
one_amino_acid = AMINO_ACIDS[0] | ||
|
||
texts = [ | ||
one_amino_acid * 10, | ||
one_amino_acid * 5, | ||
one_amino_acid * 15, | ||
one_amino_acid * 3, | ||
one_amino_acid * 8, | ||
] | ||
|
||
tokenizer = ListTokenizer(max_sequence_length=max_sequence_length) | ||
|
||
tokens = tokenizer.tokenize(texts) | ||
assert tokens.shape == (5, 15) | ||
assert tokens[0, 11] == tokenizer.padding_token | ||
|
||
|
||
def test_list_tokenizer_outputs_error_on_wrong_max_sequence(): | ||
one_amino_acid = AMINO_ACIDS[0] | ||
|
||
texts = [ | ||
one_amino_acid * 10, | ||
one_amino_acid * 5, | ||
one_amino_acid * 15, | ||
one_amino_acid * 3, | ||
one_amino_acid * 8, | ||
] | ||
|
||
tokenizer = ListTokenizer(max_sequence_length=10) | ||
|
||
with pytest.raises(ValueError): | ||
_ = tokenizer.tokenize(texts) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_list_tokenizer_on_multiple_sequences_of_varying_length() |