Skip to content

Commit

Permalink
Filling in missing comments/links and removing unused class members.
Browse files Browse the repository at this point in the history
  • Loading branch information
Bryan Plummer committed Feb 3, 2018
1 parent e711008 commit 73b3ded
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 8 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,7 @@ Along with our example data processing script in `data_processing_example` you c

Our best CITE model on Flickr30K Entities using these precomputed features can be found [here](https://drive.google.com/open?id=1rmeIqYTCIduNc2QWUEdXLHFGrlOzz2xO).

You can download the raw Flickr30K Entities data [here](http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/), but isn't necessary to use our precomputed features.

Many thanks to [Kevin Shih](https://scholar.google.com/citations?user=4x3DhzAAAAAJ&hl=en) and [Liwei Wang](https://scholar.google.com/citations?user=qnbdnZEAAAAJ&hl=en) for access to their Similarity Network code that was used as the basis for this implementation.

Many thanks to [Kevin Shih](https://scholar.google.com/citations?user=4x3DhzAAAAAJ&hl=en) and [Liwei Wang](https://scholar.google.com/citations?user=qnbdnZEAAAAJ&hl=en) for access to their [Similarity Network](https://arxiv.org/abs/1704.03470) code that was used as the basis for this implementation.
8 changes: 3 additions & 5 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def __init__(self, args, region_dim, phrase_dim, plh, split):
datafn = os.path.join('data', 'flickr', '%s_imfeats.h5' % split)
self.data = h5py.File(datafn, 'r')
vecs = np.array(self.data['phrase_features'], np.float32)
uniquePhrases = list(self.data['phrases'])
assert(vecs.shape[0] == len(uniquePhrases))
phrases = list(self.data['phrases'])
assert(vecs.shape[0] == len(phrases))

w2v_dict = {}
for index, phrase in enumerate(uniquePhrases):
for index, phrase in enumerate(phrases):
w2v_dict[phrase] = vecs[index, :]

# mapping from uniquePhrase to w2v
Expand All @@ -30,8 +30,6 @@ def __init__(self, args, region_dim, phrase_dim, plh, split):
self.n_pairs = len(self.pairs[0])
self.pair_index = range(self.n_pairs)

self.uniquePhrases = uniquePhrases # set of unique phrases

self.split = split
self.plh = plh
self.is_train = split == 'train'
Expand Down
2 changes: 1 addition & 1 deletion data_processing_example/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Conditional Image-Text Embedding Networks

The code currently assumes datasets are divided into three hdf5 files named `<split>_imfeats.h5` where `split` takes on the value train, test, or val. It assumes it has the following items:
The code currently assumes datasets are divided into three hdf5 files named `<split>_imfeats.h5` where `split` takes on the value train, test, or val. It assumes the hdf5 files contain the following items:

1. phrase_features: #num_phrase X 6000 dimensional matrix of phrase features
2. phrases: array of #num_phrase strings corresponding to the phrase features
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
parser.add_argument('--embed_l1', type=float, default=5e-5,
help='weight of the L1 regularization term used on the concept weight branch (default: 5e-5)')
parser.add_argument('--max_epoch', type=int, default=0,
help='maximum number of epochs, <1 indicates no limit (default: 0)')
help='maximum number of epochs, less than 1 indicates no limit (default: 0)')
parser.add_argument('--no_gain_stop', type=int, default=5,
help='number of epochs used to perform early stopping based on validation performance (default: 5)')
parser.add_argument('--neg_to_pos_ratio', type=int, default=2,
Expand Down
6 changes: 6 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def setup_model(args, phrase_plh, region_plh, train_phase_plh, labels_plh, num_b
labels_plh -- indicates positive (1), negative (-1), or ignore (0)
num_boxes_plh -- number of boxes per example in the batch
region_feature_dim -- dimensions of the region features
Returns:
total_loss -- weighted combination of the region and concept loss
region_loss -- logistic loss for phrase-region prediction
concept_loss -- L1 loss for the output of the concept weight branch
region_prob -- each row contains the probability a region is associated with a phrase
"""
labels_plh = tf.reshape(labels_plh, [-1, num_boxes_plh])
eb_fea_plh = tf.reshape(region_plh, [-1, num_boxes_plh, region_feature_dim])
Expand Down

0 comments on commit 73b3ded

Please sign in to comment.