diff --git a/keras_retinanet/bin/evaluate_coco.py b/keras_retinanet/bin/evaluate_coco.py index cb0203184..fcc88efa0 100755 --- a/keras_retinanet/bin/evaluate_coco.py +++ b/keras_retinanet/bin/evaluate_coco.py @@ -72,14 +72,10 @@ def main(args=None): print('Loading model, this may take a second...') model = keras.models.load_model(args.model, custom_objects=custom_objects) - # create image data generator object - test_image_data_generator = keras.preprocessing.image.ImageDataGenerator() - # create a generator for testing data test_generator = CocoGenerator( args.coco_path, args.set, - test_image_data_generator, ) evaluate_coco(test_generator, model, args.score_threshold) diff --git a/keras_retinanet/bin/train.py b/keras_retinanet/bin/train.py index 06d87c494..0fdcdff36 100755 --- a/keras_retinanet/bin/train.py +++ b/keras_retinanet/bin/train.py @@ -38,6 +38,7 @@ from ..preprocessing.pascal_voc import PascalVocGenerator from ..preprocessing.csv_generator import CSVGenerator from ..models.resnet import resnet50_retinanet +from ..utils.transform import random_transform_generator from ..utils.keras_version import check_keras_version @@ -109,11 +110,8 @@ def create_callbacks(model, training_model, prediction_model, validation_generat def create_generators(args): - # create image data generator objects - train_image_data_generator = keras.preprocessing.image.ImageDataGenerator( - horizontal_flip=True, - ) - val_image_data_generator = keras.preprocessing.image.ImageDataGenerator() + # create random transform generator for augmenting training data + transform_generator = random_transform_generator(flip_x_chance=0.5) if args.dataset_type == 'coco': # import here to prevent unnecessary dependency on cocoapi @@ -122,35 +120,33 @@ def create_generators(args): train_generator = CocoGenerator( args.coco_path, 'train2017', - train_image_data_generator, + transform_generator=transform_generator, batch_size=args.batch_size ) validation_generator = CocoGenerator( args.coco_path, 'val2017', - val_image_data_generator, batch_size=args.batch_size ) elif args.dataset_type == 'pascal': train_generator = PascalVocGenerator( args.pascal_path, 'trainval', - train_image_data_generator, + transform_generator=transform_generator, batch_size=args.batch_size ) validation_generator = PascalVocGenerator( args.pascal_path, 'test', - val_image_data_generator, batch_size=args.batch_size ) elif args.dataset_type == 'csv': train_generator = CSVGenerator( args.annotations, args.classes, - train_image_data_generator, + transform_generator=transform_generator, batch_size=args.batch_size ) @@ -158,7 +154,6 @@ def create_generators(args): validation_generator = CSVGenerator( args.val_annotations, args.classes, - val_image_data_generator, batch_size=args.batch_size ) else: diff --git a/keras_retinanet/preprocessing/coco.py b/keras_retinanet/preprocessing/coco.py index 94e5195fb..fc0637100 100644 --- a/keras_retinanet/preprocessing/coco.py +++ b/keras_retinanet/preprocessing/coco.py @@ -24,7 +24,7 @@ class CocoGenerator(Generator): - def __init__(self, data_dir, set_name, image_data_generator, *args, **kwargs): + def __init__(self, data_dir, set_name, **kwargs): self.data_dir = data_dir self.set_name = set_name self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json')) @@ -32,7 +32,7 @@ def __init__(self, data_dir, set_name, image_data_generator, *args, **kwargs): self.load_classes() - super(CocoGenerator, self).__init__(image_data_generator, **kwargs) + super(CocoGenerator, self).__init__(**kwargs) def load_classes(self): # load class names (name -> label) diff --git a/keras_retinanet/preprocessing/csv_generator.py b/keras_retinanet/preprocessing/csv_generator.py index abb02c414..0e2c22df9 100644 --- a/keras_retinanet/preprocessing/csv_generator.py +++ b/keras_retinanet/preprocessing/csv_generator.py @@ -108,7 +108,6 @@ def __init__( self, csv_data_file, csv_class_file, - image_data_generator, base_dir=None, **kwargs ): @@ -139,7 +138,7 @@ def __init__( raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None) self.image_names = list(self.image_data.keys()) - super(CSVGenerator, self).__init__(image_data_generator, **kwargs) + super(CSVGenerator, self).__init__(**kwargs) def size(self): return len(self.image_names) diff --git a/keras_retinanet/preprocessing/generator.py b/keras_retinanet/preprocessing/generator.py index d94d86b59..8a7a95604 100644 --- a/keras_retinanet/preprocessing/generator.py +++ b/keras_retinanet/preprocessing/generator.py @@ -22,31 +22,35 @@ import keras -from ..utils.image import preprocess_image, resize_image, random_transform from ..utils.anchors import anchor_targets_bbox +from ..utils.image import ( + TransformParameters, + adjust_transform_for_image, + apply_transform, + preprocess_image, + resize_image, +) +from ..utils.transform import transform_aabb class Generator(object): def __init__( self, - image_data_generator, + transform_generator = None, batch_size=1, group_method='ratio', # one of 'none', 'random', 'ratio' shuffle_groups=True, image_min_side=600, image_max_side=1024, - seed=None + transform_parameters=None, ): - self.image_data_generator = image_data_generator + self.transform_generator = transform_generator self.batch_size = int(batch_size) self.group_method = group_method self.shuffle_groups = shuffle_groups self.image_min_side = image_min_side self.image_max_side = image_max_side - - if seed is None: - seed = np.uint32((time.time() % 1)) * 1000 - np.random.seed(seed) + self.transform_parameters = transform_parameters or TransformParameters() self.group_index = 0 self.lock = threading.Lock() @@ -112,19 +116,32 @@ def resize_image(self, image): def preprocess_image(self, image): return preprocess_image(image) - def preprocess_group(self, image_group, annotations_group): - for index, (image, annotations) in enumerate(zip(image_group, annotations_group)): - # preprocess the image (subtract imagenet mean) - image = self.preprocess_image(image) + def preprocess_group_entry(self, image, annotations): + # preprocess the image + image = self.preprocess_image(image) + + # randomly transform both image and annotations + if self.transform_generator: + transform = adjust_transform_for_image(next(self.transform_generator), image) + apply_transform(transform, image, self.transform_parameters) - # randomly transform both image and annotations - image, annotations = random_transform(image, annotations, self.image_data_generator) + # Transform the bounding boxes in the annotations. + annotations = annotations.copy() + for index in range(annotations.shape[0]): + annotations[index, :4] = transform_aabb(transform, annotations[index, :4]) - # resize image - image, image_scale = self.resize_image(image) + # resize image + image, image_scale = self.resize_image(image) - # apply resizing to annotations too - annotations[:, :4] *= image_scale + # apply resizing to annotations too + annotations[:, :4] *= image_scale + + return image, annotations + + def preprocess_group(self, image_group, annotations_group): + for index, (image, annotations) in enumerate(zip(image_group, annotations_group)): + # preprocess a single group entry + image, annotations = self.preprocess_group_entry(image, annotations) # copy processed data back to group image_group[index] = image diff --git a/keras_retinanet/preprocessing/pascal_voc.py b/keras_retinanet/preprocessing/pascal_voc.py index 8db5f8f14..e273aacb9 100644 --- a/keras_retinanet/preprocessing/pascal_voc.py +++ b/keras_retinanet/preprocessing/pascal_voc.py @@ -71,7 +71,6 @@ def __init__( self, data_dir, set_name, - image_data_generator, classes=voc_classes, image_extension='.jpg', skip_truncated=False, @@ -90,7 +89,7 @@ def __init__( for key, value in self.classes.items(): self.labels[value] = key - super(PascalVocGenerator, self).__init__(image_data_generator, **kwargs) + super(PascalVocGenerator, self).__init__(**kwargs) def size(self): return len(self.image_names) diff --git a/keras_retinanet/utils/image.py b/keras_retinanet/utils/image.py index c523e8c82..89e9c7a36 100644 --- a/keras_retinanet/utils/image.py +++ b/keras_retinanet/utils/image.py @@ -21,6 +21,8 @@ import cv2 import PIL +from .transform import change_transform_origin, transform_aabb, colvec + def read_image_bgr(path): image = np.asarray(PIL.Image.open(path).convert('RGB')) @@ -48,44 +50,61 @@ def preprocess_image(x): return x -def random_transform( - image, - boxes, - image_data_generator, - seed=None -): - if seed is None: - seed = np.random.randint(10000) +def adjust_transform_for_image(transform, image): + """ Adjust a transformation for a specific image. + + The translation of the matrix will be scaled with the size of the image. + The linear part of the transformation will adjusted so that the origin of the transformation will be at the center of the image. + """ + height, width, channels = image.shape - image = image_data_generator.random_transform(image, seed=seed) + # Move the origin of transformation. + result = change_transform_origin(transform, (0.5 * width, 0.5 * height)) - # set fill mode so that masks are not enlarged - fill_mode = image_data_generator.fill_mode - image_data_generator.fill_mode = 'constant' + # Scale the translation with the image size. + result[0:2, 2] *= [width, height] - for index in range(boxes.shape[0]): - # generate box mask and randomly transform it - mask = np.zeros_like(image, dtype=np.uint8) - b = boxes[index, :4].astype(int) + return result - assert(b[0] < b[2] and b[1] < b[3]), 'Annotations contain invalid box: {}'.format(b) - assert(b[2] <= image.shape[1] and b[3] <= image.shape[0]), 'Annotation ({}) is outside of image shape ({}).'.format(b, image.shape) - mask[b[1]:b[3], b[0]:b[2], :] = 255 - mask = image_data_generator.random_transform(mask, seed=seed)[..., 0] - mask = mask.copy() # to force contiguous arrays +class TransformParameters: + """ Struct holding parameters determining how to transform images. - # find bounding box again in augmented image - [i, j] = np.where(mask == 255) - boxes[index, 0] = float(min(j)) - boxes[index, 1] = float(min(i)) - boxes[index, 2] = float(max(j)) + 1 # set box to an open interval [min, max) - boxes[index, 3] = float(max(i)) + 1 # set box to an open interval [min, max) + # Arguments + fill_mode: Same as for keras.preprocessing.image.apply_transform + cval: Same as for keras.preprocessing.image.apply_transform + data_format: Same as for keras.preprocessing.image.apply_transform + """ + def __init__( + self, + fill_mode = 'nearest', + cval = 0, + data_format = None, + ): + self.fill_mode = fill_mode + self.cval = cval - # restore fill_mode - image_data_generator.fill_mode = fill_mode + if data_format is None: + data_format = keras.backend.image_data_format() + self.data_format = data_format - return image, boxes + if data_format == 'channels_first': + self.channel_axis = 0 + elif data_format == 'channels_last': + self.channel_axis = 2 + else: + raise ValueError("invalid data_format, expected 'channels_first' or 'channels_last', got '{}'".format(data_format)) + + +def apply_transform(transform, image, params): + """ Wrapper around keras.preprocessing.image.apply_transform using TransformParameters. """ + return keras.preprocessing.image.apply_transform( + image, + transform, + channel_axis=params.channel_axis, + fill_mode=params.fill_mode, + cval=params.cval + ) def resize_image(img, min_side=600, max_side=1024): diff --git a/keras_retinanet/utils/transform.py b/keras_retinanet/utils/transform.py new file mode 100644 index 000000000..96473600d --- /dev/null +++ b/keras_retinanet/utils/transform.py @@ -0,0 +1,241 @@ +import numpy as np + +DEFAULT_PRNG = np.random + + +def colvec(*args): + """ Create a numpy array representing a column vector. """ + return np.array([args]).T + + +def transform_aabb(transform, aabb): + """ Apply a transformation to an axis aligned bounding box. + + The result is a new AABB in the same coordinate system as the original AABB. + The new AABB contains all corner points of the original AABB after applying the given transformation. + + # Arguments + transform: The transormation to apply. + x1: The minimum X value of the AABB. + y1: The minimum y value of the AABB. + x2: The maximum X value of the AABB. + y2: The maximum y value of the AABB. + # Returns + The new AABB as tuple (x1, y1, x2, y2) + """ + x1, y1, x2, y2 = aabb + # Transform all 4 corners of the AABB. + points = transform.dot([ + [x1, x2, x1, x2], + [y1, y2, y2, y1], + [1, 1, 1, 1 ], + ]) + + # Extract the min and max corners again. + min_corner = points.min(axis=1) + max_corner = points.max(axis=1) + + return [min_corner[0], min_corner[1], max_corner[0], max_corner[1]] + + +def _random_vector(min, max, prng=DEFAULT_PRNG): + """ Construct a random vector between min and max. + # Arguments + min: the minimum value for each component + max: the maximum value for each component + """ + min = np.array(min) + max = np.array(max) + assert min.shape == max.shape + assert len(min.shape) == 1 + return prng.uniform(min, max) + + +def rotation(angle): + """ Construct a homogeneous 2D rotation matrix. + # Arguments + angle: the angle in radians + # Returns + the rotation matrix as 3 by 3 numpy array + """ + return np.array([ + [np.cos(angle), -np.sin(angle), 0], + [np.sin(angle), np.cos(angle), 0], + [0, 0, 1] + ]) + + +def random_rotation(min, max, prng=DEFAULT_PRNG): + """ Construct a random rotation between -max and max. + # Arguments + min: a scalar for the minumum absolute angle in radians + max: a scalar for the maximum absolute angle in radians + prng: the pseudo-random number generator to use. + # Returns + a homogeneous 3 by 3 rotation matrix + """ + return rotation(prng.uniform(min, max)) + + +def translation(translation): + """ Construct a homogeneous 2D translation matrix. + # Arguments + translation: the translation 2D vector + # Returns + the translation matrix as 3 by 3 numpy array + """ + return np.array([ + [1, 0, translation[0]], + [0, 1, translation[1]], + [0, 0, 1] + ]) + + +def random_translation(min, max, prng=DEFAULT_PRNG): + """ Construct a random 2D translation between min and max. + # Arguments + min: a 2D vector with the minumum translation for each dimension + max: a 2D vector with the maximum translation for each dimension + prng: the pseudo-random number generator to use. + # Returns + a homogeneous 3 by 3 translation matrix + """ + return translation(_random_vector(min, max, prng)) + + +def shear(amount): + """ Construct a homogeneous 2D shear matrix. + # Arguments + amount: the shear amount + # Returns + the shear matrix as 3 by 3 numpy array + """ + return np.array([ + [1, -np.sin(amount), 0], + [0, np.cos(amount), 0], + [0, 0, 1] + ]) + + +def random_shear(min, max, prng=DEFAULT_PRNG): + """ Construct a random 2D shear matrix with shear angle between -max and max. + # Arguments + min: the minumum shear factor. + max: the maximum shear factor. + prng: the pseudo-random number generator to use. + # Returns + a homogeneous 3 by 3 shear matrix + """ + return shear(prng.uniform(min, max)) + + +def scaling(factor): + """ Construct a homogeneous 2D scaling matrix. + # Arguments + factor: a 2D vector for X and Y scaling + # Returns + the zoom matrix as 3 by 3 numpy array + """ + return np.array([ + [factor[0], 0, 0], + [0, factor[1], 0], + [0, 0, 1] + ]) + + +def random_scaling(min, max, prng=DEFAULT_PRNG): + """ Construct a random 2D scale matrix between -max and max. + # Arguments + min: a 2D vector containing the minimum scaling factor for X and Y. + min: a 2D vector containing The maximum scaling factor for X and Y. + prng: the pseudo-random number generator to use. + # Returns + a homogeneous 3 by 3 scaling matrix + """ + return scaling(_random_vector(min, max, prng)) + + +def random_flip(flip_x_chance, flip_y_chance, prng=DEFAULT_PRNG): + """ Construct a transformation randomly containing X/Y flips (or not). + # Arguments + flip_x_chance: The chance that the result will contain a flip along the X axis. + flip_y_chance: The chance that the result will contain a flip along the Y axis. + prng: The pseudo-random number generator to use. + # Returns + a homogeneous 3 by 3 transformation matrix + """ + flip_x = prng.uniform(0, 1) < flip_x_chance + flip_y = prng.uniform(0, 1) < flip_y_chance + # 1 - 2 * bool gives 1 for False and -1 for True. + return scaling((1 - 2 * flip_x, 1 - 2 * flip_y)) + + +def change_transform_origin(transform, center): + """ Create a new transform with the origin at a different location. + # Arguments: + transform: the transformation matrix + center: the new origin of the transformation + # Return: + translate(center) * transform * translate(-center) + """ + center = np.array(center) + return np.dot(np.dot(translation(center), transform), translation(-center)) + + +def random_transform( + min_rotation=0, + max_rotation=0, + min_translation=(0, 0), + max_translation=(0, 0), + min_shear=0, + max_shear=0, + min_scaling=(1, 1), + max_scaling=(1, 1), + flip_x_chance=0, + flip_y_chance=0, + prng=DEFAULT_PRNG +): + """ Create a random transformation. + + The transformation consists of the following operations in this order (from left to right): + * rotation + * translation + * shear + * scaling + * flip x (if applied) + * flip y (if applied) + + # Arguments + min_rotation: The minimum rotation for the transform as scalar. + max_rotation: The maximum rotation for the transform as scalar. + min_translation: The minimum translation for the transform as 2D column vector. + max_translation: The maximum translation for the transform as 2D column vector. + min_shear: The minimum shear for the transform as scalar. + max_shear: The maximum shear for the transform as scalar. + min_scaling: The minimum scaling for the transform as 2D column vector. + max_scaling: The maximum scaling for the transform as 2D column vector. + flip_x_chance: The chance (0 to 1) that a transform will contain a flip along X direction. + flip_y_chance: The chance (0 to 1) that a transform will contain a flip along Y direction. + prng: The pseudo-random number generator to use. + """ + return np.linalg.multi_dot([ + random_rotation(min_rotation, max_rotation, prng), + random_translation(min_translation, max_translation, prng), + random_shear(min_shear, max_shear, prng), + random_scaling(min_scaling, max_scaling, prng), + random_flip(flip_x_chance, flip_x_chance) + ]) + + +def random_transform_generator(prng=None, **kwargs): + """ Create a random transform generator with the same arugments as `random_transform`. + + Uses a dedicated, newly created, properly seeded PRNG by default instead of the global DEFAULT_PRNG. + """ + + if prng is None: + # RandomState automatically seeds using the best available method. + prng = np.random.RandomState() + + while True: + yield random_transform(prng=prng, **kwargs) diff --git a/tests/preprocessing/test_generator.py b/tests/preprocessing/test_generator.py index d28e0212b..2740c4300 100644 --- a/tests/preprocessing/test_generator.py +++ b/tests/preprocessing/test_generator.py @@ -15,7 +15,6 @@ """ import keras.backend -from keras.preprocessing.image import ImageDataGenerator from keras_retinanet.preprocessing.generator import Generator import numpy as np @@ -27,7 +26,7 @@ def __init__(self, annotations_group, num_classes=0, image=None): self.annotations_group = annotations_group self.num_classes_ = num_classes self.image = image - super(SimpleGenerator, self).__init__(ImageDataGenerator(), group_method='none', shuffle_groups=False) + super(SimpleGenerator, self).__init__(group_method='none', shuffle_groups=False) def num_classes(self): return self.num_classes_ diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/test_transform.py b/tests/utils/test_transform.py new file mode 100644 index 000000000..f8b94e788 --- /dev/null +++ b/tests/utils/test_transform.py @@ -0,0 +1,152 @@ +import numpy as np +from numpy.testing import assert_almost_equal +from math import pi + +from keras_retinanet.utils.transform import ( + colvec, + transform_aabb, + rotation, random_rotation, + translation, random_translation, + scaling, random_scaling, + shear, random_shear, + random_flip, + random_transform, + random_transform_generator, + change_transform_origin, +) + + +def test_colvec(): + assert np.array_equal(colvec(0), np.array([[0]])) + assert np.array_equal(colvec(1, 2, 3), np.array([[1], [2], [3]])) + assert np.array_equal(colvec(-1, -2), np.array([[-1], [-2]])) + + +def test_rotation(): + assert_almost_equal(colvec( 1, 0, 1), rotation(0.0 * pi).dot(colvec(1, 0, 1))) + assert_almost_equal(colvec( 0, 1, 1), rotation(0.5 * pi).dot(colvec(1, 0, 1))) + assert_almost_equal(colvec(-1, 0, 1), rotation(1.0 * pi).dot(colvec(1, 0, 1))) + assert_almost_equal(colvec( 0, -1, 1), rotation(1.5 * pi).dot(colvec(1, 0, 1))) + assert_almost_equal(colvec( 1, 0, 1), rotation(2.0 * pi).dot(colvec(1, 0, 1))) + + assert_almost_equal(colvec( 0, 1, 1), rotation(0.0 * pi).dot(colvec(0, 1, 1))) + assert_almost_equal(colvec(-1, 0, 1), rotation(0.5 * pi).dot(colvec(0, 1, 1))) + assert_almost_equal(colvec( 0, -1, 1), rotation(1.0 * pi).dot(colvec(0, 1, 1))) + assert_almost_equal(colvec( 1, 0, 1), rotation(1.5 * pi).dot(colvec(0, 1, 1))) + assert_almost_equal(colvec( 0, 1, 1), rotation(2.0 * pi).dot(colvec(0, 1, 1))) + + +def test_random_rotation(): + prng = np.random.RandomState(0) + for i in range(100): + assert_almost_equal(1, np.linalg.det(random_rotation(-i, i, prng))) + + +def test_translation(): + assert_almost_equal(colvec( 1, 2, 1), translation(colvec( 0, 0)).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec( 4, 6, 1), translation(colvec( 3, 4)).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec(-2, -2, 1), translation(colvec(-3, -4)).dot(colvec(1, 2, 1))) + + +def assert_is_translation(transform, min, max): + assert transform.shape == (3, 3) + assert np.array_equal(transform[:, 0:2], np.eye(3, 2)) + assert transform[2, 2] == 1 + assert np.greater_equal(transform[0:2, 2], min).all() + assert np.less( transform[0:2, 2], max).all() + + +def test_random_translation(): + prng = np.random.RandomState(0) + min = (-10, -20) + max = (20, 10) + for i in range(100): + assert_is_translation(random_translation(min, max, prng), min, max) + + +def test_shear(): + assert_almost_equal(colvec( 1, 2, 1), shear(0.0 * pi).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec(-1, 0, 1), shear(0.5 * pi).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec( 1, -2, 1), shear(1.0 * pi).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec( 3, 0, 1), shear(1.5 * pi).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec( 1, 2, 1), shear(2.0 * pi).dot(colvec(1, 2, 1))) + + +def assert_is_shear(transform): + assert transform.shape == (3, 3) + assert np.array_equal(transform[:, 0], [1, 0, 0]) + assert np.array_equal(transform[:, 2], [0, 0, 1]) + assert transform[2, 1] == 0 + # sin^2 + cos^2 == 1 + assert_almost_equal(1, transform[0, 1] ** 2 + transform[1, 1] ** 2) + + +def test_random_shear(): + prng = np.random.RandomState(0) + for i in range(100): + assert_is_shear(random_shear(-pi, pi, prng)) + + +def test_scaling(): + assert_almost_equal(colvec(1.0, 2, 1), scaling(colvec(1.0, 1.0)).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec(0.0, 2, 1), scaling(colvec(0.0, 1.0)).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec(1.0, 0, 1), scaling(colvec(1.0, 0.0)).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec(0.5, 4, 1), scaling(colvec(0.5, 2.0)).dot(colvec(1, 2, 1))) + + +def assert_is_scaling(transform, min, max): + assert transform.shape == (3, 3) + assert np.array_equal(transform[2, :], [0, 0, 1]) + assert np.array_equal(transform[:, 2], [0, 0, 1]) + assert transform[1, 0] == 0 + assert transform[0, 1] == 0 + assert np.greater_equal(np.diagonal(transform)[:2], min).all() + assert np.less( np.diagonal(transform)[:2], max).all() + + +def test_random_scaling(): + prng = np.random.RandomState(0) + min = (0.1, 0.2) + max = (20, 10) + for i in range(100): + assert_is_scaling(random_scaling(min, max, prng), min, max) + + +def assert_is_flip(transform): + assert transform.shape == (3, 3) + assert np.array_equal(transform[2, :], [0, 0, 1]) + assert np.array_equal(transform[:, 2], [0, 0, 1]) + assert transform[1, 0] == 0 + assert transform[0, 1] == 0 + assert abs(transform[0, 0]) == 1 + assert abs(transform[1, 1]) == 1 + + +def test_random_flip(): + prng = np.random.RandomState(0) + for i in range(100): + assert_is_flip(random_flip(0.5, 0.5, prng)) + + +def test_random_transform(): + prng = np.random.RandomState(0) + for i in range(100): + transform = random_transform(prng=prng) + assert np.array_equal(transform, np.identity(3)) + + for i, transform in zip(range(100), random_transform_generator(prng=np.random.RandomState())): + assert np.array_equal(transform, np.identity(3)) + + +def test_transform_aabb(): + assert np.array_equal([1, 2, 3, 4], transform_aabb(np.identity(3), [1, 2, 3, 4])) + assert_almost_equal([-3, -4, -1, -2], transform_aabb(rotation(pi), [1, 2, 3, 4])) + assert_almost_equal([ 2, 4, 4, 6], transform_aabb(translation([1, 2]), [1, 2, 3, 4])) + + +def test_change_transform_origin(): + prng = np.random.RandomState(0) + assert np.array_equal(change_transform_origin(translation([3, 4]), [1, 2]), translation([3, 4])) + assert_almost_equal(colvec(1, 2, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(1, 2, 1))) + assert_almost_equal(colvec(0, 0, 1), change_transform_origin(rotation(pi), [1, 2]).dot(colvec(2, 4, 1))) + assert_almost_equal(colvec(0, 0, 1), change_transform_origin(scaling([0.5, 0.5]), [-2, -4]).dot(colvec(2, 4, 1)))