Skip to content

Commit

Permalink
Bounding Box prior for custom dataset generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ksanjeevan committed Jul 18, 2018
1 parent 0b29d94 commit 7c40910
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 66 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ Script to generate training/testing splits.
`python3 split_dataset.py -p 0.75 --in_ann VOC2012/Annotations/ --in_img VOC2012/JPEGImages/ --output ~/Documents/DATA/VOC`


##### Anchor Generation

Running:

`python3 dourflow.py genp -c config.json`

Will store your the bounding box priors wherever the path indicates in the config file under **config['config_path']['anchors']** with the prefix 'custom_' (so as to not overwrite accidentally).

##### Tensorboard

Training will create directory **logs/** which will store metrics and checkpoints for all the different training runs.
Expand All @@ -207,7 +215,7 @@ Then, in another terminal tab you can run `tensorboard --logdir=logs/run_X` and
#### To Do

- [ ] cfg parser
- [ ] Anchor generation for custom datasets
- [x] Anchor generation for custom datasets
- [ ] mAP write up
- [x] Add webcam support
- [ ] Data Augmentation
Expand Down
5 changes: 2 additions & 3 deletions dourflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from yolov2 import YoloV2, YoloInferenceModel
import os

#from net.neteval import gen_anchors
from kmeans_anchors import gen_anchors


# Add CPU option
Expand All @@ -17,8 +17,7 @@
if YoloParams.WEIGHT_FILE:
generate_model()
elif YoloParams.GEN_ANCHORS_PATH:
pass
#gen_anchors(YoloParams.GEN_ANCHORS_PATH)
gen_anchors(YoloParams.GEN_ANCHORS_PATH)
else:
YoloV2().run()

Expand Down
188 changes: 188 additions & 0 deletions kmeans_anchors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@

from net.netparams import YoloParams
from net.utils import compute_iou, parse_annotation
import numpy as np
from scipy.spatial.distance import cdist

# See https://arxiv.org/abs/1612.08242
NUM_CENTROIDS = 5


def weighted_choice(choices):

r = np.random.uniform(0, np.sum(choices))
upto = 0
for c, w in enumerate(choices):
if upto + w >= r:
return c
upto += w
return 0

class KMeans:

def __init__(self, k):

self.k = k
self.diff_thresh = 1
self.distf = IoU_dist
#self.distf = lambda x,y: (x[0]-y[0])**2 + (x[1]-y[1])**2

def fit(self, data):
initial_centroids = self.init_centroids_kpp(data)

self.centroids, self.clusters = self.cluster_data(data, initial_centroids)
return self.centroids, self.clusters

def init_centroids_kpp(self, data):

centroids = []

random_index = np.random.randint(len(data))
centroids.append(data[random_index])

while len(centroids) < self.k:

prob_array = np.apply_along_axis(lambda x:
self.mindist2(x, centroids), 1, data)

norm = sum(prob_array)
prob_array /= (norm + 1e-8)

new_index = weighted_choice(prob_array)
centroids.append(data[new_index])

return np.array(centroids)


def mindist2(self, x, centroids):
dists = np.apply_along_axis(lambda c: self.distf(x, c),1, centroids)
return np.min(dists) * np.min(dists)


def cluster_data(self, data, initial_centroids):
centroids = initial_centroids
clusters = []
counter = 0
while True:
old_clusters = clusters
old_centroids = centroids

clusters = self.clusterfy(data, centroids)

centroids = self.recalc_centroids(data, clusters)

# Kmeans stopping condition based on some centroid shift delta?
if len(old_clusters)>0:
num_diffs = np.sum(old_clusters != clusters)
print("Iteration = %d, Delta = %d"%(counter, num_diffs), flush=True)

if num_diffs <= self.diff_thresh:
break
counter += 1

return centroids, clusters

def clusterfy(self, data, centroids):
return np.apply_along_axis(lambda d:
np.argmin(cdist([d], centroids, self.distf)[0]), 1, data)


def recalc_centroids(self, data, clusters):

new_centroids = []

for centroid_index in range(self.k):

centroid_data_idxs = np.where(clusters==centroid_index)[0]
centroid_data = data[centroid_data_idxs]
new_centroids.append( np.mean(centroid_data, axis=0) )

return np.array(new_centroids)


def IoU_dist(x, c):
return 1. - compute_iou([0,0,x[0],x[1]], [0,0,c[0],c[1]])




def exrtract_wh(img):
result = []
pixel_height = img['height']
pixel_width = img['width']

fact_pixel_grid_h = YoloParams.GRID_SIZE / pixel_height
fact_pixel_grid_w = YoloParams.GRID_SIZE / pixel_width

for obj in img['object']:
grid_w = (obj['xmax'] - obj['xmin']) * fact_pixel_grid_w
grid_h = (obj['ymax'] - obj['ymin']) * fact_pixel_grid_h
result.append( [grid_w, grid_h] )

return result

def gen_anchors(fname):

imgs = parse_annotation(YoloParams.TRAIN_ANN_PATH,YoloParams.TRAIN_IMG_PATH)

data_wh = []
for img in imgs:
data_wh += exrtract_wh(img)

clustering = KMeans(NUM_CENTROIDS)

centroids, _ = clustering.fit(np.array(data_wh))
anchors = list(centroids.flatten())

anchors_text = "".join(["%.5f, "%a \
if i < len(anchors)-1 else "%.5f"%a for i,a in enumerate(anchors)])

fname = fname if fname != 'custom_' else 'custom_anchors.txt'

with open(fname,'w') as f:
f.write("%s"%anchors_text)

print("\nAnchors: \n")
print(anchors_text)
print("\n\tSored at: %s\n"%(fname))

return anchors


def test():
import matplotlib.pyplot as plt

data1 = np.random.multivariate_normal([0,0], [[5,0],[0,5]], size=1000)
data2 = np.random.multivariate_normal([0,10], [[5,0],[0,3]], size=700)
data3 = np.random.multivariate_normal([10,0], [[2,0],[0,5]], size=900)

data = np.concatenate([data1, data2, data3], axis=0)

clust = KMeans(3)
centroids, clusters = clust.fit(data)
colors = ['c', 'g', 'r']

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(111)

for k in range(len(centroids)):
clust_data = data[np.where(clusters==k)[0]]
x,y = clust_data.T
ax.scatter(x,y, color=colors[k])

x,y = centroids.T
ax.scatter(x,y, color='k')

ax.set_title('Test')

fig.savefig('test.png', format='png')
plt.close()




if __name__ == '__main__':


gen_anchors()

31 changes: 1 addition & 30 deletions net/neteval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,11 @@
import cv2, os
import keras
from net.utils import draw_boxes, compute_iou, mkdir_p, \
mkdir_p, handle_empty_indexing, parse_annotation
mkdir_p, handle_empty_indexing

from keras import backend as K


'''
def exrtract_wh(img):
result = []
pixel_height = img['height']
pixel_width = img['width']
fact_pixel_grid_h = YoloParams.GRID_SIZE / pixel_height
fact_pixel_grid_w = YoloParams.GRID_SIZE / pixel_width
for obj in img['object']:
grid_h = (obj['ymax'] - obj['ymin']) * fact_pixel_grid_h
grid_w = (obj['xmax'] - obj['xmin']) * fact_pixel_grid_w
result.append( np.array(grid_h, grid_w) )
return result
def gen_anchors(fname):
imgs, _ = parse_annotation(ann_dir, img_dir)
data_wh = []
for img in imgs:
data_wh += exrtract_wh(img)
c = AgglomerativeClustering(self.num_clusters, affinity='precomputed', linkage=self.c_type)
'''




Expand Down
49 changes: 26 additions & 23 deletions net/netparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ class YoloParams(object):
WEBCAM_OUT = 'cam_out.mp4'
YOLO_MODE = 'cam'
elif action in ['genp', 'generate_priors']:
GEN_ANCHORS_PATH = 'new_anchors.png'
current_anchors_path = config['config_path']['anchors']
GEN_ANCHORS_PATH = os.path.join(os.path.dirname(current_anchors_path),
'custom_'+os.path.basename(current_anchors_path))
YOLO_MODE = 'genp'
else:
if action in ['validate', 'train', 'cam']:
Expand Down Expand Up @@ -123,26 +125,27 @@ class YoloParams(object):
INPUT_SIZE = config['model']['input_size']
GRID_SIZE = config['model']['grid_size']
TRUE_BOX_BUFFER = config['model']['true_box_buffer']
ANCHORS = [float(a) for a in open(config['config_path']['anchors']).read().split(', ')]

NUM_BOUNDING_BOXES = len(ANCHORS) // 2
OBJECT_SCALE = 5.0
NO_OBJECT_SCALE = 1.0
CLASS_SCALE = 1.0
COORD_SCALE = 1.0

# Train params
BATCH_SIZE = config['train']['batch_size']
L_RATE = config['train']['learning_rate']
NUM_EPOCHS = config['train']['num_epochs']
TRAIN_VERBOSE = config['train']['verbose']

# Thresholding
IOU_THRESHOLD = get_threshold(config['model']['iou_threshold'])
NMS_THRESHOLD = get_threshold(config['model']['nms_threshold'])
DETECTION_THRESHOLD = get_threshold(args.threshold)

# Additional / Precomputing
c_grid = generate_yolo_grid(BATCH_SIZE, GRID_SIZE, NUM_BOUNDING_BOXES)
anchors = np.reshape(ANCHORS, [1,1,1,NUM_BOUNDING_BOXES,2])

if config['config_path']['anchors']:
ANCHORS = [float(a) for a in open(config['config_path']['anchors']).read().split(', ')]
NUM_BOUNDING_BOXES = len(ANCHORS) // 2
OBJECT_SCALE = 5.0
NO_OBJECT_SCALE = 1.0
CLASS_SCALE = 1.0
COORD_SCALE = 1.0

# Train params
BATCH_SIZE = config['train']['batch_size']
L_RATE = config['train']['learning_rate']
NUM_EPOCHS = config['train']['num_epochs']
TRAIN_VERBOSE = config['train']['verbose']

# Thresholding
IOU_THRESHOLD = get_threshold(config['model']['iou_threshold'])
NMS_THRESHOLD = get_threshold(config['model']['nms_threshold'])
DETECTION_THRESHOLD = get_threshold(args.threshold)

# Additional / Precomputing
c_grid = generate_yolo_grid(BATCH_SIZE, GRID_SIZE, NUM_BOUNDING_BOXES)
anchors = np.reshape(ANCHORS, [1,1,1,NUM_BOUNDING_BOXES,2])

9 changes: 3 additions & 6 deletions net/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,7 @@ def parse_annotation(ann_dir, img_dir, labels=[]):
}
"""
# seen_labels: {'classname': count}
return all_imgs, seen_labels





return all_imgs


def setup_logging(logging_path='logs'):
Expand All @@ -227,6 +222,8 @@ def setup_logging(logging_path='logs'):
return run_path




def handle_empty_indexing(arr, idx):
if idx.size > 0:
return arr[idx]
Expand Down
6 changes: 3 additions & 3 deletions yolov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def cam_inference(self, fname):

def validation(self):

valid_data, _ = parse_annotation(
valid_data = parse_annotation(
YoloParams.VALIDATION_ANN_PATH, YoloParams.VALIDATION_IMG_PATH)

generator = YoloDataGenerator(valid_data, shuffle=True)
Expand All @@ -183,9 +183,9 @@ def validation(self):

def training(self):

train_data, _ = parse_annotation(
train_data = parse_annotation(
YoloParams.TRAIN_ANN_PATH, YoloParams.TRAIN_IMG_PATH)
valid_data, _ = parse_annotation(
valid_data = parse_annotation(
YoloParams.VALIDATION_ANN_PATH, YoloParams.VALIDATION_IMG_PATH)

train_gen = YoloDataGenerator(train_data, shuffle=True)
Expand Down

0 comments on commit 7c40910

Please sign in to comment.