-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Bryan Plummer
committed
Feb 3, 2018
1 parent
544f5b2
commit 19d2b36
Showing
6 changed files
with
606 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,41 @@ | ||
# cite | ||
Implementation for our paper "Conditional Image-Text Embedding Networks" (coming soon) | ||
# Conditional Image-Text Embedding Networks | ||
|
||
**cite** contains a Tensorflow implementation for our [paper](https://arxiv.org/abs/1711.08389). If you find this code useful in your research, please consider citing: | ||
|
||
@article{plummerCITE2017, | ||
Author = {Bryan A. Plummer and Paige Kordas and M. Hadi Kiapour and Shuai Zheng and Robinson Piramuthu and Svetlana Lazebnik}, | ||
Title = {Conditional Image-Text Embedding Networks}, | ||
Journal = {arXiv:1711.08389}, | ||
Year = {2017} | ||
} | ||
|
||
This code was tested on an Ubuntu 16.04 system using Tensorflow 1.2.1. | ||
|
||
### Phrase Localization Evaluation Demo | ||
After you download our precomputed features/model you can test it using: | ||
|
||
```Shell | ||
python main.py --test --spatial --name runs/cite_spatial_k4/model_best | ||
``` | ||
|
||
### Training New Models | ||
Our code contains everything required to train or test models using precomputed features. You can train a new model on Flickr30K Entites using: | ||
|
||
```Shell | ||
python main.py --name <name of experiment> | ||
``` | ||
|
||
When it completes training it will output the localization accuracy using the best model on the testing and validation sets. You can see a listing and description of many tuneable parameters with: | ||
|
||
```Shell | ||
python main.py --help | ||
``` | ||
|
||
### Precomputed Features | ||
|
||
Along with our example data processing script in `data_processing_example` you can download our precomputed (PASCAL) features for the Flickr30K Entities dataset [here](https://drive.google.com/file/d/1m5DQ3kh2rCkPremgM91chQgJYZxnEbZw/view?usp=sharing) (52G). Unpack the features in a folder named `data` or update the path in the data loader class. | ||
|
||
Our best CITE model on Flickr30K Entities using these precomputed features can be found [here](https://drive.google.com/open?id=1rmeIqYTCIduNc2QWUEdXLHFGrlOzz2xO). | ||
|
||
|
||
Many thanks to [Kevin Shih](https://scholar.google.com/citations?user=4x3DhzAAAAAJ&hl=en) and [Liwei Wang](https://scholar.google.com/citations?user=qnbdnZEAAAAJ&hl=en) for access to their Similarity Network code that was used as the basis for this implementation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import numpy as np | ||
import h5py | ||
import os | ||
|
||
class DataLoader: | ||
"""Class minibatches from data on disk in HDF5 format""" | ||
def __init__(self, args, region_dim, phrase_dim, plh, split): | ||
"""Constructor | ||
Arguments: | ||
args -- command line arguments passed into the main function | ||
region_dim -- dimensions of the region features | ||
phrase_dim -- dimensions of the phrase features | ||
plh -- placeholder dictory containing the tensor inputs | ||
split -- the data split (i.e. 'train', 'test', 'val') | ||
""" | ||
datafn = os.path.join('data', 'flickr', '%s_imfeats.h5' % split) | ||
self.data = h5py.File(datafn, 'r') | ||
vecs = np.array(self.data['phrase_features'], np.float32) | ||
uniquePhrases = list(self.data['phrases']) | ||
assert(vecs.shape[0] == len(uniquePhrases)) | ||
|
||
w2v_dict = {} | ||
for index, phrase in enumerate(uniquePhrases): | ||
w2v_dict[phrase] = vecs[index, :] | ||
|
||
# mapping from uniquePhrase to w2v | ||
self.w2v_dict = w2v_dict | ||
self.pairs = list(self.data['pairs']) | ||
self.n_pairs = len(self.pairs[0]) | ||
self.pair_index = range(self.n_pairs) | ||
|
||
self.uniquePhrases = uniquePhrases # set of unique phrases | ||
|
||
self.split = split | ||
self.plh = plh | ||
self.is_train = split == 'train' | ||
self.neg_to_pos_ratio = args.neg_to_pos_ratio | ||
self.batch_size = args.batch_size | ||
self.max_boxes = args.max_boxes | ||
if self.is_train: | ||
self.success_thresh = args.train_success_thresh | ||
else: | ||
self.success_thresh = args.test_success_thresh | ||
|
||
self.region_feature_dim = region_dim | ||
self.phrase_feature_dim = phrase_dim | ||
|
||
def __len__(self): | ||
return self.n_pairs | ||
|
||
def shuffle(self): | ||
''' Shuffles the order of the pairs being sampled | ||
''' | ||
np.random.shuffle(self.pair_index) | ||
|
||
def num_batches(self): | ||
return int(np.ceil(float(len(self)) / self.batch_size)) | ||
|
||
def get_batch(self, batch_id): | ||
"""Returns a minibatch given a valid id for it | ||
Arguments: | ||
batch_id -- number between 0 and self.num_batches() | ||
Returns: | ||
feed_dict -- dictionary containing minibatch data | ||
gt_labels -- indicates positive/negative regions | ||
num_pairs -- number of pairs without padding | ||
""" | ||
region_features = np.zeros((self.batch_size, self.max_boxes, | ||
self.region_feature_dim), dtype=np.float32) | ||
start_pair = batch_id * self.batch_size | ||
end_pair = min(start_pair + self.batch_size, len(self)) | ||
num_pairs = end_pair - start_pair | ||
gt_labels = np.zeros((self.batch_size, self.max_boxes), | ||
dtype=np.float32) | ||
phrase_features = np.zeros((self.batch_size, self.phrase_feature_dim), | ||
dtype=np.float32) | ||
for pair_id in range(num_pairs): | ||
sample_id = self.pair_index[start_pair + pair_id] | ||
|
||
# paired image | ||
im_id = self.pairs[0][sample_id] | ||
|
||
# paired phrase | ||
phrase = self.pairs[1][sample_id] | ||
|
||
# phrase instance identifier | ||
p_id = self.pairs[2][sample_id] | ||
|
||
# gets region features | ||
features = np.array(self.data[im_id], np.float32) | ||
num_boxes = min(len(features), self.max_boxes) | ||
features = features[:num_boxes, :self.region_feature_dim] | ||
overlaps = np.array(self.data['%s_%s_%s' % (im_id, phrase, p_id)]) | ||
|
||
# last 4 dimensions of overlaps are ground truth box coordinates | ||
assert(num_boxes <= len(overlaps) - 4) | ||
overlaps = overlaps[:num_boxes] | ||
region_features[pair_id, :num_boxes,:] = features | ||
phrase_features[pair_id, :] = self.w2v_dict[phrase] | ||
gt_labels[pair_id, :num_boxes] = overlaps >= self.success_thresh | ||
if self.is_train: | ||
num_pos = int(np.sum(gt_labels[pair_id, :])) | ||
num_neg = num_pos * self.neg_to_pos_ratio | ||
negs = np.random.permutation(np.where(overlaps < 0.3)[0]) | ||
|
||
if len(negs) < num_neg: # if not enough negatives | ||
negs = np.random.permutation(np.where(overlaps < 0.4)[0]) | ||
|
||
# logistic loss only counts a region labeled as -1 negative | ||
gt_labels[pair_id, negs[:num_neg]] = -1 | ||
|
||
feed_dict = {self.plh['phrase'] : phrase_features, | ||
self.plh['region'] : region_features, | ||
self.plh['train_phase'] : self.is_train, | ||
self.plh['num_boxes'] : self.max_boxes, | ||
self.plh['labels'] : gt_labels | ||
} | ||
|
||
return feed_dict, gt_labels, num_pairs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Conditional Image-Text Embedding Networks | ||
|
||
The code currently assumes datasets are divided into three hdf5 files named `<split>_imfeats.h5` where `split` takes on the value train, test, or val. It assumes it has the following items: | ||
|
||
1. phrase_features: #num_phrase X 6000 dimensional matrix of phrase features | ||
2. phrases: array of #num_phrase strings corresponding to the phrase features | ||
3. pairs: 3 x M matrix where each column contains a string representation for the `[image name, phrase, pair identifier]` pairs in the split. | ||
4. Each `<image name>` should return a #num_boxes x feature_dimensional matrix of the visual features. The features should contain the visual representation as well as the spatial features for the box followed by its coordinates (i.e. the precomputed features we released are 4096 (VGG) + 5 (spatial) + 4 (box coordinates) = 4105 dimensional). | ||
5. Each `<image name>_<phrase>_<pair identifier> should contain a vector containing the intersection over union with the ground truth box followed by the box's coordinates (i.e. for N boxes the vector should be N + 4 dimensional). | ||
|
||
|
||
The example script uses the [pl-clc](https://github.com/BryanPlummer/pl-clc) repo for parsing and computing features of the Flick30K Entities dataset. It assumes it uses the built-in MATLAB PCA function, and not the one in the `toolbox` external module. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
% Code should be placed in the pl-clc directory to run and assumes | ||
% the dataset has been downloaded as specified in that repo. | ||
net = fullfile('models', 'voc_2007_trainvaltest_2012_trainval', 'vgg16_fast_rcnn_iter_100000.caffemodel'); | ||
def = fullfile('models', 'fastrcnn_feat.prototxt'); | ||
image_dir = fullfile('datasets', 'Flickr30kEntities', 'Images'); | ||
output_dir = '.'; | ||
|
||
% code assumes train is the first to be processed | ||
splits = {'train', 'test', 'val'}; | ||
phraseCoeff = []; | ||
phraseMean = []; | ||
for i = 1:length(splits) | ||
load(fullfile('data', 'flickr30k', sprintf('%sData.mat',splits{i}),'imData'); | ||
|
||
imData.filterNonvisualPhrases(); | ||
if strcmp(splits{i},'train') | ||
imData.concatenateGTBoxes(); | ||
end | ||
|
||
% getPhraseWords returns a nested image x sentence x phrase | ||
% cell array | ||
uniquePhrases = imData.getPhraseWords(); | ||
uniquePhrases = vertcat(uniquePhrases{:}); | ||
maxPhrases = max(cellfun(@length, uniquePhrases)); | ||
uniquePhrases = vertcat(uniquePhrases{:}); | ||
uniquePhrases(cellfun(@isempty, uniquePhrases)) = []; | ||
uniquePhrases = cellfun(@(f)strrep(f,'+',' '), uniquePhrases, 'UniformOutput',false); | ||
uniquePhrases = cellfun(@(f)strtrim(removePunctuation(f)), uniquePhrases, 'UniformOutput',false); | ||
uniquePhrases(cellfun(@isempty, uniquePhrases)) = []; | ||
uniquePhrases = unique(uniquePhrases); | ||
uniquePhrases = cellfun(@(f)strrep(f,' ','+'), uniquePhrases, 'UniformOutput',false); | ||
phraseFeatures = single(getHGLMMFeatures(uniquePhrases))'; | ||
uniquePhrases = [uniquePhrases;{'unk'}]; | ||
|
||
% compute PCA parameters on first iteration (train split) | ||
if isempty(phraseCoeff) | ||
phraseMean = mean(phraseFeatures, 1); | ||
phraseFeatures = bsxfun(@minus, phraseFeatures, phraseMean); | ||
phraseCoeff = pca(phraseFeatures, 'NumComponents', 6000); | ||
else | ||
phrase_features = bsxfun(@minus, phraseFeatures, phraseMean); | ||
end | ||
|
||
phraseFeatures = phraseFeatures * phraseCoeff; | ||
phraseFeatures = [phraseFeatures; zeros(1, 6000, 'single')]; | ||
phraseFeatures = phraseFeatures'; | ||
|
||
hdf5fn = fullfile(output_dir, sprintf('%s_imfeats.h5', splits{i})); | ||
hdf5write(hdf5fn, 'max_phrases', maxPhrases, 'phrase_features', phraseFeatures, 'phrases', uniquePhrases); | ||
clear phraseFeatures | ||
|
||
pairs = cell(imData.nImages, 1); | ||
for x = 1:imData.nImages | ||
boxes = imData.getBoxes(x); | ||
pairs{x} = cell(imData.nSentences(x), 1); | ||
for y = 1:imData.nSentences(x) | ||
pairs{x}{y} = cell(imData.nPhrases(x, y), 1); | ||
for z = 1:imData.nPhrases(x, y) | ||
box = imData.getPhraseGT(x, y, z); | ||
phrase = imData.getPhrase(x,y,z).getPhraseString(imData.stopwords); | ||
if ~isempty(phrase) | ||
phrase = strtrim(removePunctuation(strrep(phrase,'+',' '))); | ||
end | ||
|
||
if isempty(phrase) | ||
phrase = 'unk'; | ||
end | ||
phrase = strrep(phrase,' ','+'); | ||
|
||
p_id = sprintf('%i_%i', y, z); | ||
pairs{x}{y}{z} = {imData.imagefns{x}, phrase, p_id}; | ||
overlaps = getIOU(box, boxes); | ||
phrase_id = sprintf('%s_%s_%s', imData.imagefns{x}, phrase, p_id); | ||
overlaps = [overlaps; box']; | ||
hdf5write(hdf5fn, phrase_id, overlaps, 'WriteMode', 'append'); | ||
end | ||
end | ||
end | ||
|
||
pairs = vertcat(pairs{:}); | ||
pairs = vertcat(pairs{:}); | ||
pairs = vertcat(pairs{:}); | ||
hdf5write(hdf5fn, 'pairs', pairs, 'WriteMode', 'append'); | ||
|
||
% separate into batches since getFastRCNNFeatures operates on a | ||
% batch level | ||
batchSize = 1000; | ||
nBatches = ceil(imData.nImages/batchSize); | ||
for batch = 1:batchSize:imData.nImages | ||
batchEnd = min(batch+batchSize-1,imData.nImages); | ||
imagefns = imData.imagefns(batch:batchEnd); | ||
imagedir = imData.imagedir; | ||
ext = imData.ext; | ||
stopwords = imData.stopwords; | ||
batchData = ImageSetData(imData.split,imagefns,imagedir,ext,stopwords); | ||
batchData.phrase = imData.phrase(batch:batchEnd); | ||
batchData.relationship = imData.phrase(batch:batchEnd); | ||
batchData.annotations = imData.annotations(batch:batchEnd); | ||
batchData.boxes = imData.boxes(batch:batchEnd); | ||
regionFeatures = getFastRCNNFeatures(batchData, net, def); | ||
for j = 1:batchData.nImages | ||
boxes = batchData.getBoxes(j); | ||
imDims = batchData.imSize(j); | ||
boxFeatures = boxes; | ||
boxFeatures(:, 1) = boxFeatures(:,1) / imDims(2); | ||
boxFeatures(:, 2) = boxFeatures(:,2) / imDims(1); | ||
boxFeatures(:, 3) = boxFeatures(:,3) / imDims(2); | ||
boxFeatures(:, 4) = boxFeatures(:,4) / imDims(1); | ||
boxWidth = boxes(:,3) - boxes(:,1); | ||
boxHeight = boxes(:,4) - boxes(:,2); | ||
boxFeatures = [boxFeatures, (boxWidth.*boxHeight)./(imDims(1)*imDims(2))]; | ||
features = [regionFeatures{j}; boxFeatures'; boxes']; | ||
hdf5write(hdf5fn, batchData.imagefns{j}, features, 'WriteMode', 'append'); | ||
end | ||
end | ||
end | ||
|
Oops, something went wrong.