Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Feb 1, 2024
1 parent b3c0ded commit cafb64b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 53 deletions.
40 changes: 22 additions & 18 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,28 @@ def __init__(

conv_layer = DeformConv2d if deform_conv else nn.Conv2d

self.in_branches = nn.ModuleList([
nn.Sequential(
conv_layer(chans, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
for idx, chans in enumerate(in_channels)
])
self.in_branches = nn.ModuleList(
[
nn.Sequential(
conv_layer(chans, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
for idx, chans in enumerate(in_channels)
]
)
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.out_branches = nn.ModuleList([
nn.Sequential(
conv_layer(out_channels, out_chans, 3, padding=1, bias=False),
nn.BatchNorm2d(out_chans),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True),
)
for idx, chans in enumerate(in_channels)
])
self.out_branches = nn.ModuleList(
[
nn.Sequential(
conv_layer(out_channels, out_chans, 3, padding=1, bias=False),
nn.BatchNorm2d(out_chans),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2**idx, mode="bilinear", align_corners=True),
)
for idx, chans in enumerate(in_channels)
]
)

def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
if len(x) != len(self.out_branches):
Expand Down Expand Up @@ -268,7 +272,7 @@ def compute_loss(
dice_map = torch.softmax(out_map, dim=1)
else:
# compute binary map instead
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
# Class reduced
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
Expand Down
38 changes: 22 additions & 16 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import psutil
import torch
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort
from tqdm.auto import tqdm
Expand Down Expand Up @@ -266,15 +266,17 @@ def main(args):
train_set = DetectionDataset(
img_folder=os.path.join(args.train_path, "images"),
label_path=os.path.join(args.train_path, "labels.json"),
img_transforms=Compose([
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
T.RandomApply(T.RandomShadow(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=3), 0.1),
RandomPhotometricDistort(p=0.05),
RandomGrayscale(p=0.05),
]),
img_transforms=Compose(
[
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
T.RandomApply(T.RandomShadow(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=3), 0.1),
RandomPhotometricDistort(p=0.05),
RandomGrayscale(p=0.05),
]
),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
Expand Down Expand Up @@ -335,6 +337,8 @@ def main(args):
scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
elif args.sched == "onecycle":
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))
elif args.sched == "poly":
scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader))

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -385,12 +389,14 @@ def main(args):
print(log_msg)
# W&B
if args.wb:
wandb.log({
"val_loss": val_loss,
"recall": recall,
"precision": precision,
"mean_iou": mean_iou,
})
wandb.log(
{
"val_loss": val_loss,
"recall": recall,
"precision": precision,
"mean_iou": mean_iou,
}
)
if args.early_stop and early_stopper.early_stop(val_loss):
print("Training halted early due to reaching patience limit.")
break
Expand Down
52 changes: 33 additions & 19 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,18 +221,20 @@ def main(args):
train_set = DetectionDataset(
img_folder=os.path.join(args.train_path, "images"),
label_path=os.path.join(args.train_path, "labels.json"),
img_transforms=T.Compose([
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomJpegQuality(60),
#T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
#T.RandomApply(T.RandomShadow(), 0.1),
#T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1),
T.RandomSaturation(0.3),
T.RandomContrast(0.3),
T.RandomBrightness(0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.05),
]),
img_transforms=T.Compose(
[
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomJpegQuality(60),
# T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
# T.RandomApply(T.RandomShadow(), 0.1),
# T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1),
T.RandomSaturation(0.3),
T.RandomContrast(0.3),
T.RandomBrightness(0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.05),
]
),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
Expand Down Expand Up @@ -270,14 +272,25 @@ def main(args):
plot_samples(x, target)
return

# Scheduler
if args.sched == "exponential":
scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
args.lr,
decay_steps=args.epochs * len(train_loader),
decay_rate=1 / (25e4), # final lr as a fraction of initial lr
staircase=False,
name="ExponentialDecay",
)
elif args.sched == "poly":
scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
args.lr,
decay_steps=args.epochs * len(train_loader),
end_learning_rate=1e-7,
power=1.0,
cycle=False,
name="PolynomialDecay",
)
# Optimizer
scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
args.lr,
decay_steps=args.epochs * len(train_loader),
decay_rate=1 / (25e4), # final lr as a fraction of initial lr
staircase=False,
name="ExponentialDecay",
)
optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5)
if args.amp:
optimizer = mixed_precision.LossScaleOptimizer(optimizer)
Expand Down Expand Up @@ -413,6 +426,7 @@ def parse_args():
action="store_true",
help="metrics evaluation with straight boxes instead of polygons to save time + memory",
)
parser.add_argument("--sched", type=str, default="exponential", help="scheduler to use")
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR")
parser.add_argument("--early-stop", action="store_true", help="Enable early stopping")
Expand Down

0 comments on commit cafb64b

Please sign in to comment.