Skip to content

Commit

Permalink
Merge pull request #190 from fizyr/cheap-augmenting
Browse files Browse the repository at this point in the history
Add random_transform_generator for augmenting images.
  • Loading branch information
de-vri-es authored Jan 15, 2018
2 parents 2b27ea7 + ca6f5b6 commit 8f22f35
Show file tree
Hide file tree
Showing 11 changed files with 488 additions and 71 deletions.
4 changes: 0 additions & 4 deletions keras_retinanet/bin/evaluate_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,10 @@ def main(args=None):
print('Loading model, this may take a second...')
model = keras.models.load_model(args.model, custom_objects=custom_objects)

# create image data generator object
test_image_data_generator = keras.preprocessing.image.ImageDataGenerator()

# create a generator for testing data
test_generator = CocoGenerator(
args.coco_path,
args.set,
test_image_data_generator,
)

evaluate_coco(test_generator, model, args.score_threshold)
Expand Down
17 changes: 6 additions & 11 deletions keras_retinanet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..preprocessing.pascal_voc import PascalVocGenerator
from ..preprocessing.csv_generator import CSVGenerator
from ..models.resnet import resnet50_retinanet
from ..utils.transform import random_transform_generator
from ..utils.keras_version import check_keras_version


Expand Down Expand Up @@ -109,11 +110,8 @@ def create_callbacks(model, training_model, prediction_model, validation_generat


def create_generators(args):
# create image data generator objects
train_image_data_generator = keras.preprocessing.image.ImageDataGenerator(
horizontal_flip=True,
)
val_image_data_generator = keras.preprocessing.image.ImageDataGenerator()
# create random transform generator for augmenting training data
transform_generator = random_transform_generator(flip_x_chance=0.5)

if args.dataset_type == 'coco':
# import here to prevent unnecessary dependency on cocoapi
Expand All @@ -122,43 +120,40 @@ def create_generators(args):
train_generator = CocoGenerator(
args.coco_path,
'train2017',
train_image_data_generator,
transform_generator=transform_generator,
batch_size=args.batch_size
)

validation_generator = CocoGenerator(
args.coco_path,
'val2017',
val_image_data_generator,
batch_size=args.batch_size
)
elif args.dataset_type == 'pascal':
train_generator = PascalVocGenerator(
args.pascal_path,
'trainval',
train_image_data_generator,
transform_generator=transform_generator,
batch_size=args.batch_size
)

validation_generator = PascalVocGenerator(
args.pascal_path,
'test',
val_image_data_generator,
batch_size=args.batch_size
)
elif args.dataset_type == 'csv':
train_generator = CSVGenerator(
args.annotations,
args.classes,
train_image_data_generator,
transform_generator=transform_generator,
batch_size=args.batch_size
)

if args.val_annotations:
validation_generator = CSVGenerator(
args.val_annotations,
args.classes,
val_image_data_generator,
batch_size=args.batch_size
)
else:
Expand Down
4 changes: 2 additions & 2 deletions keras_retinanet/preprocessing/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@


class CocoGenerator(Generator):
def __init__(self, data_dir, set_name, image_data_generator, *args, **kwargs):
def __init__(self, data_dir, set_name, **kwargs):
self.data_dir = data_dir
self.set_name = set_name
self.coco = COCO(os.path.join(data_dir, 'annotations', 'instances_' + set_name + '.json'))
self.image_ids = self.coco.getImgIds()

self.load_classes()

super(CocoGenerator, self).__init__(image_data_generator, **kwargs)
super(CocoGenerator, self).__init__(**kwargs)

def load_classes(self):
# load class names (name -> label)
Expand Down
3 changes: 1 addition & 2 deletions keras_retinanet/preprocessing/csv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __init__(
self,
csv_data_file,
csv_class_file,
image_data_generator,
base_dir=None,
**kwargs
):
Expand Down Expand Up @@ -139,7 +138,7 @@ def __init__(
raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(csv_data_file, e)), None)
self.image_names = list(self.image_data.keys())

super(CSVGenerator, self).__init__(image_data_generator, **kwargs)
super(CSVGenerator, self).__init__(**kwargs)

def size(self):
return len(self.image_names)
Expand Down
53 changes: 35 additions & 18 deletions keras_retinanet/preprocessing/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,35 @@

import keras

from ..utils.image import preprocess_image, resize_image, random_transform
from ..utils.anchors import anchor_targets_bbox
from ..utils.image import (
TransformParameters,
adjust_transform_for_image,
apply_transform,
preprocess_image,
resize_image,
)
from ..utils.transform import transform_aabb


class Generator(object):
def __init__(
self,
image_data_generator,
transform_generator = None,
batch_size=1,
group_method='ratio', # one of 'none', 'random', 'ratio'
shuffle_groups=True,
image_min_side=600,
image_max_side=1024,
seed=None
transform_parameters=None,
):
self.image_data_generator = image_data_generator
self.transform_generator = transform_generator
self.batch_size = int(batch_size)
self.group_method = group_method
self.shuffle_groups = shuffle_groups
self.image_min_side = image_min_side
self.image_max_side = image_max_side

if seed is None:
seed = np.uint32((time.time() % 1)) * 1000
np.random.seed(seed)
self.transform_parameters = transform_parameters or TransformParameters()

self.group_index = 0
self.lock = threading.Lock()
Expand Down Expand Up @@ -112,19 +116,32 @@ def resize_image(self, image):
def preprocess_image(self, image):
return preprocess_image(image)

def preprocess_group(self, image_group, annotations_group):
for index, (image, annotations) in enumerate(zip(image_group, annotations_group)):
# preprocess the image (subtract imagenet mean)
image = self.preprocess_image(image)
def preprocess_group_entry(self, image, annotations):
# preprocess the image
image = self.preprocess_image(image)

# randomly transform both image and annotations
if self.transform_generator:
transform = adjust_transform_for_image(next(self.transform_generator), image)
apply_transform(transform, image, self.transform_parameters)

# randomly transform both image and annotations
image, annotations = random_transform(image, annotations, self.image_data_generator)
# Transform the bounding boxes in the annotations.
annotations = annotations.copy()
for index in range(annotations.shape[0]):
annotations[index, :4] = transform_aabb(transform, annotations[index, :4])

# resize image
image, image_scale = self.resize_image(image)
# resize image
image, image_scale = self.resize_image(image)

# apply resizing to annotations too
annotations[:, :4] *= image_scale
# apply resizing to annotations too
annotations[:, :4] *= image_scale

return image, annotations

def preprocess_group(self, image_group, annotations_group):
for index, (image, annotations) in enumerate(zip(image_group, annotations_group)):
# preprocess a single group entry
image, annotations = self.preprocess_group_entry(image, annotations)

# copy processed data back to group
image_group[index] = image
Expand Down
3 changes: 1 addition & 2 deletions keras_retinanet/preprocessing/pascal_voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(
self,
data_dir,
set_name,
image_data_generator,
classes=voc_classes,
image_extension='.jpg',
skip_truncated=False,
Expand All @@ -90,7 +89,7 @@ def __init__(
for key, value in self.classes.items():
self.labels[value] = key

super(PascalVocGenerator, self).__init__(image_data_generator, **kwargs)
super(PascalVocGenerator, self).__init__(**kwargs)

def size(self):
return len(self.image_names)
Expand Down
79 changes: 49 additions & 30 deletions keras_retinanet/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import cv2
import PIL

from .transform import change_transform_origin, transform_aabb, colvec


def read_image_bgr(path):
image = np.asarray(PIL.Image.open(path).convert('RGB'))
Expand Down Expand Up @@ -48,44 +50,61 @@ def preprocess_image(x):
return x


def random_transform(
image,
boxes,
image_data_generator,
seed=None
):
if seed is None:
seed = np.random.randint(10000)
def adjust_transform_for_image(transform, image):
""" Adjust a transformation for a specific image.
The translation of the matrix will be scaled with the size of the image.
The linear part of the transformation will adjusted so that the origin of the transformation will be at the center of the image.
"""
height, width, channels = image.shape

image = image_data_generator.random_transform(image, seed=seed)
# Move the origin of transformation.
result = change_transform_origin(transform, (0.5 * width, 0.5 * height))

# set fill mode so that masks are not enlarged
fill_mode = image_data_generator.fill_mode
image_data_generator.fill_mode = 'constant'
# Scale the translation with the image size.
result[0:2, 2] *= [width, height]

for index in range(boxes.shape[0]):
# generate box mask and randomly transform it
mask = np.zeros_like(image, dtype=np.uint8)
b = boxes[index, :4].astype(int)
return result

assert(b[0] < b[2] and b[1] < b[3]), 'Annotations contain invalid box: {}'.format(b)
assert(b[2] <= image.shape[1] and b[3] <= image.shape[0]), 'Annotation ({}) is outside of image shape ({}).'.format(b, image.shape)

mask[b[1]:b[3], b[0]:b[2], :] = 255
mask = image_data_generator.random_transform(mask, seed=seed)[..., 0]
mask = mask.copy() # to force contiguous arrays
class TransformParameters:
""" Struct holding parameters determining how to transform images.
# find bounding box again in augmented image
[i, j] = np.where(mask == 255)
boxes[index, 0] = float(min(j))
boxes[index, 1] = float(min(i))
boxes[index, 2] = float(max(j)) + 1 # set box to an open interval [min, max)
boxes[index, 3] = float(max(i)) + 1 # set box to an open interval [min, max)
# Arguments
fill_mode: Same as for keras.preprocessing.image.apply_transform
cval: Same as for keras.preprocessing.image.apply_transform
data_format: Same as for keras.preprocessing.image.apply_transform
"""
def __init__(
self,
fill_mode = 'nearest',
cval = 0,
data_format = None,
):
self.fill_mode = fill_mode
self.cval = cval

# restore fill_mode
image_data_generator.fill_mode = fill_mode
if data_format is None:
data_format = keras.backend.image_data_format()
self.data_format = data_format

return image, boxes
if data_format == 'channels_first':
self.channel_axis = 0
elif data_format == 'channels_last':
self.channel_axis = 2
else:
raise ValueError("invalid data_format, expected 'channels_first' or 'channels_last', got '{}'".format(data_format))


def apply_transform(transform, image, params):
""" Wrapper around keras.preprocessing.image.apply_transform using TransformParameters. """
return keras.preprocessing.image.apply_transform(
image,
transform,
channel_axis=params.channel_axis,
fill_mode=params.fill_mode,
cval=params.cval
)


def resize_image(img, min_side=600, max_side=1024):
Expand Down
Loading

0 comments on commit 8f22f35

Please sign in to comment.