Skip to content

Commit

Permalink
Change detokenize behavior (#54)
Browse files Browse the repository at this point in the history
* now trims sequence

* bump changelog

* Fixed bug

* improve test coverage
  • Loading branch information
wfondrie authored May 10, 2024
1 parent 1b53a35 commit 486221c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

## [v0.4.8]
### Changed
- `Tokenizer.detokenize()` now truncates the output to the first stop token it finds, if `trim_stop_token=True`.

## [v0.4.7]
### Fixed
- Add stop and start tokens for `AnnotatedSpectrumDataset`, when available.
Expand Down
2 changes: 1 addition & 1 deletion depthcharge/tokenizers/peptides.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def detokenize(
tokens=tokens,
join=join,
trim_start_token=trim_start_token,
trim_stop_token=trim_start_token,
trim_stop_token=trim_stop_token,
)

if self.reverse:
Expand Down
18 changes: 10 additions & 8 deletions depthcharge/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def detokenize(
trim_start_token : bool, optional
Remove the start token from the beginning of a sequence.
trim_stop_token : bool, optional
Remove the stop token from the end of a sequence.
Remove the stop token and anything following it from the sequence.
Returns
-------
Expand All @@ -143,16 +143,18 @@ def detokenize(
"""
decoded = []
for row in tokens:
seq = [
self.reverse_index[i]
for i in row
if self.reverse_index[i] is not None
]
seq = []
for idx in row:
if self.reverse_index[idx] is None:
continue

if trim_stop_token and idx == self.stop_int:
break

seq.append(self.reverse_index[idx])

if trim_start_token and seq[0] == self.start_token:
seq.pop(0)
if trim_stop_token and seq[-1] == self.stop_token:
seq.pop(-1)

if join:
seq = "".join(seq)
Expand Down
22 changes: 22 additions & 0 deletions tests/unit_tests/test_tokenizers/test_peptides.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,25 @@ def test_almost_compliant_proform():
"""Test initializing with a peptide without an expicit mass sign."""
tokenizer = PeptideTokenizer.from_proforma("[10]-EDITHR")
assert "[+10.000000]-" in tokenizer.residues


@pytest.mark.parametrize(
("start", "stop", "expected"),
[
(True, True, "ACD"),
(True, False, "ACD$E"),
(False, True, "?ACD"),
(False, False, "?ACD$E"),
],
)
def test_trim(start, stop, expected):
"""Test that the start and stop tokens can be trimmed."""
tokenizer = PeptideTokenizer(start_token="?")
tokens = torch.tensor([[0, 2, 3, 4, 5, 1, 6]])
out = tokenizer.detokenize(
tokens,
trim_start_token=start,
trim_stop_token=stop,
)

assert out[0] == expected

0 comments on commit 486221c

Please sign in to comment.