Skip to content

Commit

Permalink
replace no_grad with inference_mode (#1323)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Sep 26, 2023
1 parent 9fa5e78 commit 56395ba
Show file tree
Hide file tree
Showing 12 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion doctr/models/classification/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self.pre_processor = pre_processor
self.model = model.eval()

@torch.no_grad()
@torch.inference_mode()
def forward(
self,
crops: List[Union[np.ndarray, torch.Tensor]],
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self.pre_processor = pre_processor
self.model = model.eval()

@torch.no_grad()
@torch.inference_mode()
def forward(
self,
pages: List[Union[np.ndarray, torch.Tensor]],
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self.detect_orientation = detect_orientation
self.detect_language = detect_language

@torch.no_grad()
@torch.inference_mode()
def forward(
self,
pages: List[Union[np.ndarray, torch.Tensor]],
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self.detect_orientation = detect_orientation
self.detect_language = detect_language

@torch.no_grad()
@torch.inference_mode()
def forward(
self,
pages: List[Union[np.ndarray, torch.Tensor]],
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self.dil_factor = 1.4 # Dilation factor to overlap the crops
self.target_ar = 6 # Target aspect ratio

@torch.no_grad()
@torch.inference_mode()
def forward(
self,
crops: Sequence[Union[np.ndarray, torch.Tensor]],
Expand Down
2 changes: 1 addition & 1 deletion references/classification/latency_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from doctr.models import classification


@torch.no_grad()
@torch.inference_mode()
def main(args):
device = torch.device("cuda:0" if args.gpu else "cpu")

Expand Down
2 changes: 1 addition & 1 deletion references/detection/evaluate_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from doctr.utils.metrics import LocalizationConfusion


@torch.no_grad()
@torch.inference_mode()
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion references/detection/latency_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from doctr.models import detection


@torch.no_grad()
@torch.inference_mode()
def main(args):
device = torch.device("cuda:0" if args.gpu else "cpu")

Expand Down
2 changes: 1 addition & 1 deletion references/obj_detection/latency_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from doctr.models import obj_detection


@torch.no_grad()
@torch.inference_mode()
def main(args):
device = torch.device("cuda:0" if args.gpu else "cpu")

Expand Down
2 changes: 1 addition & 1 deletion references/recognition/evaluate_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from doctr.utils.metrics import TextMatch


@torch.no_grad()
@torch.inference_mode()
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion references/recognition/latency_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from doctr.models import recognition


@torch.no_grad()
@torch.inference_mode()
def main(args):
device = torch.device("cuda:0" if args.gpu else "cpu")

Expand Down
2 changes: 1 addition & 1 deletion scripts/detect_artefacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def plot_predictions(image, boxes, labels):
plt.show()


@torch.no_grad()
@torch.inference_mode()
def main(args):
print(args)

Expand Down

0 comments on commit 56395ba

Please sign in to comment.