diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8b1863f4c..c3ca2993f 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -36,3 +36,4 @@ please create an issue or pull request at https://github.com/fizyr/keras-retinan * Eduardo Ramos * DiegoAgher * Alexander Pacha +* Martin Zlocha \ No newline at end of file diff --git a/keras_retinanet/bin/train.py b/keras_retinanet/bin/train.py index 1d1431a25..c6e33c23d 100755 --- a/keras_retinanet/bin/train.py +++ b/keras_retinanet/bin/train.py @@ -21,9 +21,11 @@ import sys import warnings +import random import keras import keras.preprocessing.image import tensorflow as tf +import numpy as np # Allow relative imports when being executed as script. if __name__ == "__main__" and __package__ is None: @@ -223,6 +225,11 @@ def create_generators(args, preprocess_image): 'preprocess_image' : preprocess_image, } + prng = None + + if args.seed: + prng = np.random.RandomState(args.seed) + # create random transform generator for augmenting training data if args.random_transform: transform_generator = random_transform_generator( @@ -236,6 +243,7 @@ def create_generators(args, preprocess_image): max_scaling=(1.1, 1.1), flip_x_chance=0.5, flip_y_chance=0.5, + prng=prng ) visual_effect_generator = random_visual_effect_generator( contrast_range=(0.9, 1.1), @@ -244,7 +252,7 @@ def create_generators(args, preprocess_image): saturation_range=(0.95, 1.05) ) else: - transform_generator = random_transform_generator(flip_x_chance=0.5) + transform_generator = random_transform_generator(flip_x_chance=0.5, prng=prng) visual_effect_generator = None if args.dataset_type == 'coco': @@ -429,6 +437,7 @@ def csv_list(string): parser.add_argument('--config', help='Path to a configuration parameters .ini file.') parser.add_argument('--weighted-average', help='Compute the mAP using the weighted average of precisions among classes.', action='store_true') parser.add_argument('--compute-val-loss', help='Compute validation loss during training', dest='compute_val_loss', action='store_true') + parser.add_argument('--seed', help='Seed value to use for training.') # Fit generator arguments parser.add_argument('--multiprocessing', help='Use multiprocessing in fit_generator.', action='store_true') @@ -444,6 +453,11 @@ def main(args=None): args = sys.argv[1:] args = parse_args(args) + if args.seed: + np.random.seed(args.seed) + tf.set_random_seed(args.seed) + random.seed(args.seed) + # create object that stores backbone information backbone = models.backbone(args.backbone)