Skip to content

Commit

Permalink
Scaffolds an abstract tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu committed Aug 8, 2024
1 parent b9f5843 commit 7b2c54a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 0 deletions.
Empty file.
Empty file.
49 changes: 49 additions & 0 deletions src/poli/core/util/tokenizers/abstract_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import numpy as np


def _pad(
texts: list[list[str]], max_sequence_length: int, padding_token: str = ""
) -> list[list[str]]:
if max_sequence_length < max(map(len, texts)):
raise ValueError(
f"max_sequence_length ({max_sequence_length}) must be greater than the length of the longest text ({max(map(len, texts))})"
)
return [
text + [padding_token] * (max_sequence_length - len(text)) for text in texts
]


class AbstractTokenizer:
def __init__(
self, max_sequence_length: int | float | None, padding_token: str = ""
) -> None:
if isinstance(max_sequence_length, float):
assert (
max_sequence_length == np.inf
), "max_sequence_length must be np.inf if float"

self.max_sequence_length = max_sequence_length
self.padding_token = padding_token

def _tokenize(self, texts: str | list[str]) -> list[str] | list[list[str]]:
raise NotImplementedError

def tokenize(self, texts: str | list[str]) -> np.ndarray:
unpadded_tokens = self._tokenize(texts)

if isinstance(unpadded_tokens[0], list):
tokens = unpadded_tokens
else:
tokens = [unpadded_tokens]

if self.max_sequence_length is None or np.isinf(self.max_sequence_length):
max_length = max(map(len, tokens))
tokens = _pad(tokens, max_length, padding_token=self.padding_token)
else:
tokens = _pad(
tokens, self.max_sequence_length, padding_token=self.padding_token
)

return np.array(tokens)
13 changes: 13 additions & 0 deletions src/poli/core/util/tokenizers/list_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from poli.core.util.tokenizers.abstract_tokenizer import AbstractTokenizer


class ListTokenizer(AbstractTokenizer):
def _tokenize(self, texts: str | list[str]) -> list[str] | list[list[str]]:
if isinstance(texts, str):
return list(texts)
elif isinstance(texts, list):
return [list(t) for t in texts]
else:
raise ValueError(f"Expected str or list, got {type(texts)}")
Empty file.
67 changes: 67 additions & 0 deletions src/poli/tests/tokenizers/test_tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pytest

from poli.core.util.proteins.defaults import AMINO_ACIDS
from poli.core.util.tokenizers.list_tokenizer import ListTokenizer


def test_list_tokenizer_on_single_sequence():
one_amino_acid = AMINO_ACIDS[0]

text = one_amino_acid * 10

tokenizer = ListTokenizer(max_sequence_length=10)

tokens = tokenizer.tokenize(text)
assert (tokens == np.array(list(one_amino_acid * 10))).all()


def test_list_tokenizer_on_multiple_sequences_of_same_length():
one_amino_acid = AMINO_ACIDS[0]

texts = [one_amino_acid * 10] * 5

tokenizer = ListTokenizer(max_sequence_length=10)

tokens = tokenizer.tokenize(texts)
assert (tokens == np.array([list(one_amino_acid * 10)] * 5)).all()


@pytest.mark.parametrize("max_sequence_length", [np.inf, 15, None])
def test_list_tokenizer_on_multiple_sequences_of_varying_length(max_sequence_length):
one_amino_acid = AMINO_ACIDS[0]

texts = [
one_amino_acid * 10,
one_amino_acid * 5,
one_amino_acid * 15,
one_amino_acid * 3,
one_amino_acid * 8,
]

tokenizer = ListTokenizer(max_sequence_length=max_sequence_length)

tokens = tokenizer.tokenize(texts)
assert tokens.shape == (5, 15)
assert tokens[0, 11] == tokenizer.padding_token


def test_list_tokenizer_outputs_error_on_wrong_max_sequence():
one_amino_acid = AMINO_ACIDS[0]

texts = [
one_amino_acid * 10,
one_amino_acid * 5,
one_amino_acid * 15,
one_amino_acid * 3,
one_amino_acid * 8,
]

tokenizer = ListTokenizer(max_sequence_length=10)

with pytest.raises(ValueError):
_ = tokenizer.tokenize(texts)


if __name__ == "__main__":
test_list_tokenizer_on_multiple_sequences_of_varying_length()

0 comments on commit 7b2c54a

Please sign in to comment.