From 56395ba3e9e66a0fa1137f397628d4aea1c1c477 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Tue, 26 Sep 2023 10:55:52 +0200 Subject: [PATCH] replace no_grad with inference_mode (#1323) --- doctr/models/classification/predictor/pytorch.py | 2 +- doctr/models/detection/predictor/pytorch.py | 2 +- doctr/models/kie_predictor/pytorch.py | 2 +- doctr/models/predictor/pytorch.py | 2 +- doctr/models/recognition/predictor/pytorch.py | 2 +- references/classification/latency_pytorch.py | 2 +- references/detection/evaluate_pytorch.py | 2 +- references/detection/latency_pytorch.py | 2 +- references/obj_detection/latency_pytorch.py | 2 +- references/recognition/evaluate_pytorch.py | 2 +- references/recognition/latency_pytorch.py | 2 +- scripts/detect_artefacts.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py index f94c861943..0e38754fe0 100644 --- a/doctr/models/classification/predictor/pytorch.py +++ b/doctr/models/classification/predictor/pytorch.py @@ -33,7 +33,7 @@ def __init__( self.pre_processor = pre_processor self.model = model.eval() - @torch.no_grad() + @torch.inference_mode() def forward( self, crops: List[Union[np.ndarray, torch.Tensor]], diff --git a/doctr/models/detection/predictor/pytorch.py b/doctr/models/detection/predictor/pytorch.py index 93844713c7..78fc1c0c79 100644 --- a/doctr/models/detection/predictor/pytorch.py +++ b/doctr/models/detection/predictor/pytorch.py @@ -32,7 +32,7 @@ def __init__( self.pre_processor = pre_processor self.model = model.eval() - @torch.no_grad() + @torch.inference_mode() def forward( self, pages: List[Union[np.ndarray, torch.Tensor]], diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 3a4ca5740b..bd953ec34f 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -59,7 +59,7 @@ def __init__( self.detect_orientation = detect_orientation self.detect_language = detect_language - @torch.no_grad() + @torch.inference_mode() def forward( self, pages: List[Union[np.ndarray, torch.Tensor]], diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index aa8a878a93..1ca0ac5845 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -59,7 +59,7 @@ def __init__( self.detect_orientation = detect_orientation self.detect_language = detect_language - @torch.no_grad() + @torch.inference_mode() def forward( self, pages: List[Union[np.ndarray, torch.Tensor]], diff --git a/doctr/models/recognition/predictor/pytorch.py b/doctr/models/recognition/predictor/pytorch.py index 9fb7ac04c3..e267f92f8c 100644 --- a/doctr/models/recognition/predictor/pytorch.py +++ b/doctr/models/recognition/predictor/pytorch.py @@ -40,7 +40,7 @@ def __init__( self.dil_factor = 1.4 # Dilation factor to overlap the crops self.target_ar = 6 # Target aspect ratio - @torch.no_grad() + @torch.inference_mode() def forward( self, crops: Sequence[Union[np.ndarray, torch.Tensor]], diff --git a/references/classification/latency_pytorch.py b/references/classification/latency_pytorch.py index 954236f45f..06d9e0f9a2 100644 --- a/references/classification/latency_pytorch.py +++ b/references/classification/latency_pytorch.py @@ -19,7 +19,7 @@ from doctr.models import classification -@torch.no_grad() +@torch.inference_mode() def main(args): device = torch.device("cuda:0" if args.gpu else "cpu") diff --git a/references/detection/evaluate_pytorch.py b/references/detection/evaluate_pytorch.py index e72a767d93..ae15989ebb 100644 --- a/references/detection/evaluate_pytorch.py +++ b/references/detection/evaluate_pytorch.py @@ -26,7 +26,7 @@ from doctr.utils.metrics import LocalizationConfusion -@torch.no_grad() +@torch.inference_mode() def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): # Model in eval mode model.eval() diff --git a/references/detection/latency_pytorch.py b/references/detection/latency_pytorch.py index 47d33494db..ea017056e1 100644 --- a/references/detection/latency_pytorch.py +++ b/references/detection/latency_pytorch.py @@ -19,7 +19,7 @@ from doctr.models import detection -@torch.no_grad() +@torch.inference_mode() def main(args): device = torch.device("cuda:0" if args.gpu else "cpu") diff --git a/references/obj_detection/latency_pytorch.py b/references/obj_detection/latency_pytorch.py index 6dd57a5e85..a0a609d0e5 100644 --- a/references/obj_detection/latency_pytorch.py +++ b/references/obj_detection/latency_pytorch.py @@ -19,7 +19,7 @@ from doctr.models import obj_detection -@torch.no_grad() +@torch.inference_mode() def main(args): device = torch.device("cuda:0" if args.gpu else "cpu") diff --git a/references/recognition/evaluate_pytorch.py b/references/recognition/evaluate_pytorch.py index ccd9e8b2e2..92a91b1bfe 100644 --- a/references/recognition/evaluate_pytorch.py +++ b/references/recognition/evaluate_pytorch.py @@ -22,7 +22,7 @@ from doctr.utils.metrics import TextMatch -@torch.no_grad() +@torch.inference_mode() def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): # Model in eval mode model.eval() diff --git a/references/recognition/latency_pytorch.py b/references/recognition/latency_pytorch.py index 436bd82333..cd3f0afa3a 100644 --- a/references/recognition/latency_pytorch.py +++ b/references/recognition/latency_pytorch.py @@ -19,7 +19,7 @@ from doctr.models import recognition -@torch.no_grad() +@torch.inference_mode() def main(args): device = torch.device("cuda:0" if args.gpu else "cpu") diff --git a/scripts/detect_artefacts.py b/scripts/detect_artefacts.py index e8f4c518c8..abfa7211c8 100644 --- a/scripts/detect_artefacts.py +++ b/scripts/detect_artefacts.py @@ -37,7 +37,7 @@ def plot_predictions(image, boxes, labels): plt.show() -@torch.no_grad() +@torch.inference_mode() def main(args): print(args)