Skip to content

Commit

Permalink
update type check + mypy torch decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Nov 21, 2024
1 parent 93b4a54 commit 096a3a9
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 17 deletions.
7 changes: 4 additions & 3 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ def _orientation_predictor(
# Load directly classifier from backbone
_model = classification.__dict__[arch](pretrained=pretrained)
else:
allowed_archs = (classification.MobileNetV3,)
allowed_archs = [classification.MobileNetV3]
if is_torch_available():
# The following is required for torch compiled models
import torch

allowed_archs += (torch._dynamo.eval_frame.OptimizedModule,)
allowed_archs.append(torch._dynamo.eval_frame.OptimizedModule)

if not isinstance(arch, allowed_archs):
if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def forward(

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def forward(

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
Expand Down
8 changes: 5 additions & 3 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
if isinstance(_model, detection.FAST):
_model = reparameterize(_model)
else:
allowed_archs = (detection.DBNet, detection.LinkNet, detection.FAST)
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
if is_torch_available():
# The following is required for torch compiled models
import torch

allowed_archs += (torch._dynamo.eval_frame.OptimizedModule,)
if not isinstance(arch, allowed_archs):
allowed_archs.append(torch._dynamo.eval_frame.OptimizedModule)

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")

_model = arch
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/crnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def forward(

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(logits)

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def forward(

if return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(logits)

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def forward(

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(logits)

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def forward(

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(decoded_features: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(decoded_features)

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(

if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(decoded_features: torch.Tensor) -> List[Tuple[str, float]]:
return self.postprocessor(decoded_features)

Expand Down
8 changes: 5 additions & 3 deletions doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
)
else:
allowed_archs = (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
if is_torch_available():
# The following is required for torch compiled models
import torch

allowed_archs += (torch._dynamo.eval_frame.OptimizedModule,)
if not isinstance(arch, allowed_archs):
allowed_archs.append(torch._dynamo.eval_frame.OptimizedModule)

if not isinstance(arch, tuple(allowed_archs)):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

Expand Down

0 comments on commit 096a3a9

Please sign in to comment.