diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py index daf0e56e58..51563a2ade 100644 --- a/doctr/models/recognition/crnn/pytorch.py +++ b/doctr/models/recognition/crnn/pytorch.py @@ -128,12 +128,9 @@ def __init__( self.feat_extractor = feature_extractor # Resolve the input_size of the LSTM - self.feat_extractor.eval() - with torch.no_grad(): + with torch.inference_mode(): out_shape = self.feat_extractor(torch.zeros((1, *input_shape))).shape lstm_in = out_shape[1] * out_shape[2] - # Switch back to original mode - self.feat_extractor.train() self.decoder = nn.LSTM( input_size=lstm_in,