Skip to content

Commit

Permalink
Initital release version.
Browse files Browse the repository at this point in the history
  • Loading branch information
Bryan Plummer committed Feb 3, 2018
1 parent 544f5b2 commit 19d2b36
Show file tree
Hide file tree
Showing 6 changed files with 606 additions and 2 deletions.
43 changes: 41 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,41 @@
# cite
Implementation for our paper "Conditional Image-Text Embedding Networks" (coming soon)
# Conditional Image-Text Embedding Networks

**cite** contains a Tensorflow implementation for our [paper](https://arxiv.org/abs/1711.08389). If you find this code useful in your research, please consider citing:

@article{plummerCITE2017,
Author = {Bryan A. Plummer and Paige Kordas and M. Hadi Kiapour and Shuai Zheng and Robinson Piramuthu and Svetlana Lazebnik},
Title = {Conditional Image-Text Embedding Networks},
Journal = {arXiv:1711.08389},
Year = {2017}
}

This code was tested on an Ubuntu 16.04 system using Tensorflow 1.2.1.

### Phrase Localization Evaluation Demo
After you download our precomputed features/model you can test it using:

```Shell
python main.py --test --spatial --name runs/cite_spatial_k4/model_best
```

### Training New Models
Our code contains everything required to train or test models using precomputed features. You can train a new model on Flickr30K Entites using:

```Shell
python main.py --name <name of experiment>
```

When it completes training it will output the localization accuracy using the best model on the testing and validation sets. You can see a listing and description of many tuneable parameters with:

```Shell
python main.py --help
```

### Precomputed Features

Along with our example data processing script in `data_processing_example` you can download our precomputed (PASCAL) features for the Flickr30K Entities dataset [here](https://drive.google.com/file/d/1m5DQ3kh2rCkPremgM91chQgJYZxnEbZw/view?usp=sharing) (52G). Unpack the features in a folder named `data` or update the path in the data loader class.

Our best CITE model on Flickr30K Entities using these precomputed features can be found [here](https://drive.google.com/open?id=1rmeIqYTCIduNc2QWUEdXLHFGrlOzz2xO).


Many thanks to [Kevin Shih](https://scholar.google.com/citations?user=4x3DhzAAAAAJ&hl=en) and [Liwei Wang](https://scholar.google.com/citations?user=qnbdnZEAAAAJ&hl=en) for access to their Similarity Network code that was used as the basis for this implementation.
123 changes: 123 additions & 0 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np
import h5py
import os

class DataLoader:
"""Class minibatches from data on disk in HDF5 format"""
def __init__(self, args, region_dim, phrase_dim, plh, split):
"""Constructor
Arguments:
args -- command line arguments passed into the main function
region_dim -- dimensions of the region features
phrase_dim -- dimensions of the phrase features
plh -- placeholder dictory containing the tensor inputs
split -- the data split (i.e. 'train', 'test', 'val')
"""
datafn = os.path.join('data', 'flickr', '%s_imfeats.h5' % split)
self.data = h5py.File(datafn, 'r')
vecs = np.array(self.data['phrase_features'], np.float32)
uniquePhrases = list(self.data['phrases'])
assert(vecs.shape[0] == len(uniquePhrases))

w2v_dict = {}
for index, phrase in enumerate(uniquePhrases):
w2v_dict[phrase] = vecs[index, :]

# mapping from uniquePhrase to w2v
self.w2v_dict = w2v_dict
self.pairs = list(self.data['pairs'])
self.n_pairs = len(self.pairs[0])
self.pair_index = range(self.n_pairs)

self.uniquePhrases = uniquePhrases # set of unique phrases

self.split = split
self.plh = plh
self.is_train = split == 'train'
self.neg_to_pos_ratio = args.neg_to_pos_ratio
self.batch_size = args.batch_size
self.max_boxes = args.max_boxes
if self.is_train:
self.success_thresh = args.train_success_thresh
else:
self.success_thresh = args.test_success_thresh

self.region_feature_dim = region_dim
self.phrase_feature_dim = phrase_dim

def __len__(self):
return self.n_pairs

def shuffle(self):
''' Shuffles the order of the pairs being sampled
'''
np.random.shuffle(self.pair_index)

def num_batches(self):
return int(np.ceil(float(len(self)) / self.batch_size))

def get_batch(self, batch_id):
"""Returns a minibatch given a valid id for it
Arguments:
batch_id -- number between 0 and self.num_batches()
Returns:
feed_dict -- dictionary containing minibatch data
gt_labels -- indicates positive/negative regions
num_pairs -- number of pairs without padding
"""
region_features = np.zeros((self.batch_size, self.max_boxes,
self.region_feature_dim), dtype=np.float32)
start_pair = batch_id * self.batch_size
end_pair = min(start_pair + self.batch_size, len(self))
num_pairs = end_pair - start_pair
gt_labels = np.zeros((self.batch_size, self.max_boxes),
dtype=np.float32)
phrase_features = np.zeros((self.batch_size, self.phrase_feature_dim),
dtype=np.float32)
for pair_id in range(num_pairs):
sample_id = self.pair_index[start_pair + pair_id]

# paired image
im_id = self.pairs[0][sample_id]

# paired phrase
phrase = self.pairs[1][sample_id]

# phrase instance identifier
p_id = self.pairs[2][sample_id]

# gets region features
features = np.array(self.data[im_id], np.float32)
num_boxes = min(len(features), self.max_boxes)
features = features[:num_boxes, :self.region_feature_dim]
overlaps = np.array(self.data['%s_%s_%s' % (im_id, phrase, p_id)])

# last 4 dimensions of overlaps are ground truth box coordinates
assert(num_boxes <= len(overlaps) - 4)
overlaps = overlaps[:num_boxes]
region_features[pair_id, :num_boxes,:] = features
phrase_features[pair_id, :] = self.w2v_dict[phrase]
gt_labels[pair_id, :num_boxes] = overlaps >= self.success_thresh
if self.is_train:
num_pos = int(np.sum(gt_labels[pair_id, :]))
num_neg = num_pos * self.neg_to_pos_ratio
negs = np.random.permutation(np.where(overlaps < 0.3)[0])

if len(negs) < num_neg: # if not enough negatives
negs = np.random.permutation(np.where(overlaps < 0.4)[0])

# logistic loss only counts a region labeled as -1 negative
gt_labels[pair_id, negs[:num_neg]] = -1

feed_dict = {self.plh['phrase'] : phrase_features,
self.plh['region'] : region_features,
self.plh['train_phase'] : self.is_train,
self.plh['num_boxes'] : self.max_boxes,
self.plh['labels'] : gt_labels
}

return feed_dict, gt_labels, num_pairs

12 changes: 12 additions & 0 deletions data_processing_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Conditional Image-Text Embedding Networks

The code currently assumes datasets are divided into three hdf5 files named `<split>_imfeats.h5` where `split` takes on the value train, test, or val. It assumes it has the following items:

1. phrase_features: #num_phrase X 6000 dimensional matrix of phrase features
2. phrases: array of #num_phrase strings corresponding to the phrase features
3. pairs: 3 x M matrix where each column contains a string representation for the `[image name, phrase, pair identifier]` pairs in the split.
4. Each `<image name>` should return a #num_boxes x feature_dimensional matrix of the visual features. The features should contain the visual representation as well as the spatial features for the box followed by its coordinates (i.e. the precomputed features we released are 4096 (VGG) + 5 (spatial) + 4 (box coordinates) = 4105 dimensional).
5. Each `<image name>_<phrase>_<pair identifier> should contain a vector containing the intersection over union with the ground truth box followed by the box's coordinates (i.e. for N boxes the vector should be N + 4 dimensional).


The example script uses the [pl-clc](https://github.com/BryanPlummer/pl-clc) repo for parsing and computing features of the Flick30K Entities dataset. It assumes it uses the built-in MATLAB PCA function, and not the one in the `toolbox` external module.
117 changes: 117 additions & 0 deletions data_processing_example/process_data.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
% Code should be placed in the pl-clc directory to run and assumes
% the dataset has been downloaded as specified in that repo.
net = fullfile('models', 'voc_2007_trainvaltest_2012_trainval', 'vgg16_fast_rcnn_iter_100000.caffemodel');
def = fullfile('models', 'fastrcnn_feat.prototxt');
image_dir = fullfile('datasets', 'Flickr30kEntities', 'Images');
output_dir = '.';

% code assumes train is the first to be processed
splits = {'train', 'test', 'val'};
phraseCoeff = [];
phraseMean = [];
for i = 1:length(splits)
load(fullfile('data', 'flickr30k', sprintf('%sData.mat',splits{i}),'imData');

imData.filterNonvisualPhrases();
if strcmp(splits{i},'train')
imData.concatenateGTBoxes();
end

% getPhraseWords returns a nested image x sentence x phrase
% cell array
uniquePhrases = imData.getPhraseWords();
uniquePhrases = vertcat(uniquePhrases{:});
maxPhrases = max(cellfun(@length, uniquePhrases));
uniquePhrases = vertcat(uniquePhrases{:});
uniquePhrases(cellfun(@isempty, uniquePhrases)) = [];
uniquePhrases = cellfun(@(f)strrep(f,'+',' '), uniquePhrases, 'UniformOutput',false);
uniquePhrases = cellfun(@(f)strtrim(removePunctuation(f)), uniquePhrases, 'UniformOutput',false);
uniquePhrases(cellfun(@isempty, uniquePhrases)) = [];
uniquePhrases = unique(uniquePhrases);
uniquePhrases = cellfun(@(f)strrep(f,' ','+'), uniquePhrases, 'UniformOutput',false);
phraseFeatures = single(getHGLMMFeatures(uniquePhrases))';
uniquePhrases = [uniquePhrases;{'unk'}];

% compute PCA parameters on first iteration (train split)
if isempty(phraseCoeff)
phraseMean = mean(phraseFeatures, 1);
phraseFeatures = bsxfun(@minus, phraseFeatures, phraseMean);
phraseCoeff = pca(phraseFeatures, 'NumComponents', 6000);
else
phrase_features = bsxfun(@minus, phraseFeatures, phraseMean);
end

phraseFeatures = phraseFeatures * phraseCoeff;
phraseFeatures = [phraseFeatures; zeros(1, 6000, 'single')];
phraseFeatures = phraseFeatures';

hdf5fn = fullfile(output_dir, sprintf('%s_imfeats.h5', splits{i}));
hdf5write(hdf5fn, 'max_phrases', maxPhrases, 'phrase_features', phraseFeatures, 'phrases', uniquePhrases);
clear phraseFeatures

pairs = cell(imData.nImages, 1);
for x = 1:imData.nImages
boxes = imData.getBoxes(x);
pairs{x} = cell(imData.nSentences(x), 1);
for y = 1:imData.nSentences(x)
pairs{x}{y} = cell(imData.nPhrases(x, y), 1);
for z = 1:imData.nPhrases(x, y)
box = imData.getPhraseGT(x, y, z);
phrase = imData.getPhrase(x,y,z).getPhraseString(imData.stopwords);
if ~isempty(phrase)
phrase = strtrim(removePunctuation(strrep(phrase,'+',' ')));
end

if isempty(phrase)
phrase = 'unk';
end
phrase = strrep(phrase,' ','+');

p_id = sprintf('%i_%i', y, z);
pairs{x}{y}{z} = {imData.imagefns{x}, phrase, p_id};
overlaps = getIOU(box, boxes);
phrase_id = sprintf('%s_%s_%s', imData.imagefns{x}, phrase, p_id);
overlaps = [overlaps; box'];
hdf5write(hdf5fn, phrase_id, overlaps, 'WriteMode', 'append');
end
end
end

pairs = vertcat(pairs{:});
pairs = vertcat(pairs{:});
pairs = vertcat(pairs{:});
hdf5write(hdf5fn, 'pairs', pairs, 'WriteMode', 'append');

% separate into batches since getFastRCNNFeatures operates on a
% batch level
batchSize = 1000;
nBatches = ceil(imData.nImages/batchSize);
for batch = 1:batchSize:imData.nImages
batchEnd = min(batch+batchSize-1,imData.nImages);
imagefns = imData.imagefns(batch:batchEnd);
imagedir = imData.imagedir;
ext = imData.ext;
stopwords = imData.stopwords;
batchData = ImageSetData(imData.split,imagefns,imagedir,ext,stopwords);
batchData.phrase = imData.phrase(batch:batchEnd);
batchData.relationship = imData.phrase(batch:batchEnd);
batchData.annotations = imData.annotations(batch:batchEnd);
batchData.boxes = imData.boxes(batch:batchEnd);
regionFeatures = getFastRCNNFeatures(batchData, net, def);
for j = 1:batchData.nImages
boxes = batchData.getBoxes(j);
imDims = batchData.imSize(j);
boxFeatures = boxes;
boxFeatures(:, 1) = boxFeatures(:,1) / imDims(2);
boxFeatures(:, 2) = boxFeatures(:,2) / imDims(1);
boxFeatures(:, 3) = boxFeatures(:,3) / imDims(2);
boxFeatures(:, 4) = boxFeatures(:,4) / imDims(1);
boxWidth = boxes(:,3) - boxes(:,1);
boxHeight = boxes(:,4) - boxes(:,2);
boxFeatures = [boxFeatures, (boxWidth.*boxHeight)./(imDims(1)*imDims(2))];
features = [regionFeatures{j}; boxFeatures'; boxes'];
hdf5write(hdf5fn, batchData.imagefns{j}, features, 'WriteMode', 'append');
end
end
end

Loading

0 comments on commit 19d2b36

Please sign in to comment.