From 0b1e3e342db5ac90d392401cd2f056b1dd6a06ac Mon Sep 17 00:00:00 2001 From: "jun.sun" Date: Wed, 27 May 2020 20:31:54 +0800 Subject: [PATCH] support transformer --- README.md | 3 + configs/resnet_fc.py | 8 +- configs/satrn.py | 307 ++++++++++++++++++ configs/small_satrn.py | 302 +++++++++++++++++ configs/srn.py | 298 +++++++++++++++++ configs/tps_resnet_bilstm_attn.py | 9 +- requirements.txt | 4 +- vedastr/assembler/assembler.py | 2 +- vedastr/converter/__init__.py | 3 +- vedastr/converter/base_convert.py | 14 +- vedastr/converter/satrn_converter.py | 50 +++ vedastr/criteria/cross_entropy_loss.py | 1 - vedastr/datasets/transforms/__init__.py | 2 +- vedastr/datasets/transforms/transforms.py | 278 +++++++++++++++- vedastr/lr_schedulers/__init__.py | 1 + vedastr/lr_schedulers/constant_lr.py | 18 + vedastr/lr_schedulers/exponential_lr.py | 4 +- vedastr/lr_schedulers/step_lr.py | 18 +- vedastr/models/bodies/body.py | 12 +- .../decoders/bricks/__init__.py | 1 + .../decoders/bricks/bricks.py | 1 - .../feature_extractors/decoders/bricks/pva.py | 38 +++ .../encoders/backbones/resnet.py | 23 +- .../encoders/enhance_modules/aspp.py | 4 +- vedastr/models/bodies/sequences/__init__.py | 5 +- vedastr/models/bodies/sequences/builder.py | 14 + .../bodies/sequences/decoders/__init__.py | 2 - .../bodies/sequences/decoders/builder.py | 8 - .../bodies/sequences/encoders/__init__.py | 2 - .../bodies/sequences/encoders/builder.py | 8 - .../bodies/sequences/encoders/registry.py | 3 - .../sequences/{decoders => }/registry.py | 1 + .../models/bodies/sequences/rnn/__init__.py | 2 + .../{decoders/rnn_cell.py => rnn/decoder.py} | 5 +- .../{encoders/rnn.py => rnn/encoder.py} | 2 +- .../bodies/sequences/transformer/__init__.py | 2 + .../bodies/sequences/transformer/decoder.py | 38 +++ .../bodies/sequences/transformer/encoder.py | 38 +++ .../transformer/position_encoder/__init__.py | 3 + .../position_encoder/adaptive_2d_encoder.py | 48 +++ .../transformer/position_encoder/builder.py | 8 + .../transformer/position_encoder/encoder.py | 21 ++ .../transformer/position_encoder/registry.py | 3 + .../transformer/position_encoder/utils.py | 14 + .../sequences/transformer/unit/__init__.py | 3 + .../transformer/unit/attention/__init__.py | 2 + .../transformer/unit/attention/builder.py | 8 + .../unit/attention/multihead_attention.py | 62 ++++ .../transformer/unit/attention/registry.py | 3 + .../sequences/transformer/unit/builder.py | 14 + .../sequences/transformer/unit/decoder.py | 46 +++ .../sequences/transformer/unit/encoder.py | 68 ++++ .../transformer/unit/feedforward/__init__.py | 2 + .../transformer/unit/feedforward/builder.py | 8 + .../unit/feedforward/feedforward.py | 18 + .../transformer/unit/feedforward/registry.py | 3 + .../sequences/transformer/unit/registry.py | 4 + vedastr/models/heads/__init__.py | 2 + vedastr/models/heads/fc_head.py | 7 +- vedastr/models/heads/head.py | 35 ++ vedastr/models/heads/transformer_head.py | 83 +++++ vedastr/models/utils/__init__.py | 1 + vedastr/models/utils/conv_module.py | 2 +- vedastr/models/weight_init.py | 2 +- vedastr/runner/runner.py | 98 +++--- vedastr/utils/metrics.py | 31 +- 66 files changed, 1959 insertions(+), 171 deletions(-) create mode 100644 configs/satrn.py create mode 100644 configs/small_satrn.py create mode 100644 configs/srn.py create mode 100644 vedastr/converter/satrn_converter.py create mode 100644 vedastr/lr_schedulers/constant_lr.py create mode 100644 vedastr/models/bodies/feature_extractors/decoders/bricks/pva.py create mode 100644 vedastr/models/bodies/sequences/builder.py delete mode 100644 vedastr/models/bodies/sequences/decoders/__init__.py delete mode 100644 vedastr/models/bodies/sequences/decoders/builder.py delete mode 100644 vedastr/models/bodies/sequences/encoders/__init__.py delete mode 100644 vedastr/models/bodies/sequences/encoders/builder.py delete mode 100644 vedastr/models/bodies/sequences/encoders/registry.py rename vedastr/models/bodies/sequences/{decoders => }/registry.py (63%) create mode 100644 vedastr/models/bodies/sequences/rnn/__init__.py rename vedastr/models/bodies/sequences/{decoders/rnn_cell.py => rnn/decoder.py} (98%) rename vedastr/models/bodies/sequences/{encoders/rnn.py => rnn/encoder.py} (97%) create mode 100644 vedastr/models/bodies/sequences/transformer/__init__.py create mode 100644 vedastr/models/bodies/sequences/transformer/decoder.py create mode 100644 vedastr/models/bodies/sequences/transformer/encoder.py create mode 100644 vedastr/models/bodies/sequences/transformer/position_encoder/__init__.py create mode 100644 vedastr/models/bodies/sequences/transformer/position_encoder/adaptive_2d_encoder.py create mode 100644 vedastr/models/bodies/sequences/transformer/position_encoder/builder.py create mode 100644 vedastr/models/bodies/sequences/transformer/position_encoder/encoder.py create mode 100644 vedastr/models/bodies/sequences/transformer/position_encoder/registry.py create mode 100644 vedastr/models/bodies/sequences/transformer/position_encoder/utils.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/__init__.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/attention/__init__.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/attention/builder.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/attention/multihead_attention.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/attention/registry.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/builder.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/decoder.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/encoder.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/feedforward/__init__.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/feedforward/builder.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/feedforward/feedforward.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/feedforward/registry.py create mode 100644 vedastr/models/bodies/sequences/transformer/unit/registry.py create mode 100644 vedastr/models/heads/head.py create mode 100644 vedastr/models/heads/transformer_head.py diff --git a/README.md b/README.md index 2b6d19e..dda625b 100644 --- a/README.md +++ b/README.md @@ -41,9 +41,12 @@ Note: |:----:|:----:| :----: | :----: |:----: |:----: |:----: |:----: |:----: | :----:| |[TPS-ResNet-BiLSTM-Attention](https://drive.google.com/open?id=1b5ykMGwLFyt-tpoWBMyhgjABaqxKBxRU)| False|87.33 | 87.79 | 95.04| 92.61|74.45|81.09|74.91|84.95| |[ResNet-FC](https://drive.google.com/open?id=105kvjvSAwyxv_6VsCI0kWEmKkqQX8jul)| False|85.03 | 86.4 | 94| 91.03|70.29|77.67|71.43|82.38| +|[Small-SATRN]()| False|88.87 | 88.87 | 96.19 | 93.99|79.08|84.81|84.67|87.55| AVERAGE : Average accuracy over all test datasets\ TPS : [Spatial transformer network](https://arxiv.org/abs/1603.03915)\ +Small-SATRN: [On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention](https://arxiv.org/abs/1910.04396), +training phase is case sensitive while testing phase is case insensitive. \ CASE SENSITIVE : If true, the output is case sensitive and contain common characters. If false, the output is not case sentive and contains only numbers and letters. diff --git a/configs/resnet_fc.py b/configs/resnet_fc.py index 08057c1..0a02cc7 100644 --- a/configs/resnet_fc.py +++ b/configs/resnet_fc.py @@ -86,20 +86,20 @@ transforms=transforms, datasets=valid_dataset, loader=dict( - type='TestDataloader', + type='RawDataloader', batch_size=batch_size, num_workers=4, - shuffle=False, + shuffle=True, ), ), test=dict( transforms=transforms, datasets=test_dataset, loader=dict( - type='TestDataloader', + type='RawDataloader', batch_size=batch_size, num_workers=4, - shuffle=False, + shuffle=True, ), ), ) diff --git a/configs/satrn.py b/configs/satrn.py new file mode 100644 index 0000000..1a827fe --- /dev/null +++ b/configs/satrn.py @@ -0,0 +1,307 @@ +# work dir +root_workdir = 'workdir/' + +# seed +seed = 1111 + +# 1. logging +logger = dict( + handlers=( + dict(type='StreamHandler', level='INFO'), + dict(type='FileHandler', level='INFO'), + ), +) + +# 2. data +batch_size = 4 +mean, std = 0.5, 0.5 # normalize mean and std +size = (32, 100) +batch_max_length = 25 +fill = 0 +mode = 'bilinear' +data_filter_off = False +train_sensitive = True +train_character = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' # need character +test_sensitive = False +test_character = '0123456789abcdefghijklmnopqrstuvwxyz' + +# dataset params +train_dataset_params = dict( + batch_max_length=batch_max_length, + data_filter_off=data_filter_off, + character=train_character, +) +test_dataset_params = dict( + batch_max_length=batch_max_length, + data_filter_off=data_filter_off, + character=test_character, +) + +data_root = './data/data_lmdb_release/' + +# train data +train_root = data_root + 'training/' +## MJ dataset +train_root_mj = train_root + 'MJ/' +mj_folder_names = ['/MJ_test', 'MJ_valid', 'MJ_train'] +## ST dataset +train_root_st = train_root + 'ST/' + +train_dataset_mj = [dict(type='LmdbDataset', root=train_root_mj + folder_name) for folder_name in mj_folder_names] +train_dataset_st = [dict(type='LmdbDataset', root=train_root_st)] + +# valid +valid_root = data_root + 'validation/' +valid_dataset = [dict(type='LmdbDataset', root=valid_root, **test_dataset_params)] + +# test +test_root = data_root + 'evaluation/' +test_folder_names = ['CUTE80', 'IC03_867', 'IC13_1015', 'IC15_2077', 'IIIT5k_3000', 'SVT', 'SVTP'] +test_dataset = [dict(type='LmdbDataset', root=test_root + folder_name, **test_dataset_params) for folder_name in + test_folder_names] + +# transforms +train_transforms = [ + dict(type='Sensitive', sensitive=train_sensitive), + dict(type='ColorToGray'), + dict(type='RandomNormalRotation', mean=0, std=34, expand=True, center=None, fill=fill, mode=mode, p=0.5), + dict(type='Resize', size=size), + dict(type='ToTensor'), + dict(type='Normalize', mean=mean, std=std), +] +test_transforms = [ + dict(type='Sensitive', sensitive=test_sensitive), + dict(type='ColorToGray'), + dict(type='Resize', size=size), + dict(type='ToTensor'), + dict(type='Normalize', mean=mean, std=std), +] + +data = dict( + train=dict( + transforms=train_transforms, + datasets=[ + dict( + type='ConcatDatasets', + datasets=train_dataset_mj, + **train_dataset_params, + ), + dict( + type='ConcatDatasets', + datasets=train_dataset_st, + **train_dataset_params, + ), + ], + loader=dict( + type='BatchBalanceDataloader', + batch_size=batch_size, + each_batch_ratio=[0.5, 0.5], + each_usage=[1.0, 1.0], + shuffle=True, + ), + ), + val=dict( + transforms=test_transforms, + datasets=valid_dataset, + loader=dict( + type='TestDataloader', + batch_size=batch_size, + num_workers=4, + shuffle=False, + ), + ), + test=dict( + transforms=test_transforms, + datasets=test_dataset, + loader=dict( + type='TestDataloader', + batch_size=batch_size, + num_workers=4, + shuffle=False, + ), + ), +) + +test_cfg = dict( + sensitive=test_sensitive, + character=test_character, +) + +# 3. converter +converter = dict( + type='SATRNConverter', + character=train_character, + batch_max_length=batch_max_length, + go_last=True, +) + +# 4. model +dropout = 0.1 +n_e = 12 +n_d = 6 +hidden_dim = 512 +n_head = 8 +batch_norm = dict(type='BN') +layer_norm = dict(type='LayerNorm', normalized_shape=hidden_dim) +num_class = len(train_character) + 1 +num_steps = batch_max_length + 1 +model = dict( + type='GModel', + need_text=True, + body=dict( + type='GBody', + pipelines=[ + dict( + type='FeatureExtractorComponent', + from_layer='input', + to_layer='cnn_feat', + arch=dict( + encoder=dict( + backbone=dict( + type='GResNet', + layers=[ + ('conv', dict(type='ConvModule', in_channels=1, out_channels=64, kernel_size=3, + stride=1, padding=1, norm_cfg=batch_norm)), + ('conv', dict(type='ConvModule', in_channels=64, out_channels=128, kernel_size=3, + stride=1, padding=1, norm_cfg=batch_norm)), + ('pool', dict(type='MaxPool2d', kernel_size=2, stride=2, padding=0)), + ('conv', dict(type='ConvModule', in_channels=128, out_channels=256, kernel_size=3, + stride=1, padding=1, norm_cfg=batch_norm)), + ('conv', dict(type='ConvModule', in_channels=256, out_channels=512, kernel_size=3, + stride=1, padding=1, norm_cfg=batch_norm)), + ('pool', dict(type='MaxPool2d', kernel_size=2, stride=2, padding=0)), + ], + ), + ), + collect=dict(type='CollectBlock', from_layer='c2'), + ), + ), + dict( + type='SequenceEncoderComponent', + from_layer='cnn_feat', + to_layer='src', + arch=dict( + type='TransformerEncoder', + position_encoder=dict( + type='Adaptive2DPositionEncoder', + in_channels=hidden_dim, + max_h=100, + max_w=100, + dropout=dropout, + ), + encoder_layer=dict( + type='TransformerEncoderLayer2D', + attention=dict( + type='MultiHeadAttention', + in_channels=hidden_dim, + k_channels=hidden_dim, + v_channels=hidden_dim, + n_head=n_head, + dropout=dropout, + ), + attention_norm=layer_norm, + feedforward=dict( + type='Feedforward', + layers=[ + dict(type='ConvModule', in_channels=hidden_dim, out_channels=hidden_dim*4, kernel_size=3, padding=1, + activation='relu', dropout=dropout), + dict(type='ConvModule', in_channels=hidden_dim*4, out_channels=hidden_dim, kernel_size=3, padding=1, + activation=None, dropout=dropout), + ] + ), + feedforward_norm=layer_norm, + ), + num_layers=n_e, + ), + ), + ], + ), + head=dict( + type='TransformerHead', + src_from='src', + decoder=dict( + type='TransformerDecoder', + position_encoder=dict( + type='PositionEncoder1D', + in_channels=hidden_dim, + max_len=100, + dropout=dropout, + ), + decoder_layer=dict( + type='TransformerDecoderLayer1D', + self_attention=dict( + type='MultiHeadAttention', + in_channels=hidden_dim, + k_channels=hidden_dim, + v_channels=hidden_dim, + n_head=n_head, + dropout=dropout, + ), + self_attention_norm=layer_norm, + attention=dict( + type='MultiHeadAttention', + in_channels=hidden_dim, + k_channels=hidden_dim, + v_channels=hidden_dim, + n_head=n_head, + dropout=dropout, + ), + attention_norm=layer_norm, + feedforward=dict( + type='Feedforward', + layers=[ + dict(type='FCModule', in_channels=hidden_dim, out_channels=hidden_dim * 4, bias=True, + activation='relu', + dropout=dropout), + dict(type='FCModule', in_channels=hidden_dim * 4, out_channels=hidden_dim, bias=True, + activation=None, + dropout=dropout), + ] + ), + feedforward_norm=layer_norm, + ), + num_layers=n_d, + ), + generator=dict( + type='Linear', + in_features=hidden_dim, + out_features=num_class, + ), + embedding=dict( + type='Embedding', + num_embeddings=num_class + 1, + embedding_dim=hidden_dim, + padding_idx=num_class, + ), + num_steps=num_steps, + pad_id=num_class, + ), +) + +## 4.1 resume +resume = None + +# 5. criterion +criterion = dict(type='CrossEntropyLoss', ignore_index=num_class) + +# 6. optim +optimizer = dict(type='Adam', lr=1e-4) + +# 7. lr scheduler +epochs = 6 +decay_epochs = [2,4] +niter_per_epoch = int(55000 * 256 / batch_size) +milestones = [niter_per_epoch * epoch for epoch in decay_epochs] +max_iterations = epochs * niter_per_epoch +lr_scheduler = dict(type='StepLR', niter_per_epoch=niter_per_epoch, max_epochs=epochs, milestones=milestones) + +# 8. runner +runner = dict( + type='Runner', + iterations=max_iterations, + trainval_ratio=2000, + snapshot_interval=20000, +) + +# 9. device +gpu_id = '0' diff --git a/configs/small_satrn.py b/configs/small_satrn.py new file mode 100644 index 0000000..233d61c --- /dev/null +++ b/configs/small_satrn.py @@ -0,0 +1,302 @@ +# work dir +root_workdir = 'workdir/' + +# seed +seed = 1111 + +# 1. logging +logger = dict( + handlers=( + dict(type='StreamHandler', level='INFO'), + dict(type='FileHandler', level='INFO'), + ), +) + +# 2. data +batch_size = 256 +mean, std = 0.5, 0.5 # normalize mean and std +size = (32, 100) +batch_max_length = 25 +fill = 0 +mode = 'nearest' +data_filter_off = False +train_sensitive = True +train_character = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' # need character +test_sensitive = False +test_character = '0123456789abcdefghijklmnopqrstuvwxyz' + +# dataset params +train_dataset_params = dict( + batch_max_length=batch_max_length, + data_filter_off=data_filter_off, + character=train_character, +) +test_dataset_params = dict( + batch_max_length=batch_max_length, + data_filter_off=data_filter_off, + character=test_character, +) + +data_root = './data/data_lmdb_release/' + +# train data +train_root = data_root + 'training/' +## MJ dataset +train_root_mj = train_root + 'MJ/' +mj_folder_names = ['/MJ_test', 'MJ_valid', 'MJ_train'] +## ST dataset +train_root_st = train_root + 'ST/' + +train_dataset_mj = [dict(type='LmdbDataset', root=train_root_mj + folder_name) for folder_name in mj_folder_names] +train_dataset_st = [dict(type='LmdbDataset', root=train_root_st)] + +# valid +valid_root = data_root + 'validation/' +valid_dataset = [dict(type='LmdbDataset', root=valid_root, **test_dataset_params)] + +# test +test_root = data_root + 'evaluation/' +test_folder_names = ['CUTE80', 'IC03_867', 'IC13_1015', 'IC15_2077', 'IIIT5k_3000', 'SVT', 'SVTP'] +test_dataset = [dict(type='LmdbDataset', root=test_root + folder_name, **test_dataset_params) for folder_name in + test_folder_names] + +# transforms +train_transforms = [ + dict(type='Sensitive', sensitive=train_sensitive), + dict(type='ColorToGray'), + dict(type='RandomNormalRotation', mean=0, std=34, expand=True, center=None, fill=fill, mode=mode, p=0.5), + dict(type='Resize', size=size), + dict(type='ToTensor'), + dict(type='Normalize', mean=mean, std=std), +] +test_transforms = [ + dict(type='Sensitive', sensitive=test_sensitive), + dict(type='ColorToGray'), + dict(type='Resize', size=size), + dict(type='ToTensor'), + dict(type='Normalize', mean=mean, std=std), +] + +data = dict( + train=dict( + transforms=train_transforms, + datasets=[ + dict( + type='ConcatDatasets', + datasets=train_dataset_mj, + **train_dataset_params, + ), + dict( + type='ConcatDatasets', + datasets=train_dataset_st, + **train_dataset_params, + ), + ], + loader=dict( + type='BatchBalanceDataloader', + batch_size=batch_size, + each_batch_ratio=[0.5, 0.5], + each_usage=[1.0, 1.0], + shuffle=True, + num_workers=4, + ), + ), + val=dict( + transforms=test_transforms, + datasets=valid_dataset, + loader=dict( + type='TestDataloader', + batch_size=batch_size, + num_workers=4, + shuffle=False, + ), + ), + test=dict( + transforms=test_transforms, + datasets=test_dataset, + loader=dict( + type='TestDataloader', + batch_size=batch_size, + num_workers=4, + shuffle=False, + ), + ), +) + +test_cfg = dict( + sensitive=test_sensitive, + character=test_character, +) + +# 3. converter +converter = dict( + type='SATRNConverter', + character=train_character, + batch_max_length=batch_max_length, + go_last=True, +) + +# 4. model +dropout = 0.1 +n_e = 9 +n_d = 3 +hidden_dim = 256 +n_head = 8 +batch_norm = dict(type='BN') +layer_norm = dict(type='LayerNorm', normalized_shape=hidden_dim) +num_class = len(train_character) + 1 +num_steps = batch_max_length + 1 +model = dict( + type='GModel', + need_text=True, + body=dict( + type='GBody', + pipelines=[ + dict( + type='FeatureExtractorComponent', + from_layer='input', + to_layer='cnn_feat', + arch=dict( + encoder=dict( + backbone=dict( + type='GResNet', + layers=[ + ('conv', dict(type='ConvModule', in_channels=1, out_channels=int(hidden_dim/2), kernel_size=3, + stride=1, padding=1, norm_cfg=batch_norm)), + ('pool', dict(type='MaxPool2d', kernel_size=2, stride=2, padding=0)), + ('conv', dict(type='ConvModule', in_channels=int(hidden_dim/2), out_channels=hidden_dim, kernel_size=3, + stride=1, padding=1, norm_cfg=batch_norm)), + ('pool', dict(type='MaxPool2d', kernel_size=2, stride=2, padding=0)), + ], + ), + ), + collect=dict(type='CollectBlock', from_layer='c2'), + ), + ), + dict( + type='SequenceEncoderComponent', + from_layer='cnn_feat', + to_layer='src', + arch=dict( + type='TransformerEncoder', + position_encoder=dict( + type='Adaptive2DPositionEncoder', + in_channels=hidden_dim, + max_h=100, + max_w=100, + dropout=dropout, + ), + encoder_layer=dict( + type='TransformerEncoderLayer2D', + attention=dict( + type='MultiHeadAttention', + in_channels=hidden_dim, + k_channels=hidden_dim, + v_channels=hidden_dim, + n_head=n_head, + dropout=dropout, + ), + attention_norm=layer_norm, + feedforward=dict( + type='Feedforward', + layers=[ + dict(type='ConvModule', in_channels=hidden_dim, out_channels=hidden_dim*4, kernel_size=3, padding=1, + bias=True, norm_cfg=None, activation='relu', dropout=dropout), + dict(type='ConvModule', in_channels=hidden_dim*4, out_channels=hidden_dim, kernel_size=3, padding=1, + bias=True, norm_cfg=None, activation=None, dropout=dropout), + ], + ), + feedforward_norm=layer_norm, + ), + num_layers=n_e, + ), + ), + ], + ), + head=dict( + type='TransformerHead', + src_from='src', + num_steps=num_steps, + pad_id=num_class, + decoder=dict( + type='TransformerDecoder', + position_encoder=dict( + type='PositionEncoder1D', + in_channels=hidden_dim, + max_len=100, + dropout=dropout, + ), + decoder_layer=dict( + type='TransformerDecoderLayer1D', + self_attention=dict( + type='MultiHeadAttention', + in_channels=hidden_dim, + k_channels=hidden_dim, + v_channels=hidden_dim, + n_head=n_head, + dropout=dropout, + ), + self_attention_norm=layer_norm, + attention=dict( + type='MultiHeadAttention', + in_channels=hidden_dim, + k_channels=hidden_dim, + v_channels=hidden_dim, + n_head=n_head, + dropout=dropout, + ), + attention_norm=layer_norm, + feedforward=dict( + type='Feedforward', + layers=[ + dict(type='FCModule', in_channels=hidden_dim, out_channels=hidden_dim * 4, bias=True, + activation='relu', dropout=dropout), + dict(type='FCModule', in_channels=hidden_dim * 4, out_channels=hidden_dim, bias=True, + activation=None, dropout=dropout), + ], + ), + feedforward_norm=layer_norm, + ), + num_layers=n_d, + ), + generator=dict( + type='Linear', + in_features=hidden_dim, + out_features=num_class, + ), + embedding=dict( + type='Embedding', + num_embeddings=num_class + 1, + embedding_dim=hidden_dim, + padding_idx=num_class, + ), + ), +) + +## 4.1 resume +resume = None + +# 5. criterion +criterion = dict(type='CrossEntropyLoss', ignore_index=num_class) + +# 6. optim +optimizer = dict(type='Adam', lr=1e-4) + +# 7. lr scheduler +epochs = 6 +milestones = [2, 4] +niter_per_epoch = int(55000 * 256 / batch_size) +max_iterations = epochs * niter_per_epoch +milestones = [niter_per_epoch * epoch for epoch in milestones] +lr_scheduler = dict(type='StepLR', niter_per_epoch=niter_per_epoch, max_epochs=epochs, milestones=milestones, gamma=0.1, warmup_epochs=0.1) + +# 8. runner +runner = dict( + type='Runner', + iterations=max_iterations, + trainval_ratio=2000, + snapshot_interval=20000, +) + +# 9. device +gpu_id = '3' diff --git a/configs/srn.py b/configs/srn.py new file mode 100644 index 0000000..aebb414 --- /dev/null +++ b/configs/srn.py @@ -0,0 +1,298 @@ +# work dir +root_workdir = 'workdir/' + +# seed +seed = 6 + +# 1. logging +logger = dict( + handlers=( + dict(type='StreamHandler', level='INFO'), + dict(type='FileHandler', level='INFO'), + ), +) + +# 2. data +batch_size = 256 +mean, std = 0.5, 0.5 # normalize mean and std +size = (64, 256) +mode = 'nearest' +fill = 0 +batch_max_length = 25 +data_filter_off = False +sensitive = False +character = 'abcdefghijklmnopqrstuvwxyz0123456789' # need character + +# dataset params +dataset_params = dict( + batch_max_length=batch_max_length, + data_filter_off=data_filter_off, +) + +data_root = './data/data_lmdb_release/' + +# train data +train_root = data_root + 'training/' +## MJ dataset +train_root_mj = train_root + 'MJ/' +mj_folder_names = ['/MJ_test', 'MJ_valid', 'MJ_train'] +## ST dataset +train_root_st = train_root + 'ST/' + +train_dataset_mj = [dict(type='LmdbDataset', root=train_root_mj + folder_name) for folder_name in mj_folder_names] +train_dataset_st = [dict(type='LmdbDataset', root=train_root_st)] + +# valid +valid_root = data_root + 'validation/' +valid_dataset = [dict(type='LmdbDataset', root=valid_root, **dataset_params)] + +# test +test_root = data_root + 'evaluation/' +# test_folder_names = ['CUTE80', 'IC03_867', 'IC13_1015', 'IC15_2077', 'IIIT5k_3000', 'SVT', 'SVTP'] +test_folder_names = ['IIIT5k_3000'] +test_dataset = [dict(type='LmdbDataset', root=test_root + folder_name, **dataset_params) for folder_name in + test_folder_names] + +# transforms +train_transforms = [ + dict(type='Sensitive', sensitive=sensitive), + dict(type='KeepHorizontal', clockwise=False), + dict(type='Resize', size=size, keep_ratio=True, mode=mode), + dict(type='RandomScale', scales=(0.25, 1.0), step=0.25, mode=mode, p=0.5), + dict(type='ColorJitter', brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.5), + dict(type='MotionBlur', blur_limit=5, p=0.5), + dict(type='GaussianNoise', var_limit=(10, 50), mean=0, p=0.5), + dict(type='RandomPerspective', distortion_scale=0.3, mode=mode, p=0.5), + dict(type='RandomRotation', degrees=10, expand=False, fill=fill, mode=mode, p=1.0), + dict(type='PadIfNeeded', size=size, fill=fill), + dict(type='ToTensor'), + dict(type='Normalize', mean=mean, std=std), +] +test_transforms = [ + dict(type='Sensitive', sensitive=sensitive), + dict(type='KeepHorizontal', clockwise=False), + dict(type='Resize', size=size, keep_ratio=True, mode=mode), + dict(type='PadIfNeeded', size=size, fill=fill), + dict(type='ToTensor'), + dict(type='Normalize', mean=mean, std=std), +] + +data = dict( + train=dict( + transforms=train_transforms, + datasets=[ + dict( + type='ConcatDatasets', + datasets=train_dataset_mj, + **dataset_params, + ), + dict( + type='ConcatDatasets', + datasets=train_dataset_st, + **dataset_params, + ), + ], + loader=dict( + type='BatchRandomDataloader', + batch_size=batch_size, + each_batch_ratio=[1], + each_usage=[1.0], + shuffle=True, + ), + ), + val=dict( + transforms=test_transforms, + datasets=valid_dataset, + loader=dict( + type='TestDataloader', + batch_size=batch_size, + num_workers=4, + shuffle=False, + ), + ), + test=dict( + transforms=test_transforms, + datasets=test_dataset, + loader=dict( + type='TestDataloader', + batch_size=batch_size, + num_workers=4, + shuffle=False, + ), + ), +) + +# 3. converter +converter = dict( + type='FCConverter', + character=character, + batch_max_length=batch_max_length, +) + +# 4. model +num_class = 37 +model = dict( + type='GModel', + need_text=False, + body=dict( + type='GBody', + pipelines=[ + dict( + type='FeatureExtractorComponent', + from_layer='input', + to_layer='cnn_feat', + arch=dict( + encoder=dict( + backbone=dict( + type='ResNet', + arch='resnet50', + ), + ), + decoder=dict( + type='GFPN', + neck=[ + dict( + type='JunctionBlock', + top_down=None, + lateral=dict( + from_layer='c5', + type='ConvModule', + in_channels=2048, + out_channels=512, + kernel_size=1, + norm_cfg=None, + activation=None, + ), + post=None, + to_layer='p5', + ), # 32 + dict( + type='JunctionBlock', + fusion_method='add', + top_down=dict( + from_layer='p5', + upsample=dict( + type='Upsample', + scale_factor=2, + mode=mode, + ), + ), + lateral=dict( + from_layer='c4', + type='ConvModule', + in_channels=1024, + out_channels=512, + kernel_size=1, + norm_cfg=None, + activation=None, + ), + post=None, + to_layer='p4', + ), # 16 + dict( + type='JunctionBlock', + fusion_method='add', + top_down=dict( + from_layer='p4', + upsample=dict( + type='Upsample', + scale_factor=2, + mode=mode, + ), + ), + lateral=dict( + from_layer='c3', + type='ConvModule', + in_channels=512, + out_channels=512, + kernel_size=1, + norm_cfg=None, + activation=None, + ), + post=dict( + type='ConvModule', + in_channels=512, + out_channels=512, + kernel_size=3, + padding=1, + norm_cfg=None, + activation=None, + ), + to_layer='p3', + ), # 8 + ], + ), + collect=dict(type='CollectBlock', from_layer='p3'), + ), + ), + dict( + type='SequenceEncoderComponent', + from_layer='cnn_feat', + to_layer='tf_feat', + arch=dict( + type='Transformer', + num_layers=2, + d_model=512, + nhead=8, + dim_feedforward=512, + dropout=0.1, + activation='relu', + norm_cfg=None, + use_pos_encode=True, + pos_encode_len=256, + ), + ), + dict( + type='BrickComponent', + from_layer='tf_feat', + to_layer='seq_feat', + arch=dict( + type='PVABlock', + num_steps=batch_max_length+1, + in_channels=512, + embedding_channels=512, + inner_channels=512, + ), + ), + ], + ), + head=dict( + type='Head', + from_layer='seq_feat', + generator=dict( + type='FCModule', + in_channels=512, + out_channels=num_class, + bias=True, + activation=None, + ), + ), +) + +## 4.1 resume +resume = None + +# 5. criterion +criterion = dict(type='CrossEntropyLoss', ignore_index=num_class) + +# 6. optim +optimizer = dict(type='Adam', lr=1e-4) + +# 7. lr scheduler +epochs = 7 +decay_epochs = [3,5] +niter_per_epoch = int(55000 * 256 / batch_size) +milestones = [niter_per_epoch * epoch for epoch in decay_epochs] +max_iterations = epochs * niter_per_epoch +lr_scheduler = dict(type='StepLR', niter_per_epoch=niter_per_epoch, max_epochs=epochs, milestones=milestones, gamma=0.1, warmup_epochs=1) + +# 8. runner +runner = dict( + type='Runner', + iterations=max_iterations, + trainval_ratio=2000, + snapshot_interval=niter_per_epoch, +) + +# 9. device +gpu_id = '0,5,6,9' diff --git a/configs/tps_resnet_bilstm_attn.py b/configs/tps_resnet_bilstm_attn.py index 2fedd6d..406d903 100644 --- a/configs/tps_resnet_bilstm_attn.py +++ b/configs/tps_resnet_bilstm_attn.py @@ -86,20 +86,20 @@ transforms=transforms, datasets=valid_dataset, loader=dict( - type='TestDataloader', + type='RawDataloader', batch_size=batch_size, num_workers=4, - shuffle=False, + shuffle=True, ), ), test=dict( transforms=transforms, datasets=test_dataset, loader=dict( - type='TestDataloader', + type='RawDataloader', batch_size=batch_size, num_workers=4, - shuffle=False, + shuffle=True, ), ), ) @@ -273,7 +273,6 @@ optimizer = dict(type='Adadelta', lr=1.0, rho=0.95, eps=1e-8) # 7. lr scheduler - lr_scheduler = dict(type='StepLR', niter_per_epoch=100000, max_epochs=3, milestones=[150000, 250000]) # 8. runner diff --git a/requirements.txt b/requirements.txt index 6f1b0b6..7a566c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,9 @@ numpy opencv-python addict six -torch>=1.1.0 +torch>=1.2.0 torchvision>=0.3.0 -pillow<7.0 +6.0 w: + if self.clockwise: + image = image.transpose(Image.ROTATE_270) + else: + image = image.transpose(Image.ROTATE_90) + + return image, label diff --git a/vedastr/lr_schedulers/__init__.py b/vedastr/lr_schedulers/__init__.py index ca249b2..bfa14e1 100644 --- a/vedastr/lr_schedulers/__init__.py +++ b/vedastr/lr_schedulers/__init__.py @@ -3,3 +3,4 @@ from .cosine_lr import CosineLR from .exponential_lr import ExponentialLR from .step_lr import StepLR +from .constant_lr import ConstantLR diff --git a/vedastr/lr_schedulers/constant_lr.py b/vedastr/lr_schedulers/constant_lr.py new file mode 100644 index 0000000..227208f --- /dev/null +++ b/vedastr/lr_schedulers/constant_lr.py @@ -0,0 +1,18 @@ +from .base import _Iter_LRScheduler +from .registry import LR_SCHEDULERS + + +@LR_SCHEDULERS.register_module +class ConstantLR(_Iter_LRScheduler): + """ConstantLR + """ + def __init__(self, optimizer, niter_per_epoch, last_iter=-1, warmup_epochs=0): + self.warmup_iters = niter_per_epoch * warmup_epochs + super().__init__(optimizer, niter_per_epoch, last_iter) + + def get_lr(self): + if self.last_iter < self.warmup_iters: + multiplier = self.last_iter / float(self.warmup_iters) + else: + multiplier = 1.0 + return [base_lr * multiplier for base_lr in self.base_lrs] diff --git a/vedastr/lr_schedulers/exponential_lr.py b/vedastr/lr_schedulers/exponential_lr.py index efab972..2744c5e 100644 --- a/vedastr/lr_schedulers/exponential_lr.py +++ b/vedastr/lr_schedulers/exponential_lr.py @@ -10,11 +10,11 @@ def __init__(self, optimizer, niter_per_epoch, max_epochs, gamma, step, last_ite self.max_iters = niter_per_epoch * max_epochs self.gamma = gamma self.step_iters = niter_per_epoch * step - self.warmup_iters = niter_per_epoch * warmup_epochs + self.warmup_iters = int(niter_per_epoch * warmup_epochs) super().__init__(optimizer, niter_per_epoch, last_iter) def get_lr(self): - if self.last_iter < self.warm_up: + if self.last_iter < self.warmup_iters: multiplier = self.last_iter / float(self.warmup_iters) else: multiplier = self.gamma ** ((self.last_iter-self.warmup_iters) / float(self.step_iters)) diff --git a/vedastr/lr_schedulers/step_lr.py b/vedastr/lr_schedulers/step_lr.py index 50f285a..14af52f 100644 --- a/vedastr/lr_schedulers/step_lr.py +++ b/vedastr/lr_schedulers/step_lr.py @@ -6,17 +6,21 @@ @LR_SCHEDULERS.register_module class StepLR(_Iter_LRScheduler): - def __init__(self, optimizer, niter_per_epoch, max_epochs, milestones, gamma=0.1, last_iter=-1, warmup_epochs=0): self.max_iters = niter_per_epoch * max_epochs - self.milestones = Counter(milestones) + self.milestones = milestones + self.count = 0 self.gamma = gamma - self.warmup_iters = niter_per_epoch * warmup_epochs + self.warmup_iters = int(niter_per_epoch * warmup_epochs) super(StepLR, self).__init__(optimizer, niter_per_epoch, last_iter) def get_lr(self): + if self.last_iter in self.milestones: + self.count += 1 + + if self.last_iter < self.warmup_iters: + multiplier = self.last_iter / float(self.warmup_iters) + else: + multiplier = self.gamma ** self.count + return [base_lr * multiplier for base_lr in self.base_lrs] - if self.last_iter not in self.milestones: - return [group['lr'] for group in self.optimizer.param_groups] - return [group['lr'] * self.gamma ** self.milestones[self.last_iter] - for group in self.optimizer.param_groups] diff --git a/vedastr/models/bodies/body.py b/vedastr/models/bodies/body.py index 920ee19..e445684 100644 --- a/vedastr/models/bodies/body.py +++ b/vedastr/models/bodies/body.py @@ -11,9 +11,7 @@ def __init__(self, pipelines, collect=None): super(GBody, self).__init__() self.input_to_layer = 'input' - self.components = nn.ModuleList() - for component in pipelines: - self.components.append(build_component(component)) + self.components = nn.ModuleList([build_component(component) for component in pipelines]) if collect is not None: self.collect = build_brick(collect) @@ -28,7 +26,13 @@ def forward(self, x): for component in self.components: component_from = component.from_layer component_to = component.to_layer - out = component(feats[component_from]) + + if isinstance(component_from, list): + inp = {key: feats[key] for key in component_from} + out = component(**inp) + else: + inp = feats[component_from] + out = component(inp) feats[component_to] = out if self.with_collect: diff --git a/vedastr/models/bodies/feature_extractors/decoders/bricks/__init__.py b/vedastr/models/bodies/feature_extractors/decoders/bricks/__init__.py index e43926c..11c6a1a 100644 --- a/vedastr/models/bodies/feature_extractors/decoders/bricks/__init__.py +++ b/vedastr/models/bodies/feature_extractors/decoders/bricks/__init__.py @@ -1,2 +1,3 @@ from .bricks import JunctionBlock, FusionBlock, CollectBlock, CellAttentionBlock +from .pva import PVABlock from .builder import build_brick, build_bricks diff --git a/vedastr/models/bodies/feature_extractors/decoders/bricks/bricks.py b/vedastr/models/bodies/feature_extractors/decoders/bricks/bricks.py index 812ccfd..db967ce 100644 --- a/vedastr/models/bodies/feature_extractors/decoders/bricks/bricks.py +++ b/vedastr/models/bodies/feature_extractors/decoders/bricks/bricks.py @@ -171,7 +171,6 @@ def forward(self, feats): feats[self.to_layer] = {f_layer: feats[f_layer] for f_layer in self.from_layer} - @BRICKS.register_module class CellAttentionBlock(nn.Module): def __init__(self, feat, hidden, fusion_method='add', post=None, post_activation='softmax'): diff --git a/vedastr/models/bodies/feature_extractors/decoders/bricks/pva.py b/vedastr/models/bodies/feature_extractors/decoders/bricks/pva.py new file mode 100644 index 0000000..3018a2d --- /dev/null +++ b/vedastr/models/bodies/feature_extractors/decoders/bricks/pva.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn + +from vedastr.models.weight_init import init_weights +from .registry import BRICKS + + +@BRICKS.register_module +class PVABlock(nn.Module): + def __init__(self, num_steps, in_channels, embedding_channels=512, inner_channels=512): + super(PVABlock, self).__init__() + + self.num_steps = num_steps + self.in_channels = in_channels + self.inner_channels = inner_channels + self.embedding_channels = embedding_channels + + self.order_embeddings = nn.Parameter(torch.randn(self.num_steps, self.embedding_channels), requires_grad=True) + + self.v_linear = nn.Linear(self.in_channels, self.inner_channels, bias=False) + self.o_linear = nn.Linear(self.embedding_channels, self.inner_channels, bias=False) + self.e_linear = nn.Linear(self.inner_channels, 1, bias=False) + + init_weights(self.modules()) + + def forward(self, x): + b, c, h, w = x.size() + + x = x.reshape(b, c, h*w).permute(0, 2, 1) + + o_out = self.o_linear(self.order_embeddings).view(1, self.num_steps, 1, self.inner_channels) + v_out = self.v_linear(x).unsqueeze(1) + att = self.e_linear(torch.tanh(o_out + v_out)).squeeze(3) + att = torch.softmax(att, dim=2) + + out = torch.bmm(att, x) + + return out diff --git a/vedastr/models/bodies/feature_extractors/encoders/backbones/resnet.py b/vedastr/models/bodies/feature_extractors/encoders/backbones/resnet.py index eded684..2780f7d 100644 --- a/vedastr/models/bodies/feature_extractors/encoders/backbones/resnet.py +++ b/vedastr/models/bodies/feature_extractors/encoders/backbones/resnet.py @@ -31,6 +31,11 @@ 'layer': [3, 4, 6, 3], 'weights_url': model_urls['resnet50'], }, + 'resnet34': { + 'block': BasicBlock, + 'layer': [3, 4, 6, 3], + 'weights_url': model_urls['resnet34'], + }, 'resnet18': { 'block': BasicBlock, 'layer': [2, 2, 2, 2], @@ -220,22 +225,8 @@ def __init__(self, layers, zero_init_residual=False, stage_layers.append(layer) self.layers.append(nn.Sequential(*stage_layers)) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') - elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) + logger.info('GResNet init weights') + init_weights(self.modules()) def _make_layer(self, block_name, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer diff --git a/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/aspp.py b/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/aspp.py index b0c8f06..a3591db 100644 --- a/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/aspp.py +++ b/vedastr/models/bodies/feature_extractors/encoders/enhance_modules/aspp.py @@ -37,7 +37,7 @@ def __init__(self, in_channels, out_channels): def forward(self, x): size = x.shape[-2:] x = super(ASPPPooling, self).forward(x) - return F.interpolate(x, size=size, mode='bilinear', align_corners=True) + return F.interpolate(x, size=size, mode='nearest') @ENHANCE_MODULES.register_module @@ -66,7 +66,7 @@ def __init__(self, in_channels, out_channels, atrous_rates, from_layer, nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) self.with_dropout = dropout is not None if self.with_dropout: - self.dropout = nn.Dropout(dropout) + self.dropout = nn.Dropout(p=dropout) logger.info('ASPP init weights') init_weights(self.modules()) diff --git a/vedastr/models/bodies/sequences/__init__.py b/vedastr/models/bodies/sequences/__init__.py index fd32bec..aa0394a 100644 --- a/vedastr/models/bodies/sequences/__init__.py +++ b/vedastr/models/bodies/sequences/__init__.py @@ -1,2 +1,3 @@ -from .encoders import build_sequence_encoder -from .decoders import build_sequence_decoder +from .builder import build_sequence_encoder, build_sequence_decoder +from .rnn import RNN, GRUCell +from .transformer import TransformerEncoder diff --git a/vedastr/models/bodies/sequences/builder.py b/vedastr/models/bodies/sequences/builder.py new file mode 100644 index 0000000..299a3ba --- /dev/null +++ b/vedastr/models/bodies/sequences/builder.py @@ -0,0 +1,14 @@ +from vedastr.utils import build_from_cfg +from .registry import SEQUENCE_ENCODERS, SEQUENCE_DECODERS + + +def build_sequence_encoder(cfg, default_args=None): + sequence_encoder = build_from_cfg(cfg, SEQUENCE_ENCODERS, default_args) + + return sequence_encoder + + +def build_sequence_decoder(cfg, default_args=None): + sequence_encoder = build_from_cfg(cfg, SEQUENCE_DECODERS, default_args) + + return sequence_encoder diff --git a/vedastr/models/bodies/sequences/decoders/__init__.py b/vedastr/models/bodies/sequences/decoders/__init__.py deleted file mode 100644 index e3be86f..0000000 --- a/vedastr/models/bodies/sequences/decoders/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .rnn_cell import LSTMCell -from .builder import build_sequence_decoder diff --git a/vedastr/models/bodies/sequences/decoders/builder.py b/vedastr/models/bodies/sequences/decoders/builder.py deleted file mode 100644 index 3a323bd..0000000 --- a/vedastr/models/bodies/sequences/decoders/builder.py +++ /dev/null @@ -1,8 +0,0 @@ -from vedastr.utils import build_from_cfg -from .registry import SEQUENCE_DECODERS - - -def build_sequence_decoder(cfg, default_args=None): - sequence_encoder = build_from_cfg(cfg, SEQUENCE_DECODERS, default_args) - - return sequence_encoder diff --git a/vedastr/models/bodies/sequences/encoders/__init__.py b/vedastr/models/bodies/sequences/encoders/__init__.py deleted file mode 100644 index 2a37559..0000000 --- a/vedastr/models/bodies/sequences/encoders/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .rnn import RNN -from .builder import build_sequence_encoder diff --git a/vedastr/models/bodies/sequences/encoders/builder.py b/vedastr/models/bodies/sequences/encoders/builder.py deleted file mode 100644 index cee3354..0000000 --- a/vedastr/models/bodies/sequences/encoders/builder.py +++ /dev/null @@ -1,8 +0,0 @@ -from vedastr.utils import build_from_cfg -from .registry import SEQUENCE_ENCODERS - - -def build_sequence_encoder(cfg, default_args=None): - sequence_encoder = build_from_cfg(cfg, SEQUENCE_ENCODERS, default_args) - - return sequence_encoder diff --git a/vedastr/models/bodies/sequences/encoders/registry.py b/vedastr/models/bodies/sequences/encoders/registry.py deleted file mode 100644 index 4689131..0000000 --- a/vedastr/models/bodies/sequences/encoders/registry.py +++ /dev/null @@ -1,3 +0,0 @@ -from vedastr.utils import Registry - -SEQUENCE_ENCODERS = Registry('sequence_encoder') diff --git a/vedastr/models/bodies/sequences/decoders/registry.py b/vedastr/models/bodies/sequences/registry.py similarity index 63% rename from vedastr/models/bodies/sequences/decoders/registry.py rename to vedastr/models/bodies/sequences/registry.py index ff3f0e5..e6526eb 100644 --- a/vedastr/models/bodies/sequences/decoders/registry.py +++ b/vedastr/models/bodies/sequences/registry.py @@ -1,3 +1,4 @@ from vedastr.utils import Registry +SEQUENCE_ENCODERS = Registry('sequence_encoder') SEQUENCE_DECODERS = Registry('sequence_decoder') diff --git a/vedastr/models/bodies/sequences/rnn/__init__.py b/vedastr/models/bodies/sequences/rnn/__init__.py new file mode 100644 index 0000000..f9c7230 --- /dev/null +++ b/vedastr/models/bodies/sequences/rnn/__init__.py @@ -0,0 +1,2 @@ +from .encoder import RNN +from .decoder import LSTMCell, GRUCell diff --git a/vedastr/models/bodies/sequences/decoders/rnn_cell.py b/vedastr/models/bodies/sequences/rnn/decoder.py similarity index 98% rename from vedastr/models/bodies/sequences/decoders/rnn_cell.py rename to vedastr/models/bodies/sequences/rnn/decoder.py index 77f394a..7f570e7 100644 --- a/vedastr/models/bodies/sequences/decoders/rnn_cell.py +++ b/vedastr/models/bodies/sequences/rnn/decoder.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from vedastr.models.weight_init import init_weights -from .registry import SEQUENCE_DECODERS +from vedastr.models.weight_init import init_weights +from ..registry import SEQUENCE_DECODERS class BaseCell(nn.Module): @@ -22,7 +22,6 @@ def __init__(self, basic_cell, input_size, hidden_size, bias=True, num_layers=1) self.cells.append(basic_cell(input_size=hidden_size, hidden_size=hidden_size, bias=bias)) init_weights(self.modules()) - def init_hidden(self, batch_size, device=None, value=0): raise NotImplementedError() diff --git a/vedastr/models/bodies/sequences/encoders/rnn.py b/vedastr/models/bodies/sequences/rnn/encoder.py similarity index 97% rename from vedastr/models/bodies/sequences/encoders/rnn.py rename to vedastr/models/bodies/sequences/rnn/encoder.py index 489a346..612020b 100644 --- a/vedastr/models/bodies/sequences/encoders/rnn.py +++ b/vedastr/models/bodies/sequences/rnn/encoder.py @@ -2,7 +2,7 @@ from vedastr.models.utils import build_torch_nn from vedastr.models.weight_init import init_weights -from .registry import SEQUENCE_ENCODERS +from ..registry import SEQUENCE_ENCODERS @SEQUENCE_ENCODERS.register_module diff --git a/vedastr/models/bodies/sequences/transformer/__init__.py b/vedastr/models/bodies/sequences/transformer/__init__.py new file mode 100644 index 0000000..5bba66d --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/__init__.py @@ -0,0 +1,2 @@ +from .encoder import TransformerEncoder +from .decoder import TransformerDecoder diff --git a/vedastr/models/bodies/sequences/transformer/decoder.py b/vedastr/models/bodies/sequences/transformer/decoder.py new file mode 100644 index 0000000..5d7daa8 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/decoder.py @@ -0,0 +1,38 @@ +import logging + +import torch.nn as nn + +from .position_encoder import build_position_encoder +from .unit import build_decoder_layer +from ..registry import SEQUENCE_DECODERS +from vedastr.models.weight_init import init_weights + + +logger = logging.getLogger() + + +@SEQUENCE_DECODERS.register_module +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, position_encoder=None): + super(TransformerDecoder, self).__init__() + + if position_encoder is not None: + self.pos_encoder = build_position_encoder(position_encoder) + + self.layers = nn.ModuleList([build_decoder_layer(decoder_layer) for _ in range(num_layers)]) + + logger.info('TransformerDecoder init weights') + init_weights(self.modules()) + + @property + def with_position_encoder(self): + return hasattr(self, 'pos_encoder') and self.pos_encoder is not None + + def forward(self, tgt, src, tgt_mask=None, src_mask=None): + if self.with_position_encoder: + tgt = self.pos_encoder(tgt) + + for layer in self.layers: + tgt = layer(tgt, src, tgt_mask, src_mask) + + return tgt diff --git a/vedastr/models/bodies/sequences/transformer/encoder.py b/vedastr/models/bodies/sequences/transformer/encoder.py new file mode 100644 index 0000000..25340c6 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/encoder.py @@ -0,0 +1,38 @@ +import logging + +import torch.nn as nn + +from .position_encoder import build_position_encoder +from .unit import build_encoder_layer +from ..registry import SEQUENCE_ENCODERS +from vedastr.models.weight_init import init_weights + + +logger = logging.getLogger() + + +@SEQUENCE_ENCODERS.register_module +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, position_encoder=None): + super(TransformerEncoder, self).__init__() + + if position_encoder is not None: + self.pos_encoder = build_position_encoder(position_encoder) + + self.layers = nn.ModuleList([build_encoder_layer(encoder_layer) for _ in range(num_layers)]) + + logger.info('TransformerEncoder init weights') + init_weights(self.modules()) + + @property + def with_position_encoder(self): + return hasattr(self, 'pos_encoder') and self.pos_encoder is not None + + def forward(self, src, src_mask=None): + if self.with_position_encoder: + src = self.pos_encoder(src) + + for layer in self.layers: + src = layer(src, src_mask) + + return src diff --git a/vedastr/models/bodies/sequences/transformer/position_encoder/__init__.py b/vedastr/models/bodies/sequences/transformer/position_encoder/__init__.py new file mode 100644 index 0000000..c927dfa --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/position_encoder/__init__.py @@ -0,0 +1,3 @@ +from .builder import build_position_encoder +from .encoder import PositionEncoder1D +from .adaptive_2d_encoder import Adaptive2DPositionEncoder diff --git a/vedastr/models/bodies/sequences/transformer/position_encoder/adaptive_2d_encoder.py b/vedastr/models/bodies/sequences/transformer/position_encoder/adaptive_2d_encoder.py new file mode 100644 index 0000000..2425f2f --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/position_encoder/adaptive_2d_encoder.py @@ -0,0 +1,48 @@ +import torch.nn as nn + +from .utils import generate_encoder +from .registry import POSITION_ENCODERS + + +@POSITION_ENCODERS.register_module +class Adaptive2DPositionEncoder(nn.Module): + def __init__(self, in_channels, max_h=200, max_w=200, dropout=0.1): + super(Adaptive2DPositionEncoder, self).__init__() + + h_position_encoder = generate_encoder(in_channels, max_h) + h_position_encoder = h_position_encoder.transpose(0, 1).view(1, in_channels, max_h, 1) + + w_position_encoder = generate_encoder(in_channels, max_w) + w_position_encoder = w_position_encoder.transpose(0, 1).view(1, in_channels, 1, max_w) + + self.register_buffer('h_position_encoder', h_position_encoder) + self.register_buffer('w_position_encoder', w_position_encoder) + + self.h_scale = self.scale_factor_generate(in_channels) + self.w_scale = self.scale_factor_generate(in_channels) + self.pool = nn.AdaptiveAvgPool2d(1) + self.dropout = nn.Dropout(p=dropout) + + def scale_factor_generate(self, in_channels): + scale_factor = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels, in_channels, kernel_size=1), + nn.Sigmoid() + ) + + return scale_factor + + def forward(self, x): + b, c, h, w = x.size() + + avg_pool = self.pool(x) + + h_pos_encoding = self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] + w_pos_encoding = self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] + + out = x + h_pos_encoding + w_pos_encoding + + out = self.dropout(out) + + return out diff --git a/vedastr/models/bodies/sequences/transformer/position_encoder/builder.py b/vedastr/models/bodies/sequences/transformer/position_encoder/builder.py new file mode 100644 index 0000000..ad79edf --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/position_encoder/builder.py @@ -0,0 +1,8 @@ +from vedastr.utils import build_from_cfg +from .registry import POSITION_ENCODERS + + +def build_position_encoder(cfg, default_args=None): + position_encoder = build_from_cfg(cfg, POSITION_ENCODERS, default_args) + + return position_encoder diff --git a/vedastr/models/bodies/sequences/transformer/position_encoder/encoder.py b/vedastr/models/bodies/sequences/transformer/position_encoder/encoder.py new file mode 100644 index 0000000..3497857 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/position_encoder/encoder.py @@ -0,0 +1,21 @@ +import torch.nn as nn + +from .utils import generate_encoder +from .registry import POSITION_ENCODERS + + +@POSITION_ENCODERS.register_module +class PositionEncoder1D(nn.Module): + def __init__(self, in_channels, max_len=2000, dropout=0.1): + super(PositionEncoder1D, self).__init__() + + position_encoder = generate_encoder(in_channels, max_len) + position_encoder = position_encoder.unsqueeze(0) + self.register_buffer('position_encoder', position_encoder) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x): + out = x + self.position_encoder[:, :x.size(1), :] + out = self.dropout(out) + + return out diff --git a/vedastr/models/bodies/sequences/transformer/position_encoder/registry.py b/vedastr/models/bodies/sequences/transformer/position_encoder/registry.py new file mode 100644 index 0000000..209abdc --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/position_encoder/registry.py @@ -0,0 +1,3 @@ +from vedastr.utils import Registry + +POSITION_ENCODERS = Registry('position_encoder') diff --git a/vedastr/models/bodies/sequences/transformer/position_encoder/utils.py b/vedastr/models/bodies/sequences/transformer/position_encoder/utils.py new file mode 100644 index 0000000..c6c503c --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/position_encoder/utils.py @@ -0,0 +1,14 @@ +import torch + + +def generate_encoder(in_channels, max_len): + pos = torch.arange(max_len).float().unsqueeze(1) + + i = torch.arange(in_channels).float().unsqueeze(0) + angle_rates = 1 / torch.pow(10000, (2 * (i//2)) / in_channels) + + position_encoder = pos * angle_rates + position_encoder[:, 0::2] = torch.sin(position_encoder[:, 0::2]) + position_encoder[:, 1::2] = torch.cos(position_encoder[:, 1::2]) + + return position_encoder diff --git a/vedastr/models/bodies/sequences/transformer/unit/__init__.py b/vedastr/models/bodies/sequences/transformer/unit/__init__.py new file mode 100644 index 0000000..55950b8 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/__init__.py @@ -0,0 +1,3 @@ +from .builder import build_encoder_layer, build_decoder_layer +from .encoder import TransformerEncoderLayer1D, TransformerEncoderLayer2D +from .decoder import TransformerDecoderLayer1D diff --git a/vedastr/models/bodies/sequences/transformer/unit/attention/__init__.py b/vedastr/models/bodies/sequences/transformer/unit/attention/__init__.py new file mode 100644 index 0000000..d2b270a --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/attention/__init__.py @@ -0,0 +1,2 @@ +from .builder import build_attention +from .multihead_attention import MultiHeadAttention diff --git a/vedastr/models/bodies/sequences/transformer/unit/attention/builder.py b/vedastr/models/bodies/sequences/transformer/unit/attention/builder.py new file mode 100644 index 0000000..08a3bd4 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/attention/builder.py @@ -0,0 +1,8 @@ +from vedastr.utils import build_from_cfg +from .registry import TRANSFORMER_ATTENTIONS + + +def build_attention(cfg, default_args=None): + attention = build_from_cfg(cfg, TRANSFORMER_ATTENTIONS, default_args) + + return attention diff --git a/vedastr/models/bodies/sequences/transformer/unit/attention/multihead_attention.py b/vedastr/models/bodies/sequences/transformer/unit/attention/multihead_attention.py new file mode 100644 index 0000000..db03057 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/attention/multihead_attention.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + +from .registry import TRANSFORMER_ATTENTIONS + + +class ScaledDotProductAttention(nn.Module): + def __init__(self, temperature, dropout=0.1): + super(ScaledDotProductAttention, self).__init__() + + self.temperature = temperature + self.dropout = nn.Dropout(p=dropout) + + def forward(self, q, k, v, mask=None): + attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature + + if mask is not None: + attn = attn.masked_fill(mask=mask, value=float('-inf')) + + attn = torch.softmax(attn, dim=-1) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) + + return out, attn + + +@TRANSFORMER_ATTENTIONS.register_module +class MultiHeadAttention(nn.Module): + def __init__(self, in_channels, k_channels, v_channels, n_head=8, dropout=0.1): + super(MultiHeadAttention, self).__init__() + + self.in_channels = in_channels + self.k_channels = k_channels + self.v_channels = v_channels + self.n_head = n_head + + self.q_linear = nn.Linear(in_channels, n_head * k_channels) + self.k_linear = nn.Linear(in_channels, n_head * k_channels) + self.v_linear = nn.Linear(in_channels, n_head * v_channels) + self.attention = ScaledDotProductAttention(temperature=k_channels ** 0.5, dropout=dropout) + self.out_linear = nn.Linear(n_head * v_channels, in_channels) + + self.dropout = nn.Dropout(p=dropout) + + def forward(self, q, k, v, mask=None): + b, q_len, k_len, v_len = q.size(0), q.size(1), k.size(1), v.size(1) + + q = self.q_linear(q).view(b, q_len, self.n_head, self.k_channels).transpose(1, 2) + k = self.k_linear(k).view(b, k_len, self.n_head, self.k_channels).transpose(1, 2) + v = self.v_linear(v).view(b, v_len, self.n_head, self.v_channels).transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) + + out, attn = self.attention(q, k, v, mask=mask) + + out = out.transpose(1, 2).contiguous().view(b, q_len, self.n_head * self.v_channels) + out = self.out_linear(out) + out = self.dropout(out) + + return out, attn diff --git a/vedastr/models/bodies/sequences/transformer/unit/attention/registry.py b/vedastr/models/bodies/sequences/transformer/unit/attention/registry.py new file mode 100644 index 0000000..0a122b8 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/attention/registry.py @@ -0,0 +1,3 @@ +from vedastr.utils import Registry + +TRANSFORMER_ATTENTIONS = Registry('transformer_attention') diff --git a/vedastr/models/bodies/sequences/transformer/unit/builder.py b/vedastr/models/bodies/sequences/transformer/unit/builder.py new file mode 100644 index 0000000..2e20c27 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/builder.py @@ -0,0 +1,14 @@ +from vedastr.utils import build_from_cfg +from .registry import TRANSFORMER_ENCODER_LAYERS, TRANSFORMER_DECODER_LAYERS + + +def build_encoder_layer(cfg, default_args=None): + encoder_layer = build_from_cfg(cfg, TRANSFORMER_ENCODER_LAYERS, default_args) + + return encoder_layer + + +def build_decoder_layer(cfg, default_args=None): + decoder_layer = build_from_cfg(cfg, TRANSFORMER_DECODER_LAYERS, default_args) + + return decoder_layer diff --git a/vedastr/models/bodies/sequences/transformer/unit/decoder.py b/vedastr/models/bodies/sequences/transformer/unit/decoder.py new file mode 100644 index 0000000..3436e31 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/decoder.py @@ -0,0 +1,46 @@ +import torch.nn as nn + +from vedastr.models.utils import build_torch_nn +from .attention import build_attention +from .feedforward import build_feedforward +from .registry import TRANSFORMER_DECODER_LAYERS + + +@TRANSFORMER_DECODER_LAYERS.register_module +class TransformerDecoderLayer1D(nn.Module): + def __init__(self, + self_attention, + self_attention_norm, + attention, + attention_norm, + feedforward, + feedforward_norm): + super(TransformerDecoderLayer1D, self).__init__() + + self.self_attention = build_attention(self_attention) + self.self_attention_norm = build_torch_nn(self_attention_norm) + + self.attention = build_attention(attention) + self.attention_norm = build_torch_nn(attention_norm) + + self.feedforward = build_feedforward(feedforward) + self.feedforward_norm = build_torch_nn(feedforward_norm) + + def forward(self, tgt, src, tgt_mask=None, src_mask=None): + attn1, _ = self.self_attention(tgt, tgt, tgt, tgt_mask) + out1 = self.self_attention_norm(tgt+attn1) + + size = src.size() + if len(size) == 4: + b, c, h, w = size + src = src.view(b, c, h * w).transpose(1, 2) + if src_mask is not None: + src_mask = src_mask.view(b, 1, h * w) + + attn2, _ = self.attention(out1, src, src, src_mask) + out2 = self.attention_norm(out1+attn2) + + ffn_out = self.feedforward(out2) + out3 = self.feedforward_norm(out2+ffn_out) + + return out3 diff --git a/vedastr/models/bodies/sequences/transformer/unit/encoder.py b/vedastr/models/bodies/sequences/transformer/unit/encoder.py new file mode 100644 index 0000000..eb69482 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/encoder.py @@ -0,0 +1,68 @@ +import torch.nn as nn + +from vedastr.models.utils import build_torch_nn +from .attention import build_attention +from .feedforward import build_feedforward +from .registry import TRANSFORMER_ENCODER_LAYERS + + +@TRANSFORMER_ENCODER_LAYERS.register_module +class TransformerEncoderLayer1D(nn.Module): + def __init__(self, attention, attention_norm, feedforward, feedforward_norm): + super(TransformerEncoderLayer1D, self).__init__() + + self.attention = build_attention(attention) + self.attention_norm = build_torch_nn(attention_norm) + + self.feedforward = build_feedforward(feedforward) + self.feedforward_norm = build_torch_nn(feedforward_norm) + + def forward(self, src, src_mask=None): + attn_out, _ = self.attention(src, src, src, src_mask) + out1 = self.attention_norm(src+attn_out) + + ffn_out = self.feedforward(out1) + out2 = self.feedforward_norm(out1+ffn_out) + + return out2 + + +@TRANSFORMER_ENCODER_LAYERS.register_module +class TransformerEncoderLayer2D(nn.Module): + def __init__(self, attention, attention_norm, feedforward, feedforward_norm): + super(TransformerEncoderLayer2D, self).__init__() + + self.attention = build_attention(attention) + self.attention_norm = build_torch_nn(attention_norm) + + self.feedforward = build_feedforward(feedforward) + self.feedforward_norm = build_torch_nn(feedforward_norm) + + def norm(self, norm_layer, x): + b, c, h, w = x.size() + + if isinstance(norm_layer, nn.LayerNorm): + out = x.view(b, c, h * w).transpose(1, 2) + out = norm_layer(out) + out = out.transpose(1, 2).contiguous().view(b, c, h, w) + else: + out = norm_layer(x) + + return out + + def forward(self, src, src_mask=None): + b, c, h, w = src.size() + + src = src.view(b, c, h * w).transpose(1, 2) + if src_mask is not None: + src_mask = src_mask.view(b, 1, h * w) + + attn_out, _ = self.attention(src, src, src, src_mask) + out1 = src + attn_out + out1 = out1.transpose(1, 2).contiguous().view(b, c, h, w) + out1 = self.norm(self.attention_norm, out1) + + ffn_out = self.feedforward(out1) + out2 = self.norm(self.feedforward_norm, out1+ffn_out) + + return out2 diff --git a/vedastr/models/bodies/sequences/transformer/unit/feedforward/__init__.py b/vedastr/models/bodies/sequences/transformer/unit/feedforward/__init__.py new file mode 100644 index 0000000..26fb713 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/feedforward/__init__.py @@ -0,0 +1,2 @@ +from .builder import build_feedforward +from .feedforward import Feedforward diff --git a/vedastr/models/bodies/sequences/transformer/unit/feedforward/builder.py b/vedastr/models/bodies/sequences/transformer/unit/feedforward/builder.py new file mode 100644 index 0000000..a598d70 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/feedforward/builder.py @@ -0,0 +1,8 @@ +from vedastr.utils import build_from_cfg +from .registry import TRANSFORMER_FEEDFORWARDS + + +def build_feedforward(cfg, default_args=None): + feedforward = build_from_cfg(cfg, TRANSFORMER_FEEDFORWARDS, default_args) + + return feedforward diff --git a/vedastr/models/bodies/sequences/transformer/unit/feedforward/feedforward.py b/vedastr/models/bodies/sequences/transformer/unit/feedforward/feedforward.py new file mode 100644 index 0000000..02c4de2 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/feedforward/feedforward.py @@ -0,0 +1,18 @@ +import torch.nn as nn + +from vedastr.models.utils import build_module +from .registry import TRANSFORMER_FEEDFORWARDS + + +@TRANSFORMER_FEEDFORWARDS.register_module +class Feedforward(nn.Module): + def __init__(self, layers): + super(Feedforward, self).__init__() + + self.layers = [build_module(layer) for layer in layers] + self.layers = nn.Sequential(*self.layers) + + def forward(self, x): + out = self.layers(x) + + return out diff --git a/vedastr/models/bodies/sequences/transformer/unit/feedforward/registry.py b/vedastr/models/bodies/sequences/transformer/unit/feedforward/registry.py new file mode 100644 index 0000000..9960907 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/feedforward/registry.py @@ -0,0 +1,3 @@ +from vedastr.utils import Registry + +TRANSFORMER_FEEDFORWARDS = Registry('transformer_feedforward') diff --git a/vedastr/models/bodies/sequences/transformer/unit/registry.py b/vedastr/models/bodies/sequences/transformer/unit/registry.py new file mode 100644 index 0000000..d83b0a1 --- /dev/null +++ b/vedastr/models/bodies/sequences/transformer/unit/registry.py @@ -0,0 +1,4 @@ +from vedastr.utils import Registry + +TRANSFORMER_ENCODER_LAYERS = Registry('transformer_encoder_layer') +TRANSFORMER_DECODER_LAYERS = Registry('transformer_decoder_layer') diff --git a/vedastr/models/heads/__init__.py b/vedastr/models/heads/__init__.py index 60c1c10..7183512 100644 --- a/vedastr/models/heads/__init__.py +++ b/vedastr/models/heads/__init__.py @@ -1,3 +1,5 @@ from .builder import build_head from .att_head import AttHead from .fc_head import FCHead +from .head import Head +from .transformer_head import TransformerHead diff --git a/vedastr/models/heads/fc_head.py b/vedastr/models/heads/fc_head.py index 12e8abf..bfcdac7 100644 --- a/vedastr/models/heads/fc_head.py +++ b/vedastr/models/heads/fc_head.py @@ -22,7 +22,7 @@ def __init__(self, num_class, batch_max_length, from_layer, - inter_channels=None, + inner_channels=None, bias=True, activation='relu', inplace=True, @@ -30,13 +30,14 @@ def __init__(self, num_fcs=0, pool=None): super(FCHead, self).__init__() + self.num_class = num_class self.batch_max_length = batch_max_length self.from_layer = from_layer if num_fcs > 0: - inter_fc = FCModules(in_channels, inter_channels, bias, activation, inplace, dropouts, num_fcs) - fc = nn.Linear(inter_channels, out_channels) + inter_fc = FCModules(in_channels, inner_channels, bias, activation, inplace, dropouts, num_fcs) + fc = nn.Linear(inner_channels, out_channels) else: inter_fc = nn.Sequential() fc = nn.Linear(in_channels, out_channels) diff --git a/vedastr/models/heads/head.py b/vedastr/models/heads/head.py new file mode 100644 index 0000000..5b0b560 --- /dev/null +++ b/vedastr/models/heads/head.py @@ -0,0 +1,35 @@ +import logging + +import torch.nn as nn + +from vedastr.models.utils import build_module +from vedastr.models.weight_init import init_weights +from .registry import HEADS + +logger = logging.getLogger() + + +@HEADS.register_module +class Head(nn.Module): + """Head + + Args: + """ + + def __init__(self, + from_layer, + generator, + ): + super(Head, self).__init__() + + self.from_layer = from_layer + self.generator = build_module(generator) + + logger.info('Head init weights') + init_weights(self.modules()) + + def forward(self, feats): + x = feats[self.from_layer] + out = self.generator(x) + + return out diff --git a/vedastr/models/heads/transformer_head.py b/vedastr/models/heads/transformer_head.py new file mode 100644 index 0000000..9586b25 --- /dev/null +++ b/vedastr/models/heads/transformer_head.py @@ -0,0 +1,83 @@ +import logging +import math + +import torch +import torch.nn as nn + +from vedastr.models.bodies import build_sequence_decoder +from vedastr.models.utils import build_torch_nn +from vedastr.models.weight_init import init_weights +from .registry import HEADS + +logger = logging.getLogger() + + +@HEADS.register_module +class TransformerHead(nn.Module): + def __init__(self, + decoder, + generator, + embedding, + num_steps, + pad_id, + src_from, + src_mask_from=None, + ): + super(TransformerHead, self).__init__() + + self.decoder = build_sequence_decoder(decoder) + self.generator = build_torch_nn(generator) + self.embedding = build_torch_nn(embedding) + self.num_steps = num_steps + self.pad_id = pad_id + self.src_from = src_from + self.src_mask_from = src_mask_from + + logger.info('TransformerHead init weights') + init_weights(self.modules()) + + def pad_mask(self, text): + pad_mask = (text == self.pad_id) + pad_mask[:, 0] = False + pad_mask = pad_mask.unsqueeze(1) + + return pad_mask + + def order_mask(self, text): + t = text.size(1) + order_mask = torch.triu(torch.ones(t,t), diagonal=1).bool() + order_mask = order_mask.unsqueeze(0).to(text.device) + + return order_mask + + def text_embedding(self, texts): + tgt = self.embedding(texts) + tgt *= math.sqrt(tgt.size(2)) + + return tgt + + def forward(self, feats, texts): + src = feats[self.src_from] + if self.src_mask_from: + src_mask = feats[self.src_mask_from] + else: + src_mask = None + + if self.training: + tgt = self.text_embedding(texts) + tgt_mask = (self.pad_mask(texts) | self.order_mask(texts)) + + out = self.decoder(tgt, src, tgt_mask, src_mask) + out = self.generator(out) + else: + out = None + for _ in range(self.num_steps): + tgt = self.text_embedding(texts) + tgt_mask = self.order_mask(texts) + out = self.decoder(tgt, src, tgt_mask, src_mask) + out = self.generator(out) + next_text = torch.argmax(out[:, -1:, :], dim=-1) + + texts = torch.cat([texts, next_text], dim=-1) + + return out diff --git a/vedastr/models/utils/__init__.py b/vedastr/models/utils/__init__.py index 6a0dd30..95ad12c 100644 --- a/vedastr/models/utils/__init__.py +++ b/vedastr/models/utils/__init__.py @@ -2,3 +2,4 @@ from .fc_module import FCModule, FCModules from .upsample import Upsample from .builder import build_module, build_torch_nn +from .norm import build_norm_layer diff --git a/vedastr/models/utils/conv_module.py b/vedastr/models/utils/conv_module.py index 8306acc..4b354cf 100644 --- a/vedastr/models/utils/conv_module.py +++ b/vedastr/models/utils/conv_module.py @@ -141,7 +141,7 @@ def __init__(self, self.activate = nn.Tanh() if self.with_dropout: - self.dropout = nn.Dropout2d(p=dropout) + self.dropout = nn.Dropout(p=dropout) @property def norm(self): diff --git a/vedastr/models/weight_init.py b/vedastr/models/weight_init.py index 6077db0..fbb6362 100644 --- a/vedastr/models/weight_init.py +++ b/vedastr/models/weight_init.py @@ -92,6 +92,6 @@ def init_weights(modules): elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): constant_init(m, 1) elif isinstance(m, nn.Linear): - kaiming_init(m) + xavier_init(m) elif isinstance(m, (nn.LSTM, nn.LSTMCell)): kaiming_init(m, is_rnn=True) diff --git a/vedastr/runner/runner.py b/vedastr/runner/runner.py index 16fca00..75879d1 100644 --- a/vedastr/runner/runner.py +++ b/vedastr/runner/runner.py @@ -1,9 +1,11 @@ -import torch import logging import os.path as osp -import torch.nn.functional as F -import numpy as np from collections.abc import Iterable +import re + +import numpy as np +import torch +from torch.nn import functional as F from vedastr.utils.checkpoint import load_checkpoint, save_checkpoint @@ -30,7 +32,6 @@ def __init__(self, lr_scheduler, iterations, workdir, - start_iters=0, trainval_ratio=1, snapshot_interval=1, gpu=True, @@ -45,7 +46,6 @@ def __init__(self, self.metric = metric self.optim = optim self.lr_scheduler = lr_scheduler - self.start_iters = start_iters self.iterations = iterations self.workdir = workdir self.trainval_ratio = trainval_ratio @@ -65,25 +65,22 @@ def __call__(self): else: self.metric.reset() logger.info('Start train...') - for iteration in range(self.start_iters, self.iterations): + for iteration in range(self.iterations): img, label = self.loader['train'].get_batch self.train_batch(img, label) if self.lr_scheduler: self.lr_scheduler.step() - self.c_iter = self.iter + 1 + self.c_iter = self.iter else: - self.c_iter = iteration + 1 + self.c_iter = iteration if self.trainval_ratio > 0 \ and (iteration + 1) % self.trainval_ratio == 0 \ and self.loader.get('val'): self.validate_epoch() self.metric.reset() - if (iteration + 1) % self.snapshot_interval == 0: - self.save_model(out_dir=self.workdir, - filename=f'iter{iteration + 1}.pth', - iteration=iteration, - ) + if (iteration+1) % self.snapshot_interval == 0: + self.save_model(out_dir=self.workdir, filename=f'iter{iteration+1}.pth', iteration=iteration) def validate_epoch(self): logger.info('Iteration %d, Start validating' % self.c_iter) @@ -92,17 +89,12 @@ def validate_epoch(self): self.validate_batch(img, label) if self.metric.avg['acc']['true'] >= self.best_acc: self.best_acc = self.metric.avg['acc']['true'] - self.save_model(out_dir=self.workdir, - filename='best_acc.pth', - iteration=self.c_iter) + self.save_model(out_dir=self.workdir, filename='best_acc.pth', iteration=self.c_iter) if self.metric.avg['edit'] >= self.best_norm: self.best_norm = self.metric.avg['edit'] - self.save_model(out_dir=self.workdir, - filename='best_norm.pth', - iteration=self.c_iter) + self.save_model(out_dir=self.workdir, filename='best_norm.pth', iteration=self.c_iter) logger.info('Validate, best_acc %.4f, best_edit %s' % (self.best_acc, self.best_norm)) - logger.info('Validate, acc %.4f, edit %s' % (self.metric.avg['acc']['true'], - self.metric.avg['edit'])) + logger.info('Validate, acc %.4f, edit %s' % (self.metric.avg['acc']['true'], self.metric.avg['edit'])) logger.info(f'\n{self.metric.predict_example_log}') def test_epoch(self): @@ -115,6 +107,35 @@ def test_epoch(self): logger.info('Test, acc %.4f, edit %s' % (self.metric.avg['acc']['true'], self.metric.avg['edit'])) + def postprocess(self, preds, test_cfg=None): + if test_cfg is not None: + sensitive = test_cfg.get('sensitive', True) + character = test_cfg.get('character', '') + else: + sensitive = True + character = '' + + probs = F.softmax(preds, dim=2) + max_probs, indexes = probs.max(dim=2) + preds_str = [] + preds_prob = [] + for i, pstr in enumerate(self.converter.decode(indexes)): + str_len = len(pstr) + if str_len == 0: + prob = 0 + else: + prob = max_probs[i, :str_len].cumprod(dim=0)[-1] + preds_prob.append(prob) + + if test_cfg: + if not sensitive: + pstr = pstr.lower() + if character: + pstr = re.sub('[^{}]'.format(character), '', pstr) + preds_str.append(pstr) + + return preds_str, preds_prob + def train_batch(self, img, label): self.model.train() @@ -137,13 +158,11 @@ def train_batch(self, img, label): torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optim.step() - preds_prob = F.softmax(pred, dim=2) - preds_prob, pred_index = preds_prob.max(dim=2) - pred_str = self.converter.decode(pred_index) - - self.metric.measure(pred_str, label, preds_prob) + with torch.no_grad(): + pred, prob = self.postprocess(pred) + self.metric.measure(pred, prob, label) - if self.c_iter % 10 == 0: + if self.c_iter != 0 and self.c_iter % 10 == 0: logger.info( 'Train, Iter %d, LR %s, Loss %.4f, acc %.4f, edit_distance %s' % (self.c_iter, self.lr, loss.item(), self.metric.avg['acc']['true'], @@ -162,11 +181,9 @@ def validate_batch(self, img, label): pred = self.model(img, label_input) else: pred = self.model(img) - preds_prob = F.softmax(pred, dim=2) - preds_prob, pred_index = preds_prob.max(dim=2) - pred_str = self.converter.decode(pred_index) - self.metric.measure(pred_str, label, preds_prob) + pred, prob = self.postprocess(pred, self.test_cfg) + self.metric.measure(pred, prob, label) def test_batch(self, img, label): self.model.eval() @@ -180,11 +197,9 @@ def test_batch(self, img, label): pred = self.model(img, label_input) else: pred = self.model(img) - preds_prob = F.softmax(pred, dim=2) - preds_prob, pred_index = preds_prob.max(dim=2) - pred_str = self.converter.decode(pred_index) - self.metric.measure(pred_str, label, preds_prob) + pred, prob = self.postprocess(pred, self.test_cfg) + self.metric.measure(pred, prob, label) def save_model(self, out_dir, @@ -193,9 +208,9 @@ def save_model(self, save_optimizer=True, meta=None): if meta is None: - meta = dict(iter=iteration + 1, lr=self.lr, iters=self.iterations) + meta = dict(iter=iteration, lr=self.lr) else: - meta.update(iter=iteration + 1, lr=self.lr, iters=self.iterations) + meta.update(iter=iteration, lr=self.lr) filepath = osp.join(out_dir, filename) optimizer = self.optim if save_optimizer else None @@ -235,8 +250,6 @@ def lr(self, val): def resume(self, checkpoint, - resume_lr=True, - resume_iters=True, resume_optimizer=False, map_location='default'): if map_location == 'default': @@ -248,10 +261,3 @@ def resume(self, checkpoint = self.load_checkpoint(checkpoint, map_location=map_location) if 'optimizer' in checkpoint and resume_optimizer: self.optim.load_state_dict(checkpoint['optimizer']) - if 'meta' in checkpoint and resume_iters: - self.iterations = checkpoint['meta']['iters'] - self.start_iters = checkpoint['meta']['iter'] - self.iter = checkpoint['meta']['iter'] - self.c_iter = self.start_iters + 1 - if 'meta' in checkpoint and resume_lr: - self.lr = checkpoint['meta']['lr'] diff --git a/vedastr/utils/metrics.py b/vedastr/utils/metrics.py index a82758b..657ad3a 100644 --- a/vedastr/utils/metrics.py +++ b/vedastr/utils/metrics.py @@ -1,40 +1,29 @@ # modify from clovaai -import torch.nn.functional as F from nltk.metrics.distance import edit_distance class STRMeters(object): - def __init__(self, converter): + def __init__(self): self.reset() - self.converter = converter self.predict_example_log = None - self.sample = [] - def measure(self, pred, gt, probs): + def measure(self, preds, preds_prob, gts): + batch_size = len(gts) true_num = 0 norm_ED = 0 - batch_size = len(gt) - sample_list = [] - confidence_list = [] - for pstr, gstr, pred_prob in zip(pred, gt, probs): - try: - confidence_score = pred[:len(pstr)].cumprod(dim=0)[-1] - except: - confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) - sample_list.append([pstr, gstr, pred_prob]) + for pstr, gstr in zip(preds, gts): if pstr == gstr: true_num += 1 + if len(pstr) == 0 or len(gstr) == 0: norm_ED += 0 elif len(gstr) > len(pstr): norm_ED += 1 - edit_distance(pstr, gstr) / len(gstr) else: norm_ED += 1 - edit_distance(pstr, gstr) / len(pstr) - confidence_list.append(confidence_score) - self.show_example(pred, gt, confidence_list) - self.sample = sample_list + self.show_example(preds, preds_prob, gts) self.all['acc']['true'] += true_num self.all['acc']['false'] += (batch_size - true_num) self.all['edit'] += norm_ED @@ -60,17 +49,17 @@ def reset(self): ) self.count = 0 - def show_example(self, preds, labels, confidence_score): + def show_example(self, preds, preds_prob, gts): count = 0 self.predict_example_log = None dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' self.predict_example_log = f'{dashed_line}\n{head}\n{dashed_line}\n' - for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): - - self.predict_example_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' + for gt, pred, prob in zip(gts[:5], preds[:5], preds_prob[:5]): + self.predict_example_log += f'{gt:25s} | {pred:25s} | {prob:0.4f}\t{str(pred == gt)}\n' count += 1 if count > 4: break self.predict_example_log += f'{dashed_line}' + return self.predict_example_log