diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index a4652affe5..55ea755edf 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -178,9 +178,11 @@ def main(args): with open(os.path.join(args.val_path, "labels.json"), "rb") as f: val_hash = hashlib.sha256(f.read()).hexdigest() - batch_transforms = T.Compose([ - T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), - ]) + batch_transforms = T.Compose( + [ + T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), + ] + ) # Load doctr model model = detection.__dict__[args.arch]( @@ -223,9 +225,9 @@ def main(args): # Augmentations T.RandomApply(T.ColorInversion(), 0.1), T.RandomJpegQuality(60), - T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), - T.RandomApply(T.RandomShadow(), 0.1), - T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1), + #T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1), + #T.RandomApply(T.RandomShadow(), 0.1), + #T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1), T.RandomSaturation(0.3), T.RandomContrast(0.3), T.RandomBrightness(0.3), @@ -342,12 +344,14 @@ def main(args): print(log_msg) # W&B if args.wb: - wandb.log({ - "val_loss": val_loss, - "recall": recall, - "precision": precision, - "mean_iou": mean_iou, - }) + wandb.log( + { + "val_loss": val_loss, + "recall": recall, + "precision": precision, + "mean_iou": mean_iou, + } + ) # ClearML if args.clearml: