diff --git a/Examples/Scripts/Python/MLAmbiguityResolution/transformer/transformer_network.py b/Examples/Scripts/Python/MLAmbiguityResolution/transformer/transformer_network.py index 6ee19cb8f8a..1a96a116eff 100644 --- a/Examples/Scripts/Python/MLAmbiguityResolution/transformer/transformer_network.py +++ b/Examples/Scripts/Python/MLAmbiguityResolution/transformer/transformer_network.py @@ -19,7 +19,7 @@ class SeedTransformer(nn.Module): - embedding_encoder: Embedding layer for the encoder - embedding_decoder: Embedding layer for the decoder - pos_encoding_decoder: Positional encoding layer for the decoder - - nb_seeds_from_encoded: Linear layer to extract the number of seeds from the encoded information + - classify_seed: Linear layer to classify the seed by particle - seed_vertex: Linear layer to extract the vertex position from the decoder output - seed_momentum: Linear layer to extract the seed momentum from the decoder output - keep_iterating: Linear layer to determine whether to keep iterating or not based on the decoder output @@ -45,7 +45,7 @@ def __init__( nb_head: int, device_acc: str, embedding_network, - dropout: float = 0.1, + dropout: float = 0.0, dim_seed: int = 100, ): super(SeedTransformer, self).__init__() @@ -74,9 +74,9 @@ def __init__( ) # Linear layer to extract the expected number of seed from the encoded information - self.nb_seeds_from_encoded = nn.Linear( - dim_hits, - 1, + self.classify_seed = nn.Linear( + dim_embedding, + dim_seed, device=device_acc, ) # Linear layer to extract the seed Z0 and momentum from the decoder output @@ -85,6 +85,7 @@ def __init__( self.keep_iterating = nn.Linear(dim_embedding, 1, device=device_acc) self.keep_sigmoide = nn.Sigmoid().to(device_acc) + self.class_softMax = nn.Softmax(dim=2).to(device_acc) # # First token as a learnable parameter <= THING ABOUT THIS AT A LATER POINT !!! # self.first_token = nn.Parameter(torch.randn(1, 6)) @@ -114,7 +115,7 @@ def encode( Returns: - encoded (Tensor): Encoded memory. - - nb_seeds_from_encoded (Tensor): Number of seeds from the encoded information. + - classify_seed (Tensor): Attempt to classify the seed by particle """ # Loop over the entry in the batch and run the embedding layer embedded_src = self.embedding_encoder(hits) @@ -122,7 +123,7 @@ def encode( encoded = self.transformer.encoder( src=embedded_src, mask=mask, src_key_padding_mask=padding_mask ) - return encoded, self.nb_seeds_from_encoded(encoded[:, :, 0]) + return encoded, self.class_softMax(self.classify_seed(encoded)) def decode( self, @@ -157,6 +158,7 @@ def decode( tgt_key_padding_mask=padding_mask, memory_key_padding_mask=None, ) + return ( self.seed_momentum(reconstructed_seeds), self.keep_sigmoide(self.keep_iterating(reconstructed_seeds)), @@ -189,7 +191,7 @@ def forward( iter_threshold = 0.1 # Encode the source sequence - encoded, nb_seeds_encoder = self.encode(hits, mask_hits, padding_mask_hits) + encoded, seed_class = self.encode(hits, mask_hits, padding_mask_hits) nb_loop = 0 keep_iteration = True nb_seeds = Tensor(seed.size(0)).to(seed.device) @@ -211,4 +213,4 @@ def forward( nb_seeds[batch] = hits break - return nb_seeds_encoder, nb_seeds, seed_momentum + return seed_class, nb_seeds, seed_momentum diff --git a/Examples/Scripts/Python/MLAmbiguityResolution/transformer/transformer_train.py b/Examples/Scripts/Python/MLAmbiguityResolution/transformer/transformer_train.py index 404f4cecc70..ad886cb0abe 100644 --- a/Examples/Scripts/Python/MLAmbiguityResolution/transformer/transformer_train.py +++ b/Examples/Scripts/Python/MLAmbiguityResolution/transformer/transformer_train.py @@ -58,14 +58,9 @@ def plot_loss( plt.xlabel("Epoch") plt.ylabel(loss_name) plt.legend() + plt.savefig(loss_name + ".png") if interactive: - plt.show(block=True) - plt.savefig(loss_name + ".png") - - if not interactive: - plt.show(block=False) - plt.savefig(loss_name + ".png") - + plt.show() pd.DataFrame(metrics_train).to_csv(loss_name + "_train.csv") pd.DataFrame(metrics_val).to_csv(loss_name + "_val.csv") @@ -86,8 +81,9 @@ def __init__(self): - max_particle_input: int: The maximum number of particle in the decoder input - vertex_cuts: list[int]: The cuts to apply on the vertex position to only keep the primary vertex - event_test: int: Number of event in the test will be run on - - encoder_only: bool: If set to True only the encoder will be run (can be use to pretrain the network) - interactive: bool: If set to True the plot will be displayed if not it will only be saved as a png file + - encoder_only: bool: If set to True only the encoder will be run (can be use to pretrain the network) + - input_type: str: The type of input to use (csv or tensor) - device_acc: str: The device to use (cpu/gpu) """ self.embedding = "ID+Position" @@ -99,6 +95,7 @@ def __init__(self): self.event_test = 1 self.interactive = True self.encoder_only = False + self.input_type = "csv" self.device_acc = torch.device("cpu") def parse_args(self): @@ -111,7 +108,7 @@ def parse_args(self): parser.add_argument( "--vertex_cuts", type=list, default=[10, 10, 200], help="Vertex cuts" ) - parser.add_argument("--epoch_nb", type=int, default=1, help="Number of epoch") + parser.add_argument("--epoch_nb", type=int, default=10, help="Number of epoch") parser.add_argument("--batch_size", type=int, default=10, help="Batch size") parser.add_argument( "--max_hit_input", @@ -135,6 +132,9 @@ def parse_args(self): parser.add_argument( "--embedding", type=str, default="ID", help="Type of embedding" ) + parser.add_argument( + "--input_type", type=str, default="csv", help="Type of input" + ) args = parser.parse_args() self.vertex_cuts = args.vertex_cuts self.epoch_nb = args.epoch_nb @@ -144,6 +144,8 @@ def parse_args(self): self.event_test = args.event_test self.interactive = args.interactive self.embedding = args.embedding + self.encoder_only = args.encoder_only + self.input_type = args.input_type def print_config(self): """ @@ -173,7 +175,7 @@ def __init__(self, epoch_nb: int = 1): - loss_iter: list[float]: the loss for the iter variable computed by the decoder for each epoch """ self.loss = [0] * epoch_nb - self.loss_nb = [0] * epoch_nb + self.loss_class = [0] * epoch_nb self.loss_momentum = [0] * epoch_nb self.loss_iter = [0] * epoch_nb @@ -181,7 +183,7 @@ def add_loss( self, epoch: int, loss: float, - loss_nb: float, + loss_class: float, loss_momentum: float, loss_iter: float, ): @@ -195,7 +197,7 @@ def add_loss( - loss_iter: the loss for the iter variable computed by the decoder """ self.loss[epoch] += loss - self.loss_nb[epoch] += loss_nb + self.loss_class[epoch] += loss_class self.loss_momentum[epoch] += loss_momentum self.loss_iter[epoch] += loss_iter @@ -206,7 +208,7 @@ def print_loss(self, epoch: int): - epoch: the epoch number """ print("Epoch", epoch, "loss is", self.loss[epoch]) - print("Epoch", epoch, "loss_nb is", self.loss_nb[epoch]) + print("Epoch", epoch, "loss_class is", self.loss_class[epoch]) print("Epoch", epoch, "loss_momentum is", self.loss_momentum[epoch]) print("Epoch", epoch, "loss_iter is", self.loss_iter[epoch]) @@ -217,7 +219,7 @@ def prepare_input_tensor( nb_events: int, cfg: config, embedding, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """ Prepare the input tensor and padding mask for a given event. @@ -232,11 +234,25 @@ def prepare_input_tensor( # Ininitalise the padding masks padding_mask_hit = torch.zeros(nb_events, cfg.max_hit_input) padding_mask_particle = torch.zeros(nb_events, cfg.max_particle_input) + particle_class = torch.zeros(nb_events, cfg.max_hit_input, cfg.max_particle_input) for i in range(nb_events): # Select the hits and particles for the event i hits_event = hits[hits["event_id"] == i] particles_event = particles[particles["event_id"] == i] - particles_event = particles_event[["vx", "vy", "vz", "eta", "phi", "pT"]] + + # Fill same_particle which contain for each event a list of hits and for eah hits has a tensor of value 1 if the hit is the same as the particle + # and 0 otherwise + # The tensor is of size (max_hit_input, max_particle_input) + + particle_class[ + i, + : hits_event.shape[0], + : min(len(particles_event["particle_id"].unique()), cfg.max_particle_input), + ] = torch.tensor( + (hits_event["particle_id"].values.reshape(-1, 1)) + == (particles_event["particle_id"].values[: cfg.max_particle_input]) + ) + particles_event = particles_event[["vz", "eta", "phi", "pT"]] # Add one column to the particles DataFrame to indicate if a particle is the last one in the events particles_event["iter"] = 1 particles_event.iloc[-1, -1] = 0 @@ -318,41 +334,24 @@ def prepare_input_tensor( tensor_particle = torch.cat( ( tensor_particle, - torch.zeros(cfg.max_particle_input - input_size_particle, 7), + torch.zeros(cfg.max_particle_input - input_size_particle, 5), ) ) if i == 0: input_tensor_hits = tensor_hit.unsqueeze(0) input_tensor_particles = tensor_particle.unsqueeze(0) - tensor_nb_particle = torch.tensor( - [len(particles_event)], dtype=torch.float32 - ).unsqueeze(0) + else: input_tensor_hits = torch.cat((input_tensor_hits, tensor_hit.unsqueeze(0))) input_tensor_particles = torch.cat( (input_tensor_particles, tensor_particle.unsqueeze(0)) ) - tensor_nb_particle = torch.cat( - ( - tensor_nb_particle, - torch.tensor([len(particles_event)], dtype=torch.float32).unsqueeze( - 0 - ), - ) - ) - - # Move the tensor to the right device - # input_tensor_hits.to(cfg.device_acc) - # input_tensor_particles.to(cfg.device_acc) - # padding_mask_hit.to(cfg.device_acc) - # padding_mask_particle.to(cfg.device_acc) - # tensor_nb_particle.to(cfg.device_acc) return ( input_tensor_hits, input_tensor_particles, - tensor_nb_particle, + particle_class, padding_mask_hit, padding_mask_particle, ) @@ -415,12 +414,12 @@ def decoder_loss( exit() # Compute the loss by comparing the decoded value to the particles in the event - loss_momentum = F.mse_loss(momentum, particles[:, :, 2:6]) - loss_iter = F.binary_cross_entropy(keep_iterating, particles[:, :, 6].unsqueeze(-1)) + loss_momentum = F.mse_loss(momentum, particles[:, :, 0:4]) + loss_iter = F.binary_cross_entropy(keep_iterating, particles[:, :, 4].unsqueeze(-1)) # Extract the keep_iterating value for the last particle in the event (it should be 0) stop_iterating = keep_iterating - ( - keep_iterating * particles[:, :, 6].unsqueeze(-1) + keep_iterating * particles[:, :, 4].unsqueeze(-1) ) # Compute the corresponding loss, enrich_iter will be used to balance the case iter=0 and iter=1 loss_iter = (1 - enrich_iter) * loss_iter + enrich_iter * 100 * F.mse_loss( @@ -435,11 +434,12 @@ def compute_loss( particles: Tensor, padding_mask_hits: Tensor, padding_mask_particle: Tensor, - nb_particles: Tensor, + particle_class: Tensor, mask_hits: Tensor, mask_particle: Tensor, initial_seed: Tensor, model: SeedTransformer, + device: str, encoder_only: bool = False, ) -> Tensor: """ @@ -459,22 +459,21 @@ def compute_loss( """ # Run the encoder on the hits # Print the device of the encoder input - print("The device of the encoder input is", hits.device) + print("The device of the encoder input is", device) # Add all the input tensor to the device - hits = hits.to(model.device) - padding_mask_hits = padding_mask_hits.to(model.device) - padding_mask_particle = padding_mask_particle.to(model.device) - nb_particles = nb_particles.to(model.device) + hits = hits.to(device) + padding_mask_hits = padding_mask_hits.to(device) + padding_mask_particle = padding_mask_particle.to(device) + particle_class = particle_class.to(device) - encoded, nb_seed = model.encode(hits, mask_hits, padding_mask_hits) + encoded, seed_class = model.encode(hits, mask_hits, padding_mask_hits) # Compute the loss for nb_seed by comparing its value to the number of particles in the events of the batch - loss_nb = F.mse_loss(nb_seed, nb_particles) + loss_class = F.cross_entropy(seed_class, particle_class) if encoder_only == True: loss_momentum, loss_iter = ( torch.zeros(1), torch.zeros(1), - torch.zeros(1), ) else: # We will now run the decoder on the particles @@ -499,7 +498,7 @@ def compute_loss( particles, 0.5, ) - return loss_nb, loss_momentum, loss_iter + return loss_class, loss_momentum, loss_iter def run_model( @@ -507,7 +506,7 @@ def run_model( cfg: config, input_tensor_hits: Tensor, input_tensor_particles: Tensor, - nb_particles: Tensor, + particle_class: Tensor, padding_mask_hit: Tensor, padding_mask_particle: Tensor, model: SeedTransformer, @@ -529,7 +528,7 @@ def run_model( nb_batches = input_tensor_hits.size(0) // cfg.batch_size # Loop over the event batches for i in range(nb_batches): - if optimiser is not None and i % 100 == 0: + if optimiser is not None: print("Training batch:", i, "/", nb_batches) else: print("Validation batch:", i, "/", nb_batches) @@ -547,30 +546,38 @@ def run_model( batch_padding_particle = padding_mask_particle[ i * cfg.batch_size : (i + 1) * cfg.batch_size ] - batch_nb_particles = nb_particles[i * cfg.batch_size : (i + 1) * cfg.batch_size] + batch_particle_class = particle_class[ + i * cfg.batch_size : (i + 1) * cfg.batch_size + ] # Create the lookahead mask for the hit (encoder) and particle (decoder) mask_hits = build_look_ahead_mask(cfg.max_hit_input, cfg.device_acc) mask_particle = build_look_ahead_mask(cfg.max_particle_input, cfg.device_acc) # Create an initialisation seed for the transformer initial_seed = init_seed(cfg.batch_size, cfg.device_acc) # Compute the loos for the batch - loss_nb, loss_momentum, loss_iter = compute_loss( + loss_class, loss_momentum, loss_iter = compute_loss( batch_tensor_hits, batch_tensor_particles, batch_padding_hit, batch_padding_particle, - batch_nb_particles, + batch_particle_class, mask_hits, mask_particle, initial_seed, model, + cfg.device_acc, cfg.encoder_only, ) - loss = loss_nb + loss_momentum + loss_iter + print("The loss class: ", loss_class.item()) + print("The loss momentum: ", loss_momentum.item()) + print("The loss iter: ", loss_iter.item()) + + # Add the loss to t + loss = 100 * loss_class + 0.1 * loss_momentum + 10 * loss_iter met.add_loss( epoch, loss.item(), - loss_nb.item(), + loss_class.item(), loss_momentum.item(), loss_iter.item(), ) @@ -606,7 +613,7 @@ def test_model( ( input_tensor_hits, input_tensor_particles, - nb_particles, + particle_class, padding_mask_hit, _, ) = prepare_input_tensor(hits, particles, nb_events, cfg, model.embedding_encoder) @@ -619,14 +626,14 @@ def test_model( initial_seed = torch.cat( ( initial_seed, - torch.zeros(1, cfg.max_particle_input - 1, 7, device=cfg.device_acc), + torch.zeros(1, cfg.max_particle_input - 1, 5, device=cfg.device_acc), ), dim=1, ) for event in range(nb_events): # Run the encoder on the hits - nb_seeds_encoded, nb_seeds, seed_momentum = model( + seed_class, nb_seeds, seed_momentum = model( input_tensor_hits[event].unsqueeze(0), initial_seed, mask_hits, @@ -636,15 +643,24 @@ def test_model( ) print("testing event ", nb_events + 1) - print("The number of track is :", nb_particles[event]) - print("The number of seed evaluated by the encoder is ", nb_seeds_encoded[0]) + print( + "The number of track is :", len(particles[particles["event_id"] == event]) + ) print("The number of seed found is ", nb_seeds[0]) + # Compare the seed class and the particle class, for each hit print the bin nb of particle_class different from 0 and the bin nb of the largest bin of seed_class + for hit in range(seed_class.size(1)): + particle_bin = torch.argmax(particle_class[0, hit]) + seed_bin = torch.argmax(seed_class[0, hit]) + print("Hit", hit) + print("Particle class bin:", particle_bin.item()) + print("Seed class bin:", seed_bin.item()) + # Print loop over the seed, print them and print the corresponding particle seed_i = 0 for seed in range(seed_momentum.size(1)): # Compare the particle and the seed print("Seed", seed_i) - print("Particle", input_tensor_particles[0, seed][2:]) + print("Particle", input_tensor_particles[0, seed]) print("Seed vertex Z", seed_momentum[0, seed][0]) print("Seed momentum", seed_momentum[0, seed][1:]) seed_i += 1 @@ -667,62 +683,33 @@ def main(): cfg.print_config() print("Using device:", cfg.device_acc) - # Open the hits and particles csv files - hits_train, particles_train, nb_events_train = read_data( - "train/hits.csv", "train/particles.csv", 0, cfg.vertex_cuts - ) - # hits_train = pd.DataFrame() - # particles_train = pd.DataFrame() - # val_fraction = 0.1 - # dir_path = "ODD_data_mu" - # nb_files = len([name for name in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, name))]) - # nb_events = 0 - # for i in range(math.floor(nb_files*(1-val_fraction))): - # hits, particles, nb_events = read_data("train/odd_full_chain_" + str(i) + "/hits.csv", "train/odd_full_chain_" + str(i) + "/particles.csv", nb_events, cfg.vertex_cuts) - # hits_train = pd.concat([hits_train, hits]) - # particles_train = pd.concat([particles_train, particles]) - - hits_val, particles_val, nb_events_val = read_data( - "val/hits.csv", "val/particles.csv", 0, cfg.vertex_cuts - ) - # hits_train = pd.DataFrame() - # particles_train = pd.DataFrame() - # nb_events = 0 - # for i in range(math.floor(nb_files * (1 - val_fraction)), nb_files): - # hits, particles, nb_events = read_data( - # "train/odd_full_chain_" + str(i) + "/hits.csv", - # "train/odd_full_chain_" + str(i) + "/particles.csv", - # nb_events, - # cfg.vertex_cuts, - # ) - # hits_train = pd.concat([hits_train, hits]) - # particles_train = pd.concat([particles_train, particles]) + emb = 512 if cfg.embedding == "ID": embedding_encoder = EmbeddingGeoID( - emb_size=512, + emb_size=emb, max_volume=100, max_layer=100, max_sensitive=10000, max_extra=100, - dropout=0.1, + dropout=0.0, device=cfg.device_acc, ) elif cfg.embedding == "Position": embedding_encoder = EmbeddingHitPosition( - emb_size=512, + emb_size=emb, range_x=[-200, 200], range_y=[-200, 200], range_z=[-3000, 3000], bins_x=100, bins_y=100, bins_z=100, - dropout=0.1, + dropout=0.0, device=cfg.device_acc, ) elif cfg.embedding == "ID+Position": embedding_encoder = EmbeddingHitIDPosition( - emb_size=512, + emb_size=emb, max_volume=100, max_layer=100, range_x=[-200, 200], @@ -731,7 +718,7 @@ def main(): bins_x=100, bins_y=100, bins_z=100, - dropout=0.1, + dropout=0.0, device=cfg.device_acc, ) else: @@ -739,9 +726,9 @@ def main(): # Create the transformer model model = SeedTransformer( - 6, - 6, - 512, + 3, + 3, + emb, cfg.max_hit_input, 4, cfg.device_acc, @@ -757,32 +744,97 @@ def main(): # if param.requires_grad: # print(name, param.data) + if cfg.input_type == "csv": + + # Open the hits and particles csv files + hits_train, particles_train, nb_events_train = read_data( + "train/hits.csv", "train/particles.csv", 0, cfg.vertex_cuts + ) + # hits_train = pd.DataFrame() + # particles_train = pd.DataFrame() + # val_fraction = 0.1 + # dir_path = "ODD_data_mu" + # nb_files = len([name for name in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, name))]) + # nb_events = 0 + # for i in range(math.floor(nb_files*(1-val_fraction))): + # hits, particles, nb_events = read_data("train/odd_full_chain_" + str(i) + "/hits.csv", "train/odd_full_chain_" + str(i) + "/particles.csv", nb_events, cfg.vertex_cuts) + # hits_train = pd.concat([hits_train, hits]) + # particles_train = pd.concat([particles_train, particles]) + + ( + input_tensor_hits_train, + input_tensor_particles_train, + particle_class_train, + padding_mask_hit_train, + padding_mask_particle_train, + ) = prepare_input_tensor( + hits_train, particles_train, nb_events_train, cfg, model.embedding_encoder + ) + # Save the new tensor so we can reuse them for later training + torch.save(input_tensor_hits_train, "train/input_tensor_hits_train.pt") + torch.save( + input_tensor_particles_train, "train/input_tensor_particles_train.pt" + ) + torch.save(particle_class_train, "train/particle_class_train.pt") + torch.save(padding_mask_hit_train, "train/padding_mask_hit_train.pt") + torch.save(padding_mask_particle_train, "train/padding_mask_particle_train.pt") + + hits_val, particles_val, nb_events_val = read_data( + "val/hits.csv", "val/particles.csv", 0, cfg.vertex_cuts + ) + # hits_train = pd.DataFrame() + # particles_train = pd.DataFrame() + # nb_events = 0 + # for i in range(math.floor(nb_files * (1 - val_fraction)), nb_files): + # hits, particles, nb_events = read_data( + # "train/odd_full_chain_" + str(i) + "/hits.csv", + # "train/odd_full_chain_" + str(i) + "/particles.csv", + # nb_events, + # cfg.vertex_cuts, + # ) + # hits_train = pd.concat([hits_train, hits]) + # particles_train = pd.concat([particles_train, particles]) + + # Prepare the input tensor and padding mask + ( + input_tensor_hits_val, + input_tensor_particles_val, + particle_class_val, + padding_mask_hit_val, + padding_mask_particle_val, + ) = prepare_input_tensor( + hits_val, particles_val, nb_events_val, cfg, model.embedding_encoder + ) + # Save the new tensor so we can reuse them for later training + torch.save(input_tensor_hits_val, "val/input_tensor_hits_val.pt") + torch.save(input_tensor_particles_val, "val/input_tensor_particles_val.pt") + torch.save(particle_class_val, "val/particle_class_val.pt") + torch.save(padding_mask_hit_val, "val/padding_mask_hit_val.pt") + torch.save(padding_mask_particle_val, "val/padding_mask_particle_val.pt") + + elif cfg.input_type == "tensor": + input_tensor_hits_train = torch.load("train/input_tensor_hits_train.pt") + input_tensor_particles_train = torch.load( + "train/input_tensor_particles_train.pt" + ) + particle_class_train = torch.load("train/particle_class_train.pt") + padding_mask_hit_train = torch.load("train/padding_mask_hit_train.pt") + padding_mask_particle_train = torch.load("train/padding_mask_particle_train.pt") + + input_tensor_hits_val = torch.load("val/input_tensor_hits_val.pt") + input_tensor_particles_val = torch.load("val/input_tensor_particles_val.pt") + particle_class_val = torch.load("val/particle_class_val.pt") + padding_mask_hit_val = torch.load("val/padding_mask_hit_val.pt") + padding_mask_particle_val = torch.load("val/padding_mask_particle_val.pt") + + else: + print("The input type is not recognised") + exit() + # Initialise the metrics metrics_train = metrics(cfg.epoch_nb) metrics_val = metrics(cfg.epoch_nb) - # Prepare the input tensor and padding mask - ( - input_tensor_hits_train, - input_tensor_particles_train, - nb_particles_train, - padding_mask_hit_train, - padding_mask_particle_train, - ) = prepare_input_tensor( - hits_train, particles_train, nb_events_train, cfg, model.embedding_encoder - ) - - # Prepare the input tensor and padding mask - ( - input_tensor_hits_val, - input_tensor_particles_val, - nb_particles_val, - padding_mask_hit_val, - padding_mask_particle_val, - ) = prepare_input_tensor( - hits_val, particles_val, nb_events_val, cfg, model.embedding_encoder - ) - for epoch in range(cfg.epoch_nb): print("Epoch: ", epoch) @@ -793,7 +845,7 @@ def main(): cfg, input_tensor_hits_train, input_tensor_particles_train, - nb_particles_train, + particle_class_train, padding_mask_hit_train, padding_mask_particle_train, model, @@ -804,13 +856,13 @@ def main(): # Validate the model with torch.no_grad(): # Perform the validation of the model - model.eval() + # model.eval() _, metrics_val = run_model( epoch, cfg, input_tensor_hits_val, input_tensor_particles_val, - nb_particles_val, + particle_class_val, padding_mask_hit_val, padding_mask_particle_val, model, @@ -832,7 +884,9 @@ def main(): # Display plot of the loss of the training and validation as a function of the epoch plot_loss(metrics_train.loss, metrics_val.loss, "Loss", cfg.interactive) - plot_loss(metrics_train.loss_nb, metrics_val.loss_nb, "Loss_nb", cfg.interactive) + plot_loss( + metrics_train.loss_class, metrics_val.loss_class, "loss_class", cfg.interactive + ) plot_loss( metrics_train.loss_momentum, metrics_val.loss_momentum, @@ -846,7 +900,6 @@ def main(): # Delete all the variable to free some memory del hits_val del particles_val - del nb_events del metrics_train del metrics_val