diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index 25efc23252..294e0236b4 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -403,7 +403,9 @@ def __call__( for encoded_seq in out_idxs.cpu().numpy() ] # compute probabilties for each word up to the EOS token - probs = [preds_prob[i, : len(word)].clip(0, 1).mean().item() for i, word in enumerate(word_values)] + probs = [ + preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + ] return list(zip(word_values, probs)) diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 21a35605f5..28dc625cc9 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -432,7 +432,10 @@ def __call__( word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] # compute probabilties for each word up to the EOS token - probs = [preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() for i, word in enumerate(word_values)] + probs = [ + preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0 + for i, word in enumerate(word_values) + ] return list(zip(word_values, probs)) diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py index 0ff05f826a..6f68351d8d 100644 --- a/doctr/models/recognition/vitstr/pytorch.py +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -167,7 +167,9 @@ def __call__( for encoded_seq in out_idxs.cpu().numpy() ] # compute probabilties for each word up to the EOS token - probs = [preds_prob[i, : len(word)].clip(0, 1).mean().item() for i, word in enumerate(word_values)] + probs = [ + preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + ] return list(zip(word_values, probs)) diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 70c7325b3f..117ba6a7aa 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -175,7 +175,10 @@ def __call__( word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] # compute probabilties for each word up to the EOS token - probs = [preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() for i, word in enumerate(word_values)] + probs = [ + preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0 + for i, word in enumerate(word_values) + ] return list(zip(word_values, probs))