diff --git a/README.md b/README.md index 0181c24..d6e5675 100644 --- a/README.md +++ b/README.md @@ -13,15 +13,22 @@ Original paper and github: [YOLO9000: Better, Faster, Stronger](https://arxiv.or ### Dependancies --- - [keras](https://github.com/fchollet/keras) -- [tensorflow](https://www.tensorflow.org/) +- [tensorflow 1.9](https://www.tensorflow.org/api_docs/) - [numpy](http://www.numpy.org/) - [h5py](http://www.h5py.org/) - [opencv](https://pypi.org/project/opencv-python/) -- [python3.5](https://www.python.org/) +- [python 3.5](https://www.python.org/) - [moviepy](https://zulko.github.io/moviepy/) (optional, gifs) ### Simple use --- + +0. Clone repository: +```bash +git clone https://github.com/ksanjeevan/dourflow.git +``` + 1. Download pretrained [model](https://drive.google.com/open?id=1khOgS8VD-paUD8KhjOLOzEXkzaXNAAMq) and place it in **dourflow/**. + 2. Predict on an [image](https://images.pexels.com/photos/349758/hummingbird-bird-birds-349758.jpeg?auto=compress&cs=tinysrgb&h=350): ```bash @@ -38,13 +45,14 @@ Running `python3 dourflow.py --help`: ```bash usage: dourflow.py [-h] [-m MODEL] [-c CONF] [-t THRESHOLD] [-w WEIGHT_FILE] + [--gif] action dourflow: a keras YOLO V2 implementation. positional arguments: - action what to do: 'train', 'validate', 'cam' or pass a video, image - file/dir. + action what to do: 'train', 'validate', 'cam' or pass a + video, image file/dir. optional arguments: -h, --help show this help message and exit @@ -55,7 +63,7 @@ optional arguments: detection threshold -w WEIGHT_FILE, --weight_file WEIGHT_FILE path to weight file - + --gif video output stored as gif also ``` #### *action* [positional] @@ -151,28 +159,28 @@ Output: ```bash Batch Processed: 100%|████████████████████████████████████████████| 4282/4282 [01:53<00:00, 37.84it/s] -AP( bus ): 0.806 -AP( tvmonitor ): 0.716 -AP( motorbike ): 0.666 -AP( dog ): 0.811 -AP( horse ): 0.574 -AP( boat ): 0.618 -AP( sofa ): 0.625 -AP( sheep ): 0.718 -AP( bicycle ): 0.557 -AP( cow ): 0.725 -AP( pottedplant ): 0.565 -AP( train ): 0.907 -AP( bird ): 0.813 -AP( person ): 0.665 -AP( car ): 0.580 AP( cat ): 0.908 -AP( bottle ): 0.429 -AP( diningtable ): 0.593 -AP( chair ): 0.475 -AP( aeroplane ): 0.724 +AP( train ): 0.907 +AP( dog ): 0.899 +AP( bird ): 0.814 +AP( aeroplane ): 0.810 +AP( cow ): 0.810 +AP( bus ): 0.806 +AP( motorbike ): 0.792 +AP( person ): 0.737 +AP( sheep ): 0.719 +AP( tvmonitor ): 0.718 +AP( sofa ): 0.701 +AP( bicycle ): 0.683 +AP( diningtable ): 0.665 +AP( car ): 0.641 +AP( boat ): 0.617 +AP( horse ): 0.575 +AP( pottedplant ): 0.568 +AP( chair ): 0.528 +AP( bottle ): 0.487 ------------------------------- -mAP: 0.674 +mAP: 0.719 ``` @@ -198,7 +206,7 @@ Will store the custom bounding box priors wherever the path indicates in the con Training will create directory **logs/** which will store metrics and checkpoints for all the different training runs. -Model passed is used for [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning) (TRAINING FROM SCRATCH / TRAINING ONLY LAST LAYER SHOULD BE ADDED SOON). +Model passed is used for [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning). Example: `python3 dourflow.py train -m models/logo/coco_model.h5 -c confs/config_custom.json` @@ -214,11 +222,11 @@ Then, in another terminal tab you can run `tensorboard --logdir=logs/run_X` and #### To Do -- [ ] cfg parser +- [x] Multiclass NMS - [x] Anchor generation for custom datasets - [ ] mAP write up - [x] Add webcam support -- [ ] Data Augmentation +- [x] Data Augmentation - [x] TensorBoard metrics #### Inspired from @@ -226,4 +234,4 @@ Then, in another terminal tab you can run `tensorboard --logdir=logs/run_X` and - [Darknet](https://github.com/pjreddie/darknet) - [Darkflow](https://github.com/thtrieu/darkflow) - [keras-yolo2](https://github.com/experiencor/keras-yolo2) -- [YAD2K](https://github.com/allanzelener/YAD2K) \ No newline at end of file +- [YAD2K](https://github.com/allanzelener/YAD2K) diff --git a/confs/config_coco.json b/confs/config_coco.json index 60eaf3c..a082b7d 100755 --- a/confs/config_coco.json +++ b/confs/config_coco.json @@ -2,9 +2,9 @@ "model" : { "input_size": 416, "grid_size": 13, - "true_box_buffer": 30, + "true_box_buffer": 10, "iou_threshold": 0.5, - "nms_threshold": 0.3 + "nms_threshold": 0.4 }, "config_path" : { "labels": "models/coco/labels_coco.txt", diff --git a/confs/config_voc.json b/confs/config_voc.json index 5f40ee6..509facd 100644 --- a/confs/config_voc.json +++ b/confs/config_voc.json @@ -4,7 +4,7 @@ "grid_size": 13, "true_box_buffer": 10, "iou_threshold": 0.5, - "nms_threshold": 0.3 + "nms_threshold": 0.45 }, "config_path" : { "labels": "models/voc/labels_voc.txt", @@ -17,7 +17,7 @@ "annot_folder": "/home/kiran/Documents/DATA/VOC/train/anns", "batch_size": 16, "learning_rate": 1e-4, - "num_epochs": 20, + "num_epochs": 50, "object_scale": 5.0 , "no_object_scale": 1.0, "coord_scale": 1.0, diff --git a/kmeans_anchors.py b/kmeans_anchors.py index e675950..380573e 100644 --- a/kmeans_anchors.py +++ b/kmeans_anchors.py @@ -105,7 +105,6 @@ def IoU_dist(x, c): - def exrtract_wh(img): result = [] pixel_height = img['height'] diff --git a/net/__init__.py b/net/__init__.py index e69de29..8a87e95 100644 --- a/net/__init__.py +++ b/net/__init__.py @@ -0,0 +1,7 @@ +#from . import netarch +#from . import netdecode +#from . import neteval +#from . import netgen +#from . import netloss +#from . import netparams +#from . import utils \ No newline at end of file diff --git a/net/netarch.py b/net/netarch.py index f260456..07b7232 100644 --- a/net/netarch.py +++ b/net/netarch.py @@ -1,17 +1,20 @@ +""" +Set up keras model with Yolo v2 architecture, for both training +and inference. +""" +import tensorflow as tf +import numpy as np +import pickle, argparse, json, os, cv2 + from keras.models import Model, load_model from keras.layers import Reshape, Conv2D, Input, MaxPooling2D, BatchNormalization, Lambda from keras.layers.advanced_activations import LeakyReLU from keras.layers.merge import concatenate - -import tensorflow as tf -import numpy as np -import pickle, argparse, json, os, cv2 - from keras.utils.vis_utils import plot_model -from net.netparams import YoloParams -from net.netdecode import YoloOutProcess +from .netparams import YoloParams +from .netdecode import YoloOutProcess @@ -20,6 +23,7 @@ class YoloInferenceModel(object): def __init__(self, model): self._yolo_out = YoloOutProcess() self._inf_model = self._extend_processing(model) + self._model = model def _extend_processing(self, model): output = Lambda(self._yolo_out, name='lambda_2')(model.output) @@ -40,6 +44,8 @@ def _prepro_single_image(self, image): def predict(self, image): image = self._prepro_single_image(image) + + #np.save('person.npy', self._model.predict(image)) output = self._inf_model.predict(image)[0] if output.size == 0: diff --git a/net/netdecode.py b/net/netdecode.py index bac0b64..cea658d 100644 --- a/net/netdecode.py +++ b/net/netdecode.py @@ -1,8 +1,13 @@ +""" +Process [GRID x GRID x BOXES x (4 + 1 + CLASSES)]. Filter low confidence +boxes, apply NMS and return boxes, scores, classes. +""" + import tensorflow as tf from keras import backend as K import numpy as np -from net.netparams import YoloParams +from .netparams import YoloParams def process_outs(b, s, c): @@ -21,17 +26,126 @@ def process_outs(b, s, c): return K.expand_dims(output_stack, axis=0) +class YoloOutProcess(object): + + + def __init__(self): + # thresholds + self.max_boxes = YoloParams.TRUE_BOX_BUFFER + self.nms_threshold = YoloParams.NMS_THRESHOLD + self.detection_threshold = YoloParams.DETECTION_THRESHOLD + + self.num_classes = YoloParams.NUM_CLASSES + + def __call__(self, y_sing_pred): + + # need to convert b's from GRID_SIZE units into IMG coords. Divide by grid here. + b_xy = (K.sigmoid(y_sing_pred[..., 0:2]) + YoloParams.c_grid[0]) / YoloParams.GRID_SIZE + b_wh = (K.exp(y_sing_pred[..., 2:4])*YoloParams.anchors[0]) / YoloParams.GRID_SIZE + b_xy1 = b_xy - b_wh / 2. + b_xy2 = b_xy + b_wh / 2. + boxes = K.concatenate([b_xy1, b_xy2], axis=-1) + + # filter out scores below detection threshold + scores_all = K.expand_dims(K.sigmoid(y_sing_pred[..., 4]), axis=-1) * K.softmax(y_sing_pred[...,5:]) + indicator_detection = scores_all > self.detection_threshold + scores_all = scores_all * K.cast(indicator_detection, np.float32) + + # compute detected classes and scores + classes = K.argmax(scores_all, axis=-1) + scores = K.max(scores_all, axis=-1) + + # flattened tensor length + S2B = YoloParams.GRID_SIZE*YoloParams.GRID_SIZE*YoloParams.NUM_BOUNDING_BOXES + + # flatten boxes, scores for NMS + flatten_boxes = K.reshape(boxes, shape=(S2B, 4)) + flatten_scores = K.reshape(scores, shape=(S2B, )) + flatten_classes = K.reshape(classes, shape=(S2B, )) + + inds = [] + + # apply multiclass NMS + for c in range(self.num_classes): + + # only include boxes of the current class, with > 0 confidence + class_mask = K.cast(K.equal(flatten_classes, c), np.float32) + score_mask = K.cast(flatten_scores > 0, np.float32) + mask = class_mask * score_mask + + # compute class NMS + nms_inds = tf.image.non_max_suppression( + flatten_boxes, + flatten_scores*mask, + max_output_size=self.max_boxes, + iou_threshold=self.nms_threshold, + score_threshold=0. + ) + + inds.append(nms_inds) + + # combine winning box indices of all classes + selected_indices = K.concatenate(inds, axis=-1) + + # gather corresponding boxes, scores, class indices + selected_boxes = K.gather(flatten_boxes, selected_indices) + selected_scores = K.gather(flatten_scores, selected_indices) + selected_classes = K.gather(flatten_classes, selected_indices) + + return process_outs(selected_boxes, selected_scores, K.cast(selected_classes, np.float32)) -class YoloOutProcess(object): +class YoloOutProcessOther(object): + """ + [UNUSED] Ignore. + """ + def __init__(self): self.max_boxes = YoloParams.TRUE_BOX_BUFFER self.nms_threshold = YoloParams.NMS_THRESHOLD self.detection_threshold = YoloParams.DETECTION_THRESHOLD + self.num_classes = YoloParams.NUM_CLASSES + + + def _class_nms(self, boxes, scores, c_mask): + #c_mask = K.equal(classes, i) + c_mask = c_mask*K.cast(scores > 0, np.float32) + c_boxes = boxes * K.expand_dims(c_mask, axis=-1) + c_scores = scores * c_mask + inds = tf.image.non_max_suppression(c_boxes, c_scores, max_output_size=10, iou_threshold=0.2) + # tf.pad(inds, tf.Variable([[0,10-tf.shape(inds)[0]]]), "CONSTANT") + return self._pad_tensor(inds, 10, value=-1) + + + def _pad_tensor(self, t, length, value=0): + """Pads the input tensor with 0s along the first dimension up to the length. + Args: + t: the input tensor, assuming the rank is at least 1. + length: a tensor of shape [1] or an integer, indicating the first dimension + of the input tensor t after padding, assuming length <= t.shape[0]. + Returns: + padded_t: the padded tensor, whose first dimension is length. If the length + is an integer, the first dimension of padded_t is set to length + statically. + """ + t_rank = tf.rank(t) + t_shape = tf.shape(t) + t_d0 = t_shape[0] + pad_d0 = tf.expand_dims(length - t_d0, 0) + pad_shape = tf.cond( + tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0), + lambda: tf.expand_dims(length - t_d0, 0)) + padded_t = tf.concat([t, value+tf.zeros(pad_shape, dtype=t.dtype)], 0) + + t_shape = padded_t.get_shape().as_list() + t_shape[0] = length + padded_t.set_shape(t_shape) + + return padded_t def __call__(self, y_sing_pred): @@ -55,12 +169,17 @@ def __call__(self, y_sing_pred): flatten_scores = K.reshape(scores, shape=(S2B, )) flatten_classes = K.reshape(classes, shape=(S2B, )) - selected_indices = tf.image.non_max_suppression( - flatten_boxes, - flatten_scores, - max_output_size=self.max_boxes, - iou_threshold=self.nms_threshold) - + + c_masks = K.map_fn(lambda c: K.cast(K.equal(flatten_classes, c), np.float32), np.arange(self.num_classes), dtype=np.float32) + resu_stacked = tf.map_fn( + lambda c: self._class_nms(flatten_boxes, flatten_scores, c), + c_masks, + dtype=np.int32, + infer_shape=True) + + resu_flat = K.reshape(resu_stacked, shape=(-1,)) + selected_indices = tf.boolean_mask(resu_flat, ~K.equal(resu_flat, -1)) + selected_boxes = K.gather(flatten_boxes, selected_indices) selected_scores = K.gather(flatten_scores, selected_indices) selected_classes = K.gather(flatten_classes, selected_indices) @@ -73,28 +192,26 @@ def __call__(self, y_sing_pred): return process_outs(selected_boxes, selected_scores, K.cast(selected_classes, np.float32)) - -if __name__ == '__main__': - sess = tf.InteractiveSession() - max_boxes = 10 - nms_threshold = 0.1 - boxes = tf.convert_to_tensor(np.random.rand(10,4), np.float32) - scores = tf.convert_to_tensor(np.random.rand(10,), np.float32) - classes = tf.convert_to_tensor((10.*np.random.rand(10,)%3).astype(int), np.float32) +if __name__ == '__main__': + + tf.InteractiveSession() - s,b,c = yolo_non_max_suppression(scores, boxes, classes, max_boxes, nms_threshold) + a = tf.convert_to_tensor(np.load('ocell.npy'), np.float32) + + yolo_out = YoloOutProcess() - print(boxes.eval().shape) - print(scores.eval().shape) - print(classes.eval().shape) + resu = yolo_out(a).eval()[0] - print('-----------------------') + b = resu[:,:4] + s = resu[:,4] + c = resu[:,5] - print(b.eval().shape) - print(s.eval().shape) - print(c.eval().shape) + print('---------------------') + print(c) + print(s) + print(b) diff --git a/net/neteval.py b/net/neteval.py index 2771ce6..5c2c404 100644 --- a/net/neteval.py +++ b/net/neteval.py @@ -1,5 +1,11 @@ +""" +Evaluate results. mAP validation, tensorboard metrics: +mAP callback, recall. +""" -from net.netparams import YoloParams +from .netparams import YoloParams +from .utils import draw_boxes, compute_iou, mkdir_p, handle_empty_indexing +from .netloss import _transform_netout, calculate_ious import tensorflow as tf from tqdm import tqdm @@ -7,8 +13,6 @@ import numpy as np import cv2, os import keras -from net.utils import draw_boxes, compute_iou, mkdir_p, \ -mkdir_p, handle_empty_indexing from keras import backend as K @@ -16,143 +20,6 @@ -class YoloDataGenerator(keras.utils.Sequence): - 'Generates data for Keras' - def __init__(self, images, shuffle=True): - - self.images = self._prune_ann_labels(images) - self.input_size = YoloParams.INPUT_SIZE - self.anchors = YoloParams.anchors - - self.generator = None - - self.batch_size = YoloParams.BATCH_SIZE - - self.shuffle = shuffle - self.on_epoch_end() - - def __len__(self): - 'Denotes the number of batches per epoch' - # return int(np.ceil(float(len(self.images))/self.config['BATCH_SIZE'])) - return int(np.floor(len(self.images) / self.batch_size)) - - def __getitem__(self, index): - 'Generate one batch of data' - bound_l = index*self.batch_size - bound_r = (index+1)*self.batch_size - - return self._data_to_yolo_output(self.images[bound_l:bound_r]) - - def load_image_name(self, i): - return self.images[i]['filename'] - - - def load_image(self, i): - return cv2.imread(self.images[i]['filename']) - - def load_annotation(self, i): - labels = [] - bboxes = [] - - height = self.images[i]['height'] - width = self.images[i]['width'] - - for obj in self.images[i]['object']: - #if obj['name'] in YoloParams.CLASS_LABELS: - labels.append( obj['name'] ) - bboxes.append( - [obj['xmin'] / width, obj['ymin'] / height, obj['xmax'] / width, obj['ymax'] / height] ) - - - class_inds = [YoloParams.CLASS_TO_INDEX[l] for l in labels] - - return np.array(bboxes), np.array(class_inds) - - def on_epoch_end(self): - 'Updates indexes after each epoch' - if self.shuffle: np.random.shuffle(self.images) - - def _prune_ann_labels(self, images): - clean_images = [] - for im in images: - clean_im = im.copy() - clean_objs = [] - for obj in clean_im['object']: - if obj['name'] in YoloParams.CLASS_LABELS: - clean_objs.append( obj ) - - clean_im.update({'object' : clean_objs}) - clean_images.append(clean_im) - - return clean_images - - - def _data_to_yolo_output(self, batch_images): - - # INPUT IMAGES READY FOR TRAINING - x_batch = np.zeros((len(batch_images), self.input_size, self.input_size, 3)) - - # GET DESIRED NETWORK OUTPUT - y_batch = np.zeros((len(batch_images), YoloParams.GRID_SIZE, - YoloParams.GRID_SIZE, YoloParams.NUM_BOUNDING_BOXES, 4+1+len(YoloParams.CLASS_LABELS))) - - grid_factor = YoloParams.GRID_SIZE / self.input_size - - for j, train_instance in enumerate(batch_images): - - img_raw = cv2.imread(train_instance['filename']) - - h_factor_resize = img_raw.shape[0] / self.input_size - w_factor_resize = img_raw.shape[1] / self.input_size - - img = cv2.resize(img_raw, (self.input_size, self.input_size)) - - for obj_box_idx, label in enumerate(train_instance['object']): - - xmin_resized = int(round(label['xmin'] / w_factor_resize)) - xmax_resized = int(round(label['xmax'] / w_factor_resize)) - ymin_resized = int(round(label['ymin'] / h_factor_resize)) - ymax_resized = int(round(label['ymax'] / h_factor_resize)) - - bbox_center_x = .5*(xmin_resized + xmax_resized) * grid_factor - grid_x = int(bbox_center_x) - - bbox_center_y = .5*(ymin_resized + ymax_resized) * grid_factor - grid_y = int(bbox_center_y) - - obj_indx = YoloParams.CLASS_LABELS.index(label['name']) - - bbox_w = (xmax_resized - xmin_resized) * grid_factor - bbox_h = (ymax_resized - ymin_resized) * grid_factor - - shifted_wh = np.array([0,0,bbox_w, bbox_h]) - - func = lambda prior: compute_iou((0,0,prior[0],prior[1]), shifted_wh) - - anchor_winner = np.argmax(np.apply_along_axis(func, -1, self.anchors)) - - # assign ground truth x, y, w, h, confidence and class probs to y_batch - - # ASSIGN CLASS CONFIDENCE - y_batch[j, grid_y, grid_x, anchor_winner, 0:4] = [bbox_center_x, bbox_center_y, bbox_w, bbox_h] - - # ASSIGN OBJECTNESS CONF - y_batch[j, grid_y, grid_x, anchor_winner, 4 ] = 1. - - # ASSIGN CORRECT CLASS TO - y_batch[j, grid_y, grid_x, anchor_winner, 4+1+obj_indx] = 1 - - # number of labels per instance !> than true_box_buffer, add check in processing (?) - x_batch[j] = img / 255. - - ############################################################ - # x_batch -> list of input images - # y_batch -> list of network ouput gt values for each image - ############################################################ - return x_batch, y_batch - - - class YoloEvaluate(object): @@ -285,13 +152,9 @@ def compute_ap(self, detections, num_gts): def comp_map(self): detection_results = [] - detection_labels = np.array([0]*YoloParams.NUM_CLASSES) - - num_annotations = 0 - counter = 0 + detection_labels = np.array([0]*YoloParams.NUM_CLASSES) for i in tqdm(range(len(self.generator.images)), desc='Batch Processed'): - counter += 1 image_name = os.path.basename( self.generator.load_image_name(i) ) @@ -307,8 +170,9 @@ def comp_map(self): ap_dic = {} for class_ind, num_gts in enumerate(detection_labels): + class_detections = detection_results[detection_results[:,0]==class_ind] - + ap = self.compute_ap(class_detections, num_gts) ap_dic[self.class_labels[class_ind]] = ap @@ -342,15 +206,17 @@ def on_epoch_end(self, epoch, logs={}): - -def yolo_recall(y_true, y_pred): - +def yolo_recall(y_true, y_pred_raw): truth = y_true[...,4] - pred_scores = K.expand_dims(K.sigmoid(y_pred[..., 4]), axis=-1) * K.softmax(y_pred[...,5:]) - preds = K.cast(K.max(pred_scores, axis=-1) > YoloParams.DETECTION_THRESHOLD, np.float32) + y_pred = _transform_netout(y_pred_raw) + ious = calculate_ious(y_true, y_pred, use_iou=True) + pred_ious = K.cast(ious > YoloParams.IOU_THRESHOLD, np.float32) + + scores = y_pred[..., 4:5] * y_pred[...,5:] + pred_scores = K.cast(K.max(scores, axis=-1) > YoloParams.DETECTION_THRESHOLD, np.float32) - tp = K.sum(truth * preds) + tp = K.sum(pred_ious * pred_scores) tpfn = K.sum(truth) return tp / (tpfn + 1e-8) diff --git a/net/netgen.py b/net/netgen.py new file mode 100644 index 0000000..5c42ef4 --- /dev/null +++ b/net/netgen.py @@ -0,0 +1,229 @@ +""" +Data generator augmentation. +""" + +from .netparams import YoloParams +from .utils import compute_iou + +import numpy as np +import cv2, os +import keras +import copy + +PERC_LIMIT = 0.2 +HSV_FACT = 1.5 + +class YoloDataGenerator(keras.utils.Sequence): + 'Generates data for Keras' + def __init__(self, images, shuffle=True, augment=False): + + self.images = self._prune_ann_labels(images) + self.input_size = YoloParams.INPUT_SIZE + self.anchors = YoloParams.anchors + + self.generator = None + + self.batch_size = YoloParams.BATCH_SIZE + + self.shuffle = shuffle + self.perc = PERC_LIMIT + self.hsvf = HSV_FACT + self.augment = augment + + self.on_epoch_end() + + def __len__(self): + 'Denotes the number of batches per epoch' + # return int(np.ceil(float(len(self.images))/self.config['BATCH_SIZE'])) + return int(np.floor(len(self.images) / self.batch_size)) + + def __getitem__(self, index): + 'Generate one batch of data' + bound_l = index*self.batch_size + bound_r = (index+1)*self.batch_size + + return self._data_to_yolo_output(self.images[bound_l:bound_r]) + + def load_image_name(self, i): + return self.images[i]['filename'] + + + def load_image(self, i): + return cv2.imread(self.images[i]['filename']) + + def load_annotation(self, i): + labels = [] + bboxes = [] + + height = self.images[i]['height'] + width = self.images[i]['width'] + + for obj in self.images[i]['object']: + #if obj['name'] in YoloParams.CLASS_LABELS: + labels.append( obj['name'] ) + bboxes.append( + [obj['xmin'] / width, obj['ymin'] / height, obj['xmax'] / width, obj['ymax'] / height] ) + + class_inds = [YoloParams.CLASS_TO_INDEX[l] for l in labels] + + return np.array(bboxes), np.array(class_inds) + + def on_epoch_end(self): + 'Updates indexes after each epoch' + if self.shuffle: np.random.shuffle(self.images) + + def _prune_ann_labels(self, images): + clean_images = [] + for im in images: + clean_im = im.copy() + clean_objs = [] + for obj in clean_im['object']: + if obj['name'] in YoloParams.CLASS_LABELS: + clean_objs.append( obj ) + + clean_im.update({'object' : clean_objs}) + clean_images.append(clean_im) + + return clean_images + + + def _data_to_yolo_output(self, batch_images): + + # INPUT IMAGES READY FOR TRAINING + x_batch = np.zeros((len(batch_images), self.input_size, self.input_size, 3)) + + # GET DESIRED NETWORK OUTPUT + y_batch = np.zeros((len(batch_images), YoloParams.GRID_SIZE, + YoloParams.GRID_SIZE, YoloParams.NUM_BOUNDING_BOXES, 4+1+len(YoloParams.CLASS_LABELS))) + + grid_factor = YoloParams.GRID_SIZE / self.input_size + + for j, inst in enumerate(batch_images): + + img_raw, new_inst = data_augmentation(inst, self.perc, self.hsvf, self.augment) + + h_factor_resize = img_raw.shape[0] / self.input_size + w_factor_resize = img_raw.shape[1] / self.input_size + + img = cv2.resize(img_raw, (self.input_size, self.input_size)) + + for label in new_inst['object']: + + xmin_resized = int(round(label['xmin'] / w_factor_resize)) + xmax_resized = int(round(label['xmax'] / w_factor_resize)) + ymin_resized = int(round(label['ymin'] / h_factor_resize)) + ymax_resized = int(round(label['ymax'] / h_factor_resize)) + + bbox_center_x = .5*(xmin_resized + xmax_resized) * grid_factor + grid_x = int(bbox_center_x) + + bbox_center_y = .5*(ymin_resized + ymax_resized) * grid_factor + grid_y = int(bbox_center_y) + + + obj_indx = YoloParams.CLASS_LABELS.index(label['name']) + + bbox_w = (xmax_resized - xmin_resized) * grid_factor + bbox_h = (ymax_resized - ymin_resized) * grid_factor + + shifted_wh = np.array([0,0,bbox_w, bbox_h]) + + func = lambda prior: compute_iou((0,0,prior[0],prior[1]), shifted_wh) + + anchor_winner = np.argmax(np.apply_along_axis(func, -1, self.anchors)) + + # assign ground truth x, y, w, h, confidence and class probs to y_batch + + # ASSIGN CLASS CONFIDENCE + y_batch[j, grid_y, grid_x, anchor_winner, 0:4] = [bbox_center_x, bbox_center_y, bbox_w, bbox_h] + # ASSIGN OBJECTNESS CONF + y_batch[j, grid_y, grid_x, anchor_winner, 4 ] = 1. + + # ASSIGN CORRECT CLASS TO + y_batch[j, grid_y, grid_x, anchor_winner, 4+1+obj_indx] = 1 + + # number of labels per instance !> than true_box_buffer, add check in processing (?) + + x_batch[j] = img / 255. + + ############################################################ + # x_batch -> list of input images + # y_batch -> list of network ouput gt values for each image + ############################################################ + return x_batch, y_batch + + + + +def _scale_translation(inst, fact): + + height, width = inst['height'], inst['width'] + # what % from the increased height will + # contribute to the offset position + pos_fact = fact * np.random.rand() + off_x = int(round(pos_fact * width)) + off_y = int(round(pos_fact * height)) + + fields = { + 'xmin':(off_x, width), + 'xmax':(off_x, width), + 'ymin':(off_y, height), + 'ymax':(off_y, height)} + + final_objs = [] + for label in inst['object']: + + for coord,v in fields.items(): + offset, lim = v + label[coord] = label[coord] * (1+fact) + label[coord] = max(min(int(round(label[coord]-offset)), lim),0) + + # if a an object was left out of the transform don't include it + # for amin, amax in [('xmin', 'xmax'),('ymin','ymax')]: + xcond = label['xmax'] - label['xmin'] > 5 + ycond = label['ymax'] - label['ymin'] > 5 + + if xcond and ycond: + final_objs.append(label) + + return off_x, off_y, final_objs + +def _exposure_saturation(img, hsvf): + sfact = np.random.uniform(1,hsvf) + vfact = np.random.uniform(1,hsvf) + + hsv = cv2.cvtColor(img,cv2.COLOR_BGR2HSV) + + s = (hsv[...,1]*sfact).astype(np.int) + v = (hsv[...,2]*vfact).astype(np.int) + + hsv[...,1] = np.where(s < 255, s, 255) + hsv[...,2] = np.where(v < 255, v, 255) + + return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + + + +def data_augmentation(inst, perc, hsvf, augment): + img_raw = cv2.imread(inst['filename']).copy() + new_inst = copy.deepcopy(inst) + + if not augment: return img_raw, new_inst + + + fact = perc * np.random.rand() + + off_x, off_y, objs = _scale_translation(new_inst, fact=fact) + + if len(objs) == 0: + return img_raw, inst + + new_inst['object'] = objs + + img_resized = cv2.resize(img_raw, (0,0), fx=(1+fact), fy=(1+fact)) + # preserve original size after scaling & translating + img_off = img_resized[off_y:off_y+img_raw.shape[0], off_x:off_x+img_raw.shape[1]] + + img_final = _exposure_saturation(img_off, hsvf=hsvf) + + return img_final, new_inst \ No newline at end of file diff --git a/net/netloss.py b/net/netloss.py index e7646e3..35b29ae 100644 --- a/net/netloss.py +++ b/net/netloss.py @@ -1,8 +1,12 @@ +""" +Yolo v2 loss function. +""" -import tensorflow as tf import numpy as np +import tensorflow as tf from keras import backend as K -from net.netparams import YoloParams + +from .netparams import YoloParams EPSILON = 1e-7 @@ -10,7 +14,7 @@ def calculate_ious(A1, A2, use_iou=True): if not use_iou: - return 1. + return A1[..., 4] A1_xy = A1[..., 0:2] A1_wh = A1[..., 2:4] @@ -19,43 +23,46 @@ def calculate_ious(A1, A2, use_iou=True): A2_wh = A2[..., 2:4] A1_wh_half = A1_wh / 2. - A1_mins = A1_xy - A1_wh_half - A1_maxes = A1_xy + A1_wh_half + A1_mins = A1_xy - A1_wh_half + A1_maxes = A1_xy + A1_wh_half A2_wh_half = A2_wh / 2. A2_mins = A2_xy - A2_wh_half - A2_maxes = A2_xy + A2_wh_half + A2_maxes = A2_xy + A2_wh_half intersect_mins = K.maximum(A2_mins, A1_mins) intersect_maxes = K.minimum(A2_maxes, A1_maxes) - intersect_wh = K.maximum(intersect_maxes - intersect_mins, 0.) + intersect_wh = K.maximum(intersect_maxes - intersect_mins, 0.) intersect_areas = intersect_wh[..., 0] * intersect_wh[..., 1] true_areas = A1_wh[..., 0] * A1_wh[..., 1] pred_areas = A2_wh[..., 0] * A2_wh[..., 1] union_areas = pred_areas + true_areas - intersect_areas - iou_scores = intersect_areas / union_areas + return intersect_areas / union_areas + - return iou_scores +def _transform_netout(y_pred_raw): + y_pred_xy = K.sigmoid(y_pred_raw[..., :2]) + YoloParams.c_grid + y_pred_wh = K.exp(y_pred_raw[..., 2:4]) * YoloParams.anchors + y_pred_conf = K.sigmoid(y_pred_raw[..., 4:5]) + y_pred_class = y_pred_raw[...,5:] + return K.concatenate([y_pred_xy, y_pred_wh, y_pred_conf, y_pred_class], axis=-1) class YoloLoss(object): - # WARM UP def __init__(self): self.__name__ = 'yolo_loss' - self.iou_threshold = YoloParams.IOU_THRESHOLD - self.readjust_obj_score = True + self.iou_filter = 0.6 + self.readjust_obj_score = False self.lambda_coord = YoloParams.COORD_SCALE self.lambda_noobj = YoloParams.NO_OBJECT_SCALE self.lambda_obj = YoloParams.OBJECT_SCALE self.lambda_class = YoloParams.CLASS_SCALE - self.norm = False - def coord_loss(self, y_true, y_pred): @@ -67,20 +74,16 @@ def coord_loss(self, y_true, y_pred): indicator_coord = K.expand_dims(y_true[..., 4], axis=-1) * self.lambda_coord - norm_coord = 1 - if self.norm: - norm_coord = K.sum(K.cast(indicator_coord > 0.0, np.float32)) - - loss_xy = K.sum(K.square(b_xy - b_xy_pred) * indicator_coord, axis=[1,2,3,4]) - #loss_wh = K.sum(K.square(b_wh - b_wh_pred) * indicator_coord, axis=[1,2,3,4]) - loss_wh = K.sum(K.square(K.sqrt(b_wh) - K.sqrt(b_wh_pred)) * indicator_coord, axis=[1,2,3,4]) + loss_xy = K.sum(K.square(b_xy - b_xy_pred) * indicator_coord)#, axis=[1,2,3,4]) + loss_wh = K.sum(K.square(b_wh - b_wh_pred) * indicator_coord)#, axis=[1,2,3,4]) + #loss_wh = K.sum(K.square(K.sqrt(b_wh) - K.sqrt(b_wh_pred)) * indicator_coord)#, axis=[1,2,3,4]) - return (loss_wh + loss_xy) / (norm_coord + EPSILON) / 2 + return (loss_wh + loss_xy) / 2 def obj_loss(self, y_true, y_pred): - b_o = calculate_ious(y_true, y_pred, use_iou=self.readjust_obj_score) * y_true[..., 4] + b_o = calculate_ious(y_true, y_pred, use_iou=self.readjust_obj_score) b_o_pred = y_pred[..., 4] num_true_labels = YoloParams.GRID_SIZE*YoloParams.GRID_SIZE*YoloParams.NUM_BOUNDING_BOXES @@ -88,18 +91,13 @@ def obj_loss(self, y_true, y_pred): iou_scores_buff = calculate_ious(y_true_p, K.expand_dims(y_pred, axis=4)) best_ious = K.max(iou_scores_buff, axis=4) - indicator_noobj = K.cast(best_ious < self.iou_threshold, np.float32) * (1 - y_true[..., 4]) * self.lambda_noobj + indicator_noobj = K.cast(best_ious < self.iou_filter, np.float32) * (1 - y_true[..., 4]) * self.lambda_noobj indicator_obj = y_true[..., 4] * self.lambda_obj - - - norm_conf = 1 - if self.norm: - norm_conf = K.sum(K.cast((indicator_obj + indicator_noobj) > 0.0, np.float32)) - indicator_o = indicator_obj + indicator_noobj - loss_obj = K.sum(K.square(b_o-b_o_pred) * indicator_o, axis=[1,2,3]) - return loss_obj / (norm_conf + EPSILON) / 2 + loss_obj = K.sum(K.square(b_o-b_o_pred) * indicator_o)#, axis=[1,2,3]) + + return loss_obj / 2 def class_loss(self, y_true, y_pred): @@ -110,43 +108,33 @@ def class_loss(self, y_true, y_pred): #b_class = K.argmax(y_true[..., 5:], axis=-1) #b_class_pred = y_pred[..., 5:] - #loss_class_arg = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=b_class, logits=b_class_pred) + #oss_class_arg = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=b_class, logits=b_class_pred) indicator_class = y_true[..., 4] * self.lambda_class - norm_class = 1 - if self.norm: - norm_class = K.sum(K.cast(indicator_class > 0.0, np.float32)) - - loss_class = K.sum(loss_class_arg * indicator_class, axis=[1,2,3]) - - return loss_class / (norm_class + EPSILON) + loss_class = K.sum(loss_class_arg * indicator_class)#, axis=[1,2,3]) + return loss_class - def _transform_netout(self, y_pred_raw): - y_pred_xy = K.sigmoid(y_pred_raw[..., :2]) + YoloParams.c_grid - y_pred_wh = K.exp(y_pred_raw[..., 2:4]) * YoloParams.anchors - y_pred_conf = K.sigmoid(y_pred_raw[..., 4:5]) - y_pred_class = y_pred_raw[...,5:] - return K.concatenate([y_pred_xy, y_pred_wh, y_pred_conf, y_pred_class], axis=-1) + def l_coord(self, y_true, y_pred_raw): - return self.coord_loss(y_true, self._transform_netout(y_pred_raw)) + return self.coord_loss(y_true, _transform_netout(y_pred_raw)) def l_obj(self, y_true, y_pred_raw): - return self.obj_loss(y_true, self._transform_netout(y_pred_raw)) + return self.obj_loss(y_true, _transform_netout(y_pred_raw)) def l_class(self, y_true, y_pred_raw): - return self.class_loss(y_true, self._transform_netout(y_pred_raw)) + return self.class_loss(y_true, _transform_netout(y_pred_raw)) def __call__(self, y_true, y_pred_raw): - y_pred = self._transform_netout(y_pred_raw) - + y_pred = _transform_netout(y_pred_raw) + total_coord_loss = self.coord_loss(y_true, y_pred) total_obj_loss = self.obj_loss(y_true, y_pred) total_class_loss = self.class_loss(y_true, y_pred) diff --git a/net/netparams.py b/net/netparams.py index 8e2e578..d2321a3 100644 --- a/net/netparams.py +++ b/net/netparams.py @@ -119,7 +119,9 @@ class YoloParams(object): # Model IN_MODEL = args.model - assert os.path.isfile(IN_MODEL), "Pass valid input keras model." + if not os.path.isfile(IN_MODEL): + raise ValueError("Pass valid input keras model.") + OUT_MODEL_NAME = config['train']['out_model_name'] ARCH_FNAME = config['config_path']['arch_plotname'] diff --git a/net/utils.py b/net/utils.py index 1466bbb..1ad340a 100755 --- a/net/utils.py +++ b/net/utils.py @@ -112,7 +112,8 @@ def space_to_depth_x2(x): return tf.space_to_depth(x, block_size=2) -def draw_boxes(image, info): +def draw_boxes(image_in, info): + image = image_in.copy() image_h, image_w, _ = image.shape boxes, scores, labels = info diff --git a/result_plots/drivingsf.gif b/result_plots/drivingsf.gif index 8308dff..a5dcd4e 100644 Binary files a/result_plots/drivingsf.gif and b/result_plots/drivingsf.gif differ diff --git a/result_plots/tbexam.png b/result_plots/tbexam.png index 119dc6c..df7b755 100644 Binary files a/result_plots/tbexam.png and b/result_plots/tbexam.png differ diff --git a/split_dataset.py b/split_dataset.py index 400d67d..dbe9c5b 100644 --- a/split_dataset.py +++ b/split_dataset.py @@ -6,7 +6,7 @@ import argparse from tqdm import tqdm -from utils import mkdir_p +from net.utils import mkdir_p argparser = argparse.ArgumentParser( @@ -40,15 +40,13 @@ def sample_from_dir(paths, train_p): img_path, ann_path, out_path = paths imgs = os.listdir(img_path) - + total_num = len(imgs) - train_num = int(len(imgs)*train_p) + train_num = int(len(imgs)*float(train_p)) img_fmt = '.' + imgs[0].split('.')[1] fns = [im.split('.')[0] for im in imgs] - shuffle(fns) - fn_train = fns[:train_num] fn_val = fns[train_num:] diff --git a/yolov2.py b/yolov2.py index 7c7d2c0..7e122b1 100644 --- a/yolov2.py +++ b/yolov2.py @@ -1,5 +1,4 @@ - -import pickle, argparse, json, cv2, os +import cv2, os import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt @@ -11,17 +10,18 @@ from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard from keras.optimizers import SGD, Adam, RMSprop -from net.utils import parse_annotation, mkdir_p, \ -setup_logging, draw_boxes, generate_gif - from net.netparams import YoloParams -from net.netloss import YoloLoss -from net.neteval import YoloDataGenerator, YoloEvaluate, \ -YoloTensorBoard, Callback_MAP, yolo_recall + +from net.utils import parse_annotation, mkdir_p, setup_logging, draw_boxes, generate_gif from net.netarch import YoloArchitecture, YoloInferenceModel +from net.netloss import YoloLoss +from net.neteval import YoloEvaluate, YoloTensorBoard, Callback_MAP, yolo_recall +from net.netgen import YoloDataGenerator +CAM_WIDTH = 1038 +CAM_HEIGHT = 576 class YoloV2(object): @@ -85,7 +85,7 @@ def inference(self, path): plt.figure(figsize=(10,10)) boxes, scores, _, labels = self.inf_model.predict(image.copy()) - #print(f, labels) + print(f, list(zip(labels, scores))) image = draw_boxes(image, (boxes, scores, labels)) out_name = os.path.join(out_path, os.path.basename(f).split('.')[0] + out_fname_mod) cv2.imwrite(out_name, image) @@ -94,10 +94,17 @@ def inference(self, path): def _video_params(self, name): cap = cv2.VideoCapture(name) - video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + if name == 0: + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, CAM_HEIGHT) + cap.set(cv2.CAP_PROP_FRAME_WIDTH, CAM_WIDTH) + size = (CAM_WIDTH, CAM_HEIGHT) + else: + video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + size = (video_width, video_height) + video_len = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - size = (video_width, video_height) + fps = round(cap.get(cv2.CAP_PROP_FPS)) return cap, size, video_len, fps @@ -134,7 +141,7 @@ def cam_inference(self, fname): cap, size, _, fps = self._video_params(0) fourcc = cv2.VideoWriter_fourcc('m','p','4','v') - if fname: writer = cv2.VideoWriter(fname, fourcc, 40, size) + if fname: writer = cv2.VideoWriter(fname, fourcc, fps, size) while(cap.isOpened()): @@ -144,9 +151,9 @@ def cam_inference(self, fname): boxes, scores, _, labels = self.inf_model.predict(frame) frame_pred = draw_boxes(frame, (boxes, scores, labels)) - if fname: writer.write(frame) + if fname: writer.write(frame_pred) - cv2.imshow('Yolo Output',frame) + cv2.imshow('Yolo Output',frame_pred) if cv2.waitKey(1) & 0xFF == ord('q'): break @@ -168,10 +175,12 @@ def validation(self): yolo_eval = YoloEvaluate(generator=generator, model=self.inf_model) AP = yolo_eval.comp_map() - mAP_values = [] - for class_label, ap in AP.items(): + + _AP_items = [[class_label, ap] for class_label, ap in AP.items()] + AP_items = sorted(_AP_items, key=lambda x: x[1], reverse=True) + + for class_label, ap in AP_items: print("AP( %s ): %.3f"%(class_label, ap)) - mAP_values.append( ap ) # Store AP results as csv #df_ap = pd.DataFrame.from_dict(AP, orient='index') @@ -179,7 +188,7 @@ def validation(self): #df_ap.to_csv('validation_maP.csv', header=False) print('-------------------------------') - print("mAP: %.3f"%(np.mean(mAP_values))) + print("mAP: %.3f"%(np.mean(list(AP.values())))) return AP @@ -191,13 +200,12 @@ def training(self): valid_data = parse_annotation( YoloParams.VALIDATION_ANN_PATH, YoloParams.VALIDATION_IMG_PATH) - train_gen = YoloDataGenerator(train_data, shuffle=True) - valid_gen = YoloDataGenerator(valid_data, shuffle=True) - - + train_gen = YoloDataGenerator(train_data, shuffle=True, augment=True) + valid_gen = YoloDataGenerator(valid_data, shuffle=True) + early_stop = EarlyStopping(monitor='val_loss', min_delta=0.001, - patience=3, + patience=7, mode='min', verbose=1) @@ -220,13 +228,20 @@ def training(self): write_graph=True, write_images=False) + """ + optimizer = SGD( + lr=YoloParams.L_RATE, + momentum=0.9, + decay=0.0005 + ) + """ + optimizer = Adam( lr=YoloParams.L_RATE, beta_1=0.9, beta_2=0.999, - epsilon=1e-08, decay=0.0) - + map_cbck = Callback_MAP(generator=valid_gen, @@ -234,7 +249,7 @@ def training(self): tensorboard=tensorboard) - # add metrics.. + yolo_recall.__name__ = 'recall' metrics = [