Skip to content

Commit

Permalink
Finished with updates to integrate CITE and phrase detection repos
Browse files Browse the repository at this point in the history
  • Loading branch information
Bryan Plummer committed Dec 19, 2019
1 parent 1d12804 commit 0c3e8f0
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 205 deletions.
20 changes: 7 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,29 @@
@inproceedings{plummerCITE2018,
Author = {Bryan A. Plummer and Paige Kordas and M. Hadi Kiapour and Shuai Zheng and Robinson Piramuthu and Svetlana Lazebnik},
Title = {Conditional Image-Text Embedding Networks},
Booktitle = {ECCV},
Booktitle = {The European Conference on Computer Vision (ECCV)},
Year = {2018}
}

This code was tested on an Ubuntu 16.04 system using Tensorflow 1.2.1.

### Phrase Localization Evaluation
You can test a model using:

python main.py --test --spatial --resume runs/<experiment_name>/model_best

You can test the ReferIt dataset by setting the dataset flag and adjusting the number of embeddings to match the trained model, e.g. to train a model with 12 conditional embeddings you would use:
We recommend using the `data/cache_cite_features.sh` script from the [phrase detection repository](https://github.com/BryanPlummer/phrase_detection) to obtain the precomputed features to use with our model. These will obtain better performance than our original paper as seen in [this paper](https://arxiv.org/pdf/1811.07212.pdf), i.e. about 72/54 localization accuracy on Flickr30K Entities and Referit, respectively. You can also find an explanation of the format of the dataset in the `data_processing_example`.

python main.py --test --spatial --dataset referit --num_embeddings 12 --resume runs/<experiment_name>/model_best
You can also find precomputed HGLMM features used in our work [here](http://ai.bu.edu/grovle/)

### Training New Models
Our code contains everything required to train or test models using precomputed features. You can train a new model on Flickr30K Entites using:
Our code contains everything required to train or test models using precomputed features. You can train a model using:

python main.py --name <name of experiment>

When it completes training it will output the localization accuracy using the best model on the testing and validation sets. Note that the above does not use the spatial features we used in our paper (needs the `--spatial` flag). You can see a listing and description of many tuneable parameters with:

python main.py --help

### Precomputed Features

We recommend using the `data/cache_cite_features.sh` script from the [phrase detection repository](https://github.com/BryanPlummer/phrase_detection) to obtain the precomputed features to use with our model. These will obtain better performance than our original paper as seen in [this paper](https://arxiv.org/pdf/1811.07212.pdf), i.e. about 72/54 localization accuracy on Flickr30K Entities and Referit, respectively. You can also find an explanation of the format of the dataset in the `data_processing_example`.
### Phrase Localization Evaluation
When testing a model you need to use the same settings as used during training. For example afer training with spatial features, you would have to test using:

You can also find precomputed HGLMM features used in our work [here](http://ai.bu.edu/grovle/).
python main.py --test --spatial --resume runs/<experiment_name>/model_best


Many thanks to [Kevin Shih](https://scholar.google.com/citations?user=4x3DhzAAAAAJ&hl=en) and [Liwei Wang](https://scholar.google.com/citations?user=qnbdnZEAAAAJ&hl=en) for providing to their implementation of the [Similarity Network](https://arxiv.org/abs/1704.03470) that was used as the basis for this repo.
Empty file added __init__.py
Empty file.
19 changes: 11 additions & 8 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def load_word_embeddings(word_embedding_filename, embedding_length):
if not line:
continue

vec = line.split(',')
vec = line.split()
if len(vec) != embedding_length + 1:
continue

Expand Down Expand Up @@ -44,7 +44,11 @@ def __init__(self, args, region_dim, split, tok2idx, train_phrases = None):
plh -- placeholder dictory containing the tensor inputs
split -- the data split (i.e. 'train', 'test', 'val')
"""
datafn = os.path.join('data', args.dataset, '%s_features.h5' % split)
if args.datadir == 'data':
datafn = os.path.join(args.datadir, args.dataset, '%s_features.h5' % split)
else:
datafn = os.path.join(args.datadir, '%s_features.h5' % split)

self.data = h5py.File(datafn, 'r')

phrases = list(self.data['phrases'])
Expand Down Expand Up @@ -96,7 +100,6 @@ def __init__(self, args, region_dim, split, tok2idx, train_phrases = None):
self.neg_to_pos_ratio = args.neg_to_pos_ratio
self.batch_size = args.batch_size
self.max_boxes = args.max_boxes
self.num_pos = args.num_pos
if self.is_train:
self.success_thresh = args.train_success_thresh
else:
Expand Down Expand Up @@ -324,12 +327,12 @@ def get_batch(self, batch_id, plh):
# logistic loss only counts a region labeled as -1 negative
gt_labels[pair_id, i, negs[:num_neg]] = -1

feed_dict = {plh['phrase'] : phrase_features,
plh['region'] : region_features,
feed_dict = {plh['phrases'] : phrase_features,
plh['regions'] : region_features,
plh['train_phase'] : self.is_train,
plh['num_boxes'] : self.max_boxes,
plh['num_phrases'] : max_phrases,
plh['phrase_denom'] : np.sum(num_phrases).astype(np.float32) + 1e-6,
plh['boxes_per_image'] : self.max_boxes,
plh['phrases_per_image'] : max_phrases,
plh['phrase_count'] : np.sum(num_phrases).astype(np.float32) + 1e-6,
plh['labels'] : gt_labels
}

Expand Down
79 changes: 29 additions & 50 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,30 @@
import tensorflow as tf

from collections import Counter
from model import setup_embed_model, setup_cite_model
from model import CITE
from data_loader import DataLoader, load_word_embeddings

parser = argparse.ArgumentParser(description='Conditional Image-Text Similarity Network')
parser.add_argument('--name', default='Conditional_Image-Text_Similarity_Network', type=str,
help='name of experiment')
parser.add_argument('--dataset', default='flickr', type=str,
help='name of the dataset to use')
parser.add_argument('--datadir', default='data', type=str,
help='directory containing the hdf5 data files')
parser.add_argument('--language_model', default='avg', type=str,
help='name of experiment')
help='type of language model to use, types: avg (default), attend, gru')
parser.add_argument('--r_seed', type=int, default=42,
help='random seed (default: 42)')
parser.add_argument('--info_iterval', type=int, default=1000,
help='number of batches to process before outputing training status')
parser.add_argument('--resume', default='', type=str,
help='filename of model to load (default: none)')
parser.add_argument('--cca_parameters', default='', type=str,
help='filename of model to load (default: none)')
help='filename of cca parameters to load (default: none)')
parser.add_argument('--test', dest='test', action='store_true', default=False,
help='Run model on test set')
parser.add_argument('--batch-size', type=int, default=6,
help='input batch size for training (default: 200)')
help='input batch size for training (default: 6)')
parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
help='learning rate (default: 5e-5)')
parser.add_argument('--embed_l1', type=float, default=5e-5,
Expand All @@ -38,46 +40,42 @@
help='maximum number of epochs, less than 1 indicates no limit (default: 0)')
parser.add_argument('--no_gain_stop', type=int, default=5,
help='number of epochs used to perform early stopping based on validation performance (default: 5)')
parser.add_argument('--num_pos', type=int, default=2,
help='ratio of negatives to positives used during training (default: 2)')
parser.add_argument('--neg_to_pos_ratio', type=int, default=2,
help='ratio of negatives to positives used during training (default: 2)')
parser.add_argument('--minimum_gain', type=float, default=5e-4, metavar='N',
help='minimum performance gain for a model to be considered better (default: 5e-4)')
parser.add_argument('--cca_weight_reg', type=float, default=1.,
help='learning rate (default: 1)')
parser.add_argument('--phrase_phrase_weight', type=float, default=.1,
help='learning rate (default: 1)')
parser.add_argument('--margin', type=float, default=0.05,
help='learning rate (default: 1)')
parser.add_argument('--embed_topk', type=int, default=5,
parser.add_argument('--cca_weight_reg', type=float, default=5e-5,
help='learning rate (default: 1)')
parser.add_argument('--train_success_thresh', type=float, default=0.6,
help='minimum training intersection-over-union threshold for success (default: 0.6)')
parser.add_argument('--test_success_thresh', type=float, default=0.5,
help='minimum testing intersection-over-union threshold for success (default: 0.5)')
parser.add_argument('--dim_embed', type=int, default=256,
help='how many dimensions in final embedding (default: 256)')
help='how many dimensions in the final embedding (default: 256)')
parser.add_argument('--max_boxes', type=int, default=300,
help='maximum number of edge boxes per image (default: 300)')
parser.add_argument('--max_phrases', type=int, default=-1,
help='maximum number of phrases per image, values of less than will use all of them (default: -1)')
parser.add_argument('--max_tokens', type=int, default=10,
help='maximum number of words allowed in a phrase (default: 10)')
parser.add_argument('--num_embeddings', type=int, default=4,
help='number of embeddings to train (default: 4)')
parser.add_argument('--region_norm_axis', type=int, default=1,
help='axis=1 treats all regions like a single image (better for localization-only) and for axis=2 L2 norm is done for each region')
parser.add_argument('--spatial', dest='spatial', action='store_true', default=False,
help='Flag indicating whether to use spatial features')
parser.add_argument('--use_augmented', dest='use_augmented', action='store_true', default=False,
help='Flag indicating whether to use augmented positive phrases (default: use gt only)')
help='flag indicating whether to use spatial features')
parser.add_argument('--npa', action='store_true', default=False,
help='Flag indicates using phrase detection metrics rather than localization-only')
parser.add_argument('--embed_weight', type=float, default=1e-5,
help='learning rate (default: 1)')
help='use hard-negative phrase mining')
parser.add_argument('--use_augmented', dest='use_augmented', action='store_true', default=False,
help='flag indicating whether to use augmented positive phrases (default: use gt only)')
parser.add_argument('--ifs', action='store_true', default=False,
help='uses inverse frequency sampling when training with augmented phrases')
parser.add_argument('--word_embedding', type=str, default='data/hglmm_6kd.txt',
help='full path to space separated language embedding features to load')
parser.add_argument('--region_norm_axis', type=int, default=1,
help='axis on which to perform L2 norm, 1 treats all regions like a single image (better for localization-only) and for 2 L2 norm is done for each region')
parser.add_argument('--embedding_ft', dest='embedding_ft', action='store_true', default=False,
help='Flag indicating whether to fine-tune the language features')
help='flag indicating whether to fine-tune the language features')
parser.add_argument('--embed_weight', type=float, default=1e-5,
help='L2 regularization weight for fine-tuning language features (default: 1e-5)')

def main():
global args
Expand All @@ -93,41 +91,21 @@ def main():
region_feature_dim += 5

test_loader, train_loader, val_loader = get_data_loaders(region_feature_dim, tok2idx)
max_length = test_loader.max_length
# setup placeholders
labels_plh = tf.placeholder(tf.float32, shape=[None, None, None])
phrase_plh = tf.placeholder(tf.int32, shape=[None, None, max_length])
region_plh = tf.placeholder(tf.float32, shape=[None, None, region_feature_dim])
phrase_labels_plh = tf.placeholder(tf.float32, shape=[None, None])
train_phase_plh = tf.placeholder(tf.bool, name='train_phase')
num_boxes_plh = tf.placeholder(tf.int32)
num_phrases_plh = tf.placeholder(tf.int32)
phrase_denom_plh = tf.placeholder(tf.float32)

plh = {}
plh['num_boxes'] = num_boxes_plh
plh['labels'] = labels_plh
plh['phrase'] = phrase_plh
plh['region'] = region_plh
plh['train_phase'] = train_phase_plh
plh['num_phrases'] = num_phrases_plh
plh['phrase_denom'] = phrase_denom_plh
plh['phrase_labels'] = phrase_labels_plh
model = setup_cite_model(args, phrase_plh, region_plh, train_phase_plh, labels_plh,
num_boxes_plh, num_phrases_plh, region_feature_dim,
max_length, phrase_denom_plh, vecs)
model_constructor = CITE(args, vecs, test_loader.max_length, region_feature_dim)
model = model_constructor.setup_model()
plh = model_constructor.get_placeholders()
if args.test:
test(model, test_loader, plh, model_name=args.resume)
sys.exit()

save_model_directory = os.path.join('runs', args.name)
save_model_directory = os.path.join('runs', args.dataset, args.name)
if not os.path.exists(save_model_directory):
os.makedirs(save_model_directory)
# training with Adam
acc, best_adam = train(model, train_loader, val_loader, plh, args.resume)

# finetune with SGD after loading the best model trained with Adam
best_model_filename = os.path.join('runs', args.name, 'model_best')
best_model_filename = os.path.join('runs', args.dataset, args.name, 'model_best')
acc, best_sgd = train(model, train_loader, val_loader, plh,
best_model_filename, False, acc)
best_epoch = best_adam + best_sgd
Expand Down Expand Up @@ -270,17 +248,18 @@ def train(model, train_loader, test_loader, plh, model_weights, use_adam = True,
update_confusion_table(model, test_loader, train_loader, plh, sess)

process_epoch(model, train_loader, plh, sess, train_step, epoch, suffix)
saver.save(sess, os.path.join('runs', args.name, 'checkpoint'),
saver.save(sess, os.path.join('runs', args.dataset, args.name, 'checkpoint'),
global_step = epoch)

acc = test(model, test_loader, plh, sess)

# the first time we update the table localization accuracy may drop a lot
# so let's reset the baseline of what is good
if update_table and epoch - 3 == 0 and use_adam:
best_acc = acc

if acc > best_acc:
saver.save(sess, os.path.join('runs', args.name, 'model_best'))
saver.save(sess, os.path.join('runs', args.dataset, args.name, 'model_best'))
if (acc - args.minimum_gain) > best_acc:
best_epoch = epoch

Expand Down
Loading

0 comments on commit 0c3e8f0

Please sign in to comment.