Skip to content

Commit

Permalink
[TF] Move model building & unify train scripts (#1744)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Oct 9, 2024
1 parent 59f1c30 commit 2f9f50e
Show file tree
Hide file tree
Showing 30 changed files with 93 additions and 78 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/references.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,16 @@ jobs:
unzip toy_recogition_set-036a4d80.zip -d reco_set
- if: matrix.framework == 'tensorflow'
name: Train for a short epoch (TF) (document orientation)
run: python references/classification/train_tensorflow_orientation.py ./det_set ./det_set resnet18 page -b 2 --epochs 1
run: python references/classification/train_tensorflow_orientation.py resnet18 --type page --train_path ./det_set --val_path ./det_set -b 2 --epochs 1
- if: matrix.framework == 'pytorch'
name: Train for a short epoch (PT) (document orientation)
run: python references/classification/train_pytorch_orientation.py ./det_set ./det_set resnet18 page -b 2 --epochs 1
run: python references/classification/train_pytorch_orientation.py resnet18 --type page --train_path ./det_set --val_path ./det_set -b 2 --epochs 1
- if: matrix.framework == 'tensorflow'
name: Train for a short epoch (TF) (crop orientation)
run: python references/classification/train_tensorflow_orientation.py ./reco_set ./reco_set resnet18 crop -b 4 --epochs 1
run: python references/classification/train_tensorflow_orientation.py resnet18 --type crop --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1
- if: matrix.framework == 'pytorch'
name: Train for a short epoch (PT) (crop orientation)
run: python references/classification/train_pytorch_orientation.py ./reco_set ./reco_set resnet18 crop -b 4 --epochs 1
run: python references/classification/train_pytorch_orientation.py resnet18 --type crop --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1

train-text-recognition:
runs-on: ${{ matrix.os }}
Expand Down Expand Up @@ -318,10 +318,10 @@ jobs:
unzip toy_detection_set-bbbb4243.zip -d det_set
- if: matrix.framework == 'tensorflow'
name: Train for a short epoch (TF)
run: python references/detection/train_tensorflow.py --train_path ./det_set --val_path ./det_set linknet_resnet18 -b 2 --epochs 1
run: python references/detection/train_tensorflow.py linknet_resnet18 --train_path ./det_set --val_path ./det_set -b 2 --epochs 1
- if: matrix.framework == 'pytorch'
name: Train for a short epoch (PT)
run: python references/detection/train_pytorch.py ./det_set ./det_set db_mobilenet_v3_large -b 2 --epochs 1
run: python references/detection/train_pytorch.py db_mobilenet_v3_large --train_path ./det_set --val_path ./det_set -b 2 --epochs 1

evaluate-text-detection:
runs-on: ${{ matrix.os }}
Expand Down
6 changes: 4 additions & 2 deletions doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from doctr.datasets import VOCABS

from ...utils import load_pretrained_params
from ...utils import _build_model, load_pretrained_params
from ..resnet.tensorflow import ResNet

__all__ = ["magc_resnet31"]
Expand Down Expand Up @@ -115,7 +115,7 @@ def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
# Context modeling: B, H, W, C -> B, 1, 1, C
context = self.context_modeling(inputs)
# Transform: B, 1, 1, C -> B, 1, 1, C
transformed = self.transform(context)
transformed = self.transform(context, **kwargs)
return inputs + transformed


Expand Down Expand Up @@ -152,6 +152,8 @@ def _magc_resnet(
cfg=_cfg,
**kwargs,
)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensorflow.keras.models import Sequential

from ....datasets import VOCABS
from ...utils import conv_sequence, load_pretrained_params
from ...utils import _build_model, conv_sequence, load_pretrained_params

__all__ = [
"MobileNetV3",
Expand Down Expand Up @@ -295,6 +295,8 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
cfg=_cfg,
**kwargs,
)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
Expand Down
5 changes: 4 additions & 1 deletion doctr/models/classification/resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from doctr.datasets import VOCABS

from ...utils import conv_sequence, load_pretrained_params
from ...utils import _build_model, conv_sequence, load_pretrained_params

__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]

Expand Down Expand Up @@ -210,6 +210,8 @@ def _resnet(
model = ResNet(
num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
Expand Down Expand Up @@ -358,6 +360,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
)

model.cfg = _cfg
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/classification/textnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from doctr.datasets import VOCABS

from ...modules.layers.tensorflow import FASTConvLayer
from ...utils import conv_sequence, load_pretrained_params
from ...utils import _build_model, conv_sequence, load_pretrained_params

__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]

Expand Down Expand Up @@ -111,6 +111,8 @@ def _textnet(

# Build the model
model = TextNet(cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/classification/vgg/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from doctr.datasets import VOCABS

from ...utils import conv_sequence, load_pretrained_params
from ...utils import _build_model, conv_sequence, load_pretrained_params

__all__ = ["VGG", "vgg16_bn_r"]

Expand Down Expand Up @@ -81,6 +81,8 @@ def _vgg(

# Build the model
model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
from doctr.utils.repr import NestedObject

from ...utils import load_pretrained_params
from ...utils import _build_model, load_pretrained_params

__all__ = ["vit_s", "vit_b"]

Expand Down Expand Up @@ -121,6 +121,8 @@ def _vit(

# Build the model
model = VisionTransformer(cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
Expand Down
11 changes: 10 additions & 1 deletion doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from tensorflow.keras.applications import ResNet50

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_to_float32,
_build_model,
conv_sequence,
load_pretrained_params,
)
from doctr.utils.repr import NestedObject

from ...classification import mobilenet_v3_large
Expand Down Expand Up @@ -304,6 +310,8 @@ def _db_resnet(

# Build the model
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down Expand Up @@ -347,6 +355,7 @@ def _db_mobilenet(

# Build the model
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)
# Load pretrained parameters
if pretrained:
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down
7 changes: 3 additions & 4 deletions doctr/models/detection/fast/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensorflow.keras import Model, Sequential, layers

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
from doctr.utils.repr import NestedObject

from ...classification import textnet_base, textnet_small, textnet_tiny
Expand Down Expand Up @@ -333,6 +333,8 @@ def _fast(

# Build the model
model = FAST(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
Expand All @@ -342,9 +344,6 @@ def _fast(
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
)

# Build the model for reparameterization to access the layers
_ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)

return model


Expand Down
14 changes: 11 additions & 3 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_to_float32,
_build_model,
conv_sequence,
load_pretrained_params,
)
from doctr.utils.repr import NestedObject

from .base import LinkNetPostProcessor, _LinkNet
Expand Down Expand Up @@ -79,10 +85,10 @@ def __init__(
for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
]

def call(self, x: List[tf.Tensor]) -> tf.Tensor:
def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor:
out = 0
for decoder, fmap in zip(self.decoders, x[::-1]):
out = decoder(out + fmap)
out = decoder(out + fmap, **kwargs)
return out

def extra_repr(self) -> str:
Expand Down Expand Up @@ -274,6 +280,8 @@ def _linknet(

# Build the model
model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down
6 changes: 0 additions & 6 deletions doctr/models/factory/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

if is_torch_available():
import torch
elif is_tf_available():
import tensorflow as tf

__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]

Expand Down Expand Up @@ -76,8 +74,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
torch.save(model.state_dict(), weights_path)
elif is_tf_available():
weights_path = save_directory / "tf_model.weights.h5"
# NOTE: `model.build` is not an option because it doesn't runs in eager mode
_ = model(tf.ones((1, *model.cfg["input_shape"])), training=False)
model.save_weights(str(weights_path))

config_path = save_directory / "config.json"
Expand Down Expand Up @@ -229,8 +225,6 @@ def from_hub(repo_id: str, **kwargs: Any):
model.load_state_dict(state_dict)
else: # tf
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
# NOTE: `model.build` is not an option because it doesn't runs in eager mode
_ = model(tf.ones((1, *model.cfg["input_shape"])), training=False)
model.load_weights(weights)

return model
2 changes: 1 addition & 1 deletion doctr/models/preprocessor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self.resize = Resize(output_size, **kwargs)
# Perform the division by 255 at the same time
self.normalize = Normalize(mean, std)
self._runs_on_cuda = tf.test.is_gpu_available()
self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []

def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
"""Gather samples into batches for inference purposes
Expand Down
3 changes: 2 additions & 1 deletion doctr/models/recognition/crnn/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.datasets import VOCABS

from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
Expand Down Expand Up @@ -245,6 +245,7 @@ def _crnn(

# Build the model
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)
# Load pretrained parameters
if pretrained:
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.models.classification import magc_resnet31
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -290,6 +290,8 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
cfg=_cfg,
**kwargs,
)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward

from ...classification import vit_s
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -473,6 +473,8 @@ def _parseq(

# Build the model
model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down
3 changes: 2 additions & 1 deletion doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.utils.repr import NestedObject

from ...classification import resnet31
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["SAR", "sar_resnet31"]
Expand Down Expand Up @@ -392,6 +392,7 @@ def _sar(

# Build the model
model = SAR(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)
# Load pretrained parameters
if pretrained:
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down
4 changes: 3 additions & 1 deletion doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from doctr.datasets import VOCABS

from ...classification import vit_b, vit_s
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
Expand Down Expand Up @@ -216,6 +216,8 @@ def _vitstr(

# Build the model
model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down
Loading

0 comments on commit 2f9f50e

Please sign in to comment.