From cafb64bc611a59b48ab39601b044a0ee544010ed Mon Sep 17 00:00:00 2001 From: Olivier Dulcy Date: Thu, 1 Feb 2024 23:19:57 +0100 Subject: [PATCH] from https://github.com/mindee/doctr/pull/1444 --- .../differentiable_binarization/pytorch.py | 40 +++++++------- references/detection/train_pytorch.py | 38 ++++++++------ references/detection/train_tensorflow.py | 52 ++++++++++++------- 3 files changed, 77 insertions(+), 53 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index f11408bd3d..e3022f0a1b 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -57,24 +57,28 @@ def __init__( conv_layer = DeformConv2d if deform_conv else nn.Conv2d - self.in_branches = nn.ModuleList([ - nn.Sequential( - conv_layer(chans, out_channels, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - for idx, chans in enumerate(in_channels) - ]) + self.in_branches = nn.ModuleList( + [ + nn.Sequential( + conv_layer(chans, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + for idx, chans in enumerate(in_channels) + ] + ) self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.out_branches = nn.ModuleList([ - nn.Sequential( - conv_layer(out_channels, out_chans, 3, padding=1, bias=False), - nn.BatchNorm2d(out_chans), - nn.ReLU(inplace=True), - nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), - ) - for idx, chans in enumerate(in_channels) - ]) + self.out_branches = nn.ModuleList( + [ + nn.Sequential( + conv_layer(out_channels, out_chans, 3, padding=1, bias=False), + nn.BatchNorm2d(out_chans), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True), + ) + for idx, chans in enumerate(in_channels) + ] + ) def forward(self, x: List[torch.Tensor]) -> torch.Tensor: if len(x) != len(self.out_branches): @@ -268,7 +272,7 @@ def compute_loss( dice_map = torch.softmax(out_map, dim=1) else: # compute binary map instead - dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) + dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment] # Class reduced inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3)) cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3)) diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 6f01f6e9e3..c7914beaff 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -17,7 +17,7 @@ import psutil import torch import wandb -from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort from tqdm.auto import tqdm @@ -266,15 +266,17 @@ def main(args): train_set = DetectionDataset( img_folder=os.path.join(args.train_path, "images"), label_path=os.path.join(args.train_path, "labels.json"), - img_transforms=Compose([ - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), - T.RandomApply(T.RandomShadow(), 0.1), - T.RandomApply(GaussianBlur(kernel_size=3), 0.1), - RandomPhotometricDistort(p=0.05), - RandomGrayscale(p=0.05), - ]), + img_transforms=Compose( + [ + # Augmentations + T.RandomApply(T.ColorInversion(), 0.1), + T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), + T.RandomApply(T.RandomShadow(), 0.1), + T.RandomApply(GaussianBlur(kernel_size=3), 0.1), + RandomPhotometricDistort(p=0.05), + RandomGrayscale(p=0.05), + ] + ), sample_transforms=T.SampleCompose( ( [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] @@ -335,6 +337,8 @@ def main(args): scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) elif args.sched == "onecycle": scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + elif args.sched == "poly": + scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") @@ -385,12 +389,14 @@ def main(args): print(log_msg) # W&B if args.wb: - wandb.log({ - "val_loss": val_loss, - "recall": recall, - "precision": precision, - "mean_iou": mean_iou, - }) + wandb.log( + { + "val_loss": val_loss, + "recall": recall, + "precision": precision, + "mean_iou": mean_iou, + } + ) if args.early_stop and early_stopper.early_stop(val_loss): print("Training halted early due to reaching patience limit.") break diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 55ea755edf..14475aff4d 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -221,18 +221,20 @@ def main(args): train_set = DetectionDataset( img_folder=os.path.join(args.train_path, "images"), label_path=os.path.join(args.train_path, "labels.json"), - img_transforms=T.Compose([ - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - T.RandomJpegQuality(60), - #T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), - #T.RandomApply(T.RandomShadow(), 0.1), - #T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1), - T.RandomSaturation(0.3), - T.RandomContrast(0.3), - T.RandomBrightness(0.3), - T.RandomApply(T.ToGray(num_output_channels=3), 0.05), - ]), + img_transforms=T.Compose( + [ + # Augmentations + T.RandomApply(T.ColorInversion(), 0.1), + T.RandomJpegQuality(60), + # T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), + # T.RandomApply(T.RandomShadow(), 0.1), + # T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1), + T.RandomSaturation(0.3), + T.RandomContrast(0.3), + T.RandomBrightness(0.3), + T.RandomApply(T.ToGray(num_output_channels=3), 0.05), + ] + ), sample_transforms=T.SampleCompose( ( [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] @@ -270,14 +272,25 @@ def main(args): plot_samples(x, target) return + # Scheduler + if args.sched == "exponential": + scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + args.lr, + decay_steps=args.epochs * len(train_loader), + decay_rate=1 / (25e4), # final lr as a fraction of initial lr + staircase=False, + name="ExponentialDecay", + ) + elif args.sched == "poly": + scheduler = tf.keras.optimizers.schedules.PolynomialDecay( + args.lr, + decay_steps=args.epochs * len(train_loader), + end_learning_rate=1e-7, + power=1.0, + cycle=False, + name="PolynomialDecay", + ) # Optimizer - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( - args.lr, - decay_steps=args.epochs * len(train_loader), - decay_rate=1 / (25e4), # final lr as a fraction of initial lr - staircase=False, - name="ExponentialDecay", - ) optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) if args.amp: optimizer = mixed_precision.LossScaleOptimizer(optimizer) @@ -413,6 +426,7 @@ def parse_args(): action="store_true", help="metrics evaluation with straight boxes instead of polygons to save time + memory", ) + parser.add_argument("--sched", type=str, default="exponential", help="scheduler to use") parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") parser.add_argument("--early-stop", action="store_true", help="Enable early stopping")