From 0a226a835138a2bd6eefa7ef298dd29b4812f65b Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 22 Sep 2023 11:09:24 +0200 Subject: [PATCH] fix prob computation for parseq and vitstr models --- doctr/models/recognition/parseq/pytorch.py | 9 ++++----- doctr/models/recognition/parseq/tensorflow.py | 10 +++++----- doctr/models/recognition/vitstr/pytorch.py | 9 ++++----- doctr/models/recognition/vitstr/tensorflow.py | 10 +++++----- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index 5410e0f7e5..3a847b9f4a 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -393,18 +393,17 @@ def __call__( ) -> List[Tuple[str, float]]: # compute pred with argmax for attention models out_idxs = logits.argmax(-1) - # N x L - probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) - # Take the minimum confidence of the sequence - probs = probs.min(dim=1).values.detach().cpu() + preds_prob = torch.softmax(logits, -1).max(dim=-1)[0] # Manual decoding word_values = [ "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs.cpu().numpy() ] + # compute probabilties for each word up to the EOS token + probs = [preds_prob[i, : len(word)].mean().item() for i, word in enumerate(word_values)] - return list(zip(word_values, probs.numpy().tolist())) + return list(zip(word_values, probs)) def _parseq( diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index f72344fb9f..4acf259f7d 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -421,10 +421,7 @@ def __call__( ) -> List[Tuple[str, float]]: # compute pred with argmax for attention models out_idxs = tf.math.argmax(logits, axis=2) - # N x L - probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) - # Take the minimum confidence of the sequence - probs = tf.math.reduce_min(probs, axis=1) + preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1) # decode raw output of the model with tf_label_to_idx out_idxs = tf.cast(out_idxs, dtype="int32") @@ -434,7 +431,10 @@ def __call__( decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] - return list(zip(word_values, probs.numpy().tolist())) + # compute probabilties for each word up to the EOS token + probs = [preds_prob[i, : len(word)].numpy().mean().item() for i, word in enumerate(word_values)] + + return list(zip(word_values, probs)) def _parseq( diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py index 5f68800125..f79cc2d71d 100644 --- a/doctr/models/recognition/vitstr/pytorch.py +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -159,18 +159,17 @@ def __call__( ) -> List[Tuple[str, float]]: # compute pred with argmax for attention models out_idxs = logits.argmax(-1) - # N x L - probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) - # Take the minimum confidence of the sequence - probs = probs.min(dim=1).values.detach().cpu() + preds_prob = torch.softmax(logits, -1).max(dim=-1)[0] # Manual decoding word_values = [ "".join(self._embedding[idx] for idx in encoded_seq).split("")[0] for encoded_seq in out_idxs.cpu().numpy() ] + # compute probabilties for each word up to the EOS token + probs = [preds_prob[i, : len(word)].mean().item() for i, word in enumerate(word_values)] - return list(zip(word_values, probs.numpy().tolist())) + return list(zip(word_values, probs)) def _vitstr( diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index ecf4101209..3536eb721a 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -164,10 +164,7 @@ def __call__( ) -> List[Tuple[str, float]]: # compute pred with argmax for attention models out_idxs = tf.math.argmax(logits, axis=2) - # N x L - probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) - # Take the minimum confidence of the sequence - probs = tf.math.reduce_min(probs, axis=1) + preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1) # decode raw output of the model with tf_label_to_idx out_idxs = tf.cast(out_idxs, dtype="int32") @@ -177,7 +174,10 @@ def __call__( decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0] word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] - return list(zip(word_values, probs.numpy().tolist())) + # compute probabilties for each word up to the EOS token + probs = [preds_prob[i, : len(word)].numpy().mean().item() for i, word in enumerate(word_values)] + + return list(zip(word_values, probs)) def _vitstr(