Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] Change Resize kwargs to args for each zoo predictor #1765

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Install all additional dependencies with the following command:

```shell
python -m pip install --upgrade pip
pip install -e .[dev]
pip install -e '.[dev]'
pre-commit install
```

Expand Down
10 changes: 6 additions & 4 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _orientation_predictor(


def crop_orientation_predictor(
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any
) -> OrientationPredictor:
"""Crop orientation classification architecture.

Expand All @@ -77,17 +77,18 @@ def crop_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
pretrained: If True, returns a model pre-trained on our recognition crops dataset
batch_size: number of samples the model processes in parallel
**kwargs: keyword arguments to be passed to the OrientationPredictor

Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs)


def page_orientation_predictor(
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any
) -> OrientationPredictor:
"""Page orientation classification architecture.

Expand All @@ -101,10 +102,11 @@ def page_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
pretrained: If True, returns a model pre-trained on our recognition crops dataset
batch_size: number of samples the model processes in parallel
**kwargs: keyword arguments to be passed to the OrientationPredictor

Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)
17 changes: 16 additions & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def detection_predictor(
arch: Any = "fast_base",
pretrained: bool = False,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
batch_size: int = 2,
**kwargs: Any,
) -> DetectionPredictor:
"""Text detection architecture.
Expand All @@ -94,10 +97,22 @@ def detection_predictor(
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
pretrained: If True, returns a model pre-trained on our text detection dataset
assume_straight_pages: If True, fit straight boxes to the page
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
running the detection model on it
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
batch_size: number of samples the model processes in parallel
**kwargs: optional keyword arguments passed to the architecture

Returns:
-------
Detection predictor
"""
return _predictor(arch, pretrained, assume_straight_pages, **kwargs)
return _predictor(
arch=arch,
pretrained=pretrained,
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
batch_size=batch_size,
**kwargs,
)
1 change: 1 addition & 0 deletions doctr/models/preprocessor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PreProcessor(nn.Module):
batch_size: the size of page batches
mean: mean value of the training distribution by channel
std: standard deviation of the training distribution by channel
**kwargs: additional arguments for the resizing operation
"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions doctr/models/preprocessor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class PreProcessor(NestedObject):
batch_size: the size of page batches
mean: mean value of the training distribution by channel
std: standard deviation of the training distribution by channel
**kwargs: additional arguments for the resizing operation
"""

_children_names: List[str] = ["resize", "normalize"]
Expand Down
12 changes: 10 additions & 2 deletions doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
return predictor


def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor:
def recognition_predictor(
arch: Any = "crnn_vgg16_bn",
pretrained: bool = False,
symmetric_pad: bool = False,
batch_size: int = 128,
**kwargs: Any,
) -> RecognitionPredictor:
"""Text recognition architecture.

Example::
Expand All @@ -66,10 +72,12 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
----
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
pretrained: If True, returns a model pre-trained on our text recognition dataset
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
batch_size: number of samples the model processes in parallel
**kwargs: optional parameters to be passed to the architecture

Returns:
-------
Recognition predictor
"""
return _predictor(arch, pretrained, **kwargs)
return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)
Loading