Skip to content

Commit

Permalink
update train scripts (#1328)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Sep 26, 2023
1 parent f0ea666 commit 3271460
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 4 deletions.
2 changes: 1 addition & 1 deletion references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def main(args):
# Backbone freezing
if args.freeze_backbone:
for p in model.feat_extractor.parameters():
p.reguires_grad_(False)
p.requires_grad = False

# Optimizer
optimizer = torch.optim.Adam(
Expand Down
8 changes: 8 additions & 0 deletions references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def main(args):
plot_samples(x, target)
return

# Backbone freezing
if args.freeze_backbone:
for p in model.feat_extractor.parameters():
p.requires_grad = False

# Optimizer
optimizer = torch.optim.Adam(
[p for p in model.parameters() if p.requires_grad],
Expand Down Expand Up @@ -457,6 +462,9 @@ def parse_args():
parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training")
parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
parser.add_argument(
"--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
)
parser.add_argument(
"--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
)
Expand Down
31 changes: 28 additions & 3 deletions references/recognition/train_pytorch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms import ColorJitter, Compose, Normalize
from torchvision.transforms.v2 import (
Compose,
GaussianBlur,
Normalize,
RandomGrayscale,
RandomPerspective,
RandomPhotometricDistort,
)

from doctr import transforms as T
from doctr.datasets import VOCABS, RecognitionDataset, WordGenerator
Expand Down Expand Up @@ -170,6 +177,11 @@ def main(rank: int, world_size: int, args):
checkpoint = torch.load(args.resume, map_location="cpu")
model.load_state_dict(checkpoint)

# Backbone freezing
if args.freeze_backbone:
for p in model.feat_extractor.parameters():
p.requires_grad = False

# create default process group
device = torch.device("cuda", args.devices[rank])
dist.init_process_group(args.backend, rank=rank, world_size=world_size)
Expand Down Expand Up @@ -211,7 +223,12 @@ def main(rank: int, world_size: int, args):
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
RandomGrayscale(p=0.1),
RandomPhotometricDistort(p=0.1),
T.RandomApply(T.RandomShadow(), p=0.4),
T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1),
T.RandomApply(GaussianBlur(3), 0.3),
RandomPerspective(distortion_scale=0.2, p=0.3),
]
),
)
Expand All @@ -234,7 +251,12 @@ def main(rank: int, world_size: int, args):
T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True),
# Ensure we have a 90% split of white-background images
T.RandomApply(T.ColorInversion(), 0.9),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
RandomGrayscale(p=0.1),
RandomPhotometricDistort(p=0.1),
T.RandomApply(T.RandomShadow(), p=0.4),
T.RandomApply(T.GaussianNoise(mean=0, std=0.1), 0.1),
T.RandomApply(GaussianBlur(3), 0.3),
RandomPerspective(distortion_scale=0.2, p=0.3),
]
),
)
Expand Down Expand Up @@ -376,6 +398,9 @@ def parse_args():
parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training")
parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
parser.add_argument(
"--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
)
parser.add_argument(
"--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
)
Expand Down
8 changes: 8 additions & 0 deletions references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def main(args):
task = Task.init(project_name="docTR/text-recognition", task_name=exp_name, reuse_last_task_id=False)
task.upload_artifact("config", config)

# Backbone freezing
if args.freeze_backbone:
for layer in model.feat_extractor.layers:
layer.trainable = False

min_loss = np.inf

# Training loop
Expand Down Expand Up @@ -413,6 +418,9 @@ def parse_args():
parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint")
parser.add_argument("--vocab", type=str, default="french", help="Vocab to be used for training")
parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop")
parser.add_argument(
"--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning"
)
parser.add_argument(
"--show-samples", dest="show_samples", action="store_true", help="Display unormalized training samples"
)
Expand Down

0 comments on commit 3271460

Please sign in to comment.