Skip to content

Commit

Permalink
classification loss encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Corentin-Allaire committed Aug 26, 2024
1 parent 8b993a9 commit a0bc138
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SeedTransformer(nn.Module):
- embedding_encoder: Embedding layer for the encoder
- embedding_decoder: Embedding layer for the decoder
- pos_encoding_decoder: Positional encoding layer for the decoder
- nb_seeds_from_encoded: Linear layer to extract the number of seeds from the encoded information
- classify_seed: Linear layer to classify the seed by particle
- seed_vertex: Linear layer to extract the vertex position from the decoder output
- seed_momentum: Linear layer to extract the seed momentum from the decoder output
- keep_iterating: Linear layer to determine whether to keep iterating or not based on the decoder output
Expand All @@ -45,7 +45,7 @@ def __init__(
nb_head: int,
device_acc: str,
embedding_network,
dropout: float = 0.1,
dropout: float = 0.0,
dim_seed: int = 100,
):
super(SeedTransformer, self).__init__()
Expand Down Expand Up @@ -74,9 +74,9 @@ def __init__(
)

# Linear layer to extract the expected number of seed from the encoded information
self.nb_seeds_from_encoded = nn.Linear(
dim_hits,
1,
self.classify_seed = nn.Linear(
dim_embedding,
dim_seed,
device=device_acc,
)
# Linear layer to extract the seed Z0 and momentum from the decoder output
Expand All @@ -85,6 +85,7 @@ def __init__(
self.keep_iterating = nn.Linear(dim_embedding, 1, device=device_acc)

self.keep_sigmoide = nn.Sigmoid().to(device_acc)
self.class_softMax = nn.Softmax(dim=2).to(device_acc)
# # First token as a learnable parameter <= THING ABOUT THIS AT A LATER POINT !!!
# self.first_token = nn.Parameter(torch.randn(1, 6))

Expand Down Expand Up @@ -114,15 +115,15 @@ def encode(
Returns:
- encoded (Tensor): Encoded memory.
- nb_seeds_from_encoded (Tensor): Number of seeds from the encoded information.
- classify_seed (Tensor): Attempt to classify the seed by particle
"""
# Loop over the entry in the batch and run the embedding layer
embedded_src = self.embedding_encoder(hits)
# Encode the source sequence
encoded = self.transformer.encoder(
src=embedded_src, mask=mask, src_key_padding_mask=padding_mask
)
return encoded, self.nb_seeds_from_encoded(encoded[:, :, 0])
return encoded, self.class_softMax(self.classify_seed(encoded))

def decode(
self,
Expand Down Expand Up @@ -157,6 +158,7 @@ def decode(
tgt_key_padding_mask=padding_mask,
memory_key_padding_mask=None,
)

return (
self.seed_momentum(reconstructed_seeds),
self.keep_sigmoide(self.keep_iterating(reconstructed_seeds)),
Expand Down Expand Up @@ -189,7 +191,7 @@ def forward(

iter_threshold = 0.1
# Encode the source sequence
encoded, nb_seeds_encoder = self.encode(hits, mask_hits, padding_mask_hits)
encoded, seed_class = self.encode(hits, mask_hits, padding_mask_hits)
nb_loop = 0
keep_iteration = True
nb_seeds = Tensor(seed.size(0)).to(seed.device)
Expand All @@ -211,4 +213,4 @@ def forward(
nb_seeds[batch] = hits
break

return nb_seeds_encoder, nb_seeds, seed_momentum
return seed_class, nb_seeds, seed_momentum
Loading

0 comments on commit a0bc138

Please sign in to comment.