Skip to content

Commit

Permalink
Merge pull request #2 from yutanagano/lint
Browse files Browse the repository at this point in the history
Lint code with black
  • Loading branch information
yutanagano authored Jun 5, 2024
2 parents d48ccaf + c24aab2 commit 005b99d
Show file tree
Hide file tree
Showing 17 changed files with 1,876 additions and 109 deletions.
19 changes: 17 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,24 @@ build-backend = "setuptools.build_meta"

[project]
name = "libtcrlm"
requires-python = ">=3.9"
authors = [
{name = "Yuta Nagano", email = "[email protected]"}
]
maintainers = [
{name = "Yuta Nagano", email = "[email protected]"}
]
description = "TCR language modelling library using Pytorch."
readme = "README.md"
keywords = ["TCR", "TR", "T cell", "transformer", "bert", "MLM", "immunology", "bioinformatics"]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering",
]
dependencies = [
"blosum~=2.0",
"pandas~=2.2",
Expand All @@ -19,8 +33,9 @@ dynamic = ["version"]
[project.optional-dependencies]
dev = [
"pytest",
"pytest-cov"
"pytest-cov",
"tox"
]

[tool.setuptools.dynamic]
version = {attr = "libtcrlm.VERSION"}
version = {attr = "libtcrlm.VERSION"}
4 changes: 2 additions & 2 deletions src/libtcrlm/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
@property
def d_model(self) -> int:
return self._self_attention_stack.d_model

def set_fine_tuning_mode(self, turn_on: bool) -> None:
self._self_attention_stack.set_fine_tuning_mode(turn_on)

Expand Down Expand Up @@ -58,4 +58,4 @@ def _embed(self, tokenised_tcrs: LongTensor) -> FloatTensor:
return self._token_embedder.forward(tokenised_tcrs)

def _get_padding_mask(self, tokenised_tcrs: LongTensor) -> BoolTensor:
return tokenised_tcrs[:, :, 0] == DefaultTokenIndex.NULL
return tokenised_tcrs[:, :, 0] == DefaultTokenIndex.NULL
2 changes: 1 addition & 1 deletion src/libtcrlm/schema/pmhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __repr__(self) -> str:
"?" if self.epitope_sequence is None else self.epitope_sequence
)
return f"{epitope_representation}/{self.mhc_a}/{self.mhc_b}"

def __hash__(self) -> int:
return hash((self.epitope_sequence, self.mhc_a.symbol, self.mhc_b.symbol))

Expand Down
10 changes: 7 additions & 3 deletions src/libtcrlm/schema/tcr.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,15 @@ def cdr1b_sequence(self) -> Optional[str]:
@property
def cdr2b_sequence(self) -> Optional[str]:
return self._trbv.cdr2_sequence

@property
def both_chains_specified(self) -> bool:
tra_specified = (not self._trav._gene_is_unknown()) or (not self.junction_a_sequence is None)
trb_specified = (not self._trbv._gene_is_unknown()) or (not self.junction_b_sequence is None)
tra_specified = (not self._trav._gene_is_unknown()) or (
not self.junction_a_sequence is None
)
trb_specified = (not self._trbv._gene_is_unknown()) or (
not self.junction_b_sequence is None
)
return tra_specified and trb_specified

def copy(self) -> "Tcr":
Expand Down
21 changes: 15 additions & 6 deletions src/libtcrlm/self_attention_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@ class SelfAttentionStackWithBuiltins(SelfAttentionStack):
d_model: int = None

def __init__(
self, num_layers: int, d_model: int, nhead: int, dim_feedforward: Optional[int] = None, dropout: float = 0.1
self,
num_layers: int,
d_model: int,
nhead: int,
dim_feedforward: Optional[int] = None,
dropout: float = 0.1,
) -> None:
super().__init__()

if dim_feedforward is None:
dim_feedforward = d_model * 4 # backwards compatibility
dim_feedforward = d_model * 4 # backwards compatibility

self.d_model = d_model
self._num_layers_in_stack = num_layers
Expand Down Expand Up @@ -66,7 +71,7 @@ def get_token_embeddings_at_penultimate_layer(
)

return token_embeddings

def set_fine_tuning_mode(self, turn_on: bool) -> None:
upper_layers_require_grad = not turn_on
penultimate_layer_index = self._num_layers_in_stack - 1
Expand Down Expand Up @@ -95,7 +100,11 @@ def __init__(
in_features=embedding_dim, out_features=d_model, bias=False
)
self._standard_stack = SelfAttentionStackWithBuiltins(
num_layers=num_layers, d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout
num_layers=num_layers,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
)

def forward(self, token_embeddings: Tensor, padding_mask: Tensor) -> Tensor:
Expand All @@ -109,6 +118,6 @@ def get_token_embeddings_at_penultimate_layer(
return self._standard_stack.get_token_embeddings_at_penultimate_layer(
projected_embeddings, padding_mask
)

def set_fine_tuning_mode(self, turn_on: bool) -> None:
self._standard_stack.set_fine_tuning_mode(turn_on)
self._standard_stack.set_fine_tuning_mode(turn_on)
6 changes: 3 additions & 3 deletions src/libtcrlm/token_embedder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
Cdr3Embedder,
Cdr3EmbedderWithRelativePositions,
Cdr3SimpleEmbedder,
SingleChainCdr3SimpleEmbedder
SingleChainCdr3SimpleEmbedder,
)
from .cdr_embedder import (
CdrBlosumEmbedder,
CdrSimpleEmbedder,
CdrEmbedder,
SingleChainCdrEmbedder,
SingleChainCdrEmbedderWithRelativePositions,
SingleChainCdrSimpleEmbedder
)
SingleChainCdrSimpleEmbedder,
)
18 changes: 9 additions & 9 deletions src/libtcrlm/token_embedder/blosum_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,21 @@ def __init__(self) -> None:
self._special_token_embeddings = Embedding(
num_embeddings=len(DefaultTokenIndex),
embedding_dim=len(AminoAcid),
padding_idx=DefaultTokenIndex.NULL
padding_idx=DefaultTokenIndex.NULL,
)
self._register_blosum_embeddings()

def _register_blosum_embeddings(self) -> None:
blosum_matrix = blosum.BLOSUM(62)

null_embedding = torch.zeros(size=(len(DefaultTokenIndex),len(AminoAcid)))
aa_embeddings = torch.zeros(size=(len(AminoAcid),len(AminoAcid)))
null_embedding = torch.zeros(size=(len(DefaultTokenIndex), len(AminoAcid)))
aa_embeddings = torch.zeros(size=(len(AminoAcid), len(AminoAcid)))

for row, column in itertools.product(AminoAcid, repeat=2):
blosum_score = blosum_matrix[row.name][column.name]
aa_embeddings[row.value,column.value] = blosum_score
aa_embeddings[row.value, column.value] = blosum_score

blosum_embeddings = torch.concatenate(
[null_embedding, aa_embeddings], dim=0
)
blosum_embeddings = torch.concatenate([null_embedding, aa_embeddings], dim=0)
blosum_embeddings_normalised = blosum_embeddings / blosum_embeddings.abs().max()

self.register_buffer("_blosum_embeddings", blosum_embeddings_normalised)
Expand All @@ -39,7 +37,9 @@ def forward(self, token_indices: LongTensor) -> FloatTensor:
special_token_mask = token_indices < len(DefaultTokenIndex)
token_indices_aa_masked_out = token_indices * special_token_mask

special_token_embeddings = self._special_token_embeddings.forward(token_indices_aa_masked_out)
special_token_embeddings = self._special_token_embeddings.forward(
token_indices_aa_masked_out
)
aa_blosum_embeddings = self._blosum_embeddings[token_indices]

return special_token_embeddings + aa_blosum_embeddings
4 changes: 2 additions & 2 deletions src/libtcrlm/token_embedder/cdr3_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(self, tokenised_tcrs: LongTensor) -> FloatTensor:
[token_component, position_component, compartment_component], dim=-1
)
return all_components_stacked


class SingleChainCdr3SimpleEmbedder(TokenEmbedder):
def __init__(self) -> None:
Expand All @@ -111,4 +111,4 @@ def forward(self, tokenised_tcrs: LongTensor) -> FloatTensor:
all_components_stacked = torch.concatenate(
[token_component, position_component], dim=-1
)
return all_components_stacked
return all_components_stacked
10 changes: 7 additions & 3 deletions src/libtcrlm/token_embedder/cdr_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def __init__(self, embedding_dim: int) -> None:
padding_idx=DefaultTokenIndex.NULL,
)
self._position_embedding = SimpleRelativePositionEmbedding()
self._compartment_embedding = OneHotTokenIndexEmbedding(SingleChainCdrCompartmentIndex)
self._compartment_embedding = OneHotTokenIndexEmbedding(
SingleChainCdrCompartmentIndex
)

def forward(self, tokenised_tcrs: LongTensor) -> FloatTensor:
token_component = self._token_embedding.forward(tokenised_tcrs[:, :, 0])
Expand All @@ -159,7 +161,9 @@ def __init__(self) -> None:
super().__init__()
self._token_embedding = OneHotTokenIndexEmbedding(AminoAcidTokenIndex)
self._position_embedding = SimpleRelativePositionEmbedding()
self._compartment_embedding = OneHotTokenIndexEmbedding(SingleChainCdrCompartmentIndex)
self._compartment_embedding = OneHotTokenIndexEmbedding(
SingleChainCdrCompartmentIndex
)

def forward(self, tokenised_tcrs: LongTensor) -> FloatTensor:
token_component = self._token_embedding.forward(tokenised_tcrs[:, :, 0])
Expand All @@ -170,4 +174,4 @@ def forward(self, tokenised_tcrs: LongTensor) -> FloatTensor:
all_components_stacked = torch.concatenate(
[token_component, position_component, compartment_component], dim=-1
)
return all_components_stacked
return all_components_stacked
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def forward(self, position_indices: LongTensor) -> FloatTensor:
)

RELATIVE_POSITION_IF_ONLY_ONE_TOKEN_IN_COMPARTMENT = 0.5
relative_token_positions[
relative_token_positions.isnan()
] = RELATIVE_POSITION_IF_ONLY_ONE_TOKEN_IN_COMPARTMENT
relative_token_positions[relative_token_positions.isnan()] = (
RELATIVE_POSITION_IF_ONLY_ONE_TOKEN_IN_COMPARTMENT
)
relative_token_positions[null_mask] = 0

relative_token_positions = relative_token_positions.unsqueeze(dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion src/libtcrlm/tokeniser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .tokeniser import Tokeniser
from .cdr3_tokeniser import Cdr3Tokeniser, BetaCdr3Tokeniser
from .cdr_tokeniser import CdrTokeniser, AlphaCdrTokeniser, BetaCdrTokeniser
from .cdr_tokeniser import CdrTokeniser, AlphaCdrTokeniser, BetaCdrTokeniser
6 changes: 2 additions & 4 deletions src/libtcrlm/tokeniser/cdr3_tokeniser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def tokenise(self, tcr: Tcr) -> Tensor:
tcr.junction_b_sequence, Cdr3CompartmentIndex.CDR3B
)

all_cdrs_tokenised = (
[initial_cls_vector] + cdr3a + cdr3b
)
all_cdrs_tokenised = [initial_cls_vector] + cdr3a + cdr3b

number_of_tokens_other_than_initial_cls = len(all_cdrs_tokenised) - 1
if number_of_tokens_other_than_initial_cls == 0:
Expand Down Expand Up @@ -95,4 +93,4 @@ def _convert_to_numerical_form(

iterator_over_token_vectors = zip(token_indices, token_positions, cdr_length)

return list(iterator_over_token_vectors)
return list(iterator_over_token_vectors)
8 changes: 6 additions & 2 deletions src/libtcrlm/tokeniser/cdr_tokeniser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from libtcrlm.tokeniser.tokeniser import Tokeniser
from libtcrlm.tokeniser.token_indices import AminoAcidTokenIndex, CdrCompartmentIndex, SingleChainCdrCompartmentIndex
from libtcrlm.tokeniser.token_indices import (
AminoAcidTokenIndex,
CdrCompartmentIndex,
SingleChainCdrCompartmentIndex,
)
from libtcrlm.schema import Tcr
import torch
from torch import Tensor
Expand Down Expand Up @@ -178,4 +182,4 @@ def _convert_to_numerical_form(
token_indices, token_positions, cdr_length, compartment_index
)

return list(iterator_over_token_vectors)
return list(iterator_over_token_vectors)
2 changes: 1 addition & 1 deletion src/libtcrlm/vector_representation_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ def get_vector_representations_of(
final_cls_embeddings = final_token_embeddings[:, LOCATION_OF_CLS_TOKEN, :]
l2_normed_cls_embeddings = F.normalize(final_cls_embeddings, p=2, dim=1)

return l2_normed_cls_embeddings
return l2_normed_cls_embeddings
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def mock_data_path():

@pytest.fixture
def mock_data_df(mock_data_path):
return pd.read_csv(mock_data_path)
return pd.read_csv(mock_data_path)
Loading

0 comments on commit 005b99d

Please sign in to comment.