From cfd89e16c9868c6ab5c2708f702b4040d4ece5e7 Mon Sep 17 00:00:00 2001 From: Ayano Clarke Date: Wed, 25 Oct 2023 11:55:55 +0800 Subject: [PATCH] Fix transformer model output shape error --- .../modules/python/models/simple_model_transformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pepper_variant/modules/python/models/simple_model_transformers.py b/pepper_variant/modules/python/models/simple_model_transformers.py index 19c9880c..1b9d54b1 100644 --- a/pepper_variant/modules/python/models/simple_model_transformers.py +++ b/pepper_variant/modules/python/models/simple_model_transformers.py @@ -37,7 +37,7 @@ def __init__(self, image_features, gru_layers, hidden_size, num_classes, num_cla self.linear_4 = nn.Linear(self.linear_3_size, self.linear_4_size) self.linear_5 = nn.Linear(self.linear_4_size, self.linear_5_size) - self.output_layer = nn.Linear(self.linear_5_size, self.num_classes) + self.output_layer = nn.Linear(self.linear_5_size, self.num_classes_type) def forward(self, x, hidden, cell_state, train_mode=False): # Reshape for CNN @@ -94,4 +94,4 @@ def init_hidden(self, batch_size, num_layers, bidirectional=True): if bidirectional: num_directions = 2 - return torch.zeros(batch_size, num_directions * num_layers, self.hidden_size) \ No newline at end of file + return torch.zeros(batch_size, num_directions * num_layers, self.hidden_size)