diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index d6eff53464..6330425de0 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -178,9 +178,11 @@ def main(args): with open(os.path.join(args.val_path, "labels.json"), "rb") as f: val_hash = hashlib.sha256(f.read()).hexdigest() - batch_transforms = T.Compose([ - T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), - ]) + batch_transforms = T.Compose( + [ + T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), + ] + ) # Load doctr model model = detection.__dict__[args.arch]( @@ -219,18 +221,20 @@ def main(args): train_set = DetectionDataset( img_folder=os.path.join(args.train_path, "images"), label_path=os.path.join(args.train_path, "labels.json"), - img_transforms=T.Compose([ - # Augmentations - T.RandomApply(T.ColorInversion(), 0.1), - T.RandomJpegQuality(60), - T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), - T.RandomApply(T.RandomShadow(), 0.4), - T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3), - T.RandomSaturation(0.3), - T.RandomContrast(0.3), - T.RandomBrightness(0.3), - T.RandomApply(T.ToGray(num_output_channels=3), 0.1), - ]), + img_transforms=T.Compose( + [ + # Augmentations + T.RandomApply(T.ColorInversion(), 0.1), + T.RandomJpegQuality(60), + # T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), + # T.RandomApply(T.RandomShadow(), 0.4), + # T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3), + T.RandomSaturation(0.3), + T.RandomContrast(0.3), + T.RandomBrightness(0.3), + T.RandomApply(T.ToGray(num_output_channels=3), 0.1), + ] + ), sample_transforms=T.SampleCompose( ( [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] @@ -342,12 +346,14 @@ def main(args): print(log_msg) # W&B if args.wb: - wandb.log({ - "val_loss": val_loss, - "recall": recall, - "precision": precision, - "mean_iou": mean_iou, - }) + wandb.log( + { + "val_loss": val_loss, + "recall": recall, + "precision": precision, + "mean_iou": mean_iou, + } + ) # ClearML if args.clearml: