From fb61c3fbdbcfad0bca4232be45dd5e3ff00d5bc4 Mon Sep 17 00:00:00 2001 From: Martin Zlocha Date: Wed, 24 Jul 2019 19:41:43 -0700 Subject: [PATCH 1/2] Add the option to seed training. --- CONTRIBUTORS.md | 1 + keras_retinanet/bin/train.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8b1863f4c..c3ca2993f 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -36,3 +36,4 @@ please create an issue or pull request at https://github.com/fizyr/keras-retinan * Eduardo Ramos * DiegoAgher * Alexander Pacha +* Martin Zlocha \ No newline at end of file diff --git a/keras_retinanet/bin/train.py b/keras_retinanet/bin/train.py index 1d1431a25..c5bc5db5e 100755 --- a/keras_retinanet/bin/train.py +++ b/keras_retinanet/bin/train.py @@ -21,9 +21,11 @@ import sys import warnings +import random import keras import keras.preprocessing.image import tensorflow as tf +import numpy as np # Allow relative imports when being executed as script. if __name__ == "__main__" and __package__ is None: @@ -429,6 +431,7 @@ def csv_list(string): parser.add_argument('--config', help='Path to a configuration parameters .ini file.') parser.add_argument('--weighted-average', help='Compute the mAP using the weighted average of precisions among classes.', action='store_true') parser.add_argument('--compute-val-loss', help='Compute validation loss during training', dest='compute_val_loss', action='store_true') + parser.add_argument('--seed', help='Seed value to use for training.') # Fit generator arguments parser.add_argument('--multiprocessing', help='Use multiprocessing in fit_generator.', action='store_true') @@ -444,6 +447,11 @@ def main(args=None): args = sys.argv[1:] args = parse_args(args) + if args.seed: + np.random.seed(args.seed) + tf.set_random_seed(args.seed) + random.seed(args.seed) + # create object that stores backbone information backbone = models.backbone(args.backbone) From 6928f9494e2a7a39112d2614d3ac61ae571c12eb Mon Sep 17 00:00:00 2001 From: Martin Zlocha Date: Fri, 9 Aug 2019 16:40:27 +0200 Subject: [PATCH 2/2] Pass seed to random_transform_generator. --- keras_retinanet/bin/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/keras_retinanet/bin/train.py b/keras_retinanet/bin/train.py index c5bc5db5e..c6e33c23d 100755 --- a/keras_retinanet/bin/train.py +++ b/keras_retinanet/bin/train.py @@ -225,6 +225,11 @@ def create_generators(args, preprocess_image): 'preprocess_image' : preprocess_image, } + prng = None + + if args.seed: + prng = np.random.RandomState(args.seed) + # create random transform generator for augmenting training data if args.random_transform: transform_generator = random_transform_generator( @@ -238,6 +243,7 @@ def create_generators(args, preprocess_image): max_scaling=(1.1, 1.1), flip_x_chance=0.5, flip_y_chance=0.5, + prng=prng ) visual_effect_generator = random_visual_effect_generator( contrast_range=(0.9, 1.1), @@ -246,7 +252,7 @@ def create_generators(args, preprocess_image): saturation_range=(0.95, 1.05) ) else: - transform_generator = random_transform_generator(flip_x_chance=0.5) + transform_generator = random_transform_generator(flip_x_chance=0.5, prng=prng) visual_effect_generator = None if args.dataset_type == 'coco':