Skip to content

Commit

Permalink
Problems with augmentations involving _gaussian_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee committed Jan 16, 2024
1 parent ea17ffb commit 8318940
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,11 @@ def main(args):
with open(os.path.join(args.val_path, "labels.json"), "rb") as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

batch_transforms = T.Compose([
T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)),
])
batch_transforms = T.Compose(
[
T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)),
]
)

# Load doctr model
model = detection.__dict__[args.arch](
Expand Down Expand Up @@ -219,18 +221,20 @@ def main(args):
train_set = DetectionDataset(
img_folder=os.path.join(args.train_path, "images"),
label_path=os.path.join(args.train_path, "labels.json"),
img_transforms=T.Compose([
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomJpegQuality(60),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
T.RandomApply(T.RandomShadow(), 0.4),
T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3),
T.RandomSaturation(0.3),
T.RandomContrast(0.3),
T.RandomBrightness(0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.1),
]),
img_transforms=T.Compose(
[
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomJpegQuality(60),
# T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
# T.RandomApply(T.RandomShadow(), 0.4),
# T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3),
T.RandomSaturation(0.3),
T.RandomContrast(0.3),
T.RandomBrightness(0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.1),
]
),
sample_transforms=T.SampleCompose(
(
[T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)]
Expand Down Expand Up @@ -342,12 +346,14 @@ def main(args):
print(log_msg)
# W&B
if args.wb:
wandb.log({
"val_loss": val_loss,
"recall": recall,
"precision": precision,
"mean_iou": mean_iou,
})
wandb.log(
{
"val_loss": val_loss,
"recall": recall,
"precision": precision,
"mean_iou": mean_iou,
}
)

# ClearML
if args.clearml:
Expand Down

0 comments on commit 8318940

Please sign in to comment.