diff --git a/README.md b/README.md index 2622aec..d3605ea 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ # dourflow: Keras implementation of YOLO v2 +
+ +
+ + **dourflow** is a keras/tensorflow implementation of the state-of-the-art object detection system [You only look once](https://pjreddie.com/darknet/yolo/). -- Original paper: [YOLO9000: Better, Faster, Stronger](https://arxiv.org/abs/1612.08242) -- Github repo: [Darknet](https://github.com/pjreddie/darknet) +Original paper and github: [YOLO9000: Better, Faster, Stronger](https://arxiv.org/abs/1612.08242) & [Darknet](https://github.com/pjreddie/darknet). - -- -
### Dependancies --- @@ -17,7 +17,7 @@ - [numpy](http://www.numpy.org/) - [h5py](http://www.h5py.org/) - [opencv](https://pypi.org/project/opencv-python/) -- [python 3](https://www.python.org/) +- [python3.5](https://www.python.org/) ### Usage --- @@ -131,7 +131,7 @@ Example: ```bash python3 dourflow.py validate -m voc_model.h5 -c voc_config.json ``` -Terminal output: +Output: ```bash Batch Processed: 100%|████████████████████████████████████████████| 4282/4282 [01:53<00:00, 37.84it/s] AP( bus ): 0.806 diff --git a/net/netloss.py b/net/netloss.py index 5f33079..93c145f 100644 --- a/net/netloss.py +++ b/net/netloss.py @@ -1,10 +1,10 @@ import tensorflow as tf import numpy as np - +from keras import backend as K from net.netparams import YoloParams -EPSILON = 1e-6 +EPSILON = 1e-7 def calculate_ious(A1, A2, use_iou=True): @@ -26,22 +26,22 @@ def calculate_ious(A1, A2, use_iou=True): A2_mins = A2_xy - A2_wh_half A2_maxes = A2_xy + A2_wh_half - intersect_mins = tf.maximum(A2_mins, A1_mins) - intersect_maxes = tf.minimum(A2_maxes, A1_maxes) - intersect_wh = tf.maximum(intersect_maxes - intersect_mins, 0.) + intersect_mins = K.maximum(A2_mins, A1_mins) + intersect_maxes = K.minimum(A2_maxes, A1_maxes) + intersect_wh = K.maximum(intersect_maxes - intersect_mins, 0.) intersect_areas = intersect_wh[..., 0] * intersect_wh[..., 1] true_areas = A1_wh[..., 0] * A1_wh[..., 1] pred_areas = A2_wh[..., 0] * A2_wh[..., 1] union_areas = pred_areas + true_areas - intersect_areas - iou_scores = tf.truediv(intersect_areas, union_areas) + iou_scores = intersect_areas / union_areas return iou_scores class YoloLoss(object): - # ADD WARM UP CONDITIONS + # WARM UP def __init__(self): @@ -64,16 +64,16 @@ def coord_loss(self, y_true, y_pred): b_xy = y_true[..., 0:2] b_wh = y_true[..., 2:4] - indicator_coord = tf.expand_dims(y_true[..., 4], axis=-1) * self.lambda_coord + indicator_coord = K.expand_dims(y_true[..., 4], axis=-1) * self.lambda_coord norm_coord = 1 if self.norm: - norm_coord = tf.reduce_sum(tf.to_float(indicator_coord > 0.0)) + norm_coord = K.sum(K.cast(indicator_coord > 0.0, np.float32)) - loss_xy = tf.reduce_sum(tf.square(b_xy - b_xy_pred) * indicator_coord, axis=[1,2,3,4]) - #loss_wh = tf.reduce_sum(tf.square(b_wh - b_wh_pred) * indicator_coord, axis=[1,2,3,4]) - loss_wh = tf.reduce_sum(tf.square(tf.sqrt(b_wh) - tf.sqrt(b_wh_pred)) * indicator_coord, axis=[1,2,3,4]) + loss_xy = K.sum(K.square(b_xy - b_xy_pred) * indicator_coord, axis=[1,2,3,4]) + #loss_wh = K.sum(K.square(b_wh - b_wh_pred) * indicator_coord, axis=[1,2,3,4]) + loss_wh = K.sum(K.square(K.sqrt(b_wh) - K.sqrt(b_wh_pred)) * indicator_coord, axis=[1,2,3,4]) return (loss_wh + loss_xy) / (norm_coord + EPSILON) / 2 @@ -84,50 +84,50 @@ def obj_loss(self, y_true, y_pred): b_o_pred = y_pred[..., 4] num_true_labels = YoloParams.GRID_SIZE*YoloParams.GRID_SIZE*YoloParams.NUM_BOUNDING_BOXES - y_true_p = tf.reshape(y_true[..., :4], shape=(YoloParams.BATCH_SIZE, 1, 1, 1, num_true_labels, 4)) - iou_scores_buff = calculate_ious(y_true_p, tf.expand_dims(y_pred, axis=4)) - best_ious = tf.reduce_max(iou_scores_buff, axis=4) + y_true_p = K.reshape(y_true[..., :4], shape=(YoloParams.BATCH_SIZE, 1, 1, 1, num_true_labels, 4)) + iou_scores_buff = calculate_ious(y_true_p, K.expand_dims(y_pred, axis=4)) + best_ious = K.max(iou_scores_buff, axis=4) - indicator_noobj = tf.to_float(best_ious < self.iou_threshold) * (1 - y_true[..., 4]) * self.lambda_noobj + indicator_noobj = K.cast(best_ious < self.iou_threshold, np.float32) * (1 - y_true[..., 4]) * self.lambda_noobj indicator_obj = y_true[..., 4] * self.lambda_obj norm_conf = 1 if self.norm: - norm_conf = tf.reduce_sum(tf.to_float((indicator_obj + indicator_noobj) > 0.0)) + norm_conf = K.sum(K.cast((indicator_obj + indicator_noobj) > 0.0), np.float32) - loss_obj = tf.reduce_sum(tf.square(b_o-b_o_pred) * (indicator_obj + indicator_noobj), axis=[1,2,3]) + loss_obj = K.sum(K.square(b_o-b_o_pred) * (indicator_obj + indicator_noobj), axis=[1,2,3]) return loss_obj / (norm_conf + EPSILON) / 2 def class_loss(self, y_true, y_pred): - b_class = tf.argmax(y_true[..., 5:], axis=-1) + b_class = K.argmax(y_true[..., 5:], axis=-1) b_class_pred = y_pred[..., 5:] - indicator_class = y_true[..., 4] * tf.gather( + indicator_class = y_true[..., 4] * K.gather( YoloParams.CLASS_WEIGHTS, b_class) * self.lambda_class norm_class = 1 if self.norm: - norm_class = tf.reduce_sum(tf.to_float(indicator_class > 0.0)) + norm_class = K.sum(K.cast(indicator_class > 0.0, np.float32)) loss_class_arg = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=b_class, logits=b_class_pred) - loss_class = tf.reduce_sum(loss_class_arg * indicator_class, axis=[1,2,3]) + loss_class = K.sum(loss_class_arg * indicator_class, axis=[1,2,3]) return loss_class / (norm_class + EPSILON) def _transform_netout(self, y_pred_raw): - y_pred_xy = tf.sigmoid(y_pred_raw[..., :2]) + YoloParams.c_grid - y_pred_wh = tf.exp(y_pred_raw[..., 2:4]) * YoloParams.anchors - y_pred_conf = tf.sigmoid(y_pred_raw[..., 4:5]) + y_pred_xy = K.sigmoid(y_pred_raw[..., :2]) + YoloParams.c_grid + y_pred_wh = K.exp(y_pred_raw[..., 2:4]) * YoloParams.anchors + y_pred_conf = K.sigmoid(y_pred_raw[..., 4:5]) y_pred_class = y_pred_raw[...,5:] - return tf.concat([y_pred_xy, y_pred_wh, y_pred_conf, y_pred_class], axis=-1) + return K.concatenate([y_pred_xy, y_pred_wh, y_pred_conf, y_pred_class], axis=-1) @@ -149,8 +149,6 @@ def __call__(self, y_true, y_pred_raw): return loss - - if __name__ == '__main__': sess = tf.InteractiveSession() @@ -160,26 +158,4 @@ def __call__(self, y_true, y_pred_raw): var = YoloLoss() - print( var(y_true, y_pred).eval() ) - - - - - - - - - - - - - - - - - - - - - - + print( var(y_true, y_pred).eval() ) \ No newline at end of file diff --git a/result_plots/the_office.png b/result_plots/the_office.png deleted file mode 100644 index e6204ee..0000000 Binary files a/result_plots/the_office.png and /dev/null differ