Skip to content

Commit

Permalink
update typings
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Dec 4, 2024
1 parent 75e233c commit 768ec80
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward(
if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def forward(
if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def forward(
if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(prob_map: torch.Tensor) -> List[Dict[str, Any]]:
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
return [
dict(zip(self.class_names, preds))
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/crnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward(
if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/master/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def forward(
if return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def forward(
if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(logits: torch.Tensor) -> List[Tuple[str, float]]:
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(logits)

# Post-process boxes
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def forward(
if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(decoded_features: torch.Tensor) -> List[Tuple[str, float]]:
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(decoded_features)

# Post-process boxes
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(
if target is None or return_preds:
# Disable for torch.compile compatibility
@torch.compiler.disable # type: ignore[attr-defined]
def _postprocess(decoded_features: torch.Tensor) -> List[Tuple[str, float]]:
def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
return self.postprocessor(decoded_features)

# Post-process boxes
Expand Down

0 comments on commit 768ec80

Please sign in to comment.